diff --git a/lungmask/__main__.py b/lungmask/__main__.py index 45817ac..8206936 100644 --- a/lungmask/__main__.py +++ b/lungmask/__main__.py @@ -7,7 +7,7 @@ import pkg_resources # type: ignore import SimpleITK as sitk -from lungmask import mask, utils +from lungmask import LMInferer, utils def path(string): @@ -48,7 +48,6 @@ def main(): parser.add_argument( "--classes", help="spcifies the number of output classes of the model", - default=3, ) parser.add_argument( "--cpu", @@ -86,6 +85,11 @@ def main(): argsin = sys.argv[1:] args = parser.parse_args(argsin) + if args.classes is not None: + logging.warn( + "!!! Warning: The `classes` parameter is deprecated and will be removed in the next version !!!" + ) + batchsize = args.batchsize if args.cpu: batchsize = 1 @@ -98,29 +102,30 @@ def main(): assert ( args.modelpath is None ), "Modelpath can not be specified for LTRCLobes_R231 mode" - result = mask.apply_fused( - input_image, + inferer = LMInferer( + modelname="LTRCLobes", force_cpu=args.cpu, + fillmodel="R231", batch_size=batchsize, volume_postprocessing=not (args.nopostprocess), noHU=args.noHU, tqdm_disable=args.noprogress, ) + result = inferer.apply(input_image) else: - model = mask.get_model(args.modelname, args.modelpath, args.classes) - result = mask.apply( - input_image, - model, + inferer = LMInferer( + modelname=args.modelname, + modelpath=args.modelpath, force_cpu=args.cpu, batch_size=batchsize, volume_postprocessing=not (args.nopostprocess), noHU=args.noHU, tqdm_disable=args.noprogress, ) + result = inferer.apply(input_image) if args.noHU: file_ending = args.output.split(".")[-1] - print(file_ending) if file_ending in ["jpg", "jpeg", "png"]: result = (result / (result.max()) * 255).astype(np.uint8) result = result[0] diff --git a/lungmask/mask.py b/lungmask/mask.py index 1430b65..666bd66 100644 --- a/lungmask/mask.py +++ b/lungmask/mask.py @@ -1,4 +1,5 @@ import logging +import os import sys import warnings from typing import Optional, Union @@ -41,15 +42,12 @@ } -def get_model( - modelname: str, modelpath: Optional[str] = None, n_classes: int = 3 -) -> torch.nn.Module: +def get_model(modelname: str, modelpath: Optional[str] = None) -> torch.nn.Module: """Loads specific model and state Args: modelname (str): Modelname (e.g. R231, LTRCLobes or R231CovidWeb) modelpath (Optional[str], optional): Path to statedict, if not provided will be downloaded automatically. Modelname will be ignored if provided. Defaults to None. - n_classes (int, optional): Number of classes. Will be automatically set if modelname is provided. Defaults to 3. Returns: torch.nn.Module: Loaded model in eval state @@ -62,6 +60,8 @@ def get_model( else: state_dict = torch.load(modelpath, map_location=torch.device("cpu")) + n_classes = len(list(state_dict.values())[-1]) + model = UNet( n_classes=n_classes, padding=True, @@ -78,19 +78,23 @@ def get_model( class LMInferer: def __init__( self, - modelname="R231", + modelname: str = "R231", + modelpath: Optional[str] = None, fillmodel: Optional[str] = None, - force_cpu=False, - batch_size=20, - volume_postprocessing=True, - noHU=False, - tqdm_disable=False, + fillmodel_path: Optional[str] = None, + force_cpu: bool = False, + batch_size: int = 20, + volume_postprocessing: bool = True, + noHU: bool = False, + tqdm_disable: bool = False, ): """LungMaskInference Args: modelname (str, optional): Model to be applied. Defaults to 'R231'. + modelpath (str, optional): Path to modeleights. `modelname` parameter will be ignored if provided. Defaults to None. fillmodel (Optional[str], optional): Fillmodel to be applied. Defaults to None. + fillmodel_path (Optional[str], optional): Path to weights for fillmodel. `fillmodel` parameter will be ignored if provided. Defaults to None. force_cpu (bool, optional): Will not use GPU is `True`. Defaults to False. batch_size (int, optional): Batch size. Defaults to 20. volume_postprocessing (bool, optional): If `Fales` will not perform postprocessing (connected component analysis). Defaults to True. @@ -104,6 +108,13 @@ def __init__( assert ( fillmodel in MODEL_URLS ), "Modelname not found. Please choose from: {}".format(MODEL_URLS.keys()) + + # if paths provided, overwrite name + if modelpath is not None: + modelname = os.path.basename(modelpath) + if fillmodel_path is not None: + fillmodel = os.path.basename(fillmodel_path) + self.fillmodel = fillmodel self.modelname = modelname self.force_cpu = force_cpu @@ -112,7 +123,7 @@ def __init__( self.noHU = noHU self.tqdm_disable = tqdm_disable - self.model = get_model(self.modelname) + self.model = get_model(self.modelname, modelpath) self.device = torch.device("cpu") if not self.force_cpu: @@ -124,7 +135,7 @@ def __init__( self.fillmodelm = None if self.fillmodel is not None: - self.fillmodelm = get_model(self.fillmodel) + self.fillmodelm = get_model(self.fillmodel, fillmodel_path) self.fillmodelm.to(self.device) def _inference( @@ -250,6 +261,10 @@ def apply( noHU=False, tqdm_disable=False, ): + warnings.warn( + "The function `apply` will be removed in a future version. Please use the LMInferer class!", + DeprecationWarning, + ) inferer = LMInferer( force_cpu=force_cpu, batch_size=batch_size, @@ -272,6 +287,10 @@ def apply_fused( noHU=False, tqdm_disable=False, ): + warnings.warn( + "The function `apply_fused` will be removed in a future version. Please use the LMInferer class!", + DeprecationWarning, + ) inferer = LMInferer( modelname=basemodel, force_cpu=force_cpu, diff --git a/lungmask/utils.py b/lungmask/utils.py index b40eb4e..bf8dc02 100644 --- a/lungmask/utils.py +++ b/lungmask/utils.py @@ -252,6 +252,7 @@ def postprocessing( Returns: np.ndarray: Postprocessed volume """ + logging.info("Postprocessing") # CC analysis regionmask = skimage.measure.label(label_image) diff --git a/setup.py b/setup.py index 3441423..1757915 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name="lungmask", - version="0.2.14", + version="0.2.15", author="Johannes Hofmanninger", author_email="johannes.hofmanninger@gmail.com", description="Package for automated lung segmentation in CT", diff --git a/tests/test_mask.py b/tests/test_mask.py index 48e08ff..825c1ba 100644 --- a/tests/test_mask.py +++ b/tests/test_mask.py @@ -1,10 +1,11 @@ import os +import shutil import numpy as np -import pydicom as pyd import pytest +import torch -from lungmask.mask import LMInferer, apply, apply_fused +from lungmask.mask import MODEL_URLS, LMInferer from lungmask.utils import read_dicoms @@ -13,13 +14,47 @@ def fixture_testvol(): return read_dicoms(os.path.join(os.path.dirname(__file__), "testdata"))[0] -def test_apply(fixture_testvol): - res = apply(fixture_testvol) +@pytest.fixture(scope="session") +def fixture_weights_path_R231(tmpdir_factory): + # we make sure the model is there + torch.hub.load_state_dict_from_url( + MODEL_URLS["R231"][0], progress=True, map_location=torch.device("cpu") + ) + modelbasename = os.path.basename(MODEL_URLS["R231"][0]) + modelpath = os.path.join(torch.hub.get_dir(), "checkpoints", modelbasename) + tmppath = str(tmpdir_factory.mktemp("weights").join(modelbasename)) + shutil.copy(modelpath, tmppath) + return tmppath + + +def test_LMInferer(fixture_testvol, fixture_weights_path_R231): + inferer = LMInferer( + force_cpu=True, + tqdm_disable=True, + ) + res = inferer.apply(fixture_testvol) assert np.all(np.unique(res, return_counts=True)[1] == [423000, 64752, 36536]) + # here, we provide a path to the R231 weights but specify LTRCLobes (6 channel) as modelname + # The modelname should be ignored and a 3 channel output should be generated + inferer = LMInferer( + modelname="LTRCLobes", + modelpath=fixture_weights_path_R231, + force_cpu=True, + tqdm_disable=True, + ) + res = inferer.apply(fixture_testvol) + assert np.all(np.unique(res, return_counts=True)[1] == [423000, 64752, 36536]) -def test_apply_fused(fixture_testvol): - res = apply_fused(fixture_testvol) + +def test_LMInferer_fused(fixture_testvol): + inferer = LMInferer( + modelname="LTRCLobes", + force_cpu=True, + fillmodel="R231", + tqdm_disable=True, + ) + res = inferer.apply(fixture_testvol) assert np.all( np.unique(res, return_counts=True)[1] == [423000, 13334, 23202, 23834, 40918] ) diff --git a/tests/test_utils.py b/tests/test_utils.py index ac97b32..f97ce3d 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,10 +1,7 @@ import os import numpy as np -import pydicom as pd -import pydicom as pyd import SimpleITK as sitk -from pydicom.dataset import FileMetaDataset from lungmask.utils import ( bbox_3D, @@ -18,6 +15,9 @@ ) # creating test dicom data for reference +# import pydicom as pd +# import pydicom as pyd +# from pydicom.dataset import FileMetaDataset # # studyuid = pyd.uid.generate_uid() # seriesuid = pyd.uid.generate_uid()