Skip to content

Commit

Permalink
regressor and do not return std by default
Browse files Browse the repository at this point in the history
  • Loading branch information
Kevin Maik Jablonka committed Oct 3, 2023
1 parent eba0871 commit 31080a9
Showing 1 changed file with 26 additions and 3 deletions.
29 changes: 26 additions & 3 deletions src/chemlift/finetune/peftmodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def predict(
temperature=0.7,
do_sample=False,
formatted: Optional[pd.DataFrame] = None,
return_std: bool = True,
return_std: bool = False,
):
predictions = self._predict(
X=X, temperature=temperature, do_sample=do_sample, formatted=formatted
Expand Down Expand Up @@ -334,7 +334,7 @@ def _query(self, formatted_df, temperature, do_sample):


class PEFTRegressor(PEFTClassifier):
def __init__(
def __init__(
self,
property_name: str,
extractor: RegressionExtractor = RegressionExtractor(),
Expand Down Expand Up @@ -376,8 +376,31 @@ def __init__(

self.tune_settings["per_device_train_batch_size"] = self.batch_size

__repr__ = basic_repr("property_name", "_base_model", 'num_digits')
__repr__ = basic_repr("property_name", "_base_model", "num_digits")

def predict(
self,
X: Optional[ArrayLike] = None,
temperature=0.7,
do_sample=False,
formatted: Optional[pd.DataFrame] = None,
return_std: bool = False,
):
predictions = self._predict(
X=X, temperature=temperature, do_sample=do_sample, formatted=formatted
)

predictions = np.array(predictions).T

# nan values make issues here
predictions_mean = np.array(
[try_exccept_nan(np.mean, pred) for pred in predictions.astype(int)]
)

if return_std:
predictions_std = np.array([np.std(pred) for pred in predictions.astype(int)])
return predictions_mean, predictions_std
return predictions_mean


class SMILESAugmentedPEFTClassifier(PEFTClassifier):
Expand Down

0 comments on commit 31080a9

Please sign in to comment.