1414@pytest .mark .parametrize ("drd" , [0 , 2 ])
1515@pytest .mark .parametrize ("uf" , [1 , 2 ])
1616@pytest .mark .parametrize ("ub" , [1 , 2 ])
17- def test_forward_nt (nt , mwd , mrd , dwd , drd , uf , ub , singlefile ):
17+ @pytest .mark .parametrize ("block_size" , [1 , 5 , 10 ])
18+ def test_forward_nt (nt , mwd , mrd , dwd , drd , uf , ub , singlefile , block_size ):
1819 nx = 10
1920 ny = 10
20- df = np .zeros ([nx , ny ])
21+ df = np .zeros ([block_size , nx , ny ])
2122 db = np .zeros ([nx , ny ])
2223 cp = IncrementCheckpoint ([df , db ])
2324 f = IncOperator (1 , df )
@@ -28,7 +29,8 @@ def test_forward_nt(nt, mwd, mrd, dwd, drd, uf, ub, singlefile):
2829 cp .size , nt , cp .dtype , filedir = "./" , singlefile = singlefile , wd = dwd , rd = drd
2930 )
3031 st_list = [npStorage , dkStorage ]
31- rev = MultiLevelRevolver (cp , f , b , nt , storage_list = st_list , uf = uf , ub = ub )
32+ rev = MultiLevelRevolver (cp , f , b , nt , storage_list = st_list , uf = uf , ub = ub ,
33+ block_size = block_size )
3234 assert f .counter == 0
3335 rev .apply_forward ()
3436 assert f .counter == nt
@@ -42,10 +44,11 @@ def test_forward_nt(nt, mwd, mrd, dwd, drd, uf, ub, singlefile):
4244@pytest .mark .parametrize ("drd" , [0 , 2 ])
4345@pytest .mark .parametrize ("uf" , [1 , 2 ])
4446@pytest .mark .parametrize ("ub" , [1 , 2 ])
45- def test_reverse_nt (nt , mwd , mrd , dwd , drd , uf , ub , singlefile ):
47+ @pytest .mark .parametrize ("block_size" , [1 , 5 , 10 ])
48+ def test_reverse_nt (nt , mwd , mrd , dwd , drd , uf , ub , singlefile , block_size ):
4649 nx = 10
4750 ny = 10
48- df = np .zeros ([nx , ny ])
51+ df = np .zeros ([block_size , nx , ny ])
4952 db = np .zeros ([nx , ny ])
5053 cp = IncrementCheckpoint ([df ])
5154 f = IncOperator (1 , df )
@@ -56,7 +59,8 @@ def test_reverse_nt(nt, mwd, mrd, dwd, drd, uf, ub, singlefile):
5659 cp .size , nt , cp .dtype , filedir = "./" , singlefile = singlefile , wd = dwd , rd = drd
5760 )
5861 st_list = [npStorage , dkStorage ]
59- rev = MultiLevelRevolver (cp , f , b , nt , storage_list = st_list , uf = uf , ub = ub )
62+ rev = MultiLevelRevolver (cp , f , b , nt , storage_list = st_list , uf = uf , ub = ub ,
63+ block_size = block_size )
6064
6165 rev .apply_forward ()
6266 assert f .counter == nt
@@ -73,7 +77,8 @@ def test_reverse_nt(nt, mwd, mrd, dwd, drd, uf, ub, singlefile):
7377@pytest .mark .parametrize ("drd" , [0 , 2 ])
7478@pytest .mark .parametrize ("uf" , [1 , 2 ])
7579@pytest .mark .parametrize ("ub" , [1 , 2 ])
76- def test_num_loads_and_saves (nt , mwd , mrd , dwd , drd , uf , ub , singlefile ):
80+ @pytest .mark .parametrize ("block_size" , [1 , 5 , 10 ])
81+ def test_num_loads_and_saves (nt , mwd , mrd , dwd , drd , uf , ub , singlefile , block_size ):
7782 cp = SimpleCheckpoint ()
7883 f = SimpleOperator ()
7984 b = SimpleOperator ()
@@ -83,7 +88,8 @@ def test_num_loads_and_saves(nt, mwd, mrd, dwd, drd, uf, ub, singlefile):
8388 cp .size , nt , cp .dtype , filedir = "./" , singlefile = singlefile , wd = dwd , rd = drd
8489 )
8590 st_list = [npStorage , dkStorage ]
86- rev = MultiLevelRevolver (cp , f , b , nt , storage_list = st_list , uf = uf , ub = ub )
91+ rev = MultiLevelRevolver (cp , f , b , nt , storage_list = st_list , uf = uf , ub = ub ,
92+ block_size = block_size )
8793
8894 rev .apply_forward ()
8995 assert cp .load_counter == 0
@@ -99,20 +105,21 @@ def test_num_loads_and_saves(nt, mwd, mrd, dwd, drd, uf, ub, singlefile):
99105@pytest .mark .parametrize ("drd" , [0 , 2 ])
100106@pytest .mark .parametrize ("uf" , [1 , 2 ])
101107@pytest .mark .parametrize ("ub" , [1 , 2 ])
102- def test_multi_and_single_outputs (nt , mwd , mrd , dwd , drd , uf , ub ):
108+ @pytest .mark .parametrize ("block_size" , [1 , 5 , 10 ])
109+ def test_multi_and_single_outputs (nt , mwd , mrd , dwd , drd , uf , ub , block_size ):
103110 """
104111 Tests whether SingleLevelRevolver and MultilevelRevolver are producing
105112 the same outputs
106113 """
107114 nx = 10
108115 ny = 10
109116 const = 1
110- m_df = np .zeros ([nx , ny ])
117+ m_df = np .zeros ([block_size , nx , ny ])
111118 m_db = np .zeros ([nx , ny ])
112119 m_cp = IncrementCheckpoint ([m_df ])
113120 m_fwd = IncOperator (const , m_df )
114121 m_rev = IncOperator ((- 1 ) * const , m_df , m_db )
115- s_df = np .zeros ([nx , ny ])
122+ s_df = np .zeros ([block_size , nx , ny ])
116123 s_db = np .zeros ([nx , ny ])
117124 s_cp = IncrementCheckpoint ([s_df ])
118125 s_fwd = IncOperator (const , s_df )
@@ -124,10 +131,11 @@ def test_multi_and_single_outputs(nt, mwd, mrd, dwd, drd, uf, ub):
124131 )
125132 st_list = [m_npStorage , m_dkStorage ]
126133 m_wrp = MultiLevelRevolver (
127- m_cp , m_fwd , m_rev , nt , storage_list = st_list , uf = uf , ub = ub
134+ m_cp , m_fwd , m_rev , nt , storage_list = st_list , uf = uf , ub = ub ,
135+ block_size = block_size
128136 )
129137
130- s_wrp = MemoryRevolver (s_cp , s_fwd , s_rev , nt , nt )
138+ s_wrp = MemoryRevolver (s_cp , s_fwd , s_rev , nt , nt , block_size = block_size )
131139
132140 m_wrp .apply_forward ()
133141 s_wrp .apply_forward ()
0 commit comments