-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmpi_consensus_admm.py
More file actions
80 lines (66 loc) · 2.05 KB
/
mpi_consensus_admm.py
File metadata and controls
80 lines (66 loc) · 2.05 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
"""
Copyright 2013 Steven Diamond
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
# Consensus ADMM
# See ADMM paper section 7
# xi = argmin [f(xi) + (rho/2)norm(xi - xbar + ui)]
# ui = ui + xi - xbar
import operator as op
from functools import reduce
from multiprocessing import Pool
import dill
import numpy as np
from numpy.random import randn
from cvxpy import Minimize, Parameter, Problem, Variable, norm, sum_squares
# Initialize the problem.
n = 1000
m = 500
rho = 1.0
xbar = 0
# No! Distribute problem objects.
# Use MPI.
# Create function to perform local update from penalty f.
def create_update(f):
x = Variable(n)
u = Parameter(n)
def local_update(xbar):
# Update u.
if x.value is None:
u.value = np.zeros(n)
else:
u.value += x.value - xbar
# Update x.
obj = f(x) + (rho/2)*sum_squares(x - xbar + u)
Problem(Minimize(obj)).solve()
return x.value
return local_update
# Penalty functions.
functions = map(dill.dumps,
map(create_update, [
lambda x: norm(randn(m, n)*x + randn(m), 2),
lambda x: norm(randn(m, n)*x + randn(m), 2),
lambda x: norm(randn(m, n)*x + randn(m), 2),
lambda x: norm(randn(m, n)*x + randn(m), 2),
lambda x: norm(x, 1),
])
)
# Do ADMM iterations in parallel.
def apply_f(args):
f = dill.loads(args[0])
return f(args[1])
pool = Pool(processes = len(functions))
for i in range(10):
total = reduce(op.add,
pool.map(apply_f, zip(functions, len(functions)*[xbar]))
)
xbar = total/len(functions)
print(i)