Skip to content

Commit

Permalink
fix row_wise qr
Browse files Browse the repository at this point in the history
  • Loading branch information
Szkered committed Aug 14, 2023
1 parent 2149d95 commit ae7710c
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions d4ft/hamiltonian/ortho.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,28 @@ def qr_factor(
) -> Array:
"""Get a orthongal matrix parametrized with qr factor.
NOTE: QR decomposition is done column-wise, and is only unique up
to a column-wise phase shift. In the case of real coefficients, this
means given an input matrix :math:`A`, the output of QR decomposition
:math:`A=QR` is only unique up to a sign flip of each column of :math:`Q`.
In D4FT the convention for MO coefficients is that each row represents a MO,
so row_wise must be true.
Another case for setting row_wise is even when the input is non-square.
For non-square matrix of size (a,b) where a<b, QR returns orthogonal column
vectors of shape (a,a). To get row-wise orthogonality transposition is needed:
first transpose the input matrix to (b,a), then QR returns orthogonal columns
of size (b,a), which are rows in the original space.
Args:
batch_dim: if provided vmap over this dim
row_wise: if true return row-wise orthogonal matrix. For non-square matrix
of size (a,b) where a<b, QR returns orthogonal column vectors of shape (a,a).
To get row-wise orthogonality transposition is needed.
row_wise: if true return row-wise orthogonal matrix.
"""
qr_fn = lambda p: jnp.linalg.qr(p)[0]
if batch_dim:
qr_fn = jax.vmap(qr_fn, batch_dim, batch_dim)
orthogonal = qr_fn(params)
if row_wise:
transpose_axis = (0,) + tuple(range(1, len(params.shape)))
transpose_axis = (0,) + tuple(reversed(range(1, len(params.shape))))
orthogonal = jnp.transpose(orthogonal, transpose_axis)
return orthogonal

0 comments on commit ae7710c

Please sign in to comment.