Skip to content

Commit adb8dbb

Browse files
committed
Introduce JobClient
1 parent 5bd0418 commit adb8dbb

File tree

11 files changed

+271
-0
lines changed

11 files changed

+271
-0
lines changed
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from typing import Any
2+
3+
from pydantic import BaseModel, ConfigDict
4+
from pydantic.alias_generators import to_camel
5+
6+
7+
class ArrowBaseModel(BaseModel):
8+
model_config = ConfigDict(alias_generator=to_camel)
9+
10+
def dump_camel(self) -> dict[str, Any]:
11+
return self.model_dump(by_alias=True)
12+
13+
def dump_json(self) -> str:
14+
return self.model_dump_json(by_alias=True)
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from graphdatascience.arrow_client.arrow_base_model import ArrowBaseModel
2+
3+
4+
class JobIdConfig(ArrowBaseModel):
5+
job_id: str
6+
7+
8+
class JobStatus(ArrowBaseModel):
9+
job_id: str
10+
status: str
11+
progress: float
12+
13+
14+
class MutateResult(ArrowBaseModel):
15+
node_properties_written: int
16+
relationships_written: int
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import json
2+
from typing import Any, Iterator
3+
4+
from pyarrow._flight import Result
5+
6+
7+
def deserialize_single(input_stream: Iterator[Result]) -> dict[str, Any]:
8+
rows = deserialize(input_stream)
9+
if len(rows) != 1:
10+
raise ValueError(f"Expected exactly one result, got {len(rows)}")
11+
12+
return rows[0]
13+
14+
15+
def deserialize(input_stream: Iterator[Result]) -> list[dict[str, Any]]:
16+
def deserialize_row(row: Result): # type:ignore
17+
return json.loads(row.body.to_pybytes().decode())
18+
19+
return [deserialize_row(row) for row in list(input_stream)]
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import json
2+
from typing import Any
3+
4+
from pandas import ArrowDtype, DataFrame
5+
from pyarrow._flight import Ticket
6+
7+
from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient
8+
from graphdatascience.arrow_client.v2.api_types import JobIdConfig, JobStatus
9+
from graphdatascience.arrow_client.v2.data_mapper_utils import deserialize_single
10+
11+
JOB_STATUS_ENDPOINT = "v2/jobs.status"
12+
RESULTS_SUMMARY_ENDPOINT = "v2/results.summary"
13+
14+
15+
class JobClient:
16+
@staticmethod
17+
def run_job_and_wait(client: AuthenticatedArrowClient, endpoint: str, config: dict[str, Any]) -> str:
18+
job_id = JobClient.run_job(client, endpoint, config)
19+
JobClient.wait_for_job(client, job_id)
20+
return job_id
21+
22+
@staticmethod
23+
def run_job(client: AuthenticatedArrowClient, endpoint: str, config: dict[str, Any]) -> str:
24+
encoded_config = json.dumps(config).encode("utf-8")
25+
res = client.do_action_with_retry(endpoint, encoded_config)
26+
27+
single = deserialize_single(res)
28+
return JobIdConfig(**single).job_id
29+
30+
@staticmethod
31+
def wait_for_job(client: AuthenticatedArrowClient, job_id: str) -> None:
32+
while True:
33+
encoded_config = JobIdConfig(jobId=job_id).dump_json().encode("utf-8")
34+
35+
arrow_res = client.do_action_with_retry(JOB_STATUS_ENDPOINT, encoded_config)
36+
job_status = JobStatus(**deserialize_single(arrow_res))
37+
if job_status.status == "Done":
38+
break
39+
40+
@staticmethod
41+
def get_summary(client: AuthenticatedArrowClient, job_id: str) -> dict[str, Any]:
42+
encoded_config = JobIdConfig(jobId=job_id).dump_json().encode("utf-8")
43+
44+
res = client.do_action_with_retry(RESULTS_SUMMARY_ENDPOINT, encoded_config)
45+
return deserialize_single(res)
46+
47+
@staticmethod
48+
def stream_results(client: AuthenticatedArrowClient, job_id: str) -> DataFrame:
49+
encoded_config = JobIdConfig(jobId=job_id).dump_json().encode("utf-8")
50+
51+
res = client.do_action_with_retry("v2/results.stream", encoded_config)
52+
export_job_id = JobIdConfig(**deserialize_single(res)).job_id
53+
54+
payload = {
55+
"name": export_job_id,
56+
"version": 1,
57+
}
58+
59+
ticket = Ticket(json.dumps(payload).encode("utf-8"))
60+
with client.get_stream(ticket) as get:
61+
arrow_table = get.read_all()
62+
63+
return arrow_table.to_pandas(types_mapper=ArrowDtype) # type: ignore

graphdatascience/procedure_surface/__init__.py

Whitespace-only changes.

graphdatascience/procedure_surface/arrow/__init__.py

Whitespace-only changes.

graphdatascience/tests/unit/arrow_client/V2/__init__.py

Whitespace-only changes.
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from typing import Iterator
2+
3+
import pytest
4+
from pyarrow._flight import Result
5+
6+
from graphdatascience.arrow_client.v2.data_mapper_utils import deserialize_single
7+
from graphdatascience.tests.unit.arrow_client.arrow_test_utils import ArrowTestResult
8+
9+
10+
def test_deserialize_single_success() -> None:
11+
input_stream = iter([ArrowTestResult({"key": "value"})])
12+
expected = {"key": "value"}
13+
actual = deserialize_single(input_stream)
14+
assert expected == actual
15+
16+
17+
def test_deserialize_single_raises_on_empty_stream() -> None:
18+
input_stream: Iterator[Result] = iter([])
19+
with pytest.raises(ValueError, match="Expected exactly one result, got 0"):
20+
deserialize_single(input_stream)
21+
22+
23+
def test_deserialize_single_raises_on_multiple_results() -> None:
24+
input_stream = iter([ArrowTestResult({"key1": "value1"}), ArrowTestResult({"key2": "value2"})])
25+
with pytest.raises(ValueError, match="Expected exactly one result, got 2"):
26+
deserialize_single(input_stream)
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import json
2+
import unittest
3+
from unittest.mock import MagicMock
4+
5+
from graphdatascience.arrow_client.v2.api_types import JobIdConfig, JobStatus
6+
from graphdatascience.arrow_client.v2.job_client import JobClient
7+
from graphdatascience.tests.unit.arrow_client.arrow_test_utils import ArrowTestResult
8+
9+
10+
class TestJobClient(unittest.TestCase):
11+
def setUp(self) -> None:
12+
self.mock_client = MagicMock()
13+
14+
def test_run_job(self) -> None:
15+
# Setup
16+
job_id = "test-job-123"
17+
endpoint = "v2/test.endpoint"
18+
config = {"param1": "value1", "param2": 42}
19+
20+
self.mock_client.do_action_with_retry.return_value = iter([ArrowTestResult({"jobId": job_id})])
21+
22+
# Execute
23+
result = JobClient.run_job(self.mock_client, endpoint, config)
24+
25+
# Verify
26+
expected_config = json.dumps(config).encode("utf-8")
27+
self.mock_client.do_action_with_retry.assert_called_once_with(endpoint, expected_config)
28+
self.assertEqual(result, job_id)
29+
30+
def test_run_job_and_wait(
31+
self,
32+
) -> None:
33+
job_id = "test-job-456"
34+
endpoint = "v2/test.endpoint"
35+
config = {"param": "value"}
36+
37+
job_id_config = JobIdConfig(jobId=job_id)
38+
39+
status = JobStatus(
40+
jobId=job_id,
41+
progress=1.0,
42+
status="Done",
43+
)
44+
45+
do_action_with_retry = MagicMock()
46+
do_action_with_retry.side_effect = [
47+
iter([ArrowTestResult(job_id_config.dump_camel())]),
48+
iter([ArrowTestResult(status.dump_camel())]),
49+
]
50+
51+
self.mock_client.do_action_with_retry = do_action_with_retry
52+
53+
result = JobClient.run_job_and_wait(self.mock_client, endpoint, config)
54+
55+
do_action_with_retry.assert_called_with("v2/jobs.status", job_id_config.dump_json().encode("utf-8"))
56+
self.assertEqual(result, job_id)
57+
58+
def test_wait_for_job_completes_immediately(self) -> None:
59+
job_id = "test-job-789"
60+
61+
status = JobStatus(
62+
jobId=job_id,
63+
progress=1.0,
64+
status="Done",
65+
)
66+
67+
self.mock_client.do_action_with_retry.return_value = iter([ArrowTestResult(status.dump_camel())])
68+
69+
JobClient.wait_for_job(self.mock_client, job_id)
70+
71+
self.mock_client.do_action_with_retry.assert_called_once_with(
72+
"v2/jobs.status", JobIdConfig(jobId=job_id).dump_json().encode("utf-8")
73+
)
74+
75+
def test_wait_for_job_waits_for_completion(self) -> None:
76+
job_id = "test-job-waiting"
77+
status_running = JobStatus(
78+
jobId=job_id,
79+
progress=0.5,
80+
status="RUNNING",
81+
)
82+
status_done = JobStatus(
83+
jobId=job_id,
84+
progress=1.0,
85+
status="Done",
86+
)
87+
88+
do_action_with_retry = MagicMock()
89+
do_action_with_retry.side_effect = [
90+
iter([ArrowTestResult(status_running.dump_camel())]),
91+
iter([ArrowTestResult(status_done.dump_camel())]),
92+
]
93+
94+
self.mock_client.do_action_with_retry = do_action_with_retry
95+
96+
JobClient.wait_for_job(self.mock_client, job_id)
97+
98+
self.assertEqual(self.mock_client.do_action_with_retry.call_count, 2)
99+
100+
def test_get_summary(self) -> None:
101+
# Setup
102+
job_id = "summary-job-123"
103+
expected_summary = {"nodeCount": 100, "relationshipCount": 200, "requiredMemory": "1GB"}
104+
105+
self.mock_client.do_action_with_retry.return_value = iter([ArrowTestResult(expected_summary)])
106+
107+
result = JobClient.get_summary(self.mock_client, job_id)
108+
109+
self.mock_client.do_action_with_retry.assert_called_once_with(
110+
"v2/results.summary", JobIdConfig(jobId=job_id).dump_json().encode("utf-8")
111+
)
112+
self.assertEqual(result, expected_summary)
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import json
2+
from typing import Any
3+
4+
from pyarrow._flight import Result
5+
6+
7+
class ArrowTestResult(Result): # type:ignore
8+
def __init__(self, body: dict[str, Any]):
9+
self._body = json.dumps(body).encode()
10+
11+
@property
12+
def body(self) -> Any:
13+
class MockBody:
14+
def __init__(self, data: bytes):
15+
self._data = data
16+
17+
def to_pybytes(self) -> bytes:
18+
return self._data
19+
20+
return MockBody(self._body)

0 commit comments

Comments
 (0)