Source code for linear_operator_learning.nn.linalg

"""Linear Algebra."""

from typing import NamedTuple

import torch
from torch import Tensor


[docs] def sqrtmh(A: Tensor) -> Tensor: """Compute the square root of a Symmetric or Hermitian positive definite matrix or batch of matrices. Used code from `this issue <https://github.com/pytorch/pytorch/issues/25481#issuecomment-1032789228>`_. Args: A (Tensor): Symmetric or Hermitian positive definite matrix or batch of matrices. Shape: ``A``: :math:`(N, N)` Output: :math:`(N, N)` """ L, Q = torch.linalg.eigh(A) zero = torch.zeros((), device=L.device, dtype=L.dtype) threshold = L.max(-1).values * L.size(-1) * torch.finfo(L.dtype).eps L = L.where(L > threshold.unsqueeze(-1), zero) # zero out small components return (Q * L.sqrt().unsqueeze(-2)) @ Q.mH
#################################################################################################### # TODO: THIS IS JUST COPY AND PASTE FROM OLD NCP # Should topk and filter_reduced_rank_svals be in utils? They look like linalg to me, specially the # filter #################################################################################################### # Sorting and parsing class TopKReturnType(NamedTuple): # noqa: D101 values: Tensor indices: Tensor def topk(vec: Tensor, k: int): # noqa: D103 assert vec.ndim == 1, "'vec' must be a 1D array" assert k > 0, "k should be greater than 0" sort_perm = torch.flip(torch.argsort(vec), dims=[0]) # descending order indices = sort_perm[:k] values = vec[indices] return TopKReturnType(values, indices) def filter_reduced_rank_svals(values, vectors): # noqa: D103 eps = 2 * torch.finfo(torch.get_default_dtype()).eps # Filtering procedure. # Create a mask which is True when the real part of the eigenvalue is negative or the imaginary part is nonzero is_invalid = torch.logical_or( torch.real(values) <= eps, torch.imag(values) != 0 if torch.is_complex(values) else torch.zeros(len(values), device=values.device), ) # Check if any is invalid take the first occurrence of a True value in the mask and filter everything after that if torch.any(is_invalid): values = values[~is_invalid].real vectors = vectors[:, ~is_invalid] sort_perm = topk(values, len(values)).indices values = values[sort_perm] vectors = vectors[:, sort_perm] # Assert that the eigenvectors do not have any imaginary part assert torch.all( torch.imag(vectors) == 0 if torch.is_complex(values) else torch.ones(len(values)) ), "The eigenvectors should be real. Decrease the rank or increase the regularization strength." # Take the real part of the eigenvectors vectors = torch.real(vectors) values = torch.real(values) return values, vectors