diff --git a/bayesflow/networks/fusion_network/fusion_network.py b/bayesflow/networks/fusion_network/fusion_network.py index 269408e30..dbc014571 100644 --- a/bayesflow/networks/fusion_network/fusion_network.py +++ b/bayesflow/networks/fusion_network/fusion_network.py @@ -1,4 +1,4 @@ -from collections.abc import Mapping +from collections.abc import Mapping, Sequence from ..summary_network import SummaryNetwork from bayesflow.utils.serialization import deserialize, serializable, serialize from bayesflow.types import Tensor, Shape @@ -10,23 +10,32 @@ class FusionNetwork(SummaryNetwork): def __init__( self, - backbones: Mapping[str, keras.Layer], + backbones: Sequence | Mapping[str, keras.Layer], head: keras.Layer | None = None, **kwargs, ): - """(SN) Wraps multiple summary networks (`backbones`) to learn summary statistics from multi-modal data. + """(SN) Wraps multiple summary networks (`backbones`) to learn summary statistics from (optionally) + multi-modal data. - Networks and inputs are passed as dictionaries with corresponding keys, so that each input is processed - by the correct summary network. This means the "summary_variables" entry to the approximator has to be - a dictionary, which can be achieved using the :py:meth:`bayesflow.adapters.Adapter.group` method. + There are two modes of operation: + + - Identical input: each backbone receives the same input. The backbones have to be passed as a sequence. + - Multi-modal input: each backbone gets its own input, which is the usual case for multi-modal data. Networks + and inputs have to be passed as dictionaries with corresponding keys, so that each + input is processed by the correct summary network. This means the "summary_variables" entry to the + approximator has to be a dictionary, which can be achieved using the + :py:meth:`bayesflow.adapters.Adapter.group` method. This network implements _late_ fusion. The output of the individual summary networks is concatenated, and can be further processed by another neural network (`head`). Parameters ---------- - backbones : dict - A dictionary with names of inputs as keys and corresponding summary networks as values. + backbones : Sequence or dict + Either (see above for details): + + - a sequence, when each backbone should receive the same input. + - a dictionary with names of inputs as keys and corresponding summary networks as values. head : keras.Layer, optional A network to further process the concatenated outputs of the summary networks. By default, the concatenated outputs are returned without further processing. @@ -37,16 +46,40 @@ def __init__( super().__init__(**kwargs) self.backbones = backbones self.head = head - self._ordered_keys = sorted(list(self.backbones.keys())) + self._dict_mode = isinstance(backbones, Mapping) + if self._dict_mode: + # order keys to always concatenate in the same order + self._ordered_keys = sorted(list(self.backbones.keys())) - def build(self, inputs_shape: Mapping[str, Shape]): + def build(self, inputs_shape: Shape | Mapping[str, Shape]): + if self._dict_mode and not isinstance(inputs_shape, Mapping): + raise ValueError( + "`backbones` were passed as a dictionary, but the input shapes are not a dictionary. " + "If you want to pass the same input to each backbone, pass the backbones as a list instead of a " + "dictionary. If you want to provide each backbone with different input, please ensure that you have " + "correctly assembled the `summary_variables` to provide a dictionary using the Adapter.group method." + ) if self.built: return output_shapes = [] - for k, shape in inputs_shape.items(): - if not self.backbones[k].built: - self.backbones[k].build(shape) - output_shapes.append(self.backbones[k].compute_output_shape(shape)) + if self._dict_mode: + missing_keys = list(set(inputs_shape.keys()).difference(set(self._ordered_keys))) + if len(missing_keys) > 0: + raise ValueError( + f"Expected the input to contain the following keys: {self._ordered_keys}. " + f"Missing keys: {missing_keys}" + ) + for k, shape in inputs_shape.items(): + # build each summary network with different input shape + if not self.backbones[k].built: + self.backbones[k].build(shape) + output_shapes.append(self.backbones[k].compute_output_shape(shape)) + else: + for backbone in self.backbones: + # build all summary networks with the same input shape + if not backbone.built: + backbone.build(inputs_shape) + output_shapes.append(backbone.compute_output_shape(inputs_shape)) if self.head and not self.head.built: fusion_input_shape = (*output_shapes[0][:-1], sum(shape[-1] for shape in output_shapes)) self.head.build(fusion_input_shape) @@ -54,8 +87,10 @@ def build(self, inputs_shape: Mapping[str, Shape]): def compute_output_shape(self, inputs_shape: Mapping[str, Shape]): output_shapes = [] - for k, shape in inputs_shape.items(): - output_shapes.append(self.backbones[k].compute_output_shape(shape)) + if self._dict_mode: + output_shapes = [self.backbones[k].compute_output_shape(shape) for k, shape in inputs_shape.items()] + else: + output_shapes = [backbone.compute_output_shape(inputs_shape) for backbone in self.backbones] output_shape = (*output_shapes[0][:-1], sum(shape[-1] for shape in output_shapes)) if self.head: output_shape = self.head.compute_output_shape(output_shape) @@ -65,13 +100,20 @@ def call(self, inputs: Mapping[str, Tensor], training=False): """ Parameters ---------- - inputs : dict[str, Tensor] - Each value in the dictionary is the input to the summary network with the corresponding key. + inputs : Tensor | dict[str, Tensor] + Either (see above for details): + + - a tensor, when the backbones where passed as a list and should receive identical inputs + - a dictionary, when the backbones were passed as a dictionary, where each value is the input to the + summary network with the corresponding key. training : bool, optional Whether the model is in training mode, affecting layers like dropout and batch normalization. Default is False. """ - outputs = [self.backbones[k](inputs[k], training=training) for k in self._ordered_keys] + if self._dict_mode: + outputs = [self.backbones[k](inputs[k], training=training) for k in self._ordered_keys] + else: + outputs = [backbone(inputs, training=training) for backbone in self.backbones] outputs = ops.concatenate(outputs, axis=-1) if self.head is None: return outputs @@ -81,8 +123,12 @@ def compute_metrics(self, inputs: Mapping[str, Tensor], stage: str = "training", """ Parameters ---------- - inputs : dict[str, Tensor] - Each value in the dictionary is the input to the summary network with the corresponding key. + inputs : Tensor | dict[str, Tensor] + Either (see above for details): + + - a tensor, when the backbones where passed as a list and should receive identical inputs + - a dictionary, when the backbones were passed as a dictionary, where each value is the input to the + summary network with the corresponding key. stage : bool, optional Whether the model is in training mode, affecting layers like dropout and batch normalization. Default is False. @@ -93,14 +139,23 @@ def compute_metrics(self, inputs: Mapping[str, Tensor], stage: str = "training", self.build(keras.tree.map_structure(keras.ops.shape, inputs)) metrics = {"loss": [], "outputs": []} - for k in self._ordered_keys: - if isinstance(self.backbones[k], SummaryNetwork): - metrics_k = self.backbones[k].compute_metrics(inputs[k], stage=stage, **kwargs) - metrics["outputs"].append(metrics_k["outputs"]) - if "loss" in metrics_k: - metrics["loss"].append(metrics_k["loss"]) + def process_backbone(backbone, input): + # helper function to avoid code duplication for the two modes + if isinstance(backbone, SummaryNetwork): + backbone_metrics = backbone.compute_metrics(input, stage=stage, **kwargs) + metrics["outputs"].append(backbone_metrics["outputs"]) + if "loss" in backbone_metrics: + metrics["loss"].append(backbone_metrics["loss"]) else: - metrics["outputs"].append(self.backbones[k](inputs[k], training=stage == "training")) + metrics["outputs"].append(backbone(input, training=stage == "training")) + + if self._dict_mode: + for k in self._ordered_keys: + process_backbone(self.backbones[k], inputs[k]) + else: + for backbone in self.backbones: + process_backbone(backbone, inputs) + if len(metrics["loss"]) == 0: del metrics["loss"] else: diff --git a/tests/test_networks/test_fusion_network/conftest.py b/tests/test_networks/test_fusion_network/conftest.py index 184242477..287df5317 100644 --- a/tests/test_networks/test_fusion_network/conftest.py +++ b/tests/test_networks/test_fusion_network/conftest.py @@ -1,17 +1,40 @@ import pytest +@pytest.fixture(params=[True, False]) +def multimodal(request): + return request.param + + @pytest.fixture() -def multimodal_data(random_samples, random_set): - return {"x1": random_samples, "x2": random_set} +def data(random_samples, random_set, multimodal): + if multimodal: + return {"x1": random_samples, "x2": random_set} + return random_set @pytest.fixture() -def fusion_network(): +def fusion_network(multimodal): from bayesflow.networks import FusionNetwork, DeepSet import keras + deepset_kwargs = dict( + summary_dim=2, + mlp_widths_equivariant=(2, 2), + mlp_widths_invariant_inner=(2, 2), + mlp_widths_invariant_outer=(2, 2), + mlp_widths_invariant_last=(2, 2), + base_distribution="normal", + ) + if multimodal: + return FusionNetwork( + backbones={"x1": keras.layers.Dense(3), "x2": DeepSet(**deepset_kwargs)}, + head=keras.layers.Dense(3), + ) return FusionNetwork( - backbones={"x1": keras.layers.Dense(3), "x2": DeepSet(summary_dim=2, base_distribution="normal")}, + backbones=[ + DeepSet(**deepset_kwargs), + DeepSet(**deepset_kwargs), + ], head=keras.layers.Dense(3), ) diff --git a/tests/test_networks/test_fusion_network/test_fusion_network.py b/tests/test_networks/test_fusion_network/test_fusion_network.py index f1dbfa1c0..aafec92e7 100644 --- a/tests/test_networks/test_fusion_network/test_fusion_network.py +++ b/tests/test_networks/test_fusion_network/test_fusion_network.py @@ -6,16 +6,16 @@ @pytest.mark.parametrize("automatic", [True, False]) -def test_build(automatic, fusion_network, multimodal_data): +def test_build(automatic, fusion_network, data, multimodal): if fusion_network is None: pytest.skip(reason="Nothing to do, because there is no summary network.") assert fusion_network.built is False if automatic: - fusion_network(multimodal_data) + fusion_network(data) else: - fusion_network.build(keras.tree.map_structure(keras.ops.shape, multimodal_data)) + fusion_network.build(keras.tree.map_structure(keras.ops.shape, data)) assert fusion_network.built is True @@ -23,23 +23,36 @@ def test_build(automatic, fusion_network, multimodal_data): assert fusion_network.variables, "Model has no variables." +def test_build_failure(fusion_network, data, multimodal): + if not multimodal: + pytest.skip(reason="Nothing to do, as summary networks may consume aribrary inputs") + with pytest.raises(ValueError): + fusion_network.build((3, 2, 2)) + with pytest.raises(ValueError): + data["x3"] = data.pop("x1") + fusion_network.build(keras.tree.map_structure(keras.ops.shape, data)) + + @pytest.mark.parametrize("automatic", [True, False]) -def test_build_functional_api(automatic, fusion_network, multimodal_data): +def test_build_functional_api(automatic, fusion_network, data, multimodal): if fusion_network is None: pytest.skip(reason="Nothing to do, because there is no summary network.") assert fusion_network.built is False - inputs = {} - for k, v in multimodal_data.items(): - inputs[k] = keras.layers.Input(shape=keras.ops.shape(v)[1:], name=k) + if multimodal: + inputs = {} + for k, v in data.items(): + inputs[k] = keras.layers.Input(shape=keras.ops.shape(v)[1:], name=k) + else: + inputs = keras.layers.Input(shape=keras.ops.shape(data)[1:]) outputs = fusion_network(inputs) model = keras.Model(inputs=inputs, outputs=outputs) if automatic: - model(multimodal_data) + model(data) else: - model.build(keras.tree.map_structure(keras.ops.shape, multimodal_data)) + model.build(keras.tree.map_structure(keras.ops.shape, data)) assert model.built is True @@ -47,11 +60,11 @@ def test_build_functional_api(automatic, fusion_network, multimodal_data): assert fusion_network.variables, "Model has no variables." -def test_serialize_deserialize(fusion_network, multimodal_data): +def test_serialize_deserialize(fusion_network, data, multimodal): if fusion_network is None: pytest.skip(reason="Nothing to do, because there is no summary network.") - fusion_network.build(keras.tree.map_structure(keras.ops.shape, multimodal_data)) + fusion_network.build(keras.tree.map_structure(keras.ops.shape, data)) serialized = serialize(fusion_network) deserialized = deserialize(serialized) @@ -60,28 +73,28 @@ def test_serialize_deserialize(fusion_network, multimodal_data): assert keras.tree.lists_to_tuples(serialized) == keras.tree.lists_to_tuples(reserialized) -def test_save_and_load(tmp_path, fusion_network, multimodal_data): +def test_save_and_load(tmp_path, fusion_network, data, multimodal): if fusion_network is None: pytest.skip(reason="Nothing to do, because there is no summary network.") - fusion_network.build(keras.tree.map_structure(keras.ops.shape, multimodal_data)) + fusion_network.build(keras.tree.map_structure(keras.ops.shape, data)) keras.saving.save_model(fusion_network, tmp_path / "model.keras") loaded = keras.saving.load_model(tmp_path / "model.keras") assert_layers_equal(fusion_network, loaded) - assert allclose(fusion_network(multimodal_data), loaded(multimodal_data)) + assert allclose(fusion_network(data), loaded(data)) @pytest.mark.parametrize("stage", ["training", "validation"]) -def test_compute_metrics(stage, fusion_network, multimodal_data): +def test_compute_metrics(stage, fusion_network, data, multimodal): if fusion_network is None: pytest.skip("Nothing to do, because there is no summary network.") - fusion_network.build(keras.tree.map_structure(keras.ops.shape, multimodal_data)) + fusion_network.build(keras.tree.map_structure(keras.ops.shape, data)) - metrics = fusion_network.compute_metrics(multimodal_data, stage=stage) - outputs_via_call = fusion_network(multimodal_data, training=stage == "training") + metrics = fusion_network.compute_metrics(data, stage=stage) + outputs_via_call = fusion_network(data, training=stage == "training") assert "outputs" in metrics @@ -90,11 +103,9 @@ def test_compute_metrics(stage, fusion_network, multimodal_data): assert allclose(metrics["outputs"], outputs_via_call) # check that the batch dimension is preserved - assert ( - keras.ops.shape(metrics["outputs"])[0] - == keras.ops.shape(multimodal_data[next(iter(multimodal_data.keys()))])[0] - ) + batch_size = keras.ops.shape(data)[0] if not multimodal else keras.ops.shape(data[next(iter(data.keys()))])[0] + assert keras.ops.shape(metrics["outputs"])[0] == batch_size assert "loss" in metrics assert keras.ops.shape(metrics["loss"]) == ()