From 3b7e39731e9cf201ed7173e8176ac3ad4f7b2316 Mon Sep 17 00:00:00 2001 From: Qiaolin-Yu Date: Wed, 2 Jul 2025 11:51:42 -0700 Subject: [PATCH 1/4] Add tests for tensordict --- python/ray/tests/test_gpu_objects_gloo.py | 74 ++++++++++++++++++++++- 1 file changed, 72 insertions(+), 2 deletions(-) diff --git a/python/ray/tests/test_gpu_objects_gloo.py b/python/ray/tests/test_gpu_objects_gloo.py index 495f2e82cd4f..87ff82076841 100644 --- a/python/ray/tests/test_gpu_objects_gloo.py +++ b/python/ray/tests/test_gpu_objects_gloo.py @@ -2,6 +2,7 @@ import random import torch import pytest +from tensordict import TensorDict import ray from ray.experimental.collective import create_collective_group from ray._private.custom_types import TensorTransportEnum @@ -15,7 +16,10 @@ def echo(self, data): def double(self, data): if isinstance(data, list): - return [d * 2 for d in data] + return [self.double(d) for d in data] + if isinstance(data, TensorDict): + ret = data.apply(lambda x: x * 2) + return ret return data * 2 def get_gpu_object(self, obj_id: str): @@ -102,8 +106,14 @@ def test_multiple_tensors(ray_start_regular): tensor1 = torch.randn((1,)) tensor2 = torch.randn((2,)) + td1 = TensorDict( + {"action1": torch.randn((2,)), "reward1": torch.randn((2,))}, batch_size=[2] + ) + td2 = TensorDict( + {"action2": torch.randn((2,)), "reward2": torch.randn((2,))}, batch_size=[2] + ) cpu_data = random.randint(0, 100) - data = [tensor1, tensor2, cpu_data] + data = [tensor1, tensor2, cpu_data, td1, td2] sender, receiver = actors[0], actors[1] ref = sender.echo.remote(data) @@ -113,6 +123,10 @@ def test_multiple_tensors(ray_start_regular): assert result[0] == pytest.approx(tensor1 * 2) assert result[1] == pytest.approx(tensor2 * 2) assert result[2] == cpu_data * 2 + assert result[3]["action1"] == pytest.approx(td1["action1"] * 2) + assert result[3]["reward1"] == pytest.approx(td1["reward1"] * 2) + assert result[4]["action2"] == pytest.approx(td2["action2"] * 2) + assert result[4]["reward2"] == pytest.approx(td2["reward2"] * 2) def test_trigger_out_of_band_tensor_transfer(ray_start_regular): @@ -159,5 +173,61 @@ def echo(self, data): return data +def test_tensordict_transfer(ray_start_regular): + world_size = 2 + actors = [GPUTestActor.remote() for _ in range(world_size)] + create_collective_group(actors, backend="torch_gloo") + + td = TensorDict( + {"action": torch.randn((2,)), "reward": torch.randn((2,))}, batch_size=[2] + ) + sender, receiver = actors[0], actors[1] + ref = sender.echo.remote(td) + result = receiver.double.remote(ref) + td_result = ray.get(result) + + assert td_result["action"] == pytest.approx(td["action"] * 2) + assert td_result["reward"] == pytest.approx(td["reward"] * 2) + + +def test_nested_tensordict(ray_start_regular): + world_size = 2 + actors = [GPUTestActor.remote() for _ in range(world_size)] + create_collective_group(actors, backend="torch_gloo") + + inner_td = TensorDict( + {"action": torch.randn((2,)), "reward": torch.randn((2,))}, batch_size=[2] + ) + outer_td = TensorDict( + {"inner_td": inner_td, "test": torch.randn((2,))}, batch_size=[2] + ) + sender = actors[0] + receiver = actors[1] + gpu_ref = sender.echo.remote(outer_td) + ret_val_src = ray.get(receiver.double.remote(gpu_ref)) + assert ret_val_src is not None + assert torch.equal(ret_val_src["inner_td"]["action"], inner_td["action"] * 2) + assert torch.equal(ret_val_src["inner_td"]["reward"], inner_td["reward"] * 2) + assert torch.equal(ret_val_src["test"], outer_td["test"] * 2) + + +def test_tensor_extracted_from_tensordict_in_gpu_object_store(ray_start_regular): + actor = GPUTestActor.remote() + create_collective_group([actor], backend="torch_gloo") + + td = TensorDict( + {"action": torch.randn((2,)), "reward": torch.randn((2,))}, batch_size=[2] + ).to("cpu") + gpu_ref = actor.echo.remote(td) + + # Since the tensor is extracted from the tensordict, the `ret_val_src` will be a list of tensors + # instead of a tensordict. + ret_val_src = ray.get(actor.get_gpu_object.remote(gpu_ref.hex())) + assert ret_val_src is not None + assert len(ret_val_src) == 2 + assert torch.equal(ret_val_src[0], td["action"]) + assert torch.equal(ret_val_src[1], td["reward"]) + + if __name__ == "__main__": sys.exit(pytest.main(["-sv", __file__])) From a164316652b2cb79739a5810b6cc9baae0ce0ee7 Mon Sep 17 00:00:00 2001 From: Qiaolin-Yu Date: Wed, 2 Jul 2025 14:27:08 -0700 Subject: [PATCH 2/4] refine --- python/ray/tests/test_gpu_objects_gloo.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/ray/tests/test_gpu_objects_gloo.py b/python/ray/tests/test_gpu_objects_gloo.py index 87ff82076841..53c04b71e926 100644 --- a/python/ray/tests/test_gpu_objects_gloo.py +++ b/python/ray/tests/test_gpu_objects_gloo.py @@ -18,8 +18,7 @@ def double(self, data): if isinstance(data, list): return [self.double(d) for d in data] if isinstance(data, TensorDict): - ret = data.apply(lambda x: x * 2) - return ret + return data.apply(lambda x: x * 2) return data * 2 def get_gpu_object(self, obj_id: str): From 832d4afb71715e334d78366108ebe4a8510d5141 Mon Sep 17 00:00:00 2001 From: Qiaolin-Yu Date: Thu, 3 Jul 2025 10:28:14 -0700 Subject: [PATCH 3/4] add dependency --- python/requirements/test-requirements.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/requirements/test-requirements.txt b/python/requirements/test-requirements.txt index 0ef9462c2773..5832afe8680b 100644 --- a/python/requirements/test-requirements.txt +++ b/python/requirements/test-requirements.txt @@ -109,6 +109,9 @@ backoff==1.10 threadpoolctl==3.1.0 numexpr==2.8.4 +# For test_gpu_objects_gloo.py +tensordict==0.8.3 + # For `serve run --reload` CLI. watchfiles==0.19.0 From e4cf5f7bd8b9c3f221900e5033435430cb0b49f0 Mon Sep 17 00:00:00 2001 From: Qiaolin-Yu Date: Thu, 3 Jul 2025 11:26:49 -0700 Subject: [PATCH 4/4] fix --- python/requirements_compiled.txt | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/python/requirements_compiled.txt b/python/requirements_compiled.txt index a37c963ef590..0f293fcee0ad 100644 --- a/python/requirements_compiled.txt +++ b/python/requirements_compiled.txt @@ -305,6 +305,7 @@ cloudpickle==2.2.0 # mlflow-skinny # pymars # statsforecast + # tensordict # tensorflow-probability cma==3.2.2 # via nevergrad @@ -788,6 +789,7 @@ importlib-metadata==6.11.0 # myst-nb # opentelemetry-api # pytest-virtualenv + # tensordict importlib-resources==5.13.0 # via # etils @@ -1264,6 +1266,7 @@ numpy==1.26.4 # supersuit # tensorboard # tensorboardx + # tensordict # tensorflow # tensorflow-datasets # tensorflow-probability @@ -1359,7 +1362,9 @@ opt-einsum==3.3.0 optuna==4.1.0 # via -r python/requirements/ml/tune-requirements.txt orjson==3.9.10 - # via gradio + # via + # gradio + # tensordict ormsgpack==1.7.0 # via -r python/requirements/ml/rllib-requirements.txt packaging==23.0 @@ -1407,6 +1412,7 @@ packaging==23.0 # sphinx # statsmodels # tensorboardx + # tensordict # tensorflow # torchmetrics # transformers @@ -2183,6 +2189,8 @@ tensorboardx==2.6.2.2 # -r python/requirements.txt # -r python/requirements/test-requirements.txt # pytorch-lightning +tensordict==0.8.3 + # via -r python/requirements/test-requirements.txt tensorflow==2.15.1 ; python_version < "3.12" and (sys_platform != "darwin" or platform_machine != "arm64") # via -r python/requirements/ml/dl-cpu-requirements.txt tensorflow-datasets==4.9.3 ; python_version < "3.12" @@ -2258,6 +2266,7 @@ torch==2.3.0 # pyro-ppl # pytorch-lightning # pytorch-ranger + # tensordict # timm # torch-optimizer # torchmetrics