From 63a4a4e18b3c01e421080cd9251a9a2547183751 Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Sat, 23 Aug 2025 12:08:11 +0000 Subject: [PATCH] FusionNetwork: handle multiple networks with identical inputs Up to now, the network strictly required passing the input to each summary network, which required duplication of the data somewhere upstream if the same data should be used for all backbones, which might become a more common use case. For this reason, I extended the fusion network with a second mode where all backbones receive the same input. This is one possible implementation, but we might also outsource this functionality into a separate class. --- .../networks/fusion_network/fusion_network.py | 111 +++++++++++++----- .../test_fusion_network/conftest.py | 31 ++++- .../test_fusion_network.py | 55 +++++---- 3 files changed, 143 insertions(+), 54 deletions(-) diff --git a/bayesflow/networks/fusion_network/fusion_network.py b/bayesflow/networks/fusion_network/fusion_network.py index 269408e30..dbc014571 100644 --- a/bayesflow/networks/fusion_network/fusion_network.py +++ b/bayesflow/networks/fusion_network/fusion_network.py @@ -1,4 +1,4 @@ -from collections.abc import Mapping +from collections.abc import Mapping, Sequence from ..summary_network import SummaryNetwork from bayesflow.utils.serialization import deserialize, serializable, serialize from bayesflow.types import Tensor, Shape @@ -10,23 +10,32 @@ class FusionNetwork(SummaryNetwork): def __init__( self, - backbones: Mapping[str, keras.Layer], + backbones: Sequence | Mapping[str, keras.Layer], head: keras.Layer | None = None, **kwargs, ): - """(SN) Wraps multiple summary networks (`backbones`) to learn summary statistics from multi-modal data. + """(SN) Wraps multiple summary networks (`backbones`) to learn summary statistics from (optionally) + multi-modal data. - Networks and inputs are passed as dictionaries with corresponding keys, so that each input is processed - by the correct summary network. This means the "summary_variables" entry to the approximator has to be - a dictionary, which can be achieved using the :py:meth:`bayesflow.adapters.Adapter.group` method. + There are two modes of operation: + + - Identical input: each backbone receives the same input. The backbones have to be passed as a sequence. + - Multi-modal input: each backbone gets its own input, which is the usual case for multi-modal data. Networks + and inputs have to be passed as dictionaries with corresponding keys, so that each + input is processed by the correct summary network. This means the "summary_variables" entry to the + approximator has to be a dictionary, which can be achieved using the + :py:meth:`bayesflow.adapters.Adapter.group` method. This network implements _late_ fusion. The output of the individual summary networks is concatenated, and can be further processed by another neural network (`head`). Parameters ---------- - backbones : dict - A dictionary with names of inputs as keys and corresponding summary networks as values. + backbones : Sequence or dict + Either (see above for details): + + - a sequence, when each backbone should receive the same input. + - a dictionary with names of inputs as keys and corresponding summary networks as values. head : keras.Layer, optional A network to further process the concatenated outputs of the summary networks. By default, the concatenated outputs are returned without further processing. @@ -37,16 +46,40 @@ def __init__( super().__init__(**kwargs) self.backbones = backbones self.head = head - self._ordered_keys = sorted(list(self.backbones.keys())) + self._dict_mode = isinstance(backbones, Mapping) + if self._dict_mode: + # order keys to always concatenate in the same order + self._ordered_keys = sorted(list(self.backbones.keys())) - def build(self, inputs_shape: Mapping[str, Shape]): + def build(self, inputs_shape: Shape | Mapping[str, Shape]): + if self._dict_mode and not isinstance(inputs_shape, Mapping): + raise ValueError( + "`backbones` were passed as a dictionary, but the input shapes are not a dictionary. " + "If you want to pass the same input to each backbone, pass the backbones as a list instead of a " + "dictionary. If you want to provide each backbone with different input, please ensure that you have " + "correctly assembled the `summary_variables` to provide a dictionary using the Adapter.group method." + ) if self.built: return output_shapes = [] - for k, shape in inputs_shape.items(): - if not self.backbones[k].built: - self.backbones[k].build(shape) - output_shapes.append(self.backbones[k].compute_output_shape(shape)) + if self._dict_mode: + missing_keys = list(set(inputs_shape.keys()).difference(set(self._ordered_keys))) + if len(missing_keys) > 0: + raise ValueError( + f"Expected the input to contain the following keys: {self._ordered_keys}. " + f"Missing keys: {missing_keys}" + ) + for k, shape in inputs_shape.items(): + # build each summary network with different input shape + if not self.backbones[k].built: + self.backbones[k].build(shape) + output_shapes.append(self.backbones[k].compute_output_shape(shape)) + else: + for backbone in self.backbones: + # build all summary networks with the same input shape + if not backbone.built: + backbone.build(inputs_shape) + output_shapes.append(backbone.compute_output_shape(inputs_shape)) if self.head and not self.head.built: fusion_input_shape = (*output_shapes[0][:-1], sum(shape[-1] for shape in output_shapes)) self.head.build(fusion_input_shape) @@ -54,8 +87,10 @@ def build(self, inputs_shape: Mapping[str, Shape]): def compute_output_shape(self, inputs_shape: Mapping[str, Shape]): output_shapes = [] - for k, shape in inputs_shape.items(): - output_shapes.append(self.backbones[k].compute_output_shape(shape)) + if self._dict_mode: + output_shapes = [self.backbones[k].compute_output_shape(shape) for k, shape in inputs_shape.items()] + else: + output_shapes = [backbone.compute_output_shape(inputs_shape) for backbone in self.backbones] output_shape = (*output_shapes[0][:-1], sum(shape[-1] for shape in output_shapes)) if self.head: output_shape = self.head.compute_output_shape(output_shape) @@ -65,13 +100,20 @@ def call(self, inputs: Mapping[str, Tensor], training=False): """ Parameters ---------- - inputs : dict[str, Tensor] - Each value in the dictionary is the input to the summary network with the corresponding key. + inputs : Tensor | dict[str, Tensor] + Either (see above for details): + + - a tensor, when the backbones where passed as a list and should receive identical inputs + - a dictionary, when the backbones were passed as a dictionary, where each value is the input to the + summary network with the corresponding key. training : bool, optional Whether the model is in training mode, affecting layers like dropout and batch normalization. Default is False. """ - outputs = [self.backbones[k](inputs[k], training=training) for k in self._ordered_keys] + if self._dict_mode: + outputs = [self.backbones[k](inputs[k], training=training) for k in self._ordered_keys] + else: + outputs = [backbone(inputs, training=training) for backbone in self.backbones] outputs = ops.concatenate(outputs, axis=-1) if self.head is None: return outputs @@ -81,8 +123,12 @@ def compute_metrics(self, inputs: Mapping[str, Tensor], stage: str = "training", """ Parameters ---------- - inputs : dict[str, Tensor] - Each value in the dictionary is the input to the summary network with the corresponding key. + inputs : Tensor | dict[str, Tensor] + Either (see above for details): + + - a tensor, when the backbones where passed as a list and should receive identical inputs + - a dictionary, when the backbones were passed as a dictionary, where each value is the input to the + summary network with the corresponding key. stage : bool, optional Whether the model is in training mode, affecting layers like dropout and batch normalization. Default is False. @@ -93,14 +139,23 @@ def compute_metrics(self, inputs: Mapping[str, Tensor], stage: str = "training", self.build(keras.tree.map_structure(keras.ops.shape, inputs)) metrics = {"loss": [], "outputs": []} - for k in self._ordered_keys: - if isinstance(self.backbones[k], SummaryNetwork): - metrics_k = self.backbones[k].compute_metrics(inputs[k], stage=stage, **kwargs) - metrics["outputs"].append(metrics_k["outputs"]) - if "loss" in metrics_k: - metrics["loss"].append(metrics_k["loss"]) + def process_backbone(backbone, input): + # helper function to avoid code duplication for the two modes + if isinstance(backbone, SummaryNetwork): + backbone_metrics = backbone.compute_metrics(input, stage=stage, **kwargs) + metrics["outputs"].append(backbone_metrics["outputs"]) + if "loss" in backbone_metrics: + metrics["loss"].append(backbone_metrics["loss"]) else: - metrics["outputs"].append(self.backbones[k](inputs[k], training=stage == "training")) + metrics["outputs"].append(backbone(input, training=stage == "training")) + + if self._dict_mode: + for k in self._ordered_keys: + process_backbone(self.backbones[k], inputs[k]) + else: + for backbone in self.backbones: + process_backbone(backbone, inputs) + if len(metrics["loss"]) == 0: del metrics["loss"] else: diff --git a/tests/test_networks/test_fusion_network/conftest.py b/tests/test_networks/test_fusion_network/conftest.py index 184242477..287df5317 100644 --- a/tests/test_networks/test_fusion_network/conftest.py +++ b/tests/test_networks/test_fusion_network/conftest.py @@ -1,17 +1,40 @@ import pytest +@pytest.fixture(params=[True, False]) +def multimodal(request): + return request.param + + @pytest.fixture() -def multimodal_data(random_samples, random_set): - return {"x1": random_samples, "x2": random_set} +def data(random_samples, random_set, multimodal): + if multimodal: + return {"x1": random_samples, "x2": random_set} + return random_set @pytest.fixture() -def fusion_network(): +def fusion_network(multimodal): from bayesflow.networks import FusionNetwork, DeepSet import keras + deepset_kwargs = dict( + summary_dim=2, + mlp_widths_equivariant=(2, 2), + mlp_widths_invariant_inner=(2, 2), + mlp_widths_invariant_outer=(2, 2), + mlp_widths_invariant_last=(2, 2), + base_distribution="normal", + ) + if multimodal: + return FusionNetwork( + backbones={"x1": keras.layers.Dense(3), "x2": DeepSet(**deepset_kwargs)}, + head=keras.layers.Dense(3), + ) return FusionNetwork( - backbones={"x1": keras.layers.Dense(3), "x2": DeepSet(summary_dim=2, base_distribution="normal")}, + backbones=[ + DeepSet(**deepset_kwargs), + DeepSet(**deepset_kwargs), + ], head=keras.layers.Dense(3), ) diff --git a/tests/test_networks/test_fusion_network/test_fusion_network.py b/tests/test_networks/test_fusion_network/test_fusion_network.py index f1dbfa1c0..aafec92e7 100644 --- a/tests/test_networks/test_fusion_network/test_fusion_network.py +++ b/tests/test_networks/test_fusion_network/test_fusion_network.py @@ -6,16 +6,16 @@ @pytest.mark.parametrize("automatic", [True, False]) -def test_build(automatic, fusion_network, multimodal_data): +def test_build(automatic, fusion_network, data, multimodal): if fusion_network is None: pytest.skip(reason="Nothing to do, because there is no summary network.") assert fusion_network.built is False if automatic: - fusion_network(multimodal_data) + fusion_network(data) else: - fusion_network.build(keras.tree.map_structure(keras.ops.shape, multimodal_data)) + fusion_network.build(keras.tree.map_structure(keras.ops.shape, data)) assert fusion_network.built is True @@ -23,23 +23,36 @@ def test_build(automatic, fusion_network, multimodal_data): assert fusion_network.variables, "Model has no variables." +def test_build_failure(fusion_network, data, multimodal): + if not multimodal: + pytest.skip(reason="Nothing to do, as summary networks may consume aribrary inputs") + with pytest.raises(ValueError): + fusion_network.build((3, 2, 2)) + with pytest.raises(ValueError): + data["x3"] = data.pop("x1") + fusion_network.build(keras.tree.map_structure(keras.ops.shape, data)) + + @pytest.mark.parametrize("automatic", [True, False]) -def test_build_functional_api(automatic, fusion_network, multimodal_data): +def test_build_functional_api(automatic, fusion_network, data, multimodal): if fusion_network is None: pytest.skip(reason="Nothing to do, because there is no summary network.") assert fusion_network.built is False - inputs = {} - for k, v in multimodal_data.items(): - inputs[k] = keras.layers.Input(shape=keras.ops.shape(v)[1:], name=k) + if multimodal: + inputs = {} + for k, v in data.items(): + inputs[k] = keras.layers.Input(shape=keras.ops.shape(v)[1:], name=k) + else: + inputs = keras.layers.Input(shape=keras.ops.shape(data)[1:]) outputs = fusion_network(inputs) model = keras.Model(inputs=inputs, outputs=outputs) if automatic: - model(multimodal_data) + model(data) else: - model.build(keras.tree.map_structure(keras.ops.shape, multimodal_data)) + model.build(keras.tree.map_structure(keras.ops.shape, data)) assert model.built is True @@ -47,11 +60,11 @@ def test_build_functional_api(automatic, fusion_network, multimodal_data): assert fusion_network.variables, "Model has no variables." -def test_serialize_deserialize(fusion_network, multimodal_data): +def test_serialize_deserialize(fusion_network, data, multimodal): if fusion_network is None: pytest.skip(reason="Nothing to do, because there is no summary network.") - fusion_network.build(keras.tree.map_structure(keras.ops.shape, multimodal_data)) + fusion_network.build(keras.tree.map_structure(keras.ops.shape, data)) serialized = serialize(fusion_network) deserialized = deserialize(serialized) @@ -60,28 +73,28 @@ def test_serialize_deserialize(fusion_network, multimodal_data): assert keras.tree.lists_to_tuples(serialized) == keras.tree.lists_to_tuples(reserialized) -def test_save_and_load(tmp_path, fusion_network, multimodal_data): +def test_save_and_load(tmp_path, fusion_network, data, multimodal): if fusion_network is None: pytest.skip(reason="Nothing to do, because there is no summary network.") - fusion_network.build(keras.tree.map_structure(keras.ops.shape, multimodal_data)) + fusion_network.build(keras.tree.map_structure(keras.ops.shape, data)) keras.saving.save_model(fusion_network, tmp_path / "model.keras") loaded = keras.saving.load_model(tmp_path / "model.keras") assert_layers_equal(fusion_network, loaded) - assert allclose(fusion_network(multimodal_data), loaded(multimodal_data)) + assert allclose(fusion_network(data), loaded(data)) @pytest.mark.parametrize("stage", ["training", "validation"]) -def test_compute_metrics(stage, fusion_network, multimodal_data): +def test_compute_metrics(stage, fusion_network, data, multimodal): if fusion_network is None: pytest.skip("Nothing to do, because there is no summary network.") - fusion_network.build(keras.tree.map_structure(keras.ops.shape, multimodal_data)) + fusion_network.build(keras.tree.map_structure(keras.ops.shape, data)) - metrics = fusion_network.compute_metrics(multimodal_data, stage=stage) - outputs_via_call = fusion_network(multimodal_data, training=stage == "training") + metrics = fusion_network.compute_metrics(data, stage=stage) + outputs_via_call = fusion_network(data, training=stage == "training") assert "outputs" in metrics @@ -90,11 +103,9 @@ def test_compute_metrics(stage, fusion_network, multimodal_data): assert allclose(metrics["outputs"], outputs_via_call) # check that the batch dimension is preserved - assert ( - keras.ops.shape(metrics["outputs"])[0] - == keras.ops.shape(multimodal_data[next(iter(multimodal_data.keys()))])[0] - ) + batch_size = keras.ops.shape(data)[0] if not multimodal else keras.ops.shape(data[next(iter(data.keys()))])[0] + assert keras.ops.shape(metrics["outputs"])[0] == batch_size assert "loss" in metrics assert keras.ops.shape(metrics["loss"]) == ()