from typing import Callable, List, Optional
import torch
from torch_geometric.nn.norm import LayerNorm
[docs]
class MLP(torch.nn.Sequential):
"""This block implements the multi-layer perceptron (MLP) module.
Note:
Adapted from :class:`torchvision.ops.MLP`, the only difference being the option to disable dropout in the
last layer, and the fact that no dropout layers are added if the dropout probability is 0.0.
Args:
in_channels (int): Number of channels of the input
hidden_channels (List[int]): List of the hidden channel dimensions
norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the linear layer. If ``None`` this layer won't be used. Default: ``None``
activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the linear layer. If ``None`` this layer won't be used. Default: ``torch.nn.ReLU``
inplace (bool, optional): Parameter for the activation layer, which can optionally do the operation in-place.
Default is ``None``, which uses the respective default values of the ``activation_layer`` and Dropout layer.
bias (bool): Whether to use bias in the linear layer. Default ``True``
dropout (float): The probability for the dropout layer. Default: 0.0
dropout_last_layer (bool): Whether to use dropout in the last layer. Default: ``True``
"""
[docs]
def __init__(
self,
in_channels: int,
hidden_channels: List[int],
norm_layer: Optional[Callable[..., torch.nn.Module]] = None,
activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
inplace: Optional[bool] = None,
bias: bool = True,
dropout: float = 0.0,
disable_norm_last_layer: bool = False,
disable_activation_last_layer: bool = False,
disable_dropout_last_layer: bool = False,
):
"""Initializes the MLP module."""
# The addition of `norm_layer` is inspired from the implementation of TorchMultimodal:
# https://github.com/facebookresearch/multimodal/blob/5dec8a/torchmultimodal/modules/layers/mlp.py
params = {} if inplace is None else {"inplace": inplace}
assert 0 <= dropout < 1, f"Dropout probability must be in [0,1), got {dropout=}."
layers = []
in_dim = in_channels
for hidden_dim in hidden_channels[:-1]:
layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias))
if norm_layer is not None:
layers.append(norm_layer(hidden_dim))
layers.append(activation_layer(**params))
layers.append(torch.nn.Dropout(dropout, **params))
in_dim = hidden_dim
# Configure last layer
layers.append(torch.nn.Linear(in_dim, hidden_channels[-1], bias=bias))
if not disable_norm_last_layer and norm_layer is not None:
layers.append(norm_layer(hidden_channels[-1]))
if not disable_activation_last_layer:
layers.append(activation_layer(**params))
if not disable_dropout_last_layer:
layers.append(torch.nn.Dropout(dropout, **params))
super().__init__(*layers)
[docs]
def forward(self, x: torch.Tensor, batch: torch.Tensor = None) -> torch.Tensor:
"""Calculates the forward pass of the module.
Args:
x (torch.Tensor): The input tensor.
batch (torch.Tensor): Batch tensor.
Returns:
torch.Tensor: The output tensor.
"""
for layer in self:
if isinstance(layer, LayerNorm) and layer.mode == "graph":
assert batch is not None, "Batch tensor must be provided for LayerNorm."
x = layer(x, batch)
else:
x = layer(x)
return x