99from MCintegration .utils import get_device
1010import sys
1111
12- TINY = 10 ** (sys .float_info .min_10_exp + 50 ) # Small but safe non-zero value
12+ # TINY = 10 ** (sys.float_info.min_10_exp + 50) # Small but safe non-zero value
13+ TINY = 1e-45
1314
1415
1516class Configuration :
@@ -38,14 +39,10 @@ def __init__(self, batch_size, dim, f_dim, device=None, dtype=torch.float32):
3839 self .f_dim = f_dim
3940 self .batch_size = batch_size
4041 # Initialize tensors for storing samples and results
41- self .u = torch .empty (
42- (batch_size , dim ), dtype = dtype , device = self .device )
43- self .x = torch .empty (
44- (batch_size , dim ), dtype = dtype , device = self .device )
45- self .fx = torch .empty ((batch_size , f_dim ),
46- dtype = dtype , device = self .device )
47- self .weight = torch .empty (
48- (batch_size ,), dtype = dtype , device = self .device )
42+ self .u = torch .empty ((batch_size , dim ), dtype = dtype , device = self .device )
43+ self .x = torch .empty ((batch_size , dim ), dtype = dtype , device = self .device )
44+ self .fx = torch .empty ((batch_size , f_dim ), dtype = dtype , device = self .device )
45+ self .weight = torch .empty ((batch_size ,), dtype = dtype , device = self .device )
4946 self .detJ = torch .empty ((batch_size ,), dtype = dtype , device = self .device )
5047
5148
@@ -202,8 +199,7 @@ def __init__(self, dim, ninc=1000, device=None, dtype=torch.float32):
202199 (self .dim ,), ninc , dtype = torch .int32 , device = self .device
203200 )
204201 elif isinstance (ninc , (list , np .ndarray )):
205- self .ninc = torch .tensor (
206- ninc , dtype = torch .int32 , device = self .device )
202+ self .ninc = torch .tensor (ninc , dtype = torch .int32 , device = self .device )
207203 elif isinstance (ninc , torch .Tensor ):
208204 self .ninc = ninc .to (dtype = torch .int32 , device = self .device )
209205 else :
@@ -223,8 +219,9 @@ def __init__(self, dim, ninc=1000, device=None, dtype=torch.float32):
223219 self .sum_f = torch .zeros (
224220 self .dim , self .max_ninc , dtype = self .dtype , device = self .device
225221 )
226- self .n_f = torch .zeros (
227- self .dim , self .max_ninc , dtype = self .dtype , device = self .device
222+ self .n_f = (
223+ torch .zeros (self .dim , self .max_ninc , dtype = self .dtype , device = self .device )
224+ + TINY
228225 )
229226 self .avg_f = torch .ones (
230227 (self .dim , self .max_ninc ), dtype = self .dtype , device = self .device
@@ -308,19 +305,16 @@ def adapt(self, alpha=0.5):
308305 """
309306 # Aggregate training data across distributed processes if applicable
310307 if torch .distributed .is_initialized ():
311- torch .distributed .all_reduce (
312- self .sum_f , op = torch .distributed .ReduceOp .SUM )
313- torch .distributed .all_reduce (
314- self .n_f , op = torch .distributed .ReduceOp .SUM )
308+ torch .distributed .all_reduce (self .sum_f , op = torch .distributed .ReduceOp .SUM )
309+ torch .distributed .all_reduce (self .n_f , op = torch .distributed .ReduceOp .SUM )
315310
316311 # Initialize a new grid tensor
317312 new_grid = torch .empty (
318313 (self .dim , self .max_ninc + 1 ), dtype = self .dtype , device = self .device
319314 )
320315
321316 if alpha > 0 :
322- tmp_f = torch .empty (
323- self .max_ninc , dtype = self .dtype , device = self .device )
317+ tmp_f = torch .empty (self .max_ninc , dtype = self .dtype , device = self .device )
324318
325319 # avg_f = torch.ones(self.inc.shape[1], dtype=self.dtype, device=self.device)
326320 # print(self.ninc.shape, self.dim)
@@ -338,14 +332,12 @@ def adapt(self, alpha=0.5):
338332
339333 if alpha > 0 :
340334 # Smooth avg_f
341- tmp_f [0 ] = (7.0 * avg_f [0 ] + avg_f [1 ]
342- ).abs () / 8.0 # Shape: ()
335+ tmp_f [0 ] = (7.0 * avg_f [0 ] + avg_f [1 ]).abs () / 8.0 # Shape: ()
343336 tmp_f [ninc - 1 ] = (
344337 7.0 * avg_f [ninc - 1 ] + avg_f [ninc - 2 ]
345338 ).abs () / 8.0 # Shape: ()
346- tmp_f [1 : ninc - 1 ] = (
347- 6.0 * avg_f [1 : ninc - 1 ] +
348- avg_f [: ninc - 2 ] + avg_f [2 :ninc ]
339+ tmp_f [1 : ninc - 1 ] = (
340+ 6.0 * avg_f [1 : ninc - 1 ] + avg_f [: ninc - 2 ] + avg_f [2 :ninc ]
349341 ).abs () / 8.0
350342
351343 # Normalize tmp_f to ensure the sum is 1
@@ -393,8 +385,7 @@ def adapt(self, alpha=0.5):
393385 ) / avg_f_relevant
394386
395387 # Calculate the new grid points using vectorized operations
396- new_grid [d , 1 :ninc ] = grid_left + \
397- fractional_positions * inc_relevant
388+ new_grid [d , 1 :ninc ] = grid_left + fractional_positions * inc_relevant
398389 else :
399390 # If alpha == 0 or no training data, retain the existing grid
400391 new_grid [d , :] = self .grid [d , :]
@@ -407,8 +398,7 @@ def adapt(self, alpha=0.5):
407398 self .inc .zero_ () # Reset increments to zero
408399 for d in range (self .dim ):
409400 self .inc [d , : self .ninc [d ]] = (
410- self .grid [d , 1 : self .ninc [d ] + 1 ] -
411- self .grid [d , : self .ninc [d ]]
401+ self .grid [d , 1 : self .ninc [d ] + 1 ] - self .grid [d , : self .ninc [d ]]
412402 )
413403
414404 # Clear accumulated training data for the next adaptation cycle
@@ -432,8 +422,7 @@ def make_uniform(self):
432422 device = self .device ,
433423 )
434424 self .inc [d , : self .ninc [d ]] = (
435- self .grid [d , 1 : self .ninc [d ] + 1 ] -
436- self .grid [d , : self .ninc [d ]]
425+ self .grid [d , 1 : self .ninc [d ] + 1 ] - self .grid [d , : self .ninc [d ]]
437426 )
438427 self .clear ()
439428
@@ -459,8 +448,7 @@ def forward(self, u):
459448
460449 batch_size = u .size (0 )
461450 # Clamp iu to [0, ninc-1] to handle out-of-bounds indices
462- min_tensor = torch .zeros (
463- (1 , self .dim ), dtype = iu .dtype , device = self .device )
451+ min_tensor = torch .zeros ((1 , self .dim ), dtype = iu .dtype , device = self .device )
464452 # Shape: (1, dim)
465453 max_tensor = (self .ninc - 1 ).unsqueeze (0 ).to (iu .dtype )
466454 iu_clamped = torch .clamp (iu , min = min_tensor , max = max_tensor )
@@ -471,8 +459,7 @@ def forward(self, u):
471459 grid_gather = torch .gather (grid_expanded , 2 , iu_clamped .unsqueeze (2 )).squeeze (
472460 2
473461 ) # Shape: (batch_size, dim)
474- inc_gather = torch .gather (
475- inc_expanded , 2 , iu_clamped .unsqueeze (2 )).squeeze (2 )
462+ inc_gather = torch .gather (inc_expanded , 2 , iu_clamped .unsqueeze (2 )).squeeze (2 )
476463
477464 x = grid_gather + inc_gather * du_ninc
478465 log_detJ = (inc_gather * self .ninc ).log_ ().sum (dim = 1 )
@@ -484,17 +471,15 @@ def forward(self, u):
484471 # For each sample and dimension, set x to grid[d, ninc[d]]
485472 # and log_detJ += log(inc[d, ninc[d]-1] * ninc[d])
486473 boundary_grid = (
487- self .grid [torch .arange (
488- self .dim , device = self .device ), self .ninc ]
474+ self .grid [torch .arange (self .dim , device = self .device ), self .ninc ]
489475 .unsqueeze (0 )
490476 .expand (batch_size , - 1 )
491477 )
492478 # x = torch.where(out_of_bounds, boundary_grid, x)
493479 x [out_of_bounds ] = boundary_grid [out_of_bounds ]
494480
495481 boundary_inc = (
496- self .inc [torch .arange (
497- self .dim , device = self .device ), self .ninc - 1 ]
482+ self .inc [torch .arange (self .dim , device = self .device ), self .ninc - 1 ]
498483 .unsqueeze (0 )
499484 .expand (batch_size , - 1 )
500485 )
@@ -522,8 +507,7 @@ def inverse(self, x):
522507
523508 # Initialize output tensors
524509 u = torch .empty_like (x )
525- log_detJ = torch .zeros (
526- batch_size , device = self .device , dtype = self .dtype )
510+ log_detJ = torch .zeros (batch_size , device = self .device , dtype = self .dtype )
527511
528512 # Loop over each dimension to perform inverse mapping
529513 for d in range (dim ):
@@ -537,8 +521,7 @@ def inverse(self, x):
537521 # Perform searchsorted to find indices where x should be inserted to maintain order
538522 # torch.searchsorted returns indices in [0, max_ninc +1]
539523 iu = (
540- torch .searchsorted (
541- grid_d , x [:, d ].contiguous (), right = True ) - 1
524+ torch .searchsorted (grid_d , x [:, d ].contiguous (), right = True ) - 1
542525 ) # Shape: (batch_size,)
543526
544527 # Clamp indices to [0, ninc_d - 1] to ensure they are within valid range
@@ -551,8 +534,7 @@ def inverse(self, x):
551534 inc_gather = inc_d [iu_clamped ] # Shape: (batch_size,)
552535
553536 # Compute du: fractional part within the increment
554- du = (x [:, d ] - grid_gather ) / \
555- (inc_gather + TINY ) # Shape: (batch_size,)
537+ du = (x [:, d ] - grid_gather ) / (inc_gather + TINY ) # Shape: (batch_size,)
556538
557539 # Compute u for dimension d
558540 u [:, d ] = (du + iu_clamped ) / ninc_d # Shape: (batch_size,)
0 commit comments