-
Notifications
You must be signed in to change notification settings - Fork 8
Extragradient and tests #21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
manuel-delverme
wants to merge
22
commits into
gehring:master
Choose a base branch
from
manuel-delverme:extragradient_test
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
3c4d321
Basic implementation of extra-gradient
pierrelux 157e2e3
Merge remote-tracking branch 'manuel-delverme/master' into extragradient
manuel-delverme 6a10a40
Added tests, forwarding the cost function for debugging purposes (not…
manuel-delverme a2c0a8e
This commit includes only non functional changes, some of them might …
manuel-delverme c8ae755
added more details to setup.py
manuel-delverme fd79e66
added rprop extra gradient
manuel-delverme e8accf6
rprop EG solves ~half of the constrained tasks
manuel-delverme c2af669
some tests fail, im not sure why
manuel-delverme 927c37e
passing all the tests
manuel-delverme 967f73e
Merge remote-tracking branch 'manuel-delverme/HockSchittkowski_tests'…
manuel-delverme bc12646
3/17 tests fail
manuel-delverme bf6d58c
non working extragradient.py cleanup
manuel-delverme 0b3a821
2/20 tests fail
manuel-delverme 139b9f5
some cleanup
manuel-delverme fa21e13
removing basic tests, the HS test suite is broad enough;
manuel-delverme da1baa3
removed extragradient_test.py, the tests are in constrained_test.py
manuel-delverme 5350c21
Merge branch 'master' into extragradient_test
manuel-delverme 4f3883b
Addressing some reviews:
manuel-delverme 7bc1c92
Merge remote-tracking branch 'manuel-delverme/extragradient_test' int…
manuel-delverme 5d33f8e
reverted setup.py changes and indentation levels
manuel-delverme 498f2eb
Merge branch 'master' into extragradient_test
manuel-delverme 19ae12b
Fixed jax.numpy import
manuel-delverme File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,55 @@ | ||
| from typing import Callable | ||
|
|
||
| import jax.experimental.optimizers | ||
| from jax import numpy as np, tree_util | ||
|
|
||
| import fax.competitive.sgd | ||
| from fax.jax_utils import add | ||
|
|
||
|
|
||
| def adam_extragradient_optimizer(step_size_x, step_size_y, b1=0.3, b2=0.2, eps=1e-8) -> (Callable, Callable, Callable): | ||
| """Construct optimizer triple for Adam. | ||
|
|
||
| Args: | ||
| step_size_x: positive scalar, or a callable representing a step size schedule | ||
| that maps the iteration index to positive scalar for the first player. | ||
| step_size_y: positive scalar, or a callable representing a step size schedule | ||
| that maps the iteration index to positive scalar for the second player. | ||
| b1: optional, a positive scalar value for beta_1, the exponential decay rate | ||
| for the first moment estimates (default 0.3). | ||
| b2: optional, a positive scalar value for beta_2, the exponential decay rate | ||
| for the second moment estimates (default 0.2). | ||
| eps: optional, a positive scalar value for epsilon, a small constant for | ||
| numerical stability (default 1e-8). | ||
|
|
||
| Returns: | ||
| An (init_fun, update_fun, get_params) triple. | ||
| """ | ||
| step_size_x = jax.experimental.optimizers.make_schedule(step_size_x) | ||
| step_size_y = jax.experimental.optimizers.make_schedule(step_size_y) | ||
|
|
||
| def init(initial_values): | ||
| mean_avg = tree_util.tree_map(lambda x: np.zeros(x.shape, x.dtype), initial_values) | ||
| var_avg = tree_util.tree_map(lambda x: np.zeros(x.shape, x.dtype), initial_values) | ||
| return initial_values, (mean_avg, var_avg) | ||
|
|
||
| def update(step, grad_fns, state): | ||
| x0, optimizer_state = state | ||
| step_sizes = - step_size_x(step), step_size_y(step) # negate the step size so that we do gradient ascent-descent | ||
|
|
||
| grads = grad_fns(*x0) | ||
| deltas, optimizer_state = fax.competitive.sgd.adam_step(b1, b2, eps, step_sizes, grads, optimizer_state, step) | ||
|
|
||
| x_bar = add(x0, deltas) | ||
|
|
||
| grads = grad_fns(*x_bar) # the gradient is evaluated at x_bar | ||
| deltas, optimizer_state = fax.competitive.sgd.adam_step(b1, b2, eps, step_sizes, grads, optimizer_state, step) | ||
| x1 = add(x0, deltas) # but applied at x_0 | ||
|
|
||
| return x1, optimizer_state | ||
|
|
||
| def get_params(state): | ||
| x, _optimizer_state = state | ||
| return x | ||
|
|
||
| return init, update, get_params |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,31 @@ | ||
| import functools | ||
|
|
||
| from jax import tree_util, lax, numpy as np | ||
|
|
||
| division = functools.partial(tree_util.tree_multimap, lax.div) | ||
| add = functools.partial(tree_util.tree_multimap, lax.add) | ||
| sub = functools.partial(tree_util.tree_multimap, lax.sub) | ||
| mul = functools.partial(tree_util.tree_multimap, lax.mul) | ||
| square = functools.partial(tree_util.tree_map, lax.square) | ||
|
|
||
|
|
||
| def division_constant(constant): | ||
| def divide(a): | ||
| return tree_util.tree_multimap(lambda _a: _a / constant, a) | ||
|
|
||
| return divide | ||
|
|
||
|
|
||
| def multiply_constant(constant): | ||
| return functools.partial(mul, constant) | ||
|
|
||
|
|
||
| def expand_like(a, b): | ||
| return a * np.ones(b.shape, b.dtype) | ||
|
|
||
|
|
||
| def make_exp_smoothing(beta): | ||
| def exp_smoothing(state, var): | ||
| return multiply_constant(beta)(state) + multiply_constant((1 - beta))(var) | ||
|
|
||
| return exp_smoothing |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.