Source code for deeprank2.neuralnets.gnn.alignmentnet

import torch
from torch import nn

# ruff: noqa: ANN001, ANN201

__author__ = "Daniel-Tobias Rademaker"


[docs]class GNNLayer(nn.Module): """Custom-defined layer of a Graph Neural Network. Args: nmb_edge_projection: Number of features in the edge projection. nmb_hidden_attr: Number of features in the hidden attributes. nmb_output_features: Number of output features. message_vector_length: Length of the message vector. nmb_mlp_neurons: Number of neurons in the MLP. act_fn: Activation function. Defaults to nn.SiLU(). is_last_layer: Whether this is the last layer of the GNN. Defaults to True. """ def __init__( self, nmb_edge_projection, nmb_hidden_attr, nmb_output_features, message_vector_length, nmb_mlp_neurons, act_fn=nn.SiLU(), # noqa: B008 is_last_layer=True, ): super().__init__() # The MLP that takes in atom-pairs and creates the Mij's self.edge_mlp = nn.Sequential( nn.Linear(nmb_edge_projection + nmb_hidden_attr * 2, nmb_mlp_neurons), act_fn, nn.Linear(nmb_mlp_neurons, message_vector_length), act_fn, ) # The node-MLP, creates a new node-representation given the Mi's self.node_mlp = nn.Sequential( nn.BatchNorm1d(message_vector_length + nmb_hidden_attr), nn.Linear(message_vector_length + nmb_hidden_attr, nmb_mlp_neurons), act_fn, nn.Linear(nmb_mlp_neurons, nmb_mlp_neurons), act_fn, nn.Linear(nmb_mlp_neurons, nmb_hidden_attr), ) # Only last layer have attention and output modules if is_last_layer: # attention mlp, to weight the ouput significance self.attention_mlp = nn.Sequential( nn.Linear(nmb_hidden_attr, nmb_mlp_neurons), act_fn, nn.Linear(nmb_mlp_neurons, 1), nn.Sigmoid(), ) # Create the output vector per node we are interested in self.output_mlp = nn.Sequential( nn.Linear(nmb_hidden_attr, nmb_mlp_neurons), act_fn, nn.Linear(nmb_mlp_neurons, nmb_output_features), ) # MLP that takes in the node-attributes of nodes (source + target), the edge attributes # and node attributes in order to create a 'message vector'between those # nodes
[docs] def edge_model(self, edge_attr, hidden_features_source, hidden_features_target): cat = torch.cat([edge_attr, hidden_features_source, hidden_features_target], dim=1) return self.edge_mlp(cat)
# A function that updates the node-attributes. Assumed that submessages # are already summed
[docs] def node_model(self, summed_edge_message, hidden_features): cat = torch.cat([summed_edge_message, hidden_features], dim=1) output = self.node_mlp(cat) return hidden_features + output
# Sums the individual sub-messages (multiple per node) into singel message # vector per node
[docs] def sum_messages(self, edges, messages, nmb_nodes): row, _ = edges summed_messages_shape = (nmb_nodes, messages.size(1)) result = messages.new_full(summed_messages_shape, 0) row = row.unsqueeze(-1).expand(-1, messages.size(1)) result.scatter_add_(0, row, messages) return result
# Runs the GNN # steps is number of times it exanges info with neighbors
[docs] def update_nodes(self, edges, edge_attr, hidden_features, steps=1): ( row, col, ) = edges # a single edge is defined as the index of atom1 and the index of atom2 h = hidden_features # shortening the variable name # It is possible to run input through the same same layer multiple # times for _ in range(steps): node_pair_messages = self.edge_model(edge_attr, h[row], h[col]) # get all atom-pair messages # sum all messages per node to single message vector messages = self.sum_messages(edges, node_pair_messages, len(h)) # Use the messages to update the node-attributes h = self.node_model(messages, h) return h
# output, every node creates a prediction + an estimate how sure it is of # its prediction. Only done by last 'GNN layer'
[docs] def output(self, hidden_features, get_attention=True): output = self.output_mlp(hidden_features) if get_attention: return output, self.attention_mlp(hidden_features) return output
[docs]class SuperGNN(nn.Module): """SuperGNN is a class that defines multiple GNN layers. In particular, the `preproc_edge_mlp` and `preproc_node_mlp` are meant to preprocess the edge and node attributes, respectively. The `modlist` is a list of GNNLayer objects. Args: nm_edge_attr: Number of edge features. nmb_node_attr: Number of node features. nmb_hidden_attr: Number of hidden features. nmb_mlp_neurons: Number of neurons in the MLP. nmb_edge_projection: Number of edge projections. nmb_gnn_layers: Number of GNN layers. nmb_output_features: Number of output features. message_vector_length: Length of the message vector. act_fn: Activation function. Defaults to nn.SiLU(). """ def __init__( self, nmb_edge_attr, nmb_node_attr, nmb_hidden_attr, nmb_mlp_neurons, nmb_edge_projection, nmb_gnn_layers, nmb_output_features, message_vector_length, act_fn=nn.SiLU(), # noqa: B008 ): super().__init__() # Since edge_atributes go into every layer, it might be betetr to learn # a better/smarter representation of them first self.preproc_edge_mlp = nn.Sequential( nn.BatchNorm1d(nmb_edge_attr), nn.Linear(nmb_edge_attr, nmb_mlp_neurons), nn.BatchNorm1d(nmb_mlp_neurons), act_fn, nn.Linear(nmb_mlp_neurons, nmb_edge_projection), act_fn, ) # Project the node_attributes to the same size as the hidden vector self.preproc_node_mlp = nn.Sequential( nn.BatchNorm1d(nmb_node_attr), nn.Linear(nmb_node_attr, nmb_mlp_neurons), nn.BatchNorm1d(nmb_mlp_neurons), act_fn, nn.Linear(nmb_mlp_neurons, nmb_hidden_attr), act_fn, ) self.modlist = nn.ModuleList( [ GNNLayer( nmb_edge_projection, nmb_hidden_attr, nmb_output_features, message_vector_length, nmb_mlp_neurons, is_last_layer=(gnn_layer == (nmb_gnn_layers - 1)), ) for gnn_layer in range(nmb_gnn_layers) ], ) # always use this function before running the GNN layers
[docs] def preprocess(self, edge_attr, node_attr): edge_attr = self.preproc_edge_mlp(edge_attr) hidden_features = self.preproc_node_mlp(node_attr) return edge_attr, hidden_features
# Runs data through layers and return output. Potentially, attention can # also be returned
[docs] def run_through_network(self, edges, edge_attr, node_attr, with_output_attention=False): edge_attr, node_attr = self.preprocess(edge_attr, node_attr) for layer in self.modlist: node_attr = layer.update_nodes(edges, edge_attr, node_attr) if with_output_attention: representations, attention = self.modlist[-1].output(node_attr, True) # (boolean-positional-value-in-call) return representations, attention return self.modlist[-1].output(node_attr, True) # (boolean-positional-value-in-call)
[docs]class AlignmentGNN(SuperGNN): """Architecture based on multiple :class:`GNNLayer` layers, suited for both regression and classification tasks. It applies different layers to the nodes and edges of a graph (`preproc_edge_mlp` and `preproc_node_mlp`), and then applies multiple GNN layers (`modlist`). Args: nm_edge_attr: Number of edge features. nmb_node_attr: Number of node features. nmb_output_features: Number of output features. nmb_hidden_attr: Number of hidden features. message_vector_length: Length of the message vector. nmb_mlp_neurons: Number of neurons in the MLP. nmb_gnn_layers: Number of GNN layers. nmb_edge_projection: Number of edge projections. act_fn: Activation function. Defaults to nn.SiLU(). """ def __init__( self, nmb_edge_attr, nmb_node_attr, nmb_output_features, nmb_hidden_attr, message_vector_length, nmb_mlp_neurons, nmb_gnn_layers, nmb_edge_projection, act_fn=nn.SiLU(), # noqa: B008 ): super().__init__( nmb_edge_attr, nmb_node_attr, nmb_hidden_attr, nmb_mlp_neurons, nmb_edge_projection, nmb_gnn_layers, nmb_output_features, message_vector_length, act_fn, ) # Run over all layers, and return the ouput vectors
[docs] def forward(self, edges, edge_attr, node_attr): return self.run_through_network(edges, edge_attr, node_attr)