Source code for mldft.ml.models.components.toy_net

"""Toy architecture for testing purposes."""

import torch
from torch import Tensor, nn
from torch_geometric.nn import MessagePassing, global_add_pool, knn_graph
from torchvision.ops import MLP

from mldft.ml.data.components.basis_info import BasisInfo


[docs] class ToyNet(nn.Module): """Toy net module for testing purposes."""
[docs] def __init__( self, basis_info: BasisInfo, k_neighbors: int, hidden_dimension_energy: tuple[int, ...] = (128, 128, 128), hidden_dimension_coeffs: tuple[int, ...] = (128, 128, 128), ) -> None: """Initializes the ToyNet module. Args: basis_info (BasisInfo): BasisInfo object containing information about the basis for the dataset k_neighbors (int): Number of neighbors to consider in the graph creation hidden_dimension_energy (list[int], optional): list of dimensions of the hidden layer in the mlp calculating the energy. Defaults to [128,128,128]. hidden_dimension_coeffs (list[int], optional): list of dimensions of the hidden layer in the mlp calculating the coefficients. Defaults to [128,128,128]. """ super().__init__() # calculate the maximum number of basis functions in an atom max_number_basis = basis_info.basis_dim_per_atom.max() # setup everything self.max_number_basis = max_number_basis self.k_neighbors = k_neighbors self.ToyInitialGuessDeltaModel = ToyInitialGuessDelta( max_number_basis, list(hidden_dimension_coeffs) ) self.ToyEnergyModel = ToyEnergy(max_number_basis, list(hidden_dimension_energy))
[docs] def forward(self, batch) -> tuple[Tensor, Tensor]: """Calculates the forward pass of the ToyNet module. By independently calculating the energy and the predicted delta to the ground state, to be used for the learned initial guess in OF-DFT. Args: batch (Batch): Batch object containing the data Returns: tuple[Tensor, Tensor]: The predicted energy and the predicted coefficients (in order) """ # calculate the edge_index as a knn graph based on the positions of the nodes, and add self loops edge_index = knn_graph(x=batch.pos, k=self.k_neighbors, batch=batch.batch, loop=True) # split the coeffs by atom # this is needed to be able to use the message passing framework of pytorch geometric coeffs_list = list(batch.split_field_by_atom(batch.coeffs)) # zero pad the coefficients for i in range(len(coeffs_list)): coeffs_list[i] = torch.nn.functional.pad( coeffs_list[i], (0, self.max_number_basis - coeffs_list[i].shape[0]) ) padded_coeffs = torch.stack(coeffs_list, dim=0) # calculate the forward pass of the energy and the coefficient differences initial_guess_delta = self.ToyInitialGuessDeltaModel.forward( batch, edge_index, padded_coeffs ) energy = self.ToyEnergyModel.forward(batch, edge_index, padded_coeffs) return energy, initial_guess_delta
[docs] class ToyInitialGuessDelta(MessagePassing): """MessagePassing Module for calculating the inital_guess_delta."""
[docs] def __init__(self, max_number_basis: int, hidden_dimensions: list[int]) -> None: """Initializes the ToyInitialGuessDelta module. Args: max_number_basis (int): maximal number of basis functions per atom hidden_dimensions (list[int]): dimensions of the hidden layers in the mlp """ super().__init__() self.max_number_basis = max_number_basis # append the last layer to the hidden layers hidden_dimensions.append(max_number_basis) # need to add 3 for the positions self.mlp = MLP(max_number_basis + 3, hidden_dimensions)
[docs] def forward(self, batch, edge_index: Tensor, padded_coeffs: Tensor) -> Tensor: """Calculates the forward pass of the ToyCoeffs module. Args: batch (Batch): Batch object containing the data edge_index (Tensor): edge index of the graph padded_coeffs (Tensor): zero padded coefficients Returns: Tensor: Coefficient inital_guess_delta """ # Propagate the messages through the graph output = self.propagate(edge_index, pos=batch.pos, coeffs=padded_coeffs) # slice the tensor to the size of the coeffs (remove what was added by the zero padding) tmp_list = [] for i in range(batch.pos.shape[0]): node_coeffs = output[i, : batch.n_basis_per_atom[i].int()] tmp_list.append(node_coeffs) # and concatenate the list of tensors to a single tensor out = torch.cat(tmp_list, dim=0) return out
[docs] def message(self, pos_i: Tensor, pos_j: Tensor, coeffs_j: Tensor) -> Tensor: """Message function to calculate the message from node j to i. Args: pos_i (Tensor): position of the receiving node pos_j (Tensor): position of the sending node coeffs_j (Tensor): coefficients of the sending node Returns: Tensor: Message """ # concatenate the difference of positions and the coefficients concatenated = torch.cat((pos_j - pos_i, coeffs_j), dim=-1) # and pass it through the mlp message = self.mlp(concatenated) return message
[docs] class ToyEnergy(MessagePassing): """Toy energy module for testing purposes."""
[docs] def __init__(self, max_number_basis: int, hidden_dimensions: list[int]) -> None: """Initializes the ToyEnergy module. Args: max_number_basis (int): Maximal number of basis per atom hidden_dimensions (list[int]): number of dimensions of the hidden layers in the mlp """ super().__init__() self.max_number_basis = max_number_basis # append the last layer onto the hidden layers hidden_dimensions.append(1) # need to add 3 for the positions self.mlp = MLP(max_number_basis + 3, hidden_dimensions)
[docs] def forward(self, batch, edge_index: Tensor, padded_coeffs: Tensor) -> Tensor: """Calculates the forward pass of the ToyEnergy module. Args: batch (Batch): Batch object containing the data edge_index (Tensor): edge index of the graph padded_coeffs (Tensor): zero padded coefficients Returns: Tensor: predicted energy """ # propagate the messages through the graph output = self.propagate(edge_index, pos=batch.pos, coeffs=padded_coeffs) # for each molecule in the batch, sum over all nodes to get the energy pooled_output = global_add_pool(output, batch.batch) # squeeze the last dimension because of how the labels are defined in the dataset return pooled_output.squeeze(dim=-1)
[docs] def message(self, pos_i: Tensor, pos_j: Tensor, coeffs_j: Tensor) -> Tensor: """Message function to calculate the message from node j to i. Args: pos_i (Tensor): position of the receiving node pos_j (Tensor): position of the sending node coeffs_j (Tensor): coefficients of the sending node Returns: Tensor: Message """ # concatenate the difference of positions and the coefficients concatenated = torch.cat((pos_j - pos_i, coeffs_j), dim=-1) # pass it through the mlp message = self.mlp(concatenated) return message