Skip to content

Commit

Permalink
update docs
Browse files Browse the repository at this point in the history
  • Loading branch information
Kevin Maik Jablonka committed Oct 3, 2023
1 parent 79ecb5b commit b05e142
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 3 deletions.
11 changes: 9 additions & 2 deletions docs/source/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ To handle the different model types, we provide a :code:`ChemLIFTClassifierFacto
from chemlift.finetuning.classifier import ChemLIFTClassifierFactory
model = ChemLIFTClassifierFactory('EleutherAI/gpt-neo-125m', load_in_8bit=False).create_model()
model = ChemLIFTClassifierFactory('property name',
model_name='EleutherAI/pythia-1b-deduped').create_model()
model.fit(X, y)
model.predict(X)
```
Expand Down Expand Up @@ -55,7 +56,13 @@ train_y = [1 if y > train_median else 0 for y in train_y]
test_y = [1 if y > train_median else 0 for y in test_y]
# train
model = ChemLIFTClassifierFactory('EleutherAI/gpt-neo-125m', load_in_8bit=False).create_model() # create the model
model = ChemLIFTClassifierFactory('transition wavelength class', # property name
model_name='EleutherAI/pythia-1b-deduped', # base model
load_in_8bit=True, # use quantized model
inference_batch_size=32, # batch size for inference
tokenizer_kwargs={"cutoff_len": 50}, # tokenizer kwargs, cutoff_len is the most important one
tune_settings={'num_train_epochs': 32} # settings for the training process, see transformers docs
).create_model() # create the model
model.fit(train_names, train_y)
# predict
Expand Down
2 changes: 1 addition & 1 deletion src/chemlift/finetune/peft_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
"EleutherAI/gpt-neo-1.3B": "left",
"EleutherAI/gpt-neo-2.7B": "left",
"EleutherAI/gpt-neox-20b": "left",
"EleutherAI/pythia-12b-dedupedz``": "left",
"EleutherAI/pythia-12b-deduped": "left",
"EleutherAI/pythia-6.9b-deduped": "left",
"EleutherAI/pythia-2.8b-deduped": "left",
"EleutherAI/pythia-1.4b-deduped": "left",
Expand Down
1 change: 1 addition & 0 deletions src/chemlift/finetune/peftmodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from transformers.utils import logging
from functools import partial
from peft.utils.save_and_load import set_peft_model_state_dict
from fastcore.basics import basic_repr


class ChemLIFTClassifierFactory:
Expand Down

0 comments on commit b05e142

Please sign in to comment.