Source code for dgmc.utils.data

import re
from itertools import chain

import torch
import random
from torch_geometric.data import Data


class PairData(Data):  # pragma: no cover
    def __inc__(self, key, value, *args):
        if bool(re.search('index_s', key)):
            return self.x_s.size(0)
        if bool(re.search('index_t', key)):
            return self.x_t.size(0)
        else:
            return 0


[docs]class PairDataset(torch.utils.data.Dataset): r"""Combines two datasets, a source dataset and a target dataset, by building pairs between separate dataset examples. Args: dataset_s (torch.utils.data.Dataset): The source dataset. dataset_t (torch.utils.data.Dataset): The target dataset. sample (bool, optional): If set to :obj:`True`, will sample exactly one target example for every source example instead of holding the product of all source and target examples. (default: :obj:`False`) """ def __init__(self, dataset_s, dataset_t, sample=False): self.dataset_s = dataset_s self.dataset_t = dataset_t self.sample = sample def __len__(self): return len(self.dataset_s) if self.sample else len( self.dataset_s) * len(self.dataset_t) def __getitem__(self, idx): if self.sample: data_s = self.dataset_s[idx] data_t = self.dataset_t[random.randint(0, len(self.dataset_t) - 1)] else: data_s = self.dataset_s[idx // len(self.dataset_t)] data_t = self.dataset_t[idx % len(self.dataset_t)] return PairData( x_s=data_s.x, edge_index_s=data_s.edge_index, edge_attr_s=data_s.edge_attr, x_t=data_t.x, edge_index_t=data_t.edge_index, edge_attr_t=data_t.edge_attr, num_nodes=None, ) def __repr__(self): return '{}({}, {}, sample={})'.format(self.__class__.__name__, self.dataset_s, self.dataset_t, self.sample)
[docs]class ValidPairDataset(torch.utils.data.Dataset): r"""Combines two datasets, a source dataset and a target dataset, by building valid pairs between separate dataset examples. A pair is valid if each node class in the source graph also exists in the target graph. Args: dataset_s (torch.utils.data.Dataset): The source dataset. dataset_t (torch.utils.data.Dataset): The target dataset. sample (bool, optional): If set to :obj:`True`, will sample exactly one target example for every source example instead of holding the product of all source and target examples. (default: :obj:`False`) """ def __init__(self, dataset_s, dataset_t, sample=False): self.dataset_s = dataset_s self.dataset_t = dataset_t self.sample = sample self.pairs, self.cumdeg = self.__compute_pairs__() def __compute_pairs__(self): num_classes = 0 for data in chain(self.dataset_s, self.dataset_t): num_classes = max(num_classes, data.y.max().item() + 1) y_s = torch.zeros((len(self.dataset_s), num_classes), dtype=torch.bool) y_t = torch.zeros((len(self.dataset_t), num_classes), dtype=torch.bool) for i, data in enumerate(self.dataset_s): y_s[i, data.y] = 1 for i, data in enumerate(self.dataset_t): y_t[i, data.y] = 1 y_s = y_s.view(len(self.dataset_s), 1, num_classes) y_t = y_t.view(1, len(self.dataset_t), num_classes) pairs = ((y_s * y_t).sum(dim=-1) == y_s.sum(dim=-1)).nonzero() cumdeg = pairs[:, 0].bincount().cumsum(dim=0) return pairs.tolist(), [0] + cumdeg.tolist() def __len__(self): return len(self.dataset_s) if self.sample else len(self.pairs) def __getitem__(self, idx): if self.sample: data_s = self.dataset_s[idx] i = random.randint(self.cumdeg[idx], self.cumdeg[idx + 1] - 1) data_t = self.dataset_t[self.pairs[i][1]] else: data_s = self.dataset_s[self.pairs[idx][0]] data_t = self.dataset_t[self.pairs[idx][1]] y = data_s.y.new_full((data_t.y.max().item() + 1, ), -1) y[data_t.y] = torch.arange(data_t.num_nodes) y = y[data_s.y] return PairData( x_s=data_s.x, edge_index_s=data_s.edge_index, edge_attr_s=data_s.edge_attr, x_t=data_t.x, edge_index_t=data_t.edge_index, edge_attr_t=data_t.edge_attr, y=y, num_nodes=None, ) def __repr__(self): return '{}({}, {}, sample={})'.format(self.__class__.__name__, self.dataset_s, self.dataset_t, self.sample)