Skip to content

Commit 3c84972

Browse files
committed
update readme
1 parent 6ef8bce commit 3c84972

File tree

4 files changed

+37
-64
lines changed

4 files changed

+37
-64
lines changed

MCintegration/base.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,8 @@
99
from MCintegration.utils import get_device
1010

1111
# Constants for numerical stability
12-
# Small but safe non-zero value
13-
MINVAL = 10 ** (sys.float_info.min_10_exp + 50)
14-
MAXVAL = 10 ** (sys.float_info.max_10_exp - 50) # Large but safe value
1512
EPSILON = 1e-16 # Small value to ensure numerical stability
13+
# EPSILON = sys.float_info.epsilon * 1e4 # Small value to ensure numerical stability
1614

1715

1816
class BaseDistribution(nn.Module):
@@ -98,10 +96,8 @@ def sample(self, batch_size=1, **kwargs):
9896
tuple: (uniform samples, log_det_jacobian=0)
9997
"""
10098
# torch.manual_seed(0) # test seed
101-
u = torch.rand((batch_size, self.dim),
102-
device=self.device, dtype=self.dtype)
103-
log_detJ = torch.zeros(
104-
batch_size, device=self.device, dtype=self.dtype)
99+
u = torch.rand((batch_size, self.dim), device=self.device, dtype=self.dtype)
100+
log_detJ = torch.zeros(batch_size, device=self.device, dtype=self.dtype)
105101
return u, log_detJ
106102

107103

@@ -133,16 +129,14 @@ def __init__(self, A, b, device=None, dtype=torch.float32):
133129
elif isinstance(A, torch.Tensor):
134130
self.A = A.to(dtype=self.dtype, device=self.device)
135131
else:
136-
raise ValueError(
137-
"'A' must be a list, numpy array, or torch tensor.")
132+
raise ValueError("'A' must be a list, numpy array, or torch tensor.")
138133

139134
if isinstance(b, (list, np.ndarray)):
140135
self.b = torch.tensor(b, dtype=self.dtype, device=self.device)
141136
elif isinstance(b, torch.Tensor):
142137
self.b = b.to(dtype=self.dtype, device=self.device)
143138
else:
144-
raise ValueError(
145-
"'b' must be a list, numpy array, or torch tensor.")
139+
raise ValueError("'b' must be a list, numpy array, or torch tensor.")
146140

147141
# Pre-compute determinant of Jacobian for efficiency
148142
self._detJ = torch.prod(self.A)

MCintegration/maps.py

Lines changed: 26 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
from MCintegration.utils import get_device
1010
import sys
1111

12-
TINY = 10 ** (sys.float_info.min_10_exp + 50) # Small but safe non-zero value
12+
# TINY = 10 ** (sys.float_info.min_10_exp + 50) # Small but safe non-zero value
13+
TINY = 1e-45
1314

1415

1516
class Configuration:
@@ -38,14 +39,10 @@ def __init__(self, batch_size, dim, f_dim, device=None, dtype=torch.float32):
3839
self.f_dim = f_dim
3940
self.batch_size = batch_size
4041
# Initialize tensors for storing samples and results
41-
self.u = torch.empty(
42-
(batch_size, dim), dtype=dtype, device=self.device)
43-
self.x = torch.empty(
44-
(batch_size, dim), dtype=dtype, device=self.device)
45-
self.fx = torch.empty((batch_size, f_dim),
46-
dtype=dtype, device=self.device)
47-
self.weight = torch.empty(
48-
(batch_size,), dtype=dtype, device=self.device)
42+
self.u = torch.empty((batch_size, dim), dtype=dtype, device=self.device)
43+
self.x = torch.empty((batch_size, dim), dtype=dtype, device=self.device)
44+
self.fx = torch.empty((batch_size, f_dim), dtype=dtype, device=self.device)
45+
self.weight = torch.empty((batch_size,), dtype=dtype, device=self.device)
4946
self.detJ = torch.empty((batch_size,), dtype=dtype, device=self.device)
5047

5148

@@ -202,8 +199,7 @@ def __init__(self, dim, ninc=1000, device=None, dtype=torch.float32):
202199
(self.dim,), ninc, dtype=torch.int32, device=self.device
203200
)
204201
elif isinstance(ninc, (list, np.ndarray)):
205-
self.ninc = torch.tensor(
206-
ninc, dtype=torch.int32, device=self.device)
202+
self.ninc = torch.tensor(ninc, dtype=torch.int32, device=self.device)
207203
elif isinstance(ninc, torch.Tensor):
208204
self.ninc = ninc.to(dtype=torch.int32, device=self.device)
209205
else:
@@ -223,8 +219,9 @@ def __init__(self, dim, ninc=1000, device=None, dtype=torch.float32):
223219
self.sum_f = torch.zeros(
224220
self.dim, self.max_ninc, dtype=self.dtype, device=self.device
225221
)
226-
self.n_f = torch.zeros(
227-
self.dim, self.max_ninc, dtype=self.dtype, device=self.device
222+
self.n_f = (
223+
torch.zeros(self.dim, self.max_ninc, dtype=self.dtype, device=self.device)
224+
+ TINY
228225
)
229226
self.avg_f = torch.ones(
230227
(self.dim, self.max_ninc), dtype=self.dtype, device=self.device
@@ -308,19 +305,16 @@ def adapt(self, alpha=0.5):
308305
"""
309306
# Aggregate training data across distributed processes if applicable
310307
if torch.distributed.is_initialized():
311-
torch.distributed.all_reduce(
312-
self.sum_f, op=torch.distributed.ReduceOp.SUM)
313-
torch.distributed.all_reduce(
314-
self.n_f, op=torch.distributed.ReduceOp.SUM)
308+
torch.distributed.all_reduce(self.sum_f, op=torch.distributed.ReduceOp.SUM)
309+
torch.distributed.all_reduce(self.n_f, op=torch.distributed.ReduceOp.SUM)
315310

316311
# Initialize a new grid tensor
317312
new_grid = torch.empty(
318313
(self.dim, self.max_ninc + 1), dtype=self.dtype, device=self.device
319314
)
320315

321316
if alpha > 0:
322-
tmp_f = torch.empty(
323-
self.max_ninc, dtype=self.dtype, device=self.device)
317+
tmp_f = torch.empty(self.max_ninc, dtype=self.dtype, device=self.device)
324318

325319
# avg_f = torch.ones(self.inc.shape[1], dtype=self.dtype, device=self.device)
326320
# print(self.ninc.shape, self.dim)
@@ -338,14 +332,12 @@ def adapt(self, alpha=0.5):
338332

339333
if alpha > 0:
340334
# Smooth avg_f
341-
tmp_f[0] = (7.0 * avg_f[0] + avg_f[1]
342-
).abs() / 8.0 # Shape: ()
335+
tmp_f[0] = (7.0 * avg_f[0] + avg_f[1]).abs() / 8.0 # Shape: ()
343336
tmp_f[ninc - 1] = (
344337
7.0 * avg_f[ninc - 1] + avg_f[ninc - 2]
345338
).abs() / 8.0 # Shape: ()
346-
tmp_f[1: ninc - 1] = (
347-
6.0 * avg_f[1: ninc - 1] +
348-
avg_f[: ninc - 2] + avg_f[2:ninc]
339+
tmp_f[1 : ninc - 1] = (
340+
6.0 * avg_f[1 : ninc - 1] + avg_f[: ninc - 2] + avg_f[2:ninc]
349341
).abs() / 8.0
350342

351343
# Normalize tmp_f to ensure the sum is 1
@@ -393,8 +385,7 @@ def adapt(self, alpha=0.5):
393385
) / avg_f_relevant
394386

395387
# Calculate the new grid points using vectorized operations
396-
new_grid[d, 1:ninc] = grid_left + \
397-
fractional_positions * inc_relevant
388+
new_grid[d, 1:ninc] = grid_left + fractional_positions * inc_relevant
398389
else:
399390
# If alpha == 0 or no training data, retain the existing grid
400391
new_grid[d, :] = self.grid[d, :]
@@ -407,8 +398,7 @@ def adapt(self, alpha=0.5):
407398
self.inc.zero_() # Reset increments to zero
408399
for d in range(self.dim):
409400
self.inc[d, : self.ninc[d]] = (
410-
self.grid[d, 1: self.ninc[d] + 1] -
411-
self.grid[d, : self.ninc[d]]
401+
self.grid[d, 1 : self.ninc[d] + 1] - self.grid[d, : self.ninc[d]]
412402
)
413403

414404
# Clear accumulated training data for the next adaptation cycle
@@ -432,8 +422,7 @@ def make_uniform(self):
432422
device=self.device,
433423
)
434424
self.inc[d, : self.ninc[d]] = (
435-
self.grid[d, 1: self.ninc[d] + 1] -
436-
self.grid[d, : self.ninc[d]]
425+
self.grid[d, 1 : self.ninc[d] + 1] - self.grid[d, : self.ninc[d]]
437426
)
438427
self.clear()
439428

@@ -459,8 +448,7 @@ def forward(self, u):
459448

460449
batch_size = u.size(0)
461450
# Clamp iu to [0, ninc-1] to handle out-of-bounds indices
462-
min_tensor = torch.zeros(
463-
(1, self.dim), dtype=iu.dtype, device=self.device)
451+
min_tensor = torch.zeros((1, self.dim), dtype=iu.dtype, device=self.device)
464452
# Shape: (1, dim)
465453
max_tensor = (self.ninc - 1).unsqueeze(0).to(iu.dtype)
466454
iu_clamped = torch.clamp(iu, min=min_tensor, max=max_tensor)
@@ -471,8 +459,7 @@ def forward(self, u):
471459
grid_gather = torch.gather(grid_expanded, 2, iu_clamped.unsqueeze(2)).squeeze(
472460
2
473461
) # Shape: (batch_size, dim)
474-
inc_gather = torch.gather(
475-
inc_expanded, 2, iu_clamped.unsqueeze(2)).squeeze(2)
462+
inc_gather = torch.gather(inc_expanded, 2, iu_clamped.unsqueeze(2)).squeeze(2)
476463

477464
x = grid_gather + inc_gather * du_ninc
478465
log_detJ = (inc_gather * self.ninc).log_().sum(dim=1)
@@ -484,17 +471,15 @@ def forward(self, u):
484471
# For each sample and dimension, set x to grid[d, ninc[d]]
485472
# and log_detJ += log(inc[d, ninc[d]-1] * ninc[d])
486473
boundary_grid = (
487-
self.grid[torch.arange(
488-
self.dim, device=self.device), self.ninc]
474+
self.grid[torch.arange(self.dim, device=self.device), self.ninc]
489475
.unsqueeze(0)
490476
.expand(batch_size, -1)
491477
)
492478
# x = torch.where(out_of_bounds, boundary_grid, x)
493479
x[out_of_bounds] = boundary_grid[out_of_bounds]
494480

495481
boundary_inc = (
496-
self.inc[torch.arange(
497-
self.dim, device=self.device), self.ninc - 1]
482+
self.inc[torch.arange(self.dim, device=self.device), self.ninc - 1]
498483
.unsqueeze(0)
499484
.expand(batch_size, -1)
500485
)
@@ -522,8 +507,7 @@ def inverse(self, x):
522507

523508
# Initialize output tensors
524509
u = torch.empty_like(x)
525-
log_detJ = torch.zeros(
526-
batch_size, device=self.device, dtype=self.dtype)
510+
log_detJ = torch.zeros(batch_size, device=self.device, dtype=self.dtype)
527511

528512
# Loop over each dimension to perform inverse mapping
529513
for d in range(dim):
@@ -537,8 +521,7 @@ def inverse(self, x):
537521
# Perform searchsorted to find indices where x should be inserted to maintain order
538522
# torch.searchsorted returns indices in [0, max_ninc +1]
539523
iu = (
540-
torch.searchsorted(
541-
grid_d, x[:, d].contiguous(), right=True) - 1
524+
torch.searchsorted(grid_d, x[:, d].contiguous(), right=True) - 1
542525
) # Shape: (batch_size,)
543526

544527
# Clamp indices to [0, ninc_d - 1] to ensure they are within valid range
@@ -551,8 +534,7 @@ def inverse(self, x):
551534
inc_gather = inc_d[iu_clamped] # Shape: (batch_size,)
552535

553536
# Compute du: fractional part within the increment
554-
du = (x[:, d] - grid_gather) / \
555-
(inc_gather + TINY) # Shape: (batch_size,)
537+
du = (x[:, d] - grid_gather) / (inc_gather + TINY) # Shape: (batch_size,)
556538

557539
# Compute u for dimension d
558540
u[:, d] = (du + iu_clamped) / ninc_d # Shape: (batch_size,)

MCintegration/utils.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111

1212
# Constants for numerical stability
1313
# Small but safe non-zero value
14-
MINVAL = 10 ** (sys.float_info.min_10_exp + 50)
15-
MAXVAL = 10 ** (sys.float_info.max_10_exp - 50) # Large but safe value
14+
# MINVAL = 10 ** (sys.float_info.min_10_exp + 50)
15+
MINVAL = 1e-45
1616
_VECTOR_TYPES = [np.ndarray, list]
1717

1818

@@ -83,8 +83,7 @@ def add(self, res):
8383
self._wlist.append(1 / (res.var if res.var > MINVAL else MINVAL))
8484
var = 1.0 / np.sum(self._wlist)
8585
sdev = np.sqrt(var)
86-
mean = np.sum(
87-
[w * m for w, m in zip(self._wlist, self._mlist)]) * var
86+
mean = np.sum([w * m for w, m in zip(self._wlist, self._mlist)]) * var
8887
super(RAvg, self).__init__(*gvar.gvar(mean, sdev).internaldata)
8988
else:
9089
# Simple average
@@ -93,8 +92,7 @@ def add(self, res):
9392
self._count += 1
9493
mean = self._sum / self._count
9594
var = self._varsum / self._count**2
96-
super(RAvg, self).__init__(
97-
*gvar.gvar(mean, np.sqrt(var)).internaldata)
95+
super(RAvg, self).__init__(*gvar.gvar(mean, np.sqrt(var)).internaldata)
9896

9997
def extend(self, ravg):
10098
"""

README.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,9 +178,8 @@ def singular_func(x, f):
178178
Integrand with singularity at x=0.
179179
The integral ∫₀¹ log(x)/√x dx = -4 (analytical result)
180180
"""
181-
f[:, 0] = torch.where(x[:, 0] < 1e-14,
182-
0.0,
183-
torch.log(x[:, 0]) / torch.sqrt(x[:, 0]))
181+
x_safe = torch.clamp(x[:, 0], min=1e-32)
182+
f[:, 0] = torch.log(x_safe) / torch.sqrt(x_safe)
184183
return f[:, 0]
185184

186185
# Integration parameters

0 commit comments

Comments
 (0)