Skip to content

Commit 1a9ecbb

Browse files
committed
Add WriteBackClient
1 parent d16947f commit 1a9ecbb

File tree

2 files changed

+58
-1
lines changed

2 files changed

+58
-1
lines changed

graphdatascience/arrow_client/v2/write_back_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,4 @@ def _arrow_configuration(self) -> dict[str, Any]:
5151
"encrypted": connection_info.encrypted,
5252
}
5353

54-
return arrow_config
54+
return arrow_config
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from typing import Optional
2+
from unittest.mock import Mock
3+
4+
import pytest
5+
from pandas import DataFrame
6+
7+
from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient
8+
from graphdatascience.arrow_client.v2.write_back_client import WriteBackClient
9+
from graphdatascience.tests.unit.conftest import DEFAULT_SERVER_VERSION, CollectingQueryRunner
10+
11+
12+
@pytest.fixture
13+
def mock_arrow_client() -> AuthenticatedArrowClient:
14+
client = Mock(spec=AuthenticatedArrowClient)
15+
client.connection_info.return_value = Mock(host="localhost", port=8080, encrypted=False)
16+
client.request_token.return_value = "test_token"
17+
return client
18+
19+
20+
@pytest.fixture
21+
def write_back_client(mock_arrow_client: AuthenticatedArrowClient) -> WriteBackClient:
22+
query_runner = CollectingQueryRunner(
23+
DEFAULT_SERVER_VERSION,
24+
{
25+
"protocol.version": DataFrame([{"version": "v3"}]),
26+
},
27+
)
28+
return WriteBackClient(mock_arrow_client, query_runner)
29+
30+
31+
def test_write_back_client_initialization(write_back_client: WriteBackClient) -> None:
32+
assert isinstance(write_back_client, WriteBackClient)
33+
34+
35+
def test_arrow_configuration(write_back_client: WriteBackClient, mock_arrow_client: AuthenticatedArrowClient) -> None:
36+
expected_config = {
37+
"host": "localhost",
38+
"port": 8080,
39+
"token": "test_token",
40+
"encrypted": False,
41+
}
42+
43+
config = write_back_client._arrow_configuration()
44+
assert config == expected_config
45+
46+
47+
def test_write_calls_run_write_back(write_back_client: WriteBackClient) -> None:
48+
graph_name = "test_graph"
49+
job_id = "123"
50+
concurrency: Optional[int] = 4
51+
52+
write_back_client._write_protocol.run_write_back = Mock() # type: ignore
53+
54+
duration = write_back_client.write(graph_name, job_id, concurrency)
55+
56+
write_back_client._write_protocol.run_write_back.assert_called_once()
57+
assert duration >= 0

0 commit comments

Comments
 (0)