Skip to content

Commit

Permalink
test with path provided
Browse files Browse the repository at this point in the history
  • Loading branch information
Johannes Hofmanninger committed Jun 15, 2023
1 parent 4d4ff42 commit b285ecc
Showing 1 changed file with 28 additions and 3 deletions.
31 changes: 28 additions & 3 deletions tests/test_mask.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import os
import shutil

import numpy as np
import pydicom as pyd
import pytest
import torch

from lungmask.mask import LMInferer, apply, apply_fused
from lungmask.mask import MODEL_URLS, LMInferer
from lungmask.utils import read_dicoms


Expand All @@ -13,8 +14,32 @@ def fixture_testvol():
return read_dicoms(os.path.join(os.path.dirname(__file__), "testdata"))[0]


def test_LMInferer(fixture_testvol):
@pytest.fixture(scope="session")
def fixture_weights_path_R231(tmpdir_factory):
# we make sure the model is there
torch.hub.load_state_dict_from_url(
MODEL_URLS["R231"][0], progress=True, map_location=torch.device("cpu")
)
modelbasename = os.path.basename(MODEL_URLS["R231"][0])
modelpath = os.path.join(torch.hub.get_dir(), "checkpoints", modelbasename)
tmppath = str(tmpdir_factory.mktemp("weights").join(modelbasename))
shutil.copy(modelpath, tmppath)
return tmppath


def test_LMInferer(fixture_testvol, fixture_weights_path_R231):
inferer = LMInferer(
force_cpu=True,
tqdm_disable=True,
)
res = inferer.apply(fixture_testvol)
assert np.all(np.unique(res, return_counts=True)[1] == [423000, 64752, 36536])

# here, we provide a path to the R231 weights but specify LTRCLobes (6 channel) as modelname
# The modelname should be ignored and a 3 channel output should be generated
inferer = LMInferer(
modelname="LTRCLobes",
modelpath=fixture_weights_path_R231,
force_cpu=True,
tqdm_disable=True,
)
Expand Down

0 comments on commit b285ecc

Please sign in to comment.