11import unittest
22import torch
3-
4- # import numpy as np
3+ import numpy as np
54from maps import Map , CompositeMap , Vegas , Configuration
65from base import LinearMap
76
87
8+ class TestConfiguration (unittest .TestCase ):
9+ def setUp (self ):
10+ self .batch_size = 5
11+ self .dim = 3
12+ self .f_dim = 2
13+ self .device = "cpu"
14+ self .dtype = torch .float64
15+
16+ def test_configuration_initialization (self ):
17+ config = Configuration (
18+ batch_size = self .batch_size ,
19+ dim = self .dim ,
20+ f_dim = self .f_dim ,
21+ device = self .device ,
22+ dtype = self .dtype ,
23+ )
24+
25+ self .assertEqual (config .batch_size , self .batch_size )
26+ self .assertEqual (config .dim , self .dim )
27+ self .assertEqual (config .f_dim , self .f_dim )
28+ self .assertEqual (config .device , self .device )
29+
30+ self .assertEqual (config .u .shape , (self .batch_size , self .dim ))
31+ self .assertEqual (config .x .shape , (self .batch_size , self .dim ))
32+ self .assertEqual (config .fx .shape , (self .batch_size , self .f_dim ))
33+ self .assertEqual (config .weight .shape , (self .batch_size ,))
34+ self .assertEqual (config .detJ .shape , (self .batch_size ,))
35+
36+ self .assertEqual (config .u .dtype , self .dtype )
37+ self .assertEqual (config .x .dtype , self .dtype )
38+ self .assertEqual (config .fx .dtype , self .dtype )
39+ self .assertEqual (config .weight .dtype , self .dtype )
40+ self .assertEqual (config .detJ .dtype , self .dtype )
41+
42+
943class TestMap (unittest .TestCase ):
1044 def setUp (self ):
1145 self .device = "cpu"
@@ -24,6 +58,35 @@ def test_inverse_not_implemented(self):
2458 with self .assertRaises (NotImplementedError ):
2559 self .map .inverse (torch .tensor ([0.5 , 0.5 ], dtype = self .dtype ))
2660
61+ def test_forward_with_detJ (self ):
62+ # Create a simple linear map for testing: x = u * A + b
63+ # With A=[1, 1] and b=[0, 0], we have x = u
64+ linear_map = LinearMap ([1 , 1 ], [0 , 0 ], device = self .device )
65+
66+ # Test forward_with_detJ method
67+ u = torch .tensor ([[0.5 , 0.5 ]], dtype = torch .float64 , device = self .device )
68+ x , detJ = linear_map .forward_with_detJ (u )
69+
70+ # Since it's a linear map from [0,0] to [1,1], x should equal u
71+ self .assertTrue (torch .allclose (x , u ))
72+
73+ # Determinant of Jacobian should be 1 for linear map with slope 1
74+ # forward_with_detJ returns actual determinant, not log
75+ self .assertAlmostEqual (detJ .item (), 1.0 )
76+
77+ # Test with a different linear map: x = u * [2, 3] + [1, 1]
78+ # So u = [0.5, 0.5] should give x = [0.5*2+1, 0.5*3+1] = [2, 2.5]
79+ linear_map2 = LinearMap ([2 , 3 ], [1 , 1 ], device = self .device )
80+ u2 = torch .tensor ([[0.5 , 0.5 ]], dtype = torch .float64 , device = self .device )
81+ x2 , detJ2 = linear_map2 .forward_with_detJ (u2 )
82+ expected_x2 = torch .tensor (
83+ [[2.0 , 2.5 ]], dtype = torch .float64 , device = self .device
84+ )
85+ self .assertTrue (torch .allclose (x2 , expected_x2 ))
86+
87+ # Determinant should be 2 * 3 = 6
88+ self .assertAlmostEqual (detJ2 .item (), 6.0 )
89+
2790
2891class TestCompositeMap (unittest .TestCase ):
2992 def setUp (self ):
@@ -99,6 +162,32 @@ def test_initialization(self):
99162 self .assertTrue (torch .equal (self .vegas .grid , self .init_grid ))
100163 self .assertEqual (self .vegas .inc .shape , (2 , self .ninc ))
101164
165+ def test_ninc_initialization_types (self ):
166+ # Test ninc initialization with int
167+ vegas_int = Vegas (self .dim , ninc = 5 )
168+ self .assertEqual (vegas_int .ninc .tolist (), [5 , 5 ])
169+
170+ # Test ninc initialization with list
171+ vegas_list = Vegas (self .dim , ninc = [5 , 10 ])
172+ self .assertEqual (vegas_list .ninc .tolist (), [5 , 10 ])
173+
174+ # Test ninc initialization with numpy array
175+ vegas_np = Vegas (self .dim , ninc = np .array ([3 , 7 ]))
176+ self .assertEqual (vegas_np .ninc .tolist (), [3 , 7 ])
177+
178+ # Test ninc initialization with torch tensor
179+ vegas_tensor = Vegas (self .dim , ninc = torch .tensor ([4 , 6 ]))
180+ self .assertEqual (vegas_tensor .ninc .tolist (), [4 , 6 ])
181+
182+ # Test ninc initialization with invalid type
183+ with self .assertRaises (ValueError ):
184+ Vegas (self .dim , ninc = "invalid" )
185+
186+ def test_ninc_shape_validation (self ):
187+ # Test ninc shape validation
188+ with self .assertRaises (ValueError ):
189+ Vegas (self .dim , ninc = [1 , 2 , 3 ]) # Wrong length
190+
102191 def test_add_training_data (self ):
103192 # Test adding training data
104193 self .vegas .add_training_data (self .sample )
@@ -137,6 +226,16 @@ def test_forward(self):
137226 self .assertEqual (x .shape , u .shape )
138227 self .assertEqual (log_jac .shape , (u .shape [0 ],))
139228
229+ def test_forward_with_detJ (self ):
230+ # Test forward_with_detJ transformation
231+ u = torch .tensor ([[0.1 , 0.2 ], [0.3 , 0.4 ]], dtype = torch .float64 )
232+ x , det_jac = self .vegas .forward_with_detJ (u )
233+ self .assertEqual (x .shape , u .shape )
234+ self .assertEqual (det_jac .shape , (u .shape [0 ],))
235+
236+ # Determinant should be positive
237+ self .assertTrue (torch .all (det_jac > 0 ))
238+
140239 def test_forward_out_of_bounds (self ):
141240 # Test forward transformation with out-of-bounds u values
142241 u = torch .tensor (
@@ -220,4 +319,4 @@ def test_edge_cases(self):
220319
221320
222321if __name__ == "__main__" :
223- unittest .main ()
322+ unittest .main ()
0 commit comments