Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions magma/bits.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,8 @@ def __int__(self):

@debug_wire
def wire(self, other, debug_info):
if isinstance(other, (IntegerTypes, BitVector)):
N = (other.bit_length()
if isinstance(other, IntegerTypes)
else len(other))
if isinstance(other, IntegerTypes):
N = other.bit_length()
if N > len(self):
raise ValueError(
f"Cannot convert integer {other} "
Expand Down
39 changes: 39 additions & 0 deletions magma/coerce.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from functools import wraps

import hwtypes as ht

from magma.protocol_type import MagmaProtocol
from magma.debug import debug_info


def python_to_magma_coerce(value):
if isinstance(value, debug_info):
# Short circuit tuple converion
return value

# Circular import
from magma.conversions import tuple_, sint, uint, bits, bit
if isinstance(value, tuple):
return tuple_(value)
if isinstance(value, ht.SIntVector):
return sint(value, len(value))
if isinstance(value, ht.UIntVector):
return uint(value, len(value))
if isinstance(value, ht.BitVector):
return bits(value, len(value))
if isinstance(value, (bool, ht.Bit)):
return bit(value)

if isinstance(value, MagmaProtocol):
return value._get_magma_value_()

return value


def python_to_magma_coerce_wrapper(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
args = (python_to_magma_coerce(a) for a in args)
kwargs = {k: python_to_magma_coerce(v) for k, v in kwargs.items()}
return fn(*args, **kwargs)
return wrapper
7 changes: 4 additions & 3 deletions magma/conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
__all__ += ['as_bits', 'from_bits']

def can_convert_to_bit(value):
return isinstance(magma_value(value), (Digital, Array, Tuple, IntegerTypes))
return isinstance(magma_value(value), (Digital, Array, Tuple, IntegerTypes,
ht.Bit))


def can_convert_to_bit_type(value):
Expand Down Expand Up @@ -61,9 +62,9 @@ def convertbit(value, totype):
"bit can only be used on arrays and tuples of bits"
f"; not {type(value)}")

assert isinstance(value, (IntegerTypes, Digital))
assert isinstance(value, (IntegerTypes, Digital, ht.Bit))

if isinstance(value, IntegerTypes):
if isinstance(value, (IntegerTypes, ht.Bit)):
# Just return VCC or GND here, otherwise we lose VCC/GND singleton
# invariant
return totype(1) if value else totype(0)
Expand Down
2 changes: 1 addition & 1 deletion magma/digital.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def wire(self, o, debug_info):
i = self
o = magma_value(o)
# promote integer types to LOW/HIGH
if isinstance(o, (IntegerTypes, bool, ht.Bit)):
if isinstance(o, IntegerTypes):
o = HIGH if o else LOW

if not isinstance(o, Digital):
Expand Down
20 changes: 7 additions & 13 deletions magma/primitives/mux.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from magma.bits import Bits, UInt, SInt
from magma.bitutils import clog2, seq2int
from magma.circuit import coreir_port_mapping
from magma.coerce import python_to_magma_coerce
from magma.generator import Generator2
from magma.interface import IO
from magma.protocol_type import MagmaProtocol, magma_type
Expand Down Expand Up @@ -86,22 +87,14 @@ def _infer_mux_type(args):
"""
T = None
for arg in args:
if isinstance(arg, (Type, MagmaProtocol)):
next_T = type(arg).qualify(Direction.Undirected)
elif isinstance(arg, UIntVector):
next_T = UInt[len(arg)]
elif isinstance(arg, SIntVector):
next_T = SInt[len(arg)]
elif isinstance(arg, BitVector):
next_T = Bits[len(arg)]
elif isinstance(arg, (ht.Bit, bool)):
next_T = Bit
elif isinstance(arg, tuple):
next_T = type(tuple_(arg))
elif isinstance(arg, int):
if isinstance(arg, int):
# Cannot infer type without width, use wiring implicit coercion to
# handle (or raise type error there)
continue
if not isinstance(arg, (Type, MagmaProtocol)):
raise TypeError(f"Found unsupport argument {arg} of type"
f" {type(arg)}")
next_T = type(arg).qualify(Direction.Undirected)

if T is not None:
if issubclass(T, next_T):
Expand Down Expand Up @@ -143,6 +136,7 @@ def mux(I: list, S, **kwargs):
S = seq2int(S.bits())
if isinstance(S, int):
return I[S]
I = tuple(python_to_magma_coerce(i) for i in I)
T, I = _infer_mux_type(I)
inst = Mux(len(I), T)(**kwargs)
if len(I) == 2 and isinstance(S, Bits[1]):
Expand Down
13 changes: 2 additions & 11 deletions magma/wire.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from .wire_container import Wire # TODO(rsetaluri): only here for b.c.
from .debug import debug_wire
from .logging import root_logger
from .protocol_type import magma_value

from magma.coerce import python_to_magma_coerce_wrapper
from magma.wire_container import WiringLog


Expand All @@ -15,18 +15,9 @@
_CONSTANTS = (IntegerTypes, BitVector, Bit)


@python_to_magma_coerce_wrapper
@debug_wire
def wire(o, i, debug_info=None):
o = magma_value(o)
i = magma_value(i)

# Circular import
from .conversions import tuple_
if isinstance(o, tuple):
o = tuple_(o)
if isinstance(i, tuple):
i = tuple_(i)

# Wire(o, Circuit).
if hasattr(i, 'interface'):
i.wire(o, debug_info)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_errors/test_mux_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class Foo(m.Circuit):
with pytest.raises(TypeError) as e:
m.mux([1, 2], io.S)
assert str(e.value) == f"""\
Could not infer mux type from [1, 2]
Could not infer mux type from (1, 2)
Need at least one magma value, BitVector, bool or tuple\
"""

Expand Down
2 changes: 1 addition & 1 deletion tests/test_primitives/test_mux.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ class test_mux_array_select_bits_1(m.Circuit):
def test_mux_intv(ht_T, m_T):
class Main(m.Circuit):
O = m.mux([ht_T[4](1), m_T[4](2)], m.Bit())
assert isinstance(O, m_T)
assert isinstance(O, m_T), type(O)


@pytest.mark.parametrize("ht_T", [ht.UIntVector, ht.SIntVector])
Expand Down