Skip to content

Commit 713cc5c

Browse files
authored
Merge pull request #918 from DarthMax/v2_arrow_clients
V2 Arrow Clients
2 parents fdbbe0f + 5b7e29f commit 713cc5c

24 files changed

+868
-0
lines changed
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from abc import ABC, abstractmethod
2+
from dataclasses import dataclass
3+
4+
5+
class ArrowAuthentication(ABC):
6+
@abstractmethod
7+
def auth_pair(self) -> tuple[str, str]:
8+
"""Returns the auth pair used for authentication."""
9+
pass
10+
11+
12+
@dataclass
13+
class UsernamePasswordAuthentication(ArrowAuthentication):
14+
username: str
15+
password: str
16+
17+
def auth_pair(self) -> tuple[str, str]:
18+
return self.username, self.password
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from typing import Any
2+
3+
from pydantic import BaseModel, ConfigDict
4+
from pydantic.alias_generators import to_camel
5+
6+
7+
class ArrowBaseModel(BaseModel):
8+
model_config = ConfigDict(alias_generator=to_camel)
9+
10+
def dump_camel(self) -> dict[str, Any]:
11+
return self.model_dump(by_alias=True)
12+
13+
def dump_json(self) -> str:
14+
return self.model_dump_json(by_alias=True)
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass
4+
5+
from ..query_runner.query_runner import QueryRunner
6+
from ..server_version.server_version import ServerVersion
7+
8+
9+
@dataclass(frozen=True)
10+
class ArrowInfo:
11+
listenAddress: str
12+
enabled: bool
13+
running: bool
14+
versions: list[str]
15+
16+
@staticmethod
17+
def create(query_runner: QueryRunner) -> ArrowInfo:
18+
debugYields = ["listenAddress", "enabled", "running"]
19+
if query_runner.server_version() > ServerVersion(2, 6, 0):
20+
debugYields.append("versions")
21+
22+
procResult = query_runner.call_procedure(
23+
endpoint="gds.debug.arrow", custom_error=False, yields=debugYields
24+
).iloc[0]
25+
26+
return ArrowInfo(
27+
listenAddress=procResult["listenAddress"],
28+
enabled=procResult["enabled"],
29+
running=procResult["running"],
30+
versions=procResult.get("versions", []),
31+
)
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
from __future__ import annotations
2+
3+
import logging
4+
from dataclasses import dataclass
5+
from typing import Any, Iterator, Optional
6+
7+
from pyarrow import __version__ as arrow_version
8+
from pyarrow import flight
9+
from pyarrow._flight import (
10+
Action,
11+
FlightInternalError,
12+
FlightStreamReader,
13+
FlightTimedOutError,
14+
FlightUnavailableError,
15+
Result,
16+
Ticket,
17+
)
18+
from tenacity import retry, retry_any, retry_if_exception_type, stop_after_attempt, stop_after_delay, wait_exponential
19+
20+
from graphdatascience.arrow_client.arrow_authentication import ArrowAuthentication
21+
from graphdatascience.arrow_client.arrow_info import ArrowInfo
22+
from graphdatascience.retry_utils.retry_config import RetryConfig
23+
24+
from ..retry_utils.retry_utils import before_log
25+
from ..version import __version__
26+
from .middleware.auth_middleware import AuthFactory, AuthMiddleware
27+
from .middleware.user_agent_middleware import UserAgentFactory
28+
29+
30+
class AuthenticatedArrowClient:
31+
@staticmethod
32+
def create(
33+
arrow_info: ArrowInfo,
34+
auth: Optional[ArrowAuthentication] = None,
35+
encrypted: bool = False,
36+
arrow_client_options: Optional[dict[str, Any]] = None,
37+
connection_string_override: Optional[str] = None,
38+
retry_config: Optional[RetryConfig] = None,
39+
) -> AuthenticatedArrowClient:
40+
connection_string: str
41+
if connection_string_override is not None:
42+
connection_string = connection_string_override
43+
else:
44+
connection_string = arrow_info.listenAddress
45+
46+
host, port = connection_string.split(":")
47+
48+
if retry_config is None:
49+
retry_config = RetryConfig(
50+
retry=retry_any(
51+
retry_if_exception_type(FlightTimedOutError),
52+
retry_if_exception_type(FlightUnavailableError),
53+
retry_if_exception_type(FlightInternalError),
54+
),
55+
stop=(stop_after_delay(10) | stop_after_attempt(5)),
56+
wait=wait_exponential(multiplier=1, min=1, max=10),
57+
)
58+
59+
return AuthenticatedArrowClient(
60+
host=host,
61+
retry_config=retry_config,
62+
port=int(port),
63+
auth=auth,
64+
encrypted=encrypted,
65+
arrow_client_options=arrow_client_options,
66+
)
67+
68+
def __init__(
69+
self,
70+
host: str,
71+
retry_config: RetryConfig,
72+
port: int = 8491,
73+
auth: Optional[ArrowAuthentication] = None,
74+
encrypted: bool = False,
75+
arrow_client_options: Optional[dict[str, Any]] = None,
76+
user_agent: Optional[str] = None,
77+
):
78+
"""Creates a new GdsArrowClient instance.
79+
80+
Parameters
81+
----------
82+
host: str
83+
The host address of the GDS Arrow server
84+
port: int
85+
The host port of the GDS Arrow server (default is 8491)
86+
auth: Optional[ArrowAuthentication]
87+
An implementation of ArrowAuthentication providing a pair to be used for basic authentication
88+
encrypted: bool
89+
A flag that indicates whether the connection should be encrypted (default is False)
90+
arrow_client_options: Optional[dict[str, Any]]
91+
Additional options to be passed to the Arrow Flight client.
92+
user_agent: Optional[str]
93+
The user agent string to use for the connection. (default is `neo4j-graphdatascience-v[VERSION] pyarrow-v[PYARROW_VERSION])
94+
retry_config: Optional[RetryConfig]
95+
The retry configuration to use for the Arrow requests send by the client.
96+
"""
97+
self._host = host
98+
self._port = port
99+
self._auth = None
100+
self._encrypted = encrypted
101+
self._arrow_client_options = arrow_client_options
102+
self._user_agent = user_agent
103+
self._retry_config = retry_config
104+
self._logger = logging.getLogger("gds_arrow_client")
105+
self._retry_config = retry_config
106+
if auth:
107+
self._auth = auth
108+
self._auth_middleware = AuthMiddleware(auth)
109+
110+
self._flight_client = self._instantiate_flight_client()
111+
112+
def connection_info(self) -> ConnectionInfo:
113+
"""
114+
Returns the host and port of the GDS Arrow server.
115+
116+
Returns
117+
-------
118+
tuple[str, int]
119+
the host and port of the GDS Arrow server
120+
"""
121+
return ConnectionInfo(self._host, self._port, self._encrypted)
122+
123+
def request_token(self) -> Optional[str]:
124+
"""
125+
Requests a token from the server and returns it.
126+
127+
Returns
128+
-------
129+
Optional[str]
130+
a token from the server and returns it.
131+
"""
132+
133+
@retry(
134+
reraise=True,
135+
before=before_log("Request token", self._logger, logging.DEBUG),
136+
retry=self._retry_config.retry,
137+
stop=self._retry_config.stop,
138+
wait=self._retry_config.wait,
139+
)
140+
def auth_with_retry() -> None:
141+
client = self._flight_client
142+
if self._auth:
143+
auth_pair = self._auth.auth_pair()
144+
client.authenticate_basic_token(auth_pair[0], auth_pair[1])
145+
146+
if self._auth:
147+
auth_with_retry()
148+
return self._auth_middleware.token()
149+
else:
150+
return "IGNORED"
151+
152+
def get_stream(self, ticket: Ticket) -> FlightStreamReader:
153+
return self._flight_client.do_get(ticket)
154+
155+
def do_action(self, endpoint: str, payload: bytes) -> Iterator[Result]:
156+
return self._flight_client.do_action(Action(endpoint, payload)) # type: ignore
157+
158+
def do_action_with_retry(self, endpoint: str, payload: bytes) -> Iterator[Result]:
159+
@retry(
160+
reraise=True,
161+
before=before_log("Send action", self._logger, logging.DEBUG),
162+
retry=self._retry_config.retry,
163+
stop=self._retry_config.stop,
164+
wait=self._retry_config.wait,
165+
)
166+
def run_with_retry() -> Iterator[Result]:
167+
return self.do_action(endpoint, payload)
168+
169+
return run_with_retry()
170+
171+
def _instantiate_flight_client(self) -> flight.FlightClient:
172+
location = (
173+
flight.Location.for_grpc_tls(self._host, self._port)
174+
if self._encrypted
175+
else flight.Location.for_grpc_tcp(self._host, self._port)
176+
)
177+
client_options: dict[str, Any] = (self._arrow_client_options or {}).copy()
178+
if self._auth:
179+
user_agent = f"neo4j-graphdatascience-v{__version__} pyarrow-v{arrow_version}"
180+
if self._user_agent:
181+
user_agent = self._user_agent
182+
183+
client_options["middleware"] = [
184+
AuthFactory(self._auth_middleware),
185+
UserAgentFactory(useragent=user_agent),
186+
]
187+
188+
return flight.FlightClient(location, **client_options)
189+
190+
191+
@dataclass
192+
class ConnectionInfo:
193+
host: str
194+
port: int
195+
encrypted: bool
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
from __future__ import annotations
2+
3+
import base64
4+
import time
5+
from typing import Any, Optional
6+
7+
from pyarrow._flight import ClientMiddleware, ClientMiddlewareFactory
8+
9+
from graphdatascience.arrow_client.arrow_authentication import ArrowAuthentication
10+
11+
12+
class AuthFactory(ClientMiddlewareFactory): # type: ignore
13+
def __init__(self, middleware: AuthMiddleware, *args: Any, **kwargs: Any) -> None:
14+
super().__init__(*args, **kwargs)
15+
self._middleware = middleware
16+
17+
def start_call(self, info: Any) -> AuthMiddleware:
18+
return self._middleware
19+
20+
21+
class AuthMiddleware(ClientMiddleware): # type: ignore
22+
def __init__(self, auth: ArrowAuthentication, *args: Any, **kwargs: Any) -> None:
23+
super().__init__(*args, **kwargs)
24+
self._auth = auth
25+
self._token: Optional[str] = None
26+
self._token_timestamp = 0
27+
28+
def token(self) -> Optional[str]:
29+
# check whether the token is older than 10 minutes. If so, reset it.
30+
if self._token and int(time.time()) - self._token_timestamp > 600:
31+
self._token = None
32+
33+
return self._token
34+
35+
def _set_token(self, token: str) -> None:
36+
self._token = token
37+
self._token_timestamp = int(time.time())
38+
39+
def received_headers(self, headers: dict[str, Any]) -> None:
40+
auth_header = headers.get("authorization", None)
41+
if not auth_header:
42+
return
43+
44+
# the result is always a list
45+
header_value = auth_header[0]
46+
47+
if not isinstance(header_value, str):
48+
raise ValueError(f"Incompatible header value received from server: `{header_value}`")
49+
50+
auth_type, token = header_value.split(" ", 1)
51+
if auth_type == "Bearer":
52+
self._set_token(token)
53+
54+
def sending_headers(self) -> dict[str, str]:
55+
token = self.token()
56+
if token is not None:
57+
return {"authorization": "Bearer " + token}
58+
59+
auth_pair = self._auth.auth_pair()
60+
auth_token = f"{auth_pair[0]}:{auth_pair[1]}"
61+
auth_token = "Basic " + base64.b64encode(auth_token.encode("utf-8")).decode("ASCII")
62+
# There seems to be a bug, `authorization` must be lower key
63+
return {"authorization": auth_token}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from __future__ import annotations
2+
3+
from typing import Any
4+
5+
from pyarrow._flight import ClientMiddleware, ClientMiddlewareFactory
6+
7+
8+
class UserAgentFactory(ClientMiddlewareFactory): # type: ignore
9+
def __init__(self, useragent: str, *args: Any, **kwargs: Any) -> None:
10+
super().__init__(*args, **kwargs)
11+
self._middleware = UserAgentMiddleware(useragent)
12+
13+
def start_call(self, info: Any) -> ClientMiddleware:
14+
return self._middleware
15+
16+
17+
class UserAgentMiddleware(ClientMiddleware): # type: ignore
18+
def __init__(self, useragent: str, *args: Any, **kwargs: Any) -> None:
19+
super().__init__(*args, **kwargs)
20+
self._useragent = useragent
21+
22+
def sending_headers(self) -> dict[str, str]:
23+
return {"x-gds-user-agent": self._useragent}
24+
25+
def received_headers(self, headers: dict[str, Any]) -> None:
26+
pass
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from graphdatascience.arrow_client.arrow_base_model import ArrowBaseModel
2+
3+
4+
class JobIdConfig(ArrowBaseModel):
5+
job_id: str
6+
7+
8+
class JobStatus(ArrowBaseModel):
9+
job_id: str
10+
status: str
11+
progress: float
12+
13+
14+
class MutateResult(ArrowBaseModel):
15+
node_properties_written: int
16+
relationships_written: int
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import json
2+
from typing import Any, Iterator
3+
4+
from pyarrow._flight import Result
5+
6+
7+
def deserialize_single(input_stream: Iterator[Result]) -> dict[str, Any]:
8+
rows = deserialize(input_stream)
9+
if len(rows) != 1:
10+
raise ValueError(f"Expected exactly one result, got {len(rows)}")
11+
12+
return rows[0]
13+
14+
15+
def deserialize(input_stream: Iterator[Result]) -> list[dict[str, Any]]:
16+
def deserialize_row(row: Result): # type:ignore
17+
return json.loads(row.body.to_pybytes().decode())
18+
19+
return [deserialize_row(row) for row in list(input_stream)]

0 commit comments

Comments
 (0)