diff --git a/fax/competitive/cga.py b/fax/competitive/cga.py index eae5f1a..54ac294 100644 --- a/fax/competitive/cga.py +++ b/fax/competitive/cga.py @@ -2,14 +2,13 @@ from functools import partial import jax +import jax.numpy as np from jax import lax from jax import tree_util -import jax.numpy as np from jax.experimental import optimizers from fax import converge from fax import loop -from fax.competitive import cg CGAState = collections.namedtuple("CGAState", "x y delta_x delta_y") diff --git a/fax/competitive/extragradient.py b/fax/competitive/extragradient.py new file mode 100644 index 0000000..20f0f8a --- /dev/null +++ b/fax/competitive/extragradient.py @@ -0,0 +1,55 @@ +from typing import Callable + +import jax.experimental.optimizers +from jax import numpy as np, tree_util + +import fax.competitive.sgd +from fax.jax_utils import add + + +def adam_extragradient_optimizer(step_size_x, step_size_y, b1=0.3, b2=0.2, eps=1e-8) -> (Callable, Callable, Callable): + """Construct optimizer triple for Adam. + + Args: + step_size_x: positive scalar, or a callable representing a step size schedule + that maps the iteration index to positive scalar for the first player. + step_size_y: positive scalar, or a callable representing a step size schedule + that maps the iteration index to positive scalar for the second player. + b1: optional, a positive scalar value for beta_1, the exponential decay rate + for the first moment estimates (default 0.3). + b2: optional, a positive scalar value for beta_2, the exponential decay rate + for the second moment estimates (default 0.2). + eps: optional, a positive scalar value for epsilon, a small constant for + numerical stability (default 1e-8). + + Returns: + An (init_fun, update_fun, get_params) triple. + """ + step_size_x = jax.experimental.optimizers.make_schedule(step_size_x) + step_size_y = jax.experimental.optimizers.make_schedule(step_size_y) + + def init(initial_values): + mean_avg = tree_util.tree_map(lambda x: np.zeros(x.shape, x.dtype), initial_values) + var_avg = tree_util.tree_map(lambda x: np.zeros(x.shape, x.dtype), initial_values) + return initial_values, (mean_avg, var_avg) + + def update(step, grad_fns, state): + x0, optimizer_state = state + step_sizes = - step_size_x(step), step_size_y(step) # negate the step size so that we do gradient ascent-descent + + grads = grad_fns(*x0) + deltas, optimizer_state = fax.competitive.sgd.adam_step(b1, b2, eps, step_sizes, grads, optimizer_state, step) + + x_bar = add(x0, deltas) + + grads = grad_fns(*x_bar) # the gradient is evaluated at x_bar + deltas, optimizer_state = fax.competitive.sgd.adam_step(b1, b2, eps, step_sizes, grads, optimizer_state, step) + x1 = add(x0, deltas) # but applied at x_0 + + return x1, optimizer_state + + def get_params(state): + x, _optimizer_state = state + return x + + return init, update, get_params diff --git a/fax/constrained/constrained.py b/fax/constrained/constrained.py index 85df37c..badef05 100644 --- a/fax/constrained/constrained.py +++ b/fax/constrained/constrained.py @@ -2,26 +2,24 @@ """ import collections -from scipy.optimize import minimize - import jax -from jax import lax -from jax import jit +import jax.numpy as np from jax import grad from jax import jacrev -import jax.numpy as np +from jax import jit +from jax import lax from jax import tree_util from jax.experimental import optimizers from jax.flatten_util import ravel_pytree +from scipy.optimize import minimize -from fax import math from fax import converge +from fax import math from fax.competitive import cg from fax.competitive import cga -from fax.loop import fixed_point_iteration from fax.implicit.twophase import make_adjoint_fixed_point_iteration from fax.implicit.twophase import make_forward_fixed_point_iteration - +from fax.loop import fixed_point_iteration ConstrainedSolution = collections.namedtuple( "ConstrainedSolution", diff --git a/fax/constrained/constrained_test.py b/fax/constrained/constrained_test.py index fa081a3..7ddc216 100644 --- a/fax/constrained/constrained_test.py +++ b/fax/constrained/constrained_test.py @@ -1,28 +1,33 @@ +import absl.testing +import absl.testing.parameterized import hypothesis.extra.numpy import hypothesis.strategies +import jax +import jax.experimental.optimizers +import jax.nn import jax.numpy as np +import jax.scipy.special import jax.test_util +import jax.tree_util import numpy as onp from absl.testing import absltest -from absl.testing import parameterized -from jax import random -from jax import tree_util -from jax.config import config from jax.experimental import optimizers from jax.experimental.stax import softmax -from jax.scipy.special import logsumexp -import fax.tests.hock_schittkowski_suite +import fax +import fax.test_util from fax import converge from fax import test_util +from fax.competitive import extragradient from fax.constrained import cga_ecp from fax.constrained import cga_lagrange_min from fax.constrained import implicit_ecp from fax.constrained import make_lagrangian from fax.constrained import slsqp_ecp -config.update("jax_enable_x64", True) -benchmarks = list(fax.tests.hock_schittkowski_suite.load_HockSchittkowski_models()) +jax.config.update("jax_enable_x64", True) +test_params = dict(rtol=1e-4, atol=1e-4, check_dtypes=False) +convergence_params = dict(rtol=1e-5, atol=1e-5) class CGATest(jax.test_util.JaxTestCase): @@ -34,8 +39,8 @@ def test_cga_lagrange_min(self): init_mult, lagrangian, get_x = make_lagrangian(func, eq_constraints) - rng = random.PRNGKey(8413) - init_params = random.uniform(rng, (n,)) + rng = jax.random.PRNGKey(8413) + init_params = jax.random.uniform(rng, (n,)) lagr_params = init_mult(init_params) lr = 0.5 @@ -48,7 +53,8 @@ def convergence_test(x_new, x_old): @jax.jit def step(i, opt_state): params = get_params(opt_state) - grads = jax.grad(lagrangian, (0, 1))(*params) + grad_fn = jax.grad(lagrangian, (0, 1)) + grads = grad_fn(*params) return opt_update(i, grads, opt_state) opt_state = opt_init(lagr_params) @@ -65,10 +71,10 @@ def step(i, opt_state): check_dtypes=False) h = eq_constraints(get_x(final_params)) - self.assertAllClose(h, tree_util.tree_map(np.zeros_like, h), + self.assertAllClose(h, jax.tree_util.tree_map(np.zeros_like, h), check_dtypes=False) - @parameterized.parameters( + @absl.testing.parameterized.parameters( {'method': cga_ecp, 'kwargs': {'max_iter': 1000, 'lr_func': 0.5}}, {'method': slsqp_ecp, 'kwargs': {'max_iter': 1000}}, ) @hypothesis.settings(max_examples=10, deadline=5000.) @@ -86,8 +92,8 @@ def objective(x, y): def constraints(x, y): return 1 - np.linalg.norm(np.asarray([x, y])) - rng = random.PRNGKey(8413) - initial_values = random.uniform(rng, (len(v),)) + rng = jax.random.PRNGKey(8413) + initial_values = jax.random.uniform(rng, (len(v),)) solution = method(objective, constraints, initial_values, **kwargs) @@ -96,7 +102,7 @@ def constraints(x, y): objective(*solution.value), check_dtypes=False) - @parameterized.parameters( + @absl.testing.parameterized.parameters( {'method': implicit_ecp, 'kwargs': {'max_iter': 1000, 'lr_func': 0.01, 'optimizer': optimizers.adam}}, {'method': cga_ecp, 'kwargs': {'max_iter': 1000, 'lr_func': 0.15, 'lr_multipliers': 0.925}}, @@ -115,8 +121,7 @@ def test_omd(self, method, kwargs): def smooth_bellman_optimality_operator(x, params): transition, reward, discount, temperature = params - return reward + discount * np.einsum('ast,t->sa', transition, temperature * - logsumexp((1. / temperature) * x, axis=1)) + return reward + discount * np.einsum('ast,t->sa', transition, temperature * logsumexp((1. / temperature) * x, axis=1)) @jax.jit def objective(x, params): @@ -143,5 +148,53 @@ def equality_constraints(x, params): self.assertAllClose(objective(*solution.value), optimal_value, check_dtypes=False) +class EGTest(jax.test_util.JaxTestCase): + @absl.testing.parameterized.parameters(fax.test_util.load_HockSchittkowski_models()) + def test_eg_HockSchittkowski(self, objective_function, equality_constraints, hs_optimal_value: np.array, initial_value): + def convergence_test(x_new, x_old): + return fax.converge.max_diff_test(x_new, x_old, **convergence_params) + + initialize_multipliers, lagrangian, get_x = make_lagrangian(objective_function, equality_constraints) + + x0 = initial_value() + initial_values = initialize_multipliers(x0) + + final_val, h, x, multiplier = self.eg_solve(lagrangian, convergence_test, equality_constraints, objective_function, get_x, initial_values) + + import scipy.optimize + constraints = ({'type': 'eq', 'fun': equality_constraints, },) + + res = scipy.optimize.minimize(lambda *args: -objective_function(*args), initial_values[0], method='SLSQP', constraints=constraints) + scipy_optimal_value = -res.fun + scipy_constraint = equality_constraints(res.x) + + self.assertAllClose(final_val, scipy_optimal_value, **test_params) + self.assertAllClose(h, scipy_constraint, **test_params) + + def eg_solve(self, lagrangian, convergence_test, equality_constraints, objective_function, get_x, initial_values): + optimizer_init, optimizer_update, optimizer_get_params = extragradient.adam_extragradient_optimizer( + step_size_x=jax.experimental.optimizers.inverse_time_decay(1e-1, 50, 0.3, staircase=True), + step_size_y=5e-2, + ) + + @jax.jit + def update(i, opt_state): + grad_fn = jax.grad(lagrangian, (0, 1)) + return optimizer_update(i, grad_fn, opt_state) + + solution = fax.loop.fixed_point_iteration( + init_x=optimizer_init(initial_values), + func=update, + convergence_test=convergence_test, + max_iter=100000000, + get_params=optimizer_get_params, + f=lagrangian, + ) + x, multipliers = get_x(solution) + final_val = objective_function(x) + h = equality_constraints(x) + return final_val, h, x, multipliers + + if __name__ == "__main__": absltest.main() diff --git a/fax/jax_utils.py b/fax/jax_utils.py new file mode 100644 index 0000000..bcbb470 --- /dev/null +++ b/fax/jax_utils.py @@ -0,0 +1,31 @@ +import functools + +from jax import tree_util, lax, numpy as np + +division = functools.partial(tree_util.tree_multimap, lax.div) +add = functools.partial(tree_util.tree_multimap, lax.add) +sub = functools.partial(tree_util.tree_multimap, lax.sub) +mul = functools.partial(tree_util.tree_multimap, lax.mul) +square = functools.partial(tree_util.tree_map, lax.square) + + +def division_constant(constant): + def divide(a): + return tree_util.tree_multimap(lambda _a: _a / constant, a) + + return divide + + +def multiply_constant(constant): + return functools.partial(mul, constant) + + +def expand_like(a, b): + return a * np.ones(b.shape, b.dtype) + + +def make_exp_smoothing(beta): + def exp_smoothing(state, var): + return multiply_constant(beta)(state) + multiply_constant((1 - beta))(var) + + return exp_smoothing diff --git a/fax/loop.py b/fax/loop.py index bf575eb..7e52cc0 100644 --- a/fax/loop.py +++ b/fax/loop.py @@ -2,6 +2,7 @@ import warnings import jax +import jax.lax import jax.numpy as np FixedPointSolution = collections.namedtuple( @@ -28,8 +29,7 @@ def unrolled(i, init_x, func, num_iter, return_last_two=False): x_old = None for _ in range(num_iter): - x_old = x - x = func(i, x_old) + x, x_old = func(i, x), x i = i + 1 if return_last_two: @@ -38,8 +38,7 @@ def unrolled(i, init_x, func, num_iter, return_last_two=False): return i, x -def fixed_point_iteration(init_x, func, convergence_test, max_iter, - batched_iter_size=1, unroll=False): +def fixed_point_iteration(init_x, func, convergence_test, max_iter, batched_iter_size=1, unroll=False, get_params=lambda x: x, f=None) -> FixedPointSolution: """Find a fixed point of `func` by repeatedly applying `func`. Use this function to find a fixed point of `func` by repeatedly applying @@ -104,6 +103,7 @@ def fixed_point_iteration(init_x, func, convergence_test, max_iter, def cond(args): i, x_new, x_old = args + x_new, x_old = get_params(x_new), get_params(x_old) converged = convergence_test(x_new, x_old) if max_iter is not None: @@ -136,13 +136,13 @@ def scan_step(args, idx): xs=np.arange(max_batched_iter - 1), ) converged = convergence_test(sol, prev_sol) - else: iterations, sol, prev_sol = jax.lax.while_loop( cond, body, init_vals, ) + sol, prev_sol = get_params(sol), get_params(prev_sol) converged = max_iter is None or iterations < max_iter return FixedPointSolution( diff --git a/fax/loop_test.py b/fax/loop_test.py index 1b83c92..af22e4e 100644 --- a/fax/loop_test.py +++ b/fax/loop_test.py @@ -1,17 +1,16 @@ +import jax +import jax.numpy as np +import jax.test_util +import numpy as onp from absl.testing import absltest from absl.testing import parameterized - -import numpy as onp +from jax.config import config from numpy import testing from fax import converge from fax import loop from fax import test_util -import jax -import jax.numpy as np -import jax.test_util -from jax.config import config config.update("jax_enable_x64", True) @@ -80,13 +79,13 @@ def step(i, x_old): return x_old + 1 sol = loop.fixed_point_iteration( - init_x=init_x, - func=step, - convergence_test=convergence_test, - max_iter=max_steps, - batched_iter_size=1, - unroll=unroll, - ) + init_x=init_x, + func=step, + convergence_test=convergence_test, + max_iter=max_steps, + batched_iter_size=1, + unroll=unroll, + ) self.assertFalse(sol.converged) self.assertEqual(sol.iterations, max_steps) @@ -233,7 +232,7 @@ def testUnrollGrad(self, jit): def step(i, x): del i - return x*0.1 + return x * 0.1 def converge_test(x_new, x_old): return np.max(x_new - x_old) < 1e-3 @@ -357,27 +356,27 @@ def _fixedpoint_iteration_solver(unroll, default_atol=1e-10, default_max_iter=200, default_batched_iter_size=1): + def fixed_point_iteration_solver(init_x, params): + rtol, atol = converge.adjust_tol_for_dtype(default_rtol, + default_atol, + init_x.dtype) - def fixed_point_iteration_solver(init_x, params): - rtol, atol = converge.adjust_tol_for_dtype(default_rtol, - default_atol, - init_x.dtype) + def convergence_test(x_new, x_old): + return converge.max_diff_test(x_new, x_old, rtol, atol) - def convergence_test(x_new, x_old): - return converge.max_diff_test(x_new, x_old, rtol, atol) + func = param_func(params) + sol = loop.fixed_point_iteration( + init_x=init_x, + func=func, + convergence_test=convergence_test, + max_iter=default_max_iter, + batched_iter_size=default_batched_iter_size, + unroll=unroll, + ) - func = param_func(params) - sol = loop.fixed_point_iteration( - init_x=init_x, - func=func, - convergence_test=convergence_test, - max_iter=default_max_iter, - batched_iter_size=default_batched_iter_size, - unroll=unroll, - ) + return sol - return sol - return fixed_point_iteration_solver + return fixed_point_iteration_solver class UnrolledFixedPointIterationTest(test_util.FixedPointTestCase): diff --git a/fax/test_util.py b/fax/test_util.py index 21caa74..069729d 100644 --- a/fax/test_util.py +++ b/fax/test_util.py @@ -140,7 +140,6 @@ def testGradient(self): solver = self.make_solver(param_ax_plus_b) def loss(x, params): return np.sum(solver(x, params).value) - jax.test_util.check_grads( loss, (x0, (matrix, offset),), @@ -177,6 +176,9 @@ def equality_constraints(params): def dot_product_minimization(v): """Problem: find a u such that np.dot(u, v) is maximum, subject to np.linalg.norm(u) = 1. + + Args: + n (integer): Number of components for the fixed vector `v` """ def func(u): @@ -191,3 +193,210 @@ def equality_constraints(u): return func, equality_constraints, optimal_solution, optimal_value +def get_list(rows): + param_list = [] + skipped = [] + for row in rows: + row_text = row.lstrip() + if not row_text: + continue + + if row_text.startswith("!"): + skipped.append(row_text) + continue + + if {">=", "<=", "<", ">"}.intersection(row_text): + raise NotImplementedError("no inequalities") + + if row_text[0].isupper() and row.replace(" ", "").isalpha(): + assert row_text.startswith("End") + return param_list, skipped + else: + param_list.append(row_text) + raise ValueError + + +def get_struct(rows): + struct = {} + skipped = [] + for row in rows: + if not row: + continue + + if row[0] == '!' or row[0] == '#': + skipped.append(row) + continue + + row_text = row.lstrip() + if row_text == "End Model": + continue + + if row_text[0].isupper(): + struct[row_text], skipped_ = get_struct(rows) + skipped.extend(skipped_) + else: + params, skipped_ = get_list(itertools.chain([row], rows)) + skipped.extend(skipped_) + return params, skipped + return struct, skipped + + +def text_to_code(variable, equation, closure): + cost_function = equation.replace("^", "**") + seq = [] + for a in cost_function.split("]"): + if "[" not in a: + seq.append(a) + else: + rest, num = a.split("[") + b = f"{rest}[{int(num) - 1}" + seq.append(b) + cost_function = "]".join(seq) + scope = ", ".join(k for k in closure.keys() if not k.startswith("__")) + cost_function_ = f"{variable} = lambda {scope}: {cost_function}" + return cost_function_ + + +def apm_to_python(text: Text) -> Union[Text, None]: + """Convert APM format to a python code file. + + Args: + text: APM contains of the APM file. + """ + + if "Intermediates" in text: + raise NotImplementedError("Not implemented yet, maybe never.") + if "does not exist" in text: + raise NotImplementedError("I'm not sure how to handle those.") + rows = iter(text.splitlines()) + + struct, skipped = get_struct(rows) + + if len(struct) != 1: + raise NotImplementedError(f"Found {len(struct)} models in a file, only one is supported.") + (model_name, model_struct), = struct.items() + + python_code = f"class {model_name.split('Model ')[1].title()}(Hs):\n" + + var_sizes, python_code = _parse_initialization(model_struct, python_code) + python_code = _parse_equations(model_struct, python_code, var_sizes) + + skipped, python_code = _parse_optimal_solution(python_code, skipped) + python_code = _parse_constraints(model_struct, python_code) + + python_code = python_code.replace("\t", " ") + return python_code + + +def _parse_constraints(model_struct, python_code): + constraints = [] + for equation in model_struct['Equations']: + lhs, rhs = equation.split("=") + if lhs.strip() != 'obj': + if not set(rhs.strip()).difference({'0', '.', ','}): + lhs = f"{lhs} - {rhs}" + + constraint_variable = f"h{len(constraints)}" + + cost_function = text_to_code(constraint_variable, lhs, {'x': None}) + python_code += f"\t{cost_function}\n" + constraints.append(constraint_variable) + + if constraints: + python_code += f""" +\t@classmethod +\tdef constraints(cls, x): +\t\treturn stack((cls.{'(x), cls.'.join(constraints)}(x), )) +""" + return python_code + + +def _parse_optimal_solution(python_code, skipped): + for idx, comment in enumerate(skipped): + if "! best known objective =" in comment: + _, optimal_solution = comment.split("=") + + python_code += f"\toptimal_solution = -array({optimal_solution.strip()})\n" + break + else: + raise ValueError("No solution found") + del skipped[idx] + return skipped, python_code + + +def _parse_equations(model_struct, python_code, var_sizes): + for obj in model_struct["Equations"]: + variable, equation = (o.strip() for o in obj.split("=")) + + if "obj" in variable: + # By default we maximize here. + equation = "-(" + equation + ")" + + cost_function = text_to_code(variable, equation, var_sizes) + cost_function = cost_function.replace("obj =", "objective_function =") + python_code += f"\t{cost_function}\n" + return python_code + + +def _parse_initialization(model_struct, python_code): + var_sizes = collections.defaultdict(int) + for obj in model_struct["Variables"]: + if obj == "obj": + continue + + variable, value = obj.split("=") + var, size = variable.split("[") + size, _ = size.split("]") + if ":" in size: + size = max(int(s) for s in size.split(":")) + else: + size = int(size) + + var_sizes[var] = max(var_sizes[var], size) + if var_sizes: + python_code += f"\t@staticmethod\n" + python_code += f"\tdef initialize():\n" + if len(var_sizes) != 1: + raise NotImplementedError("There should only be one (multidimensional) state variable") + + (k, v), = var_sizes.items() + python_code += f"\t\treturn zeros({v}) # {k}\n\n\n" + + return var_sizes, python_code + + +def maybe_download_tests(work_directory): + if not os.path.exists(work_directory): + os.mkdir(work_directory) + filepath = os.path.join(work_directory, "hs.zip") + if not os.path.exists(filepath): + filepath, _ = urllib.request.urlretrieve(APM_TESTS, filepath) + print('Downloaded test file in', work_directory) + return filepath + + +def parse_HockSchittkowski_models(test_folder): # noqa + zip_file_path = maybe_download_tests(tempfile.gettempdir()) + if not os.path.exists(test_folder): + os.makedirs(test_folder, exist_ok=True) + + with open(os.path.join(test_folder, "HockSchittkowski.py"), "w") as test_definitions: + test_definitions.write("from jax.numpy import *\n\n\n") + test_definitions.write("class Hs:\n") + test_definitions.write(" constraints = lambda *args: 0.\n\n\n") + + with zipfile.ZipFile(zip_file_path) as test_archive: + for test_case_path in test_archive.filelist: + try: + with test_archive.open(test_case_path) as test_case: + python_code = apm_to_python(test_case.read().decode('utf-8')) + except NotImplementedError: + continue + else: + test_definitions.write(python_code + "\n\n\n") + + +def load_HockSchittkowski_models(): # noqa + import fax.tests.hock_schittkowski_suite + for model in fax.tests.hock_schittkowski_suite.load_suite(): + yield model.objective_function, model.constraints, model.optimal_solution, model.initialize