Source code for linear_operator_learning.nn.modules.mlp

# TODO: Refactor the models, add docstrings, etc...
"""PyTorch Models."""

import torch
from torch.nn import Conv2d, Dropout, Linear, MaxPool2d, Module, ReLU, Sequential


class _MLPBlock(Module):
    def __init__(self, input_size, output_size, dropout=0.0, activation=ReLU, bias=True):
        super(_MLPBlock, self).__init__()
        self.linear = Linear(input_size, output_size, bias=bias)
        self.dropout = Dropout(dropout)
        self.activation = activation()

    def forward(self, x):
        out = self.linear(x)
        out = self.dropout(out)
        out = self.activation(out)
        return out


[docs] class MLP(Module): """Multi Layer Perceptron. Args: input_shape (int): Input shape of the MLP. n_hidden (int): Number of hidden layers. layer_size (int or list of ints): Number of neurons in each layer. If an int is provided, it is used as the number of neurons for all hidden layers. Otherwise, the list of int is used to define the number of neurons for each layer. output_shape (int): Output shape of the MLP. dropout (float): Dropout probability between layers. Defaults to 0.0. activation (torch.nn.Module): Activation function. Defaults to ReLU. iterative_whitening (bool): Whether to add an IterNorm layer at the end of the network. Defaults to False. bias (bool): Whether to include bias in the layers. Defaults to False. """ def __init__( self, input_shape, n_hidden, layer_size, output_shape, dropout=0.0, activation=ReLU, iterative_whitening=False, bias=False, ): super(MLP, self).__init__() if isinstance(layer_size, int): layer_size = [layer_size] * n_hidden if n_hidden == 0: layers = [Linear(input_shape, output_shape, bias=False)] else: layers = [] for layer in range(n_hidden): if layer == 0: layers.append( _MLPBlock(input_shape, layer_size[layer], dropout, activation, bias=bias) ) else: layers.append( _MLPBlock( layer_size[layer - 1], layer_size[layer], dropout, activation, bias=bias ) ) layers.append(Linear(layer_size[-1], output_shape, bias=False)) if iterative_whitening: # layers.append(IterNorm(output_shape)) raise NotImplementedError("IterNorm isn't implemented") self.model = Sequential(*layers) def forward(self, x): # noqa: D102 return self.model(x)