Deep Graph Matching Consensus¶
Package Reference
- class dgmc.DGMC(psi_1, psi_2, num_steps, k=- 1, detach=False)[source]¶
The Deep Graph Matching Consensus module which first matches nodes locally via a graph neural network \(\Psi_{\theta_1}\), and then updates correspondence scores iteratively by reaching for neighborhood consensus via a second graph neural network \(\Psi_{\theta_2}\).
Note
See the PyTorch Geometric introductory tutorial for a detailed overview of the used GNN modules and the respective data format.
- Parameters
psi_1 (torch.nn.Module) – The first GNN \(\Psi_{\theta_1}\) which takes in node features
x, edge connectivityedge_index, and optional edge featuresedge_attrand computes node embeddings.psi_2 (torch.nn.Module) – The second GNN \(\Psi_{\theta_2}\) which takes in node features
x, edge connectivityedge_index, and optional edge featuresedge_attrand validates for neighborhood consensus.psi_2needs to hold the attributesin_channelsandout_channelswhich indicates the dimensionality of randomly drawn node indicator functions and the output dimensionality ofpsi_2, respectively.num_steps (int) – Number of consensus iterations.
k (int, optional) – Sparsity parameter. If set to
-1, will not sparsify initial correspondence rankings. (default:-1)detach (bool, optional) – If set to
True, will detach the computation of \(\Psi_{\theta_1}\) from the current computation graph. (default:False)
- acc(S, y, reduction='mean')[source]¶
Computes the accuracy of correspondence predictions.
- Parameters
S (Tensor) – Sparse or dense correspondence matrix of shape
[batch_size * num_nodes, num_nodes].y (LongTensor) – Ground-truth matchings of shape
[2, num_ground_truths].reduction (string, optional) – Specifies the reduction to apply to the output:
'mean'|'sum'. (default:'mean')
- forward(x_s, edge_index_s, edge_attr_s, batch_s, x_t, edge_index_t, edge_attr_t, batch_t, y=None)[source]¶
- Parameters
x_s (Tensor) – Source graph node features of shape
[batch_size * num_nodes, C_in].edge_index_s (LongTensor) – Source graph edge connectivity of shape
[2, num_edges].edge_attr_s (Tensor) – Source graph edge features of shape
[num_edges, D]. Set toNoneif the GNNs are not taking edge features into account.batch_s (LongTensor) – Source graph batch vector of shape
[batch_size * num_nodes]indicating node to graph assignment. Set toNoneif operating on single graphs.x_t (Tensor) – Target graph node features of shape
[batch_size * num_nodes, C_in].edge_index_t (LongTensor) – Target graph edge connectivity of shape
[2, num_edges].edge_attr_t (Tensor) – Target graph edge features of shape
[num_edges, D]. Set toNoneif the GNNs are not taking edge features into account.batch_s – Target graph batch vector of shape
[batch_size * num_nodes]indicating node to graph assignment. Set toNoneif operating on single graphs.y (LongTensor, optional) – Ground-truth matchings of shape
[2, num_ground_truths]to include ground-truth values when training against sparse correspondences. Ground-truths are only used in case the model is in training mode. (default:None)
- Returns
Initial and refined correspondence matrices
(S_0, S_L)of shapes[batch_size * num_nodes, num_nodes]. The correspondence matrix are either given as dense or sparse matrices.
- hits_at_k(k, S, y, reduction='mean')[source]¶
Computes the hits@k of correspondence predictions.
- Parameters
k (int) – The \(\mathrm{top}_k\) predictions to consider.
S (Tensor) – Sparse or dense correspondence matrix of shape
[batch_size * num_nodes, num_nodes].y (LongTensor) – Ground-truth matchings of shape
[2, num_ground_truths].reduction (string, optional) – Specifies the reduction to apply to the output:
'mean'|'sum'. (default:'mean')
- loss(S, y, reduction='mean')[source]¶
Computes the negative log-likelihood loss on the correspondence matrix.
- Parameters
S (Tensor) – Sparse or dense correspondence matrix of shape
[batch_size * num_nodes, num_nodes].y (LongTensor) – Ground-truth matchings of shape
[2, num_ground_truths].reduction (string, optional) – Specifies the reduction to apply to the output:
'none'|'mean'|'sum'. (default:'mean')