diff --git a/notebook/scratch.ipynb b/notebook/scratch.ipynb index 10006e2..02ebec2 100644 --- a/notebook/scratch.ipynb +++ b/notebook/scratch.ipynb @@ -626,6 +626,34 @@ " output [balanced RANK from 5d[highest 3 of 4d6]]\n", "}" ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "class rand:\n", + " pass\n", + "class Var:\n", + " mytype = object()\n", + " def __init__(self, value):\n", + " self.value = value\n", + " def __eq__(self, other):\n", + " return self.value == other.value\n", + "Var(Var.mytype) == Var(Var.mytype)" + ] } ], "metadata": { diff --git a/src/__init__.py b/src/__init__.py index e66ccfe..81fc17f 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -1,8 +1,9 @@ -__version__ = '0.3.2' +__version__ = '0.3.2.dev0' # core classes from .randvar import RV from .seq import Seq +from .factory import get_seq # core functions from .roller import myrange @@ -29,7 +30,7 @@ __all__ = [ 'RV', 'Seq', 'anydice_casting', 'BlankRV', 'max_func_depth', 'output', 'roll', 'settings_set', 'myrange', - 'roller', 'settings_reset', 'StringSeq', + 'roller', 'settings_reset', 'StringSeq', 'get_seq', 'absolute_X', 'X_contains_X', 'count_X_in_X', 'explode_X', 'highest_X_of_X', 'lowest_X_of_X', 'middle_X_of_X', 'highest_of_X_and_X', 'lowest_of_X_and_X', 'maximum_of_X', 'reverse_X', 'sort_X', 'myMatmul', 'myLen', 'myInvert', 'myAnd', 'myOr' ] diff --git a/src/blackrv.py b/src/blackrv.py index 44f71eb..0736538 100644 --- a/src/blackrv.py +++ b/src/blackrv.py @@ -1,3 +1,4 @@ +# TODO rename file to blankrv.py from typing import Iterable from . import randvar as rv diff --git a/src/factory.py b/src/factory.py new file mode 100644 index 0000000..3806ec6 --- /dev/null +++ b/src/factory.py @@ -0,0 +1,24 @@ +from typing import Union + +from .typings import T_if, T_ifsr +from . import randvar +from . import seq +from . import utils +from . import blackrv +from . import string_rvs + + +T_ifsrt = Union[T_ifsr, str] + + +def get_seq(*source: T_ifsrt) -> 'seq.Seq': + # check if string in values, if so, return StringSeq + flat = tuple(utils.flatten(source)) + flat_rvs = [x for x in flat if isinstance(x, randvar.RV) and not isinstance(x, blackrv.BlankRV)] # expand RVs + flat_rv_vals = [v for rv in flat_rvs for v in rv.vals] + flat_else: list[T_if] = [x for x in flat if not isinstance(x, (randvar.RV, blackrv.BlankRV))] + res = tuple(flat_else + flat_rv_vals) + if any(isinstance(x, (str, string_rvs.StringVal)) for x in res): + return string_rvs.StringSeq(res) + assert all(isinstance(x, (int, float)) for x in res), 'Seq must be made of numbers and RVs. Seq:' + str(res) + return seq.Seq(_INTERNAL_SEQ_VALUE=res) diff --git a/src/parser/parse_and_exec.py b/src/parser/parse_and_exec.py index 1c0601e..045f13a 100644 --- a/src/parser/parse_and_exec.py +++ b/src/parser/parse_and_exec.py @@ -46,7 +46,8 @@ def _get_lib(): import random import functools from ..randvar import RV - from ..seq import Seq, get_seq + from ..seq import Seq + from ..factory import get_seq from ..settings import settings_set from ..decorators import anydice_casting, max_func_depth from ..output import output diff --git a/src/seq.py b/src/seq.py index 6b8f7c8..a463416 100644 --- a/src/seq.py +++ b/src/seq.py @@ -7,21 +7,8 @@ from . import utils from . import blackrv -T_ift = Union[T_if, str] - -def get_seq(*source: T_ifsr) -> 'Seq': - # check if string in values, if so, return StringSeq - flat = tuple(utils.flatten(source)) - flat_rvs = [x for x in flat if isinstance(x, randvar.RV) and not isinstance(x, blackrv.BlankRV)] # expand RVs - flat_rv_vals = [v for rv in flat_rvs for v in rv.vals] - flat_else: list[T_if] = [x for x in flat if not isinstance(x, (randvar.RV, blackrv.BlankRV))] - res = tuple(flat_else + flat_rv_vals) - if any(isinstance(x, str) for x in res): - from .string_rvs import StringSeq - return StringSeq(res) - assert all(isinstance(x, (int, float)) for x in flat_else), 'Seq must be made of numbers and RVs. Seq:' + str(flat_else) - return Seq(_INTERNAL_SEQ_VALUE=res) +T_ift = Union[T_if, str] class Seq(Iterable): diff --git a/src/string_rvs.py b/src/string_rvs.py index 444f985..48097b1 100644 --- a/src/string_rvs.py +++ b/src/string_rvs.py @@ -1,26 +1,40 @@ -from typing import Union +from typing import Union, Literal import operator as op from .typings import T_if -from .seq import Seq +from . import seq T_ift = Union[T_if, str, 'StringVal'] +_CONST_COEF = '_UNIQUE_STRING' + + class StringVal: - def __init__(self, keys: tuple[str, ...], pairs: dict[str, int]): - self.keys = keys + def __init__(self, keys: tuple[str, ...], pairs: dict[str, T_if]): + self.keys = tuple(sorted(keys)) self.data = pairs + @staticmethod + def from_const(const: T_if): + return StringVal((_CONST_COEF, ), {_CONST_COEF: const}) + + @staticmethod + def from_str(s: str): + return StringVal((s, ), {s: 1}) + + @staticmethod + def from_paris(pairs: dict[str, T_if]): # TODO rename to from_pairs + return StringVal(tuple(pairs.keys()), pairs) + def __add__(self, other): if not isinstance(other, StringVal): - other = StringVal(('', ), {'': other}) + other = StringVal.from_const(other) newdict = self.data.copy() for key, val in other.data.items(): newdict[key] = newdict.get(key, 0) + val - keys = tuple(sorted(newdict.keys())) - return StringVal(keys, newdict) + return StringVal.from_paris(newdict) def __radd__(self, other): return self.__add__(other) @@ -29,15 +43,15 @@ def __repr__(self): r = [] last_coeff = '' for key in self.keys: - if key == '': # empty string represents a number - last_coeff = '+' + str(self.data[key]) + if key == _CONST_COEF: # empty string represents a number + last_coeff = ' + ' + str(self.data[key]) continue elif self.data[key] == 1: # coefficient 1 is not shown n = key else: n = f'{self.data[key]}*{key}' r.append(n) - return '+'.join(r) + last_coeff + return ' + '.join(r) + last_coeff def __format__(self, format_spec): return f'{repr(self):{format_spec}}' @@ -84,11 +98,14 @@ def __hash__(self): return hash((self.keys, tuple(self.data))) -class StringSeq(Seq): +class StringSeq(seq.Seq): def __init__(self, source: tuple[T_ift, ...]): # do not call super().__init__ here source_lst: list[T_ift] = list(source) for i, x in enumerate(source_lst): if isinstance(x, str): - source_lst[i] = StringVal((x, ), {x: 1}) + source_lst[i] = StringVal.from_str(x) self._seq: tuple[T_ift, ...] = tuple(source_lst) + + def __iter__(self): + return iter(self._seq) diff --git a/test/stringvar_test.py b/test/stringvar_test.py new file mode 100644 index 0000000..b7e824c --- /dev/null +++ b/test/stringvar_test.py @@ -0,0 +1,94 @@ +from typing import Iterable +import pytest + +from dice_calc import settings_reset, get_seq, roll, RV +from dice_calc.string_rvs import StringVal + + +@pytest.fixture(autouse=True) +def settings_reset_fixture(): + settings_reset() + + +def test_init(): + a = StringVal(('a', 'b'), {'a': 1, 'b': 2}) + b = StringVal(('a', 'b'), {'a': 1, 'b': 2}) + c = StringVal(('b', 'a'), {'b': 2, 'a': 1}) + d = StringVal(('a', 'b', 'c'), {'a': 1, 'b': 2, 'c': 3}) + assert a == b + assert a == c + assert b == c + assert a != d + assert b != d + assert c != d + + +def test_format(): + a = StringVal(('a', 'b'), {'a': 1, 'b': 2}) + assert f'{a}' == 'a + 2*b' + + +def test_compare(): + a = StringVal(('a', 'b'), {'a': 1, 'b': 2}) + b = StringVal(('b', 'a'), {'b': 2, 'a': 1}) + c = StringVal(('b', 'a'), {'b': 1, 'a': 1}) + assert a == b + assert a <= b + assert a >= b + assert not (a < b) + assert not (a > b) + assert not (a != b) + + assert a > c + assert a >= c + assert not (a < c) + assert not (a <= c) + assert not (a == c) + assert a != c + + +def test_hash(): + a = StringVal(('a', 'b'), {'a': 1, 'b': 2}) + c = StringVal(('b', 'a'), {'b': 1, 'a': 1}) + d = StringVal(('a', 'b', 'c'), {'a': 1, 'b': 2, 'c': 3}) + assert hash(a) != hash(c) + assert hash(a) != hash(d) + assert hash(c) != hash(d) + assert hash(a) == hash(a) + assert hash(c) == hash(c) + assert hash(d) == hash(d) + + +def test_get_seq(): + a = get_seq('a', 'b') + b = get_seq('a', 'b') + c = get_seq(1, 2) + d = get_seq(1, 2) + e = get_seq(1, 2, 'a', 'b') + assert a == b + assert c == d + assert a != c + assert a != e + assert c != e + assert isinstance(a, Iterable), 'get_seq must return an iterable' + assert a == get_seq(a) + assert e == get_seq(e) + + +def test_roll(): + a = get_seq(roll(1, 2), 'fire', 'water') # {1d2, "water", "fire"} + b = roll(2, a) # 2dA + vp = [ + (2, 1), + (3, 2), + (4, 1), + ((StringVal.from_str('fire') + StringVal.from_const(1)), 2), + ((StringVal.from_str('fire') + StringVal.from_const(2)), 2), + ((StringVal.from_str('water') + StringVal.from_const(1)), 2), + ((StringVal.from_str('water') + StringVal.from_const(2)), 2), + ((StringVal.from_str('fire') + StringVal.from_str('water')), 2), + ((StringVal.from_str('water') + StringVal.from_str('water')), 1), + ((StringVal.from_str('fire') + StringVal.from_str('fire')), 1), + ] + # assert not sum(b.probs) + assert RV.dices_are_equal(b, RV(vals=[x[0] for x in vp], probs=[x[1] for x in vp]))