diff --git a/d4ft/config.py b/d4ft/config.py index 53527e7..36969f3 100644 --- a/d4ft/config.py +++ b/d4ft/config.py @@ -29,7 +29,7 @@ class GDConfig: """learning rate""" lr_decay: Literal["none", "piecewise", "cosine"] = "none" """learning rate schedule""" - optimizer: Literal["adam", "sgd", "rmsprop"] = "rmsprop" + optimizer: Literal["adam", "sgd", "rmsprop"] = "adam" """which optimizer to use""" epochs: int = 4000 """number of updates/iterations""" @@ -40,7 +40,7 @@ class GDConfig: which is used for gradient descent convergence checking""" meta_lr: float = 0.03 """meta learning rate""" - meta_opt: Literal["none", "adam", "sgd", "rmsprop"] = "adam" + meta_opt: Literal["none", "adam", "sgd", "rmsprop"] = "none" """meta optimizer to use, none to disable"""