Skip to content

Commit

Permalink
Merge pull request #18 from lamalab-org/kjappelbaum/issue17
Browse files Browse the repository at this point in the history
  • Loading branch information
kjappelbaum authored Oct 11, 2023
2 parents 373bda1 + a9a117c commit e46bf28
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 0 deletions.
86 changes: 86 additions & 0 deletions experiments/scaling/pythia.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from gptchem.data import get_photoswitch_data

from chemlift.finetune.peftmodels import PEFTClassifier, ChemLIFTClassifierFactory
from sklearn.model_selection import train_test_split

from fastcore.xtras import load_pickle, save_pickle
from gptchem.evaluator import evaluate_classification
import time
import os


def get_timestr():
return time.strftime("%Y-%m-%d_%H-%M-%S")


models = [
"EleutherAI/pythia-12b-deduped",
"EleutherAI/pythia-6.9b-deduped",
"EleutherAI/pythia-2.8b-deduped",
"EleutherAI/pythia-1.4b-deduped",
"EleutherAI/pythia-1b-deduped",
"EleutherAI/pythia-410m-deduped",
"EleutherAI/pythia-160m-deduped",
"EleutherAI/pythia-70m-deduped",
]


def train_test(train_size, model_name, random_state=42):
data = get_photoswitch_data()

data = data.dropna(subset=["SMILES", "E isomer pi-pi* wavelength in nm"])

data["binned"] = data["E isomer pi-pi* wavelength in nm"].apply(
lambda x: 1 if x > data["E isomer pi-pi* wavelength in nm"].median() else 0
)

train, test = train_test_split(
data, train_size=train_size, stratify=data["binned"], random_state=random_state
)

train_median = train["E isomer pi-pi* wavelength in nm"].median()
train["binned"] = train["E isomer pi-pi* wavelength in nm"].apply(
lambda x: 1 if x > train_median else 0
)
test["binned"] = test["E isomer pi-pi* wavelength in nm"].apply(
lambda x: 1 if x > train_median else 0
)

model = ChemLIFTClassifierFactory(
"transition wavelength class",
model_name=model_name,
load_in_8bit=True,
inference_batch_size=32,
tokenizer_kwargs={"cutoff_len": 50},
tune_settings={"num_train_epochs": 32},
).create_model()

model.fit(train["SMILES"].values, train["binned"].values)

start = time.time()
predictions = model.predict(test["SMILES"].values)
end = time.time()

report = evaluate_classification(test["binned"].values, predictions)

if not os.path.exists("results"):
os.makedirs("results")

outname = f"results/{get_timestr()}_peft_{model_name}_{train_size}.pkl"

report["model_name"] = model_name
report["train_size"] = train_size
report["random_state"] = random_state
report["predictions"] = predictions
report["targets"] = test["binned"].values
report["fine_tune_time"] = model.fine_tune_time
report["inference_time"] = end - start

save_pickle(outname, report)


if __name__ == "__main__":
for seed in range(5):
for model in models:
for train_size in [10, 50, 100, 200, 300]:
train_test(train_size, model, random_state=seed)
9 changes: 9 additions & 0 deletions src/chemlift/finetune/peftmodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from functools import partial
from peft.utils.save_and_load import set_peft_model_state_dict
from fastcore.basics import basic_repr
import time


class ChemLIFTClassifierFactory:
Expand Down Expand Up @@ -125,8 +126,14 @@ def __init__(

self.tune_settings["per_device_train_batch_size"] = self.batch_size

self._fine_tune_time = None

__repr__ = basic_repr(["property_name", "_base_model"])

@property
def fine_tune_time(self):
return self._fine_tune_time

def _prepare_df(self, X: ArrayLike, y: ArrayLike):
rows = []
for i in range(len(X)):
Expand Down Expand Up @@ -255,6 +262,7 @@ def fit(
dfs.append(formatted)

formatted = pd.concat(dfs)
start_time = time.time()
train_model(
self.model,
self.tokenizer,
Expand All @@ -263,6 +271,7 @@ def fit(
hub_model_name=None,
report_to=None,
)
self._fine_tune_time = time.time() - start_time

def _predict(
self,
Expand Down
6 changes: 6 additions & 0 deletions src/chemlift/icl/fewshotpredictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import enum
from typing import Union
from chemlift.icl.utils import LangChainChatModelWrapper
import time


class Strategy(enum.Enum):
Expand Down Expand Up @@ -86,6 +87,11 @@ def __init__(
self._materialclass = "molecules"
self._max_test = max_test
self._prefix = prefix
self._prediction_time = None

@property
def prediction_time(self):
return self._prediction_time

def _format_examples(self, examples, targets):
"""Format examples and targets into a string.
Expand Down

0 comments on commit e46bf28

Please sign in to comment.