Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

release-pypi-0.1.12 #32

Merged
merged 5 commits into from
Oct 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ second contains the corresponding query homologous genes.
> for code testing. The resulting annotation accuracy may not be as good as
> using the full dataset as the reference.

**Suggestions**
>
> If you have sufficient GPU memory, setting the hidden-size `h_dim=512`
in "came/PARAMETERS.py" may result in a more accurate cell-type transfer.

### Test CAME's pipeline (optional)

To test the package, run the python file `test_pipeline.py`:
Expand Down Expand Up @@ -120,7 +125,7 @@ If you are having issues, please let us know. We have a mailing list located at:

### Citation

If CAME is useful for your research, consider citing our preprint:
If CAME is useful for your research, consider citing our work:

> Liu X, Shen Q, Zhang S. Cross-species cell-type assignment of single-cell RNA-seq by a heterogeneous graph neural network[J]. Genome Research, 2022: gr. 276868.122.

Expand Down
4 changes: 4 additions & 0 deletions README_CH.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ python setup.py install
> 数据文件 “raw-Baron_human.h5ad” 仅用于代码测试,是原始数据的子样本 (20%),
> 因此结果的注释精度可能不如使用完整数据集作为参考。

### **建议**

如果你有足够的GPU显存,我们建议在``came/PARAMETERS.py``中设置`h_dim=512` 来获得更好的结果。

### 测试 CAME 的分析流程 (非必要)

可以直接运行 `test_pipeline.py` 来测试 CAME 的分析流程:
Expand Down
2 changes: 1 addition & 1 deletion came/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,4 @@
from .pipeline import KET_CLUSTER, __test1__, __test2__


__version__ = "0.1.10"
__version__ = "0.1.12"
11 changes: 8 additions & 3 deletions came/model/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,17 @@
from .cgc import CGCNet
# from ._minibatch import create_blocks, create_batch

try:
from dgl.dataloading import NodeDataLoader
except ImportError:
from dgl.dataloading import DataLoader as NodeDataLoader


def idx_hetero(feat_dict, id_dict):
sub_feat_dict = {}
for k, ids in id_dict.items():
if k in feat_dict:
sub_feat_dict[k] = feat_dict[k][ids]
sub_feat_dict[k] = feat_dict[k][ids.cpu()]
else:
# logging.warning(f'key "{k}" does not exist in {feat_dict.keys()}')
pass
Expand Down Expand Up @@ -158,7 +163,7 @@ def get_all_hidden_states(
sampler = dgl.dataloading.MultiLayerNeighborSampler(
sampler.fanouts[:-1])

dataloader = dgl.dataloading.NodeDataLoader(
dataloader = NodeDataLoader(
g, {'cell': g.nodes('cell'), 'gene': g.nodes('gene')},
sampler, device=device,
batch_size=batch_size, shuffle=False, drop_last=False, num_workers=0
Expand Down Expand Up @@ -301,7 +306,7 @@ def get_model_outputs(
######################################
if sampler is None:
sampler = model.get_sampler(g.canonical_etypes, 50)
dataloader = dgl.dataloading.NodeDataLoader(
dataloader = NodeDataLoader(
g, {'cell': g.nodes('cell')},
sampler, device=device,
batch_size=batch_size,
Expand Down
35 changes: 22 additions & 13 deletions came/utils/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@
from .plot import plot_records_for_trainer
from ._base_trainer import BaseTrainer, SUBDIR_MODEL

try:
from dgl.dataloading import NodeDataLoader
except ImportError:
from dgl.dataloading import DataLoader as NodeDataLoader


def seed_everything(seed=123):
""" not works well """
Expand All @@ -52,7 +57,7 @@ def make_class_weights(labels, astensor=True, foo=np.sqrt, n_add=0):

counts = value_counts(labels).sort_index() # sort for alignment
n_cls = len(counts) + n_add
w = counts.apply(lambda x: 1 / foo(x + 1) if x > 0 else 0)
w = counts.apply(lambda x: 1 / foo(x + 1) if x > 0 else 0)
w = (w / w.sum() * (1 - n_add / n_cls)).values
w = np.array(list(w) + [1 / n_cls] * int(n_add))

Expand Down Expand Up @@ -117,7 +122,7 @@ def prepare4train(
test_idx = LongTensor(test_idx)

g = dpair.get_whole_net(rebuild=False, )
g.nodes[node_cls_type].data[key_label] = labels # date: 211113
g.nodes[node_cls_type].data[key_label] = labels # date: 211113

ENV_VARs = dict(
classes=classes,
Expand All @@ -134,8 +139,8 @@ def prepare4train(

class Trainer(BaseTrainer):
"""


"""

def __init__(self,
Expand Down Expand Up @@ -309,7 +314,7 @@ def train(self, n_epochs=350,
**params_lossfunc
)

# prediction
# prediction
_, y_pred = torch.max(logits, dim=1)

# ========== evaluation (Acc.) ==========
Expand Down Expand Up @@ -367,22 +372,24 @@ def train_minibatch(self, n_epochs=100,
if sampler is None:
sampler = model.get_sampler(g.canonical_etypes, 50)

train_dataloader = dgl.dataloading.NodeDataLoader(
train_dataloader = NodeDataLoader(
# The following arguments are specific to NodeDataLoader.
g, {'cell': train_idx}, # The node IDs to iterate over in minibatches
sampler, device=device, # Put the sampled MFGs on CPU or GPU
g, {'cell': train_idx},
# The node IDs to iterate over in minibatches
sampler, device='cpu', # Put the sampled MFGs on CPU or GPU
# The following arguments are inherited from PyTorch DataLoader.
batch_size=batch_size, shuffle=True, drop_last=False, num_workers=0
)
test_dataloader = dgl.dataloading.NodeDataLoader(
g, {'cell': test_idx}, sampler, device=device, batch_size=batch_size,
test_dataloader = NodeDataLoader(
g, {'cell': test_idx}, sampler, device='cpu', batch_size=batch_size,
shuffle=False, drop_last=False, num_workers=0
)
print(f" start training (device='{device}') ".center(60, '='))
rcd = {}
for epoch in range(n_epochs):
model.train()
self._cur_epoch += 1

t0 = time.time()
all_train_preds = []
train_labels = []
Expand Down Expand Up @@ -443,6 +450,7 @@ def train_minibatch(self, n_epochs=100,
**rcd, print_info=self._cur_epoch % info_stride == 0 or backup)
self.log_info(**rcd, print_info=True)
self._cur_epoch_adopted = self._cur_epoch
self.save_checkpoint_record()

def get_current_outputs(self,
feat_dict=None,
Expand Down Expand Up @@ -526,7 +534,8 @@ def evaluate_metrics(
)
return metrics

def log_info(self, train_acc, test_acc, ami=None, print_info=True, # dur=0.,
def log_info(self, train_acc, test_acc, ami=None, print_info=True,
# dur=0.,
**kwargs):
dur_avg = np.average(self.dur)
ami = kwargs.get('AMI', 'NaN') if ami is None else ami
Expand Down Expand Up @@ -575,7 +584,7 @@ def infer_for_nodes(
if reorder:
return order_by_ids(all_test_preds, orig_ids)
return all_test_preds


def order_by_ids(x, ids):
"""reorder by the original ids"""
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
EMAIL = '544568643@qq.com'
AUTHOR = 'Xingyan Liu'
REQUIRES_PYTHON = '>=3.8.0'
VERSION = '0.1.10'
VERSION = '0.1.12'

REQUIRED = [
'scanpy',
Expand Down