52 lines
1.7 KiB
Python
52 lines
1.7 KiB
Python
import numpy as np
|
|
|
|
from sklearn.tree._utils import PytestWeightedFenwickTree
|
|
|
|
|
|
def test_cython_weighted_fenwick_tree(global_random_seed):
|
|
"""
|
|
Test Cython's weighted Fenwick tree implementation
|
|
"""
|
|
rng = np.random.default_rng(global_random_seed)
|
|
|
|
n = 100
|
|
indices = rng.permutation(n)
|
|
y = rng.normal(size=n)
|
|
w = rng.integers(0, 4, size=n)
|
|
y_included_so_far = np.zeros_like(y)
|
|
w_included_so_far = np.zeros_like(w)
|
|
|
|
tree = PytestWeightedFenwickTree(n)
|
|
tree.py_reset(n)
|
|
|
|
for i in range(n):
|
|
idx = indices[i]
|
|
tree.py_add(idx, y[idx], w[idx])
|
|
y_included_so_far[idx] = y[idx]
|
|
w_included_so_far[idx] = w[idx]
|
|
|
|
target = rng.uniform(0, w_included_so_far.sum())
|
|
t_idx_low, t_idx, cw, cwy = tree.py_search(target)
|
|
|
|
# check the aggregates are consistent with the returned idx
|
|
assert np.isclose(cw, np.sum(w_included_so_far[:t_idx]))
|
|
assert np.isclose(
|
|
cwy, np.sum(w_included_so_far[:t_idx] * y_included_so_far[:t_idx])
|
|
)
|
|
|
|
# check if the cumulative weight is less than or equal to the target
|
|
# depending on t_idx_low and t_idx
|
|
if t_idx_low == t_idx:
|
|
assert cw < target
|
|
else:
|
|
assert cw == target
|
|
|
|
# check that if we add the next non-null weight, we are above the target:
|
|
next_weights = w_included_so_far[t_idx:][w_included_so_far[t_idx:] > 0]
|
|
if next_weights.size > 0:
|
|
assert cw + next_weights[0] > target
|
|
# and not below the target for `t_idx_low`:
|
|
next_weights = w_included_so_far[t_idx_low:][w_included_so_far[t_idx_low:] > 0]
|
|
if next_weights.size > 0:
|
|
assert cw + next_weights[0] >= target
|