Source code for deeprank2.neuralnets.gnn.sgat

import torch
from torch import nn
from torch.nn.functional import relu
from torch_geometric.nn import max_pool_x
from torch_geometric.nn.inits import uniform
from torch_scatter import scatter_mean

from deeprank2.utils.community_pooling import community_pooling, get_preloaded_cluster

# ruff: noqa: ANN001, ANN201


[docs]class SGraphAttentionLayer(nn.Module): """SGraphAttentionLayer. This is a new layer that is similar to the graph attention network but simpler z_i = 1 / Ni \\Sum_j a_ij * [x_i || x_j] * W + b_i || is the concatenation operator: [1,2,3] || [4,5,6] = [1,2,3,4,5,6] Ni is the number of neighbor of node i \\Sum_j runs over the neighbors of node i a_ij is the edge attribute between node i and j Args: in_channels: Size of each input sample. out_channels: Size of each output sample. bias: If set to `False`, the layer will not learn an additive bias. Defaults to True. """ # noqa: D301 def __init__( self, in_channels: int, out_channels: int, bias: bool = True, undirected: bool = True, ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.undirected = undirected self.weight = nn.Parameter(torch.Tensor(2 * in_channels, out_channels)) if bias: self.bias = nn.Parameter(torch.Tensor(out_channels)) else: self.register_parameter("bias", None) self.reset_parameters()
[docs] def reset_parameters(self) -> None: size = 2 * self.in_channels uniform(size, self.weight) uniform(size, self.bias)
[docs] def forward(self, x, edge_index, edge_attr): row, col = edge_index num_node = len(x) edge_attr = edge_attr.unsqueeze(-1) if edge_attr.dim() == 1 else edge_attr # create edge feature by concatenating node feature alpha = torch.cat([x[row], x[col]], dim=-1) # multiply the edge features with the fliter alpha = torch.mm(alpha, self.weight) # multiply each edge features with the corresponding dist alpha = edge_attr * alpha # scatter the resulting edge feature to get node features out = torch.zeros(num_node, self.out_channels).to(alpha.device) out = scatter_mean(alpha, row, dim=0, out=out) # if the graph is undirected and (i,j) and (j,i) are both in # the edge_index then we do not need to have that second line # or we count everything twice if not self.undirected: out = scatter_mean(alpha, col, dim=0, out=out) # add the bias if self.bias is not None: out = out + self.bias return out
def __repr__(self): return f"{self.__class__.__name__}({self.in_channels}, {self.out_channels})"
[docs]class SGAT(nn.Module): """Simple graph attention network, suited for both regression and classification tasks. It uses two graph attention layers and a MLP to predict the output. Args: input_shape: Size of each input sample. output_shape: Size of each output sample. Defaults to 1. input_shape_edge: Size of each input edge. Defaults to None. """ def __init__( self, input_shape, output_shape=1, input_shape_edge=None, # noqa: ARG002 ): super().__init__() self.conv1 = SGraphAttentionLayer(input_shape, 16) self.conv2 = SGraphAttentionLayer(16, 32) self.fc1 = nn.Linear(32, 64) self.fc2 = nn.Linear(64, output_shape) self.clustering = "mcl"
[docs] def forward(self, data): act = nn.Tanhshrink() act = relu # first conv block data.x = act(self.conv1(data.x, data.edge_index, data.edge_attr)) cluster = get_preloaded_cluster(data.cluster0, data.batch) data = community_pooling(cluster, data) # second conv block data.x = act(self.conv2(data.x, data.edge_index, data.edge_attr)) cluster = get_preloaded_cluster(data.cluster1, data.batch) x, batch = max_pool_x(cluster, data.x, data.batch) # FC x = scatter_mean(x, batch, dim=0) x = act(self.fc1(x)) x = self.fc2(x) return x # noqa:RET504 (unnecessary-assign)