Source code for dgmc.models.gin

import torch
from torch.nn import Linear as Lin
from torch_geometric.nn import GINConv

from .mlp import MLP


[docs]class GIN(torch.nn.Module): def __init__(self, in_channels, out_channels, num_layers, batch_norm=False, cat=True, lin=True): super(GIN, self).__init__() self.in_channels = in_channels self.num_layers = num_layers self.batch_norm = batch_norm self.cat = cat self.lin = lin self.convs = torch.nn.ModuleList() for _ in range(num_layers): mlp = MLP(in_channels, out_channels, 2, batch_norm, dropout=0.0) self.convs.append(GINConv(mlp, train_eps=True)) in_channels = out_channels if self.cat: in_channels = self.in_channels + num_layers * out_channels else: in_channels = out_channels if self.lin: self.out_channels = out_channels self.final = Lin(in_channels, out_channels) else: self.out_channels = in_channels self.reset_parameters()
[docs] def reset_parameters(self): for conv in self.convs: conv.reset_parameters() if self.lin: self.final.reset_parameters()
[docs] def forward(self, x, edge_index, *args): """""" xs = [x] for conv in self.convs: xs += [conv(xs[-1], edge_index)] x = torch.cat(xs, dim=-1) if self.cat else xs[-1] x = self.final(x) if self.lin else x return x
def __repr__(self): return ('{}({}, {}, num_layers={}, batch_norm={}, cat={}, ' 'lin={})').format(self.__class__.__name__, self.in_channels, self.out_channels, self.num_layers, self.batch_norm, self.cat, self.lin)