import torch
from torch.nn import Sequential as Seq, Linear as Lin, ReLU
from torch_scatter import scatter_add
from torch_geometric.utils import to_dense_batch
from torch_geometric.nn.inits import reset

    from pykeops.torch import LazyTensor
except ImportError:
    LazyTensor = None

EPS = 1e-8

def masked_softmax(src, mask, dim=-1):
    out = src.masked_fill(~mask, float('-inf'))
    out = torch.softmax(out, dim=dim)
    out = out.masked_fill(~mask, 0)
    return out

def to_sparse(x, mask):
    return x[mask]

def to_dense(x, mask):
    out = x.new_zeros(tuple(mask.size()) + (x.size(-1), ))
    out[mask] = x
    return out

[docs]class DGMC(torch.nn.Module): r"""The *Deep Graph Matching Consensus* module which first matches nodes locally via a graph neural network :math:`\Psi_{\theta_1}`, and then updates correspondence scores iteratively by reaching for neighborhood consensus via a second graph neural network :math:`\Psi_{\theta_2}`. .. note:: See the `PyTorch Geometric introductory tutorial < introduction.html>`_ for a detailed overview of the used GNN modules and the respective data format. Args: psi_1 (torch.nn.Module): The first GNN :math:`\Psi_{\theta_1}` which takes in node features :obj:`x`, edge connectivity :obj:`edge_index`, and optional edge features :obj:`edge_attr` and computes node embeddings. psi_2 (torch.nn.Module): The second GNN :math:`\Psi_{\theta_2}` which takes in node features :obj:`x`, edge connectivity :obj:`edge_index`, and optional edge features :obj:`edge_attr` and validates for neighborhood consensus. :obj:`psi_2` needs to hold the attributes :obj:`in_channels` and :obj:`out_channels` which indicates the dimensionality of randomly drawn node indicator functions and the output dimensionality of :obj:`psi_2`, respectively. num_steps (int): Number of consensus iterations. k (int, optional): Sparsity parameter. If set to :obj:`-1`, will not sparsify initial correspondence rankings. (default: :obj:`-1`) detach (bool, optional): If set to :obj:`True`, will detach the computation of :math:`\Psi_{\theta_1}` from the current computation graph. (default: :obj:`False`) """ def __init__(self, psi_1, psi_2, num_steps, k=-1, detach=False): super(DGMC, self).__init__() self.psi_1 = psi_1 self.psi_2 = psi_2 self.num_steps = num_steps self.k = k self.detach = detach self.backend = 'auto' self.mlp = Seq( Lin(psi_2.out_channels, psi_2.out_channels), ReLU(), Lin(psi_2.out_channels, 1), ) def reset_parameters(self): self.psi_1.reset_parameters() self.psi_2.reset_parameters() reset(self.mlp) def __top_k__(self, x_s, x_t): # pragma: no cover r"""Memory-efficient top-k correspondence computation.""" if LazyTensor is not None: x_s = x_s.unsqueeze(-2) # [..., n_s, 1, d] x_t = x_t.unsqueeze(-3) # [..., 1, n_t, d] x_s, x_t = LazyTensor(x_s), LazyTensor(x_t) S_ij = (-x_s * x_t).sum(dim=-1) return S_ij.argKmin(self.k, dim=2, backend=self.backend) else: x_s = x_s # [..., n_s, d] x_t = x_t.transpose(-1, -2) # [..., d, n_t] S_ij = x_s @ x_t return S_ij.topk(self.k, dim=2)[1] def __include_gt__(self, S_idx, s_mask, y): r"""Includes the ground-truth values in :obj:`y` to the index tensor :obj:`S_idx`.""" (B, N_s), (row, col), k = s_mask.size(), y, S_idx.size(-1) gt_mask = (S_idx[s_mask][row] != col.view(-1, 1)).all(dim=-1) sparse_mask = gt_mask.new_zeros((s_mask.sum(), )) sparse_mask[row] = gt_mask dense_mask = sparse_mask.new_zeros((B, N_s)) dense_mask[s_mask] = sparse_mask last_entry = torch.zeros(k, dtype=torch.bool, device=gt_mask.device) last_entry[-1] = 1 dense_mask = dense_mask.view(B, N_s, 1) * last_entry.view(1, 1, k) return S_idx.masked_scatter(dense_mask, col[gt_mask])
[docs] def forward(self, x_s, edge_index_s, edge_attr_s, batch_s, x_t, edge_index_t, edge_attr_t, batch_t, y=None): r""" Args: x_s (Tensor): Source graph node features of shape :obj:`[batch_size * num_nodes, C_in]`. edge_index_s (LongTensor): Source graph edge connectivity of shape :obj:`[2, num_edges]`. edge_attr_s (Tensor): Source graph edge features of shape :obj:`[num_edges, D]`. Set to :obj:`None` if the GNNs are not taking edge features into account. batch_s (LongTensor): Source graph batch vector of shape :obj:`[batch_size * num_nodes]` indicating node to graph assignment. Set to :obj:`None` if operating on single graphs. x_t (Tensor): Target graph node features of shape :obj:`[batch_size * num_nodes, C_in]`. edge_index_t (LongTensor): Target graph edge connectivity of shape :obj:`[2, num_edges]`. edge_attr_t (Tensor): Target graph edge features of shape :obj:`[num_edges, D]`. Set to :obj:`None` if the GNNs are not taking edge features into account. batch_s (LongTensor): Target graph batch vector of shape :obj:`[batch_size * num_nodes]` indicating node to graph assignment. Set to :obj:`None` if operating on single graphs. y (LongTensor, optional): Ground-truth matchings of shape :obj:`[2, num_ground_truths]` to include ground-truth values when training against sparse correspondences. Ground-truths are only used in case the model is in training mode. (default: :obj:`None`) Returns: Initial and refined correspondence matrices :obj:`(S_0, S_L)` of shapes :obj:`[batch_size * num_nodes, num_nodes]`. The correspondence matrix are either given as dense or sparse matrices. """ h_s = self.psi_1(x_s, edge_index_s, edge_attr_s) h_t = self.psi_1(x_t, edge_index_t, edge_attr_t) h_s, h_t = (h_s.detach(), h_t.detach()) if self.detach else (h_s, h_t) h_s, s_mask = to_dense_batch(h_s, batch_s, fill_value=0) h_t, t_mask = to_dense_batch(h_t, batch_t, fill_value=0) assert h_s.size(0) == h_t.size(0), 'Encountered unequal batch-sizes' (B, N_s, C_out), N_t = h_s.size(), h_t.size(1) R_in, R_out = self.psi_2.in_channels, self.psi_2.out_channels if self.k < 1: # ------ Dense variant ------ # S_hat = h_s @ h_t.transpose(-1, -2) # [B, N_s, N_t, C_out] S_mask = s_mask.view(B, N_s, 1) & t_mask.view(B, 1, N_t) S_0 = masked_softmax(S_hat, S_mask, dim=-1)[s_mask] for _ in range(self.num_steps): S = masked_softmax(S_hat, S_mask, dim=-1) r_s = torch.randn((B, N_s, R_in), dtype=h_s.dtype, device=h_s.device) r_t = S.transpose(-1, -2) @ r_s r_s, r_t = to_sparse(r_s, s_mask), to_sparse(r_t, t_mask) o_s = self.psi_2(r_s, edge_index_s, edge_attr_s) o_t = self.psi_2(r_t, edge_index_t, edge_attr_t) o_s, o_t = to_dense(o_s, s_mask), to_dense(o_t, t_mask) D = o_s.view(B, N_s, 1, R_out) - o_t.view(B, 1, N_t, R_out) S_hat = S_hat + self.mlp(D).squeeze(-1).masked_fill(~S_mask, 0) S_L = masked_softmax(S_hat, S_mask, dim=-1)[s_mask] return S_0, S_L else: # ------ Sparse variant ------ # S_idx = self.__top_k__(h_s, h_t) # [B, N_s, k] # In addition to the top-k, randomly sample negative examples and # ensure that the ground-truth is included as a sparse entry. if and y is not None: rnd_size = (B, N_s, min(self.k, N_t - self.k)) S_rnd_idx = torch.randint(N_t, rnd_size, dtype=torch.long, device=S_idx.device) S_idx =[S_idx, S_rnd_idx], dim=-1) S_idx = self.__include_gt__(S_idx, s_mask, y) k = S_idx.size(-1) tmp_s = h_s.view(B, N_s, 1, C_out) idx = S_idx.view(B, N_s * k, 1).expand(-1, -1, C_out) tmp_t = torch.gather(h_t.view(B, N_t, C_out), -2, idx) S_hat = (tmp_s * tmp_t.view(B, N_s, k, C_out)).sum(dim=-1) S_0 = S_hat.softmax(dim=-1)[s_mask] for _ in range(self.num_steps): S = S_hat.softmax(dim=-1) r_s = torch.randn((B, N_s, R_in), dtype=h_s.dtype, device=h_s.device) tmp_t = r_s.view(B, N_s, 1, R_in) * S.view(B, N_s, k, 1) tmp_t = tmp_t.view(B, N_s * k, R_in) idx = S_idx.view(B, N_s * k, 1) r_t = scatter_add(tmp_t, idx, dim=1, dim_size=N_t) r_s, r_t = to_sparse(r_s, s_mask), to_sparse(r_t, t_mask) o_s = self.psi_2(r_s, edge_index_s, edge_attr_s) o_t = self.psi_2(r_t, edge_index_t, edge_attr_t) o_s, o_t = to_dense(o_s, s_mask), to_dense(o_t, t_mask) o_s = o_s.view(B, N_s, 1, R_out).expand(-1, -1, k, -1) idx = S_idx.view(B, N_s * k, 1).expand(-1, -1, R_out) tmp_t = torch.gather(o_t.view(B, N_t, R_out), -2, idx) D = o_s - tmp_t.view(B, N_s, k, R_out) S_hat = S_hat + self.mlp(D).squeeze(-1) S_L = S_hat.softmax(dim=-1)[s_mask] S_idx = S_idx[s_mask] # Convert sparse layout to `torch.sparse_coo_tensor`. row = torch.arange(x_s.size(0), device=S_idx.device) row = row.view(-1, 1).repeat(1, k) idx = torch.stack([row.view(-1), S_idx.view(-1)], dim=0) size = torch.Size([x_s.size(0), N_t]) S_sparse_0 = torch.sparse_coo_tensor( idx, S_0.view(-1), size, requires_grad=S_0.requires_grad) S_sparse_0.__idx__ = S_idx S_sparse_0.__val__ = S_0 S_sparse_L = torch.sparse_coo_tensor( idx, S_L.view(-1), size, requires_grad=S_L.requires_grad) S_sparse_L.__idx__ = S_idx S_sparse_L.__val__ = S_L return S_sparse_0, S_sparse_L
[docs] def loss(self, S, y, reduction='mean'): r"""Computes the negative log-likelihood loss on the correspondence matrix. Args: S (Tensor): Sparse or dense correspondence matrix of shape :obj:`[batch_size * num_nodes, num_nodes]`. y (LongTensor): Ground-truth matchings of shape :obj:`[2, num_ground_truths]`. reduction (string, optional): Specifies the reduction to apply to the output: :obj:`'none'|'mean'|'sum'`. (default: :obj:`'mean'`) """ assert reduction in ['none', 'mean', 'sum'] if not S.is_sparse: val = S[y[0], y[1]] else: assert S.__idx__ is not None and S.__val__ is not None mask = S.__idx__[y[0]] == y[1].view(-1, 1) val = S.__val__[[y[0]]][mask] nll = -torch.log(val + EPS) return nll if reduction == 'none' else getattr(torch, reduction)(nll)
[docs] def acc(self, S, y, reduction='mean'): r"""Computes the accuracy of correspondence predictions. Args: S (Tensor): Sparse or dense correspondence matrix of shape :obj:`[batch_size * num_nodes, num_nodes]`. y (LongTensor): Ground-truth matchings of shape :obj:`[2, num_ground_truths]`. reduction (string, optional): Specifies the reduction to apply to the output: :obj:`'mean'|'sum'`. (default: :obj:`'mean'`) """ assert reduction in ['mean', 'sum'] if not S.is_sparse: pred = S[y[0]].argmax(dim=-1) else: assert S.__idx__ is not None and S.__val__ is not None pred = S.__idx__[y[0], S.__val__[y[0]].argmax(dim=-1)] correct = (pred == y[1]).sum().item() return correct / y.size(1) if reduction == 'mean' else correct
[docs] def hits_at_k(self, k, S, y, reduction='mean'): r"""Computes the hits@k of correspondence predictions. Args: k (int): The :math:`\mathrm{top}_k` predictions to consider. S (Tensor): Sparse or dense correspondence matrix of shape :obj:`[batch_size * num_nodes, num_nodes]`. y (LongTensor): Ground-truth matchings of shape :obj:`[2, num_ground_truths]`. reduction (string, optional): Specifies the reduction to apply to the output: :obj:`'mean'|'sum'`. (default: :obj:`'mean'`) """ assert reduction in ['mean', 'sum'] if not S.is_sparse: pred = S[y[0]].argsort(dim=-1, descending=True)[:, :k] else: assert S.__idx__ is not None and S.__val__ is not None perm = S.__val__[y[0]].argsort(dim=-1, descending=True)[:, :k] pred = torch.gather(S.__idx__[y[0]], -1, perm) correct = (pred == y[1].view(-1, 1)).sum().item() return correct / y.size(1) if reduction == 'mean' else correct
def __repr__(self): return ('{}(\n' ' psi_1={},\n' ' psi_2={},\n' ' num_steps={}, k={}\n)').format(self.__class__.__name__, self.psi_1, self.psi_2, self.num_steps, self.k)