diff --git a/examples/in_memory/datasets.py b/examples/in_memory/datasets.py index a7e24ecf..763f7b5a 100644 --- a/examples/in_memory/datasets.py +++ b/examples/in_memory/datasets.py @@ -12,28 +12,57 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Wraps OGBN and Planetoid datasets to use within tfgnn in-memory example. +"""Infrastructure and implementation of in-memory dataset. -* classes `OgbnDataset` and `PlanetoidDataset`, respectively, wrap datasets of - OGBN and Planetoid. Both classes inherit class - `NodeClassificationDatasetWrapper`. Therefore, they inherit methods - `export_to_graph_tensor` and `iterate_once`, respectively, which return - `GraphTensor` object (that can be fed into TF-GNN model) and return a - tf.data which yields the `GraphTensor` object (once -- you may call .repeat()) +Abstract classes: -* `create_graph_schema_from_directed` creates `tfgnn.GraphSchema` proto. + * `Dataset`: provides nodes, edges, and features, for a heteregenous graph. + * `NodeClassificationDataset`: a `Dataset` that also provides list of + {train, test, validate} nodes, as well as their labels. + * `LinkPredictionDataset`: a `Dataset` that also provides lists of edges for + {train, test, validate}. + + +All `Dataset` implementations automatically inherit abilities of: + + * `as_graph_tensor()` which constructs `GraphTensor` holding entire graph. + * `graph_schema()` returning `GraphSchema` describing `GraphTensor` above. + * More importantly, they can be plugged-into training pipelines, e.g., for + node classification (see `tf_trainer.py` and `keras_trainer.py`). + * In addition, they can be plugged-into in-memory sampling (see + `int_arithmetic_sampler.py`, and example trainer script, + `keras_minibatch_trainer.py`). + + +Concrete implementations: + + * Node classification (inheriting `NodeClassificationDataset`) + + * `OgbnDataset`: Wraps node classification datasets from OGB, i.e., with + name prefix of "ogbn-", such as, "ogbn-arxiv". + + * `PlanetoidDataset`: wraps datasets that are popularized by GCN paper + (cora, citeseer, pubmed). + + * Link Prediction (inherting `LinkPredictionDataset`) + + * `OgblDataset`: Wraps link-prediction datasets from OGB, i.e., with name + prefix of "ogbl-", such as, "ogbl-citation2". """ import collections +import functools import os import pickle import sys from typing import Any, Dict, List, Mapping, MutableMapping, NamedTuple, Optional, Tuple, Union import urllib.request +from absl import logging import apache_beam as beam from apache_beam.options.pipeline_options import PipelineOptions from apache_beam.runners.interactive import interactive_beam as ib import numpy as np +import ogb.linkproppred import ogb.nodeproppred import scipy import tensorflow as tf @@ -43,29 +72,12 @@ Example = tf.train.Example -def get_ogbn_dataset(dataset_name, root_dir): - return ogb.nodeproppred.NodePropPredDataset(dataset_name, root=root_dir) - +class Dataset: + """Abstract class for hold a dataset in-memory.""" -class NodeSplit(NamedTuple): - """Contains 1D int tensors holding positions of {train, valid, test} nodes.""" - train: tf.Tensor - valid: tf.Tensor - test: tf.Tensor - - -class NodeClassificationDatasetWrapper: - """Wraps graph datasets (nodes, edges, features). - - Inheriting classes implement straight-forward functions to adapt any external - dataset into TFGNN, by exposing methods `iterate_once` and - `export_to_graph_tensor` that yield GraphTensor objects that can be passed to - TFGNN's modeling framework. - """ - - def num_classes(self) -> int: - """Number of node classes. Max of `labels` should be `< num_classes`.""" - raise NotImplementedError('num_classes') + def graph_schema( + self, make_undirected: bool = False) -> tfgnn.GraphSchema: + raise NotImplementedError() def node_features_dicts(self, add_id=True) -> Mapping[ tfgnn.NodeSetName, MutableMapping[str, tf.Tensor]]: @@ -84,6 +96,73 @@ def edge_lists(self) -> Mapping[ """ raise NotImplementedError() + def node_sets(self, node_features_dicts_fn=None) -> MutableMapping[ + tfgnn.NodeSetName, tfgnn.NodeSet]: + """Returns node sets of entire graph (dict: node set name -> NodeSet).""" + node_features_dicts_fn = node_features_dicts_fn or self.node_features_dicts + node_counts = self.node_counts() + node_features_dicts = node_features_dicts_fn() + + node_sets = {} + for node_set_name, node_features_dict in node_features_dicts.items(): + node_sets[node_set_name] = tfgnn.NodeSet.from_fields( + sizes=as_tensor([node_counts[node_set_name]]), + features=node_features_dict) + return node_sets + + def edge_sets( + self, add_self_connections: bool = False, + make_undirected: bool = False) -> MutableMapping[ + tfgnn.EdgeSetName, tfgnn.EdgeSet]: + """Returns edge sets of entire graph (dict: edge set name -> EdgeSet).""" + edge_sets = {} + node_counts = self.node_counts() if add_self_connections else None + for edge_type, edge_list in self.edge_lists().items(): + (source_node_set_name, edge_set_name, target_node_set_name) = edge_type + + if make_undirected and source_node_set_name == target_node_set_name: + edge_list = tf.concat([edge_list, edge_list[::-1]], axis=0) + if add_self_connections and source_node_set_name == target_node_set_name: + all_nodes = tf.range(node_counts[source_node_set_name], + dtype=edge_list.dtype) + self_connections = tf.stack([all_nodes, all_nodes], axis=0) + edge_list = tf.concat([edge_list, self_connections], axis=0) + edge_sets[edge_set_name] = tfgnn.EdgeSet.from_fields( + sizes=tf.shape(edge_list)[1:2], + adjacency=tfgnn.Adjacency.from_indices( + source=(source_node_set_name, edge_list[0]), + target=(target_node_set_name, edge_list[1]))) + if not make_undirected: + edge_sets['rev_' + edge_set_name] = tfgnn.EdgeSet.from_fields( + sizes=tf.shape(edge_list)[1:2], + adjacency=tfgnn.Adjacency.from_indices( + source=(target_node_set_name, edge_list[1]), + target=(source_node_set_name, edge_list[0]))) + return edge_sets + + +class NodeSplit(NamedTuple): + """Contains 1D int tensors holding positions of {train, valid, test} nodes. + + This is returned by `NodeClassificationDataset.node_split()` + """ + train: tf.Tensor + valid: tf.Tensor + test: tf.Tensor + + +class NodeClassificationDataset(Dataset): + """Wraps graph datasets (nodes, edges, features). + + Inheriting classes implement straight-forward functions to adapt any external + dataset into TFGNN, by exposing methods `iterate_once` and `as_graph_tensor` + that yield GraphTensor objects that can be passed to TFGNN's models. + """ + + def num_classes(self) -> int: + """Number of node classes. Max of `labels` should be `< num_classes`.""" + raise NotImplementedError('num_classes') + def node_split(self) -> NodeSplit: """Returns dict with keys "train", "valid", "test" to node indices. @@ -111,7 +190,7 @@ def iterate_once(self, add_self_connections: bool = False, split: Union[str, List[str]] = 'train', make_undirected: bool = False) -> tf.data.Dataset: """tf.data iterator with one example containg entire graph (full-batch).""" - graph_tensor = self.export_to_graph_tensor( + graph_tensor = self.as_graph_tensor( add_self_connections, split, make_undirected) spec = graph_tensor.spec @@ -120,7 +199,18 @@ def once(): return tf.data.Dataset.from_generator(once, output_signature=spec) - def export_to_graph_tensor( + def node_feature_dicts_with_labels( + self, split: Union[str, List[str]] = 'train') -> Mapping[ + tfgnn.NodeSetName, MutableMapping[str, tf.Tensor]]: + node_features_dicts = self.node_features_dicts() + splits = split if isinstance(split, (tuple, list)) else [split] + if 'test' in splits: + node_features_dicts[self.labeled_nodeset]['label'] = self.test_labels() + else: + node_features_dicts[self.labeled_nodeset]['label'] = self.labels() + return node_features_dicts + + def as_graph_tensor( self, add_self_connections: bool = False, split: Union[str, List[str]] = 'train', make_undirected: bool = False) -> tfgnn.GraphTensor: @@ -149,52 +239,14 @@ def export_to_graph_tensor( Returns: GraphTensor containing the entire graph at-once. """ - # Prepare node sets, edge sets, context, to construct graph tensor. - ## Node sets. - node_counts = self.node_counts() - node_features_dicts = self.node_features_dicts() - - if not isinstance(split, (tuple, list)): - splits = [split] - else: - splits = split - if 'test' in split: - node_features_dicts[self.labeled_nodeset]['label'] = self.test_labels() - else: - node_features_dicts[self.labeled_nodeset]['label'] = self.labels() - - node_sets = {} - for node_set_name, node_features_dict in node_features_dicts.items(): - node_sets[node_set_name] = tfgnn.NodeSet.from_fields( - sizes=as_tensor([node_counts[node_set_name]]), - features=node_features_dict) - - ## Edge set. - edge_sets = {} - for edge_type, edge_list in self.edge_lists().items(): - (source_node_set_name, edge_set_name, target_node_set_name) = edge_type - - if make_undirected and source_node_set_name == target_node_set_name: - edge_list = tf.concat([edge_list, edge_list[::-1]], axis=0) - if add_self_connections and source_node_set_name == target_node_set_name: - all_nodes = tf.range(node_counts[source_node_set_name], - dtype=edge_list.dtype) - self_connections = tf.stack([all_nodes, all_nodes], axis=0) - edge_list = tf.concat([edge_list, self_connections], axis=0) - edge_sets[edge_set_name] = tfgnn.EdgeSet.from_fields( - sizes=tf.shape(edge_list)[1:2], - adjacency=tfgnn.Adjacency.from_indices( - source=(source_node_set_name, edge_list[0]), - target=(target_node_set_name, edge_list[1]))) - if not make_undirected: - edge_sets['rev_' + edge_set_name] = tfgnn.EdgeSet.from_fields( - sizes=tf.shape(edge_list)[1:2], - adjacency=tfgnn.Adjacency.from_indices( - source=(target_node_set_name, edge_list[1]), - target=(source_node_set_name, edge_list[0]))) - - ## Context. - # Expand seed nodes. + # Node and edge sets. + node_sets = self.node_sets(functools.partial( + self.node_feature_dicts_with_labels, split=split)) + edge_sets = self.edge_sets(add_self_connections=add_self_connections, + make_undirected=make_undirected) + + # Context. + splits = split if isinstance(split, (tuple, list)) else [split] node_split = self.node_split() seed_nodes = tf.concat( [getattr(node_split, split) for split in splits], axis=0) @@ -208,16 +260,39 @@ def export_to_graph_tensor( return graph_tensor - def export_graph_schema( + def graph_schema( self, make_undirected: bool = False) -> tfgnn.GraphSchema: - return create_graph_schema_from_directed( + graph_schema = _create_graph_schema_from_directed( self, make_undirected=make_undirected) + context_features = graph_schema.context.features + context_features['seed_nodes.' + self.labeled_nodeset].dtype = ( + tf.int64.as_datatype_enum) + return graph_schema + +class LinkPredictionDataset(Dataset): + """Superclasses must wrap dataset of graph(s) for link-prediction tasks.""" -def create_graph_schema_from_directed( - dataset: NodeClassificationDatasetWrapper, - make_undirected=False, -) -> tfgnn.GraphSchema: + def as_graph_tensor( + self, add_self_connections: bool = False, + make_undirected: bool = False) -> tfgnn.GraphTensor: + node_sets = self.node_sets() + edge_sets = self.edge_sets(add_self_connections=add_self_connections, + make_undirected=make_undirected) + return tfgnn.GraphTensor.from_pieces( + node_sets=node_sets, edge_sets=edge_sets) + + def graph_schema( + self, make_undirected: bool = False) -> tfgnn.GraphSchema: + return _create_graph_schema_from_directed( + self, make_undirected=make_undirected) + + def edge_split(self): + raise NotImplementedError() + + +def _create_graph_schema_from_directed( + dataset: Dataset, make_undirected=False) -> tfgnn.GraphSchema: """Creates `GraphSchema` proto from directed OGBN graph. Output of this function can be used to create a `tf.TypeSpec` object as: @@ -232,7 +307,7 @@ def create_graph_schema_from_directed( input pipeline. Args: - dataset: NodeClassificationDatasetWrapper. Feature shapes and types returned + dataset: NodeClassificationDataset. Feature shapes and types returned by `dataset.node_features_dict()` will be added to graph schema. make_undirected: If set, only edge with type name 'edges' will be registerd. Otherwise, edges with name 'rev_edges' will additionally be registered in @@ -261,34 +336,123 @@ def create_graph_schema_from_directed( schema.edge_sets['rev_' + edge_set_name].source = dst_node_set_name schema.edge_sets['rev_' + edge_set_name].target = src_node_set_name - schema.context.features['seed_nodes.' + dataset.labeled_nodeset].dtype = ( - tf.int64.as_datatype_enum) - return schema -class _OgbnGraph(NamedTuple): - # Maps "node set name" -> number of nodes. - num_nodes_dict: Mapping[str, int] +class _OgbGraph: + """Wraps data exposed by OGB graph objects, while enforcing heterogeneity.""" + + @property + def num_nodes_dict(self) -> Mapping[str, int]: + """Maps "node set name" -> number of nodes.""" + return self._num_nodes_dict + + @property + def node_feat_dict(self) -> Mapping[str, MutableMapping[str, tf.Tensor]]: + """Maps "node set name" to dict of "feature name"->tf.Tensor.""" + return self._node_feat_dict + + @property + def edge_index_dict(self) -> Mapping[ + Tuple[tfgnn.NodeSetName, tfgnn.EdgeSetName, tfgnn.NodeSetName], + tf.Tensor]: + """Adjacency lists for all edge sets. + + Returns: + Dict (source node set name, edge set name, target node set name) -> edges. + Where `edges` is tf.Tensor of shape (2, num edges), with `edges[0]` and + `edges[1]`, respectively, containing source and target node IDs (as 1D int + tf.Tensor). + """ + return self._edge_index_dict + + def __init__(self, graph: Mapping[str, Any]): + if 'edge_index_dict' in graph: # Heterogeneous graph + assert 'num_nodes_dict' in graph + assert 'node_feat_dict' in graph + + # node set name -> feature name -> feature matrix (numNodes x featDim). + node_set = {node_set_name: {'feat': as_tensor(feat)} + for node_set_name, feat in graph['node_feat_dict'].items() + if feat is not None} + # Populate remaining features + for key, node_set_name_to_feat in graph.items(): + if key.startswith('node_') and key != 'node_feat_dict': + feat_name = key.split('node_', 1)[-1] + for node_set_name, feat in node_set_name_to_feat.items(): + node_set[node_set_name][feat_name] = as_tensor(feat) + self._num_nodes_dict = graph['num_nodes_dict'] + self._node_feat_dict = node_set + self._edge_index_dict = tf.nest.map_structure( + as_tensor, graph['edge_index_dict']) + else: # Homogenous graph. Make heterogeneous. + if graph.get('node_feat', None) is not None: + node_features = { + tfgnn.NODES: {'feat': as_tensor(graph['node_feat'])} + } + else: + node_features = { + tfgnn.NODES: { + 'feat': tf.zeros([graph['num_nodes'], 0], dtype=tf.float32) + } + } + + self._edge_index_dict = { + (tfgnn.NODES, tfgnn.EDGES, tfgnn.NODES): as_tensor( + graph['edge_index']), + } + self._num_nodes_dict = {tfgnn.NODES: graph['num_nodes']} + self._node_feat_dict = node_features + + +class OgblDataset(LinkPredictionDataset): + """Wraps link prediction datasets of ogbl-* for in-memory learning.""" + + def __init__(self, dataset_name, cache_dir=None): + if cache_dir is None: + cache_dir = os.environ.get( + 'OGB_CACHE_DIR', os.path.expanduser(os.path.join('~', 'data', 'ogb'))) + + self.ogb_dataset = ogb.linkproppred.LinkPropPredDataset( + dataset_name, root=cache_dir) + + # dict with keys 'train', 'valid', 'test' + self._edge_split = self.ogb_dataset.get_edge_split() + self.ogb_graph = _OgbGraph(self.ogb_dataset.graph) + + def node_features_dicts(self, add_id=True) -> Mapping[ + tfgnn.NodeSetName, MutableMapping[str, tf.Tensor]]: + features = self.ogb_graph.node_feat_dict + features = {node_set_name: {feat: value for feat, value in features.items()} + for node_set_name, features in features.items()} + if add_id: + counts = self.node_counts() + for node_set_name, feats in features.items(): + feats['#id'] = tf.range(counts[node_set_name], dtype=tf.int32) + return features - # Maps "node set name" to dict of "feature name"->tf.Tensor. - node_feat_dict: Mapping[str, MutableMapping[str, tf.Tensor]] + def node_counts(self) -> Mapping[tfgnn.NodeSetName, int]: + return self.ogb_graph.num_nodes_dict + + def edge_lists(self) -> Mapping[ + Tuple[tfgnn.NodeSetName, tfgnn.EdgeSetName, tfgnn.NodeSetName], + tf.Tensor]: + return self.ogb_graph.edge_index_dict - # maps (source node set name, edge set name, target node set name) -> edges, - # where edges is tf.Tensor of shape (2, num edges). - edge_index_dict: Mapping[ - Tuple[tfgnn.NodeSetName, tfgnn.EdgeSetName, tfgnn.NodeSetName], tf.Tensor] + def edge_split(self): + return self._edge_split -class OgbnDataset(NodeClassificationDatasetWrapper): - """Wraps OGBN dataset for in-memory learning.""" +class OgbnDataset(NodeClassificationDataset): + """Wraps node classification datasets of ogbn-* for in-memory learning.""" def __init__(self, dataset_name, cache_dir=None): if cache_dir is None: cache_dir = os.environ.get( 'OGB_CACHE_DIR', os.path.expanduser(os.path.join('~', 'data', 'ogb'))) - self.ogb_dataset = get_ogbn_dataset(dataset_name, cache_dir) + self.ogb_dataset = ogb.nodeproppred.NodePropPredDataset( + dataset_name, root=cache_dir) self._graph, self._node_labels, self._node_split, self._labeled_nodeset = ( OgbnDataset._to_heterogenous(self.ogb_dataset)) @@ -305,9 +469,9 @@ def __init__(self, dataset_name, cache_dir=None): @staticmethod def _to_heterogenous( ogb_dataset: ogb.nodeproppred.NodePropPredDataset) -> Tuple[ - _OgbnGraph, # graph_dict. + _OgbGraph, # ogb_graph. np.ndarray, # node_labels. - NodeSplit, # idx_split. + NodeSplit, # idx_split. str]: """Returns heterogeneous dicts from homogenous or heterogeneous dataset. @@ -319,8 +483,8 @@ def _to_heterogenous( node set will be named "nodes" and the edge set will be named "edges". Returns: - tuple: `(graph_dict, node_labels, idx_split, labeled_nodeset)`, where: - `graph_dict` is instance of _OgbnGraph. + tuple: `(ogb_graph, node_labels, idx_split, labeled_nodeset)`, where: + `ogb_graph` is instance of _OgbGraph. `node_labels`: np.array of labels, with .shape[0] equals number of nodes in node set with name `labeled_nodeset`. `idx_split`: instance of NodeSplit. Members `train`, `test` and `valid`, @@ -330,7 +494,8 @@ def _to_heterogenous( designed over. """ graph, node_labels = ogb_dataset[0] - if 'edge_index_dict' in graph: # Graph is already heterogeneous + ogb_graph = _OgbGraph(graph) + if 'edge_index_dict' in graph: # Graph is heterogeneous assert 'num_nodes_dict' in graph assert 'node_feat_dict' in graph labeled_nodeset = list(node_labels.keys()) @@ -346,40 +511,17 @@ def _to_heterogenous( idx_split = {split_name: as_tensor(split_dict[labeled_nodeset]) for split_name, split_dict in idx_split.items()} idx_split = NodeSplit(**idx_split) - # node set name -> feature name -> feature matrix (numNodes x featDim). - node_set = {node_set_name: {'feat': as_tensor(feat)} - for node_set_name, feat in graph['node_feat_dict'].items()} - # Populate remaining features - for key, node_set_name_to_feat in graph.items(): - if key.startswith('node_') and key != 'node_feat_dict': - feat_name = key.split('node_', 1)[-1] - for node_set_name, feat in node_set_name_to_feat.items(): - node_set[node_set_name][feat_name] = as_tensor(feat) - ogbn_graph = _OgbnGraph( - num_nodes_dict=graph['num_nodes_dict'], - node_feat_dict=node_set, - edge_index_dict={k: as_tensor(v) - for k, v in graph['edge_index_dict'].items()}) - - return ogbn_graph, node_labels, idx_split, labeled_nodeset - - # Homogenous graph. Make heterogeneous. - ogbn_graph = _OgbnGraph( - edge_index_dict={ - (tfgnn.NODES, tfgnn.EDGES, tfgnn.NODES): as_tensor( - graph['edge_index']), - }, - num_nodes_dict={tfgnn.NODES: graph['num_nodes']}, - node_feat_dict={tfgnn.NODES: {'feat': as_tensor(graph['node_feat'])}}, - ) + + return ogb_graph, node_labels, idx_split, labeled_nodeset + # Copy other node information. for key, value in graph.items(): if key != 'node_feat' and key.startswith('node_'): key = key.split('node_', 1)[-1] - ogbn_graph.node_feat_dict[tfgnn.NODES][key] = as_tensor(value) + ogb_graph.node_feat_dict[tfgnn.NODES][key] = as_tensor(value) idx_split = NodeSplit(**tf.nest.map_structure( tf.convert_to_tensor, ogb_dataset.get_idx_split())) - return ogbn_graph, node_labels, idx_split, tfgnn.NODES + return ogb_graph, node_labels, idx_split, tfgnn.NODES def num_classes(self) -> int: return self.ogb_dataset.num_classes @@ -434,7 +576,7 @@ def _maybe_download_file(source_url, destination_path, make_dirs=True): fout.write(fin.read()) -class PlanetoidDataset(NodeClassificationDatasetWrapper): +class PlanetoidDataset(NodeClassificationDataset): """Wraps Planetoid node-classificaiton datasets. These datasets first appeared in the Planetoid [1] paper and popularized by @@ -559,9 +701,11 @@ def test_labels(self) -> tf.Tensor: return self._node_labels -def get_dataset(dataset_name): +def get_dataset(dataset_name) -> Dataset: if dataset_name.startswith('ogbn-'): return OgbnDataset(dataset_name) + elif dataset_name.startswith('ogbl-'): + return OgblDataset(dataset_name) elif dataset_name in ('cora', 'citeseer', 'pubmed'): return PlanetoidDataset(dataset_name) else: @@ -574,13 +718,28 @@ def as_tensor(obj: Any) -> tf.Tensor: return tf.convert_to_tensor(obj) -class UnigraphInMemeoryDataset(NodeClassificationDatasetWrapper): +# Copied from cs/third_party/py/tensorflow_gnn/graph/graph_tensor_random.py +def _get_feature_values(feature: tf.train.Feature) -> Union[List[str], + List[int], + List[float]]: + """Return the values from a TF feature proto.""" + if feature.HasField('float_list'): + return list(feature.float_list.value) + elif feature.HasField('int64_list'): + return list(feature.int64_list.value) + elif feature.HasField('bytes_list'): + return list(feature.bytes_list.value) + return [] + + +class UnigraphInMemeoryDataset(Dataset): """Implementation of in-memory dataset loader for unigraph format.""" def __init__(self, graph_schema: tfgnn.GraphSchema, graph_directory: Optional[str] = None, - pipeline_options: Optional[PipelineOptions] = None): + pipeline_options: Optional[PipelineOptions] = None, + autocompress=True): """Constructor to represent a unigraph in memory. Args: @@ -589,11 +748,24 @@ def __init__(self, relative paths. pipeline_options: Optional beam pipeline options that can be passed to the interactive runner. `pipeline_options` will probably not be needed. + autocompress: If set (default), populates tf.Tensors that are needed for + model training and sampling, keeping not the intermediate data + structures. Once graph is compressed, `.get_adjacency_list()` and + `.node_features` can no longer be accessed. Constructing with + `autocompress=False` then calling `.compress()` is equivalent to setting + `autocompress=True`. """ - self.graph_schema = graph_schema + self._graph_schema = graph_schema self.graph_directory = graph_directory self.pipeline_options = pipeline_options + # Node Set Name -> Node ID -> auto-incrementing int (`node_idx`)`. + self.compression_maps: Dict[ + tfgnn.NodeSetName, Dict[bytes, int]] = collections.defaultdict(dict) + # Node Set Name -> list of Node ID (from above) sorted per int (`node_idx`). + self.rev_compression_maps: Dict[ + tfgnn.NodeSetName, List[bytes]] = collections.defaultdict(list) + # Mapping from node set name to a mapping of node id to tf.train.Example # pairs. self.node_features: Dict[str, Dict[bytes, tf.train.Example]] = {} @@ -604,13 +776,16 @@ def __init__(self, with beam.Pipeline( runner='InteractiveRunner', options=pipeline_options) as p: graph_pcoll: Dict[str, Dict[str, beam.PCollection]] = unigraph.read_graph( - self.graph_schema, self.graph_directory, p) - + graph_schema, self.graph_directory, p) for node_set_name, ns in graph_pcoll[tfgnn.NODES].items(): self.node_features[node_set_name] = {} df = ib.collect(ns) - for node_id, example in zip(df[0], df[1]): + for node_order, (node_id, example) in enumerate(zip(df[0], df[1])): + if node_id in self.node_features[node_set_name]: + raise ValueError('More than one node with ID %s' % node_id) self.node_features[node_set_name][node_id] = example + self.compression_maps[node_set_name][node_id] = node_order + self.rev_compression_maps[node_set_name].append(node_id) for edge_set_name, es in graph_pcoll[tfgnn.EDGES].items(): self.flat_edge_list[edge_set_name] = [] @@ -618,6 +793,91 @@ def __init__(self, for src, target, example in zip(df[0], df[1], df[2]): self.flat_edge_list[edge_set_name].append((src, target, example)) + self._node_features_dict: Dict[tfgnn.NodeSetName, Dict[str, tf.Tensor]] = {} + self._edge_lists: Dict[ + Tuple[tfgnn.NodeSetName, tfgnn.EdgeSetName, tfgnn.NodeSetName], + tf.Tensor] = {} + + if autocompress: + self.compress() + + def compress(self, cleanup=False): + """Creates compression map from nodes to store edge endpoints as ints. + + Calling this enables functions `.edge_lists()` and `.node_features_dicts()`. + + Args: + cleanup: If set, data structures will be removed, making function + `.adjacency()` and member `.node_features` return empty results. + """ + schema = self.graph_schema() + # Node set name -> feature name -> feature tensor. + # All features under a node set must have same `feature_tensor.shape[0]`. + node_features_dict: Dict[ + tfgnn.NodeSetName, Dict[str, List[np.ndarray]]] = {} + node_features_dict = collections.defaultdict( + lambda: collections.defaultdict(list)) + + for node_set_name, node_order in self.rev_compression_maps.items(): + feature_schema = schema.node_sets[node_set_name] + for node_id in node_order: + example = self.node_features[node_set_name][node_id] + for feature_name, feature_value in example.features.feature.items(): + np_feature_value = np.array(_get_feature_values(feature_value)) + np_feature_value = np_feature_value.reshape( + feature_schema.features[feature_name].shape.dim) + node_features_dict[node_set_name][feature_name].append( + np_feature_value) + + self._node_features_dict = {} + for node_set_name, feature_dict in node_features_dict.items(): + self._node_features_dict[node_set_name] = {} + for feature_name, list_np_feature_values in feature_dict.items(): + feature_tensor = tf.convert_to_tensor( + np.stack(list_np_feature_values, axis=0)) + self._node_features_dict[node_set_name][feature_name] = feature_tensor + + edge_lists: Dict[ + Tuple[tfgnn.NodeSetName, tfgnn.EdgeSetName, tfgnn.NodeSetName], + List[np.ndarray]] = collections.defaultdict(list) + for edge_set_name, connection in self.flat_edge_list.items(): + source_node_set_name = schema.edge_sets[edge_set_name].source + target_node_set_name = schema.edge_sets[edge_set_name].target + edge_key = (source_node_set_name, edge_set_name, target_node_set_name) + # + for source_id, target_id, example in connection: + for feature_name, feature_value in example.features.feature.items(): + if feature_name in ('#source', '#target'): + continue + logging.warn('Ignoring all edge features, including feature (%s) for ' + 'edge set (%s)', feature_name, edge_set_name) + edge_endpoints = ( + self._compression_id(source_node_set_name, source_id), + self._compression_id(target_node_set_name, target_id)) + edge_lists[edge_key].append(np.array(edge_endpoints)) + + # Mapping from an edge set name to a list of [src, target] pairs + self._edge_lists = {} + for edge_key, list_np_edge_list in edge_lists.items(): + self._edge_lists[edge_key] = tf.convert_to_tensor( + np.stack(list_np_edge_list, -1)) + + # TODO(haija): Ensure that all features are populated, in case + # self._compression_id, when assembling `edge_endpoints`, has invented new + # nodes. In this case, we should pad with zeros. Ask bmayer@ if needed. + + if cleanup: + self.flat_edge_list = {} + self.node_features = {} + + def _compression_id(self, node_set_name, node_id): + if node_id not in self.compression_maps[node_set_name]: + next_int = len(self.compression_maps[node_set_name]) + self.compression_maps[node_set_name][node_id] = next_int + self.rev_compression_maps[node_set_name].append(node_id) + + return self.compression_maps[node_set_name][node_id] + def get_adjacency_list(self) -> Dict[tfgnn.EdgeSetName, Dict[str, Example]]: """Returns weighted edges as an adjacency list of nested dictionaries. @@ -632,3 +892,35 @@ def get_adjacency_list(self) -> Dict[tfgnn.EdgeSetName, Dict[str, Example]]: adjacency_sets[edge_set_name][source][target] = example return adjacency_sets + + def graph_schema(self, make_undirected: bool = False) -> tfgnn.GraphSchema: + return self._graph_schema + + def node_features_dicts(self, add_id=True) -> Mapping[ + tfgnn.NodeSetName, MutableMapping[str, tf.Tensor]]: + del add_id # Features should already have '#id' field, with dtype string. + return self._node_features_dict + + def node_counts(self) -> Mapping[tfgnn.NodeSetName, int]: + """Returns total number of graph nodes per node set.""" + return {node_set_name: len(ids) + for node_set_name, ids in self.rev_compression_maps.items()} + + def edge_lists(self) -> Mapping[ + Tuple[tfgnn.NodeSetName, tfgnn.EdgeSetName, tfgnn.NodeSetName], + tf.Tensor]: + """Returns dict from "edge type tuple" to int array of shape (2, num_edges). + + "edge type tuple" string-tuple: (src_node_set, edge_set, target_node_set). + """ + return self._edge_lists + + def as_graph_tensor( + self, add_self_connections: bool = False, + make_undirected: bool = False) -> tfgnn.GraphTensor: + node_sets = self.node_sets() + edge_sets = self.edge_sets(add_self_connections=add_self_connections, + make_undirected=make_undirected) + return tfgnn.GraphTensor.from_pieces( + node_sets=node_sets, edge_sets=edge_sets) + diff --git a/examples/in_memory/int_arithmetic_sampler.py b/examples/in_memory/int_arithmetic_sampler.py index 8e9e7db9..48c954fa 100644 --- a/examples/in_memory/int_arithmetic_sampler.py +++ b/examples/in_memory/int_arithmetic_sampler.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== + +# DO NOT SUBMIT: Edit doc to talk about `make_node_classification_tf_dataset`. + r"""Random tree walks to make GraphTensor of subgraphs rooted at seed nodes. The entry point is method `make_sampled_subgraphs_dataset()`, which accepts as @@ -351,11 +354,11 @@ class GraphSampler: Sub-graphs are encoded as `GraphTensor` or tf.data.Dataset. Random walks are performed using `TypedWalkTree`. Input data graph must be an instance of - `NodeClassificationDatasetWrapper` + `Dataset`. """ def __init__(self, - dataset: datasets.NodeClassificationDatasetWrapper, + dataset: datasets.Dataset, make_undirected: bool = False, ensure_self_loops: bool = False, reduce_memory_footprint: bool = True, @@ -506,30 +509,6 @@ def sample_one_hop( return next_nodes - def generate_subgraphs( - self, batch_size: int, - sampling_spec: sampling_spec_pb2.SamplingSpec, - split: str = 'train', - sampling=EdgeSampling.WITH_REPLACEMENT): - """Infinitely yields random subgraphs each rooted on node in train set.""" - if isinstance(split, bytes): - split = split.decode() - if not isinstance(split, (tuple, list)): - split = (split,) - - partitions = self.dataset.node_split() - - node_ids = tf.concat([getattr(partitions, s) for s in split], 0) - queue = tf.random.shuffle(node_ids) - - while True: - while queue.shape[0] < batch_size: - queue = tf.concat([queue, tf.random.shuffle(node_ids)], axis=0) - batch = queue[:batch_size] - queue = queue[batch_size:] - yield self.sample_sub_graph_tensor( - batch, sampling_spec=sampling_spec, sampling=sampling) - def random_walk_tree( self, node_idx: tf.Tensor, sampling_spec: sampling_spec_pb2.SamplingSpec, sampling: EdgeSampling = EdgeSampling.WITH_REPLACEMENT) -> TypedWalkTree: @@ -609,14 +588,52 @@ def gather_node_features_dict(self, node_set_name, node_idx): features = self.dataset.node_features_dicts(add_id=True)[node_set_name] features = {feature_name: tf.gather(feature_value, node_idx) for feature_name, feature_value in features.items()} + return features + + +class NodeClassificationGraphSampler(GraphSampler): + """Sampler returning subgraphs for node-classification in-memory datasets.""" + + def __init__(self, + dataset: datasets.NodeClassificationDataset, + **sampler_kwargs): + super().__init__(dataset, **sampler_kwargs) + self.dataset = dataset + + def generate_subgraphs( + self, batch_size: int, + sampling_spec: sampling_spec_pb2.SamplingSpec, + split: str = 'train', + sampling=EdgeSampling.WITH_REPLACEMENT): + """Infinitely yields random subgraphs each rooted on node in train set.""" + if isinstance(split, bytes): + split = split.decode() + if not isinstance(split, (tuple, list)): + split = (split,) + + partitions = self.dataset.node_split() + + node_ids = tf.concat([getattr(partitions, s) for s in split], 0) + queue = tf.random.shuffle(node_ids) + + while True: + while queue.shape[0] < batch_size: + queue = tf.concat([queue, tf.random.shuffle(node_ids)], axis=0) + batch = queue[:batch_size] + queue = queue[batch_size:] + yield self.sample_sub_graph_tensor( + batch, sampling_spec=sampling_spec, sampling=sampling) + + def gather_node_features_dict(self, node_set_name, node_idx): + features = super().gather_node_features_dict(node_set_name, node_idx) if node_set_name == self.dataset.labeled_nodeset: features['label'] = tf.gather(self.dataset.labels(), node_idx) return features -def make_sampled_subgraphs_dataset( - dataset: datasets.NodeClassificationDatasetWrapper, +def make_node_classification_tf_dataset( + dataset: datasets.NodeClassificationDataset, sampling_spec: sampling_spec_pb2.SamplingSpec, batch_size: int = 64, split='train', @@ -624,7 +641,8 @@ def make_sampled_subgraphs_dataset( sampling=EdgeSampling.WITH_REPLACEMENT ) -> Tuple[tf.TensorSpec, tf.data.Dataset]: """Infinite tf.data.Dataset wrapping generate_subgraphs.""" - subgraph_generator = GraphSampler(dataset, make_undirected=make_undirected) + subgraph_generator = NodeClassificationGraphSampler( + dataset, make_undirected=make_undirected) relaxed_spec = None for graph_tensor in subgraph_generator.generate_subgraphs( batch_size, split=split, sampling_spec=sampling_spec, sampling=sampling): diff --git a/examples/in_memory/keras_minibatch_trainer.py b/examples/in_memory/keras_minibatch_trainer.py index d0706aa1..f3005d9d 100644 --- a/examples/in_memory/keras_minibatch_trainer.py +++ b/examples/in_memory/keras_minibatch_trainer.py @@ -30,7 +30,7 @@ import tensorflow_gnn as tfgnn import datasets -from tensorflow_gnn.examples.in_memory import int_arithmetic_sampler +from tensorflow_gnn.examples.in_memory import int_arithmetic_sampler as ia_sampler import models import reader_utils from tensorflow_gnn.sampler import sampling_spec_builder @@ -58,15 +58,15 @@ def main(unused_argv): - dataset_wrapper = datasets.get_dataset(FLAGS.dataset) - num_classes = dataset_wrapper.num_classes() + dataset = datasets.get_dataset(FLAGS.dataset) + assert isinstance(dataset, datasets.NodeClassificationDataset) + num_classes = dataset.num_classes() model_kwargs = json.loads(FLAGS.model_kwargs_json) prefers_undirected, model = models.make_model_by_name( FLAGS.model, num_classes, l2_coefficient=FLAGS.l2_regularization, model_kwargs=model_kwargs) - graph_schema = dataset_wrapper.export_graph_schema( - make_undirected=prefers_undirected) + graph_schema = dataset.graph_schema(make_undirected=prefers_undirected) type_spec = tfgnn.create_graph_spec_from_schema_pb(graph_schema) input_graph = tf.keras.layers.Input(type_spec=type_spec) @@ -94,20 +94,20 @@ def init_node_state(node_set, node_set_name): # Subgraph samples for training. train_sampling_spec = (sampling_spec_builder.SamplingSpecBuilder(graph_schema) .seed().sample([3, 3]).to_sampling_spec()) - _, train_dataset = int_arithmetic_sampler.make_sampled_subgraphs_dataset( - dataset_wrapper, sampling_spec=train_sampling_spec, + _, train_dataset = ia_sampler.make_node_classification_tf_dataset( + dataset, sampling_spec=train_sampling_spec, batch_size=FLAGS.batch_size, - sampling=int_arithmetic_sampler.EdgeSampling.WITH_REPLACEMENT, + sampling=ia_sampler.EdgeSampling.WITH_REPLACEMENT, make_undirected=prefers_undirected) train_labels_dataset = train_dataset.map( functools.partial(reader_utils.pair_graphs_with_labels, num_classes)) # Subgraph samples for validation. - _, validation_ds = int_arithmetic_sampler.make_sampled_subgraphs_dataset( - dataset_wrapper, sampling_spec=train_sampling_spec, + _, validation_ds = ia_sampler.make_node_classification_tf_dataset( + dataset, sampling_spec=train_sampling_spec, batch_size=FLAGS.batch_size, - sampling=int_arithmetic_sampler.EdgeSampling.WITHOUT_REPLACEMENT, + sampling=ia_sampler.EdgeSampling.WITHOUT_REPLACEMENT, split='valid', make_undirected=prefers_undirected) validation_ds = validation_ds.map( functools.partial(reader_utils.pair_graphs_with_labels, num_classes)) @@ -123,7 +123,7 @@ def init_node_state(node_set, node_set_name): validation_steps=10, validation_freq=FLAGS.eval_every) - test_graph = dataset_wrapper.export_to_graph_tensor( + test_graph = dataset.as_graph_tensor( split='test', make_undirected=prefers_undirected) test_graph, test_labels = reader_utils.pair_graphs_with_labels( num_classes, test_graph) diff --git a/examples/in_memory/keras_trainer.py b/examples/in_memory/keras_trainer.py index 820857aa..802664fb 100644 --- a/examples/in_memory/keras_trainer.py +++ b/examples/in_memory/keras_trainer.py @@ -43,7 +43,7 @@ flags.DEFINE_string('model_kwargs_json', '{}', 'JSON object encoding model arguments') flags.DEFINE_integer('eval_every', 10, 'Eval every this many steps.') -flags.DEFINE_float('learning_rate', 1e-3, 'Learning rate.') +flags.DEFINE_float('learning_rate', 1e-2, 'Learning rate.') flags.DEFINE_float('l2_regularization', 1e-5, 'L2 Regularization for (non-bias) weights.') flags.DEFINE_integer('steps', 101, @@ -56,15 +56,15 @@ def main(unused_argv): - dataset_wrapper = datasets.get_dataset(FLAGS.dataset) - num_classes = dataset_wrapper.num_classes() + dataset = datasets.get_dataset(FLAGS.dataset) + assert isinstance(dataset, datasets.NodeClassificationDataset) + num_classes = dataset.num_classes() model_kwargs = json.loads(FLAGS.model_kwargs_json) prefers_undirected, model = models.make_model_by_name( FLAGS.model, num_classes, l2_coefficient=FLAGS.l2_regularization, model_kwargs=model_kwargs) - graph_schema = datasets.create_graph_schema_from_directed( - dataset_wrapper, make_undirected=prefers_undirected) + graph_schema = dataset.graph_schema(make_undirected=prefers_undirected) type_spec = tfgnn.create_graph_spec_from_schema_pb(graph_schema) input_graph = tf.keras.layers.Input(type_spec=type_spec) @@ -92,34 +92,35 @@ def init_node_state(node_set, node_set_name): train_split = ['train'] if FLAGS.train_on_validation: train_split.append('valid') - train_dataset = dataset_wrapper.iterate_once( - split='train', make_undirected=prefers_undirected) + train_dataset = dataset.iterate_once( + split=train_split, make_undirected=prefers_undirected) train_labels_dataset = train_dataset.map( functools.partial(reader_utils.pair_graphs_with_labels, num_classes)) # Similarly for validation. valid_split = 'test' if FLAGS.train_on_validation else 'valid' - validation_ds = dataset_wrapper.iterate_once( + validation_ds = dataset.iterate_once( split=valid_split, make_undirected=prefers_undirected) validation_ds = validation_ds.map( functools.partial(reader_utils.pair_graphs_with_labels, num_classes)) validation_repeated_ds = validation_ds.repeat() + start_alsologtostderr = FLAGS.alsologtostderr FLAGS.alsologtostderr = True # To print accuracy and training progress. keras_model.fit( train_labels_dataset, epochs=FLAGS.steps, validation_data=validation_repeated_ds, validation_steps=1, validation_freq=FLAGS.eval_every) - test_graph = dataset_wrapper.export_to_graph_tensor( + test_graph = dataset.as_graph_tensor( split='test', make_undirected=prefers_undirected) test_graph, test_labels = reader_utils.pair_graphs_with_labels( num_classes, test_graph) accuracy = (tf.argmax(keras_model(test_graph), 1) == tf.argmax(test_labels, 1)).numpy().mean() - - print('Final test accuracy=%f' % accuracy) + FLAGS.alsologtostderr = start_alsologtostderr + print('\n\n ****** \n\n Final test accuracy=%f \n\n' % accuracy) if __name__ == '__main__': app.run(main) diff --git a/examples/in_memory/tf_trainer.py b/examples/in_memory/tf_trainer.py index b1a20445..2f5f8f25 100644 --- a/examples/in_memory/tf_trainer.py +++ b/examples/in_memory/tf_trainer.py @@ -59,8 +59,9 @@ def main(unused_argv): - dataset_wrapper = datasets.get_dataset(FLAGS.dataset) - num_classes = dataset_wrapper.num_classes() + dataset = datasets.get_dataset(FLAGS.dataset) + assert isinstance(dataset, datasets.NodeClassificationDataset) + num_classes = dataset.num_classes() model_kwargs = json.loads(FLAGS.model_kwargs_json) prefers_undirected, model = models.make_model_by_name( FLAGS.model, num_classes, l2_coefficient=FLAGS.l2_regularization, @@ -69,7 +70,7 @@ def main(unused_argv): train_split = ['train'] if FLAGS.train_on_validation: train_split.append('valid') - graph_tensor = dataset_wrapper.export_to_graph_tensor( + graph_tensor = dataset.as_graph_tensor( split=train_split, make_undirected=prefers_undirected) graph_tensor, seed_y = reader_utils.pair_graphs_with_labels( num_classes, graph_tensor) @@ -88,7 +89,7 @@ def init_node_state(node_set, node_set_name): def train_step(): with tf.GradientTape() as tape: # Model output. - model_out_graph_tensor = model(graph_tensor) + model_out_graph_tensor = model(graph_tensor, training=True) seed_logits = reader_utils.readout_seed_node_features( model_out_graph_tensor) # Compare with ground-truth. @@ -102,14 +103,14 @@ def train_step(): opt.apply_gradients(zip(gradients, model.trainable_variables)) valid_split = 'test' if FLAGS.train_on_validation else 'valid' - valid_graph = dataset_wrapper.export_to_graph_tensor( + valid_graph = dataset.as_graph_tensor( split=valid_split, make_undirected=prefers_undirected) valid_graph, valid_y = reader_utils.pair_graphs_with_labels( num_classes, valid_graph) valid_graph = tfgnn.keras.layers.MapFeatures(node_sets_fn=init_node_state)( valid_graph) - test_graph = dataset_wrapper.export_to_graph_tensor( + test_graph = dataset.as_graph_tensor( split='test', make_undirected=prefers_undirected) test_graph, test_y = reader_utils.pair_graphs_with_labels( num_classes, test_graph) @@ -117,7 +118,7 @@ def train_step(): test_graph) def estimate_validation_accuracy(y=valid_y, graph=valid_graph): - graph = model(graph) + graph = model(graph, training=False) predictions = tf.argmax( reader_utils.readout_seed_node_features(graph), axis=1) labels = tf.argmax(y, axis=1)