Skip to content

Commit

Permalink
fix dimensionality init
Browse files Browse the repository at this point in the history
  • Loading branch information
katosh committed Jul 11, 2024
1 parent df980d2 commit c7c1290
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# v1.4.4rc

- remove `numpy` as direct dependency
- bugfix DimensionalityEstimator dimensionality initialization

# v1.4.3

Expand Down
3 changes: 3 additions & 0 deletions mellon/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
ndarray,
ones,
zeros,
full,
)
from jax.numpy import sum as arraysum
from jax.numpy import any as arrayany
Expand Down Expand Up @@ -846,6 +847,8 @@ def compute_initial_dimensionalities(x, mu_dim, mu_dens, L, nn_distances, d):
:rtype: array-like
"""
target = log(d) - mu_dim
if asarray(target).size == 1:
target = full(L.shape[0], target)
initial_dims = Ridge(fit_intercept=False).fit(L, target).coef_
initial_dens = compute_initial_value(nn_distances, d, mu_dens, L)
initial_value = stack([initial_dims, initial_dens])
Expand Down

0 comments on commit c7c1290

Please sign in to comment.