Source code for test_general.test_calculation

from typing import Sequence

import hypothesis.strategies as st
import numpy as np
import pytest
from hypothesis import given
from hypothesis.extra.numpy import arrays

from stream import unpacked
from stream.calculation import Calculation, _concat
from stream.composition import Calculation_factory

Addition = Calculation_factory(lambda y, *, x: y + x, [False], dict(y=0))
Multiplication = Calculation_factory(lambda x, *, y: x * y, [False], dict(x=0))
Division = Calculation_factory(lambda z, *, x: z / x, [False], dict(z=0))

add = Addition(name="Add")
multiply = Multiplication(name="Multiply")
divide = Division(name="Divide")


[docs] @given(st.lists(st.floats(allow_nan=False))) def test_unpack_correctly_unpacks_data(lst): # noinspection PyTypeChecker kwargs = dict( some_input=dict(enumerate(lst)), more_input=dict(enumerate(map(np.array, lst))), ) def give_me_values(*, some_input, more_input): return some_input, more_input output, more_output = unpacked(give_me_values)(**kwargs) assert np.allclose(output, np.array(lst)) assert np.allclose(more_output, np.array(lst))
def _give_me_values(*, some_input, more_input): return some_input, more_input
[docs] @given(st.lists(st.floats(allow_nan=False))) def test_unpack_correctly_excludes_parameters(lst): input_dict = dict(enumerate(lst)) kwargs = dict(some_input=input_dict, more_input=input_dict) output, more_output = unpacked(_give_me_values, exclude=["more_input"])(**kwargs) assert np.array_equal(np.atleast_1d(output), list(more_output.values())) assert more_output == input_dict
[docs] def test_unpack_exclusion_errors_on_missing_variable_name(): input_dict = dict(enumerate(range(1, 6))) kwargs = dict(some_input=input_dict, more_input=input_dict) missing = "missing_variable_name" with pytest.raises(KeyError, match=missing): unpacked(_give_me_values, exclude=[missing])(**kwargs)
dictvals = st.dictionaries( st.integers(), st.one_of( arrays(dtype=float, shape=st.integers(1, 10), elements=st.floats(allow_nan=False)), st.floats(allow_nan=False), ), )
[docs] @given(dictvals) def test_concat_is_at_most_1d(d): assert np.ndim(_concat(d)) <= 1
list_arrays = st.lists(arrays(dtype=float, shape=st.integers(1, 10), elements=st.floats(allow_nan=False)))
[docs] @given(list_arrays) def test_concat_of_dictionaried_arrays_is_the_same_as_their_numpy_concat(lst): d = dict(zip(range(len(lst)), lst)) if lst: assert np.allclose(_concat(d), np.concatenate(lst)) else: assert not len(_concat(d))
[docs] @given(st.floats(allow_nan=False)) def test_default_save_has_correct_output_for_one_structure(val): assert add.save([val], x=None) == {"y": val} assert multiply.save([val], y=None) == {"x": val}
def _vardict(arr: np.ndarray) -> Sequence[slice]: alphabet = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" indices = set(np.argwhere(arr).flatten()) starts = sorted({0} | indices) ends = sorted({len(arr)} | indices) slices = [slice(s, e) for s, e in zip(starts, ends)] return dict(zip(alphabet, slices)) arrlengths = st.shared(st.integers(1, 20), key="length") boolarrs = arrays(bool, arrlengths, elements=st.booleans()) valarrs = arrays(float, arrlengths, elements=st.floats(0.0, 10.0, allow_nan=False)) vardicts = boolarrs.map(_vardict) emptyfunc = st.just(lambda y, **_: np.zeros(y.shape)) calctypes = st.builds(Calculation_factory, emptyfunc, boolarrs, vardicts) calcs = calctypes.map(lambda x: x())
[docs] @given(calcs, valarrs) def test_default_save_is_compatible_with_calc_variables(calc: Calculation, arr: np.ndarray): state = calc.save(arr) for v, place in calc.variables.items(): assert np.allclose(arr[place], state[v])
[docs] @given(st.floats(allow_nan=False)) def test_default_load_for_one_structure(val): assert add.load({"y": val}) == [val] assert multiply.load({"x": val}) == [val]
[docs] @given(calcs, valarrs) def test_default_load_is_inverse_of_default_save(calc: Calculation, arr: np.ndarray): assert np.allclose(calc.load(calc.save(arr)), arr)