import torch
from torch.nn import Linear as Lin, BatchNorm1d as BN
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
class RelConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super(RelConv, self).__init__(aggr='mean')
self.in_channels = in_channels
self.out_channels = out_channels
self.lin1 = Lin(in_channels, out_channels, bias=False)
self.lin2 = Lin(in_channels, out_channels, bias=False)
self.root = Lin(in_channels, out_channels)
self.reset_parameters()
def reset_parameters(self):
self.lin1.reset_parameters()
self.lin2.reset_parameters()
self.root.reset_parameters()
def forward(self, x, edge_index):
""""""
self.flow = 'source_to_target'
out1 = self.propagate(edge_index, x=self.lin1(x))
self.flow = 'target_to_source'
out2 = self.propagate(edge_index, x=self.lin2(x))
return self.root(x) + out1 + out2
def message(self, x_j):
return x_j
def __repr__(self):
return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
self.out_channels)
[docs]class RelCNN(torch.nn.Module):
def __init__(self, in_channels, out_channels, num_layers, batch_norm=False,
cat=True, lin=True, dropout=0.0):
super(RelCNN, self).__init__()
self.in_channels = in_channels
self.num_layers = num_layers
self.batch_norm = batch_norm
self.cat = cat
self.lin = lin
self.dropout = dropout
self.convs = torch.nn.ModuleList()
self.batch_norms = torch.nn.ModuleList()
for _ in range(num_layers):
self.convs.append(RelConv(in_channels, out_channels))
self.batch_norms.append(BN(out_channels))
in_channels = out_channels
if self.cat:
in_channels = self.in_channels + num_layers * out_channels
else:
in_channels = out_channels
if self.lin:
self.out_channels = out_channels
self.final = Lin(in_channels, out_channels)
else:
self.out_channels = in_channels
self.reset_parameters()
[docs] def reset_parameters(self):
for conv, batch_norm in zip(self.convs, self.batch_norms):
conv.reset_parameters()
batch_norm.reset_parameters()
if self.lin:
self.final.reset_parameters()
[docs] def forward(self, x, edge_index, *args):
""""""
xs = [x]
for conv, batch_norm in zip(self.convs, self.batch_norms):
x = conv(xs[-1], edge_index)
x = batch_norm(F.relu(x)) if self.batch_norm else F.relu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
xs.append(x)
x = torch.cat(xs, dim=-1) if self.cat else xs[-1]
x = self.final(x) if self.lin else x
return x
def __repr__(self):
return ('{}({}, {}, num_layers={}, batch_norm={}, cat={}, lin={}, '
'dropout={})').format(self.__class__.__name__,
self.in_channels, self.out_channels,
self.num_layers, self.batch_norm,
self.cat, self.lin, self.dropout)