Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 35 additions & 16 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5494,10 +5494,10 @@ def aten_linear_backward(

@torch_op("aten::linspace", trace_only=True)
def aten_linspace(
start: TFloat,
end: TFloat,
start: TensorType,
end: TensorType,
steps: int,
dtype: int = FLOAT.dtype,
dtype: int = -1,
layout: str = "",
device: str = "",
pin_memory: bool = False,
Expand All @@ -5507,26 +5507,45 @@ def aten_linspace(
if dtype == -1 or dtype is None:
dtype = FLOAT.dtype

# Reference: https://github.com/pytorch/pytorch/blob/b35ca2cb941b5ba90858322810ca85c31e4541fd/torch/_refs/__init__.py#L4896
if steps == 0:
return aten_full(op.Constant(value_ints=[0]), 0.0, dtype=dtype)
if steps == 1:
return aten_full(op.Constant(value_ints=[steps]), start, dtype=dtype)

rg = aten_arange_start(0, steps, dtype=dtype)
start = op.Cast(start, to=dtype)
end = op.Cast(end, to=dtype)
steps_float = op.Cast(steps, to=dtype)
one = op.Cast(1.0, to=dtype)
two = op.Cast(2.0, to=dtype)
steps_minus_1 = op.Cast(steps - 1, to=dtype)
step = op.Div(op.Sub(end, start), steps_minus_1)
return op.Where(
rg < op.Div(steps_float, two),
start + step * rg,
end - step * (steps_float - one - rg),
# For integer output dtypes, cast start/end to the target dtype first
# This matches PyTorch's behavior where fractional start/end values
# are truncated before computing the linspace
dtype = ir.DataType(dtype)
if dtype.is_integer():
# Use double precision for computation to match PyTorch's internal precision
compute_dtype = ir.DataType.DOUBLE
# Cast to integer dtype first (truncation), then to compute dtype
start_int = op.Cast(start, to=dtype) # Truncate to int32/int64
end_int = op.Cast(end, to=dtype)
start_f = op.Cast(start_int, to=compute_dtype) # Then to double
end_f = op.Cast(end_int, to=compute_dtype)
else:
compute_dtype = dtype
start_f = op.Cast(start, to=compute_dtype)
end_f = op.Cast(end, to=compute_dtype)

rg = aten_arange_start(0, steps, dtype=compute_dtype)
steps_f = op.Cast(steps, to=compute_dtype)
one = op.Constant(value=ir.tensor(1, dtype=compute_dtype))
two = op.Constant(value=ir.tensor(2, dtype=compute_dtype))
steps_minus_1 = op.Sub(steps_f, one)
step = op.Div(op.Sub(end_f, start_f), steps_minus_1)

# Two-sided computation for numerical stability at endpoints
# Use forward computation for first half, backward for second half
lin_vals = op.Where(
rg < op.Div(steps_f, two),
op.Add(start_f, op.Mul(step, rg)),
op.Sub(end_f, op.Mul(step, op.Sub(steps_minus_1, rg))),
)

return op.Cast(lin_vals, to=dtype)


@torch_op("aten::log", trace_only=True)
def aten_log(self: TFloat) -> TFloat:
Expand Down
8 changes: 0 additions & 8 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,14 +779,6 @@ def _where_input_wrangler(
"linspace",
core_ops.aten_linspace,
tolerance={torch.float16: (2e-2, 2e-3)},
)
.xfail(
dtypes=(torch.int64, torch.int32),
reason="fixme: Results do not match with PyTorch. https://github.com/microsoft/onnxscript/issues/854",
)
.skip(
matcher=lambda sample: sample.kwargs.get("dtype") in (torch.int64, torch.int32),
reason="fixme: Results do not match with PyTorch. https://github.com/microsoft/onnxscript/issues/854",
),
TorchLibOpInfo("log", core_ops.aten_log),
TorchLibOpInfo("le", core_ops.aten_le),
Expand Down
Loading