import torch
from torch.nn import Linear as Lin
import torch.nn.functional as F
from torch_geometric.nn import SplineConv
[docs]class SplineCNN(torch.nn.Module):
def __init__(self, in_channels, out_channels, dim, num_layers, cat=True,
lin=True, dropout=0.0):
super(SplineCNN, self).__init__()
self.in_channels = in_channels
self.dim = dim
self.num_layers = num_layers
self.cat = cat
self.lin = lin
self.dropout = dropout
self.convs = torch.nn.ModuleList()
for _ in range(num_layers):
conv = SplineConv(in_channels, out_channels, dim, kernel_size=5)
self.convs.append(conv)
in_channels = out_channels
if self.cat:
in_channels = self.in_channels + num_layers * out_channels
else:
in_channels = out_channels
if self.lin:
self.out_channels = out_channels
self.final = Lin(in_channels, out_channels)
else:
self.out_channels = in_channels
self.reset_parameters()
[docs] def reset_parameters(self):
for conv in self.convs:
conv.reset_parameters()
if self.lin:
self.final.reset_parameters()
[docs] def forward(self, x, edge_index, edge_attr, *args):
""""""
xs = [x]
for conv in self.convs:
xs += [F.relu(conv(xs[-1], edge_index, edge_attr))]
x = torch.cat(xs, dim=-1) if self.cat else xs[-1]
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.final(x) if self.lin else x
return x
def __repr__(self):
return ('{}({}, {}, dim={}, num_layers={}, cat={}, lin={}, '
'dropout={})').format(self.__class__.__name__,
self.in_channels, self.out_channels,
self.dim, self.num_layers, self.cat,
self.lin, self.dropout)