Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/python-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ jobs:
- name: Lint
run: |
uv run ruff check
uv run mypy .
- name: Test
run: |
uv run pytest
2 changes: 1 addition & 1 deletion humanleague/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def tabulate_counts(population: npt.NDArray, names: list[str] | tuple[str, ...]
pd.Series: A pandas Series where the index is a MultiIndex created from the shape of the input array,
and the data corresponds to the flattened values of the input array.
"""
index = pd.MultiIndex.from_tuples(list(np.ndindex(population.shape)), names=names)
index = pd.MultiIndex.from_tuples(list(np.ndindex(population.shape)), names=names) # type: ignore[arg-type]
return pd.Series(index=index, data=list(np.nditer(population)), dtype=int, name="count")


Expand Down
9 changes: 5 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "humanleague"
version = "2.4.1"
version = "2.4.2"
authors = [
{ name="Andrew Smith", email="andrew@friarswood.net" },
]
Expand All @@ -32,12 +32,13 @@ dependencies = [

[dependency-groups]
dev = [
"pybind11>=2.10.3",
"pybind11>=3.0.0",
"pytest>=8.1.4",
"mypy>=1.5.0",
"mypy-extensions>=1.0.0",
"ruff>=0.0.286",
"build>=1.2.2.post1"
"ruff>=0.12.9",
"build>=1.2.2.post1",
"typing-extensions>=4.15.0",
]

[tool.pytest.ini_options]
Expand Down
12 changes: 8 additions & 4 deletions src/Integerise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@

namespace {

int64_t checked_round(double x, double tol=1e-4) // loose tolerance ~1/4 mantissa precision
// loose tolerance ~1/4 mantissa precision
constexpr double TOL = 1e-4;

int64_t checked_round(double x)
{
if (fabs(x - round(x)) > tol)
throw std::runtime_error("Marginal or total value %% is not an integer (within tolerance %%)"s % x % tol);
if (fabs(x - round(x)) > TOL)
throw std::runtime_error("Marginal or total value %% is not an integer (within tolerance %%)"s % x % TOL);
return (int64_t)round(x);
}

Expand Down Expand Up @@ -62,7 +65,8 @@ Integeriser::Integeriser(const NDArray<double>& seed) : m_seed(seed)
{
// convert to vector (reduce 1-d special case)
std::vector<double> p = reduce(seed, 0);
int pop = sum(seed);
// casting rounds down so for better consistency (with checked_round) add TOL
int pop = sum(seed) + TOL;
// convert to probabilities
for (auto& x: p) x /= pop;
std::vector<int> tmp = integeriseMarginalDistribution(p, pop, m_rmse);
Expand Down
1 change: 1 addition & 0 deletions src/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ py::tuple integerise1d(py::array_t<double> frac_a, int pop) {
const std::vector<int>& freq = integeriseMarginalDistribution(prob, pop, var);

py::dict stats;
stats["conv"] = true; // always converges, but including for consistency
stats["rmse"] = var;

return py::make_tuple(py::array_t<int>(freq.size(), freq.data()), stats);
Expand Down
24 changes: 23 additions & 1 deletion tests/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,14 @@ def test_integerise() -> None:
# exact
r, stats = hl.integerise(np.array([0.4, 0.3, 0.2, 0.1]), 10)
assert stats["rmse"] < 1e-15
assert stats["conv"] # 1d case with specified total will always converge, but return for consistency
assert np.array_equal(r, np.array([4, 3, 2, 1]))

# inexact
r, stats = hl.integerise(np.array([0.4, 0.3, 0.2, 0.1]), 17)
assert stats["rmse"] == pytest.approx(0.273861278752583, abs=1e-6)

# without total we still get the same stats keys (not values)
assert stats.keys() == hl.integerise(np.array([0.4, 0.3, 0.2, 0.1]))[1].keys()
assert np.array_equal(r, np.array([7, 5, 3, 2]))

# 1-d case
Expand Down Expand Up @@ -113,6 +115,26 @@ def test_integerise() -> None:
assert np.sum(result) == sum(m0)
assert stats["rmse"] < 1.05717

# 1d integerise without providing total - check total rounds up if appropriate
a = np.array([1.1, 2.9, 0.9999])
result, stats = hl.integerise(a)
assert (result == np.array([1, 3, 1])).all()
assert stats["conv"]

# 1d integerise without providing total - check total doesnt round up if inappropriate
a[2] = 1.0001
result, stats = hl.integerise(a)
assert (result == np.array([1, 3, 1])).all()
assert stats["conv"]

# outside tolerance
a[2] = 1.0002
with pytest.raises(RuntimeError):
hl.integerise(a)
a[2] = 0.9998
with pytest.raises(RuntimeError):
hl.integerise(a)


def test_IPF() -> None:
m0 = np.array([52.0, 48.0])
Expand Down
2 changes: 1 addition & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_tabulate_counts_no_names() -> None:
population = np.array([[5, 6], [7, 8]])
result = hl.tabulate_counts(population)

expected_index = pd.MultiIndex.from_tuples([(0, 0), (0, 1), (1, 0), (1, 1)], names=None)
expected_index = pd.MultiIndex.from_tuples([(0, 0), (0, 1), (1, 0), (1, 1)], names=None) # type: ignore[arg-type]
expected_data = [5, 6, 7, 8]
expected = pd.Series(data=expected_data, index=expected_index, name="count")

Expand Down
Loading