Source code for linear_operator_learning.nn.stats

"""Statistics utilities for multi-variate random variables."""

from math import sqrt

import torch
from torch import Tensor

from linear_operator_learning.nn.linalg import filter_reduced_rank_svals, sqrtmh


[docs] def covariance( X: Tensor, Y: Tensor | None = None, center: bool = True, norm: float | None = None, ) -> Tensor: """Computes the covariance of X or cross-covariance between X and Y if Y is given. Args: X (Tensor): Input features. Y (Tensor | None, optional): Output features. Defaults to None. center (bool, optional): Whether to compute centered covariances. Defaults to True. norm (float | None, optional): Normalization factor. Defaults to None. Shape: ``X``: :math:`(N, D)`, where :math:`N` is the batch size and :math:`D` is the number of features. ``Y``: :math:`(N, D)`, where :math:`N` is the batch size and :math:`D` is the number of features. Output: :math:`(D, D)`, where :math:`D` is the number of features. """ assert X.ndim == 2 if norm is None: norm = sqrt(X.shape[0]) else: assert norm > 0 norm = sqrt(norm) if Y is None: X = X / norm if center: X = X - X.mean(dim=0, keepdim=True) return torch.mm(X.T, X) else: assert Y.ndim == 2 X = X / norm Y = Y / norm if center: X = X - X.mean(dim=0, keepdim=True) Y = Y - Y.mean(dim=0, keepdim=True) return torch.mm(X.T, Y)
[docs] def cross_cov_norm_squared_unbiased(x: Tensor, y: Tensor, permutation=None): r"""Compute the unbiased estimation of :math:`\|\mathbf{C}_{xy}\|_F^2` from a batch of samples, using U-statistics. Given the Covariance matrix :math:`\mathbf{C}_{xy} = \mathbb{E}_p(x,y) [x^{\top} y]`, this function computes an unbiased estimation of the Frobenius norm of the covariance matrix from two independent sampling sets (an effective samples size of :math:`N^2`). .. math:: \begin{align} \|\mathbf{C}_{xy}\|_F^2 &= \text{tr}(\mathbf{C}_{xy}^{\top} \mathbf{C}_{xy}) = \sum_i \sum_j (\mathbb{E}_{x,y \sim p(x,y)} [x_i y_j]) (\mathbb{E}_{x',y' \sim p(x,y)} [x_j y_i']) \\ &= \mathbb{E}_{(x,y),(x',y') \sim p(x,y)} [(x^{\top} y') (x'^{T} y)] \\ &\approx \frac{1}{N^2} \sum_n \sum_m [(x_{n}^{\top} y^{\prime}_m) (x^{\prime \top}_m y_n)] \end{align} .. note:: The random variable is assumed to be centered. Args: x (Tensor): Centered realizations of a random variable `x` of shape (N, D_x). y (Tensor): Centered realizations of a random variable `y` of shape (N, D_y). permutation (Tensor, optional): List of integer indices of shape (n_samples,) used to permute the samples. Returns: Tensor: Unbiased estimation of :math:`\|\mathbf{C}_{xy}\|_F^2` using U-statistics. """ n_samples = x.shape[0] # Permute the rows independently to simulate independent sampling perm = permutation if permutation is not None else torch.randperm(n_samples) assert perm.shape == (n_samples,), f"Invalid permutation {perm.shape}!=({n_samples},)" xp = x[perm] # Independent sampling of x' yp = y[perm] # Independent sampling of y' # Compute 1/N^2 Σ_n Σ_m [(x_n.T y'_m) (x'_m.T y_n)] val = torch.einsum("nj,mj,mk,nk->", x, yp, xp, y) cov_fro_norm = val / (n_samples**2) return cov_fro_norm
[docs] def cov_norm_squared_unbiased(x: Tensor, permutation=None): r"""Compute the unbiased estimation of :math:`\|\mathbf{C}_x\|_F^2` from a batch of samples. Given the Covariance matrix :math:`\mathbf{C}_x = \mathbb{E}_p(x) [x^{\top} x]`, this function computes an unbiased estimation of the Frobenius norm of the covariance matrix from a single sampling set. .. math:: \begin{align} \|\mathbf{C}_x\|_F^2 &= \text{tr}(\mathbf{C}_x^{\top} \mathbf{C}_x) = \sum_i \sum_j (\mathbb{E}_{x} [x_i x_j]) (\mathbb{E}_{x'} [x'_j x'_i]) \\ &= \mathbb{E}_{x,x' \sim p(x)} [(x^{\top} x')^2] \\ &\approx \frac{1}{N^2} \sum_n \sum_m [(x_n^{\top} x'_m)^2] \end{align} .. note:: The random variable is assumed to be centered. Args: x (Tensor): (n_samples, r) Centered realizations of a random variable x = [x_1, ..., x_r]. permutation (Tensor, optional): List of integer indices of shape (n_samples,) used to permute the samples. Returns: Tensor: Unbiased estimation of :math:`\|\mathbf{C}_x\|_F^2` using U-statistics. """ return cross_cov_norm_squared_unbiased(x=x, y=x, permutation=permutation)
[docs] def whitening(u: Tensor, v: Tensor) -> tuple: """Computes whitening matrices for ``u`` and ``v``. Args: u (Tensor): Input features. v (Tensor): Output features. Shape: ``u``: :math:`(N, D)`, where :math:`N` is the batch size and :math:`D` is the number of features. ``v``: :math:`(N, D)`, where :math:`N` is the batch size and :math:`D` is the number of features. ``sqrt_cov_u_inv``: :math:`(D, D)` ``sqrt_cov_v_inv``: :math:`(D, D)` ``sing_val``: :math:`(D,)` ``sing_vec_l``: :math:`(D, D)` ``sing_vec_r``: :math:`(D, D)` """ cov_u = covariance(u) cov_v = covariance(v) cov_uv = covariance(u, v) sqrt_cov_u_inv = torch.linalg.pinv(sqrtmh(cov_u)) sqrt_cov_v_inv = torch.linalg.pinv(sqrtmh(cov_v)) M = sqrt_cov_u_inv @ cov_uv @ sqrt_cov_v_inv e_val, sing_vec_l = torch.linalg.eigh(M @ M.T) e_val, sing_vec_l = filter_reduced_rank_svals(e_val, sing_vec_l) sing_val = torch.sqrt(e_val) sing_vec_r = (M.T @ sing_vec_l) / sing_val return sqrt_cov_u_inv, sqrt_cov_v_inv, sing_val, sing_vec_l, sing_vec_r