Source code for deeprank2.neuralnets.gnn.ginet

import torch
from torch import nn
from torch.nn.functional import dropout, leaky_relu, relu, softmax
from torch_geometric.nn import max_pool_x
from torch_geometric.nn.inits import uniform
from torch_scatter import scatter_mean, scatter_sum

from deeprank2.utils.community_pooling import community_pooling, get_preloaded_cluster

# ruff: noqa: ANN001, ANN201


[docs]class GINetConvLayer(nn.Module): """GiNet convolutional layer for graph neural networks. Args: in_channels: Number of input features. out_channels: Number of output features. number_edge_features: Number of edge features. Defaults to 1. bias: If set to `False`, the layer will not learn an additive bias. Defaults to False. """ def __init__(self, in_channels, out_channels, number_edge_features=1, bias=False): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.fc = nn.Linear(self.in_channels, self.out_channels, bias=bias) self.fc_edge_attr = nn.Linear(number_edge_features, number_edge_features, bias=bias) self.fc_attention = nn.Linear(2 * self.out_channels + number_edge_features, 1, bias=bias) self.reset_parameters()
[docs] def reset_parameters(self) -> None: size = self.in_channels uniform(size, self.fc.weight) uniform(size, self.fc_attention.weight) uniform(size, self.fc_edge_attr.weight)
[docs] def forward(self, x, edge_index, edge_attr): row, col = edge_index num_node = len(x) edge_attr = edge_attr.unsqueeze(-1) if edge_attr.dim() == 1 else edge_attr xcol = self.fc(x[col]) xrow = self.fc(x[row]) ed = self.fc_edge_attr(edge_attr) # create edge feature by concatenating node feature alpha = torch.cat([xrow, xcol, ed], dim=1) alpha = self.fc_attention(alpha) alpha = leaky_relu(alpha) alpha = softmax(alpha, dim=1) h = alpha * xcol out = torch.zeros(num_node, self.out_channels).to(alpha.device) z = scatter_sum(h, row, dim=0, out=out) return z # noqa:RET504 (unnecessary-assign)
def __repr__(self): return f"{self.__class__.__name__}({self.in_channels}, {self.out_channels})"
[docs]class GINet(nn.Module): """Architecture based on the GiNet convolutional layer, suited for both regression and classification tasks. It uses community pooling to reduce the number of nodes. Args: input_shape: Number of input features. output_shape: Number of output value per graph. Defaults to 1. input_shape_edge: Number of edge input features. Defaults to 1. """ def __init__(self, input_shape, output_shape=1, input_shape_edge=1): super().__init__() self.conv1 = GINetConvLayer(input_shape, 16, input_shape_edge) self.conv2 = GINetConvLayer(16, 32, input_shape_edge) self.conv1_ext = GINetConvLayer(input_shape, 16, input_shape_edge) self.conv2_ext = GINetConvLayer(16, 32, input_shape_edge) self.fc1 = nn.Linear(2 * 32, 128) self.fc2 = nn.Linear(128, output_shape) self.clustering = "mcl" self.dropout = 0.4
[docs] def forward(self, data): act = relu data_ext = data.clone() # EXTERNAL INTERACTION GRAPH # first conv block data.x = act(self.conv1(data.x, data.edge_index, data.edge_attr)) cluster = get_preloaded_cluster(data.cluster0, data.batch) data = community_pooling(cluster, data) # second conv block data.x = act(self.conv2(data.x, data.edge_index, data.edge_attr)) cluster = get_preloaded_cluster(data.cluster1, data.batch) x, batch = max_pool_x(cluster, data.x, data.batch) # INTERNAL INTERACTION GRAPH # first conv block data_ext.x = act(self.conv1_ext(data_ext.x, data_ext.edge_index, data_ext.edge_attr)) cluster = get_preloaded_cluster(data_ext.cluster0, data_ext.batch) data_ext = community_pooling(cluster, data_ext) # second conv block data_ext.x = act(self.conv2_ext(data_ext.x, data_ext.edge_index, data_ext.edge_attr)) cluster = get_preloaded_cluster(data_ext.cluster1, data_ext.batch) x_ext, batch_ext = max_pool_x(cluster, data_ext.x, data_ext.batch) # FC x = scatter_mean(x, batch, dim=0) x_ext = scatter_mean(x_ext, batch_ext, dim=0) x = torch.cat([x, x_ext], dim=1) x = act(self.fc1(x)) x = dropout(x, self.dropout, training=self.training) x = self.fc2(x) return x # noqa:RET504 (unnecessary-assign)