Source code for dgmc.models.rel

import torch
from torch.nn import Linear as Lin, BatchNorm1d as BN
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing


class RelConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(RelConv, self).__init__(aggr='mean')

        self.in_channels = in_channels
        self.out_channels = out_channels

        self.lin1 = Lin(in_channels, out_channels, bias=False)
        self.lin2 = Lin(in_channels, out_channels, bias=False)
        self.root = Lin(in_channels, out_channels)

        self.reset_parameters()

    def reset_parameters(self):
        self.lin1.reset_parameters()
        self.lin2.reset_parameters()
        self.root.reset_parameters()

    def forward(self, x, edge_index):
        """"""
        self.flow = 'source_to_target'
        out1 = self.propagate(edge_index, x=self.lin1(x))
        self.flow = 'target_to_source'
        out2 = self.propagate(edge_index, x=self.lin2(x))
        return self.root(x) + out1 + out2

    def message(self, x_j):
        return x_j

    def __repr__(self):
        return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
                                   self.out_channels)


[docs]class RelCNN(torch.nn.Module): def __init__(self, in_channels, out_channels, num_layers, batch_norm=False, cat=True, lin=True, dropout=0.0): super(RelCNN, self).__init__() self.in_channels = in_channels self.num_layers = num_layers self.batch_norm = batch_norm self.cat = cat self.lin = lin self.dropout = dropout self.convs = torch.nn.ModuleList() self.batch_norms = torch.nn.ModuleList() for _ in range(num_layers): self.convs.append(RelConv(in_channels, out_channels)) self.batch_norms.append(BN(out_channels)) in_channels = out_channels if self.cat: in_channels = self.in_channels + num_layers * out_channels else: in_channels = out_channels if self.lin: self.out_channels = out_channels self.final = Lin(in_channels, out_channels) else: self.out_channels = in_channels self.reset_parameters()
[docs] def reset_parameters(self): for conv, batch_norm in zip(self.convs, self.batch_norms): conv.reset_parameters() batch_norm.reset_parameters() if self.lin: self.final.reset_parameters()
[docs] def forward(self, x, edge_index, *args): """""" xs = [x] for conv, batch_norm in zip(self.convs, self.batch_norms): x = conv(xs[-1], edge_index) x = batch_norm(F.relu(x)) if self.batch_norm else F.relu(x) x = F.dropout(x, p=self.dropout, training=self.training) xs.append(x) x = torch.cat(xs, dim=-1) if self.cat else xs[-1] x = self.final(x) if self.lin else x return x
def __repr__(self): return ('{}({}, {}, num_layers={}, batch_norm={}, cat={}, lin={}, ' 'dropout={})').format(self.__class__.__name__, self.in_channels, self.out_channels, self.num_layers, self.batch_norm, self.cat, self.lin, self.dropout)