Source code for mldft.ml.models.components.mlp

from typing import Callable, List, Optional

import torch
from torch_geometric.nn.norm import LayerNorm


[docs] class MLP(torch.nn.Sequential): """This block implements the multi-layer perceptron (MLP) module. Note: Adapted from :class:`torchvision.ops.MLP`, the only difference being the option to disable dropout in the last layer, and the fact that no dropout layers are added if the dropout probability is 0.0. Args: in_channels (int): Number of channels of the input hidden_channels (List[int]): List of the hidden channel dimensions norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the linear layer. If ``None`` this layer won't be used. Default: ``None`` activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the linear layer. If ``None`` this layer won't be used. Default: ``torch.nn.ReLU`` inplace (bool, optional): Parameter for the activation layer, which can optionally do the operation in-place. Default is ``None``, which uses the respective default values of the ``activation_layer`` and Dropout layer. bias (bool): Whether to use bias in the linear layer. Default ``True`` dropout (float): The probability for the dropout layer. Default: 0.0 dropout_last_layer (bool): Whether to use dropout in the last layer. Default: ``True`` """
[docs] def __init__( self, in_channels: int, hidden_channels: List[int], norm_layer: Optional[Callable[..., torch.nn.Module]] = None, activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, inplace: Optional[bool] = None, bias: bool = True, dropout: float = 0.0, disable_norm_last_layer: bool = False, disable_activation_last_layer: bool = False, disable_dropout_last_layer: bool = False, ): """Initializes the MLP module.""" # The addition of `norm_layer` is inspired from the implementation of TorchMultimodal: # https://github.com/facebookresearch/multimodal/blob/5dec8a/torchmultimodal/modules/layers/mlp.py params = {} if inplace is None else {"inplace": inplace} assert 0 <= dropout < 1, f"Dropout probability must be in [0,1), got {dropout=}." layers = [] in_dim = in_channels for hidden_dim in hidden_channels[:-1]: layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias)) if norm_layer is not None: layers.append(norm_layer(hidden_dim)) layers.append(activation_layer(**params)) layers.append(torch.nn.Dropout(dropout, **params)) in_dim = hidden_dim # Configure last layer layers.append(torch.nn.Linear(in_dim, hidden_channels[-1], bias=bias)) if not disable_norm_last_layer and norm_layer is not None: layers.append(norm_layer(hidden_channels[-1])) if not disable_activation_last_layer: layers.append(activation_layer(**params)) if not disable_dropout_last_layer: layers.append(torch.nn.Dropout(dropout, **params)) super().__init__(*layers)
[docs] def forward(self, x: torch.Tensor, batch: torch.Tensor = None) -> torch.Tensor: """Calculates the forward pass of the module. Args: x (torch.Tensor): The input tensor. batch (torch.Tensor): Batch tensor. Returns: torch.Tensor: The output tensor. """ for layer in self: if isinstance(layer, LayerNorm) and layer.mode == "graph": assert batch is not None, "Batch tensor must be provided for LayerNorm." x = layer(x, batch) else: x = layer(x) return x