diff --git a/notebook/notes.txt b/notebook/notes.txt index 12ec535..ffa56f8 100644 --- a/notebook/notes.txt +++ b/notebook/notes.txt @@ -11,6 +11,9 @@ to run tests: `python -m pytest` to generate test coverage: `pytest --cov --cov-report=html:coverage_re` + to test specific file/test: + `python -m pytest .\test\parse_exec_test.py -k test_running` + to show cli outputs: add `-o log_cli=true` to add new tests: add anydice code in `.\test\autoouts\fetch_in.py` `python .\test\autoouts\fetch.py --fetch` diff --git a/src/__init__.py b/src/__init__.py index 67c2f26..4899065 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -1,4 +1,4 @@ -__version__ = '0.3.2.dev2' +__version__ = '0.3.3' # core classes from .randvar import RV diff --git a/src/factory.py b/src/factory.py index d9b3003..0902e35 100644 --- a/src/factory.py +++ b/src/factory.py @@ -1,4 +1,4 @@ -from typing import Union, TYPE_CHECKING +from typing import Union from .typings import T_if, T_ifsr, MetaRV from . import utils diff --git a/src/output.py b/src/output.py index 6226130..68f3730 100644 --- a/src/output.py +++ b/src/output.py @@ -9,7 +9,7 @@ def output(rv: Union[T_isr, None], named=None, show_pdf=True, blocks_width=None, print_=True, print_fn=None, cdf_cut=0): if isinstance(rv, MetaSeq) and len(rv) == 0: # empty sequence plotted as empty rv = blackrv.BlankRV() - if isinstance(rv, int) or isinstance(rv, bool): + if isinstance(rv, (int, float, bool)): rv = randvar.RV.from_seq([rv]) elif isinstance(rv, Iterable): rv = randvar.RV.from_seq(rv) diff --git a/src/parser/example_parse.py b/src/parser/example_parse.py index ff9de7f..22a5ae4 100644 --- a/src/parser/example_parse.py +++ b/src/parser/example_parse.py @@ -6,6 +6,11 @@ trials = [ r''' +A: {2*1..3} +output(1dA) +A: {(3.0*2.0)/1..(4.0*3.0)/1} +output(1dA) + A: {1d4, "water", "fire"} output(2dA) ''' @@ -30,7 +35,7 @@ def main(trials=trials): if to_parse is None or to_parse.strip() == '': logger.debug('Empty string') continue - lexer, yaccer = parse_and_exec.build_lex_yacc() + lexer, yaccer = parse_and_exec.build_lex_yacc(debug=False) parse_and_exec.do_lex(to_parse, lexer) if lexer.LEX_ILLEGAL_CHARS: logger.debug('Lex Illegal characters found: ' + str(lexer.LEX_ILLEGAL_CHARS)) diff --git a/src/parser/myparser.py b/src/parser/myparser.py index 43f6f74..626f7c3 100644 --- a/src/parser/myparser.py +++ b/src/parser/myparser.py @@ -24,7 +24,7 @@ tokens = ['PLUS', 'MINUS', 'TIMES', 'DIVIDE', 'POWER', 'COLON', 'LESS', 'GREATER', 'EQUALS', 'NOTEQUALS', 'AT', 'HASH', 'OR', 'AND', 'EXCLAMATION', - 'DOT', 'COMMA', + 'DOT', 'DOUBLE_DOT', 'COMMA', 'LPAREN', 'RPAREN', 'LBRACE', 'RBRACE', 'LBRACKET', 'RBRACKET', 'LOWERNAME', 'UPPERNAME', 'NUMBER', 'D_OP', @@ -55,6 +55,7 @@ t_AND = r'&' t_EXCLAMATION = r'!' +t_DOUBLE_DOT = r'\.\.' t_DOT = r'\.' t_COMMA = r',' @@ -165,6 +166,7 @@ class NodeType(Enum): HASH = 'hash' GROUP = 'group' NUMBER = 'number' + NUMBER_DECIMAL = 'number_decimal' VAR = 'var' SEQ = 'seq' RANGE = 'range' @@ -418,6 +420,7 @@ def p_funcname_def_param(p): ('right', 'HASH_OP'), # 'HASH' (unary #) operator precedence ('right', 'EXCLAMATION'), # Unary NOT operator (!) precedence ('right', 'UMINUS', 'UPLUS'), # Unary minus and plus have the highest precedence + ('left', 'DOT'), # Decimal point ) @@ -499,6 +502,13 @@ def p_term_number(p): p[0] = Node(NodeType.NUMBER, p[1]) +def p_term_number_decimal(p): + ''' + term : NUMBER DOT NUMBER + ''' + p[0] = Node(NodeType.NUMBER_DECIMAL, p[1], p[3]) + + def p_term_name(p): ''' term : var_name @@ -541,9 +551,9 @@ def p_element(p): def p_range(p): ''' - range : expression DOT DOT expression + range : expression DOUBLE_DOT expression ''' - p[0] = Node(NodeType.RANGE, p[1], p[4]) + p[0] = Node(NodeType.RANGE, p[1], p[3]) def p_str_element(p): @@ -615,7 +625,7 @@ def p_error(p): # BUILD -def build_lex_yacc(): - lexer = lex() - yaccer = yacc() +def build_lex_yacc(debug=False): + lexer = lex(debug=debug) + yaccer = yacc(debug=debug) return lexer, yaccer diff --git a/src/parser/parse_and_exec.py b/src/parser/parse_and_exec.py index 22c68ba..d6c70e0 100644 --- a/src/parser/parse_and_exec.py +++ b/src/parser/parse_and_exec.py @@ -7,8 +7,8 @@ logger = logging.getLogger(__name__) -def build_lex_yacc(): - lexer, yaccer = myparser.build_lex_yacc() +def build_lex_yacc(debug=False): + lexer, yaccer = myparser.build_lex_yacc(debug=debug) lexer.LEX_ILLEGAL_CHARS = [] lexer.YACC_ILLEGALs = [] return lexer, yaccer diff --git a/src/parser/python_resolver.py b/src/parser/python_resolver.py index 38b9c5f..88b0bfe 100644 --- a/src/parser/python_resolver.py +++ b/src/parser/python_resolver.py @@ -99,7 +99,7 @@ def _indent_str(self, s: str): def resolve_node(self, node: Union['Node', 'str']) -> str: # noqa: C901 assert node is not None, 'Got None' - assert not isinstance(node, str), f'resolver error, not sure what to do with a string: {node}. All strings should be a Node ("string", str|strvar...)' + assert not isinstance(node, str), f'resolver error, not sure what to do with a string [{node}]. All strings should be a Node ("string", str|strvar...)' if node.type == NodeType.MULTILINE_CODE: return '\n'.join([self.resolve_node(x) for x in node]) if len(node) > 0 else 'pass' @@ -120,6 +120,10 @@ def resolve_node(self, node: Union['Node', 'str']) -> str: # noqa: C901 elif node.type == NodeType.NUMBER: # number in an expression assert isinstance(node.val, str), f'Expected str of a number, got {node.val} type: {type(node.val)}' return str(node.val) + elif node.type == NodeType.NUMBER_DECIMAL: # number in an expression + val1, val2 = node + assert isinstance(val1, str) and isinstance(val2, str), f'Expected str of a number, got {val1} and {val2}' + return f'{val1}.{val2}' elif node.type == NodeType.VAR: # variable inside an expression assert isinstance(node.val, str), f'Expected str of a variable, got {node.val}' if self._COMPILER_FLAG_NON_LOCAL_SCOPE: diff --git a/src/randvar.py b/src/randvar.py index db3d7df..5f9f78b 100644 --- a/src/randvar.py +++ b/src/randvar.py @@ -121,7 +121,7 @@ def filter(self, obj: T_ifsr): return RV.from_const(0) vals, probs = zip(*vp) assert all(isinstance(p, int) for p in probs), 'should not happen' - probs_int: tuple[int] = probs # type: ignore + probs_int: tuple[int] = probs # type: ignore assertion above return RV(vals, probs_int) def get_vals_probs(self, cdf_cut: float = 0): @@ -213,9 +213,10 @@ def __rmatmul__(self, other: T_is): other = factory.get_seq([other]) assert all(isinstance(i, int) for i in other._seq), 'indices must be integers' if len(other) == 1: # only one index, return the value at that index - k: int = other._seq[0] # type: ignore + k: T_if = other._seq[0] + assert isinstance(k, int), 'unsupported operand type(s) for @: float and RV' return self._source_die._get_kth_order_statistic(self._source_roll, k) - return _sum_at(self, other) # type: ignore + return _sum_at(self, other) # type: ignore anydice_casting def _get_kth_order_statistic(self, draws: int, k: int): '''Get the k-th smallest value of n draws: k@RV where RV is n rolls of a die''' diff --git a/src/roller.py b/src/roller.py index 5828008..8864663 100644 --- a/src/roller.py +++ b/src/roller.py @@ -119,4 +119,4 @@ def myrange(left, right): raise TypeError(f'A sequence range must begin with a number, while you provided "{left}".') if isinstance(right, RV): raise TypeError(f'A sequence range must begin with a number, while you provided "{right}".') - return range(left, right + 1) + return range(int(left), int(right) + 1) diff --git a/src/seq.py b/src/seq.py index ab3c1e9..6cecf79 100644 --- a/src/seq.py +++ b/src/seq.py @@ -1,4 +1,4 @@ -from typing import Iterable, Callable, Union +from typing import Iterable, Callable, Optional, Union from itertools import zip_longest import operator @@ -11,11 +11,11 @@ class Seq(Iterable, MetaSeq): - def __init__(self, *source: T_ifsr, _INTERNAL_SEQ_VALUE=None): + def __init__(self, *source: T_ifsr, _INTERNAL_SEQ_VALUE: Optional[tuple[T_if, ...]] = None): self._sum = None self._one_indexed = 1 if _INTERNAL_SEQ_VALUE is not None: # used for internal optimization only - self._seq: tuple[T_if, ...] = _INTERNAL_SEQ_VALUE # type: ignore + self._seq: tuple[T_if, ...] = _INTERNAL_SEQ_VALUE return flat = tuple(utils.flatten(source)) flat_rvs = [x for x in flat if isinstance(x, MetaRV) and not isinstance(x, blackrv.BlankRV)] # expand RVs diff --git a/test/parse_exec_test.py b/test/parse_exec_test.py index 6966aca..c114a5c 100644 --- a/test/parse_exec_test.py +++ b/test/parse_exec_test.py @@ -2,6 +2,7 @@ from dice_calc import RV, settings_set from dice_calc.parser import parse_and_exec +from dice_calc.seq import Seq import logging @@ -522,3 +523,43 @@ def check_res(x): i += 1 pipeline(unbound_var_code[0], version=2, global_vars={'output': lambda x: check_res(x)}) assert i == len(unbound_var_code[1]) + + + + +lst = [ +(r''' +A: 1.5 +B: A*2 +C: {A, B} +D: 2dC +output(A) +output(B) +output(C) +output(D) +''', [1.5, 3.0, Seq(1.5, 3.0), RV([3.0, 4.5, 6.0], [1, 2, 1])] +), +(r''' +A: {3.5*2.0..4.5*3.0} +output(1dA) +''', [RV([7, 8, 9, 10, 11, 12, 13], [1, 1, 1, 1, 1, 1, 1])]) +] +@pytest.mark.parametrize("code,res", lst) +def test_floats(code, res): + i = 0 + def check_res(x): + nonlocal i + check(x, res[i]) + i += 1 + pipeline(code, version=1, global_vars={'output': lambda x: check_res(x)}) + assert i == len(res) + +@pytest.mark.parametrize("code,res", lst) +def test_floatsv2(code, res): + i = 0 + def check_res(x): + nonlocal i + check(x, res[i]) + i += 1 + pipeline(code, version=2, global_vars={'output': lambda x: check_res(x)}) + assert i == len(res)