diff --git a/README.md b/README.md index 6e33950..b1f4bf2 100644 --- a/README.md +++ b/README.md @@ -86,6 +86,11 @@ second contains the corresponding query homologous genes. > for code testing. The resulting annotation accuracy may not be as good as > using the full dataset as the reference. +**Suggestions** +> +> If you have sufficient GPU memory, setting the hidden-size `h_dim=512` +in "came/PARAMETERS.py" may result in a more accurate cell-type transfer. + ### Test CAME's pipeline (optional) To test the package, run the python file `test_pipeline.py`: @@ -120,7 +125,7 @@ If you are having issues, please let us know. We have a mailing list located at: ### Citation -If CAME is useful for your research, consider citing our preprint: +If CAME is useful for your research, consider citing our work: > Liu X, Shen Q, Zhang S. Cross-species cell-type assignment of single-cell RNA-seq by a heterogeneous graph neural network[J]. Genome Research, 2022: gr. 276868.122. diff --git a/README_CH.md b/README_CH.md index 23867f5..e79eb98 100644 --- a/README_CH.md +++ b/README_CH.md @@ -73,6 +73,10 @@ python setup.py install > 数据文件 “raw-Baron_human.h5ad” 仅用于代码测试,是原始数据的子样本 (20%), > 因此结果的注释精度可能不如使用完整数据集作为参考。 +### **建议** + +如果你有足够的GPU显存,我们建议在``came/PARAMETERS.py``中设置`h_dim=512` 来获得更好的结果。 + ### 测试 CAME 的分析流程 (非必要) 可以直接运行 `test_pipeline.py` 来测试 CAME 的分析流程: diff --git a/came/__init__.py b/came/__init__.py index 4d14572..82cf95e 100644 --- a/came/__init__.py +++ b/came/__init__.py @@ -49,4 +49,4 @@ from .pipeline import KET_CLUSTER, __test1__, __test2__ -__version__ = "0.1.10" +__version__ = "0.1.12" diff --git a/came/model/_utils.py b/came/model/_utils.py index a391669..fff3c2a 100644 --- a/came/model/_utils.py +++ b/came/model/_utils.py @@ -21,12 +21,17 @@ from .cgc import CGCNet # from ._minibatch import create_blocks, create_batch +try: + from dgl.dataloading import NodeDataLoader +except ImportError: + from dgl.dataloading import DataLoader as NodeDataLoader + def idx_hetero(feat_dict, id_dict): sub_feat_dict = {} for k, ids in id_dict.items(): if k in feat_dict: - sub_feat_dict[k] = feat_dict[k][ids] + sub_feat_dict[k] = feat_dict[k][ids.cpu()] else: # logging.warning(f'key "{k}" does not exist in {feat_dict.keys()}') pass @@ -158,7 +163,7 @@ def get_all_hidden_states( sampler = dgl.dataloading.MultiLayerNeighborSampler( sampler.fanouts[:-1]) - dataloader = dgl.dataloading.NodeDataLoader( + dataloader = NodeDataLoader( g, {'cell': g.nodes('cell'), 'gene': g.nodes('gene')}, sampler, device=device, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=0 @@ -301,7 +306,7 @@ def get_model_outputs( ###################################### if sampler is None: sampler = model.get_sampler(g.canonical_etypes, 50) - dataloader = dgl.dataloading.NodeDataLoader( + dataloader = NodeDataLoader( g, {'cell': g.nodes('cell')}, sampler, device=device, batch_size=batch_size, diff --git a/came/utils/train.py b/came/utils/train.py index 1b2cb61..7a8db65 100644 --- a/came/utils/train.py +++ b/came/utils/train.py @@ -32,6 +32,11 @@ from .plot import plot_records_for_trainer from ._base_trainer import BaseTrainer, SUBDIR_MODEL +try: + from dgl.dataloading import NodeDataLoader +except ImportError: + from dgl.dataloading import DataLoader as NodeDataLoader + def seed_everything(seed=123): """ not works well """ @@ -52,7 +57,7 @@ def make_class_weights(labels, astensor=True, foo=np.sqrt, n_add=0): counts = value_counts(labels).sort_index() # sort for alignment n_cls = len(counts) + n_add - w = counts.apply(lambda x: 1 / foo(x + 1) if x > 0 else 0) + w = counts.apply(lambda x: 1 / foo(x + 1) if x > 0 else 0) w = (w / w.sum() * (1 - n_add / n_cls)).values w = np.array(list(w) + [1 / n_cls] * int(n_add)) @@ -117,7 +122,7 @@ def prepare4train( test_idx = LongTensor(test_idx) g = dpair.get_whole_net(rebuild=False, ) - g.nodes[node_cls_type].data[key_label] = labels # date: 211113 + g.nodes[node_cls_type].data[key_label] = labels # date: 211113 ENV_VARs = dict( classes=classes, @@ -134,8 +139,8 @@ def prepare4train( class Trainer(BaseTrainer): """ - - + + """ def __init__(self, @@ -309,7 +314,7 @@ def train(self, n_epochs=350, **params_lossfunc ) - # prediction + # prediction _, y_pred = torch.max(logits, dim=1) # ========== evaluation (Acc.) ========== @@ -367,22 +372,24 @@ def train_minibatch(self, n_epochs=100, if sampler is None: sampler = model.get_sampler(g.canonical_etypes, 50) - train_dataloader = dgl.dataloading.NodeDataLoader( + train_dataloader = NodeDataLoader( # The following arguments are specific to NodeDataLoader. - g, {'cell': train_idx}, # The node IDs to iterate over in minibatches - sampler, device=device, # Put the sampled MFGs on CPU or GPU + g, {'cell': train_idx}, + # The node IDs to iterate over in minibatches + sampler, device='cpu', # Put the sampled MFGs on CPU or GPU # The following arguments are inherited from PyTorch DataLoader. batch_size=batch_size, shuffle=True, drop_last=False, num_workers=0 ) - test_dataloader = dgl.dataloading.NodeDataLoader( - g, {'cell': test_idx}, sampler, device=device, batch_size=batch_size, + test_dataloader = NodeDataLoader( + g, {'cell': test_idx}, sampler, device='cpu', batch_size=batch_size, shuffle=False, drop_last=False, num_workers=0 ) print(f" start training (device='{device}') ".center(60, '=')) + rcd = {} for epoch in range(n_epochs): model.train() self._cur_epoch += 1 - + t0 = time.time() all_train_preds = [] train_labels = [] @@ -443,6 +450,7 @@ def train_minibatch(self, n_epochs=100, **rcd, print_info=self._cur_epoch % info_stride == 0 or backup) self.log_info(**rcd, print_info=True) self._cur_epoch_adopted = self._cur_epoch + self.save_checkpoint_record() def get_current_outputs(self, feat_dict=None, @@ -526,7 +534,8 @@ def evaluate_metrics( ) return metrics - def log_info(self, train_acc, test_acc, ami=None, print_info=True, # dur=0., + def log_info(self, train_acc, test_acc, ami=None, print_info=True, + # dur=0., **kwargs): dur_avg = np.average(self.dur) ami = kwargs.get('AMI', 'NaN') if ami is None else ami @@ -575,7 +584,7 @@ def infer_for_nodes( if reorder: return order_by_ids(all_test_preds, orig_ids) return all_test_preds - + def order_by_ids(x, ids): """reorder by the original ids""" diff --git a/setup.py b/setup.py index 3b893e6..b8bfcac 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,7 @@ EMAIL = '544568643@qq.com' AUTHOR = 'Xingyan Liu' REQUIRES_PYTHON = '>=3.8.0' -VERSION = '0.1.10' +VERSION = '0.1.12' REQUIRED = [ 'scanpy',