Skip to content

Commit

Permalink
UPDATE fixs
Browse files Browse the repository at this point in the history
  • Loading branch information
YuyangXueEd committed Apr 8, 2024
1 parent 13839b8 commit edb155b
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 12 deletions.
29 changes: 19 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
<div align="center">

# ReconHydra
# MuLTI Hydra

[![python](https://img.shields.io/badge/-Python_3.8_%7C_3.9_%7C_3.10-blue?logo=python&logoColor=white)](https://github.com/pre-commit/pre-commit)
[![python](https://img.shields.io/badge/-Python_3.8_%7C_3.9_%7C_3.10%7C_3.11-blue?logo=python&logoColor=white)](https://github.com/pre-commit/pre-commit)
[![pytorch](https://img.shields.io/badge/PyTorch_2.0+-ee4c2c?logo=pytorch&logoColor=white)](https://pytorch.org/get-started/locally/)
[![lightning](https://img.shields.io/badge/-Lightning_2.0+-792ee5?logo=pytorchlightning&logoColor=white)](https://pytorchlightning.ai/)
[![hydra](https://img.shields.io/badge/Config-Hydra_1.3-89b8cd)](https://hydra.cc/)
[![hydra](https://img.shields.io/badge/Config-Hydra_1.3+-89b8cd)](https://hydra.cc/)
[![black](https://img.shields.io/badge/Code%20Style-Black-black.svg?labelColor=gray)](https://black.readthedocs.io/en/stable/)
[![license](https://img.shields.io/badge/License-MIT-green.svg?labelColor=gray)](https://github.com/ashleve/lightning-hydra-template#license)

Expand All @@ -15,7 +15,15 @@

## 📌  Introduction

MRI reconstruction using diffusion, build using PyTorch Lightning and Hydra.
A multi-purpose deep learning template, build using PyTorch Lightning and Hydra.

### Pre-made Template

- For general purpose, use the `main` branch.
- For MRI Reconstruction with FastMRI as an example, checkout the `Recon` branch
- For Huggingface Transformers, checkout the `Transformers` branch
- For huggingface Diffuser, checkout the `Diffuser` branch
- More to go ...

## Project Structure

Expand Down Expand Up @@ -79,12 +87,12 @@ The directory structure of new project looks like this:

```bash
# clone project
git clone https://github.com/YuyangXueEd/MRIfussion
cd MRIfussion
git clone https://github.com/YuyangXueEd/MuLTIHydra
cd MuLTIHydra

# [OPTIONAL] create conda environment
conda create -n MRIfussion python=3.10
conda activate MRIfussion
conda create -n MuLTIHydra python=3.10
conda activate MuLTIHydra

# install pytorch according to instructions
# https://pytorch.org/get-started/
Expand Down Expand Up @@ -1084,6 +1092,7 @@ hydra:

This template was inspired by:

- [lightning-hydra-template](https://github.com/ashleve/lightning-hydra-template)
- [PyTorchLightning/deep-learning-project-template](https://github.com/PyTorchLightning/deep-learning-project-template)
- [drivendata/cookiecutter-data-science](https://github.com/drivendata/cookiecutter-data-science)
- [lucmos/nn-template](https://github.com/lucmos/nn-template)
Expand All @@ -1098,12 +1107,12 @@ Other useful repositories:

## License

Lightning-Hydra-Template is licensed under the MIT License.
MuLTI-Hydra-Template is licensed under the MIT License.

```
MIT License

Copyright (c) 2021 ashleve
Copyright (c) 2021 YuyangXueEd

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
1 change: 1 addition & 0 deletions configs/callbacks/model_checkpoint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ model_checkpoint:
train_time_interval: null # checkpoints are monitored at the specified time interval
every_n_epochs: null # number of epochs between checkpoints
save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation
enable_version_counter: True # enables versioning for checkpoint names
3 changes: 3 additions & 0 deletions environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ name: myenv

channels:
- pytorch
- nvidia
- conda-forge
- defaults

Expand All @@ -23,6 +24,8 @@ channels:
dependencies:
- python=3.10
- pytorch=2.*
- pytorch-cuda=11.8
# - cpuonly
- torchvision=0.*
- lightning=2.*
- torchmetrics=0.*
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# --------- pytorch --------- #
--extra-index-url https://download.pytorch.org/whl/cu118
torch>=2.0.0
torchvision>=0.15.0
lightning>=2.0.0
Expand Down
4 changes: 2 additions & 2 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
:return: A tuple with metrics and dict with all instantiated objects.
"""
# set seed for random number generators in pytorch, numpy and python.random
if cfg.get("seed"):
if cfg.get("seed") is not None:
L.seed_everything(cfg.seed, workers=True)

log.info(f"Instantiating mask and transforms.")
Expand Down Expand Up @@ -98,7 +98,7 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
if cfg.get("test"):
log.info("Starting testing!")
ckpt_path = trainer.checkpoint_callback.best_model_path
if ckpt_path == "":
if ckpt_path:
log.warning("Best ckpt not found! Using current weights for testing...")
ckpt_path = None
trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
Expand Down

0 comments on commit edb155b

Please sign in to comment.