From b05e14246ed936e31afd7a1a468da52d50c15f5b Mon Sep 17 00:00:00 2001 From: Kevin Maik Jablonka Date: Tue, 3 Oct 2023 14:40:28 +0200 Subject: [PATCH] update docs --- docs/source/usage.rst | 11 +++++++++-- src/chemlift/finetune/peft_transformers.py | 2 +- src/chemlift/finetune/peftmodels.py | 1 + 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 8e2e962..60e22e7 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -25,7 +25,8 @@ To handle the different model types, we provide a :code:`ChemLIFTClassifierFacto from chemlift.finetuning.classifier import ChemLIFTClassifierFactory -model = ChemLIFTClassifierFactory('EleutherAI/gpt-neo-125m', load_in_8bit=False).create_model() +model = ChemLIFTClassifierFactory('property name', + model_name='EleutherAI/pythia-1b-deduped').create_model() model.fit(X, y) model.predict(X) ``` @@ -55,7 +56,13 @@ train_y = [1 if y > train_median else 0 for y in train_y] test_y = [1 if y > train_median else 0 for y in test_y] # train -model = ChemLIFTClassifierFactory('EleutherAI/gpt-neo-125m', load_in_8bit=False).create_model() # create the model +model = ChemLIFTClassifierFactory('transition wavelength class', # property name + model_name='EleutherAI/pythia-1b-deduped', # base model + load_in_8bit=True, # use quantized model + inference_batch_size=32, # batch size for inference + tokenizer_kwargs={"cutoff_len": 50}, # tokenizer kwargs, cutoff_len is the most important one + tune_settings={'num_train_epochs': 32} # settings for the training process, see transformers docs + ).create_model() # create the model model.fit(train_names, train_y) # predict diff --git a/src/chemlift/finetune/peft_transformers.py b/src/chemlift/finetune/peft_transformers.py index 3e3bf99..a0662e6 100644 --- a/src/chemlift/finetune/peft_transformers.py +++ b/src/chemlift/finetune/peft_transformers.py @@ -50,7 +50,7 @@ "EleutherAI/gpt-neo-1.3B": "left", "EleutherAI/gpt-neo-2.7B": "left", "EleutherAI/gpt-neox-20b": "left", - "EleutherAI/pythia-12b-dedupedz``": "left", + "EleutherAI/pythia-12b-deduped": "left", "EleutherAI/pythia-6.9b-deduped": "left", "EleutherAI/pythia-2.8b-deduped": "left", "EleutherAI/pythia-1.4b-deduped": "left", diff --git a/src/chemlift/finetune/peftmodels.py b/src/chemlift/finetune/peftmodels.py index 426e93a..f870eb3 100644 --- a/src/chemlift/finetune/peftmodels.py +++ b/src/chemlift/finetune/peftmodels.py @@ -20,6 +20,7 @@ from transformers.utils import logging from functools import partial from peft.utils.save_and_load import set_peft_model_state_dict +from fastcore.basics import basic_repr class ChemLIFTClassifierFactory: