diff --git a/examples/lgnn/README.md b/examples/lgnn/README.md new file mode 100644 index 00000000..29695c4e --- /dev/null +++ b/examples/lgnn/README.md @@ -0,0 +1,9 @@ +# Line Graph Neural Networks + +[Line Graph Neural Networks](https://arxiv.org/pdf/1705.08415.pdf) is an neural network for community detection. + +## How to run + +```shell +python train.py +``` diff --git a/examples/lgnn/cora_binary.py b/examples/lgnn/cora_binary.py new file mode 100644 index 00000000..38b36d86 --- /dev/null +++ b/examples/lgnn/cora_binary.py @@ -0,0 +1,61 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Cora Binary dataset +import os +import sys + +import numpy as np + +import pgl +from pgl.utils.data import Dataset + + +class CoraBinary(Dataset): + """A mini-dataset for binary classification task using Cora. + """ + + def __init__(self, raw_dir=None): + super(CoraBinary, self).__init__() + self.num = 21 + self.save_path = "./cora_binary" + if not os.path.exists(self.save_path): + os.system( + "wget http://10.255.129.12:8122/cora_binary.zip && unzip cora_binary.zip" + ) + self.graphs, self.line_graphs, self.labels = [], [], [] + self.load() + + def load(self): + for idx in range(self.num): + self.graphs.append( + pgl.Graph.load( + os.path.join(self.save_path, str(idx), "graph"))) + self.line_graphs.append( + pgl.Graph.load( + os.path.join(self.save_path, str(idx), "line_graph"))) + self.labels.append( + np.load(os.path.join(self.save_path, str(idx), "labels.npy"))) + + def __len__(self): + return len(self.graphs) + + def __getitem__(self, i): + return (self.graphs[i], self.line_graphs[i], self.labels[i]) + + +if __name__ == "__main__": + c = CoraBinary() + for data in c: + print(data) diff --git a/examples/lgnn/model.py b/examples/lgnn/model.py new file mode 100644 index 00000000..5a08e136 --- /dev/null +++ b/examples/lgnn/model.py @@ -0,0 +1,110 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import itertools +import pgl +import paddle.nn as nn +import paddle.nn.functional as F +import numpy as np +import paddle + + +def aggregate_radius(radius, graph, feature): + feat_list = [] + feature = graph.send_recv(feature, "sum") + feat_list.append(feature) + for i in range(radius - 1): + for j in range(2**i): + feature = graph.send_recv(feature, "sum") + feat_list.append(feature) + return feat_list + + +class LGNNCore(nn.Layer): + def __init__(self, in_feats, out_feats, radius): + super(LGNNCore, self).__init__() + self.out_feats = out_feats + self.radius = radius + + self.linear_prev = nn.Linear(in_feats, out_feats) + self.linear_deg = nn.Linear(in_feats, out_feats) + self.linear_radius = nn.LayerList( + [nn.Linear(in_feats, out_feats) for i in range(radius)]) + # self.linear_fuse = nn.Linear(in_feats, out_feats) + self.bn = nn.BatchNorm1D(out_feats) + + def forward(self, graph, feat_a, feat_b, deg): + # term "prev" + prev_proj = self.linear_prev(feat_a) + # term "deg" + deg_proj = self.linear_deg(deg * feat_a) + # term "radius" "aggregate 2^j-hop features + hop2j_list = aggregate_radius(self.radius, graph, feat_a) + # apply linear transformation + hop2j_list = [ + linear(x) for linear, x in zip(self.linear_radius, hop2j_list) + ] + radius_proj = sum(hop2j_list) + + # TODO add fuse + # sum them together + result = prev_proj + deg_proj + radius_proj + + # skip connection and batch norm + n = self.out_feats // 2 + result = paddle.concat([result[:, :n], F.relu(result[:, n:])], 1) + result = self.bn(result) + return result + + +class LGNNLayer(nn.Layer): + def __init__(self, in_feats, out_feats, radius): + super(LGNNLayer, self).__init__() + self.g_layer = LGNNCore(in_feats, out_feats, radius) + self.lg_layer = LGNNCore(in_feats, out_feats, radius) + + def forward(self, graph, line_graph, feat, lg_feat, deg_g, deg_lg): + next_feat = self.g_layer(graph, feat, lg_feat, deg_g) + next_lg_feat = self.lg_layer(line_graph, lg_feat, feat, deg_lg) + return next_feat, next_lg_feat + + +class LGNN(nn.Layer): + def __init__(self, radius): + super(LGNN, self).__init__() + self.layer1 = LGNNLayer(1, 16, radius) # input is scalar feature + self.layer2 = LGNNLayer(16, 16, radius) # hidden size is 16 + self.layer3 = LGNNLayer(16, 16, radius) + self.linear = nn.Linear(16, 2) # predice two classes + + def forward(self, graph, line_graph): + # compute the degrees + deg_g = graph.indegree().astype("float32").unsqueeze(-1) + #print("deg_g", deg_g) + deg_lg = line_graph.indegree().astype("float32").unsqueeze(-1) + #print("deg_lg", deg_lg) + # use degree as the input feature + feat, lg_feat = deg_g, deg_lg + feat, lg_feat = self.layer1(graph, line_graph, feat, lg_feat, deg_g, + deg_lg) + feat, lg_feat = self.layer2(graph, line_graph, feat, lg_feat, deg_g, + deg_lg) + feat, lg_feat = self.layer3(graph, line_graph, feat, lg_feat, deg_g, + deg_lg) + return self.linear(feat) + + +if __name__ == "__main__": + g = LGNN(3) diff --git a/examples/lgnn/train.py b/examples/lgnn/train.py new file mode 100644 index 00000000..6c4fc203 --- /dev/null +++ b/examples/lgnn/train.py @@ -0,0 +1,80 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import paddle +import numpy as np +import paddle.nn as nn +import paddle.nn.functional as F +import pgl +from pgl.utils.data import Dataloader + +from model import LGNN +from cora_binary import CoraBinary + + +def main(): + train_set = CoraBinary() + training_loader = Dataloader(train_set, batch_size=1) + + model = LGNN(radius=3) + optimizer = paddle.optimizer.Adam( + parameters=model.parameters(), learning_rate=4e-3) + for i in range(20): + all_loss = [] + all_acc = [] + #for idx, (g, lg, label) in enumerate(training_loader): + for idx, inputs in enumerate(training_loader): + #print(xxx) + (p_g, p_lg, label) = inputs[0] + # Generate the line graph. + p_g.tensor() + p_lg.tensor() + # Create paddle tensors + label = paddle.to_tensor(label) + # Forward + z = model(p_g, p_lg) + + # Calculate loss: + # Since there are only two communities, there are only two permutations + # of the community labels. + loss_perm1 = F.cross_entropy(z, label) + loss_perm2 = F.cross_entropy(z, 1 - label) + loss = paddle.minimum(loss_perm1, loss_perm2) + + # Calculate accuracy: + # pred = paddle.max(z, 1) + pred = paddle.where(z[:, 0] > z[:, 1], + paddle.zeros_like(z[:, 0]), + paddle.ones_like(z[:, 0])) + # print(pred) + # print(label) + acc_perm1 = (pred == label).astype("float32").mean() + acc_perm2 = (pred == 1 - label).astype("float32").mean() + acc = paddle.maximum(acc_perm1, acc_perm2) + #print(acc) + all_loss.append(*loss.numpy().tolist()) + all_acc.append(*acc.numpy().tolist()) + + optimizer.clear_grad() + loss.backward() + optimizer.step() + niters = len(all_loss) + print("Epoch %d | loss %.4f | accuracy %.4f" % + (i, sum(all_loss) / niters, sum(all_acc) / niters)) + + +if __name__ == "__main__": + main()