diff --git a/.github/workflows/readme_snippets.yml b/.github/workflows/readme_snippets.yml index 21524cb1b3..87b2db2257 100644 --- a/.github/workflows/readme_snippets.yml +++ b/.github/workflows/readme_snippets.yml @@ -34,6 +34,9 @@ jobs: python -We readme.py sed -i -e 's/CPU/GPU/g' readme.py python -We readme.py + sed -i -e 's/GPU/JAX/g' readme.py + sed -i '/pyplot\.legend()/,$d' readme.py + python -We readme.py - name: artefacts if: github.ref == 'refs/heads/main' && matrix.platform == 'ubuntu-latest' diff --git a/PySDM/backends/__init__.py b/PySDM/backends/__init__.py index 3ec1b6e492..226bf1a60e 100644 --- a/PySDM/backends/__init__.py +++ b/PySDM/backends/__init__.py @@ -11,11 +11,14 @@ from numba import cuda from . import numba as _numba +from . import jax as _jax # for pdoc CPU = None GPU = None +JAX = None Numba = _numba.Numba +Jax = _jax.Jax ThrustRTC = None @@ -93,3 +96,5 @@ def _cached_backend(formulae=None, backend_class=None, **kwargs): GPU = partial(_cached_backend, backend_class=ThrustRTC) """ returns a cached instance of the ThrustRTC backend (cache key including formulae parameters) """ + +JAX = partial(_cached_backend, backend_class=Jax) diff --git a/PySDM/backends/impl_jax/__init__.py b/PySDM/backends/impl_jax/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/PySDM/backends/impl_jax/methods/__init__.py b/PySDM/backends/impl_jax/methods/__init__.py new file mode 100644 index 0000000000..0e09071787 --- /dev/null +++ b/PySDM/backends/impl_jax/methods/__init__.py @@ -0,0 +1,7 @@ +"""method classes of the JAX backend""" + +from .collisions_methods import CollisionsMethods +from .index_methods import IndexMethods +from .moments_methods import MomentsMethods +from .pair_methods import PairMethods +from .physics_methods import PhysicsMethods diff --git a/PySDM/backends/impl_jax/methods/collisions_methods.py b/PySDM/backends/impl_jax/methods/collisions_methods.py new file mode 100644 index 0000000000..48c3a1078d --- /dev/null +++ b/PySDM/backends/impl_jax/methods/collisions_methods.py @@ -0,0 +1,249 @@ +""" +CPU implementation of backend methods for particle collisions +""" + +from functools import cached_property +import numba +import numpy as np + +from PySDM.backends.impl_common.backend_methods import BackendMethods +from PySDM.backends.impl_numba import conf +from PySDM.backends.impl_jax.storage import Storage + + +@numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False}}) +def pair_indices(i, idx, is_first_in_pair, prob_like): + """given permutation array `idx` and `is_first_in_pair` flag array, + returns indices `j` and `k` of droplets within pair `i` and a `skip_pair` flag, + `j` points to the droplet that is first in pair (higher or equal multiplicity) + output is valid only if `2*i` or `2*i+1` points to a valid pair start index (within one cell) + otherwise the `skip_pair` flag is set to True and returned `j` & `k` indices are set to -1. + In addition, the `prob_like` array is checked for zeros at position `i`, in which case + the `skip_pair` is also set to `True` + """ + skip_pair = False + + if prob_like[i] == 0: + skip_pair = True + j, k = -1, -1 + else: + offset = 1 - is_first_in_pair[2 * i] + j = idx[2 * i + offset] + k = idx[2 * i + 1 + offset] + return j, k, skip_pair + + +@numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False}}) +def flag_zero_multiplicity(j, k, multiplicity, healthy): + if multiplicity[k] == 0 or multiplicity[j] == 0: + healthy[0] = 0 + + +@numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False}}) +def coalesce( # pylint: disable=too-many-arguments + i, j, k, cid, multiplicity, gamma, attributes, coalescence_rate +): + new_n = multiplicity[j] - gamma[i] * multiplicity[k] + if new_n > 0: + multiplicity[j] = new_n + for a in range(len(attributes)): + attributes[a, k] += gamma[i] * attributes[a, j] + else: # new_n == 0 + multiplicity[j] = multiplicity[k] // 2 + multiplicity[k] = multiplicity[k] - multiplicity[j] + for a in range(len(attributes)): + attributes[a, j] = gamma[i] * attributes[a, j] + attributes[a, k] + attributes[a, k] = attributes[a, j] + + +class CollisionsMethods(BackendMethods): + @cached_property + def _collision_coalescence_body(self): + @numba.njit(**self.default_jit_flags) + def body( + *, + multiplicity, + idx, + length, + attributes, + gamma, + healthy, + cell_id, + coalescence_rate, + is_first_in_pair, + ): + for ( + i + ) in numba.prange( # pylint: disable=not-an-iterable,too-many-nested-blocks + length // 2 + ): + j, k, skip_pair = pair_indices(i, idx, is_first_in_pair, gamma) + if skip_pair: + continue + coalesce( + i, + j, + k, + cell_id[j], + multiplicity, + gamma, + attributes, + coalescence_rate, + ) + flag_zero_multiplicity(j, k, multiplicity, healthy) + + return body + + def collision_coalescence( + self, + *, + multiplicity, + idx, + attributes, + gamma, + healthy, + cell_id, + coalescence_rate, + is_first_in_pair, + ): + self._collision_coalescence_body( + multiplicity=multiplicity.data, + idx=idx.data, + length=len(idx), + attributes=attributes.data, + gamma=gamma.data, + healthy=healthy.data, + cell_id=cell_id.data, + coalescence_rate=coalescence_rate.data, + is_first_in_pair=is_first_in_pair.indicator.data, + ) + + @cached_property + def _compute_gamma_body(self): + @numba.njit(**self.default_jit_flags) + # pylint: disable=too-many-arguments,too-many-locals + def body( + prob, + rand, + idx, + length, + multiplicity, + cell_id, + collision_rate_deficit, + collision_rate, + is_first_in_pair, + out, + ): + # TODO #1731 - shared docstring for all backends + for i in numba.prange(length // 2): # pylint: disable=not-an-iterable + out[i] = np.ceil(prob[i] - rand[i]) + j, k, skip_pair = pair_indices(i, idx, is_first_in_pair, out) + if skip_pair: + continue + prop = multiplicity[j] // multiplicity[k] + g = min(int(out[i]), prop) + out[i] = g + + return body + + def compute_gamma( + self, + *, + prob, + rand, + multiplicity, + cell_id, + collision_rate_deficit, + collision_rate, + is_first_in_pair, + out, + ): + return self._compute_gamma_body( + prob.data, + rand.data, + multiplicity.idx.data, + len(multiplicity), + multiplicity.data, + cell_id.data, + collision_rate_deficit.data, + collision_rate.data, + is_first_in_pair.indicator.data, + out.data, + ) + + @staticmethod + def make_cell_caretaker(idx_shape, idx_dtype, cell_start_len, scheme="default"): + class CellCaretaker: # pylint: disable=too-few-public-methods + def __init__(self, idx_shape, idx_dtype, cell_start_len, scheme): + if scheme == "default": + if conf.JIT_FLAGS["parallel"]: + scheme = "counting_sort_parallel" + else: + scheme = "counting_sort" + self.scheme = scheme + if scheme in ("counting_sort", "counting_sort_parallel"): + self.tmp_idx = Storage.empty(idx_shape, idx_dtype) + + def __call__(self, cell_id, cell_idx, cell_start, idx): + length = len(idx) + if self.scheme == "counting_sort": + CollisionsMethods._counting_sort_by_cell_id_and_update_cell_start( + self.tmp_idx.data, + idx.data, + cell_id.data, + cell_idx.data, + length, + cell_start.data, + ) + idx.data, self.tmp_idx.data = self.tmp_idx.data, idx.data + + return CellCaretaker(idx_shape, idx_dtype, cell_start_len, scheme) + + @cached_property + def _normalize_body(self): + @numba.njit(**{**self.default_jit_flags, **{"parallel": False}}) + # pylint: disable=too-many-arguments + def body(prob, cell_id, cell_idx, cell_start, norm_factor, timestep, dv): + n_cell = cell_start.shape[0] - 1 + for i in range(n_cell): + sd_num = cell_start[i + 1] - cell_start[i] + if sd_num < 2: + norm_factor[i] = 0 + else: + norm_factor[i] = ( + timestep / dv * sd_num * (sd_num - 1) / 2 / (sd_num // 2) + ) + for d in numba.prange(prob.shape[0]): # pylint: disable=not-an-iterable + prob[d] *= norm_factor[cell_idx[cell_id[d]]] + + return body + + # pylint: disable=too-many-arguments + def normalize(self, prob, cell_id, cell_idx, cell_start, norm_factor, timestep, dv): + return self._normalize_body( + prob.data, + cell_id.data, + cell_idx.data, + cell_start.data, + norm_factor.data, + timestep, + dv, + ) + + + @staticmethod + @numba.njit(**conf.JIT_FLAGS) + # pylint: disable=too-many-arguments + def _counting_sort_by_cell_id_and_update_cell_start( + new_idx, idx, cell_id, cell_idx, length, cell_start + ): + cell_end = cell_start + # Warning: Assuming len(cell_end) == n_cell+1 + cell_end[:] = 0 + for i in range(length): + cell_end[cell_idx[cell_id[idx[i]]]] += 1 + for i in range(1, len(cell_end)): + cell_end[i] += cell_end[i - 1] + for i in range(length - 1, -1, -1): + cell_end[cell_idx[cell_id[idx[i]]]] -= 1 + new_idx[cell_end[cell_idx[cell_id[idx[i]]]]] = idx[i] diff --git a/PySDM/backends/impl_jax/methods/index_methods.py b/PySDM/backends/impl_jax/methods/index_methods.py new file mode 100644 index 0000000000..2cbabfa458 --- /dev/null +++ b/PySDM/backends/impl_jax/methods/index_methods.py @@ -0,0 +1,26 @@ +""" +CPU implementation of shuffling and sorting backend methods +""" + +from functools import cached_property + +import numba + +from PySDM.backends.impl_common.backend_methods import BackendMethods + + +class IndexMethods(BackendMethods): + + @cached_property + def shuffle_local(self): + @numba.njit(**self.default_jit_flags) + def body(idx, u01, cell_start): + # pylint: disable=not-an-iterable + for c in numba.prange(len(cell_start) - 1): + for i in range(cell_start[c + 1] - 1, cell_start[c], -1): + j = int( + cell_start[c] + u01[i] * (cell_start[c + 1] - cell_start[c]) + ) + idx[i], idx[j] = idx[j], idx[i] + + return body diff --git a/PySDM/backends/impl_jax/methods/moments_methods.py b/PySDM/backends/impl_jax/methods/moments_methods.py new file mode 100644 index 0000000000..29cfba19ac --- /dev/null +++ b/PySDM/backends/impl_jax/methods/moments_methods.py @@ -0,0 +1,182 @@ +""" +CPU implementation of moment calculation backend methods +""" + +from functools import cached_property + +import numba + +from PySDM.backends.impl_common.backend_methods import BackendMethods +from PySDM.backends.impl_numba.atomic_operations import atomic_add + + +class MomentsMethods(BackendMethods): + @cached_property + def _moments_body(self): + @numba.njit(**self.default_jit_flags) + def body( + *, + moment_0, + moments, + multiplicity, + attr_data, + cell_id, + idx, + length, + ranks, + min_x, + max_x, + x_attr, + weighting_attribute, + weighting_rank, + skip_division_by_m0, + ): + # pylint: disable=too-many-locals + moment_0[:] = 0 + moments[:, :] = 0 + for idx_i in numba.prange(length): # pylint: disable=not-an-iterable + i = idx[idx_i] + if min_x <= x_attr[i] < max_x: + atomic_add( + moment_0, + cell_id[i], + multiplicity[i] * weighting_attribute[i] ** weighting_rank, + ) + for k in range(ranks.shape[0]): + atomic_add( + moments, + (k, cell_id[i]), + ( + multiplicity[i] + * weighting_attribute[i] ** weighting_rank + * attr_data[i] ** ranks[k] + ), + ) + if not skip_division_by_m0: + for c_id in range(moment_0.shape[0]): + for k in range(ranks.shape[0]): + moments[k, c_id] = ( + moments[k, c_id] / moment_0[c_id] + if moment_0[c_id] != 0 + else 0 + ) + + return body + + def moments( + self, + *, + moment_0, + moments, + multiplicity, + attr_data, + cell_id, + idx, + length, + ranks, + min_x, + max_x, + x_attr, + weighting_attribute, + weighting_rank, + skip_division_by_m0, + ): + return self._moments_body( + moment_0=moment_0.data, + moments=moments.data, + multiplicity=multiplicity.data, + attr_data=attr_data.data, + cell_id=cell_id.data, + idx=idx.data, + length=length, + ranks=ranks.data, + min_x=min_x, + max_x=max_x, + x_attr=x_attr.data, + weighting_attribute=weighting_attribute.data, + weighting_rank=weighting_rank, + skip_division_by_m0=skip_division_by_m0, + ) + + @cached_property + def _spectrum_moments_body(self): + @numba.njit(**self.default_jit_flags) + def body( + *, + moment_0, + moments, + multiplicity, + attr_data, + cell_id, + idx, + length, + rank, + x_bins, + x_attr, + weighting_attribute, + weighting_rank, + ): + # pylint: disable=too-many-locals + moment_0[:, :] = 0 + moments[:, :] = 0 + for idx_i in numba.prange(length): # pylint: disable=not-an-iterable + i = idx[idx_i] + for k in range(x_bins.shape[0] - 1): + if x_bins[k] <= x_attr[i] < x_bins[k + 1]: + atomic_add( + moment_0, + (k, cell_id[i]), + multiplicity[i] * weighting_attribute[i] ** weighting_rank, + ) + atomic_add( + moments, + (k, cell_id[i]), + ( + multiplicity[i] + * weighting_attribute[i] ** weighting_rank + * attr_data[i] ** rank + ), + ) + break + for c_id in range(moment_0.shape[1]): + for k in range(x_bins.shape[0] - 1): + moments[k, c_id] = ( + moments[k, c_id] / moment_0[k, c_id] + if moment_0[k, c_id] != 0 + else 0 + ) + + return body + + def spectrum_moments( + self, + *, + moment_0, + moments, + multiplicity, + attr_data, + cell_id, + idx, + length, + rank, + x_bins, + x_attr, + weighting_attribute, + weighting_rank, + ): + assert moments.shape[0] == x_bins.shape[0] - 1 + assert moment_0.shape == moments.shape + return self._spectrum_moments_body( + moment_0=moment_0.data, + moments=moments.data, + multiplicity=multiplicity.data, + attr_data=attr_data.data, + cell_id=cell_id.data, + idx=idx.data, + length=length, + rank=rank, + x_bins=x_bins.data, + x_attr=x_attr.data, + weighting_attribute=weighting_attribute.data, + weighting_rank=weighting_rank, + ) diff --git a/PySDM/backends/impl_jax/methods/pair_methods.py b/PySDM/backends/impl_jax/methods/pair_methods.py new file mode 100644 index 0000000000..cc187804cd --- /dev/null +++ b/PySDM/backends/impl_jax/methods/pair_methods.py @@ -0,0 +1,91 @@ +""" +CPU implementation of pairwise operations backend methods +""" + +from functools import cached_property + +import numba + +from PySDM.backends.impl_common.backend_methods import BackendMethods + + +class PairMethods(BackendMethods): + + @cached_property + def _find_pairs_body(self): + @numba.njit(**self.default_jit_flags) + def body(*, cell_start, is_first_in_pair, cell_id, cell_idx, idx, length): + for i in numba.prange(length - 1): # pylint: disable=not-an-iterable + is_in_same_cell = cell_id[idx[i]] == cell_id[idx[i + 1]] + is_even_index = (i - cell_start[cell_idx[cell_id[idx[i]]]]) % 2 == 0 + is_first_in_pair[i] = is_in_same_cell and is_even_index + is_first_in_pair[length - 1] = False + + return body + + # pylint: disable=too-many-arguments + def find_pairs(self, cell_start, is_first_in_pair, cell_id, cell_idx, idx): + return self._find_pairs_body( + cell_start=cell_start.data, + is_first_in_pair=is_first_in_pair.indicator.data, + cell_id=cell_id.data, + cell_idx=cell_idx.data, + idx=idx.data, + length=len(idx), + ) + + @cached_property + def _max_pair_body(self): + @numba.njit(**self.default_jit_flags) + def body(data_out, data_in, is_first_in_pair, idx, length): + data_out[:] = 0 + for i in numba.prange(length - 1): # pylint: disable=not-an-iterable + if is_first_in_pair[i]: + data_out[i // 2] = max(data_in[idx[i]], data_in[idx[i + 1]]) + + return body + + def max_pair(self, data_out, data_in, is_first_in_pair, idx): + return self._max_pair_body( + data_out.data, + data_in.data, + is_first_in_pair.indicator.data, + idx.data, + len(idx), + ) + + @cached_property + def _sort_within_pair_by_attr_body(self): + @numba.njit(**self.default_jit_flags) + def body(idx, length, is_first_in_pair, attr): + for i in numba.prange(length - 1): # pylint: disable=not-an-iterable + if is_first_in_pair[i]: + if attr[idx[i]] < attr[idx[i + 1]]: + idx[i], idx[i + 1] = idx[i + 1], idx[i] + + return body + + def sort_within_pair_by_attr(self, idx, is_first_in_pair, attr): + self._sort_within_pair_by_attr_body( + idx.data, len(idx), is_first_in_pair.indicator.data, attr.data + ) + + @cached_property + def _sum_pair_body(self): + @numba.njit(**self.default_jit_flags) + def body(data_out, data_in, is_first_in_pair, idx, length): + data_out[:] = 0 + for i in numba.prange(length): # pylint: disable=not-an-iterable + if is_first_in_pair[i]: + data_out[i // 2] = data_in[idx[i]] + data_in[idx[i + 1]] + + return body + + def sum_pair(self, data_out, data_in, is_first_in_pair, idx): + return self._sum_pair_body( + data_out.data, + data_in.data, + is_first_in_pair.indicator.data, + idx.data, + len(idx), + ) diff --git a/PySDM/backends/impl_jax/methods/physics_methods.py b/PySDM/backends/impl_jax/methods/physics_methods.py new file mode 100644 index 0000000000..de80965434 --- /dev/null +++ b/PySDM/backends/impl_jax/methods/physics_methods.py @@ -0,0 +1,29 @@ +""" +CPU implementation of backend methods wrapping basic physics formulae +""" + +from functools import cached_property + +import numba +from numba import prange + +from PySDM.backends.impl_common.backend_methods import BackendMethods + + +class PhysicsMethods(BackendMethods): + def __init__(self): + BackendMethods.__init__(self) + + @cached_property + def _volume_of_mass_body(self): + ff = self.formulae_flattened + + @numba.njit(**self.default_jit_flags) + def body(volume, mass): + for i in prange(volume.shape[0]): # pylint: disable=not-an-iterable + volume[i] = ff.particle_shape_and_density__mass_to_volume(mass[i]) + + return body + + def volume_of_water_mass(self, volume, mass): + self._volume_of_mass_body(volume.data, mass.data) diff --git a/PySDM/backends/impl_jax/storage.py b/PySDM/backends/impl_jax/storage.py new file mode 100644 index 0000000000..340a41ca6b --- /dev/null +++ b/PySDM/backends/impl_jax/storage.py @@ -0,0 +1,93 @@ +""" +CPU Numpy-based implementation of Storage class +""" + +import numpy as np + +from PySDM.backends.impl_common.storage_utils import ( + StorageBase, + StorageSignature, + empty, + get_data_from_ndarray, +) +from PySDM.backends.impl_jax import storage_impl as impl + + +class Storage(StorageBase): + FLOAT = np.float64 + INT = np.int64 + BOOL = np.bool_ + + def row_view(self, i): + return Storage( + StorageSignature(self.data[i], (*self.shape[1:],), self.dtype) + ) + + def at(self, index): + assert self.shape == (1,), "Cannot call at() on Storage of shape other than (1,)" + return self.data[index] + + def __imul__(self, other): + if hasattr(other, "data"): + impl.multiply(self.data, other.data) + else: + impl.multiply(self.data, other) + return self + + def __itruediv__(self, other): + if hasattr(other, "data"): + self.data[:] /= other.data[:] + else: + self.data[:] /= other + return self + + def download(self, target, reshape=False): + if reshape: + data = self.data.reshape(target.shape) + else: + data = self.data + np.copyto(target, data, casting="safe") + + @staticmethod + def _get_empty_data(shape, dtype): + if dtype in (float, Storage.FLOAT): + data = np.full(shape, np.nan, dtype=Storage.FLOAT) + dtype = Storage.FLOAT + elif dtype in (int, Storage.INT): + data = np.full(shape, -1, dtype=Storage.INT) + dtype = Storage.INT + elif dtype in (bool, Storage.BOOL): + data = np.full(shape, -1, dtype=Storage.BOOL) + dtype = Storage.BOOL + else: + raise NotImplementedError() + + return StorageSignature(data, shape, dtype) + + @staticmethod + def empty(shape, dtype): + return empty(shape, dtype, Storage) + + @staticmethod + def _get_data_from_ndarray(array): + return get_data_from_ndarray( + array=array, + storage_class=Storage, + copy_fun=lambda array_astype: array_astype.copy(), + ) + + @staticmethod + def from_ndarray(array): + result = Storage(Storage._get_data_from_ndarray(array)) + return result + + def urand(self, generator): + generator(self) + + def upload(self, data): + np.copyto(self.data, data, casting="safe") + def fill(self, other): + if isinstance(other, Storage): + self.data[:] = other.data + else: + self.data[:] = other diff --git a/PySDM/backends/impl_jax/storage_impl.py b/PySDM/backends/impl_jax/storage_impl.py new file mode 100644 index 0000000000..fe9d492706 --- /dev/null +++ b/PySDM/backends/impl_jax/storage_impl.py @@ -0,0 +1,11 @@ +""" +Numba njit-ted basic arithmetics routines for CPU backend +""" + +import numba +from PySDM.backends.impl_numba import conf + + +@numba.njit(**{**conf.JIT_FLAGS, **{"parallel": False}}) +def multiply(output, multiplier): + output *= multiplier diff --git a/PySDM/backends/jax.py b/PySDM/backends/jax.py new file mode 100644 index 0000000000..22f25dbea1 --- /dev/null +++ b/PySDM/backends/jax.py @@ -0,0 +1,46 @@ +""" +Multi-threaded CPU backend using LLVM-powered just-in-time compilation +""" + +from PySDM.backends.impl_jax import methods +from PySDM.backends.impl_numba.random import Random as ImportedRandom +from PySDM.backends.impl_jax.storage import Storage as ImportedStorage +from PySDM.formulae import Formulae + + +class Jax( + methods.CollisionsMethods, + methods.PairMethods, + methods.IndexMethods, + methods.PhysicsMethods, + methods.MomentsMethods, +): + Storage = ImportedStorage + Random = ImportedRandom + + default_croupier = "local" + + def __init__( + self, formulae=None, *, double_precision=True, override_jit_flags=None + ): + if not double_precision: + raise NotImplementedError() + + self.formulae = formulae or Formulae() + self.formulae_flattened = self.formulae.flatten + + # assert "fastmath" not in (override_jit_flags or {}) + # self.default_jit_flags = { + # **JIT_FLAGS, # here parallel=False (for out-of-backend code) + # **{"fastmath": self.formulae.fastmath, "parallel": parallel_default}, + # **(override_jit_flags or {}), + # } + self.default_jit_flags = { + "parallel": False + } + + methods.CollisionsMethods.__init__(self) + methods.PairMethods.__init__(self) + methods.IndexMethods.__init__(self) + methods.PhysicsMethods.__init__(self) + methods.MomentsMethods.__init__(self) diff --git a/PySDM/dynamics/collisions/collision.py b/PySDM/dynamics/collisions/collision.py index 18b24d8637..97af236a37 100644 --- a/PySDM/dynamics/collisions/collision.py +++ b/PySDM/dynamics/collisions/collision.py @@ -131,9 +131,9 @@ def register(self, builder): self.stats_n_substep = self.particulator.Storage.empty( self.particulator.mesh.n_cell, dtype=int ) - self.stats_n_substep[:] = 0 if self.adaptive else self.__substeps + self.stats_n_substep.fill(0 if self.adaptive else self.__substeps) self.stats_dt_min = self.particulator.Storage.empty(**empty_args_cellwise) - self.stats_dt_min[:] = np.nan + self.stats_dt_min.fill(np.nan) self.rnd_opt_coll.register(builder) self.collision_kernel.register(builder) diff --git a/PySDM/dynamics/impl/random_generator_optimizer.py b/PySDM/dynamics/impl/random_generator_optimizer.py index 4e2bb54094..793897d219 100644 --- a/PySDM/dynamics/impl/random_generator_optimizer.py +++ b/PySDM/dynamics/impl/random_generator_optimizer.py @@ -45,4 +45,7 @@ def get_random_arrays(self): self.pairs_rand.urand(self.rnd) self.rand.urand(self.rnd) self.substep += 1 - return self.pairs_rand[shift : self.particulator.n_sd + shift], self.rand + if self.optimized_random: + return self.pairs_rand[shift : self.particulator.n_sd + shift], self.rand + else: + return self.pairs_rand, self.rand diff --git a/PySDM/impl/particle_attributes.py b/PySDM/impl/particle_attributes.py index 5e486db29f..d91038ffd3 100644 --- a/PySDM/impl/particle_attributes.py +++ b/PySDM/impl/particle_attributes.py @@ -41,7 +41,7 @@ def __init__( @property def healthy(self) -> bool: - return bool(self.__healthy_memory[0]) + return bool(self.__healthy_memory.at(0)) @healthy.setter def healthy(self, value: bool): diff --git a/PySDM/impl/particle_attributes_factory.py b/PySDM/impl/particle_attributes_factory.py index f8f1b4546b..e5dad088c2 100644 --- a/PySDM/impl/particle_attributes_factory.py +++ b/PySDM/impl/particle_attributes_factory.py @@ -61,7 +61,7 @@ def attributes(particulator, req_attr, attributes): def helper(req_attr, all_attr, names, data, keys): for i, attr in enumerate(names): keys[attr] = i - req_attr[attr].set_data(data[i, :]) + req_attr[attr].set_data(data.row_view(i)) try: req_attr[attr].init(all_attr[attr]) except KeyError as err: diff --git a/PySDM/products/impl/spectrum_moment_product.py b/PySDM/products/impl/spectrum_moment_product.py index b02fe781d5..3762a9a507 100644 --- a/PySDM/products/impl/spectrum_moment_product.py +++ b/PySDM/products/impl/spectrum_moment_product.py @@ -48,6 +48,6 @@ def _recalculate_spectrum_moment( def _download_spectrum_moment_to_buffer(self, rank, bin_number): if rank == 0: # TODO #217 - self._download_to_buffer(self.moment_0[bin_number, :]) + self._download_to_buffer(self.moment_0.row_view(bin_number)) else: - self._download_to_buffer(self.moments[bin_number, :]) + self._download_to_buffer(self.moments.row_view(bin_number)) diff --git a/docs/markdown/pysdm_landing.md b/docs/markdown/pysdm_landing.md index d373d487f7..ef5fa29c28 100644 --- a/docs/markdown/pysdm_landing.md +++ b/docs/markdown/pysdm_landing.md @@ -157,7 +157,7 @@ radius_bins_edges = 10 .^ range(log10(10*si.um), log10(5e3*si.um), length=32) env = Box(dt=1 * si.s, dv=1e6 * si.m^3) builder = Builder(n_sd=n_sd, backend=CPU(), environment=env) -builder.add_dynamic(Coalescence(collision_kernel=Golovin(b=1.5e3 / si.s))) +builder.add_dynamic(Coalescence(collision_kernel=Golovin(b=1.5e3 / si.s), adaptive=false)) products = [ParticleVolumeVersusRadiusLogarithmSpectrum(radius_bins_edges=radius_bins_edges, name="dv/dlnr")] particulator = builder.build(attributes, products) ``` @@ -177,7 +177,7 @@ radius_bins_edges = logspace(log10(10 * si.um), log10(5e3 * si.um), 32); env = Box(pyargs('dt', 1 * si.s, 'dv', 1e6 * si.m ^ 3)); builder = Builder(pyargs('n_sd', int32(n_sd), 'backend', CPU(), 'environment', env)); -builder.add_dynamic(Coalescence(pyargs('collision_kernel', Golovin(1.5e3 / si.s)))); +builder.add_dynamic(Coalescence(pyargs('collision_kernel', Golovin(1.5e3 / si.s), 'adaptive', false))); products = py.list({ ParticleVolumeVersusRadiusLogarithmSpectrum(pyargs( ... 'radius_bins_edges', py.numpy.array(radius_bins_edges), ... 'name', 'dv/dlnr' ... @@ -201,7 +201,7 @@ radius_bins_edges = np.logspace(np.log10(10 * si.um), np.log10(5e3 * si.um), num env = Box(dt=1 * si.s, dv=1e6 * si.m ** 3) builder = Builder(n_sd=n_sd, backend=CPU(), environment=env) -builder.add_dynamic(Coalescence(collision_kernel=Golovin(b=1.5e3 / si.s))) +builder.add_dynamic(Coalescence(collision_kernel=Golovin(b=1.5e3 / si.s), adaptive=False)) products = [ParticleVolumeVersusRadiusLogarithmSpectrum(radius_bins_edges=radius_bins_edges, name='dv/dlnr')] particulator = builder.build(attributes, products) ```