Skip to content

Commit

Permalink
revert 8435574
Browse files Browse the repository at this point in the history
  • Loading branch information
katosh committed Oct 17, 2024
1 parent 8435574 commit 1499b0b
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 32 deletions.
8 changes: 3 additions & 5 deletions mellon/conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from jax.numpy import diag as diagonal
from jax.numpy.linalg import cholesky
from jax.scipy.linalg import solve_triangular
from .util import ensure_2d, stabilize, DEFAULT_JITTER, add_variance, add_projected_variance
from .util import ensure_2d, stabilize, DEFAULT_JITTER, add_variance
from .base_predictor import Predictor, ExpPredictor, PredictorTime
from .decomposition import DEFAULT_SIGMA

Expand Down Expand Up @@ -278,9 +278,9 @@ def __init__(
LLB = stabilize(LLB, jitter)
else:
logger.debug("Assuming y is not the mean of the GP.")
y_cov_factor = _sigma_to_y_cov_factor(sigma, y_cov_factor, x.shape[0])
y_cov_factor = _sigma_to_y_cov_factor(sigma, y_cov_factor, xu.shape[0])
sigma = None
LLB = add_projected_variance(LLB, A, y_cov_factor, jitter=jitter)
LLB = add_variance(LLB, y_cov_factor, jitter=jitter)

L_B = cholesky(LLB)
r = y - mu
Expand All @@ -304,8 +304,6 @@ def __init__(
self.L = L
self._state_variables.add("L")

y_cov_factor = _sigma_to_y_cov_factor(sigma, y_cov_factor, xu.shape[0])

C = solve_triangular(L_B, dot(A, y_cov_factor), lower=True)
Z = solve_triangular(L_B.T, C)
W = solve_triangular(L.T, Z)
Expand Down
27 changes: 0 additions & 27 deletions mellon/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,33 +315,6 @@ def add_variance(K, M=None, jitter=DEFAULT_JITTER):
return K


def add_projected_variance(K, A, y_cov_factor, jitter=DEFAULT_JITTER):
"""
Adds the projected observation noise covariance to K and stabilizes it.
Parameters
----------
K : array_like, shape (n_landmarks, n_landmarks)
The initial covariance matrix.
A : array_like, shape (n_landmarks, n_obs)
The projection matrix from observations to inducing points.
y_cov_factor : array_like, shape (n_obs, n_obs)
The observation noise covariance matrix.
jitter : float, optional
A small number to stabilize the covariance matrix. Defaults to 1e-6.
Returns
-------
stabilized_K : array_like, shape (n_landmarks, n_landmarks)
The stabilized covariance matrix with added projected variance.
"""
noise = A @ y_cov_factor @ A.T
noise_diag = diagonal(noise)
diff = where(noise_diag < jitter, jitter - noise_diag, 0)
K = K + noise + diagonal(diff)
return K


def mle(nn_distances, d):
R"""
Nearest Neighbor distribution maximum likelihood estimate for log density
Expand Down

0 comments on commit 1499b0b

Please sign in to comment.