Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Anndata #68

Merged
merged 12 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,13 @@ jobs:
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install harissa depencies
run: |
pip install pip-tools
pip-compile --strip-extras --extra extra -o requirements.txt pyproject.toml
pip install -r requirements.txt
- name: Install harissa
run: pip install dist/*.whl
run: pip install --no-index -f dist harissa
- name: Install pytest
run: pip install pytest pytest-cov
- name: Run tests and coverage
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ dependencies = [
dynamic = ["version"]

[project.optional-dependencies]
extra = ["alive-progress>=3.0", "umap-learn"]
extra = ["alive-progress>=3.0", "umap-learn", "anndata"]

[project.urls]
Repository = "https://github.com/harissa-framework/harissa"
Expand Down Expand Up @@ -64,7 +64,7 @@ raw-options = { local_scheme = "no-local-version" }
# Ruff

[tool.ruff]
select = ["E", "F"]
select = ["E", "F", "W"]
line-length = 79

[tool.ruff.format]
Expand Down
6 changes: 3 additions & 3 deletions src/harissa/benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def _load_value(self, key: K) -> V:
raise KeyError(f'{key} is invalid. {path} does not exist.')

result = inf[0].Result.load(path / 'result.npz', load_extra=True)
runtime_in_sec = float(np.load(path / 'runtime.npy')[0])
runtime_in_sec = np.load(path / 'runtime.npy').item()

return network, inf, dataset, result, runtime_in_sec

Expand Down Expand Up @@ -200,7 +200,7 @@ def _save_item(self, path: Path, item: Tuple[K, V]):
output.mkdir(parents=True, exist_ok=True)

result.save(output / 'result', True)
np.save(output / 'runtime.npy', np.array([runtime]))
np.save(output / 'runtime.npy', np.array(runtime))

keys = [(key[0], key[2]), key[1]]
values = [(network, dataset), inf]
Expand Down Expand Up @@ -310,7 +310,7 @@ def save_reports(self,

if __name__ == '__main__':
benchmark = Benchmark()
# benchmark.path = 'test_benchmark.zip'
# benchmark.datasets.path = 'test_datasets2'
# benchmark.networks.include = ['BN8']
print(benchmark.save_reports('test_benchmark', None, True, True))

Expand Down
27 changes: 19 additions & 8 deletions src/harissa/benchmark/generators/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
List,
Union,
Optional,
Literal,
TypeAlias
)

Expand Down Expand Up @@ -46,7 +47,8 @@ def __init__(self,
include: List[K] = [('*', '*')],
exclude: List[K] = [],
path: Optional[Union[str, Path]] = None,
verbose: bool = False
verbose: bool = False,
save_format: Literal['.npz', '.h5ad'] = '.npz'
) -> None:
self.networks = NetworksGenerator(verbose=verbose)

Expand All @@ -57,6 +59,7 @@ def __init__(self,
self._model = NetworkModel(
simulation=BurstyPDMP(use_numba=True)
)
self.save_format = save_format

def _set_path(self, path: Path):
"""
Expand Down Expand Up @@ -95,10 +98,14 @@ def _load_value(self, key: K) -> V:
network = self.networks[key[0]]
path = self._to_path(key).with_suffix('.npz')

if not path.exists():
raise KeyError(f'{key} is invalid. {path} does not exist.')

dataset = Dataset.load(path)
if path.exists():
dataset = Dataset.load(path)
else:
path = path.with_suffix('.h5ad')
if path.exists():
dataset = Dataset.load_h5ad(path)
else:
raise KeyError(f'{key} is invalid. {path} does not exist.')

return network, dataset

Expand Down Expand Up @@ -202,18 +209,22 @@ def _save_item(self, path: Path, item: Tuple[K, V]):

"""
key, (network, dataset) = item
output = self._to_path(key, path).with_suffix('.npz')
output = self._to_path(key, path)
output.parent.mkdir(parents=True, exist_ok=True)

dataset.save(output)
if self.save_format == '.npz':
dataset.save(output.with_suffix('.npz'))
else:
dataset.save_h5ad(output.with_suffix('.h5ad'))
self.networks.save_item(path, key[0], network)


if __name__ == '__main__':
n_datasets = {'BN8': 2, 'CN5': 5, 'FN4': 10, 'FN8': 1}
gen = DatasetsGenerator(
n_datasets=n_datasets,
verbose=True
verbose=True,
save_format='.h5ad'
)
gen.networks.include = list(n_datasets.keys())
gen.save('test_datasets')
Expand Down
122 changes: 96 additions & 26 deletions src/harissa/core/dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations
from typing import Dict, Union
from typing import Dict, Union, Optional, Literal
import numpy as np
from scipy.sparse import issparse
from dataclasses import dataclass, asdict
from pathlib import Path
from typing import ClassVar
Expand All @@ -11,51 +12,53 @@
load_npz,
save_dir,
save_npz
)
)

_anndata_msg_error = 'Install the package anndata to use this function.'

@dataclass(frozen=True, init=False)
class Dataset:
param_names: ClassVar[Dict[str, ParamInfos]] = {
'time_points': ParamInfos(True, np.float64, 1),
'time_points': ParamInfos(True, np.float64, 1),
'count_matrix': ParamInfos(True, np.uint, 2),
'gene_names': ParamInfos(False, np.str_, 1)
}
time_points: np.ndarray
count_matrix: np.ndarray
gene_names: np.ndarray

def __init__(self,
time_points: np.ndarray,
count_matrix: np.ndarray,
gene_names=None) -> None:
gene_names: Optional[np.ndarray]

if not (time_points.ndim == 1 and
def __init__(self,
time_points: np.ndarray,
count_matrix: np.ndarray,
gene_names=None
) -> None:
if not (time_points.ndim == 1 and
time_points.dtype == self.param_names['time_points'].dtype):
raise TypeError('time_points must be a float 1D ndarray.')

if not (count_matrix.ndim == 2 and
if not (count_matrix.ndim == 2 and
count_matrix.dtype == self.param_names['count_matrix'].dtype):
raise TypeError('count_matrix must be an uint 2D ndarray.')

if time_points.shape[0] != count_matrix.shape[0]:
raise TypeError(
'time_points must have the same number of elements'
' than the rows of count_matrix.'
'time_points must have the same number of elements'
' than the rows of count_matrix.'
f'({time_points.shape[0]} != {count_matrix.shape[0]})'
)

if count_matrix.shape[1] <= 1:
raise TypeError('count_matrix must have at least 2 columns.')

if gene_names is not None:
if (gene_names.ndim != 1 or
gene_names.dtype.type is not np.str_):
raise TypeError('gene_names must be a str 1D ndarray.')

if gene_names.shape[0] != count_matrix.shape[1]:
raise TypeError(
'genes_names must have the same number of elements'
' than the columns of count_matrix.'
'genes_names must have the same number of elements'
' than the columns of count_matrix.'
f'({gene_names.shape[0]} != {count_matrix.shape[1]})'
)

Expand All @@ -71,7 +74,7 @@ def load_txt(cls, path: Union[str, Path]) -> Dataset:
if not path.exists():
raise RuntimeError(f"{path} doesn't exist.")
# Backward compatibility, dataset inside a txt file.
# It assumes that the 1rst column is the time points (arr_list[0])
# It assumes that the 1rst column is the time points (arr_list[0])
# and the rest is the count matrix (arr_list[1])
arr = np.loadtxt(path)
data_list = [
Expand All @@ -93,17 +96,84 @@ def load_txt(cls, path: Union[str, Path]) -> Dataset:
@classmethod
def load(cls, path: Union[str, Path]) -> Dataset:
return cls(**load_npz(path, cls.param_names))

def as_dict(self) -> Dict[str, np.ndarray]:
return asdict(
self,
dict_factory=lambda x: {k:v for (k, v) in x if v is not None}
)

# Add a "save" methods

def save_txt(self, path: Union[str, Path]) -> Path:
return save_dir(path, self.as_dict())

def save(self, path: Union[str, Path]) -> Path:
return save_npz(path, self.as_dict())

def as_dict(self) -> Dict[str, np.ndarray]:
return asdict(
self,
dict_factory=lambda x: {k:v for (k, v) in x if v is not None}
)

@classmethod
def from_annData(cls, adata) -> Dataset:
try:
from pandas import DataFrame
from anndata import AnnData

if not isinstance(adata, AnnData):
raise TypeError('adata must be an AnnData object.')

if isinstance(adata.X, DataFrame):
count_matrix = adata.X.to_numpy()
elif issparse(adata.X):
count_matrix = adata.X.toarray()
else:
count_matrix = adata.X

time_points = adata.obs.get('time_points', None)
if time_points is None:
raise RuntimeError(
'adata must have a time_points field in its obs.'
)

return cls(
np.array(time_points, dtype=np.float64),
count_matrix.astype(np.uint),
np.array(adata.var_names, dtype=np.str_)
)
except ImportError:
raise RuntimeError(_anndata_msg_error)

@classmethod
def load_h5ad(cls, path: Union[str, Path]) -> Dataset:
try:
from anndata import read_h5ad
return cls.from_annData(read_h5ad(path))
except ImportError:
raise RuntimeError(_anndata_msg_error)

def as_annData(self):
try:
from anndata import AnnData
adata = AnnData(
self.count_matrix,
{'time_points': self.time_points}
)
adata.obs_names = np.array(
[f'Cell_{i+1}' for i in range(adata.n_obs)]
)
if self.gene_names is not None:
adata.var_names = self.gene_names
else:
adata.var_names = np.array(
[f'Gene_{i}' for i in range(adata.n_vars)]
)

return adata
except ImportError:
raise RuntimeError(_anndata_msg_error)

def save_h5ad(self,
path: Union[str, Path],
compression: Optional[Literal['gzip', 'lzf']] = None,
compression_opts = None
) -> Path:
path = Path(path)
adata = self.as_annData()
adata.write_h5ad(path, compression, compression_opts)
return path
Loading