Skip to content

Commit 5bd0418

Browse files
committed
Introduce AuthenticatedArrowClient
1 parent cf01c46 commit 5bd0418

File tree

9 files changed

+455
-0
lines changed

9 files changed

+455
-0
lines changed
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from abc import ABC, abstractmethod
2+
from dataclasses import dataclass
3+
4+
5+
class ArrowAuthentication(ABC):
6+
@abstractmethod
7+
def auth_pair(self) -> tuple[str, str]:
8+
"""Returns the auth pair used for authentication."""
9+
pass
10+
11+
12+
@dataclass
13+
class UsernamePasswordAuthentication(ArrowAuthentication):
14+
username: str
15+
password: str
16+
17+
def auth_pair(self) -> tuple[str, str]:
18+
return self.username, self.password
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass
4+
5+
from ..query_runner.query_runner import QueryRunner
6+
from ..server_version.server_version import ServerVersion
7+
8+
9+
@dataclass(frozen=True)
10+
class ArrowInfo:
11+
listenAddress: str
12+
enabled: bool
13+
running: bool
14+
versions: list[str]
15+
16+
@staticmethod
17+
def create(query_runner: QueryRunner) -> ArrowInfo:
18+
debugYields = ["listenAddress", "enabled", "running"]
19+
if query_runner.server_version() > ServerVersion(2, 6, 0):
20+
debugYields.append("versions")
21+
22+
procResult = query_runner.call_procedure(
23+
endpoint="gds.debug.arrow", custom_error=False, yields=debugYields
24+
).iloc[0]
25+
26+
return ArrowInfo(
27+
listenAddress=procResult["listenAddress"],
28+
enabled=procResult["enabled"],
29+
running=procResult["running"],
30+
versions=procResult.get("versions", []),
31+
)
Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
from __future__ import annotations
2+
3+
import logging
4+
from dataclasses import dataclass
5+
from typing import Any, Iterator, Optional
6+
7+
from pyarrow import __version__ as arrow_version
8+
from pyarrow import flight
9+
from pyarrow._flight import (
10+
Action,
11+
FlightInternalError,
12+
FlightStreamReader,
13+
FlightTimedOutError,
14+
FlightUnavailableError,
15+
Result,
16+
Ticket,
17+
)
18+
from tenacity import retry, retry_any, retry_if_exception_type, stop_after_attempt, stop_after_delay, wait_exponential
19+
20+
from graphdatascience.arrow_client.arrow_authentication import ArrowAuthentication
21+
from graphdatascience.arrow_client.arrow_info import ArrowInfo
22+
from graphdatascience.retry_utils.retry_config import RetryConfig
23+
24+
from ..retry_utils.retry_utils import before_log
25+
from ..version import __version__
26+
from .middleware.auth_middleware import AuthFactory, AuthMiddleware
27+
from .middleware.user_agent_middleware import UserAgentFactory
28+
29+
30+
class AuthenticatedArrowClient:
31+
@staticmethod
32+
def create(
33+
arrow_info: ArrowInfo,
34+
auth: Optional[ArrowAuthentication] = None,
35+
encrypted: bool = False,
36+
arrow_client_options: Optional[dict[str, Any]] = None,
37+
connection_string_override: Optional[str] = None,
38+
retry_config: Optional[RetryConfig] = None,
39+
) -> AuthenticatedArrowClient:
40+
connection_string: str
41+
if connection_string_override is not None:
42+
connection_string = connection_string_override
43+
else:
44+
connection_string = arrow_info.listenAddress
45+
46+
host, port = connection_string.split(":")
47+
48+
if retry_config is None:
49+
retry_config = RetryConfig(
50+
retry=retry_any(
51+
retry_if_exception_type(FlightTimedOutError),
52+
retry_if_exception_type(FlightUnavailableError),
53+
retry_if_exception_type(FlightInternalError),
54+
),
55+
stop=(stop_after_delay(10) | stop_after_attempt(5)),
56+
wait=wait_exponential(multiplier=1, min=1, max=10),
57+
)
58+
59+
return AuthenticatedArrowClient(
60+
host=host,
61+
retry_config=retry_config,
62+
port=int(port),
63+
auth=auth,
64+
encrypted=encrypted,
65+
arrow_client_options=arrow_client_options,
66+
)
67+
68+
def __init__(
69+
self,
70+
host: str,
71+
retry_config: RetryConfig,
72+
port: int = 8491,
73+
auth: Optional[ArrowAuthentication] = None,
74+
encrypted: bool = False,
75+
arrow_client_options: Optional[dict[str, Any]] = None,
76+
user_agent: Optional[str] = None,
77+
):
78+
"""Creates a new GdsArrowClient instance.
79+
80+
Parameters
81+
----------
82+
host: str
83+
The host address of the GDS Arrow server
84+
port: int
85+
The host port of the GDS Arrow server (default is 8491)
86+
auth: Optional[ArrowAuthentication]
87+
An implementation of ArrowAuthentication providing a pair to be used for basic authentication
88+
encrypted: bool
89+
A flag that indicates whether the connection should be encrypted (default is False)
90+
arrow_client_options: Optional[dict[str, Any]]
91+
Additional options to be passed to the Arrow Flight client.
92+
user_agent: Optional[str]
93+
The user agent string to use for the connection. (default is `neo4j-graphdatascience-v[VERSION] pyarrow-v[PYARROW_VERSION])
94+
retry_config: Optional[RetryConfig]
95+
The retry configuration to use for the Arrow requests send by the client.
96+
"""
97+
self._host = host
98+
self._port = port
99+
self._auth = None
100+
self._encrypted = encrypted
101+
self._arrow_client_options = arrow_client_options
102+
self._user_agent = user_agent
103+
self._retry_config = retry_config
104+
self._logger = logging.getLogger("gds_arrow_client")
105+
self._retry_config = RetryConfig(
106+
retry=retry_any(
107+
retry_if_exception_type(FlightTimedOutError),
108+
retry_if_exception_type(FlightUnavailableError),
109+
retry_if_exception_type(FlightInternalError),
110+
),
111+
stop=(stop_after_delay(10) | stop_after_attempt(5)),
112+
wait=wait_exponential(multiplier=1, min=1, max=10),
113+
)
114+
115+
if auth:
116+
self._auth = auth
117+
self._auth_middleware = AuthMiddleware(auth)
118+
119+
self._flight_client = self._instantiate_flight_client()
120+
121+
def connection_info(self) -> ConnectionInfo:
122+
"""
123+
Returns the host and port of the GDS Arrow server.
124+
125+
Returns
126+
-------
127+
tuple[str, int]
128+
the host and port of the GDS Arrow server
129+
"""
130+
return ConnectionInfo(self._host, self._port, self._encrypted)
131+
132+
def request_token(self) -> Optional[str]:
133+
"""
134+
Requests a token from the server and returns it.
135+
136+
Returns
137+
-------
138+
Optional[str]
139+
a token from the server and returns it.
140+
"""
141+
142+
@retry(
143+
reraise=True,
144+
before=before_log("Request token", self._logger, logging.DEBUG),
145+
retry=self._retry_config.retry,
146+
stop=self._retry_config.stop,
147+
wait=self._retry_config.wait,
148+
)
149+
def auth_with_retry() -> None:
150+
client = self._flight_client
151+
if self._auth:
152+
auth_pair = self._auth.auth_pair()
153+
client.authenticate_basic_token(auth_pair[0], auth_pair[1])
154+
155+
if self._auth:
156+
auth_with_retry()
157+
return self._auth_middleware.token()
158+
else:
159+
return "IGNORED"
160+
161+
def get_stream(self, ticket: Ticket) -> FlightStreamReader:
162+
return self._flight_client.do_get(ticket)
163+
164+
def do_action(self, endpoint: str, payload: bytes) -> Iterator[Result]:
165+
return self._flight_client.do_action(Action(endpoint, payload)) # type: ignore
166+
167+
def do_action_with_retry(self, endpoint: str, payload: bytes) -> Iterator[Result]:
168+
@retry(
169+
reraise=True,
170+
before=before_log("Send action", self._logger, logging.DEBUG),
171+
retry=self._retry_config.retry,
172+
stop=self._retry_config.stop,
173+
wait=self._retry_config.wait,
174+
)
175+
def run_with_retry() -> Iterator[Result]:
176+
return self.do_action(endpoint, payload)
177+
178+
return run_with_retry()
179+
180+
def _instantiate_flight_client(self) -> flight.FlightClient:
181+
location = (
182+
flight.Location.for_grpc_tls(self._host, self._port)
183+
if self._encrypted
184+
else flight.Location.for_grpc_tcp(self._host, self._port)
185+
)
186+
client_options: dict[str, Any] = (self._arrow_client_options or {}).copy()
187+
if self._auth:
188+
user_agent = f"neo4j-graphdatascience-v{__version__} pyarrow-v{arrow_version}"
189+
if self._user_agent:
190+
user_agent = self._user_agent
191+
192+
client_options["middleware"] = [
193+
AuthFactory(self._auth_middleware),
194+
UserAgentFactory(useragent=user_agent),
195+
]
196+
197+
return flight.FlightClient(location, **client_options)
198+
199+
200+
@dataclass
201+
class ConnectionInfo:
202+
host: str
203+
port: int
204+
encrypted: bool
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
from __future__ import annotations
2+
3+
import base64
4+
import time
5+
from typing import Any, Optional
6+
7+
from pyarrow._flight import ClientMiddleware, ClientMiddlewareFactory
8+
9+
from graphdatascience.arrow_client.arrow_authentication import ArrowAuthentication
10+
11+
12+
class AuthFactory(ClientMiddlewareFactory): # type: ignore
13+
def __init__(self, middleware: AuthMiddleware, *args: Any, **kwargs: Any) -> None:
14+
super().__init__(*args, **kwargs)
15+
self._middleware = middleware
16+
17+
def start_call(self, info: Any) -> AuthMiddleware:
18+
return self._middleware
19+
20+
21+
class AuthMiddleware(ClientMiddleware): # type: ignore
22+
def __init__(self, auth: ArrowAuthentication, *args: Any, **kwargs: Any) -> None:
23+
super().__init__(*args, **kwargs)
24+
self._auth = auth
25+
self._token: Optional[str] = None
26+
self._token_timestamp = 0
27+
28+
def token(self) -> Optional[str]:
29+
# check whether the token is older than 10 minutes. If so, reset it.
30+
if self._token and int(time.time()) - self._token_timestamp > 600:
31+
self._token = None
32+
33+
return self._token
34+
35+
def _set_token(self, token: str) -> None:
36+
self._token = token
37+
self._token_timestamp = int(time.time())
38+
39+
def received_headers(self, headers: dict[str, Any]) -> None:
40+
auth_header = headers.get("authorization", None)
41+
if not auth_header:
42+
return
43+
44+
# the result is always a list
45+
header_value = auth_header[0]
46+
47+
if not isinstance(header_value, str):
48+
raise ValueError(f"Incompatible header value received from server: `{header_value}`")
49+
50+
auth_type, token = header_value.split(" ", 1)
51+
if auth_type == "Bearer":
52+
self._set_token(token)
53+
54+
def sending_headers(self) -> dict[str, str]:
55+
token = self.token()
56+
if token is not None:
57+
return {"authorization": "Bearer " + token}
58+
59+
auth_pair = self._auth.auth_pair()
60+
auth_token = f"{auth_pair[0]}:{auth_pair[1]}"
61+
auth_token = "Basic " + base64.b64encode(auth_token.encode("utf-8")).decode("ASCII")
62+
# There seems to be a bug, `authorization` must be lower key
63+
return {"authorization": auth_token}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from __future__ import annotations
2+
3+
from typing import Any
4+
5+
from pyarrow._flight import ClientMiddleware, ClientMiddlewareFactory
6+
7+
8+
class UserAgentFactory(ClientMiddlewareFactory): # type: ignore
9+
def __init__(self, useragent: str, *args: Any, **kwargs: Any) -> None:
10+
super().__init__(*args, **kwargs)
11+
self._middleware = UserAgentMiddleware(useragent)
12+
13+
def start_call(self, info: Any) -> ClientMiddleware:
14+
return self._middleware
15+
16+
17+
class UserAgentMiddleware(ClientMiddleware): # type: ignore
18+
def __init__(self, useragent: str, *args: Any, **kwargs: Any) -> None:
19+
super().__init__(*args, **kwargs)
20+
self._useragent = useragent
21+
22+
def sending_headers(self) -> dict[str, str]:
23+
return {"x-gds-user-agent": self._useragent}
24+
25+
def received_headers(self, headers: dict[str, Any]) -> None:
26+
pass

graphdatascience/tests/unit/arrow_client/__init__.py

Whitespace-only changes.

graphdatascience/tests/unit/arrow_client/middleware/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)