Skip to content

Commit e9b8c06

Browse files
authored
Add merge_datasets to DataClient (#1008)
1 parent d7db9b4 commit e9b8c06

File tree

3 files changed

+52
-8
lines changed

3 files changed

+52
-8
lines changed

src/viam/app/data_client.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@
9191
ListDatasetsByOrganizationIDRequest,
9292
ListDatasetsByOrganizationIDResponse,
9393
RenameDatasetRequest,
94+
MergeDatasetsRequest,
95+
MergeDatasetsResponse,
9496
)
9597
from viam.proto.app.datasync import (
9698
DataCaptureUploadMetadata,
@@ -1289,6 +1291,33 @@ async def create_dataset(self, name: str, organization_id: str) -> str:
12891291
response: CreateDatasetResponse = await self._dataset_client.CreateDataset(request, metadata=self._metadata)
12901292
return response.id
12911293

1294+
1295+
async def merge_datasets(self, name: str, organization_id: str, dataset_ids: List[str]) -> str:
1296+
"""Merge multiple datasets into a new dataset.
1297+
1298+
::
1299+
1300+
dataset_id = await data_client.merge_datasets(
1301+
name="<DATASET-NAME>",
1302+
organization_id="<YOUR-ORG-ID>",
1303+
dataset_ids=["<YOUR-DATASET-ID-1>", "<YOUR-DATASET-ID-2>"]
1304+
)
1305+
print(dataset_id)
1306+
1307+
Args:
1308+
name (str): The name of the dataset being created.
1309+
organization_id (str): The ID of the organization where the dataset is being created.
1310+
To find your organization ID, visit the organization settings page.
1311+
dataset_ids (List[str]): The IDs of the datasets that you would like to merge.
1312+
Returns:
1313+
str: The dataset ID of the created dataset.
1314+
1315+
For more information, see `Data Client API <https://docs.viam.com/dev/reference/apis/data-client/#mergedatasets>`_.
1316+
"""
1317+
request = MergeDatasetsRequest(name=name, organization_id=organization_id, dataset_ids=dataset_ids)
1318+
response: MergeDatasetsResponse = await self._dataset_client.MergeDatasets(request, metadata=self._metadata)
1319+
return response.dataset_id
1320+
12921321
async def list_dataset_by_ids(self, ids: List[str]) -> Sequence[Dataset]:
12931322
"""Get a list of datasets using their IDs.
12941323

tests/mocks/services.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,6 @@
316316
PointCloudObject,
317317
Pose,
318318
PoseInFrame,
319-
ResourceName,
320319
Transform,
321320
)
322321
from viam.proto.provisioning import (
@@ -616,11 +615,11 @@ def __init__(
616615
async def Move(self, stream: Stream[MoveRequest, MoveResponse]) -> None:
617616
request = await stream.recv_message()
618617
assert request is not None
619-
name: ResourceName = request.component_name
618+
name: str = request.component_name
620619
self.constraints = request.constraints
621620
self.extra = struct_to_dict(request.extra)
622621
self.timeout = stream.deadline.time_remaining() if stream.deadline else None
623-
success = self.move_responses[name.name]
622+
success = self.move_responses[name]
624623
response = MoveResponse(success=success)
625624
await stream.send_message(response)
626625

@@ -655,10 +654,10 @@ async def MoveOnGlobe(self, stream: Stream[MoveOnGlobeRequest, MoveOnGlobeRespon
655654
async def GetPose(self, stream: Stream[GetPoseRequest, GetPoseResponse]) -> None:
656655
request = await stream.recv_message()
657656
assert request is not None
658-
name: ResourceName = request.component_name
657+
name: str = request.component_name
659658
self.extra = struct_to_dict(request.extra)
660659
self.timeout = stream.deadline.time_remaining() if stream.deadline else None
661-
pose = self.get_pose_responses[name.name]
660+
pose = self.get_pose_responses[name]
662661
response = GetPoseResponse(pose=pose)
663662
await stream.send_message(response)
664663

@@ -1071,9 +1070,10 @@ async def ExportTabularData(self, stream: Stream[ExportTabularDataRequest, Expor
10711070

10721071

10731072
class MockDataset(DatasetServiceBase):
1074-
def __init__(self, create_response: str, datasets_response: Sequence[Dataset]):
1073+
def __init__(self, create_response: str, datasets_response: Sequence[Dataset], merged_response: Optional[str] = None):
10751074
self.create_response = create_response
10761075
self.datasets_response = datasets_response
1076+
self.merged_response = merged_response
10771077

10781078
async def CreateDataset(self, stream: Stream[CreateDatasetRequest, CreateDatasetResponse]) -> None:
10791079
request = await stream.recv_message()
@@ -1105,7 +1105,11 @@ async def ListDatasetsByOrganizationID(
11051105
async def MergeDatasets(self, stream: Stream[MergeDatasetsRequest, MergeDatasetsResponse]) -> None:
11061106
request = await stream.recv_message()
11071107
assert request is not None
1108-
await stream.send_message(MergeDatasetsResponse())
1108+
self.name = request.name
1109+
self.org_id = request.organization_id
1110+
self.dataset_ids = request.dataset_ids
1111+
self.merged_response = "".join(self.dataset_ids)
1112+
await stream.send_message(MergeDatasetsResponse(dataset_id=self.merged_response))
11091113

11101114
async def RenameDataset(self, stream: Stream[RenameDatasetRequest, RenameDatasetResponse]) -> None:
11111115
request = await stream.recv_message()

tests/test_dataset.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88
from .mocks.services import MockDataset
99

1010
CREATED_ID = "VIAM_DATASET_0"
11+
MERGED_NAME = "VIAM_DATASET_MERGED"
1112
ID = "VIAM_DATASET_1"
13+
ID2 = "VIAM_DATASET_2"
14+
MERGED_ID = f'{ID}{ID2}'
1215
NAME = "dataset"
1316
ORG_ID = "org_id"
1417
SECONDS = 978310861
@@ -22,7 +25,7 @@
2225

2326
@pytest.fixture(scope="function")
2427
def service() -> MockDataset:
25-
return MockDataset(CREATED_ID, DATASETS)
28+
return MockDataset(CREATED_ID, DATASETS, MERGED_ID)
2629

2730

2831
class TestClient:
@@ -34,6 +37,14 @@ async def test_create_dataset(self, service: MockDataset):
3437
assert service.org_id == ORG_ID
3538
assert id == CREATED_ID
3639

40+
async def test_merge_datasets(self, service: MockDataset):
41+
async with ChannelFor([service]) as channel:
42+
client = DataClient(channel, DATA_SERVICE_METADATA)
43+
id = await client.merge_datasets(MERGED_NAME, ORG_ID, [ID, ID2])
44+
assert service.name == MERGED_NAME
45+
assert service.org_id == ORG_ID
46+
assert id == MERGED_ID
47+
3748
async def test_delete_dataset(self, service: MockDataset):
3849
async with ChannelFor([service]) as channel:
3950
client = DataClient(channel, DATA_SERVICE_METADATA)

0 commit comments

Comments
 (0)