From 3b91982b5ea92d7c8fe92e8d2f54b34d310ea413 Mon Sep 17 00:00:00 2001 From: Sebastian Ament Date: Wed, 12 Nov 2025 12:52:56 -0800 Subject: [PATCH 1/2] Adding `assign` to `load_state_dict` implementations Summary: This commit adds `assign` to `GPyTorchModel.load_state_dict` and other model types, to ensure consistency with `Module.load_state_dict`. Reviewed By: hvarfner Differential Revision: D86870038 --- botorch/models/fully_bayesian.py | 14 +++++-- botorch/models/gpytorch.py | 19 ++++++++-- botorch/models/model.py | 3 +- test/models/test_gpytorch.py | 63 ++++++++++++++++++++++++++++++++ 4 files changed, 90 insertions(+), 9 deletions(-) diff --git a/botorch/models/fully_bayesian.py b/botorch/models/fully_bayesian.py index a92f4e9ee5..282afb2c3d 100644 --- a/botorch/models/fully_bayesian.py +++ b/botorch/models/fully_bayesian.py @@ -958,7 +958,10 @@ def _get_dummy_mcmc_samples( return mcmc_samples def load_state_dict( - self, state_dict: Mapping[str, Any], strict: bool = True + self, + state_dict: Mapping[str, Any], + strict: bool = True, + assign: bool = False, ) -> None: r"""Custom logic for loading the state dict. @@ -980,7 +983,7 @@ def load_state_dict( ) self.load_mcmc_samples(mcmc_samples=mcmc_samples) # Load the actual samples from the state dict - super().load_state_dict(state_dict=state_dict, strict=strict) + super().load_state_dict(state_dict=state_dict, strict=strict, assign=assign) class SaasFullyBayesianSingleTaskGP(FullyBayesianSingleTaskGP): @@ -1047,7 +1050,10 @@ def median_weight_variance(self) -> Tensor: return weight_variance.median(0).values.squeeze(0) def load_state_dict( - self, state_dict: Mapping[str, Any], strict: bool = True + self, + state_dict: Mapping[str, Any], + strict: bool = True, + assign: bool = False, ) -> None: r"""Custom logic for loading the state dict. @@ -1077,4 +1083,4 @@ def load_state_dict( mcmc_samples["noise"] = torch.ones(num_mcmc_samples, **tkwargs) self.load_mcmc_samples(mcmc_samples=mcmc_samples) # Load the actual samples from the state dict - super().load_state_dict(state_dict=state_dict, strict=strict) + super().load_state_dict(state_dict=state_dict, strict=strict, assign=assign) diff --git a/botorch/models/gpytorch.py b/botorch/models/gpytorch.py index 9f105efcfd..33052d8ee2 100644 --- a/botorch/models/gpytorch.py +++ b/botorch/models/gpytorch.py @@ -328,6 +328,7 @@ def load_state_dict( state_dict: Mapping[str, Any], strict: bool = True, keep_transforms: bool = True, + assign: bool = False, ) -> None: r"""Load the model state. @@ -337,9 +338,14 @@ def load_state_dict( keep_transforms: A boolean indicating whether to keep the input and outcome transforms. Doing so is useful when loading a model that was trained on a full set of data, and is later loaded with a subset of the data. + assign: When set to ``False``, the properties of the tensors in the current + module are preserved whereas setting it to ``True`` preserves + properties of the Tensors in the state dict. The only + exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter` + for which the value from the module is preserved. Default: ``False``. """ if not keep_transforms: - super().load_state_dict(state_dict, strict) + super().load_state_dict(state_dict=state_dict, strict=strict, assign=assign) return should_outcome_transform = ( @@ -368,10 +374,12 @@ def load_state_dict( BotorchWarning, stacklevel=3, ) - super().load_state_dict(state_dict, strict) + super().load_state_dict( + state_dict=state_dict, strict=strict, assign=assign + ) return - super().load_state_dict(state_dict, strict) + super().load_state_dict(state_dict=state_dict, strict=strict, assign=assign) if getattr(self, "input_transform", None) is not None: self.input_transform.eval() @@ -763,8 +771,11 @@ def load_state_dict( self, state_dict: Mapping[str, Any], strict: bool = True, + assign: bool = False, ) -> None: - return ModelList.load_state_dict(self, state_dict, strict) + return ModelList.load_state_dict( + self, state_dict=state_dict, strict=strict, assign=assign + ) # pyre-fixme[14]: Inconsistent override in return types def posterior( diff --git a/botorch/models/model.py b/botorch/models/model.py index 8fb3b69eed..b0b3753b2e 100644 --- a/botorch/models/model.py +++ b/botorch/models/model.py @@ -581,6 +581,7 @@ def load_state_dict( state_dict: Mapping[str, Any], strict: bool = True, keep_transforms: bool = True, + assign: bool = False, ) -> None: """Initialize the fully Bayesian models before loading the state dict.""" for i, m in enumerate(self.models): @@ -589,7 +590,7 @@ def load_state_dict( for k, v in state_dict.items() if k.startswith(f"models.{i}.") } - m.load_state_dict(filtered_dict, strict=strict) + m.load_state_dict(filtered_dict, strict=strict, assign=assign) def fantasize( self, diff --git a/test/models/test_gpytorch.py b/test/models/test_gpytorch.py index ff92d15ada..bdcb7ec9e6 100644 --- a/test/models/test_gpytorch.py +++ b/test/models/test_gpytorch.py @@ -1043,6 +1043,69 @@ def test_load_state_dict_with_transforms(self): ) ) + def test_load_state_dict_assign_parameter(self): + """Test that the assign parameter correctly controls tensor property preservation. + + With assign=False (default): properties of the current model's tensors are preserved. + With assign=True: properties of the state dict's tensors are preserved. + """ + # Create base model with double precision + tkwargs_double = {"device": self.device, "dtype": torch.double} + train_X_double = torch.rand(5, 2, **tkwargs_double) + train_Y_double = torch.sin(train_X_double).sum(dim=1, keepdim=True) + + base_model = SingleTaskGP( + train_X=train_X_double, + train_Y=train_Y_double, + **_get_input_output_transform(d=2, indices=[0, 1], m=1), + ) + state_dict_double = base_model.state_dict() + + # Create a new model with float32 precision (different dtype) + tkwargs_float = {"device": self.device, "dtype": torch.float} + train_X_float = torch.rand(5, 2, **tkwargs_float) + train_Y_float = torch.sin(train_X_float).sum(dim=1, keepdim=True) + + # Test assign=False (default behavior) + model_assign_false = SingleTaskGP( + train_X=train_X_float, + train_Y=train_Y_float, + **_get_input_output_transform(d=2, indices=[0, 1], m=1), + ) + + # Load double precision state dict with assign=False + model_assign_false.load_state_dict( + state_dict_double, keep_transforms=True, assign=False + ) + + # With assign=False, the model should keep its original float32 dtype + self.assertEqual(model_assign_false.train_inputs[0].dtype, torch.float) + + # Test assign=True + model_assign_true = SingleTaskGP( + train_X=train_X_float, + train_Y=train_Y_float, + **_get_input_output_transform(d=2, indices=[0, 1], m=1), + ) + + # Load double precision state dict with assign=True + model_assign_true.load_state_dict( + state_dict_double, keep_transforms=True, assign=True + ) + + # With assign=True, the model should adopt the state dict's double dtype + self.assertEqual(model_assign_true.train_inputs[0].dtype, torch.double) + self.assertEqual( + model_assign_true.train_inputs[0].dtype, + state_dict_double["train_inputs.0"].dtype, + ) + + # Verify the two models have different dtypes + self.assertNotEqual( + model_assign_false.train_inputs[0].dtype, + model_assign_true.train_inputs[0].dtype, + ) + def test_load_state_dict_no_transforms(self): tkwargs = {"device": self.device, "dtype": torch.double} From cfa90bfc3abcba405978af42dbec93aa5e5f334c Mon Sep 17 00:00:00 2001 From: Carl Hvarfner Date: Wed, 12 Nov 2025 12:52:56 -0800 Subject: [PATCH 2/2] Fixes to load_state_dict and tests to support `assign` Summary: Completed support for `assign` in load_state_dict Reviewed By: SebastianAment Differential Revision: D86894383 --- botorch/models/gpytorch.py | 3 +++ test/models/test_gpytorch.py | 21 ++++++++++++++++----- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/botorch/models/gpytorch.py b/botorch/models/gpytorch.py index 33052d8ee2..4287ce1b45 100644 --- a/botorch/models/gpytorch.py +++ b/botorch/models/gpytorch.py @@ -344,6 +344,9 @@ def load_state_dict( exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter` for which the value from the module is preserved. Default: ``False``. """ + if assign: + first_item = next(iter(state_dict.values())) + self.to(first_item) if not keep_transforms: super().load_state_dict(state_dict=state_dict, strict=strict, assign=assign) return diff --git a/test/models/test_gpytorch.py b/test/models/test_gpytorch.py index bdcb7ec9e6..88d89741d2 100644 --- a/test/models/test_gpytorch.py +++ b/test/models/test_gpytorch.py @@ -1044,9 +1044,11 @@ def test_load_state_dict_with_transforms(self): ) def test_load_state_dict_assign_parameter(self): - """Test that the assign parameter correctly controls tensor property preservation. + """Test that the assign parameter correctly controls tensor property + preservation. - With assign=False (default): properties of the current model's tensors are preserved. + With assign=False (default): properties of the current model's tensors are + preserved. With assign=True: properties of the state dict's tensors are preserved. """ # Create base model with double precision @@ -1054,9 +1056,16 @@ def test_load_state_dict_assign_parameter(self): train_X_double = torch.rand(5, 2, **tkwargs_double) train_Y_double = torch.sin(train_X_double).sum(dim=1, keepdim=True) + # NOTE Due to issues with transformed priors in gpytorch, we refrain from + # instantiating a model with a LogNormal prior here. + model_specs_without_priors = { + "covar_module": RBFKernel(ard_num_dims=2), + "likelihood": GaussianLikelihood(), + } base_model = SingleTaskGP( train_X=train_X_double, train_Y=train_Y_double, + **model_specs_without_priors, **_get_input_output_transform(d=2, indices=[0, 1], m=1), ) state_dict_double = base_model.state_dict() @@ -1070,12 +1079,13 @@ def test_load_state_dict_assign_parameter(self): model_assign_false = SingleTaskGP( train_X=train_X_float, train_Y=train_Y_float, + **model_specs_without_priors, **_get_input_output_transform(d=2, indices=[0, 1], m=1), ) # Load double precision state dict with assign=False model_assign_false.load_state_dict( - state_dict_double, keep_transforms=True, assign=False + state_dict_double, keep_transforms=False, assign=False ) # With assign=False, the model should keep its original float32 dtype @@ -1085,19 +1095,20 @@ def test_load_state_dict_assign_parameter(self): model_assign_true = SingleTaskGP( train_X=train_X_float, train_Y=train_Y_float, + **model_specs_without_priors, **_get_input_output_transform(d=2, indices=[0, 1], m=1), ) # Load double precision state dict with assign=True model_assign_true.load_state_dict( - state_dict_double, keep_transforms=True, assign=True + state_dict_double, keep_transforms=False, assign=True ) # With assign=True, the model should adopt the state dict's double dtype self.assertEqual(model_assign_true.train_inputs[0].dtype, torch.double) self.assertEqual( model_assign_true.train_inputs[0].dtype, - state_dict_double["train_inputs.0"].dtype, + next(iter(state_dict_double.values())).dtype, ) # Verify the two models have different dtypes