11import pytest
2- from typing import List
3- from data_stack .mnist .factory import MNISTFactory
42from data_stack .io .storage_connectors import StorageConnector , StorageConnectorFactory
53from data_stack .dataset .reporting import DatasetIteratorReportGenerator
64import tempfile
75import shutil
8- from data_stack .dataset .iterator import InformedDatasetIterator
9- from data_stack .dataset .meta import MetaFactory
106from data_stack .dataset .factory import InformedDatasetFactory
7+ from data_stack .dataset .meta import DatasetMeta , MetaFactory
8+ from data_stack .dataset .iterator import DatasetIteratorIF , SequenceDatasetIterator , InformedDatasetIterator
119
1210
1311class TestReporting :
@@ -22,36 +20,38 @@ def tmp_folder_path(self) -> str:
2220 def storage_connector (self , tmp_folder_path : str ) -> StorageConnector :
2321 return StorageConnectorFactory .get_file_storage_connector (tmp_folder_path )
2422
25- @pytest .fixture (scope = "session" )
26- def mnist_factory (self , storage_connector ) -> List [int ]:
27- mnist_factory = MNISTFactory (storage_connector )
28- return mnist_factory
29-
30- def test_plain_iterator_reporting (self , mnist_factory ):
31- iterator , iterator_meta = mnist_factory .get_dataset_iterator (config = {"split" : "train" })
32- dataset_meta = MetaFactory .get_dataset_meta (identifier = "id x" , dataset_name = "MNIST" ,
33- dataset_tag = "train" , iterator_meta = iterator_meta )
34-
35- informed_iterator = InformedDatasetIterator (iterator , dataset_meta )
36- report = DatasetIteratorReportGenerator .generate_report (informed_iterator )
23+ # @pytest.fixture(scope="session")
24+ # def mnist_factory(self, storage_connector) -> List[int]:
25+ # mnist_factory = MNISTFactory(storage_connector)
26+ # return mnist_factory
27+
28+ @pytest .fixture
29+ def dataset_meta (self ) -> DatasetMeta :
30+ iterator_meta = MetaFactory .get_iterator_meta (sample_pos = 0 , target_pos = 1 , tag_pos = 2 )
31+ return MetaFactory .get_dataset_meta (identifier = "identifier_1" ,
32+ dataset_name = "TEST DATASET" ,
33+ dataset_tag = "train" ,
34+ iterator_meta = iterator_meta )
35+
36+ @pytest .fixture
37+ def dataset_iterator (self ) -> DatasetIteratorIF :
38+ targets = [j for i in range (10 ) for j in range (9 )] + [10 ]* 1000
39+ samples = [0 ]* len (targets )
40+ return SequenceDatasetIterator (dataset_sequences = [samples , targets ])
41+
42+ @pytest .fixture
43+ def informed_dataset_iterator (self , dataset_iterator , dataset_meta ) -> DatasetIteratorIF :
44+ return InformedDatasetFactory .get_dataset_iterator (dataset_iterator , dataset_meta )
45+
46+ def test_plain_iterator_reporting (self , informed_dataset_iterator ):
47+ report = DatasetIteratorReportGenerator .generate_report (informed_dataset_iterator )
3748 print (report )
38- assert report .length == 60000 and not report .sub_reports
39-
40- def test_combined_iterator_reporting (self , mnist_factory ):
41-
42- iterator_train , iterator_train_meta = mnist_factory .get_dataset_iterator (config = {"split" : "train" })
43- iterator_test , iterator_test_meta = mnist_factory .get_dataset_iterator (config = {"split" : "test" })
44- meta_train = MetaFactory .get_dataset_meta (identifier = "id x" , dataset_name = "MNIST" ,
45- dataset_tag = "train" , iterator_meta = iterator_train_meta )
46- meta_test = MetaFactory .get_dataset_meta (identifier = "id x" , dataset_name = "MNIST" ,
47- dataset_tag = "train" , iterator_meta = iterator_test_meta )
48-
49- informed_iterator_train = InformedDatasetFactory .get_dataset_iterator (iterator_train , meta_train )
50- informed_iterator_test = InformedDatasetFactory .get_dataset_iterator (iterator_test , meta_test )
51-
52- meta_combined = MetaFactory .get_dataset_meta_from_existing (informed_iterator_train .dataset_meta , dataset_tag = "full" )
49+ assert report .length == 1090 and not report .sub_reports
5350
54- iterator = InformedDatasetFactory .get_combined_dataset_iterator ([informed_iterator_train , informed_iterator_test ], meta_combined )
51+ def test_combined_iterator_reporting (self , informed_dataset_iterator ):
52+ meta_combined = MetaFactory .get_dataset_meta_from_existing (informed_dataset_iterator .dataset_meta , dataset_tag = "full" )
53+ iterator = InformedDatasetFactory .get_combined_dataset_iterator (
54+ [informed_dataset_iterator , informed_dataset_iterator ], meta_combined )
5555 report = DatasetIteratorReportGenerator .generate_report (iterator )
56- assert report .length == 70000 and report .sub_reports [0 ].length == 60000 and report .sub_reports [1 ].length == 10000
56+ assert report .length == 2180 and report .sub_reports [0 ].length == 1090 and report .sub_reports [1 ].length == 1090
5757 assert not report .sub_reports [0 ].sub_reports and not report .sub_reports [1 ].sub_reports
0 commit comments