Skip to content

Commit 4e49714

Browse files
committed
Implement sign ufunc extend unary op tests
1 parent 5914811 commit 4e49714

File tree

3 files changed

+45
-24
lines changed

3 files changed

+45
-24
lines changed

quaddtype/numpy_quaddtype/src/ops.hpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,14 @@ quad_positive(Sleef_quad *op)
1717
return *op;
1818
}
1919

20+
static inline Sleef_quad
21+
quad_sign(Sleef_quad *op)
22+
{
23+
int32_t sign = Sleef_icmpq1(*op, Sleef_cast_from_doubleq1(0.0));
24+
// sign(x=NaN) = x; otherwise sign(x) in { -1.0; 0.0; +1.0 }
25+
return Sleef_iunordq1(*op, *op) ? *op : Sleef_cast_from_int64q1(sign);
26+
}
27+
2028
static inline Sleef_quad
2129
quad_absolute(Sleef_quad *op)
2230
{
@@ -152,6 +160,16 @@ ld_absolute(long double *op)
152160
return fabsl(*op);
153161
}
154162

163+
static inline long double
164+
ld_sign(long double *op)
165+
{
166+
if (x < 0.0) return -1.0;
167+
if (x == 0.0) return 0.0;
168+
if (x > 0.0) return 1.0;
169+
// sign(x=NaN) = x
170+
return x;
171+
}
172+
155173
static inline long double
156174
ld_rint(long double *op)
157175
{

quaddtype/numpy_quaddtype/src/umath.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,9 @@ init_quad_unary_ops(PyObject *numpy)
211211
if (create_quad_unary_ufunc<quad_absolute, ld_absolute>(numpy, "absolute") < 0) {
212212
return -1;
213213
}
214+
if (create_quad_unary_ufunc<quad_sign, ld_sign>(numpy, "sign") < 0) {
215+
return -1;
216+
}
214217
if (create_quad_unary_ufunc<quad_rint, ld_rint>(numpy, "rint") < 0) {
215218
return -1;
216219
}

quaddtype/tests/test_quaddtype.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def test_binary_ops(op, other):
3232
quad_result = op_func(quad_a, quad_b)
3333
float_result = op_func(float_a, float_b)
3434

35+
# FIXME: @juntyr: replace with array_equal once isnan is supported
3536
with np.errstate(invalid="ignore"):
3637
assert (
3738
(np.float64(quad_result) == float_result) or
@@ -106,31 +107,30 @@ def test_array_aminmax(op, a, b):
106107
assert np.all((quad_res == float_res) | ((quad_res != quad_res) & (float_res != float_res)))
107108

108109

109-
@pytest.mark.parametrize("op, val, expected", [
110-
("neg", "3.0", "-3.0"),
111-
("neg", "-3.0", "3.0"),
112-
("pos", "3.0", "3.0"),
113-
("pos", "-3.0", "-3.0"),
114-
("abs", "3.0", "3.0"),
115-
("abs", "-3.0", "3.0"),
116-
("neg", "12.5", "-12.5"),
117-
("pos", "100.0", "100.0"),
118-
("abs", "-25.5", "25.5"),
119-
])
120-
def test_unary_ops(op, val, expected):
110+
@pytest.mark.parametrize("op,nop", [("neg", "negative"), ("pos", "positive"), ("abs", "absolute"), (None, "sign")])
111+
@pytest.mark.parametrize("val", ["3.0", "-3.0", "12.5", "100.0", "0.0", "-0.0", "inf", "-inf", "nan", "-nan"])
112+
def test_unary_ops(op, nop, val):
113+
op_func = None if op is None else getattr(operator, op)
114+
nop_func = getattr(np, nop)
115+
121116
quad_val = QuadPrecision(val)
122-
expected_val = QuadPrecision(expected)
123-
124-
if op == "neg":
125-
result = -quad_val
126-
elif op == "pos":
127-
result = +quad_val
128-
elif op == "abs":
129-
result = abs(quad_val)
130-
else:
131-
raise ValueError(f"Unsupported operation: {op}")
132-
133-
assert result == expected_val, f"{op}({val}) should be {expected}, but got {result}"
117+
float_val = float(val)
118+
119+
for op_func in [op_func, nop_func]:
120+
if op_func is None:
121+
continue
122+
123+
quad_result = op_func(quad_val)
124+
float_result = op_func(float_val)
125+
126+
# FIXME: @juntyr: replace with array_equal once isnan is supported
127+
# FIXME: @juntyr: also check the signbit once that is supported
128+
with np.errstate(invalid="ignore"):
129+
assert (
130+
(np.float64(quad_result) == float_result) or
131+
((float_result != float_result) and (quad_result != quad_result))
132+
), f"{op}({val}) should be {float_result}, but got {quad_result}"
133+
134134

135135
def test_inf():
136136
assert QuadPrecision("inf") > QuadPrecision("1e1000")

0 commit comments

Comments
 (0)