-
-
Couldn't load subscription status.
- Fork 308
Description
🐛 Describe the bug
I am using benchmark_with_validation_stream to split my benchmark into a version where the train_stream is divided into train_stream and valid_stream. However, training fails as soon as I set eval_streams=[cl_val_stream] in my strategy.
Error: AttributeError: 'DatasetExperience' object has no attribute 'benchmark'
When inspecting the streams it seems like the train_steam, as well as the valid_stream, no longer have the benchmark attribute, whereas test_stream still has it.
I think the issue here is that the generated streams from benchmark_with_validation_stream, so the train_stream and valid_stream, belong to EagerCLStream while the test_stream still belongs to NCStream.
Also the train_stream and valid_stream lose the benchmark attribute after calling benchmark_with_validation_stream, while test_stream retains it.
🐜 To Reproduce
cl_mnist = SplitMNIST(
n_experiences=5,
return_task_id=False,
seed=42,
fixed_class_order=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
)
cl_mnist_with_val = benchmark_with_validation_stream(
cl_mnist,
validation_size=0.1,
shuffle=True,
seed=42
)
cl_train_stream = cl_mnist_with_val.train_stream
cl_test_stream = cl_mnist_with_val.test_stream
cl_val_stream = cl_mnist_with_val.valid_stream
baseline_model = SimpleMLP(num_classes=10)
baseline_optimizer = SGD(baseline_model.parameters(), lr=0.001, momentum=0.9)
baseline_criterion = CrossEntropyLoss()
baseline_naive_strategy = Naive(
model=baseline_model,
optimizer=baseline_optimizer,
criterion=baseline_criterion,
train_mb_size=64,
train_epochs=5,
eval_mb_size=64,
eval_every=0,
evaluator=baseline_eval_plugin
)
baseline_results = []
for exp in cl_train_stream:
res = baseline_naive_strategy.train(exp, eval_streams=[cl_val_stream]) # <- Error is happening here
baseline_results.append(res)
🐝 Expected behavior
I should be able to train the model on the train set, validate it during training on the validation set and do inference on the test set afterwards. However, since cl_val_stream laks the benchmark attribute, training fails when eval_streams=[cl_val_stream] is used.