Skip to content

Commit 8f2f05d

Browse files
committed
Implement cast support for ubyte
1 parent b4a9429 commit 8f2f05d

File tree

2 files changed

+41
-0
lines changed

2 files changed

+41
-0
lines changed

quaddtype/numpy_quaddtype/src/casts.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,20 @@ to_quad<npy_byte>(npy_byte x, QuadBackendType backend)
184184
return result;
185185
}
186186

187+
template <>
188+
inline quad_value
189+
to_quad<npy_ubyte>(npy_ubyte x, QuadBackendType backend)
190+
{
191+
quad_value result;
192+
if (backend == BACKEND_SLEEF) {
193+
result.sleef_value = Sleef_cast_from_uint64q1(x);
194+
}
195+
else {
196+
result.longdouble_value = (long double)x;
197+
}
198+
return result;
199+
}
200+
187201
template <>
188202
inline quad_value
189203
to_quad<npy_short>(npy_short x, QuadBackendType backend)
@@ -447,6 +461,18 @@ from_quad<npy_byte>(quad_value x, QuadBackendType backend)
447461
}
448462
}
449463

464+
template <>
465+
inline npy_ubyte
466+
from_quad<npy_ubyte>(quad_value x, QuadBackendType backend)
467+
{
468+
if (backend == BACKEND_SLEEF) {
469+
return (npy_ubyte)Sleef_cast_to_int64q1(x.sleef_value);
470+
}
471+
else {
472+
return (npy_ubyte)x.longdouble_value;
473+
}
474+
}
475+
450476
template <>
451477
inline npy_short
452478
from_quad<npy_short>(quad_value x, QuadBackendType backend)
@@ -741,6 +767,7 @@ init_casts_internal(void)
741767

742768
add_cast_to<npy_bool>(&PyArray_BoolDType);
743769
add_cast_to<npy_byte>(&PyArray_ByteDType);
770+
add_cast_to<npy_ubyte>(&PyArray_ByteDType);
744771
add_cast_to<npy_short>(&PyArray_ShortDType);
745772
add_cast_to<npy_ushort>(&PyArray_UShortDType);
746773
add_cast_to<npy_int>(&PyArray_IntDType);
@@ -755,6 +782,7 @@ init_casts_internal(void)
755782

756783
add_cast_from<npy_bool>(&PyArray_BoolDType);
757784
add_cast_from<npy_byte>(&PyArray_ByteDType);
785+
add_cast_from<npy_ubyte>(&PyArray_ByteDType);
758786
add_cast_from<npy_short>(&PyArray_ShortDType);
759787
add_cast_from<npy_ushort>(&PyArray_UShortDType);
760788
add_cast_from<npy_int>(&PyArray_IntDType);

quaddtype/tests/test_quaddtype.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,19 @@ def test_finfo_int_constant(name, value):
3939
assert getattr(numpy_quaddtype, name) == value
4040

4141

42+
@pytest.mark.parametrize("dtype", ["bool", "byte", "int8", "ubyte", "uint8", "short", "int16", "ushort", "uint16", "int", "int32", "uint", "uint32", "long", "ulong", "longlong", "int64", "ulonglong", "uint64", "half", "float16", "float", "float32", "double", "float64", "longdouble"])
43+
def test_astype(dtype):
44+
if dtype in ("half", "float16"):
45+
pytest.xfail(f"{dtype} astype not yet supported")
46+
47+
orig = np.array(1, dtype=dtype)
48+
quad = orig.astype(QuadPrecDType, casting="safe")
49+
back = quad.astype(dtype, casting="unsafe")
50+
51+
assert quad == 1
52+
assert back == orig
53+
54+
4255
def test_basic_equality():
4356
assert QuadPrecision("12") == QuadPrecision(
4457
"12.0") == QuadPrecision("12.00")

0 commit comments

Comments
 (0)