From 697356dd1875859a9656d5e5ef431f4645f3f22b Mon Sep 17 00:00:00 2001 From: Zhong Hui Date: Tue, 16 Mar 2021 20:46:55 +0800 Subject: [PATCH 1/2] [Model][LGNN] Add line graph nerual network. --- examples/lgnn/README.md | 9 ++ examples/lgnn/cora_binary.py | 164 ++++++++++++++++++++++++++++++++ examples/lgnn/model.py | 103 ++++++++++++++++++++ examples/lgnn/train.py | 179 +++++++++++++++++++++++++++++++++++ 4 files changed, 455 insertions(+) create mode 100644 examples/lgnn/README.md create mode 100644 examples/lgnn/cora_binary.py create mode 100644 examples/lgnn/model.py create mode 100644 examples/lgnn/train.py diff --git a/examples/lgnn/README.md b/examples/lgnn/README.md new file mode 100644 index 00000000..29695c4e --- /dev/null +++ b/examples/lgnn/README.md @@ -0,0 +1,9 @@ +# Line Graph Neural Networks + +[Line Graph Neural Networks](https://arxiv.org/pdf/1705.08415.pdf) is an neural network for community detection. + +## How to run + +```shell +python train.py +``` diff --git a/examples/lgnn/cora_binary.py b/examples/lgnn/cora_binary.py new file mode 100644 index 00000000..c9c8b7db --- /dev/null +++ b/examples/lgnn/cora_binary.py @@ -0,0 +1,164 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Cora Binary dataset +import os +import sys + +import numpy as np +import pickle as pkl +import networkx as nx +import scipy.sparse as sp + +from .utils import save_graphs, load_graphs, save_info, load_info, makedirs, _get_dgl_url +from .utils import generate_mask_tensor +from .utils import deprecate_property, deprecate_function +from .dgl_dataset import DGLBuiltinDataset +from .. import convert +from .. import batch +from .. import backend as F +from ..convert import graph as dgl_graph +from ..convert import from_networkx, to_networkx + +backend = os.environ.get('DGLBACKEND', 'pytorch') + + +def _pickle_load(pkl_file): + if sys.version_info > (3, 0): + return pkl.load(pkl_file, encoding='latin1') + else: + return pkl.load(pkl_file) + + +class CoraBinary(DGLBuiltinDataset): + """A mini-dataset for binary classification task using Cora. + After loaded, it has following members: + graphs : list of :class:`~dgl.DGLGraph` + pmpds : list of :class:`scipy.sparse.coo_matrix` + labels : list of :class:`numpy.ndarray` + Parameters + ----------- + raw_dir : str + Raw file directory to download/contains the input data directory. + Default: ~/.dgl/ + force_reload : bool + Whether to reload the dataset. Default: False + verbose: bool + Whether to print out progress information. Default: True. + """ + + def __init__(self, raw_dir=None, force_reload=False, verbose=True): + name = 'cora_binary' + url = _get_dgl_url('dataset/cora_binary.zip') + super(CoraBinary, self).__init__( + name, + url=url, + raw_dir=raw_dir, + force_reload=force_reload, + verbose=verbose) + + def process(self): + root = self.raw_path + # load graphs + self.graphs = [] + with open("{}/graphs.txt".format(root), 'r') as f: + elist = [] + for line in f.readlines(): + if line.startswith('graph'): + if len(elist) != 0: + self.graphs.append(dgl_graph(tuple(zip(*elist)))) + elist = [] + else: + u, v = line.strip().split(' ') + elist.append((int(u), int(v))) + if len(elist) != 0: + self.graphs.append(dgl_graph(tuple(zip(*elist)))) + with open("{}/pmpds.pkl".format(root), 'rb') as f: + self.pmpds = _pickle_load(f) + self.labels = [] + with open("{}/labels.txt".format(root), 'r') as f: + cur = [] + for line in f.readlines(): + if line.startswith('graph'): + if len(cur) != 0: + self.labels.append(np.asarray(cur)) + cur = [] + else: + cur.append(int(line.strip())) + if len(cur) != 0: + self.labels.append(np.asarray(cur)) + # sanity check + assert len(self.graphs) == len(self.pmpds) + assert len(self.graphs) == len(self.labels) + + def has_cache(self): + graph_path = os.path.join(self.save_path, self.save_name + '.bin') + if os.path.exists(graph_path): + return True + + return False + + def save(self): + """save the graph list and the labels""" + graph_path = os.path.join(self.save_path, self.save_name + '.bin') + labels = {} + for i, label in enumerate(self.labels): + labels['{}'.format(i)] = F.tensor(label) + save_graphs(str(graph_path), self.graphs, labels) + if self.verbose: + print('Done saving data into cached files.') + + def load(self): + graph_path = os.path.join(self.save_path, self.save_name + '.bin') + self.graphs, labels = load_graphs(str(graph_path)) + + self.labels = [] + for i in range(len(lables)): + self.labels.append(labels['{}'.format(i)].asnumpy()) + # load pmpds under self.raw_path + with open("{}/pmpds.pkl".format(self.raw_path), 'rb') as f: + self.pmpds = _pickle_load(f) + if self.verbose: + print('Done loading data into cached files.') + # sanity check + assert len(self.graphs) == len(self.pmpds) + assert len(self.graphs) == len(self.labels) + + def __len__(self): + return len(self.graphs) + + def __getitem__(self, i): + r"""Gets the idx-th sample. + Parameters + ----------- + idx : int + The sample index. + Returns + ------- + (dgl.DGLGraph, scipy.sparse.coo_matrix, int) + The graph, scipy sparse coo_matrix and its label. + """ + return (self.graphs[i], self.pmpds[i], self.labels[i]) + + @property + def save_name(self): + return self.name + '_dgl_graph' + + @staticmethod + def collate_fn(cur): + graphs, pmpds, labels = zip(*cur) + batched_graphs = batch.batch(graphs) + batched_pmpds = sp.block_diag(pmpds) + batched_labels = np.concatenate(labels, axis=0) + return batched_graphs, batched_pmpds, batched_labels diff --git a/examples/lgnn/model.py b/examples/lgnn/model.py new file mode 100644 index 00000000..c6f0ff2a --- /dev/null +++ b/examples/lgnn/model.py @@ -0,0 +1,103 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import itertools +import pgl +# import pgl.function as fn +import networkx as nx +import paddle as th +import paddle.nn as nn +import paddle.nn.functional as F +import numpy as np + + +class GNNLayer(nn.Layer): + def __init__(self, in_feats, out_feats, radius): + super().__init__() + self.out_feats = out_feats + self.radius = radius + + new_linear = lambda: nn.Linear(in_feats, out_feats) + new_linear_list = lambda: nn.LayerList([new_linear() for i in range(radius)]) + + self.theta_x, self.theta_deg, self.theta_y = \ + new_linear(), new_linear(), new_linear() + self.theta_list = new_linear_list() + + self.gamma_y, self.gamma_deg, self.gamma_x = \ + new_linear(), new_linear(), new_linear() + self.gamma_list = new_linear_list() + + self.bn_x = nn.BatchNorm1D(out_feats) + self.bn_y = nn.BatchNorm1D(out_feats) + + def aggregate(self, g, z): + z_list = [] + g.ndata['z'] = z + g.update_all(fn.copy_src(src='z', out='m'), fn.sum(msg='m', out='z')) + z_list.append(g.ndata['z']) + for i in range(self.radius - 1): + for j in range(2**i): + g.update_all( + fn.copy_src( + src='z', out='m'), fn.sum(msg='m', out='z')) + z_list.append(g.ndata['z']) + return z_list + + def forward(self, g, lg, x, y, deg_g, deg_lg, pm_pd): + pmpd_x = F.embedding(pm_pd, x) + + sum_x = sum( + theta(z) + for theta, z in zip(self.theta_list, self.aggregate(g, x))) + + g.edata['y'] = y + g.update_all(fn.copy_edge(edge='y', out='m'), fn.sum('m', 'pmpd_y')) + pmpd_y = g.ndata.pop('pmpd_y') + + x = self.theta_x(x) + self.theta_deg(deg_g * + x) + sum_x + self.theta_y(pmpd_y) + n = self.out_feats // 2 + x = paddle.cat([x[:, :n], F.relu(x[:, n:])], 1) + x = self.bn_x(x) + + sum_y = sum( + gamma(z) + for gamma, z in zip(self.gamma_list, self.aggregate(lg, y))) + + y = self.gamma_y(y) + self.gamma_deg(deg_lg * + y) + sum_y + self.gamma_x(pmpd_x) + y = paddle.cat([y[:, :n], F.relu(y[:, n:])], 1) + y = self.bn_y(y) + + return x, y + + +class GNN(nn.Layer): + def __init__(self, feats, radius, n_classes): + super(GNN, self).__init__() + self.linear = nn.Linear(feats[-1], n_classes) + self.module_list = nn.LayerList( + [GNNLayer(m, n, radius) for m, n in zip(feats[:-1], feats[1:])]) + + def forward(self, g, lg, deg_g, deg_lg, pm_pd): + x, y = deg_g, deg_lg + for module in self.module_list: + x, y = module(g, lg, x, y, deg_g, deg_lg, pm_pd) + return self.linear(x) + + +if __name__ == "__main__": + g = GNN([10, 10, 10], 3, 7) diff --git a/examples/lgnn/train.py b/examples/lgnn/train.py new file mode 100644 index 00000000..dccaa941 --- /dev/null +++ b/examples/lgnn/train.py @@ -0,0 +1,179 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Supervised Community Detection with Line Graph Neural Networks https://arxiv.org/abs/1705.08415 + +from __future__ import division +import time + +import argparse +from itertools import permutations + +import numpy as np +import paddle as th +import paddle.nn.functional as F +import paddle.optim as optim +from paddle.utils.data import DataLoader + +from pgl.data import SBMMixtureDataset +import model as gnn + +parser = argparse.ArgumentParser() +parser.add_argument('--batch-size', type=int, help='Batch size', default=1) +parser.add_argument('--gpu', type=int, help='GPU index', default=-1) +parser.add_argument('--lr', type=float, help='Learning rate', default=0.001) +parser.add_argument( + '--n-communities', type=int, help='Number of communities', default=2) +parser.add_argument( + '--n-epochs', type=int, help='Number of epochs', default=100) +parser.add_argument( + '--n-features', type=int, help='Number of features', default=16) +parser.add_argument( + '--n-graphs', type=int, help='Number of graphs', default=10) +parser.add_argument( + '--n-layers', type=int, help='Number of layers', default=30) +parser.add_argument( + '--n-nodes', type=int, help='Number of nodes', default=10000) +parser.add_argument('--optim', type=str, help='Optimizer', default='Adam') +parser.add_argument('--radius', type=int, help='Radius', default=3) +parser.add_argument('--verbose', action='store_true') +args = parser.parse_args() + +dev = paddle.device('cpu') if args.gpu < 0 else paddle.device('cuda:%d' % + args.gpu) +K = args.n_communities + +training_dataset = SBMMixtureDataset(args.n_graphs, args.n_nodes, K) +training_loader = DataLoader( + training_dataset, + args.batch_size, + collate_fn=training_dataset.collate_fn, + drop_last=True) + +ones = paddle.ones(args.n_nodes // K) +y_list = [ + paddle.cat([x * ones for x in p]).long().to(dev) + for p in permutations(range(K)) +] + +feats = [1] + [args.n_features] * args.n_layers + [K] +model = gnn.GNN(feats, args.radius, K).to(dev) +optimizer = getattr(optim, args.optim)(model.parameters(), lr=args.lr) + + +def compute_overlap(z_list): + ybar_list = [paddle.max(z, 1)[1] for z in z_list] + overlap_list = [] + for y_bar in ybar_list: + accuracy = max(paddle.sum(y_bar == y).item() + for y in y_list) / args.n_nodes + overlap = (accuracy - 1 / K) / (1 - 1 / K) + overlap_list.append(overlap) + return sum(overlap_list) / len(overlap_list) + + +def from_np(f, *args): + def wrap(*args): + new = [ + paddle.to_tensor(x) if isinstance(x, np.ndarray) else x + for x in args + ] + return f(*new) + + return wrap + + +@from_np +def step(i, j, g, lg, deg_g, deg_lg, pm_pd): + """ One step of training. """ + g = g.to(dev) + lg = lg.to(dev) + deg_g = deg_g.to(dev).unsqueeze(1) + deg_lg = deg_lg.to(dev).unsqueeze(1) + pm_pd = pm_pd.to(dev) + t0 = time.time() + z = model(g, lg, deg_g, deg_lg, pm_pd) + t_forward = time.time() - t0 + + z_list = paddle.chunk(z, args.batch_size, 0) + loss = sum(min(F.cross_entropy(z, y) for y in y_list) + for z in z_list) / args.batch_size + overlap = compute_overlap(z_list) + + optimizer.zero_grad() + t0 = time.time() + loss.backward() + t_backward = time.time() - t0 + optimizer.step() + + return loss, overlap, t_forward, t_backward + + +@from_np +def inference(g, lg, deg_g, deg_lg, pm_pd): + g = g.to(dev) + lg = lg.to(dev) + deg_g = deg_g.to(dev).unsqueeze(1) + deg_lg = deg_lg.to(dev).unsqueeze(1) + pm_pd = pm_pd.to(dev) + + z = model(g, lg, deg_g, deg_lg, pm_pd) + + return z + + +def test(): + p_list = [6, 5.5, 5, 4.5, 1.5, 1, 0.5, 0] + q_list = [0, 0.5, 1, 1.5, 4.5, 5, 5.5, 6] + N = 1 + overlap_list = [] + for p, q in zip(p_list, q_list): + dataset = SBMMixtureDataset(N, args.n_nodes, K, pq=[[p, q]] * N) + loader = DataLoader(dataset, N, collate_fn=dataset.collate_fn) + g, lg, deg_g, deg_lg, pm_pd = next(iter(loader)) + z = inference(g, lg, deg_g, deg_lg, pm_pd) + overlap_list.append(compute_overlap(paddle.chunk(z, N, 0))) + return overlap_list + + +n_iterations = args.n_graphs // args.batch_size +for i in range(args.n_epochs): + total_loss, total_overlap, s_forward, s_backward = 0, 0, 0, 0 + for j, [g, lg, deg_g, deg_lg, pm_pd] in enumerate(training_loader): + loss, overlap, t_forward, t_backward = step(i, j, g, lg, deg_g, deg_lg, + pm_pd) + + total_loss += loss + total_overlap += overlap + s_forward += t_forward + s_backward += t_backward + + epoch = '0' * (len(str(args.n_epochs)) - len(str(i))) + iteration = '0' * (len(str(n_iterations)) - len(str(j))) + if args.verbose: + print('[epoch %s%d iteration %s%d]loss %.3f | overlap %.3f' % + (epoch, i, iteration, j, loss, overlap)) + + epoch = '0' * (len(str(args.n_epochs)) - len(str(i))) + loss = total_loss / (j + 1) + overlap = total_overlap / (j + 1) + t_forward = s_forward / (j + 1) + t_backward = s_backward / (j + 1) + print( + '[epoch %s%d]loss %.3f | overlap %.3f | forward time %.3fs | backward time %.3fs' + % (epoch, i, loss, overlap, t_forward, t_backward)) + + overlap_list = test() + overlap_str = ' - '.join(['%.3f' % overlap for overlap in overlap_list]) + print('[epoch %s%d]overlap: %s' % (epoch, i, overlap_str)) From 50f96f3d2dcbdc7fc2a85cbe476aaf9fd1e8db89 Mon Sep 17 00:00:00 2001 From: Zhong Hui Date: Fri, 19 Mar 2021 20:57:57 +0800 Subject: [PATCH 2/2] [Model] add lgnn. --- examples/lgnn/cora_binary.py | 157 +++++------------------- examples/lgnn/model.py | 149 ++++++++++++----------- examples/lgnn/train.py | 225 ++++++++++------------------------- 3 files changed, 168 insertions(+), 363 deletions(-) diff --git a/examples/lgnn/cora_binary.py b/examples/lgnn/cora_binary.py index c9c8b7db..38b36d86 100644 --- a/examples/lgnn/cora_binary.py +++ b/examples/lgnn/cora_binary.py @@ -17,148 +17,45 @@ import sys import numpy as np -import pickle as pkl -import networkx as nx -import scipy.sparse as sp -from .utils import save_graphs, load_graphs, save_info, load_info, makedirs, _get_dgl_url -from .utils import generate_mask_tensor -from .utils import deprecate_property, deprecate_function -from .dgl_dataset import DGLBuiltinDataset -from .. import convert -from .. import batch -from .. import backend as F -from ..convert import graph as dgl_graph -from ..convert import from_networkx, to_networkx +import pgl +from pgl.utils.data import Dataset -backend = os.environ.get('DGLBACKEND', 'pytorch') - -def _pickle_load(pkl_file): - if sys.version_info > (3, 0): - return pkl.load(pkl_file, encoding='latin1') - else: - return pkl.load(pkl_file) - - -class CoraBinary(DGLBuiltinDataset): +class CoraBinary(Dataset): """A mini-dataset for binary classification task using Cora. - After loaded, it has following members: - graphs : list of :class:`~dgl.DGLGraph` - pmpds : list of :class:`scipy.sparse.coo_matrix` - labels : list of :class:`numpy.ndarray` - Parameters - ----------- - raw_dir : str - Raw file directory to download/contains the input data directory. - Default: ~/.dgl/ - force_reload : bool - Whether to reload the dataset. Default: False - verbose: bool - Whether to print out progress information. Default: True. """ - def __init__(self, raw_dir=None, force_reload=False, verbose=True): - name = 'cora_binary' - url = _get_dgl_url('dataset/cora_binary.zip') - super(CoraBinary, self).__init__( - name, - url=url, - raw_dir=raw_dir, - force_reload=force_reload, - verbose=verbose) - - def process(self): - root = self.raw_path - # load graphs - self.graphs = [] - with open("{}/graphs.txt".format(root), 'r') as f: - elist = [] - for line in f.readlines(): - if line.startswith('graph'): - if len(elist) != 0: - self.graphs.append(dgl_graph(tuple(zip(*elist)))) - elist = [] - else: - u, v = line.strip().split(' ') - elist.append((int(u), int(v))) - if len(elist) != 0: - self.graphs.append(dgl_graph(tuple(zip(*elist)))) - with open("{}/pmpds.pkl".format(root), 'rb') as f: - self.pmpds = _pickle_load(f) - self.labels = [] - with open("{}/labels.txt".format(root), 'r') as f: - cur = [] - for line in f.readlines(): - if line.startswith('graph'): - if len(cur) != 0: - self.labels.append(np.asarray(cur)) - cur = [] - else: - cur.append(int(line.strip())) - if len(cur) != 0: - self.labels.append(np.asarray(cur)) - # sanity check - assert len(self.graphs) == len(self.pmpds) - assert len(self.graphs) == len(self.labels) - - def has_cache(self): - graph_path = os.path.join(self.save_path, self.save_name + '.bin') - if os.path.exists(graph_path): - return True - - return False - - def save(self): - """save the graph list and the labels""" - graph_path = os.path.join(self.save_path, self.save_name + '.bin') - labels = {} - for i, label in enumerate(self.labels): - labels['{}'.format(i)] = F.tensor(label) - save_graphs(str(graph_path), self.graphs, labels) - if self.verbose: - print('Done saving data into cached files.') + def __init__(self, raw_dir=None): + super(CoraBinary, self).__init__() + self.num = 21 + self.save_path = "./cora_binary" + if not os.path.exists(self.save_path): + os.system( + "wget http://10.255.129.12:8122/cora_binary.zip && unzip cora_binary.zip" + ) + self.graphs, self.line_graphs, self.labels = [], [], [] + self.load() def load(self): - graph_path = os.path.join(self.save_path, self.save_name + '.bin') - self.graphs, labels = load_graphs(str(graph_path)) - - self.labels = [] - for i in range(len(lables)): - self.labels.append(labels['{}'.format(i)].asnumpy()) - # load pmpds under self.raw_path - with open("{}/pmpds.pkl".format(self.raw_path), 'rb') as f: - self.pmpds = _pickle_load(f) - if self.verbose: - print('Done loading data into cached files.') - # sanity check - assert len(self.graphs) == len(self.pmpds) - assert len(self.graphs) == len(self.labels) + for idx in range(self.num): + self.graphs.append( + pgl.Graph.load( + os.path.join(self.save_path, str(idx), "graph"))) + self.line_graphs.append( + pgl.Graph.load( + os.path.join(self.save_path, str(idx), "line_graph"))) + self.labels.append( + np.load(os.path.join(self.save_path, str(idx), "labels.npy"))) def __len__(self): return len(self.graphs) def __getitem__(self, i): - r"""Gets the idx-th sample. - Parameters - ----------- - idx : int - The sample index. - Returns - ------- - (dgl.DGLGraph, scipy.sparse.coo_matrix, int) - The graph, scipy sparse coo_matrix and its label. - """ - return (self.graphs[i], self.pmpds[i], self.labels[i]) + return (self.graphs[i], self.line_graphs[i], self.labels[i]) - @property - def save_name(self): - return self.name + '_dgl_graph' - @staticmethod - def collate_fn(cur): - graphs, pmpds, labels = zip(*cur) - batched_graphs = batch.batch(graphs) - batched_pmpds = sp.block_diag(pmpds) - batched_labels = np.concatenate(labels, axis=0) - return batched_graphs, batched_pmpds, batched_labels +if __name__ == "__main__": + c = CoraBinary() + for data in c: + print(data) diff --git a/examples/lgnn/model.py b/examples/lgnn/model.py index c6f0ff2a..5a08e136 100644 --- a/examples/lgnn/model.py +++ b/examples/lgnn/model.py @@ -15,89 +15,96 @@ import copy import itertools import pgl -# import pgl.function as fn -import networkx as nx -import paddle as th import paddle.nn as nn import paddle.nn.functional as F import numpy as np +import paddle -class GNNLayer(nn.Layer): +def aggregate_radius(radius, graph, feature): + feat_list = [] + feature = graph.send_recv(feature, "sum") + feat_list.append(feature) + for i in range(radius - 1): + for j in range(2**i): + feature = graph.send_recv(feature, "sum") + feat_list.append(feature) + return feat_list + + +class LGNNCore(nn.Layer): def __init__(self, in_feats, out_feats, radius): - super().__init__() + super(LGNNCore, self).__init__() self.out_feats = out_feats self.radius = radius - new_linear = lambda: nn.Linear(in_feats, out_feats) - new_linear_list = lambda: nn.LayerList([new_linear() for i in range(radius)]) - - self.theta_x, self.theta_deg, self.theta_y = \ - new_linear(), new_linear(), new_linear() - self.theta_list = new_linear_list() - - self.gamma_y, self.gamma_deg, self.gamma_x = \ - new_linear(), new_linear(), new_linear() - self.gamma_list = new_linear_list() - - self.bn_x = nn.BatchNorm1D(out_feats) - self.bn_y = nn.BatchNorm1D(out_feats) - - def aggregate(self, g, z): - z_list = [] - g.ndata['z'] = z - g.update_all(fn.copy_src(src='z', out='m'), fn.sum(msg='m', out='z')) - z_list.append(g.ndata['z']) - for i in range(self.radius - 1): - for j in range(2**i): - g.update_all( - fn.copy_src( - src='z', out='m'), fn.sum(msg='m', out='z')) - z_list.append(g.ndata['z']) - return z_list - - def forward(self, g, lg, x, y, deg_g, deg_lg, pm_pd): - pmpd_x = F.embedding(pm_pd, x) - - sum_x = sum( - theta(z) - for theta, z in zip(self.theta_list, self.aggregate(g, x))) - - g.edata['y'] = y - g.update_all(fn.copy_edge(edge='y', out='m'), fn.sum('m', 'pmpd_y')) - pmpd_y = g.ndata.pop('pmpd_y') - - x = self.theta_x(x) + self.theta_deg(deg_g * - x) + sum_x + self.theta_y(pmpd_y) + self.linear_prev = nn.Linear(in_feats, out_feats) + self.linear_deg = nn.Linear(in_feats, out_feats) + self.linear_radius = nn.LayerList( + [nn.Linear(in_feats, out_feats) for i in range(radius)]) + # self.linear_fuse = nn.Linear(in_feats, out_feats) + self.bn = nn.BatchNorm1D(out_feats) + + def forward(self, graph, feat_a, feat_b, deg): + # term "prev" + prev_proj = self.linear_prev(feat_a) + # term "deg" + deg_proj = self.linear_deg(deg * feat_a) + # term "radius" "aggregate 2^j-hop features + hop2j_list = aggregate_radius(self.radius, graph, feat_a) + # apply linear transformation + hop2j_list = [ + linear(x) for linear, x in zip(self.linear_radius, hop2j_list) + ] + radius_proj = sum(hop2j_list) + + # TODO add fuse + # sum them together + result = prev_proj + deg_proj + radius_proj + + # skip connection and batch norm n = self.out_feats // 2 - x = paddle.cat([x[:, :n], F.relu(x[:, n:])], 1) - x = self.bn_x(x) - - sum_y = sum( - gamma(z) - for gamma, z in zip(self.gamma_list, self.aggregate(lg, y))) - - y = self.gamma_y(y) + self.gamma_deg(deg_lg * - y) + sum_y + self.gamma_x(pmpd_x) - y = paddle.cat([y[:, :n], F.relu(y[:, n:])], 1) - y = self.bn_y(y) + result = paddle.concat([result[:, :n], F.relu(result[:, n:])], 1) + result = self.bn(result) + return result - return x, y - -class GNN(nn.Layer): - def __init__(self, feats, radius, n_classes): - super(GNN, self).__init__() - self.linear = nn.Linear(feats[-1], n_classes) - self.module_list = nn.LayerList( - [GNNLayer(m, n, radius) for m, n in zip(feats[:-1], feats[1:])]) - - def forward(self, g, lg, deg_g, deg_lg, pm_pd): - x, y = deg_g, deg_lg - for module in self.module_list: - x, y = module(g, lg, x, y, deg_g, deg_lg, pm_pd) - return self.linear(x) +class LGNNLayer(nn.Layer): + def __init__(self, in_feats, out_feats, radius): + super(LGNNLayer, self).__init__() + self.g_layer = LGNNCore(in_feats, out_feats, radius) + self.lg_layer = LGNNCore(in_feats, out_feats, radius) + + def forward(self, graph, line_graph, feat, lg_feat, deg_g, deg_lg): + next_feat = self.g_layer(graph, feat, lg_feat, deg_g) + next_lg_feat = self.lg_layer(line_graph, lg_feat, feat, deg_lg) + return next_feat, next_lg_feat + + +class LGNN(nn.Layer): + def __init__(self, radius): + super(LGNN, self).__init__() + self.layer1 = LGNNLayer(1, 16, radius) # input is scalar feature + self.layer2 = LGNNLayer(16, 16, radius) # hidden size is 16 + self.layer3 = LGNNLayer(16, 16, radius) + self.linear = nn.Linear(16, 2) # predice two classes + + def forward(self, graph, line_graph): + # compute the degrees + deg_g = graph.indegree().astype("float32").unsqueeze(-1) + #print("deg_g", deg_g) + deg_lg = line_graph.indegree().astype("float32").unsqueeze(-1) + #print("deg_lg", deg_lg) + # use degree as the input feature + feat, lg_feat = deg_g, deg_lg + feat, lg_feat = self.layer1(graph, line_graph, feat, lg_feat, deg_g, + deg_lg) + feat, lg_feat = self.layer2(graph, line_graph, feat, lg_feat, deg_g, + deg_lg) + feat, lg_feat = self.layer3(graph, line_graph, feat, lg_feat, deg_g, + deg_lg) + return self.linear(feat) if __name__ == "__main__": - g = GNN([10, 10, 10], 3, 7) + g = LGNN(3) diff --git a/examples/lgnn/train.py b/examples/lgnn/train.py index dccaa941..6c4fc203 100644 --- a/examples/lgnn/train.py +++ b/examples/lgnn/train.py @@ -12,168 +12,69 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Supervised Community Detection with Line Graph Neural Networks https://arxiv.org/abs/1705.08415 - -from __future__ import division -import time - -import argparse -from itertools import permutations +import os +import paddle import numpy as np -import paddle as th +import paddle.nn as nn import paddle.nn.functional as F -import paddle.optim as optim -from paddle.utils.data import DataLoader - -from pgl.data import SBMMixtureDataset -import model as gnn - -parser = argparse.ArgumentParser() -parser.add_argument('--batch-size', type=int, help='Batch size', default=1) -parser.add_argument('--gpu', type=int, help='GPU index', default=-1) -parser.add_argument('--lr', type=float, help='Learning rate', default=0.001) -parser.add_argument( - '--n-communities', type=int, help='Number of communities', default=2) -parser.add_argument( - '--n-epochs', type=int, help='Number of epochs', default=100) -parser.add_argument( - '--n-features', type=int, help='Number of features', default=16) -parser.add_argument( - '--n-graphs', type=int, help='Number of graphs', default=10) -parser.add_argument( - '--n-layers', type=int, help='Number of layers', default=30) -parser.add_argument( - '--n-nodes', type=int, help='Number of nodes', default=10000) -parser.add_argument('--optim', type=str, help='Optimizer', default='Adam') -parser.add_argument('--radius', type=int, help='Radius', default=3) -parser.add_argument('--verbose', action='store_true') -args = parser.parse_args() - -dev = paddle.device('cpu') if args.gpu < 0 else paddle.device('cuda:%d' % - args.gpu) -K = args.n_communities - -training_dataset = SBMMixtureDataset(args.n_graphs, args.n_nodes, K) -training_loader = DataLoader( - training_dataset, - args.batch_size, - collate_fn=training_dataset.collate_fn, - drop_last=True) - -ones = paddle.ones(args.n_nodes // K) -y_list = [ - paddle.cat([x * ones for x in p]).long().to(dev) - for p in permutations(range(K)) -] - -feats = [1] + [args.n_features] * args.n_layers + [K] -model = gnn.GNN(feats, args.radius, K).to(dev) -optimizer = getattr(optim, args.optim)(model.parameters(), lr=args.lr) - - -def compute_overlap(z_list): - ybar_list = [paddle.max(z, 1)[1] for z in z_list] - overlap_list = [] - for y_bar in ybar_list: - accuracy = max(paddle.sum(y_bar == y).item() - for y in y_list) / args.n_nodes - overlap = (accuracy - 1 / K) / (1 - 1 / K) - overlap_list.append(overlap) - return sum(overlap_list) / len(overlap_list) - - -def from_np(f, *args): - def wrap(*args): - new = [ - paddle.to_tensor(x) if isinstance(x, np.ndarray) else x - for x in args - ] - return f(*new) - - return wrap - - -@from_np -def step(i, j, g, lg, deg_g, deg_lg, pm_pd): - """ One step of training. """ - g = g.to(dev) - lg = lg.to(dev) - deg_g = deg_g.to(dev).unsqueeze(1) - deg_lg = deg_lg.to(dev).unsqueeze(1) - pm_pd = pm_pd.to(dev) - t0 = time.time() - z = model(g, lg, deg_g, deg_lg, pm_pd) - t_forward = time.time() - t0 - - z_list = paddle.chunk(z, args.batch_size, 0) - loss = sum(min(F.cross_entropy(z, y) for y in y_list) - for z in z_list) / args.batch_size - overlap = compute_overlap(z_list) - - optimizer.zero_grad() - t0 = time.time() - loss.backward() - t_backward = time.time() - t0 - optimizer.step() - - return loss, overlap, t_forward, t_backward - - -@from_np -def inference(g, lg, deg_g, deg_lg, pm_pd): - g = g.to(dev) - lg = lg.to(dev) - deg_g = deg_g.to(dev).unsqueeze(1) - deg_lg = deg_lg.to(dev).unsqueeze(1) - pm_pd = pm_pd.to(dev) - - z = model(g, lg, deg_g, deg_lg, pm_pd) - - return z - - -def test(): - p_list = [6, 5.5, 5, 4.5, 1.5, 1, 0.5, 0] - q_list = [0, 0.5, 1, 1.5, 4.5, 5, 5.5, 6] - N = 1 - overlap_list = [] - for p, q in zip(p_list, q_list): - dataset = SBMMixtureDataset(N, args.n_nodes, K, pq=[[p, q]] * N) - loader = DataLoader(dataset, N, collate_fn=dataset.collate_fn) - g, lg, deg_g, deg_lg, pm_pd = next(iter(loader)) - z = inference(g, lg, deg_g, deg_lg, pm_pd) - overlap_list.append(compute_overlap(paddle.chunk(z, N, 0))) - return overlap_list - - -n_iterations = args.n_graphs // args.batch_size -for i in range(args.n_epochs): - total_loss, total_overlap, s_forward, s_backward = 0, 0, 0, 0 - for j, [g, lg, deg_g, deg_lg, pm_pd] in enumerate(training_loader): - loss, overlap, t_forward, t_backward = step(i, j, g, lg, deg_g, deg_lg, - pm_pd) - - total_loss += loss - total_overlap += overlap - s_forward += t_forward - s_backward += t_backward - - epoch = '0' * (len(str(args.n_epochs)) - len(str(i))) - iteration = '0' * (len(str(n_iterations)) - len(str(j))) - if args.verbose: - print('[epoch %s%d iteration %s%d]loss %.3f | overlap %.3f' % - (epoch, i, iteration, j, loss, overlap)) - - epoch = '0' * (len(str(args.n_epochs)) - len(str(i))) - loss = total_loss / (j + 1) - overlap = total_overlap / (j + 1) - t_forward = s_forward / (j + 1) - t_backward = s_backward / (j + 1) - print( - '[epoch %s%d]loss %.3f | overlap %.3f | forward time %.3fs | backward time %.3fs' - % (epoch, i, loss, overlap, t_forward, t_backward)) - - overlap_list = test() - overlap_str = ' - '.join(['%.3f' % overlap for overlap in overlap_list]) - print('[epoch %s%d]overlap: %s' % (epoch, i, overlap_str)) +import pgl +from pgl.utils.data import Dataloader + +from model import LGNN +from cora_binary import CoraBinary + + +def main(): + train_set = CoraBinary() + training_loader = Dataloader(train_set, batch_size=1) + + model = LGNN(radius=3) + optimizer = paddle.optimizer.Adam( + parameters=model.parameters(), learning_rate=4e-3) + for i in range(20): + all_loss = [] + all_acc = [] + #for idx, (g, lg, label) in enumerate(training_loader): + for idx, inputs in enumerate(training_loader): + #print(xxx) + (p_g, p_lg, label) = inputs[0] + # Generate the line graph. + p_g.tensor() + p_lg.tensor() + # Create paddle tensors + label = paddle.to_tensor(label) + # Forward + z = model(p_g, p_lg) + + # Calculate loss: + # Since there are only two communities, there are only two permutations + # of the community labels. + loss_perm1 = F.cross_entropy(z, label) + loss_perm2 = F.cross_entropy(z, 1 - label) + loss = paddle.minimum(loss_perm1, loss_perm2) + + # Calculate accuracy: + # pred = paddle.max(z, 1) + pred = paddle.where(z[:, 0] > z[:, 1], + paddle.zeros_like(z[:, 0]), + paddle.ones_like(z[:, 0])) + # print(pred) + # print(label) + acc_perm1 = (pred == label).astype("float32").mean() + acc_perm2 = (pred == 1 - label).astype("float32").mean() + acc = paddle.maximum(acc_perm1, acc_perm2) + #print(acc) + all_loss.append(*loss.numpy().tolist()) + all_acc.append(*acc.numpy().tolist()) + + optimizer.clear_grad() + loss.backward() + optimizer.step() + niters = len(all_loss) + print("Epoch %d | loss %.4f | accuracy %.4f" % + (i, sum(all_loss) / niters, sum(all_acc) / niters)) + + +if __name__ == "__main__": + main()