Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 62 additions & 52 deletions analysator/vlsvfile/vlsvreader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2129,17 +2129,61 @@ def get_max_refinement_level(self):
self.__max_spatial_amr_level = AMR_count - 1
return self.__max_spatial_amr_level


def wrap_array(dimensions):
'''Wrapper for consolidating inputs as numpy arrays.

Putting @wrap_array(dimension) before a function will use this automatically.

When making a function that uses this, remember to return whatever variable it calculates


:param dimensions: int or list of int in same order as the arguments, 0 in the list will skip an argument, self is automatically skipped.
:param squeeze: Whether to squeeze the output if input dimension mismatched, default True
'''
#Check if integer

if type(dimensions)==type(1):
dimensions=[dimensions]

def wrap_array_inner(func):
def wrap(*args, **kwargs):
stack = True
for i,d in enumerate(dimensions):
#Offset since self is always the first argument
i=i+1
if d!=0:
arg=np.array(args[i])
if arg.ndim!=d:
#Make sure that the scalar argument is turned into np.array of dimension d
while arg.ndim<d:
arg=arg[np.newaxis]
#Make sure to return scalar if value was given as scalar
stack = False

#Update args
args = (*args[0:i],arg,*args[i+1:])


#Call the function
variable = func(*args, **kwargs)

if stack:
return variable
else:
variable=np.squeeze(variable,axis=0)
return variable

return wrap
return wrap_array_inner

@wrap_array(dimensions=1)
def get_amr_level(self,cellid):
'''Returns the AMR level of a given cell defined by its cellid

:param cellid: The cell's cellid
:returns: The cell's refinement level in the AMR
'''
stack = True
if not hasattr(cellid,"__len__"):
cellid = np.atleast_1d(cellid)
stack = False

AMR_count = np.zeros(np.array(cellid).shape, dtype=np.int64)
cellids = cellid.astype(np.int64)
iters = 0
Expand All @@ -2152,24 +2196,15 @@ def get_amr_level(self,cellid):
if(iters > self.get_max_refinement_level()+1):
logging.info("Can't have that large refinements. Something broke.")
break
return AMR_count-1

if stack:
return AMR_count - 1
else:
return (AMR_count - 1)[0]

@wrap_array(dimensions=1)
def get_cell_dx(self, cellid):
'''Returns the dx of a given cell defined by its cellid

:param cellid: The cell's cellid
:returns: The cell's size [dx, dy, dz]
'''

stack = True
if not hasattr(cellid,"__len__"):
cellid = np.atleast_1d(cellid)
stack = False

cellid = np.array(cellid, dtype=np.int64)

dxs = np.array([[self.__dx,self.__dy,self.__dz]])
Expand All @@ -2181,11 +2216,7 @@ def get_cell_dx(self, cellid):
amrs[amrs < 0] = 0

ret = dxs/2**amrs

if stack:
return ret
else:
return ret[0]
return ret

def get_cell_bbox(self, cellid):
'''Returns the bounding box of a given cell defined by its cellid
Expand Down Expand Up @@ -2872,7 +2903,7 @@ def build_duals(self, cid):
self.build_dual_from_vertices(list(vertices))



@wrap_array(dimensions=1)
def get_cell_coordinates(self, cellids):
''' Returns a given cell's coordinates as a numpy array

Expand All @@ -2884,10 +2915,6 @@ def get_cell_coordinates(self, cellids):
.. note:: The cell ids go from 1 .. max not from 0
'''

stack = True
if not hasattr(cellids,"__len__"):
cellids = np.atleast_1d(cellids)
stack = False

# Get cell lengths:
xcells = np.zeros((self.get_max_refinement_level()+1), dtype=np.int64)
Expand All @@ -2907,12 +2934,11 @@ def get_cell_coordinates(self, cellids):
(self.__zmax - self.__zmin)/(zcells[reflevels])]).T
mins = np.array([self.__xmin,self.__ymin,self.__zmin])
cellcoordinates = mins + (cellindices + 0.5)*cell_lengths

# Return the coordinates:
if stack:
return np.array(cellcoordinates)
else:
return np.array(cellcoordinates)[0,:]
return cellcoordinates

@wrap_array(dimensions=1)
def get_cell_indices(self, cellids, reflevels=None):
''' Returns a given cell's indices as a numpy array

Expand All @@ -2924,12 +2950,6 @@ def get_cell_indices(self, cellids, reflevels=None):

.. note:: The cell ids go from 1 .. max not from 0
'''

stack = True
if not hasattr(cellids,"__len__"):
cellids = np.atleast_1d(cellids)
stack = False

if reflevels is None:
reflevels = self.get_amr_level(cellids)
else:
Expand All @@ -2951,12 +2971,10 @@ def get_cell_indices(self, cellids, reflevels=None):
cellindices[mask,1] = ((cellids[mask])//(np.power(2,reflevels[mask])*self.__xcells))%(np.power(2,reflevels[mask])*self.__ycells)
cellindices[mask,2] = (cellids[mask])//(np.power(4,reflevels[mask])*self.__xcells*self.__ycells)

# Return the indices:
if stack:
return np.array(cellindices)
else:
return np.array(cellindices)[0]
return cellindices


@wrap_array(dimensions=[1,2])
def get_cell_neighbor(self, cellidss, offsetss, periodic, prune_uniques=False):
''' Returns a given cells neighbor at offset (in indices)

Expand All @@ -2968,12 +2986,8 @@ def get_cell_neighbor(self, cellidss, offsetss, periodic, prune_uniques=False):
.. note:: Returns 0 if the offset is out of bounds!

'''
stack = True
if not hasattr(cellidss,"__len__"):
cellidss = np.atleast_1d(cellidss)
offsetss = np.atleast_2d(offsetss)
stack = False



if prune_uniques:
fullargs = np.array(np.hstack((cellidss[:,np.newaxis],offsetss)))
uniqueargs, inverse_indices = np.unique(fullargs,axis=0, return_inverse=True)
Expand All @@ -2984,7 +2998,6 @@ def get_cell_neighbor(self, cellidss, offsetss, periodic, prune_uniques=False):
offsets = offsetss
inverse_indices = np.indices((len(cellids),))


reflevel = self.get_amr_level(cellids)
indices = self.get_cell_indices(cellids, reflevel)

Expand Down Expand Up @@ -3020,10 +3033,7 @@ def get_cell_neighbor(self, cellidss, offsetss, periodic, prune_uniques=False):
# warnings.warn("A neighboring cell found at a different refinement level. Behaviour is janky, and results will vary.")

# Return the neighbor cellids/cellid:
if stack:
return np.array(cellid_neighbors[inverse_indices])
else:
return np.array(cellid_neighbors)[0]
return cellid_neighbors[inverse_indices]

def get_WID(self):
# default WID=4
Expand Down