diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 34de2f7..34b51c5 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -1,6 +1,9 @@ name: Publish Python 🐍 distribution 📦 to PyPI and TestPyPI +on: + push: + branches: + - master # Set this to your default branch -on: push jobs: build: diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml index 3f6852b..f21e816 100644 --- a/.github/workflows/run_tests.yml +++ b/.github/workflows/run_tests.yml @@ -1,10 +1,7 @@ name: Python Tests on: - push: - branches: [ master ] pull_request: - branches: [ master ] jobs: test: diff --git a/docs/usage/faq.md b/docs/usage/faq.md index ae69cd5..d4a2d61 100644 --- a/docs/usage/faq.md +++ b/docs/usage/faq.md @@ -10,11 +10,11 @@ delays = torch.tensor([1.0, 2.0]) history_function = lambda t : ... ts = ... -def simple_dde(t, y, args, *, history): +def simple_dde(t, y, func_args, *, history): # this correspond to y'(t) = -y(t-1) - y(t-2) return - history[0] - history[1] -ys = torchdde.integrate(f, solver, ts[0], ts[-1], ts, history_function, args=None, dt0=ts[1]-ts[0], delays=delays) +ys = torchdde.integrate(f, solver, ts[0], ts[-1], ts, history_function, func_args=None, dt0=ts[1]-ts[0], delays=delays) ``` ## How about if I want a neural network to have also several delays ? diff --git a/docs/usage/getting-started.md b/docs/usage/getting-started.md index c519e70..6e691bb 100644 --- a/docs/usage/getting-started.md +++ b/docs/usage/getting-started.md @@ -12,7 +12,7 @@ import matplotlib.pyplot as plt from torchdde import RK4 import torch -def simple_dde(t, y, args, *, history): +def simple_dde(t, y, func_args, *, history): # `history` corresponds to the list of # delayed states defined in your DDE # i.e. here history=[y(t-2)] @@ -54,7 +54,7 @@ import matplotlib.pyplot as plt from torchdde import RK4 import torch -def simple_ode(t, y, args): +def simple_ode(t, y, func_args): return -y**2 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") diff --git a/docs/usage/integration-de.md b/docs/usage/integration-de.md index e704d42..8a2dfa8 100644 --- a/docs/usage/integration-de.md +++ b/docs/usage/integration-de.md @@ -13,10 +13,10 @@ What essentially differentiates DDEs with ODEs are : In practice, your function will be defined like this : ```python - def f_ode(t,y,args): + def f_ode(t,y,func_args): return ... - def f_dde(t,y,args, history): + def f_dde(t,y,func_args, history): return ... ``` diff --git a/docs/usage/neural-dde.md b/docs/usage/neural-dde.md index 21aa635..9d69391 100644 --- a/docs/usage/neural-dde.md +++ b/docs/usage/neural-dde.md @@ -43,7 +43,7 @@ class NDDE(nn.Module): hidden_channels=depth * [width_size] + [out_size], ) - def forward(self, t, z, args, *, history): + def forward(self, t, z, func_args, *, history): # `history` corresponds to the list of # delayed states defined in your DDE # i.e. here history=[y(t-tau1), ..., y(t-taun)] @@ -54,11 +54,11 @@ We generate the toy dataset of the [delayed logistic equation](https://www.math. ```python def get_data(y0, ts, tau=torch.tensor([1.0])): - def f(t, y, args, history): + def f(t, y, func_args, history): return y * (1 - history[0]) history_function = lambda t: torch.unsqueeze(y0, dim=1) - ys = integrate(f, Euler(), ts[0], ts[-1], ts, history_function, args=None, dt0=ts[1]-ts[0], delays=tau) + ys = integrate(f, Euler(), ts[0], ts[-1], ts, history_function, func_args=None, dt0=ts[1]-ts[0], delays=tau) return ys @@ -136,7 +136,7 @@ def main( plt.plot(ts.cpu(), data[0].cpu(), c="dodgerblue", label="Real") history_values = data[0, 0][..., None] history_fn = lambda t: history_values - ys_pred = integrate(model, Euler(), ts[0], ts[-1], ts, history_fn, args=None, dt0=ts[1]-ts[0], delays=tau) + ys_pred = integrate(model, Euler(), ts[0], ts[-1], ts, history_fn, func_args=None, dt0=ts[1]-ts[0], delays=tau) plt.plot( ts.cpu(), ys_pred[0].cpu().detach(), diff --git a/docs/usage/training-de.md b/docs/usage/training-de.md index 1e9724c..b022c08 100644 --- a/docs/usage/training-de.md +++ b/docs/usage/training-de.md @@ -22,7 +22,7 @@ for step, data in enumerate(train_loader): t1=ts[-1], ts=ts, y0=..., - args=None, + func_args=..., dt0=ts[1] - ts[0], delays=..., ) diff --git a/pyproject.toml b/pyproject.toml index 168fd62..603aa0e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "torchdde" -version = "0.1.2" +version = "0.2.0" description = "DDE numerical solvers in Python." readme = "README.md" requires-python =">=3.9" diff --git a/torchdde/__init__.py b/torchdde/__init__.py index 9fabdc5..a1f303e 100644 --- a/torchdde/__init__.py +++ b/torchdde/__init__.py @@ -11,6 +11,7 @@ FourthOrderPolynomialInterpolation as FourthOrderPolynomialInterpolation, ThirdOrderPolynomialInterpolation as ThirdOrderPolynomialInterpolation, ) +from .misc import TupleTensorTransformer as TupleTensorTransformer from .solver import ( AbstractOdeSolver as AbstractOdeSolver, Bosh3 as Bosh3, diff --git a/torchdde/adjoint_dde.py b/torchdde/adjoint_dde.py index 53def7c..a84ea1e 100644 --- a/torchdde/adjoint_dde.py +++ b/torchdde/adjoint_dde.py @@ -6,6 +6,7 @@ from torchdde.global_interpolation.linear_interpolation import TorchLinearInterpolator from torchdde.integrate import _integrate_dde, _integrate_ode +from torchdde.misc import TupleTensorTransformer from torchdde.solver.base import AbstractOdeSolver from torchdde.step_size_controller.base import AbstractStepSizeController from torchdde.step_size_controller.constant import ConstantStepSizeController @@ -22,7 +23,7 @@ def forward( # type: ignore history_func: Callable[ [Float[torch.Tensor, ""]], Float[torch.Tensor, "batch ..."] ], - args: Any, + func_args: Any, solver: AbstractOdeSolver, stepsize_controller: AbstractStepSizeController = ConstantStepSizeController(), dt0: Optional[Float[torch.Tensor, ""]] = None, @@ -34,7 +35,7 @@ def forward( # type: ignore ctx.stepsize_controller = stepsize_controller ctx.solver = solver ctx.func = func - ctx.args = args + ctx.func_args = func_args ctx.ts = ts ctx.t0 = t0 ctx.t1 = t1 @@ -49,8 +50,8 @@ def forward( # type: ignore ts, history_func(t0), history_func, - args, - func.delays, + func_args, + func.delays, # type: ignore solver, stepsize_controller, dt0=dt0, @@ -69,8 +70,9 @@ def backward(ctx, *grad_y) -> Any: # type: ignore # as learnable parameter alongside with the neural network. grad_output = grad_y[0] - args = ctx.args - dt = ctx.ts[1] - ctx.ts[0] + func_args = ctx.func_args + ts = ctx.ts + dt = ts[1] - ts[0] solver = ctx.solver stepsize_controller = ctx.stepsize_controller params = ctx.saved_tensors @@ -113,35 +115,71 @@ def backward(ctx, *grad_y) -> Any: # type: ignore torch.concat([adjoint_ys_final, adjoint_ys_final], dim=1), ) - def adjoint_dyn(t, adjoint_y, args): - h_t = torch.autograd.Variable( + # augment_state = [adjoint_state, params_incr] + aug_state = [torch.zeros_like(adjoint_state)] + aug_state.extend([torch.zeros_like(param) for param in params]) + transformer = TupleTensorTransformer.from_tuple(aug_state) + + def augment_dyn(t, aug_state, func_args): + adjoint_y, *params_inc = transformer.unflatten(aug_state) + y_t = torch.autograd.Variable( state_interpolator(t) if t > ctx.t0 else ctx.history_func(t), requires_grad=True, ) h_t_minus_tau = [ ( - state_interpolator(t - tau) + torch.autograd.Variable( + state_interpolator(t - tau), requires_grad=True + ) if t - tau > ctx.t0 - else ctx.history_func(t - tau) + else torch.autograd.Variable( + ctx.history_func(t - tau), requires_grad=True + ) ) for tau in ctx.func.delays ] - out = ctx.func(t, h_t, args, history=h_t_minus_tau) - # This correspond to the term adjoint(t) df(t, y(t), y(t-tau))_dy(t) - rhs_adjoint_1 = torch.autograd.grad( - out, - h_t, + func_t = ctx.func(t, y_t, func_args, history=h_t_minus_tau) + # This correspond to the both terms : + # \lambda_t \partial{f_\theta(t)}{g} in adjoint dynamics + # \lambda_t \partial{f_\theta(t)}{\theta} and in loss equation + rhs_adjoint_1, *params_inc = torch.autograd.grad( + func_t, + (y_t,) + params, -adjoint_y, retain_graph=True, allow_unused=True, - )[0] + ) - # we need to add the second term of rhs too in rhs_adjoint computation + # we need to add the second term of rhs in rhs_adjoint computation delay_derivative_inc = torch.zeros_like(ctx.func.delays)[..., None] for idx, tau_i in enumerate(ctx.func.delays): + # This is computing part of the contribution of the gradient's + # loss w.r.t the parameters + # \lambda(t) \partial{f_\theta(t)}{g} + # where g(t) = y(t-tau_i) + params_inc2 = torch.autograd.grad( + func_t, + h_t_minus_tau[idx], + -adjoint_y, + retain_graph=True, + allow_unused=True, + )[0] + + if params_inc2 is None: + pass + else: + delay_derivative_inc[idx] += torch.sum( + params_inc2 * grad_ys[:, -1 - i], + dim=(tuple(range(len(params_inc2.shape)))), + ) + + # if t+ tau_i > T then \lambda(t+tau_i) = 0 + # computing second term of the adjoint dynamics + # \lambda_t \partial{f_\theta(t)}{g} if t < ctx.t1 - tau_i: + print("t", t, "tau_i", tau_i) adjoint_t_plus_tau = adjoint_interpolator(t + tau_i) - h_t_plus_tau = state_interpolator(t + tau_i) + y_t_plus_tau = state_interpolator(t + tau_i) history = [ ( state_interpolator(t + tau_i - tau_j) @@ -150,149 +188,67 @@ def adjoint_dyn(t, adjoint_y, args): ) for tau_j in ctx.func.delays ] - history[idx] = h_t - out_other = ctx.func(t + tau_i, h_t_plus_tau, args, history=history) + history[idx] = y_t + func_t_plus_tau_i = ctx.func( + t + tau_i, y_t_plus_tau, func_args, history=history + ) # This correspond to the term - # adjoint(t+tau) df(t+tau, y(t+tau), y(t))_dy(t) + # \lambda(t+tau_i) \partial{f_\theta(t+tau_i)}{y_i} + # where y_i(t) = g(t-\tau_i) rhs_adjoint_2 = torch.autograd.grad( - out_other, h_t, -adjoint_t_plus_tau + func_t_plus_tau_i, y_t, -adjoint_t_plus_tau )[0] rhs_adjoint_1 += rhs_adjoint_2 - # contribution of the delay in the gradient's loss - # ie int_0^{T-\tau} - lambda(t+\tau) \ - # \pdv{f(x_{t+\tau}, x_{t})}{x_t} x'(t) dt - delay_derivative_inc[idx] += torch.sum( - rhs_adjoint_2 * grad_ys[:, -1 - j], - dim=(tuple(range(len(rhs_adjoint_2.shape)))), - ) - - param_derivative_inc = torch.autograd.grad( - out, - params, - -adjoint_y, - retain_graph=True, - allow_unused=True, + params_inc = tuple( + [ + -param + if param is not None + else torch.zeros_like(transformer.original_shapes[i]) + for i, param in enumerate(params) + ] ) - return rhs_adjoint_1, ( - param_derivative_inc, - delay_derivative_inc, + + return transformer.flatten( + ( + rhs_adjoint_1, + -delay_derivative_inc.squeeze(1) + params_inc[0], + *params_inc[1:], + ) ) # computing the adjoint dynamics - out2, out3 = None, None - delay_derivative_inc = torch.zeros_like(ctx.func.delays)[..., None] current_num_steps = 0 - for j in range(len(ctx.ts) - 1, 0, -1): + for i in range(len(ts) - 1, 0, -1): current_num_steps += 1 if current_num_steps > ctx.max_steps: raise RuntimeError("Maximum number of steps reached") - tprev, tnext = ctx.ts[j], ctx.ts[j - 1] - dt = tnext - tprev + t0, t1 = ts[i], ts[i - 1] + dt = t1 - t0 dt = torch.clamp(dt, max=torch.min(ctx.func.delays)) with torch.enable_grad(): - adjoint_state = adjoint_state - grad_output[:, j] - adjoint_interpolator.add_point(tprev, adjoint_state) - ( - adjoint_state, - (param_derivative_inc, delay_derivative_inc), - ) = _integrate_ode( - adjoint_dyn, - tprev, - tnext, - tnext[None], - adjoint_state, - args, + aug_state[0] += grad_output[:, i] + adjoint_interpolator.add_point(t0, aug_state[0]) + aug_state = transformer.flatten(aug_state) + new_aug_state, _ = _integrate_ode( + augment_dyn, + t0, + t1, + t1[None], + aug_state, + func_args, solver, stepsize_controller, dt, ctx.max_steps, - has_aux=True, ) - adjoint_state = adjoint_state.squeeze(dim=1) - if out2 is None: - out2 = tuple([dt.abs() * p for p in param_derivative_inc]) - else: - for _1, _2 in zip([*out2], [*param_derivative_inc]): - if _2 is not None: - _1 += dt.abs() * _2 - - if out3 is None: - out3 = tuple([-dt.abs() * p for p in delay_derivative_inc]) - else: - for _1, _2 in zip([*out3], [*delay_derivative_inc]): - if _2 is not None: - _1 += -dt.abs() * _2 - - # Checking if the history function is a nn.Module - # If it is, we need to compute the last contribution - # of the dL/dtheta - if isinstance(ctx.history_func, nn.Module): - # adding the last contribution of the delay - # parameters in the loss w.r.t. the parameters - # ie which is the last part of the integration - # from t = 0 to t = -tau - # we must have that T > tau otherwise - # the integral isn't properly defined - # There is no mention of this anywhere in - # the litterature so this an assumption - if (ctx.t1 - ctx.t0) < max(ctx.func.delays): - raise ValueError( - "The integration span `t1-t0` must \ - be greater than the maximum delay" - ) - for idx, tau_i in enumerate(ctx.func.delays): - ts_history_i = torch.linspace( - ctx.t0 - tau_i.item(), ctx.t0, int(tau_i.item() / dt.abs()) - ).to(ctx.ts.device) - for k in range(len(ts_history_i) - 1, 0, -1): - t = ts_history_i[k] - with torch.enable_grad(): - h_t = torch.autograd.Variable( - ( - state_interpolator(t) - if t > ctx.t0 - else ctx.history_func(t) - ), - requires_grad=True, - ) - adjoint_t_plus_tau = adjoint_interpolator(t + tau_i) - h_t_plus_tau = state_interpolator(t + tau_i) - history = [ - ( - state_interpolator(t + tau_i - tau_j) - if t + tau_i - tau_j >= ctx.t0 - else ctx.history_func(t + tau_i - tau_j) - ) - for tau_j in ctx.func.delays - ] - history[idx] = h_t - out_other = ctx.func( - t + tau_i, h_t_plus_tau, args, history=history - ) - rhs_adjoint_inc = torch.autograd.grad( - out_other, h_t, -adjoint_t_plus_tau - )[0] - # remaining contribution of the delay in the gradient's loss - # int_{-\tau}^{0} \pdv{f(x_{t+\tau}, x_{t})}{x_t} x'(t) dt - delay_derivative_inc[idx] += torch.sum( - rhs_adjoint_inc * grad_ys[:, k], - dim=(tuple(range(len(rhs_adjoint_inc.shape)))), - ) + aug_state = transformer.unflatten(new_aug_state) - if out3 is not None: - for _1, _2 in zip([*out3], [*delay_derivative_inc]): - if _2 is not None: - _1 += -dt.abs() * _2 + params_incr = aug_state[1:] tuple_nones = (None, None, None, None, None, None, None, None, None, None) - if out3 is not None and out2 is not None: - return *tuple_nones, *(out3[0] + out2[0], *out2[1:]) # type: ignore - elif out3 is None and out2 is not None: - return *tuple_nones, *(out2[0], *out2[1:]) # type: ignore - else: - return *tuple_nones, *(out2[0], *out2[1:]) # type: ignore + return *tuple_nones, *params_incr def ddesolve_adjoint( @@ -301,7 +257,7 @@ def ddesolve_adjoint( t1: Float[torch.Tensor, ""], ts: Float[torch.Tensor, " time"], history_func: Callable[[Float[torch.Tensor, ""]], Float[torch.Tensor, "batch ..."]], - args: Any, + func_args: Any, solver: AbstractOdeSolver, stepsize_controller: AbstractStepSizeController = ConstantStepSizeController(), dt0: Optional[Float[torch.Tensor, ""]] = None, @@ -327,7 +283,7 @@ def ddesolve_adjoint( t1, ts, history_func, - args, + func_args, solver, stepsize_controller, dt0, diff --git a/torchdde/adjoint_ode.py b/torchdde/adjoint_ode.py index 76bb003..a98fb0b 100644 --- a/torchdde/adjoint_ode.py +++ b/torchdde/adjoint_ode.py @@ -5,6 +5,7 @@ from jaxtyping import Float from torchdde.integrate import _integrate_ode +from torchdde.misc import TupleTensorTransformer from torchdde.solver.base import AbstractOdeSolver from torchdde.step_size_controller.base import AbstractStepSizeController from torchdde.step_size_controller.constant import ConstantStepSizeController @@ -31,6 +32,7 @@ def forward( # type: ignore ctx.ts = ts ctx.y0 = y0 ctx.solver = solver + ctx.dt0 = dt0 ctx.stepsize_controller = stepsize_controller ctx.max_steps = max_steps @@ -57,49 +59,58 @@ def backward(ctx, *grad_y): # type: ignore # grad_output holds the gradient of the # loss w.r.t. each evaluation step grad_output = grad_y[0] - dt = ctx.ts[1] - ctx.ts[0] ys = ctx.ys ts = ctx.ts + dt0 = ctx.dt0 args = ctx.args solver = ctx.solver stepsize_controller = ctx.stepsize_controller params = ctx.saved_tensors - adjoint_state = grad_output[:, -1] + # aug_state will hold the [y_t, adjoint_state, params_incr] + aug_state = [torch.zeros_like(ys[:, -1]), torch.zeros_like(ys[:, -1])] + aug_state.extend([torch.zeros_like(param) for param in params]) + transformer = TupleTensorTransformer.from_tuple(aug_state) + + def augmented_dyn(t, aug_state, args): + y_t, adjoint_state, *params_inc = transformer.unflatten(aug_state) + out = ctx.func(t, y_t, args) + adjoint_state, *params_inc = torch.autograd.grad( + out, + (y_t,) + params, + -adjoint_state, + retain_graph=True, + allow_unused=True, + ) + return transformer.flatten((y_t, adjoint_state, *params_inc)) - out2 = None for i in range(len(ts) - 1, 0, -1): t0, t1 = ts[i], ts[i - 1] - dt = t1 - t0 + dt0 = t1 - t0 y_t = torch.autograd.Variable(ys[:, i], requires_grad=True) + + aug_state[0] = y_t + aug_state[1] += grad_output[:, i] + with torch.enable_grad(): - out = ctx.func(ts[i], y_t, args) - adj_dyn = lambda t, adj_y, args: torch.autograd.grad( - out, y_t, -adj_y, retain_graph=True - )[0] - adjoint_state, _ = _integrate_ode( - adj_dyn, + aug_state[0] = y_t + aug_state = transformer.flatten(aug_state) + new_aug_state, _ = _integrate_ode( + augmented_dyn, t0, t1, t1[None], - adjoint_state, + aug_state, args, solver, stepsize_controller, - dt, + dt0, ctx.max_steps, ) - adjoint_state = adjoint_state.squeeze(dim=1) - adjoint_state = adjoint_state - grad_output[:, i] - param_inc = torch.autograd.grad( - out, params, -adjoint_state, retain_graph=True - ) + aug_state = transformer.unflatten(new_aug_state) - if out2 is None: - out2 = tuple([dt.abs() * p for p in param_inc]) - else: - for _1, _2 in zip([*out2], [*param_inc]): - _1 += dt.abs() * _2 + adjoint_state = aug_state[1] + params_incr = aug_state[2:] return ( # type: ignore None, None, @@ -111,7 +122,7 @@ def backward(ctx, *grad_y): # type: ignore None, None, None, - *out2, # type: ignore + *params_incr, # type: ignore ) diff --git a/torchdde/integrate.py b/torchdde/integrate.py index ac2d69f..30c3fd5 100644 --- a/torchdde/integrate.py +++ b/torchdde/integrate.py @@ -161,7 +161,6 @@ def _integrate( dt0: Optional[Float[torch.Tensor, ""]] = None, delays: Optional[Float[torch.Tensor, " delays"]] = None, max_steps: int = 100, - has_aux: bool = False, ) -> tuple[ Float[torch.Tensor, "batch time ..."], Union[Callable[[Float[torch.Tensor, ""]], Float[torch.Tensor, "batch ..."]], Any], @@ -172,7 +171,6 @@ def _integrate( - `func`: Pytorch model, i.e vector field - `ts`: Integration span - `y0`: Initial condition for ODE / History function for DDE - - `has_aux`: Whether the model has an auxiliary output. **Returns:** @@ -195,7 +193,6 @@ def _integrate( stepsize_controller, dt0, max_steps=max_steps, - has_aux=has_aux, ) else: assert isinstance(y0, torch.Tensor) @@ -210,7 +207,6 @@ def _integrate( stepsize_controller, dt0, max_steps=max_steps, - has_aux=has_aux, ) @@ -227,7 +223,6 @@ def _integrate_dde( stepsize_controller: AbstractStepSizeController, dt0: Optional[Float[torch.Tensor, ""]] = None, max_steps: Optional[int] = 100, - has_aux: bool = False, ) -> tuple[ Float[torch.Tensor, "batch time ..."], tuple[ @@ -297,7 +292,6 @@ def cond(t, tau): state.dt, state.solver_state, func_args, - has_aux=has_aux, ) ( keep_step, @@ -398,7 +392,6 @@ def _integrate_ode( stepsize_controller: AbstractStepSizeController, dt0: Optional[Float[torch.Tensor, ""]] = None, max_steps: Optional[int] = 100, - has_aux: bool = False, ) -> tuple[Float[torch.Tensor, "batch time ..."], Any]: assert max_steps is not None @@ -439,7 +432,6 @@ def _integrate_ode( state.dt, state.solver_state, func_args, - has_aux=has_aux, ) ( keep_step, @@ -468,9 +460,7 @@ def _integrate_ode( while torch.any(state.tnext >= ts[state.save_idx + step_save_idx :]): idx = state.save_idx + step_save_idx out = interp(ts[idx]) - ys[:, idx] = ( - out.unsqueeze(1) if len(out.shape) != len(ys[:, idx].shape) else out - ) + ys[:, idx] = out.unsqueeze(1) if out.ndim != ys[:, idx].ndim else out step_save_idx += 1 ######################################## @@ -503,6 +493,7 @@ def _integrate_ode( save_idx, ) cond = state.tprev < t1 if (t1 > t0) else state.tprev > t1 + if state.num_steps >= max_steps: raise RuntimeError( f"Maximum number of steps reached \ diff --git a/torchdde/local_interpolation/fourth_order_interpolation.py b/torchdde/local_interpolation/fourth_order_interpolation.py index ba18b78..4e21c81 100644 --- a/torchdde/local_interpolation/fourth_order_interpolation.py +++ b/torchdde/local_interpolation/fourth_order_interpolation.py @@ -35,6 +35,7 @@ def __init__( self.t1 = t1 self.dt = t1 - t0 self.c_mid = c_mid + self.c_mid = self.c_mid.to(dense_info["y0"].device) self.coeffs = self._calculate(dense_info) def _calculate( diff --git a/torchdde/misc.py b/torchdde/misc.py new file mode 100644 index 0000000..3747f5d --- /dev/null +++ b/torchdde/misc.py @@ -0,0 +1,115 @@ +import torch + + +class TupleTensorTransformer: + """ + Transforms a PyTorch tensor representation of a tuple back into the + original tuple. + + This class is designed to handle the specific structure of the 'aug_state' + tuple described in the problem, which is commonly used in adjoint + sensitivity analysis. It can rebuild the tuple given a flattened tensor + and metadata about the original tensor shapes and dtypes. + + Attributes: + original_shapes (list[torch.Size]): The shapes of the tensors in the + original tuple. + original_dtypes (list[torch.dtype]): The data types of the tensors + in the original tuple. + original_devices (list[torch.device]): The devices of the tensors + in the original tuple. + param_indices (list[int]): The starting index within the flattened + tensor where each parameter's adjoint begins. + + """ + + def __init__( + self, original_shapes, original_dtypes, original_devices, param_indices + ): + """ + Initializes the TupleTensorTransformer. + + Args: + original_shapes (list[torch.Size]): The shapes of the tensors + in the original tuple. + original_dtypes (list[torch.dtype]): The data types of the tensors + in the original tuple. + original_devices (list[torch.device]): The devices of the tensors + in the original tuple. + param_indices (list[int]): The starting index within the + flattened tensor where each parameter's adjoint begins. + """ + self.original_shapes = original_shapes + self.original_dtypes = original_dtypes + self.original_devices = original_devices + self.param_indices = param_indices + + @classmethod + def from_tuple(cls, tuple_data): + """ + Creates a TupleTensorTransformer from a sample tuple. + Analyzes the tuple to determine the shapes, dtypes, and devices of + each element, which is crucial for reconstructing the tuple later. + + Args: + tuple_data (tuple or list): A sample of the tuple to + be transformed. + + Returns: + TupleTensorTransformer: An initialized TupleTensorTransformer obj. + """ + original_shapes = [item.shape for item in tuple_data] + original_dtypes = [item.dtype for item in tuple_data] + original_devices = [item.device for item in tuple_data] + + # Find the index of params (start from aug_state[3]) + param_indices = [sum(item.numel() for item in tuple_data[:3])] + current_index = param_indices[0] + + for i in range(3, len(tuple_data)): + current_index += tuple_data[i - 1].numel() + param_indices.append(current_index) + + return cls(original_shapes, original_dtypes, original_devices, param_indices) + + def flatten(self, tuple_data): + """ + Flattens the tuple into a single PyTorch tensor with size [1, N]. + Concatenates all the tensors in the tuple along a single dimension + and reshapes the result. + + Args: + tuple_data (tuple or list): The tuple to be flattened. + + Returns: + torch.Tensor: A flattened tensor representing the tuple, + reshaped to [1, N]. + """ + flat_list = [item.flatten() for item in tuple_data] + concatenated_tensor = torch.cat(flat_list) + return concatenated_tensor.reshape(1, -1) # Reshape to [1, N] + + def unflatten(self, flat_tensor): + """ + Reconstructs the original tuple from the flattened tensor. + Splits the tensor based on the stored shapes, dtypes, and devices, + and reshapes each part to match the original structure. + + Args: + flat_tensor (torch.Tensor): The flattened tensor to be unflattened. + + Returns: + tuple: The reconstructed tuple. + """ + flat_tensor = flat_tensor.reshape(-1) # Reshape back to 1D + reconstructed_tuple = [] + current_index = 0 + for shape, dtype, device in zip( + self.original_shapes, self.original_dtypes, self.original_devices + ): + num_elements = torch.Size(shape).numel() + tensor_slice = flat_tensor[current_index : current_index + num_elements] + reconstructed_tensor = tensor_slice.reshape(shape).to(dtype).to(device) + reconstructed_tuple.append(reconstructed_tensor) + current_index += num_elements + return reconstructed_tuple diff --git a/torchdde/solver/base.py b/torchdde/solver/base.py index e5cf8f5..a17e2fc 100644 --- a/torchdde/solver/base.py +++ b/torchdde/solver/base.py @@ -62,7 +62,6 @@ def step( dt: Float[torch.Tensor, ""], solver_state: Union[Tuple[Any, ...], None], func_args: Any, - has_aux: bool = False, ) -> Tuple[ Float[torch.Tensor, "batch ..."], Union[Float[torch.Tensor, "batch ..."], None], @@ -78,16 +77,10 @@ def step( - `t`: Current time step `t` - `y`: Current state `y` - `dt`: Step size `dt` - - `has_aux`: Whether the model/callable has an auxiliary output. - - ??? tip "has_aux ?" - - A function with an auxiliary output can look like - ```python - def f(t,y,func_args): - return -y, ("Hello World",1) - ``` - The `has_aux` `kwargs` argument is used to compute the adjoint method + - `solver_state`: State of the solver. It is the output of the previous + `step` method. + - `func_args`: Arguments to be passed along to `func` when it's called + (e.g., func(t, y, func_args)). **Returns:** diff --git a/torchdde/solver/bosh3.py b/torchdde/solver/bosh3.py index dbd5b0d..0501f98 100644 --- a/torchdde/solver/bosh3.py +++ b/torchdde/solver/bosh3.py @@ -11,17 +11,22 @@ class Bosh3(ExplicitRungeKutta): interpolation for dense/ts output. """ - a_list = [[], [1 / 2], [0.0, 3 / 4], [2 / 9, 1 / 3, 4 / 9]] - b_list = [2 / 9, 1 / 3, 4 / 9, 0.0] - b_err_list = [2 / 9 - 7 / 24, 1 / 3 - 1 / 4, 4 / 9 - 1 / 3, -1 / 8] c_list = [0.0, 1 / 2, 3 / 4, 1.0] + a_list = [ + [], + [1 / 2], + [0.0, 3 / 4], + [2 / 9, 1 / 3, 4 / 9], + ] + b_list = [2 / 9, 1 / 3, 4 / 9, 0.0] + b_low_order = [7 / 24, 1 / 4, 1 / 3, 1 / 8] def __init__(self): bosh3_tableau = ButcherTableau.from_lists( c=self.c_list, a=self.a_list, b=self.b_list, - b_err=self.b_err_list, + b_low_order=self.b_low_order, ) super().__init__( diff --git a/torchdde/solver/euler.py b/torchdde/solver/euler.py index 17de8d6..f9f750b 100644 --- a/torchdde/solver/euler.py +++ b/torchdde/solver/euler.py @@ -35,7 +35,6 @@ def step( dt: Float[torch.Tensor, ""], solver_state: Union[Tuple[Any, ...], None], func_args: Any, - has_aux: bool = False, ) -> Tuple[ Float[torch.Tensor, "batch ..."], None, @@ -45,13 +44,8 @@ def step( ]: assert solver_state is None, "Euler solver should be stateless" - if has_aux: - k1, aux = func(t, y, func_args) - y1 = y + dt * k1 - return y1, None, dict(y0=y, y1=y1), None, aux - else: - y1 = y + dt * func(t, y, func_args) - return y1, None, dict(y0=y, y1=y1), None, None + y1 = y + dt * func(t, y, func_args) + return y1, None, dict(y0=y, y1=y1), None, None def build_interpolation( self, diff --git a/torchdde/solver/implicit_euler.py b/torchdde/solver/implicit_euler.py index 047bcd9..fe84586 100644 --- a/torchdde/solver/implicit_euler.py +++ b/torchdde/solver/implicit_euler.py @@ -41,12 +41,8 @@ def _residual( dt: Float[torch.Tensor, ""], y_sol: Float[torch.Tensor, "batch ..."], func_args: Any, - has_aux=False, ) -> Float[torch.Tensor, ""]: - if has_aux: - f_sol, _ = func(t, y_sol, func_args) - else: - f_sol = func(t, y_sol, func_args) + f_sol = func(t, y_sol, func_args) return torch.sum((y_sol - y - dt * f_sol) ** 2) def step( @@ -57,7 +53,6 @@ def step( dt: Float[torch.Tensor, ""], solver_state: Union[Tuple[Any, ...], None], func_args: Any, - has_aux: bool = False, ) -> tuple[ Float[torch.Tensor, "batch ..."], None, @@ -82,20 +77,14 @@ def step( def closure() -> Float[torch.Tensor, ""]: opt.zero_grad() - residual = ImplicitEuler._residual( - func, t, y, dt, y_sol, func_args, has_aux=has_aux - ) + residual = ImplicitEuler._residual(func, t, y, dt, y_sol, func_args) (y_sol.grad,) = torch.autograd.grad( residual, y_sol, only_inputs=True, allow_unused=False ) return residual opt.step(closure) # type: ignore - if has_aux: - _, aux = func(t, y, func_args) - return y_sol, None, dict(y0=y, y1=y_sol), None, aux - else: - return y_sol, None, dict(y0=y, y1=y_sol), None, None + return y_sol, None, dict(y0=y, y1=y_sol), None, None def build_interpolation( self, diff --git a/torchdde/solver/runge_kutta.py b/torchdde/solver/runge_kutta.py index db9eec2..3a4325b 100644 --- a/torchdde/solver/runge_kutta.py +++ b/torchdde/solver/runge_kutta.py @@ -162,6 +162,8 @@ def init( **kwargs, ) -> Tuple[Optional[Float[torch.Tensor, "batch ..."]]]: del dt0, args, kwargs + self.tableau = self.tableau.to(y0.device, y0.dtype, t0.dtype) + if self.tableau.fsal: if f0 is None: prev_vf1 = func(t0, y0, func_args) @@ -180,7 +182,6 @@ def step( dt: Float[torch.Tensor, ""], solver_state: Union[Tuple[Any, ...], None], func_args: Any, - has_aux: bool = False, ) -> Tuple[ Float[torch.Tensor, "batch ..."], Float[torch.Tensor, "batch ..."], @@ -193,7 +194,6 @@ def step( f0, *_ = solver_state else: f0 = func(t, y, func_args) - print(f0.shape) y_i = y t_nodes = torch.addcmul(t, self.tableau.c, dt) k = f0.new_empty((self.tableau.n_stages, f0.shape[0], f0.shape[1])) diff --git a/torchdde/step_size_controller/adaptive.py b/torchdde/step_size_controller/adaptive.py index 54e6249..65469f0 100644 --- a/torchdde/step_size_controller/adaptive.py +++ b/torchdde/step_size_controller/adaptive.py @@ -1,3 +1,5 @@ +import warnings + import torch from .base import AbstractStepSizeController @@ -249,6 +251,10 @@ def adapt_step_size( new_dt = torch.max(new_dt, torch.tensor(self.dtmin)) if self.dtmax is not None: new_dt = torch.min(new_dt, torch.tensor(self.dtmax)) + if torch.abs(new_dt) < 1e-8: + warnings.warn( + "AdaptiveStepSizeController yields a abs(dt) value is smaller than 1e-8" + ) t0 = torch.where(keep_step, t1, t0) t1 = torch.where(keep_step, t1 + new_dt, t0 + new_dt)