From 5b15df66d78934762a22d633820dd93072f95eae Mon Sep 17 00:00:00 2001 From: han-ol Date: Tue, 22 Jul 2025 13:18:51 +0200 Subject: [PATCH 1/5] Support add_loss (works currently for torch and tf, does NOT for jax) --- bayesflow/approximators/continuous_approximator.py | 9 ++++++++- bayesflow/approximators/model_comparison_approximator.py | 9 ++++++++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index a5dbf12a3..bba566589 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -230,10 +230,17 @@ def compute_metrics( else: loss = inference_metrics.pop("loss") + if len(self.losses) > 0: + layer_loss = keras.ops.sum(self.losses) + loss += layer_loss + layer_loss_metrics = {"layer_loss": layer_loss} + else: + layer_loss_metrics = {} + inference_metrics = {f"{key}/inference_{key}": value for key, value in inference_metrics.items()} summary_metrics = {f"{key}/summary_{key}": value for key, value in summary_metrics.items()} - metrics = {"loss": loss} | inference_metrics | summary_metrics + metrics = {"loss": loss} | layer_loss_metrics | inference_metrics | summary_metrics return metrics def _compute_summary_metrics(self, summary_variables: Tensor | None, stage: str) -> tuple[dict, Tensor | None]: diff --git a/bayesflow/approximators/model_comparison_approximator.py b/bayesflow/approximators/model_comparison_approximator.py index d71aafdaf..7d446e313 100644 --- a/bayesflow/approximators/model_comparison_approximator.py +++ b/bayesflow/approximators/model_comparison_approximator.py @@ -233,10 +233,17 @@ def compute_metrics( else: loss = classifier_metrics.pop("loss") + if len(self.losses) > 0: + layer_loss = keras.ops.sum(self.losses) + loss += layer_loss + layer_loss_metrics = {"layer_loss": layer_loss} + else: + layer_loss_metrics = {} + classifier_metrics = {f"{key}/classifier_{key}": value for key, value in classifier_metrics.items()} summary_metrics = {f"{key}/summary_{key}": value for key, value in summary_metrics.items()} - metrics = {"loss": loss} | classifier_metrics | summary_metrics + metrics = {"loss": loss} | layer_loss_metrics | classifier_metrics | summary_metrics return metrics def fit( From 921514f4bf038db9acd65ce7b312cf81067044a3 Mon Sep 17 00:00:00 2001 From: han-ol Date: Tue, 22 Jul 2025 14:08:15 +0200 Subject: [PATCH 2/5] Test for add_loss support --- examples/Custom_losses_with_add_loss.ipynb | 148 +++++++++++++++++++++ tests/test_approximators/test_add_loss.py | 43 ++++++ 2 files changed, 191 insertions(+) create mode 100644 examples/Custom_losses_with_add_loss.ipynb create mode 100644 tests/test_approximators/test_add_loss.py diff --git a/examples/Custom_losses_with_add_loss.ipynb b/examples/Custom_losses_with_add_loss.ipynb new file mode 100644 index 000000000..7582b4ac6 --- /dev/null +++ b/examples/Custom_losses_with_add_loss.ipynb @@ -0,0 +1,148 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "c0545c7e-d9b0-4e1d-98b9-199afe1bcc31", + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "import os\n", + "\n", + "os.environ[\"KERAS_BACKEND\"] = \"jax\"" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "a776c519-c0ac-4a14-8841-e3e64d2b1716", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:bayesflow:Using backend 'jax'\n", + "2025-07-22 13:16:17.567566: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", + "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", + "E0000 00:00:1753182977.583159 574880 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", + "E0000 00:00:1753182977.586855 574880 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", + "INFO:bayesflow:Fitting on dataset instance of OnlineDataset.\n", + "INFO:bayesflow:Building on a test batch.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1/15\n", + "\u001b[1m 52/200\u001b[0m \u001b[32m━━━━━\u001b[0m\u001b[37m━━━━━━━━━━━━━━━\u001b[0m \u001b[1m12s\u001b[0m 86ms/step - loss: 2.8519" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[2], line 18\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m x\n\u001b[1;32m 10\u001b[0m workflow \u001b[38;5;241m=\u001b[39m bf\u001b[38;5;241m.\u001b[39mBasicWorkflow(\n\u001b[1;32m 11\u001b[0m inference_network\u001b[38;5;241m=\u001b[39mbf\u001b[38;5;241m.\u001b[39mnetworks\u001b[38;5;241m.\u001b[39mCouplingFlow(),\n\u001b[1;32m 12\u001b[0m summary_network\u001b[38;5;241m=\u001b[39mCustomTimeSeriesNetwork(),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 15\u001b[0m simulator\u001b[38;5;241m=\u001b[39mbf\u001b[38;5;241m.\u001b[39msimulators\u001b[38;5;241m.\u001b[39mSIR()\n\u001b[1;32m 16\u001b[0m )\n\u001b[0;32m---> 18\u001b[0m history \u001b[38;5;241m=\u001b[39m \u001b[43mworkflow\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit_online\u001b[49m\u001b[43m(\u001b[49m\u001b[43mepochs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m15\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m32\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_batches_per_epoch\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m200\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 20\u001b[0m diagnostics \u001b[38;5;241m=\u001b[39m workflow\u001b[38;5;241m.\u001b[39mplot_default_diagnostics(test_data\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m300\u001b[39m)\n", + "File \u001b[0;32m~/code/bayesflow/bayesflow/workflows/basic_workflow.py:789\u001b[0m, in \u001b[0;36mBasicWorkflow.fit_online\u001b[0;34m(self, epochs, num_batches_per_epoch, batch_size, keep_optimizer, validation_data, augmentations, **kwargs)\u001b[0m\n\u001b[1;32m 741\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 742\u001b[0m \u001b[38;5;124;03mTrain the approximator using an online data-generating process. The dataset is dynamically generated during\u001b[39;00m\n\u001b[1;32m 743\u001b[0m \u001b[38;5;124;03mtraining, making this approach suitable for scenarios where generating new simulations is computationally cheap.\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 778\u001b[0m \u001b[38;5;124;03m metric evolution over epochs.\u001b[39;00m\n\u001b[1;32m 779\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 781\u001b[0m dataset \u001b[38;5;241m=\u001b[39m OnlineDataset(\n\u001b[1;32m 782\u001b[0m simulator\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msimulator,\n\u001b[1;32m 783\u001b[0m batch_size\u001b[38;5;241m=\u001b[39mbatch_size,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 786\u001b[0m augmentations\u001b[38;5;241m=\u001b[39maugmentations,\n\u001b[1;32m 787\u001b[0m )\n\u001b[0;32m--> 789\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_fit\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 790\u001b[0m \u001b[43m \u001b[49m\u001b[43mdataset\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mepochs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstrategy\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43monline\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkeep_optimizer\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkeep_optimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvalidation_data\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mvalidation_data\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\n\u001b[1;32m 791\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/code/bayesflow/bayesflow/workflows/basic_workflow.py:964\u001b[0m, in \u001b[0;36mBasicWorkflow._fit\u001b[0;34m(self, dataset, epochs, strategy, keep_optimizer, validation_data, **kwargs)\u001b[0m\n\u001b[1;32m 961\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mapproximator\u001b[38;5;241m.\u001b[39mcompile(optimizer\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moptimizer, metrics\u001b[38;5;241m=\u001b[39mkwargs\u001b[38;5;241m.\u001b[39mpop(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmetrics\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m))\n\u001b[1;32m 963\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 964\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhistory \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mapproximator\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 965\u001b[0m \u001b[43m \u001b[49m\u001b[43mdataset\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdataset\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mepochs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mepochs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvalidation_data\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mvalidation_data\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\n\u001b[1;32m 966\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 967\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_on_training_finished()\n\u001b[1;32m 968\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhistory\n", + "File \u001b[0;32m~/code/bayesflow/bayesflow/approximators/continuous_approximator.py:322\u001b[0m, in \u001b[0;36mContinuousApproximator.fit\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 270\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mfit\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 271\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 272\u001b[0m \u001b[38;5;124;03m Trains the approximator on the provided dataset or on-demand data generated from the given simulator.\u001b[39;00m\n\u001b[1;32m 273\u001b[0m \u001b[38;5;124;03m If `dataset` is not provided, a dataset is built from the `simulator`.\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 320\u001b[0m \u001b[38;5;124;03m If both `dataset` and `simulator` are provided or neither is provided.\u001b[39;00m\n\u001b[1;32m 321\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 322\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43madapter\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madapter\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/code/bayesflow/bayesflow/approximators/approximator.py:139\u001b[0m, in \u001b[0;36mApproximator.fit\u001b[0;34m(self, dataset, simulator, **kwargs)\u001b[0m\n\u001b[1;32m 136\u001b[0m mock_data_shapes \u001b[38;5;241m=\u001b[39m keras\u001b[38;5;241m.\u001b[39mtree\u001b[38;5;241m.\u001b[39mmap_structure(keras\u001b[38;5;241m.\u001b[39mops\u001b[38;5;241m.\u001b[39mshape, mock_data)\n\u001b[1;32m 137\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbuild(mock_data_shapes)\n\u001b[0;32m--> 139\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdataset\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdataset\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/code/bayesflow/bayesflow/approximators/backend_approximators/backend_approximator.py:20\u001b[0m, in \u001b[0;36mBackendApproximator.fit\u001b[0;34m(self, dataset, **kwargs)\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mfit\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39m, dataset: keras\u001b[38;5;241m.\u001b[39mutils\u001b[38;5;241m.\u001b[39mPyDataset, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m---> 20\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdataset\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mfilter_kwargs\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/programs/anaconda3/envs/bayesflow/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py:117\u001b[0m, in \u001b[0;36mfilter_traceback..error_handler\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 115\u001b[0m filtered_tb \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 116\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 117\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 118\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 119\u001b[0m filtered_tb \u001b[38;5;241m=\u001b[39m _process_traceback_frames(e\u001b[38;5;241m.\u001b[39m__traceback__)\n", + "File \u001b[0;32m~/programs/anaconda3/envs/bayesflow/lib/python3.11/site-packages/keras/src/backend/jax/trainer.py:418\u001b[0m, in \u001b[0;36mJAXTrainer.fit\u001b[0;34m(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq)\u001b[0m\n\u001b[1;32m 409\u001b[0m state \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_get_jax_state(\n\u001b[1;32m 410\u001b[0m trainable_variables\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 411\u001b[0m non_trainable_variables\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 414\u001b[0m purge_model_variables\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 415\u001b[0m )\n\u001b[1;32m 416\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_jax_state_synced \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[0;32m--> 418\u001b[0m logs, state \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain_function\u001b[49m\u001b[43m(\u001b[49m\u001b[43mstate\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43miterator\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 419\u001b[0m (\n\u001b[1;32m 420\u001b[0m trainable_variables,\n\u001b[1;32m 421\u001b[0m non_trainable_variables,\n\u001b[1;32m 422\u001b[0m optimizer_variables,\n\u001b[1;32m 423\u001b[0m metrics_variables,\n\u001b[1;32m 424\u001b[0m ) \u001b[38;5;241m=\u001b[39m state\n\u001b[1;32m 426\u001b[0m \u001b[38;5;66;03m# Setting _jax_state enables callbacks to force a state sync\u001b[39;00m\n\u001b[1;32m 427\u001b[0m \u001b[38;5;66;03m# if they need to.\u001b[39;00m\n", + "File \u001b[0;32m~/programs/anaconda3/envs/bayesflow/lib/python3.11/site-packages/keras/src/backend/jax/trainer.py:266\u001b[0m, in \u001b[0;36mJAXTrainer._make_function..iterator_step\u001b[0;34m(state, iterator)\u001b[0m\n\u001b[1;32m 265\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21miterator_step\u001b[39m(state, iterator):\n\u001b[0;32m--> 266\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m step_function(state, \u001b[38;5;28mnext\u001b[39m(iterator))\n", + "File \u001b[0;32m~/programs/anaconda3/envs/bayesflow/lib/python3.11/site-packages/keras/src/backend/jax/trainer.py:1059\u001b[0m, in \u001b[0;36mJAXEpochIterator._prefetch_numpy_iterator\u001b[0;34m(self, numpy_iterator)\u001b[0m\n\u001b[1;32m 1057\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m queue:\n\u001b[1;32m 1058\u001b[0m \u001b[38;5;28;01myield\u001b[39;00m queue\u001b[38;5;241m.\u001b[39mpopleft()\n\u001b[0;32m-> 1059\u001b[0m \u001b[43menqueue\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/programs/anaconda3/envs/bayesflow/lib/python3.11/site-packages/keras/src/backend/jax/trainer.py:1053\u001b[0m, in \u001b[0;36mJAXEpochIterator._prefetch_numpy_iterator..enqueue\u001b[0;34m(n)\u001b[0m\n\u001b[1;32m 1052\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21menqueue\u001b[39m(n\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m):\n\u001b[0;32m-> 1053\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mitertools\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mislice\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnumpy_iterator\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn\u001b[49m\u001b[43m)\u001b[49m\u001b[43m:\u001b[49m\n\u001b[1;32m 1054\u001b[0m \u001b[43m \u001b[49m\u001b[43mqueue\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mappend\u001b[49m\u001b[43m(\u001b[49m\u001b[43m_distribute_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/programs/anaconda3/envs/bayesflow/lib/python3.11/site-packages/keras/src/trainers/data_adapters/data_adapter_utils.py:198\u001b[0m, in \u001b[0;36mget_jax_iterator\u001b[0;34m(iterable)\u001b[0m\n\u001b[1;32m 195\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 196\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m np\u001b[38;5;241m.\u001b[39masarray(x)\n\u001b[0;32m--> 198\u001b[0m \u001b[43m\u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43miterable\u001b[49m\u001b[43m:\u001b[49m\n\u001b[1;32m 199\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43;01myield\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mtree\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmap_structure\u001b[49m\u001b[43m(\u001b[49m\u001b[43mconvert_to_jax_compatible\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/programs/anaconda3/envs/bayesflow/lib/python3.11/site-packages/keras/src/trainers/data_adapters/py_dataset_adapter.py:248\u001b[0m, in \u001b[0;36mPyDatasetAdapter._finite_generator\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 245\u001b[0m random\u001b[38;5;241m.\u001b[39mshuffle(indices)\n\u001b[1;32m 247\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m indices:\n\u001b[0;32m--> 248\u001b[0m \u001b[38;5;28;01myield\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_standardize_batch(\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpy_dataset\u001b[49m\u001b[43m[\u001b[49m\u001b[43mi\u001b[49m\u001b[43m]\u001b[49m)\n", + "File \u001b[0;32m~/code/bayesflow/bayesflow/datasets/online_dataset.py:74\u001b[0m, in \u001b[0;36mOnlineDataset.__getitem__\u001b[0;34m(self, item)\u001b[0m\n\u001b[1;32m 60\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__getitem__\u001b[39m(\u001b[38;5;28mself\u001b[39m, item: \u001b[38;5;28mint\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28mdict\u001b[39m[\u001b[38;5;28mstr\u001b[39m, np\u001b[38;5;241m.\u001b[39mndarray]:\n\u001b[1;32m 61\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 62\u001b[0m \u001b[38;5;124;03m Generate one batch of data.\u001b[39;00m\n\u001b[1;32m 63\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[38;5;124;03m A batch of simulated (and optionally augmented/adapted) data.\u001b[39;00m\n\u001b[1;32m 73\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m---> 74\u001b[0m batch \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msimulator\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msample\u001b[49m\u001b[43m(\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbatch_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 76\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maugmentations \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 77\u001b[0m \u001b[38;5;28;01mpass\u001b[39;00m\n", + "File \u001b[0;32m~/code/bayesflow/bayesflow/utils/decorators.py:63\u001b[0m, in \u001b[0;36malias..alias_wrapper..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 60\u001b[0m matches \u001b[38;5;241m=\u001b[39m [name \u001b[38;5;28;01mfor\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m kwargs \u001b[38;5;28;01mif\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m aliases]\n\u001b[1;32m 62\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m matches:\n\u001b[0;32m---> 63\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 65\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(matches) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m (\u001b[38;5;28mlen\u001b[39m(matches) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(args) \u001b[38;5;241m>\u001b[39m argpos):\n\u001b[1;32m 66\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\n\u001b[1;32m 67\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfn\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m() got multiple values for argument \u001b[39m\u001b[38;5;132;01m{\u001b[39;00margname\u001b[38;5;132;01m!r}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 68\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mThis argument is also aliased as \u001b[39m\u001b[38;5;132;01m{\u001b[39;00maliases\u001b[38;5;132;01m!r}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 69\u001b[0m )\n", + "File \u001b[0;32m~/code/bayesflow/bayesflow/utils/decorators.py:95\u001b[0m, in \u001b[0;36margument_callback..callback_wrapper..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 92\u001b[0m args \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlist\u001b[39m(args)\n\u001b[1;32m 93\u001b[0m args[argpos] \u001b[38;5;241m=\u001b[39m callback(args[argpos])\n\u001b[0;32m---> 95\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/code/bayesflow/bayesflow/simulators/benchmark_simulators/benchmark_simulator.py:27\u001b[0m, in \u001b[0;36mBenchmarkSimulator.sample\u001b[0;34m(self, batch_shape, **kwargs)\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[38;5;129m@allow_batch_size\u001b[39m\n\u001b[1;32m 12\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21msample\u001b[39m(\u001b[38;5;28mself\u001b[39m, batch_shape: Shape, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28mdict\u001b[39m[\u001b[38;5;28mstr\u001b[39m, np\u001b[38;5;241m.\u001b[39mndarray]:\n\u001b[1;32m 13\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Runs simulated benchmark and returns `batch_size` parameter\u001b[39;00m\n\u001b[1;32m 14\u001b[0m \u001b[38;5;124;03m and observation batches\u001b[39;00m\n\u001b[1;32m 15\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[38;5;124;03m with shapes (`batch_size`, ...)\u001b[39;00m\n\u001b[1;32m 25\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m---> 27\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[43mbatched_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_shape\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mflatten\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m 28\u001b[0m data \u001b[38;5;241m=\u001b[39m tree_stack(data, axis\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m, numpy\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 29\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m data\n", + "File \u001b[0;32m~/code/bayesflow/bayesflow/utils/functional.py:54\u001b[0m, in \u001b[0;36mbatched_call\u001b[0;34m(f, batch_shape, args, kwargs, map_predicate, flatten)\u001b[0m\n\u001b[1;32m 51\u001b[0m map_args \u001b[38;5;241m=\u001b[39m [arg[index] \u001b[38;5;28;01mif\u001b[39;00m map_predicate(arg) \u001b[38;5;28;01melse\u001b[39;00m arg \u001b[38;5;28;01mfor\u001b[39;00m arg \u001b[38;5;129;01min\u001b[39;00m args]\n\u001b[1;32m 52\u001b[0m map_kwargs \u001b[38;5;241m=\u001b[39m {key: value[index] \u001b[38;5;28;01mif\u001b[39;00m map_predicate(value) \u001b[38;5;28;01melse\u001b[39;00m value \u001b[38;5;28;01mfor\u001b[39;00m key, value \u001b[38;5;129;01min\u001b[39;00m kwargs\u001b[38;5;241m.\u001b[39mitems()}\n\u001b[0;32m---> 54\u001b[0m outputs[index] \u001b[38;5;241m=\u001b[39m \u001b[43mf\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mmap_args\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mmap_kwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 56\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m flatten:\n\u001b[1;32m 57\u001b[0m outputs \u001b[38;5;241m=\u001b[39m outputs\u001b[38;5;241m.\u001b[39mflatten()\n", + "File \u001b[0;32m~/code/bayesflow/bayesflow/simulators/benchmark_simulators/benchmark_simulator.py:39\u001b[0m, in \u001b[0;36mBenchmarkSimulator.__call__\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 37\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28mdict\u001b[39m[\u001b[38;5;28mstr\u001b[39m, np\u001b[38;5;241m.\u001b[39mndarray]:\n\u001b[1;32m 38\u001b[0m prior_draws \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprior()\n\u001b[0;32m---> 39\u001b[0m observables \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mobservation_model\u001b[49m\u001b[43m(\u001b[49m\u001b[43mprior_draws\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 40\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mdict\u001b[39m(parameters\u001b[38;5;241m=\u001b[39mprior_draws\u001b[38;5;241m.\u001b[39mastype(np\u001b[38;5;241m.\u001b[39mfloat32), observables\u001b[38;5;241m=\u001b[39mobservables\u001b[38;5;241m.\u001b[39mastype(np\u001b[38;5;241m.\u001b[39mfloat32))\n", + "File \u001b[0;32m~/code/bayesflow/bayesflow/simulators/benchmark_simulators/sir.py:107\u001b[0m, in \u001b[0;36mSIR.observation_model\u001b[0;34m(self, params)\u001b[0m\n\u001b[1;32m 104\u001b[0m t_vec \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mlinspace(\u001b[38;5;241m0\u001b[39m, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mT, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mT)\n\u001b[1;32m 106\u001b[0m \u001b[38;5;66;03m# Integrate using scipy and retain only infected (2-nd dimension)\u001b[39;00m\n\u001b[0;32m--> 107\u001b[0m irt \u001b[38;5;241m=\u001b[39m \u001b[43modeint\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_deriv\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx0\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mt_vec\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mN\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbeta\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgamma\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m[:, \u001b[38;5;241m1\u001b[39m]\n\u001b[1;32m 109\u001b[0m \u001b[38;5;66;03m# Subsample evenly the specified number of points, if specified\u001b[39;00m\n\u001b[1;32m 110\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msubsample \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", + "File \u001b[0;32m~/programs/anaconda3/envs/bayesflow/lib/python3.11/site-packages/scipy/integrate/_odepack_py.py:247\u001b[0m, in \u001b[0;36modeint\u001b[0;34m(func, y0, t, args, Dfun, col_deriv, full_output, ml, mu, rtol, atol, tcrit, h0, hmax, hmin, ixpr, mxstep, mxhnil, mxordn, mxords, printmessg, tfirst)\u001b[0m\n\u001b[1;32m 245\u001b[0m t \u001b[38;5;241m=\u001b[39m copy(t)\n\u001b[1;32m 246\u001b[0m y0 \u001b[38;5;241m=\u001b[39m copy(y0)\n\u001b[0;32m--> 247\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43m_odepack\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43modeint\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfunc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my0\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mDfun\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcol_deriv\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mml\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmu\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 248\u001b[0m \u001b[43m \u001b[49m\u001b[43mfull_output\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrtol\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43matol\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtcrit\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mh0\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhmax\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhmin\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 249\u001b[0m \u001b[43m \u001b[49m\u001b[43mixpr\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmxstep\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmxhnil\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmxordn\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmxords\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 250\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mint\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mbool\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mtfirst\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 251\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m output[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m] \u001b[38;5;241m<\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m 252\u001b[0m warning_msg \u001b[38;5;241m=\u001b[39m (\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m_msgs[output[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m]]\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m Run with full_output = 1 to \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 253\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mget quantitative information.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", + "File \u001b[0;32m~/code/bayesflow/bayesflow/simulators/benchmark_simulators/sir.py:60\u001b[0m, in \u001b[0;36mSIR._deriv\u001b[0;34m(self, x, t, N, beta, gamma)\u001b[0m\n\u001b[1;32m 57\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrng \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 58\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrng \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mrandom\u001b[38;5;241m.\u001b[39mdefault_rng()\n\u001b[0;32m---> 60\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_deriv\u001b[39m(\u001b[38;5;28mself\u001b[39m, x, t, N, beta, gamma):\n\u001b[1;32m 61\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Helper function for scipy.integrate.odeint.\"\"\"\u001b[39;00m\n\u001b[1;32m 63\u001b[0m s, i, r \u001b[38;5;241m=\u001b[39m x\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "import bayesflow as bf\n", + "import keras\n", + "\n", + "class CustomTimeSeriesNetwork(bf.networks.TimeSeriesNetwork):\n", + " def call(self, x, training=False, **kwargs):\n", + " x = super().call(x, training=training, **kwargs)\n", + " self.add_loss(keras.ops.sum(x**2))\n", + " return x\n", + "\n", + "workflow = bf.BasicWorkflow(\n", + " inference_network=bf.networks.CouplingFlow(),\n", + " summary_network=CustomTimeSeriesNetwork(),\n", + " inference_variables=[\"parameters\"],\n", + " summary_variables=[\"observables\"],\n", + " simulator=bf.simulators.SIR()\n", + ")\n", + "\n", + "history = workflow.fit_online(epochs=15, batch_size=32, num_batches_per_epoch=200)\n", + "\n", + "diagnostics = workflow.plot_default_diagnostics(test_data=300)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "66900aa0-99e8-41ee-a08d-5e10b946deb9", + "metadata": {}, + "outputs": [], + "source": [ + "workflow.approximator.summary_network.losses" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d745dd2d-d5fc-4085-890a-4352886945ca", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "state": {}, + "version_major": 2, + "version_minor": 0 + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tests/test_approximators/test_add_loss.py b/tests/test_approximators/test_add_loss.py new file mode 100644 index 000000000..833bb3723 --- /dev/null +++ b/tests/test_approximators/test_add_loss.py @@ -0,0 +1,43 @@ +import pytest +import keras +import io +from contextlib import redirect_stdout + + +@pytest.fixture() +def approximator_using_add_loss(adapter): + from bayesflow import ContinuousApproximator + from bayesflow.networks import CouplingFlow, MLP + + class MLPAddedLoss(MLP): + def call(self, x, training=False, **kwargs): + x = super().call(x, training=training, **kwargs) + self.add_loss(keras.ops.sum(x**2)) + return x + + return ContinuousApproximator( + adapter=adapter, + inference_network=CouplingFlow(subnet=MLPAddedLoss), + summary_network=None, + ) + + +def test_layer_loss_reported(approximator_using_add_loss, train_dataset, validation_dataset): + approximator = approximator_using_add_loss + approximator.compile(optimizer="AdamW") + num_epochs = 3 + + # Capture ostream and train model + with io.StringIO() as stream: + with redirect_stdout(stream): + approximator.fit(dataset=train_dataset, validation_data=validation_dataset, epochs=num_epochs) + + output = stream.getvalue() + + print(output) + + # check that there is a progress bar + assert "━" in output, "no progress bar" + + # check that layer_loss is reported + assert "layer_loss" in output, "no layer_loss" From e59b30b35f09db5ec6e0b5d20926326ac3a1d6ad Mon Sep 17 00:00:00 2001 From: han-ol Date: Tue, 22 Jul 2025 14:09:13 +0200 Subject: [PATCH 3/5] Example notebook regularizing the summary space using layer.add_loss --- examples/Custom_losses_with_add_loss.ipynb | 99 +++++++++++----------- 1 file changed, 51 insertions(+), 48 deletions(-) diff --git a/examples/Custom_losses_with_add_loss.ipynb b/examples/Custom_losses_with_add_loss.ipynb index 7582b4ac6..3a4084ff9 100644 --- a/examples/Custom_losses_with_add_loss.ipynb +++ b/examples/Custom_losses_with_add_loss.ipynb @@ -12,7 +12,7 @@ "\n", "import os\n", "\n", - "os.environ[\"KERAS_BACKEND\"] = \"jax\"" + "os.environ[\"KERAS_BACKEND\"] = \"tensorflow\"" ] }, { @@ -25,11 +25,16 @@ "name": "stderr", "output_type": "stream", "text": [ - "INFO:bayesflow:Using backend 'jax'\n", - "2025-07-22 13:16:17.567566: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", + "2025-07-22 14:07:15.076444: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.\n", + "2025-07-22 14:07:15.079630: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.\n", + "2025-07-22 14:07:15.087697: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", - "E0000 00:00:1753182977.583159 574880 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", - "E0000 00:00:1753182977.586855 574880 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", + "E0000 00:00:1753186035.101632 583449 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", + "E0000 00:00:1753186035.105502 583449 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", + "2025-07-22 14:07:15.121221: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", + "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", + "2025-07-22 14:07:17.542250: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)\n", + "INFO:bayesflow:Using backend 'tensorflow'\n", "INFO:bayesflow:Fitting on dataset instance of OnlineDataset.\n", "INFO:bayesflow:Building on a test batch.\n" ] @@ -38,40 +43,16 @@ "name": "stdout", "output_type": "stream", "text": [ - "Epoch 1/15\n", - "\u001b[1m 52/200\u001b[0m \u001b[32m━━━━━\u001b[0m\u001b[37m━━━━━━━━━━━━━━━\u001b[0m \u001b[1m12s\u001b[0m 86ms/step - loss: 2.8519" - ] - }, - { - "ename": "KeyboardInterrupt", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[2], line 18\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m x\n\u001b[1;32m 10\u001b[0m workflow \u001b[38;5;241m=\u001b[39m bf\u001b[38;5;241m.\u001b[39mBasicWorkflow(\n\u001b[1;32m 11\u001b[0m inference_network\u001b[38;5;241m=\u001b[39mbf\u001b[38;5;241m.\u001b[39mnetworks\u001b[38;5;241m.\u001b[39mCouplingFlow(),\n\u001b[1;32m 12\u001b[0m summary_network\u001b[38;5;241m=\u001b[39mCustomTimeSeriesNetwork(),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 15\u001b[0m simulator\u001b[38;5;241m=\u001b[39mbf\u001b[38;5;241m.\u001b[39msimulators\u001b[38;5;241m.\u001b[39mSIR()\n\u001b[1;32m 16\u001b[0m )\n\u001b[0;32m---> 18\u001b[0m history \u001b[38;5;241m=\u001b[39m \u001b[43mworkflow\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit_online\u001b[49m\u001b[43m(\u001b[49m\u001b[43mepochs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m15\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m32\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_batches_per_epoch\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m200\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 20\u001b[0m diagnostics \u001b[38;5;241m=\u001b[39m workflow\u001b[38;5;241m.\u001b[39mplot_default_diagnostics(test_data\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m300\u001b[39m)\n", - "File \u001b[0;32m~/code/bayesflow/bayesflow/workflows/basic_workflow.py:789\u001b[0m, in \u001b[0;36mBasicWorkflow.fit_online\u001b[0;34m(self, epochs, num_batches_per_epoch, batch_size, keep_optimizer, validation_data, augmentations, **kwargs)\u001b[0m\n\u001b[1;32m 741\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 742\u001b[0m \u001b[38;5;124;03mTrain the approximator using an online data-generating process. The dataset is dynamically generated during\u001b[39;00m\n\u001b[1;32m 743\u001b[0m \u001b[38;5;124;03mtraining, making this approach suitable for scenarios where generating new simulations is computationally cheap.\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 778\u001b[0m \u001b[38;5;124;03m metric evolution over epochs.\u001b[39;00m\n\u001b[1;32m 779\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 781\u001b[0m dataset \u001b[38;5;241m=\u001b[39m OnlineDataset(\n\u001b[1;32m 782\u001b[0m simulator\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msimulator,\n\u001b[1;32m 783\u001b[0m batch_size\u001b[38;5;241m=\u001b[39mbatch_size,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 786\u001b[0m augmentations\u001b[38;5;241m=\u001b[39maugmentations,\n\u001b[1;32m 787\u001b[0m )\n\u001b[0;32m--> 789\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_fit\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 790\u001b[0m \u001b[43m \u001b[49m\u001b[43mdataset\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mepochs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstrategy\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43monline\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkeep_optimizer\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkeep_optimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvalidation_data\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mvalidation_data\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\n\u001b[1;32m 791\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/code/bayesflow/bayesflow/workflows/basic_workflow.py:964\u001b[0m, in \u001b[0;36mBasicWorkflow._fit\u001b[0;34m(self, dataset, epochs, strategy, keep_optimizer, validation_data, **kwargs)\u001b[0m\n\u001b[1;32m 961\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mapproximator\u001b[38;5;241m.\u001b[39mcompile(optimizer\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moptimizer, metrics\u001b[38;5;241m=\u001b[39mkwargs\u001b[38;5;241m.\u001b[39mpop(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmetrics\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m))\n\u001b[1;32m 963\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 964\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhistory \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mapproximator\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 965\u001b[0m \u001b[43m \u001b[49m\u001b[43mdataset\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdataset\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mepochs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mepochs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvalidation_data\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mvalidation_data\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\n\u001b[1;32m 966\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 967\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_on_training_finished()\n\u001b[1;32m 968\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhistory\n", - "File \u001b[0;32m~/code/bayesflow/bayesflow/approximators/continuous_approximator.py:322\u001b[0m, in \u001b[0;36mContinuousApproximator.fit\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 270\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mfit\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 271\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 272\u001b[0m \u001b[38;5;124;03m Trains the approximator on the provided dataset or on-demand data generated from the given simulator.\u001b[39;00m\n\u001b[1;32m 273\u001b[0m \u001b[38;5;124;03m If `dataset` is not provided, a dataset is built from the `simulator`.\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 320\u001b[0m \u001b[38;5;124;03m If both `dataset` and `simulator` are provided or neither is provided.\u001b[39;00m\n\u001b[1;32m 321\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 322\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43madapter\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madapter\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/code/bayesflow/bayesflow/approximators/approximator.py:139\u001b[0m, in \u001b[0;36mApproximator.fit\u001b[0;34m(self, dataset, simulator, **kwargs)\u001b[0m\n\u001b[1;32m 136\u001b[0m mock_data_shapes \u001b[38;5;241m=\u001b[39m keras\u001b[38;5;241m.\u001b[39mtree\u001b[38;5;241m.\u001b[39mmap_structure(keras\u001b[38;5;241m.\u001b[39mops\u001b[38;5;241m.\u001b[39mshape, mock_data)\n\u001b[1;32m 137\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbuild(mock_data_shapes)\n\u001b[0;32m--> 139\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdataset\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdataset\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/code/bayesflow/bayesflow/approximators/backend_approximators/backend_approximator.py:20\u001b[0m, in \u001b[0;36mBackendApproximator.fit\u001b[0;34m(self, dataset, **kwargs)\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mfit\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39m, dataset: keras\u001b[38;5;241m.\u001b[39mutils\u001b[38;5;241m.\u001b[39mPyDataset, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m---> 20\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdataset\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mfilter_kwargs\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/programs/anaconda3/envs/bayesflow/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py:117\u001b[0m, in \u001b[0;36mfilter_traceback..error_handler\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 115\u001b[0m filtered_tb \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 116\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 117\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 118\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 119\u001b[0m filtered_tb \u001b[38;5;241m=\u001b[39m _process_traceback_frames(e\u001b[38;5;241m.\u001b[39m__traceback__)\n", - "File \u001b[0;32m~/programs/anaconda3/envs/bayesflow/lib/python3.11/site-packages/keras/src/backend/jax/trainer.py:418\u001b[0m, in \u001b[0;36mJAXTrainer.fit\u001b[0;34m(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq)\u001b[0m\n\u001b[1;32m 409\u001b[0m state \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_get_jax_state(\n\u001b[1;32m 410\u001b[0m trainable_variables\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 411\u001b[0m non_trainable_variables\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 414\u001b[0m purge_model_variables\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 415\u001b[0m )\n\u001b[1;32m 416\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_jax_state_synced \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[0;32m--> 418\u001b[0m logs, state \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain_function\u001b[49m\u001b[43m(\u001b[49m\u001b[43mstate\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43miterator\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 419\u001b[0m (\n\u001b[1;32m 420\u001b[0m trainable_variables,\n\u001b[1;32m 421\u001b[0m non_trainable_variables,\n\u001b[1;32m 422\u001b[0m optimizer_variables,\n\u001b[1;32m 423\u001b[0m metrics_variables,\n\u001b[1;32m 424\u001b[0m ) \u001b[38;5;241m=\u001b[39m state\n\u001b[1;32m 426\u001b[0m \u001b[38;5;66;03m# Setting _jax_state enables callbacks to force a state sync\u001b[39;00m\n\u001b[1;32m 427\u001b[0m \u001b[38;5;66;03m# if they need to.\u001b[39;00m\n", - "File \u001b[0;32m~/programs/anaconda3/envs/bayesflow/lib/python3.11/site-packages/keras/src/backend/jax/trainer.py:266\u001b[0m, in \u001b[0;36mJAXTrainer._make_function..iterator_step\u001b[0;34m(state, iterator)\u001b[0m\n\u001b[1;32m 265\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21miterator_step\u001b[39m(state, iterator):\n\u001b[0;32m--> 266\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m step_function(state, \u001b[38;5;28mnext\u001b[39m(iterator))\n", - "File \u001b[0;32m~/programs/anaconda3/envs/bayesflow/lib/python3.11/site-packages/keras/src/backend/jax/trainer.py:1059\u001b[0m, in \u001b[0;36mJAXEpochIterator._prefetch_numpy_iterator\u001b[0;34m(self, numpy_iterator)\u001b[0m\n\u001b[1;32m 1057\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m queue:\n\u001b[1;32m 1058\u001b[0m \u001b[38;5;28;01myield\u001b[39;00m queue\u001b[38;5;241m.\u001b[39mpopleft()\n\u001b[0;32m-> 1059\u001b[0m \u001b[43menqueue\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/programs/anaconda3/envs/bayesflow/lib/python3.11/site-packages/keras/src/backend/jax/trainer.py:1053\u001b[0m, in \u001b[0;36mJAXEpochIterator._prefetch_numpy_iterator..enqueue\u001b[0;34m(n)\u001b[0m\n\u001b[1;32m 1052\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21menqueue\u001b[39m(n\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m):\n\u001b[0;32m-> 1053\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mitertools\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mislice\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnumpy_iterator\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn\u001b[49m\u001b[43m)\u001b[49m\u001b[43m:\u001b[49m\n\u001b[1;32m 1054\u001b[0m \u001b[43m \u001b[49m\u001b[43mqueue\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mappend\u001b[49m\u001b[43m(\u001b[49m\u001b[43m_distribute_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/programs/anaconda3/envs/bayesflow/lib/python3.11/site-packages/keras/src/trainers/data_adapters/data_adapter_utils.py:198\u001b[0m, in \u001b[0;36mget_jax_iterator\u001b[0;34m(iterable)\u001b[0m\n\u001b[1;32m 195\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 196\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m np\u001b[38;5;241m.\u001b[39masarray(x)\n\u001b[0;32m--> 198\u001b[0m \u001b[43m\u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43miterable\u001b[49m\u001b[43m:\u001b[49m\n\u001b[1;32m 199\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43;01myield\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mtree\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmap_structure\u001b[49m\u001b[43m(\u001b[49m\u001b[43mconvert_to_jax_compatible\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/programs/anaconda3/envs/bayesflow/lib/python3.11/site-packages/keras/src/trainers/data_adapters/py_dataset_adapter.py:248\u001b[0m, in \u001b[0;36mPyDatasetAdapter._finite_generator\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 245\u001b[0m random\u001b[38;5;241m.\u001b[39mshuffle(indices)\n\u001b[1;32m 247\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m indices:\n\u001b[0;32m--> 248\u001b[0m \u001b[38;5;28;01myield\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_standardize_batch(\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpy_dataset\u001b[49m\u001b[43m[\u001b[49m\u001b[43mi\u001b[49m\u001b[43m]\u001b[49m)\n", - "File \u001b[0;32m~/code/bayesflow/bayesflow/datasets/online_dataset.py:74\u001b[0m, in \u001b[0;36mOnlineDataset.__getitem__\u001b[0;34m(self, item)\u001b[0m\n\u001b[1;32m 60\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__getitem__\u001b[39m(\u001b[38;5;28mself\u001b[39m, item: \u001b[38;5;28mint\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28mdict\u001b[39m[\u001b[38;5;28mstr\u001b[39m, np\u001b[38;5;241m.\u001b[39mndarray]:\n\u001b[1;32m 61\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 62\u001b[0m \u001b[38;5;124;03m Generate one batch of data.\u001b[39;00m\n\u001b[1;32m 63\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[38;5;124;03m A batch of simulated (and optionally augmented/adapted) data.\u001b[39;00m\n\u001b[1;32m 73\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m---> 74\u001b[0m batch \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msimulator\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msample\u001b[49m\u001b[43m(\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbatch_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 76\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maugmentations \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 77\u001b[0m \u001b[38;5;28;01mpass\u001b[39;00m\n", - "File \u001b[0;32m~/code/bayesflow/bayesflow/utils/decorators.py:63\u001b[0m, in \u001b[0;36malias..alias_wrapper..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 60\u001b[0m matches \u001b[38;5;241m=\u001b[39m [name \u001b[38;5;28;01mfor\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m kwargs \u001b[38;5;28;01mif\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m aliases]\n\u001b[1;32m 62\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m matches:\n\u001b[0;32m---> 63\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 65\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(matches) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m (\u001b[38;5;28mlen\u001b[39m(matches) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(args) \u001b[38;5;241m>\u001b[39m argpos):\n\u001b[1;32m 66\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\n\u001b[1;32m 67\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfn\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m() got multiple values for argument \u001b[39m\u001b[38;5;132;01m{\u001b[39;00margname\u001b[38;5;132;01m!r}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 68\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mThis argument is also aliased as \u001b[39m\u001b[38;5;132;01m{\u001b[39;00maliases\u001b[38;5;132;01m!r}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 69\u001b[0m )\n", - "File \u001b[0;32m~/code/bayesflow/bayesflow/utils/decorators.py:95\u001b[0m, in \u001b[0;36margument_callback..callback_wrapper..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 92\u001b[0m args \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlist\u001b[39m(args)\n\u001b[1;32m 93\u001b[0m args[argpos] \u001b[38;5;241m=\u001b[39m callback(args[argpos])\n\u001b[0;32m---> 95\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/code/bayesflow/bayesflow/simulators/benchmark_simulators/benchmark_simulator.py:27\u001b[0m, in \u001b[0;36mBenchmarkSimulator.sample\u001b[0;34m(self, batch_shape, **kwargs)\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[38;5;129m@allow_batch_size\u001b[39m\n\u001b[1;32m 12\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21msample\u001b[39m(\u001b[38;5;28mself\u001b[39m, batch_shape: Shape, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28mdict\u001b[39m[\u001b[38;5;28mstr\u001b[39m, np\u001b[38;5;241m.\u001b[39mndarray]:\n\u001b[1;32m 13\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Runs simulated benchmark and returns `batch_size` parameter\u001b[39;00m\n\u001b[1;32m 14\u001b[0m \u001b[38;5;124;03m and observation batches\u001b[39;00m\n\u001b[1;32m 15\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[38;5;124;03m with shapes (`batch_size`, ...)\u001b[39;00m\n\u001b[1;32m 25\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m---> 27\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[43mbatched_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_shape\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mflatten\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m 28\u001b[0m data \u001b[38;5;241m=\u001b[39m tree_stack(data, axis\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m, numpy\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 29\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m data\n", - "File \u001b[0;32m~/code/bayesflow/bayesflow/utils/functional.py:54\u001b[0m, in \u001b[0;36mbatched_call\u001b[0;34m(f, batch_shape, args, kwargs, map_predicate, flatten)\u001b[0m\n\u001b[1;32m 51\u001b[0m map_args \u001b[38;5;241m=\u001b[39m [arg[index] \u001b[38;5;28;01mif\u001b[39;00m map_predicate(arg) \u001b[38;5;28;01melse\u001b[39;00m arg \u001b[38;5;28;01mfor\u001b[39;00m arg \u001b[38;5;129;01min\u001b[39;00m args]\n\u001b[1;32m 52\u001b[0m map_kwargs \u001b[38;5;241m=\u001b[39m {key: value[index] \u001b[38;5;28;01mif\u001b[39;00m map_predicate(value) \u001b[38;5;28;01melse\u001b[39;00m value \u001b[38;5;28;01mfor\u001b[39;00m key, value \u001b[38;5;129;01min\u001b[39;00m kwargs\u001b[38;5;241m.\u001b[39mitems()}\n\u001b[0;32m---> 54\u001b[0m outputs[index] \u001b[38;5;241m=\u001b[39m \u001b[43mf\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mmap_args\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mmap_kwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 56\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m flatten:\n\u001b[1;32m 57\u001b[0m outputs \u001b[38;5;241m=\u001b[39m outputs\u001b[38;5;241m.\u001b[39mflatten()\n", - "File \u001b[0;32m~/code/bayesflow/bayesflow/simulators/benchmark_simulators/benchmark_simulator.py:39\u001b[0m, in \u001b[0;36mBenchmarkSimulator.__call__\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 37\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28mdict\u001b[39m[\u001b[38;5;28mstr\u001b[39m, np\u001b[38;5;241m.\u001b[39mndarray]:\n\u001b[1;32m 38\u001b[0m prior_draws \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprior()\n\u001b[0;32m---> 39\u001b[0m observables \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mobservation_model\u001b[49m\u001b[43m(\u001b[49m\u001b[43mprior_draws\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 40\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mdict\u001b[39m(parameters\u001b[38;5;241m=\u001b[39mprior_draws\u001b[38;5;241m.\u001b[39mastype(np\u001b[38;5;241m.\u001b[39mfloat32), observables\u001b[38;5;241m=\u001b[39mobservables\u001b[38;5;241m.\u001b[39mastype(np\u001b[38;5;241m.\u001b[39mfloat32))\n", - "File \u001b[0;32m~/code/bayesflow/bayesflow/simulators/benchmark_simulators/sir.py:107\u001b[0m, in \u001b[0;36mSIR.observation_model\u001b[0;34m(self, params)\u001b[0m\n\u001b[1;32m 104\u001b[0m t_vec \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mlinspace(\u001b[38;5;241m0\u001b[39m, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mT, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mT)\n\u001b[1;32m 106\u001b[0m \u001b[38;5;66;03m# Integrate using scipy and retain only infected (2-nd dimension)\u001b[39;00m\n\u001b[0;32m--> 107\u001b[0m irt \u001b[38;5;241m=\u001b[39m \u001b[43modeint\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_deriv\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx0\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mt_vec\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mN\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbeta\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgamma\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m[:, \u001b[38;5;241m1\u001b[39m]\n\u001b[1;32m 109\u001b[0m \u001b[38;5;66;03m# Subsample evenly the specified number of points, if specified\u001b[39;00m\n\u001b[1;32m 110\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msubsample \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", - "File \u001b[0;32m~/programs/anaconda3/envs/bayesflow/lib/python3.11/site-packages/scipy/integrate/_odepack_py.py:247\u001b[0m, in \u001b[0;36modeint\u001b[0;34m(func, y0, t, args, Dfun, col_deriv, full_output, ml, mu, rtol, atol, tcrit, h0, hmax, hmin, ixpr, mxstep, mxhnil, mxordn, mxords, printmessg, tfirst)\u001b[0m\n\u001b[1;32m 245\u001b[0m t \u001b[38;5;241m=\u001b[39m copy(t)\n\u001b[1;32m 246\u001b[0m y0 \u001b[38;5;241m=\u001b[39m copy(y0)\n\u001b[0;32m--> 247\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43m_odepack\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43modeint\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfunc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my0\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mDfun\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcol_deriv\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mml\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmu\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 248\u001b[0m \u001b[43m \u001b[49m\u001b[43mfull_output\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrtol\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43matol\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtcrit\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mh0\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhmax\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhmin\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 249\u001b[0m \u001b[43m \u001b[49m\u001b[43mixpr\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmxstep\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmxhnil\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmxordn\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmxords\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 250\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mint\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mbool\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mtfirst\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 251\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m output[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m] \u001b[38;5;241m<\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m 252\u001b[0m warning_msg \u001b[38;5;241m=\u001b[39m (\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m_msgs[output[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m]]\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m Run with full_output = 1 to \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 253\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mget quantitative information.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", - "File \u001b[0;32m~/code/bayesflow/bayesflow/simulators/benchmark_simulators/sir.py:60\u001b[0m, in \u001b[0;36mSIR._deriv\u001b[0;34m(self, x, t, N, beta, gamma)\u001b[0m\n\u001b[1;32m 57\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrng \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 58\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrng \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mrandom\u001b[38;5;241m.\u001b[39mdefault_rng()\n\u001b[0;32m---> 60\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_deriv\u001b[39m(\u001b[38;5;28mself\u001b[39m, x, t, N, beta, gamma):\n\u001b[1;32m 61\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Helper function for scipy.integrate.odeint.\"\"\"\u001b[39;00m\n\u001b[1;32m 63\u001b[0m s, i, r \u001b[38;5;241m=\u001b[39m x\n", - "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + "Epoch 1/5\n", + "\u001b[1m200/200\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m30s\u001b[0m 64ms/step - layer_loss: 4.1039e-04 - loss: 2.8483\n", + "Epoch 2/5\n", + "\u001b[1m200/200\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m13s\u001b[0m 66ms/step - layer_loss: 0.0456 - loss: 2.5975 \n", + "Epoch 3/5\n", + "\u001b[1m200/200\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m13s\u001b[0m 66ms/step - layer_loss: 0.2125 - loss: 0.8778\n", + "Epoch 4/5\n", + "\u001b[1m200/200\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m13s\u001b[0m 66ms/step - layer_loss: 0.2350 - loss: 0.3789\n", + "Epoch 5/5\n", + "\u001b[1m200/200\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m13s\u001b[0m 66ms/step - layer_loss: 0.2238 - loss: 0.1804\n" ] } ], @@ -93,28 +74,50 @@ " simulator=bf.simulators.SIR()\n", ")\n", "\n", - "history = workflow.fit_online(epochs=15, batch_size=32, num_batches_per_epoch=200)\n", - "\n", - "diagnostics = workflow.plot_default_diagnostics(test_data=300)" + "history = workflow.fit_online(epochs=5, batch_size=32, num_batches_per_epoch=200)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "id": "66900aa0-99e8-41ee-a08d-5e10b946deb9", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "workflow.approximator.summary_network.losses" ] }, { "cell_type": "code", - "execution_count": null, - "id": "d745dd2d-d5fc-4085-890a-4352886945ca", + "execution_count": 5, + "id": "45e835af-8d1e-4d63-a349-f98e68a02667", "metadata": {}, - "outputs": [], - "source": [] + "outputs": [ + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "workflow.approximator.losses" + ] } ], "metadata": { From 1e88ba018326c697aace3a22214ee5facfd55d93 Mon Sep 17 00:00:00 2001 From: han-ol Date: Tue, 22 Jul 2025 14:42:02 +0200 Subject: [PATCH 4/5] Skip add_loss test for jax backend --- tests/test_approximators/test_add_loss.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/test_approximators/test_add_loss.py b/tests/test_approximators/test_add_loss.py index 833bb3723..e4d56a893 100644 --- a/tests/test_approximators/test_add_loss.py +++ b/tests/test_approximators/test_add_loss.py @@ -23,6 +23,11 @@ def call(self, x, training=False, **kwargs): def test_layer_loss_reported(approximator_using_add_loss, train_dataset, validation_dataset): + from bayesflow.approximators.backend_approximators.jax_approximator import JAXApproximator + + if isinstance(approximator_using_add_loss, JAXApproximator): + pytest.skip(reason="With JAX backend, the compute_metrics method currently fails to consider self.losses.") + approximator = approximator_using_add_loss approximator.compile(optimizer="AdamW") num_epochs = 3 From 4877f6b3a31876ddbdf8469a7a381cf9a57b5f10 Mon Sep 17 00:00:00 2001 From: han-ol Date: Tue, 22 Jul 2025 14:49:28 +0200 Subject: [PATCH 5/5] use os.environ to find out whether jax is used --- tests/test_approximators/test_add_loss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_approximators/test_add_loss.py b/tests/test_approximators/test_add_loss.py index e4d56a893..daaed236c 100644 --- a/tests/test_approximators/test_add_loss.py +++ b/tests/test_approximators/test_add_loss.py @@ -23,9 +23,9 @@ def call(self, x, training=False, **kwargs): def test_layer_loss_reported(approximator_using_add_loss, train_dataset, validation_dataset): - from bayesflow.approximators.backend_approximators.jax_approximator import JAXApproximator + import os - if isinstance(approximator_using_add_loss, JAXApproximator): + if os.environ["KERAS_BACKEND"] == "jax": pytest.skip(reason="With JAX backend, the compute_metrics method currently fails to consider self.losses.") approximator = approximator_using_add_loss