Source code for linear_operator_learning.nn.modules.simnorm

"""Simplicial normalizarion."""

import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor


[docs] class SimNorm(nn.Module): """Simplicial normalization from :footcite:t:`lavoie2022simplicial`. Simplicial normalization splits the input into chunks of dimension :code:`dim`, applies a softmax transformation to each of the chunks separately, and concatenates them back together. Args: dim (int): Dimension of the simplicial groups. """ def __init__(self, dim: int = 8): super().__init__() self.dim = dim def forward(self, x: Tensor) -> Tensor: """Forward pass of the simplicial normalization module.""" shp = x.shape x = x.view(*shp[:-1], -1, self.dim) x = F.softmax(x, dim=-1) return x.view(*shp) def __repr__(self): """String representation of the simplicial norm module.""" return f"SimNorm(dim={self.dim})"