Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions notebook/scratch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,34 @@
" output [balanced RANK from 5d[highest 3 of 4d6]]\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class rand:\n",
" pass\n",
"class Var:\n",
" mytype = object()\n",
" def __init__(self, value):\n",
" self.value = value\n",
" def __eq__(self, other):\n",
" return self.value == other.value\n",
"Var(Var.mytype) == Var(Var.mytype)"
]
}
],
"metadata": {
Expand Down
5 changes: 3 additions & 2 deletions src/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
__version__ = '0.3.2'
__version__ = '0.3.2.dev0'

# core classes
from .randvar import RV
from .seq import Seq
from .factory import get_seq

# core functions
from .roller import myrange
Expand All @@ -29,7 +30,7 @@

__all__ = [
'RV', 'Seq', 'anydice_casting', 'BlankRV', 'max_func_depth', 'output', 'roll', 'settings_set', 'myrange',
'roller', 'settings_reset', 'StringSeq',
'roller', 'settings_reset', 'StringSeq', 'get_seq',
'absolute_X', 'X_contains_X', 'count_X_in_X', 'explode_X', 'highest_X_of_X', 'lowest_X_of_X', 'middle_X_of_X', 'highest_of_X_and_X', 'lowest_of_X_and_X', 'maximum_of_X', 'reverse_X', 'sort_X',
'myMatmul', 'myLen', 'myInvert', 'myAnd', 'myOr'
]
1 change: 1 addition & 0 deletions src/blackrv.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# TODO rename file to blankrv.py
from typing import Iterable

from . import randvar as rv
Expand Down
24 changes: 24 additions & 0 deletions src/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from typing import Union

from .typings import T_if, T_ifsr
from . import randvar
from . import seq
from . import utils
from . import blackrv
from . import string_rvs


T_ifsrt = Union[T_ifsr, str]


def get_seq(*source: T_ifsrt) -> 'seq.Seq':
# check if string in values, if so, return StringSeq
flat = tuple(utils.flatten(source))
flat_rvs = [x for x in flat if isinstance(x, randvar.RV) and not isinstance(x, blackrv.BlankRV)] # expand RVs
flat_rv_vals = [v for rv in flat_rvs for v in rv.vals]
flat_else: list[T_if] = [x for x in flat if not isinstance(x, (randvar.RV, blackrv.BlankRV))]
res = tuple(flat_else + flat_rv_vals)
if any(isinstance(x, (str, string_rvs.StringVal)) for x in res):
return string_rvs.StringSeq(res)
assert all(isinstance(x, (int, float)) for x in res), 'Seq must be made of numbers and RVs. Seq:' + str(res)
return seq.Seq(_INTERNAL_SEQ_VALUE=res)
3 changes: 2 additions & 1 deletion src/parser/parse_and_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ def _get_lib():
import random
import functools
from ..randvar import RV
from ..seq import Seq, get_seq
from ..seq import Seq
from ..factory import get_seq
from ..settings import settings_set
from ..decorators import anydice_casting, max_func_depth
from ..output import output
Expand Down
15 changes: 1 addition & 14 deletions src/seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,8 @@
from . import utils
from . import blackrv

T_ift = Union[T_if, str]


def get_seq(*source: T_ifsr) -> 'Seq':
# check if string in values, if so, return StringSeq
flat = tuple(utils.flatten(source))
flat_rvs = [x for x in flat if isinstance(x, randvar.RV) and not isinstance(x, blackrv.BlankRV)] # expand RVs
flat_rv_vals = [v for rv in flat_rvs for v in rv.vals]
flat_else: list[T_if] = [x for x in flat if not isinstance(x, (randvar.RV, blackrv.BlankRV))]
res = tuple(flat_else + flat_rv_vals)
if any(isinstance(x, str) for x in res):
from .string_rvs import StringSeq
return StringSeq(res)
assert all(isinstance(x, (int, float)) for x in flat_else), 'Seq must be made of numbers and RVs. Seq:' + str(flat_else)
return Seq(_INTERNAL_SEQ_VALUE=res)
T_ift = Union[T_if, str]


class Seq(Iterable):
Expand Down
41 changes: 29 additions & 12 deletions src/string_rvs.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,40 @@

from typing import Union
from typing import Union, Literal
import operator as op

from .typings import T_if
from .seq import Seq
from . import seq

T_ift = Union[T_if, str, 'StringVal']


_CONST_COEF = '_UNIQUE_STRING'


class StringVal:
def __init__(self, keys: tuple[str, ...], pairs: dict[str, int]):
self.keys = keys
def __init__(self, keys: tuple[str, ...], pairs: dict[str, T_if]):
self.keys = tuple(sorted(keys))
self.data = pairs

@staticmethod
def from_const(const: T_if):
return StringVal((_CONST_COEF, ), {_CONST_COEF: const})

@staticmethod
def from_str(s: str):
return StringVal((s, ), {s: 1})

@staticmethod
def from_paris(pairs: dict[str, T_if]): # TODO rename to from_pairs
return StringVal(tuple(pairs.keys()), pairs)

def __add__(self, other):
if not isinstance(other, StringVal):
other = StringVal(('', ), {'': other})
other = StringVal.from_const(other)
newdict = self.data.copy()
for key, val in other.data.items():
newdict[key] = newdict.get(key, 0) + val
keys = tuple(sorted(newdict.keys()))
return StringVal(keys, newdict)
return StringVal.from_paris(newdict)

def __radd__(self, other):
return self.__add__(other)
Expand All @@ -29,15 +43,15 @@ def __repr__(self):
r = []
last_coeff = ''
for key in self.keys:
if key == '': # empty string represents a number
last_coeff = '+' + str(self.data[key])
if key == _CONST_COEF: # empty string represents a number
last_coeff = ' + ' + str(self.data[key])
continue
elif self.data[key] == 1: # coefficient 1 is not shown
n = key
else:
n = f'{self.data[key]}*{key}'
r.append(n)
return '+'.join(r) + last_coeff
return ' + '.join(r) + last_coeff

def __format__(self, format_spec):
return f'{repr(self):{format_spec}}'
Expand Down Expand Up @@ -84,11 +98,14 @@ def __hash__(self):
return hash((self.keys, tuple(self.data)))


class StringSeq(Seq):
class StringSeq(seq.Seq):
def __init__(self, source: tuple[T_ift, ...]):
# do not call super().__init__ here
source_lst: list[T_ift] = list(source)
for i, x in enumerate(source_lst):
if isinstance(x, str):
source_lst[i] = StringVal((x, ), {x: 1})
source_lst[i] = StringVal.from_str(x)
self._seq: tuple[T_ift, ...] = tuple(source_lst)

def __iter__(self):
return iter(self._seq)
94 changes: 94 additions & 0 deletions test/stringvar_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from typing import Iterable
import pytest

from dice_calc import settings_reset, get_seq, roll, RV
from dice_calc.string_rvs import StringVal


@pytest.fixture(autouse=True)
def settings_reset_fixture():
settings_reset()


def test_init():
a = StringVal(('a', 'b'), {'a': 1, 'b': 2})
b = StringVal(('a', 'b'), {'a': 1, 'b': 2})
c = StringVal(('b', 'a'), {'b': 2, 'a': 1})
d = StringVal(('a', 'b', 'c'), {'a': 1, 'b': 2, 'c': 3})
assert a == b
assert a == c
assert b == c
assert a != d
assert b != d
assert c != d


def test_format():
a = StringVal(('a', 'b'), {'a': 1, 'b': 2})
assert f'{a}' == 'a + 2*b'


def test_compare():
a = StringVal(('a', 'b'), {'a': 1, 'b': 2})
b = StringVal(('b', 'a'), {'b': 2, 'a': 1})
c = StringVal(('b', 'a'), {'b': 1, 'a': 1})
assert a == b
assert a <= b
assert a >= b
assert not (a < b)
assert not (a > b)
assert not (a != b)

assert a > c
assert a >= c
assert not (a < c)
assert not (a <= c)
assert not (a == c)
assert a != c


def test_hash():
a = StringVal(('a', 'b'), {'a': 1, 'b': 2})
c = StringVal(('b', 'a'), {'b': 1, 'a': 1})
d = StringVal(('a', 'b', 'c'), {'a': 1, 'b': 2, 'c': 3})
assert hash(a) != hash(c)
assert hash(a) != hash(d)
assert hash(c) != hash(d)
assert hash(a) == hash(a)
assert hash(c) == hash(c)
assert hash(d) == hash(d)


def test_get_seq():
a = get_seq('a', 'b')
b = get_seq('a', 'b')
c = get_seq(1, 2)
d = get_seq(1, 2)
e = get_seq(1, 2, 'a', 'b')
assert a == b
assert c == d
assert a != c
assert a != e
assert c != e
assert isinstance(a, Iterable), 'get_seq must return an iterable'
assert a == get_seq(a)
assert e == get_seq(e)


def test_roll():
a = get_seq(roll(1, 2), 'fire', 'water') # {1d2, "water", "fire"}
b = roll(2, a) # 2dA
vp = [
(2, 1),
(3, 2),
(4, 1),
((StringVal.from_str('fire') + StringVal.from_const(1)), 2),
((StringVal.from_str('fire') + StringVal.from_const(2)), 2),
((StringVal.from_str('water') + StringVal.from_const(1)), 2),
((StringVal.from_str('water') + StringVal.from_const(2)), 2),
((StringVal.from_str('fire') + StringVal.from_str('water')), 2),
((StringVal.from_str('water') + StringVal.from_str('water')), 1),
((StringVal.from_str('fire') + StringVal.from_str('fire')), 1),
]
# assert not sum(b.probs)
assert RV.dices_are_equal(b, RV(vals=[x[0] for x in vp], probs=[x[1] for x in vp]))
Loading