Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions notebook/notes.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ to run tests:
`python -m pytest`
to generate test coverage:
`pytest --cov --cov-report=html:coverage_re`
to test specific file/test:
`python -m pytest .\test\parse_exec_test.py -k test_running`
to show cli outputs: add `-o log_cli=true`
to add new tests:
add anydice code in `.\test\autoouts\fetch_in.py`
`python .\test\autoouts\fetch.py --fetch`
Expand Down
2 changes: 1 addition & 1 deletion src/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = '0.3.2.dev2'
__version__ = '0.3.3'

# core classes
from .randvar import RV
Expand Down
2 changes: 1 addition & 1 deletion src/factory.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union, TYPE_CHECKING
from typing import Union

from .typings import T_if, T_ifsr, MetaRV
from . import utils
Expand Down
2 changes: 1 addition & 1 deletion src/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
def output(rv: Union[T_isr, None], named=None, show_pdf=True, blocks_width=None, print_=True, print_fn=None, cdf_cut=0):
if isinstance(rv, MetaSeq) and len(rv) == 0: # empty sequence plotted as empty
rv = blackrv.BlankRV()
if isinstance(rv, int) or isinstance(rv, bool):
if isinstance(rv, (int, float, bool)):
rv = randvar.RV.from_seq([rv])
elif isinstance(rv, Iterable):
rv = randvar.RV.from_seq(rv)
Expand Down
7 changes: 6 additions & 1 deletion src/parser/example_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@

trials = [
r'''
A: {2*1..3}
output(1dA)
A: {(3.0*2.0)/1..(4.0*3.0)/1}
output(1dA)

A: {1d4, "water", "fire"}
output(2dA)
'''
Expand All @@ -30,7 +35,7 @@ def main(trials=trials):
if to_parse is None or to_parse.strip() == '':
logger.debug('Empty string')
continue
lexer, yaccer = parse_and_exec.build_lex_yacc()
lexer, yaccer = parse_and_exec.build_lex_yacc(debug=False)
parse_and_exec.do_lex(to_parse, lexer)
if lexer.LEX_ILLEGAL_CHARS:
logger.debug('Lex Illegal characters found: ' + str(lexer.LEX_ILLEGAL_CHARS))
Expand Down
22 changes: 16 additions & 6 deletions src/parser/myparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
tokens = ['PLUS', 'MINUS', 'TIMES', 'DIVIDE', 'POWER',
'COLON', 'LESS', 'GREATER', 'EQUALS', 'NOTEQUALS', 'AT',
'HASH', 'OR', 'AND', 'EXCLAMATION',
'DOT', 'COMMA',
'DOT', 'DOUBLE_DOT', 'COMMA',
'LPAREN', 'RPAREN', 'LBRACE', 'RBRACE', 'LBRACKET', 'RBRACKET',
'LOWERNAME', 'UPPERNAME', 'NUMBER',
'D_OP',
Expand Down Expand Up @@ -55,6 +55,7 @@
t_AND = r'&'
t_EXCLAMATION = r'!'

t_DOUBLE_DOT = r'\.\.'
t_DOT = r'\.'
t_COMMA = r','

Expand Down Expand Up @@ -165,6 +166,7 @@ class NodeType(Enum):
HASH = 'hash'
GROUP = 'group'
NUMBER = 'number'
NUMBER_DECIMAL = 'number_decimal'
VAR = 'var'
SEQ = 'seq'
RANGE = 'range'
Expand Down Expand Up @@ -418,6 +420,7 @@ def p_funcname_def_param(p):
('right', 'HASH_OP'), # 'HASH' (unary #) operator precedence
('right', 'EXCLAMATION'), # Unary NOT operator (!) precedence
('right', 'UMINUS', 'UPLUS'), # Unary minus and plus have the highest precedence
('left', 'DOT'), # Decimal point
)


Expand Down Expand Up @@ -499,6 +502,13 @@ def p_term_number(p):
p[0] = Node(NodeType.NUMBER, p[1])


def p_term_number_decimal(p):
'''
term : NUMBER DOT NUMBER
'''
p[0] = Node(NodeType.NUMBER_DECIMAL, p[1], p[3])


def p_term_name(p):
'''
term : var_name
Expand Down Expand Up @@ -541,9 +551,9 @@ def p_element(p):

def p_range(p):
'''
range : expression DOT DOT expression
range : expression DOUBLE_DOT expression
'''
p[0] = Node(NodeType.RANGE, p[1], p[4])
p[0] = Node(NodeType.RANGE, p[1], p[3])


def p_str_element(p):
Expand Down Expand Up @@ -615,7 +625,7 @@ def p_error(p):

# BUILD

def build_lex_yacc():
lexer = lex()
yaccer = yacc()
def build_lex_yacc(debug=False):
lexer = lex(debug=debug)
yaccer = yacc(debug=debug)
return lexer, yaccer
4 changes: 2 additions & 2 deletions src/parser/parse_and_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
logger = logging.getLogger(__name__)


def build_lex_yacc():
lexer, yaccer = myparser.build_lex_yacc()
def build_lex_yacc(debug=False):
lexer, yaccer = myparser.build_lex_yacc(debug=debug)
lexer.LEX_ILLEGAL_CHARS = []
lexer.YACC_ILLEGALs = []
return lexer, yaccer
Expand Down
6 changes: 5 additions & 1 deletion src/parser/python_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def _indent_str(self, s: str):

def resolve_node(self, node: Union['Node', 'str']) -> str: # noqa: C901
assert node is not None, 'Got None'
assert not isinstance(node, str), f'resolver error, not sure what to do with a string: {node}. All strings should be a Node ("string", str|strvar...)'
assert not isinstance(node, str), f'resolver error, not sure what to do with a string [{node}]. All strings should be a Node ("string", str|strvar...)'

if node.type == NodeType.MULTILINE_CODE:
return '\n'.join([self.resolve_node(x) for x in node]) if len(node) > 0 else 'pass'
Expand All @@ -120,6 +120,10 @@ def resolve_node(self, node: Union['Node', 'str']) -> str: # noqa: C901
elif node.type == NodeType.NUMBER: # number in an expression
assert isinstance(node.val, str), f'Expected str of a number, got {node.val} type: {type(node.val)}'
return str(node.val)
elif node.type == NodeType.NUMBER_DECIMAL: # number in an expression
val1, val2 = node
assert isinstance(val1, str) and isinstance(val2, str), f'Expected str of a number, got {val1} and {val2}'
return f'{val1}.{val2}'
elif node.type == NodeType.VAR: # variable inside an expression
assert isinstance(node.val, str), f'Expected str of a variable, got {node.val}'
if self._COMPILER_FLAG_NON_LOCAL_SCOPE:
Expand Down
7 changes: 4 additions & 3 deletions src/randvar.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def filter(self, obj: T_ifsr):
return RV.from_const(0)
vals, probs = zip(*vp)
assert all(isinstance(p, int) for p in probs), 'should not happen'
probs_int: tuple[int] = probs # type: ignore
probs_int: tuple[int] = probs # type: ignore assertion above
return RV(vals, probs_int)

def get_vals_probs(self, cdf_cut: float = 0):
Expand Down Expand Up @@ -213,9 +213,10 @@ def __rmatmul__(self, other: T_is):
other = factory.get_seq([other])
assert all(isinstance(i, int) for i in other._seq), 'indices must be integers'
if len(other) == 1: # only one index, return the value at that index
k: int = other._seq[0] # type: ignore
k: T_if = other._seq[0]
assert isinstance(k, int), 'unsupported operand type(s) for @: float and RV'
return self._source_die._get_kth_order_statistic(self._source_roll, k)
return _sum_at(self, other) # type: ignore
return _sum_at(self, other) # type: ignore anydice_casting

def _get_kth_order_statistic(self, draws: int, k: int):
'''Get the k-th smallest value of n draws: k@RV where RV is n rolls of a die'''
Expand Down
2 changes: 1 addition & 1 deletion src/roller.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,4 +119,4 @@ def myrange(left, right):
raise TypeError(f'A sequence range must begin with a number, while you provided "{left}".')
if isinstance(right, RV):
raise TypeError(f'A sequence range must begin with a number, while you provided "{right}".')
return range(left, right + 1)
return range(int(left), int(right) + 1)
6 changes: 3 additions & 3 deletions src/seq.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterable, Callable, Union
from typing import Iterable, Callable, Optional, Union
from itertools import zip_longest
import operator

Expand All @@ -11,11 +11,11 @@


class Seq(Iterable, MetaSeq):
def __init__(self, *source: T_ifsr, _INTERNAL_SEQ_VALUE=None):
def __init__(self, *source: T_ifsr, _INTERNAL_SEQ_VALUE: Optional[tuple[T_if, ...]] = None):
self._sum = None
self._one_indexed = 1
if _INTERNAL_SEQ_VALUE is not None: # used for internal optimization only
self._seq: tuple[T_if, ...] = _INTERNAL_SEQ_VALUE # type: ignore
self._seq: tuple[T_if, ...] = _INTERNAL_SEQ_VALUE
return
flat = tuple(utils.flatten(source))
flat_rvs = [x for x in flat if isinstance(x, MetaRV) and not isinstance(x, blackrv.BlankRV)] # expand RVs
Expand Down
41 changes: 41 additions & 0 deletions test/parse_exec_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from dice_calc import RV, settings_set
from dice_calc.parser import parse_and_exec
from dice_calc.seq import Seq

import logging

Expand Down Expand Up @@ -522,3 +523,43 @@ def check_res(x):
i += 1
pipeline(unbound_var_code[0], version=2, global_vars={'output': lambda x: check_res(x)})
assert i == len(unbound_var_code[1])




lst = [
(r'''
A: 1.5
B: A*2
C: {A, B}
D: 2dC
output(A)
output(B)
output(C)
output(D)
''', [1.5, 3.0, Seq(1.5, 3.0), RV([3.0, 4.5, 6.0], [1, 2, 1])]
),
(r'''
A: {3.5*2.0..4.5*3.0}
output(1dA)
''', [RV([7, 8, 9, 10, 11, 12, 13], [1, 1, 1, 1, 1, 1, 1])])
]
@pytest.mark.parametrize("code,res", lst)
def test_floats(code, res):
i = 0
def check_res(x):
nonlocal i
check(x, res[i])
i += 1
pipeline(code, version=1, global_vars={'output': lambda x: check_res(x)})
assert i == len(res)

@pytest.mark.parametrize("code,res", lst)
def test_floatsv2(code, res):
i = 0
def check_res(x):
nonlocal i
check(x, res[i])
i += 1
pipeline(code, version=2, global_vars={'output': lambda x: check_res(x)})
assert i == len(res)
Loading