Skip to content

Fix torch's convert_to_tensor not respecting dtype when input is a Variable #21452

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions keras/src/backend/openvino/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,18 +595,17 @@ def convert_to_tensor(x, dtype=None, sparse=None, ragged=None):
raise ValueError("`sparse=True` is not supported with openvino backend")
if ragged:
raise ValueError("`ragged=True` is not supported with openvino backend")
if dtype is not None:
dtype = standardize_dtype(dtype)
if isinstance(x, OpenVINOKerasTensor):
return x
elif isinstance(x, np.ndarray):
if dtype is not None:
dtype = standardize_dtype(dtype)
ov_type = OPENVINO_DTYPES[dtype]
return OpenVINOKerasTensor(ov_opset.constant(x, ov_type).output(0))
return OpenVINOKerasTensor(ov_opset.constant(x).output(0))
elif isinstance(x, (list, tuple)):
if dtype is not None:
dtype = standardize_dtype(dtype)
else:
if dtype is None:
# try to properly deduce element type
elem = _get_first_element(x)
if isinstance(elem, float):
Expand All @@ -619,12 +618,11 @@ def convert_to_tensor(x, dtype=None, sparse=None, ragged=None):
dtype = standardize_dtype(dtype)
ov_type = OPENVINO_DTYPES[dtype]
return OpenVINOKerasTensor(ov_opset.constant(x, ov_type).output(0), x)
if dtype is not None:
dtype = standardize_dtype(dtype)
if isinstance(x, Variable):
x = x.value
if dtype and dtype != x.dtype:
return x.value.astype(dtype)
return x.value
x = cast(x, dtype)
return x
if not is_tensor(x) and standardize_dtype(dtype) == "bfloat16":
return ov.Tensor(np.asarray(x).astype(dtype))
if dtype is None:
Expand Down
8 changes: 4 additions & 4 deletions keras/src/backend/torch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,10 +192,10 @@ def convert_to_tensor(x, dtype=None, sparse=None, ragged=None):
if ragged:
raise ValueError("`ragged=True` is not supported with torch backend")
if isinstance(x, Variable):
# TorchDynamo has bugs supporting nn.Parameter type check.
# Return it directly instead of pass it to the rest of the logic in the
# function.
return x.value
if dtype is None:
return x.value
x = x.value
return x.to(to_torch_dtype(dtype))
if is_tensor(x):
device = get_device()
if x.device != device:
Expand Down
5 changes: 5 additions & 0 deletions keras/src/ops/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1150,6 +1150,11 @@ def test_convert_to_tensor(self, x, dtype, expected_dtype):
ops.convert_to_tensor(x, dtype=dtype), expected_dtype
)

@parameterized.named_parameters(named_product(dtype=ALL_DTYPES))
def test_convert_to_tensor_with_variable(self, dtype):
x = backend.Variable(np.array([1.0, 0.0, 1.0], dtype=np.float32))
self.assertDType(ops.convert_to_tensor(x, dtype=dtype), dtype)

@parameterized.named_parameters(named_product(dtype=ALL_DTYPES))
def test_saturate_cast(self, dtype):
x = np.ones((1,))
Expand Down
6 changes: 3 additions & 3 deletions keras/src/ops/nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3137,7 +3137,7 @@ def test_dot_product_attention(self, dtype):
def test_rms_normalization(self, dtypes):
input_dtype, weight_dtype = dtypes
inputs = knp.ones((2, 8), dtype=input_dtype)
scale = knp.ones((8,), dtype=weight_dtype)
scale = backend.Variable(knp.ones((8,), dtype=weight_dtype))
expected_dtype = input_dtype

self.assertDType(knn.rms_normalization(inputs, scale), expected_dtype)
Expand All @@ -3151,8 +3151,8 @@ def test_rms_normalization(self, dtypes):
def test_layer_normalization(self, dtypes):
input_dtype, weight_dtype = dtypes
inputs = knp.ones((2, 8), dtype=input_dtype)
gamma = knp.ones((8,), dtype=weight_dtype)
beta = knp.ones((8,), dtype=weight_dtype)
gamma = backend.Variable(knp.ones((8,), dtype=weight_dtype))
beta = backend.Variable(knp.ones((8,), dtype=weight_dtype))
expected_dtype = input_dtype

self.assertDType(
Expand Down