Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 83 additions & 28 deletions bayesflow/networks/fusion_network/fusion_network.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from ..summary_network import SummaryNetwork
from bayesflow.utils.serialization import deserialize, serializable, serialize
from bayesflow.types import Tensor, Shape
Expand All @@ -10,23 +10,32 @@
class FusionNetwork(SummaryNetwork):
def __init__(
self,
backbones: Mapping[str, keras.Layer],
backbones: Sequence | Mapping[str, keras.Layer],
head: keras.Layer | None = None,
**kwargs,
):
"""(SN) Wraps multiple summary networks (`backbones`) to learn summary statistics from multi-modal data.
"""(SN) Wraps multiple summary networks (`backbones`) to learn summary statistics from (optionally)
multi-modal data.

Networks and inputs are passed as dictionaries with corresponding keys, so that each input is processed
by the correct summary network. This means the "summary_variables" entry to the approximator has to be
a dictionary, which can be achieved using the :py:meth:`bayesflow.adapters.Adapter.group` method.
There are two modes of operation:

- Identical input: each backbone receives the same input. The backbones have to be passed as a sequence.
- Multi-modal input: each backbone gets its own input, which is the usual case for multi-modal data. Networks
and inputs have to be passed as dictionaries with corresponding keys, so that each
input is processed by the correct summary network. This means the "summary_variables" entry to the
approximator has to be a dictionary, which can be achieved using the
:py:meth:`bayesflow.adapters.Adapter.group` method.

This network implements _late_ fusion. The output of the individual summary networks is concatenated, and
can be further processed by another neural network (`head`).

Parameters
----------
backbones : dict
A dictionary with names of inputs as keys and corresponding summary networks as values.
backbones : Sequence or dict
Either (see above for details):

- a sequence, when each backbone should receive the same input.
- a dictionary with names of inputs as keys and corresponding summary networks as values.
head : keras.Layer, optional
A network to further process the concatenated outputs of the summary networks. By default,
the concatenated outputs are returned without further processing.
Expand All @@ -37,25 +46,51 @@ def __init__(
super().__init__(**kwargs)
self.backbones = backbones
self.head = head
self._ordered_keys = sorted(list(self.backbones.keys()))
self._dict_mode = isinstance(backbones, Mapping)
if self._dict_mode:
# order keys to always concatenate in the same order
self._ordered_keys = sorted(list(self.backbones.keys()))

def build(self, inputs_shape: Mapping[str, Shape]):
def build(self, inputs_shape: Shape | Mapping[str, Shape]):
if self._dict_mode and not isinstance(inputs_shape, Mapping):
raise ValueError(
"`backbones` were passed as a dictionary, but the input shapes are not a dictionary. "
"If you want to pass the same input to each backbone, pass the backbones as a list instead of a "
"dictionary. If you want to provide each backbone with different input, please ensure that you have "
"correctly assembled the `summary_variables` to provide a dictionary using the Adapter.group method."
)
if self.built:
return
output_shapes = []
for k, shape in inputs_shape.items():
if not self.backbones[k].built:
self.backbones[k].build(shape)
output_shapes.append(self.backbones[k].compute_output_shape(shape))
if self._dict_mode:
missing_keys = list(set(inputs_shape.keys()).difference(set(self._ordered_keys)))
if len(missing_keys) > 0:
raise ValueError(
f"Expected the input to contain the following keys: {self._ordered_keys}. "
f"Missing keys: {missing_keys}"
)
for k, shape in inputs_shape.items():
# build each summary network with different input shape
if not self.backbones[k].built:
self.backbones[k].build(shape)
output_shapes.append(self.backbones[k].compute_output_shape(shape))
else:
for backbone in self.backbones:
# build all summary networks with the same input shape
if not backbone.built:
backbone.build(inputs_shape)
output_shapes.append(backbone.compute_output_shape(inputs_shape))
if self.head and not self.head.built:
fusion_input_shape = (*output_shapes[0][:-1], sum(shape[-1] for shape in output_shapes))
self.head.build(fusion_input_shape)
self.built = True

def compute_output_shape(self, inputs_shape: Mapping[str, Shape]):
output_shapes = []
for k, shape in inputs_shape.items():
output_shapes.append(self.backbones[k].compute_output_shape(shape))
if self._dict_mode:
output_shapes = [self.backbones[k].compute_output_shape(shape) for k, shape in inputs_shape.items()]
else:
output_shapes = [backbone.compute_output_shape(inputs_shape) for backbone in self.backbones]
output_shape = (*output_shapes[0][:-1], sum(shape[-1] for shape in output_shapes))
if self.head:
output_shape = self.head.compute_output_shape(output_shape)
Expand All @@ -65,13 +100,20 @@ def call(self, inputs: Mapping[str, Tensor], training=False):
"""
Parameters
----------
inputs : dict[str, Tensor]
Each value in the dictionary is the input to the summary network with the corresponding key.
inputs : Tensor | dict[str, Tensor]
Either (see above for details):

- a tensor, when the backbones where passed as a list and should receive identical inputs
- a dictionary, when the backbones were passed as a dictionary, where each value is the input to the
summary network with the corresponding key.
training : bool, optional
Whether the model is in training mode, affecting layers like dropout and
batch normalization. Default is False.
"""
outputs = [self.backbones[k](inputs[k], training=training) for k in self._ordered_keys]
if self._dict_mode:
outputs = [self.backbones[k](inputs[k], training=training) for k in self._ordered_keys]
else:
outputs = [backbone(inputs, training=training) for backbone in self.backbones]
outputs = ops.concatenate(outputs, axis=-1)
if self.head is None:
return outputs
Expand All @@ -81,8 +123,12 @@ def compute_metrics(self, inputs: Mapping[str, Tensor], stage: str = "training",
"""
Parameters
----------
inputs : dict[str, Tensor]
Each value in the dictionary is the input to the summary network with the corresponding key.
inputs : Tensor | dict[str, Tensor]
Either (see above for details):

- a tensor, when the backbones where passed as a list and should receive identical inputs
- a dictionary, when the backbones were passed as a dictionary, where each value is the input to the
summary network with the corresponding key.
stage : bool, optional
Whether the model is in training mode, affecting layers like dropout and
batch normalization. Default is False.
Expand All @@ -93,14 +139,23 @@ def compute_metrics(self, inputs: Mapping[str, Tensor], stage: str = "training",
self.build(keras.tree.map_structure(keras.ops.shape, inputs))
metrics = {"loss": [], "outputs": []}

for k in self._ordered_keys:
if isinstance(self.backbones[k], SummaryNetwork):
metrics_k = self.backbones[k].compute_metrics(inputs[k], stage=stage, **kwargs)
metrics["outputs"].append(metrics_k["outputs"])
if "loss" in metrics_k:
metrics["loss"].append(metrics_k["loss"])
def process_backbone(backbone, input):
# helper function to avoid code duplication for the two modes
if isinstance(backbone, SummaryNetwork):
backbone_metrics = backbone.compute_metrics(input, stage=stage, **kwargs)
metrics["outputs"].append(backbone_metrics["outputs"])
if "loss" in backbone_metrics:
metrics["loss"].append(backbone_metrics["loss"])
else:
metrics["outputs"].append(self.backbones[k](inputs[k], training=stage == "training"))
metrics["outputs"].append(backbone(input, training=stage == "training"))

if self._dict_mode:
for k in self._ordered_keys:
process_backbone(self.backbones[k], inputs[k])
else:
for backbone in self.backbones:
process_backbone(backbone, inputs)

if len(metrics["loss"]) == 0:
del metrics["loss"]
else:
Expand Down
31 changes: 27 additions & 4 deletions tests/test_networks/test_fusion_network/conftest.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,40 @@
import pytest


@pytest.fixture(params=[True, False])
def multimodal(request):
return request.param


@pytest.fixture()
def multimodal_data(random_samples, random_set):
return {"x1": random_samples, "x2": random_set}
def data(random_samples, random_set, multimodal):
if multimodal:
return {"x1": random_samples, "x2": random_set}
return random_set


@pytest.fixture()
def fusion_network():
def fusion_network(multimodal):
from bayesflow.networks import FusionNetwork, DeepSet
import keras

deepset_kwargs = dict(
summary_dim=2,
mlp_widths_equivariant=(2, 2),
mlp_widths_invariant_inner=(2, 2),
mlp_widths_invariant_outer=(2, 2),
mlp_widths_invariant_last=(2, 2),
base_distribution="normal",
)
if multimodal:
return FusionNetwork(
backbones={"x1": keras.layers.Dense(3), "x2": DeepSet(**deepset_kwargs)},
head=keras.layers.Dense(3),
)
return FusionNetwork(
backbones={"x1": keras.layers.Dense(3), "x2": DeepSet(summary_dim=2, base_distribution="normal")},
backbones=[
DeepSet(**deepset_kwargs),
DeepSet(**deepset_kwargs),
],
head=keras.layers.Dense(3),
)
55 changes: 33 additions & 22 deletions tests/test_networks/test_fusion_network/test_fusion_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,52 +6,65 @@


@pytest.mark.parametrize("automatic", [True, False])
def test_build(automatic, fusion_network, multimodal_data):
def test_build(automatic, fusion_network, data, multimodal):
if fusion_network is None:
pytest.skip(reason="Nothing to do, because there is no summary network.")

assert fusion_network.built is False

if automatic:
fusion_network(multimodal_data)
fusion_network(data)
else:
fusion_network.build(keras.tree.map_structure(keras.ops.shape, multimodal_data))
fusion_network.build(keras.tree.map_structure(keras.ops.shape, data))

assert fusion_network.built is True

# check the model has variables
assert fusion_network.variables, "Model has no variables."


def test_build_failure(fusion_network, data, multimodal):
if not multimodal:
pytest.skip(reason="Nothing to do, as summary networks may consume aribrary inputs")
with pytest.raises(ValueError):
fusion_network.build((3, 2, 2))
with pytest.raises(ValueError):
data["x3"] = data.pop("x1")
fusion_network.build(keras.tree.map_structure(keras.ops.shape, data))


@pytest.mark.parametrize("automatic", [True, False])
def test_build_functional_api(automatic, fusion_network, multimodal_data):
def test_build_functional_api(automatic, fusion_network, data, multimodal):
if fusion_network is None:
pytest.skip(reason="Nothing to do, because there is no summary network.")

assert fusion_network.built is False

inputs = {}
for k, v in multimodal_data.items():
inputs[k] = keras.layers.Input(shape=keras.ops.shape(v)[1:], name=k)
if multimodal:
inputs = {}
for k, v in data.items():
inputs[k] = keras.layers.Input(shape=keras.ops.shape(v)[1:], name=k)
else:
inputs = keras.layers.Input(shape=keras.ops.shape(data)[1:])
outputs = fusion_network(inputs)
model = keras.Model(inputs=inputs, outputs=outputs)

if automatic:
model(multimodal_data)
model(data)
else:
model.build(keras.tree.map_structure(keras.ops.shape, multimodal_data))
model.build(keras.tree.map_structure(keras.ops.shape, data))

assert model.built is True

# check the model has variables
assert fusion_network.variables, "Model has no variables."


def test_serialize_deserialize(fusion_network, multimodal_data):
def test_serialize_deserialize(fusion_network, data, multimodal):
if fusion_network is None:
pytest.skip(reason="Nothing to do, because there is no summary network.")

fusion_network.build(keras.tree.map_structure(keras.ops.shape, multimodal_data))
fusion_network.build(keras.tree.map_structure(keras.ops.shape, data))

serialized = serialize(fusion_network)
deserialized = deserialize(serialized)
Expand All @@ -60,28 +73,28 @@ def test_serialize_deserialize(fusion_network, multimodal_data):
assert keras.tree.lists_to_tuples(serialized) == keras.tree.lists_to_tuples(reserialized)


def test_save_and_load(tmp_path, fusion_network, multimodal_data):
def test_save_and_load(tmp_path, fusion_network, data, multimodal):
if fusion_network is None:
pytest.skip(reason="Nothing to do, because there is no summary network.")

fusion_network.build(keras.tree.map_structure(keras.ops.shape, multimodal_data))
fusion_network.build(keras.tree.map_structure(keras.ops.shape, data))

keras.saving.save_model(fusion_network, tmp_path / "model.keras")
loaded = keras.saving.load_model(tmp_path / "model.keras")

assert_layers_equal(fusion_network, loaded)
assert allclose(fusion_network(multimodal_data), loaded(multimodal_data))
assert allclose(fusion_network(data), loaded(data))


@pytest.mark.parametrize("stage", ["training", "validation"])
def test_compute_metrics(stage, fusion_network, multimodal_data):
def test_compute_metrics(stage, fusion_network, data, multimodal):
if fusion_network is None:
pytest.skip("Nothing to do, because there is no summary network.")

fusion_network.build(keras.tree.map_structure(keras.ops.shape, multimodal_data))
fusion_network.build(keras.tree.map_structure(keras.ops.shape, data))

metrics = fusion_network.compute_metrics(multimodal_data, stage=stage)
outputs_via_call = fusion_network(multimodal_data, training=stage == "training")
metrics = fusion_network.compute_metrics(data, stage=stage)
outputs_via_call = fusion_network(data, training=stage == "training")

assert "outputs" in metrics

Expand All @@ -90,11 +103,9 @@ def test_compute_metrics(stage, fusion_network, multimodal_data):
assert allclose(metrics["outputs"], outputs_via_call)

# check that the batch dimension is preserved
assert (
keras.ops.shape(metrics["outputs"])[0]
== keras.ops.shape(multimodal_data[next(iter(multimodal_data.keys()))])[0]
)
batch_size = keras.ops.shape(data)[0] if not multimodal else keras.ops.shape(data[next(iter(data.keys()))])[0]

assert keras.ops.shape(metrics["outputs"])[0] == batch_size
assert "loss" in metrics
assert keras.ops.shape(metrics["loss"]) == ()

Expand Down
Loading