Deep Graph Matching Consensus¶
- class dgmc.DGMC(psi_1, psi_2, num_steps, k=- 1, detach=False)[source]¶
The Deep Graph Matching Consensus module which first matches nodes locally via a graph neural network \(\Psi_{\theta_1}\), and then updates correspondence scores iteratively by reaching for neighborhood consensus via a second graph neural network \(\Psi_{\theta_2}\).
Note
See the PyTorch Geometric introductory tutorial for a detailed overview of the used GNN modules and the respective data format.
- Parameters
psi_1 (torch.nn.Module) – The first GNN \(\Psi_{\theta_1}\) which takes in node features
x
, edge connectivityedge_index
, and optional edge featuresedge_attr
and computes node embeddings.psi_2 (torch.nn.Module) – The second GNN \(\Psi_{\theta_2}\) which takes in node features
x
, edge connectivityedge_index
, and optional edge featuresedge_attr
and validates for neighborhood consensus.psi_2
needs to hold the attributesin_channels
andout_channels
which indicates the dimensionality of randomly drawn node indicator functions and the output dimensionality ofpsi_2
, respectively.num_steps (int) – Number of consensus iterations.
k (int, optional) – Sparsity parameter. If set to
-1
, will not sparsify initial correspondence rankings. (default:-1
)detach (bool, optional) – If set to
True
, will detach the computation of \(\Psi_{\theta_1}\) from the current computation graph. (default:False
)
- acc(S, y, reduction='mean')[source]¶
Computes the accuracy of correspondence predictions.
- Parameters
S (Tensor) – Sparse or dense correspondence matrix of shape
[batch_size * num_nodes, num_nodes]
.y (LongTensor) – Ground-truth matchings of shape
[2, num_ground_truths]
.reduction (string, optional) – Specifies the reduction to apply to the output:
'mean'|'sum'
. (default:'mean'
)
- forward(x_s, edge_index_s, edge_attr_s, batch_s, x_t, edge_index_t, edge_attr_t, batch_t, y=None)[source]¶
- Parameters
x_s (Tensor) – Source graph node features of shape
[batch_size * num_nodes, C_in]
.edge_index_s (LongTensor) – Source graph edge connectivity of shape
[2, num_edges]
.edge_attr_s (Tensor) – Source graph edge features of shape
[num_edges, D]
. Set toNone
if the GNNs are not taking edge features into account.batch_s (LongTensor) – Source graph batch vector of shape
[batch_size * num_nodes]
indicating node to graph assignment. Set toNone
if operating on single graphs.x_t (Tensor) – Target graph node features of shape
[batch_size * num_nodes, C_in]
.edge_index_t (LongTensor) – Target graph edge connectivity of shape
[2, num_edges]
.edge_attr_t (Tensor) – Target graph edge features of shape
[num_edges, D]
. Set toNone
if the GNNs are not taking edge features into account.batch_s – Target graph batch vector of shape
[batch_size * num_nodes]
indicating node to graph assignment. Set toNone
if operating on single graphs.y (LongTensor, optional) – Ground-truth matchings of shape
[2, num_ground_truths]
to include ground-truth values when training against sparse correspondences. Ground-truths are only used in case the model is in training mode. (default:None
)
- Returns
Initial and refined correspondence matrices
(S_0, S_L)
of shapes[batch_size * num_nodes, num_nodes]
. The correspondence matrix are either given as dense or sparse matrices.
- hits_at_k(k, S, y, reduction='mean')[source]¶
Computes the hits@k of correspondence predictions.
- Parameters
k (int) – The \(\mathrm{top}_k\) predictions to consider.
S (Tensor) – Sparse or dense correspondence matrix of shape
[batch_size * num_nodes, num_nodes]
.y (LongTensor) – Ground-truth matchings of shape
[2, num_ground_truths]
.reduction (string, optional) – Specifies the reduction to apply to the output:
'mean'|'sum'
. (default:'mean'
)
- loss(S, y, reduction='mean')[source]¶
Computes the negative log-likelihood loss on the correspondence matrix.
- Parameters
S (Tensor) – Sparse or dense correspondence matrix of shape
[batch_size * num_nodes, num_nodes]
.y (LongTensor) – Ground-truth matchings of shape
[2, num_ground_truths]
.reduction (string, optional) – Specifies the reduction to apply to the output:
'none'|'mean'|'sum'
. (default:'mean'
)