Skip to content

Commit d16947f

Browse files
committed
Add MutationClient
1 parent 5f1b3cd commit d16947f

File tree

2 files changed

+81
-0
lines changed

2 files changed

+81
-0
lines changed
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import time
2+
from typing import Any, Optional
3+
4+
from graphdatascience import QueryRunner
5+
from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient
6+
from graphdatascience.call_parameters import CallParameters
7+
from graphdatascience.query_runner.protocol.write_protocols import WriteProtocol
8+
from graphdatascience.query_runner.termination_flag import TerminationFlagNoop
9+
from graphdatascience.session.dbms.protocol_resolver import ProtocolVersionResolver
10+
11+
12+
class WriteBackClient:
13+
def __init__(self, arrow_client: AuthenticatedArrowClient, query_runner: QueryRunner):
14+
self._arrow_client = arrow_client
15+
self._query_runner = query_runner
16+
17+
protocol_version = ProtocolVersionResolver(query_runner).resolve()
18+
self._write_protocol = WriteProtocol.select(protocol_version)
19+
20+
# TODO: Add progress logging
21+
# TODO: Support setting custom writeProperties and relationshipTypes
22+
def write(self, graph_name: str, job_id: str, concurrency: Optional[int]) -> int:
23+
arrow_config = self._arrow_configuration()
24+
25+
configuration = {}
26+
if concurrency is not None:
27+
configuration["concurrency"] = concurrency
28+
29+
write_back_params = CallParameters(
30+
graphName=graph_name,
31+
jobId=job_id,
32+
arrowConfiguration=arrow_config,
33+
configuration=configuration,
34+
)
35+
36+
start_time = time.time()
37+
38+
self._write_protocol.run_write_back(self._query_runner, write_back_params, None, TerminationFlagNoop())
39+
40+
return int((time.time() - start_time) * 1000)
41+
42+
def _arrow_configuration(self) -> dict[str, Any]:
43+
connection_info = self._arrow_client.connection_info()
44+
token = self._arrow_client.request_token()
45+
if token is None:
46+
token = "IGNORED"
47+
arrow_config = {
48+
"host": connection_info.host,
49+
"port": connection_info.port,
50+
"token": token,
51+
"encrypted": connection_info.encrypted,
52+
}
53+
54+
return arrow_config
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import unittest
2+
from unittest.mock import MagicMock
3+
4+
from graphdatascience.arrow_client.v2.api_types import MutateResult
5+
from graphdatascience.arrow_client.v2.mutation_client import MutationClient
6+
from graphdatascience.tests.unit.arrow_client.arrow_test_utils import ArrowTestResult
7+
8+
9+
class TestMutationClient(unittest.TestCase):
10+
def setUp(self) -> None:
11+
self.mock_client = MagicMock()
12+
13+
def test_mutate_node_property_success(self) -> None:
14+
job_id = "test-job-123"
15+
expected_mutation_result = MutateResult(nodePropertiesWritten=42, relationshipsWritten=1337)
16+
17+
self.mock_client.do_action_with_retry.return_value = iter(
18+
[ArrowTestResult(expected_mutation_result.dump_camel())]
19+
)
20+
21+
result = MutationClient.mutate_node_property(self.mock_client, job_id, "propertyName")
22+
23+
assert result == expected_mutation_result
24+
25+
self.mock_client.do_action_with_retry.assert_called_once_with(
26+
MutationClient.MUTATE_ENDPOINT, b'{"jobId": "test-job-123", "mutateProperty": "propertyName"}'
27+
)

0 commit comments

Comments
 (0)