Source code for mldft.ml.models.components.g3d_layer

"""Graphormer layer with scalar distance features as attention bias.

Implements the G3D layer as described in [M-OFDFT]_, based on [Graphormer]_.
"""

import math
import warnings
from typing import Optional

import torch
from torch import Tensor, nn
from torch.nn import Linear
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.norm.layer_norm import LayerNorm
from torch_geometric.typing import Adj, OptTensor
from torch_geometric.utils import softmax

from mldft.ml.models.components.mlp import MLP
from mldft.ml.models.components.node_embedding import smooth_falloff
from mldft.utils.log_utils.logging_mixin import LoggingMixin


[docs] class G3DLayer(MessagePassing, LoggingMixin): """The G3D layer as described in [M-OFDFT]_. ``out_channels`` are computed by dividing the input dimension by the number of heads. Based on :class:`~torch_geometric.nn.conv.transformer_conv.TransformerConv` as implemented in :mod:`torch_geometric.nn.conv.transformer_conv`. """ # Needed for torch_geometric.compile propagate_type = { "query": Tensor, "key": Tensor, "value": Tensor, "edge_attr": Tensor, "length": Tensor, }
[docs] def __init__( self, in_channels: int, heads: int = 32, edge_dim: int = 1, dropout: float = 0.0, attention_weight_dropout: float = 0.0, mlp_hidden_dim: Optional[int] = None, mlp_activation: torch.nn.Module = torch.nn.GELU, mlp_norm_layer: torch.nn.Module = None, norm_layer_class: torch.nn.Module = LayerNorm, activation_dropout: float = 0.0, cutoff: float | None = None, cutoff_start: float = 0.0, **kwargs, ): """Initialize the G3DLayer. Args: in_channels (int): Size of each input sample. heads (int, optional): Number of multi-head-attentions. (default: `1`) edge_dim (int): Edge feature dimensionality (in case there are any). Edge features are added to the attention weights before applying the soft(arg)max. (default: `1`) dropout (float, optional): Dropout probability of the MLP. Defaults to 0.0. attention_weight_dropout (float, optional): Dropout probability of the attention weights. Defaults to 0.0. mlp_hidden_dim (int, optional): Hidden dimensionality of the MLP. If None, defaults to `in_channels`. mlp_activation (torch.nn.Module, optional): Activation function of the MLP. Defaults to `torch.nn.GELU()`. activation_dropout (float, optional): Dropout probability of the activation function. Defaults to 0.0. **kwargs (optional): Additional arguments of `torch_geometric.nn.conv.MessagePassing`. Raises: ValueError: If the number of heads does not divide the number of input channels. """ kwargs.setdefault("aggr", "add") super().__init__(node_dim=0, **kwargs) self.in_channels = in_channels self.heads = heads if in_channels % heads != 0: raise ValueError("Number of heads must divide in_channels.") self.channels_per_head = in_channels // heads self.edge_dim = edge_dim self.mlp_hidden_dim = mlp_hidden_dim if mlp_hidden_dim is not None else in_channels self.cutoff = cutoff self.cutoff_start = cutoff_start # heads * channel_per_head = in_channels # Use one large linear layer to compute query, key and value more efficiently self.linear_in = Linear(in_channels, 3 * in_channels) self.linear_out = Linear(in_channels, in_channels) self.norm_1 = norm_layer_class(self.in_channels) self.norm_2 = norm_layer_class(self.in_channels) self.dropout = torch.nn.Dropout(dropout) self.attention_weight_dropout = torch.nn.Dropout(attention_weight_dropout) # Two layer mlp, to hidden_dim and back to in_channels self.mlp = MLP( in_channels=self.in_channels, hidden_channels=[self.mlp_hidden_dim, self.in_channels], activation_layer=mlp_activation, dropout=activation_dropout, disable_dropout_last_layer=True, disable_activation_last_layer=True, norm_layer=mlp_norm_layer, ) self.reset_parameters()
[docs] def reset_parameters(self): # Initialize the linear layers (query, key, value, and output linear layers) nn.init.xavier_uniform_(self.linear_in.weight) nn.init.constant_(self.linear_in.bias, 0) nn.init.xavier_uniform_(self.linear_out.weight) nn.init.constant_(self.linear_out.bias, 0) # Initialize the MLP, if it has learnable parameters if hasattr(self.mlp, "reset_parameters"): self.mlp.reset_parameters()
[docs] def forward( self, x: Tensor, edge_index: Adj, batch: Tensor, edge_attr: OptTensor = None, length=None ) -> Tensor: """Runs the forward pass of the module. The forward pass is defined as: x = MHAtt(LN(input) + edge_attr) + input output = x + MLP(LN(x)) Args: x (torch.Tensor): The input node features. edge_index (torch.Tensor): The edge indices. batch (torch.Tensor): Batch assigning each node to a specific graph. edge_attr (torch.Tensor, optional): The edge features. (default: `None`) Returns: torch.Tensor: The output node features. """ if edge_attr is None: warnings.warn("Edge_attr is None. This is not recommended.") edge_attr = torch.zeros((edge_index.shape[1], self.edge_dim), device=x.device) # Attention Block if isinstance(self.norm_1, LayerNorm): out_normalized = self.norm_1(x, batch) else: out_normalized = self.norm_1(x) # Chunk the linear layer output into query, key and value query, key, value = self.linear_in(out_normalized).chunk(3, dim=-1) query = query.contiguous().view(-1, self.heads, self.channels_per_head) key = key.contiguous().view(-1, self.heads, self.channels_per_head) value = value.contiguous().view(-1, self.heads, self.channels_per_head) attention_output = self.propagate( edge_index, query=query, key=key, value=value, edge_attr=edge_attr, size=None, length=length, ) attention_output = attention_output.view(-1, self.heads * self.channels_per_head) attention_output = self.linear_out(attention_output) attention_output = self.dropout(attention_output) out_skip_links = attention_output + x # MLP Block if isinstance(self.norm_2, LayerNorm): out_normalized = self.norm_2(out_skip_links, batch) else: out_normalized = self.norm_2(out_skip_links) out_mlp = self.mlp(out_normalized, batch) out_mlp = self.dropout(out_mlp) self.log(out_mlp=out_mlp, out_skip_links=out_skip_links) out = out_mlp + out_skip_links return out
[docs] def message( self, query_i: Tensor, key_j: Tensor, value_j: Tensor, edge_attr: Tensor, index: Tensor, ptr: OptTensor, size_i: Optional[int], length: OptTensor = None, ) -> Tensor: """Message function of the G3D layer. Computes the attention weights of each edge, added with the according edge_attr. Args: query_i: query edge tensor of shape (E, heads, channels_per_head) key_j: key edge tensor of shape (E, heads, channels_per_head) value_j: value edge tensor of shape (E, heads, channels_per_head) edge_attr: edge features index: the indices describing where edges end ptr: pointer to indicate where graph in a batch ends and starts size_i:The dimension in which the softmax normalizes. length: The length of the edges sorted as in the edge index. Returns: """ alpha = self.compute_attention(query_i, key_j, edge_attr, index, ptr, size_i) if self.cutoff is not None: alpha = alpha * smooth_falloff(length, self.cutoff, self.cutoff_start) alpha = self.attention_weight_dropout(alpha) out = value_j * alpha.view(-1, self.heads, 1) return out
[docs] def compute_attention( self, query_i: Tensor, key_j: Tensor, edge_attr: Tensor, index: Tensor, ptr: OptTensor, size_i: Optional[int], ): """Compute the attention weights.""" alpha = (query_i * key_j).sum(dim=-1) / math.sqrt(self.channels_per_head) + edge_attr alpha = softmax(alpha, index, ptr, size_i) return alpha
[docs] def __setstate__(self, state: dict) -> None: """This method is called during unpickling. If 'cutoff_start' is missing (as would be the case with an older checkpoint), it will be added with a default value. """ # Update the state dictionary first self.__dict__.update(state) # If the new attribute is missing, add it with the default value. if not hasattr(self, "cutoff"): self.cutoff = None if not hasattr(self, "cutoff_start"): self.cutoff_start = 0.0
[docs] def __repr__(self) -> str: """Representation of the G3D layer.""" return ( f"{self.__class__.__name__}({self.in_channels}, " f"heads={self.heads}), edge_feature_dim={self.edge_dim})" )
[docs] class G3DLayerMul(G3DLayer): """G3D layer with multiplicative attention bias."""
[docs] def compute_attention( self, query_i: Tensor, key_j: Tensor, edge_attr: Tensor, index: Tensor, ptr: OptTensor, size_i: Optional[int], ): """Compute the attention weights.""" alpha = (query_i * key_j).sum(dim=-1) / math.sqrt(self.channels_per_head) * edge_attr alpha = softmax(alpha, index, ptr, size_i) return alpha
[docs] class G3DLayerSilu(G3DLayer): """G3D layer with SiLU activation function."""
[docs] def compute_attention( self, query_i: Tensor, key_j: Tensor, edge_attr: Tensor, index: Tensor, ptr: OptTensor, size_i: Optional[int], ): """Compute the attention weights.""" alpha = (query_i * key_j).sum(dim=-1) / math.sqrt(self.channels_per_head) + edge_attr alpha = torch.nn.functional.silu(alpha) return alpha
[docs] class G3DLayerMulSilu(G3DLayer): """G3D layer with SiLU activation function."""
[docs] def compute_attention( self, query_i: Tensor, key_j: Tensor, edge_attr: Tensor, index: Tensor, ptr: OptTensor, size_i: Optional[int], ): """Compute the attention weights.""" alpha = (query_i * key_j).sum(dim=-1) / math.sqrt(self.channels_per_head) * edge_attr alpha = torch.nn.functional.silu(alpha) return alpha