Skip to content

Commit 6ef8bce

Browse files
committed
add test for maps and utils
1 parent 4fce731 commit 6ef8bce

File tree

2 files changed

+162
-24
lines changed

2 files changed

+162
-24
lines changed

MCintegration/maps_test.py

Lines changed: 102 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,45 @@
11
import unittest
22
import torch
3-
4-
# import numpy as np
3+
import numpy as np
54
from maps import Map, CompositeMap, Vegas, Configuration
65
from base import LinearMap
76

87

8+
class TestConfiguration(unittest.TestCase):
9+
def setUp(self):
10+
self.batch_size = 5
11+
self.dim = 3
12+
self.f_dim = 2
13+
self.device = "cpu"
14+
self.dtype = torch.float64
15+
16+
def test_configuration_initialization(self):
17+
config = Configuration(
18+
batch_size=self.batch_size,
19+
dim=self.dim,
20+
f_dim=self.f_dim,
21+
device=self.device,
22+
dtype=self.dtype,
23+
)
24+
25+
self.assertEqual(config.batch_size, self.batch_size)
26+
self.assertEqual(config.dim, self.dim)
27+
self.assertEqual(config.f_dim, self.f_dim)
28+
self.assertEqual(config.device, self.device)
29+
30+
self.assertEqual(config.u.shape, (self.batch_size, self.dim))
31+
self.assertEqual(config.x.shape, (self.batch_size, self.dim))
32+
self.assertEqual(config.fx.shape, (self.batch_size, self.f_dim))
33+
self.assertEqual(config.weight.shape, (self.batch_size,))
34+
self.assertEqual(config.detJ.shape, (self.batch_size,))
35+
36+
self.assertEqual(config.u.dtype, self.dtype)
37+
self.assertEqual(config.x.dtype, self.dtype)
38+
self.assertEqual(config.fx.dtype, self.dtype)
39+
self.assertEqual(config.weight.dtype, self.dtype)
40+
self.assertEqual(config.detJ.dtype, self.dtype)
41+
42+
943
class TestMap(unittest.TestCase):
1044
def setUp(self):
1145
self.device = "cpu"
@@ -24,6 +58,35 @@ def test_inverse_not_implemented(self):
2458
with self.assertRaises(NotImplementedError):
2559
self.map.inverse(torch.tensor([0.5, 0.5], dtype=self.dtype))
2660

61+
def test_forward_with_detJ(self):
62+
# Create a simple linear map for testing: x = u * A + b
63+
# With A=[1, 1] and b=[0, 0], we have x = u
64+
linear_map = LinearMap([1, 1], [0, 0], device=self.device)
65+
66+
# Test forward_with_detJ method
67+
u = torch.tensor([[0.5, 0.5]], dtype=torch.float64, device=self.device)
68+
x, detJ = linear_map.forward_with_detJ(u)
69+
70+
# Since it's a linear map from [0,0] to [1,1], x should equal u
71+
self.assertTrue(torch.allclose(x, u))
72+
73+
# Determinant of Jacobian should be 1 for linear map with slope 1
74+
# forward_with_detJ returns actual determinant, not log
75+
self.assertAlmostEqual(detJ.item(), 1.0)
76+
77+
# Test with a different linear map: x = u * [2, 3] + [1, 1]
78+
# So u = [0.5, 0.5] should give x = [0.5*2+1, 0.5*3+1] = [2, 2.5]
79+
linear_map2 = LinearMap([2, 3], [1, 1], device=self.device)
80+
u2 = torch.tensor([[0.5, 0.5]], dtype=torch.float64, device=self.device)
81+
x2, detJ2 = linear_map2.forward_with_detJ(u2)
82+
expected_x2 = torch.tensor(
83+
[[2.0, 2.5]], dtype=torch.float64, device=self.device
84+
)
85+
self.assertTrue(torch.allclose(x2, expected_x2))
86+
87+
# Determinant should be 2 * 3 = 6
88+
self.assertAlmostEqual(detJ2.item(), 6.0)
89+
2790

2891
class TestCompositeMap(unittest.TestCase):
2992
def setUp(self):
@@ -99,6 +162,32 @@ def test_initialization(self):
99162
self.assertTrue(torch.equal(self.vegas.grid, self.init_grid))
100163
self.assertEqual(self.vegas.inc.shape, (2, self.ninc))
101164

165+
def test_ninc_initialization_types(self):
166+
# Test ninc initialization with int
167+
vegas_int = Vegas(self.dim, ninc=5)
168+
self.assertEqual(vegas_int.ninc.tolist(), [5, 5])
169+
170+
# Test ninc initialization with list
171+
vegas_list = Vegas(self.dim, ninc=[5, 10])
172+
self.assertEqual(vegas_list.ninc.tolist(), [5, 10])
173+
174+
# Test ninc initialization with numpy array
175+
vegas_np = Vegas(self.dim, ninc=np.array([3, 7]))
176+
self.assertEqual(vegas_np.ninc.tolist(), [3, 7])
177+
178+
# Test ninc initialization with torch tensor
179+
vegas_tensor = Vegas(self.dim, ninc=torch.tensor([4, 6]))
180+
self.assertEqual(vegas_tensor.ninc.tolist(), [4, 6])
181+
182+
# Test ninc initialization with invalid type
183+
with self.assertRaises(ValueError):
184+
Vegas(self.dim, ninc="invalid")
185+
186+
def test_ninc_shape_validation(self):
187+
# Test ninc shape validation
188+
with self.assertRaises(ValueError):
189+
Vegas(self.dim, ninc=[1, 2, 3]) # Wrong length
190+
102191
def test_add_training_data(self):
103192
# Test adding training data
104193
self.vegas.add_training_data(self.sample)
@@ -137,6 +226,16 @@ def test_forward(self):
137226
self.assertEqual(x.shape, u.shape)
138227
self.assertEqual(log_jac.shape, (u.shape[0],))
139228

229+
def test_forward_with_detJ(self):
230+
# Test forward_with_detJ transformation
231+
u = torch.tensor([[0.1, 0.2], [0.3, 0.4]], dtype=torch.float64)
232+
x, det_jac = self.vegas.forward_with_detJ(u)
233+
self.assertEqual(x.shape, u.shape)
234+
self.assertEqual(det_jac.shape, (u.shape[0],))
235+
236+
# Determinant should be positive
237+
self.assertTrue(torch.all(det_jac > 0))
238+
140239
def test_forward_out_of_bounds(self):
141240
# Test forward transformation with out-of-bounds u values
142241
u = torch.tensor(
@@ -220,4 +319,4 @@ def test_edge_cases(self):
220319

221320

222321
if __name__ == "__main__":
223-
unittest.main()
322+
unittest.main()

MCintegration/utils_test.py

Lines changed: 60 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -164,17 +164,6 @@ def test_converged_criteria(self):
164164
self.assertTrue(ravg.converged(0.1, 0.1))
165165
self.assertFalse(ravg.converged(0.001, 0.001))
166166

167-
def test_multiplication_with_another_ravg(self):
168-
ravg1 = RAvg(weighted=True)
169-
ravg1.update(2.0, 0.1)
170-
ravg2 = RAvg(weighted=True)
171-
ravg2.update(3.0, 0.1)
172-
173-
result = ravg1 * ravg2
174-
self.assertAlmostEqual(result.mean, 6.0)
175-
sdev = (0.1 / 2**2 + 0.1 / 3**2) ** 0.5 * 6.0
176-
self.assertAlmostEqual(result.sdev, sdev)
177-
178167
def test_multiplication(self):
179168
ravg1 = RAvg(weighted=True)
180169
# Test multiplication by another RAvg object
@@ -216,16 +205,19 @@ def test_multiplication(self):
216205
np.allclose([r.sdev for r in result], [2.0 * ravg1.sdev, 3.0 * ravg1.sdev])
217206
)
218207

219-
def test_division_with_another_ravg(self):
220-
ravg1 = RAvg(weighted=True)
221-
ravg1.update(6.0, 0.1)
222-
ravg2 = RAvg(weighted=True)
223-
ravg2.update(3.0, 0.1)
208+
# Test multiplication with unweighted RAvg
209+
ravg_unweighted = RAvg(weighted=False)
210+
ravg_unweighted.update(2.0, 0.1)
211+
result = ravg_unweighted * 3.0
212+
self.assertAlmostEqual(result.mean, 6.0)
213+
self.assertAlmostEqual(result.sdev, ravg_unweighted.sdev * 3)
224214

225-
result = ravg1 / ravg2
226-
self.assertAlmostEqual(result.mean, 2.0)
227-
sdev = (0.1 / 6.0**2 + 0.1 / 3.0**2) ** 0.5 * 2.0
228-
self.assertAlmostEqual(result.sdev, sdev)
215+
# Test multiplication with zero variance
216+
ravg_zero_var = RAvg(weighted=True)
217+
ravg_zero_var.update(2.0, 0.0)
218+
result = ravg_zero_var * 3.0
219+
self.assertAlmostEqual(result.mean, 6.0)
220+
self.assertAlmostEqual(result.sdev, 0.0)
229221

230222
def test_division(self):
231223
ravg1 = RAvg(weighted=True)
@@ -271,6 +263,53 @@ def test_division(self):
271263
np.allclose([r.sdev for r in result], [ravg1.sdev / 2.0, ravg1.sdev / 3.0])
272264
)
273265

266+
# Test division with unweighted RAvg
267+
ravg_unweighted = RAvg(weighted=False)
268+
ravg_unweighted.update(6.0, 0.1)
269+
result = ravg_unweighted / 3.0
270+
self.assertAlmostEqual(result.mean, 2.0)
271+
self.assertAlmostEqual(result.sdev, ravg_unweighted.sdev / 3)
272+
273+
# Test division with zero variance
274+
ravg_zero_var = RAvg(weighted=True)
275+
ravg_zero_var.update(6.0, 0.0)
276+
result = ravg_zero_var / 3.0
277+
self.assertAlmostEqual(result.mean, 2.0)
278+
self.assertAlmostEqual(result.sdev, 0.0)
279+
280+
# Test division of zero by RAvg
281+
zero_ravg = RAvg(weighted=True)
282+
zero_ravg.update(0.0, 0.1)
283+
divisor_ravg = RAvg(weighted=True)
284+
divisor_ravg.update(3.0, 0.1)
285+
result = zero_ravg / divisor_ravg
286+
self.assertAlmostEqual(result.mean, 0.0)
287+
# sdev = (0.1 / 0.0**2 + 0.1 / 3.0**2) ** 0.5 * 0.0 # This would be NaN
288+
# For 0/x, the error propagation gives 0 * sqrt((0.1/0.0^2) + (0.1/3.0^2))
289+
# But since we're dividing by zero, we need to be careful
290+
# In practice, gvar handles this appropriately
291+
292+
def test_vector_operations_not_implemented(self):
293+
# Test that NotImplemented is returned for vector operations
294+
ravg = RAvg(weighted=True)
295+
ravg.update(2.0, 0.1)
296+
297+
# Test multiplication with list (should return NotImplemented)
298+
result = ravg.__mul__([1, 2, 3])
299+
self.assertEqual(result, NotImplemented)
300+
301+
# Test division with list (should return NotImplemented)
302+
result = ravg.__truediv__([1, 2, 3])
303+
self.assertEqual(result, NotImplemented)
304+
305+
# Test multiplication with numpy array (should return NotImplemented)
306+
result = ravg.__mul__(np.array([1, 2, 3]))
307+
self.assertEqual(result, NotImplemented)
308+
309+
# Test division with numpy array (should return NotImplemented)
310+
result = ravg.__truediv__(np.array([1, 2, 3]))
311+
self.assertEqual(result, NotImplemented)
312+
274313

275314
class TestUtils(unittest.TestCase):
276315
def setUp(self):
@@ -315,4 +354,4 @@ def test_get_device_cuda_inactive(self):
315354

316355

317356
if __name__ == "__main__":
318-
unittest.main()
357+
unittest.main()

0 commit comments

Comments
 (0)