Skip to content

Wcc Endpoints + integration tests #922

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jul 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions graphdatascience/arrow_client/v2/job_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,19 @@ def get_summary(client: AuthenticatedArrowClient, job_id: str) -> dict[str, Any]
return deserialize_single(res)

@staticmethod
def stream_results(client: AuthenticatedArrowClient, job_id: str) -> DataFrame:
encoded_config = JobIdConfig(jobId=job_id).dump_json().encode("utf-8")
def stream_results(client: AuthenticatedArrowClient, graph_name: str, job_id: str) -> DataFrame:
payload = {
"graphName": graph_name,
"jobId": job_id,
}

res = client.do_action_with_retry("v2/results.stream", encoded_config)
res = client.do_action_with_retry("v2/results.stream", json.dumps(payload).encode("utf-8"))
export_job_id = JobIdConfig(**deserialize_single(res)).job_id

payload = {
"name": export_job_id,
"version": 1,
}
stream_payload = {"version": "v2", "name": export_job_id, "body": {}}

ticket = Ticket(json.dumps(payload).encode("utf-8"))
with client.get_stream(ticket) as get:
arrow_table = get.read_all()
ticket = Ticket(json.dumps(stream_payload).encode("utf-8"))

get = client.get_stream(ticket)
arrow_table = get.read_all()
return arrow_table.to_pandas(types_mapper=ArrowDtype) # type: ignore
6 changes: 3 additions & 3 deletions graphdatascience/procedure_surface/api/wcc_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def write(
@abstractmethod
def estimate(
self,
graph_name: Optional[str] = None,
G: Optional[Graph] = None,
projection_config: Optional[dict[str, Any]] = None,
) -> EstimationResult:
"""
Expand All @@ -259,8 +259,8 @@ def estimate(

Parameters
----------
graph_name : Optional[str], optional
Name of the graph to be used in the estimation
G : Optional[Graph], optional
The graph to be used in the estimation
projection_config : Optional[dict[str, Any]], optional
Configuration dictionary for the projection.

Expand Down
191 changes: 191 additions & 0 deletions graphdatascience/procedure_surface/arrow/wcc_arrow_endpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
import json
from typing import Any, List, Optional

from pandas import DataFrame

from ...arrow_client.authenticated_flight_client import AuthenticatedArrowClient
from ...arrow_client.v2.data_mapper_utils import deserialize_single
from ...arrow_client.v2.job_client import JobClient
from ...arrow_client.v2.mutation_client import MutationClient
from ...arrow_client.v2.write_back_client import WriteBackClient
from ...graph.graph_object import Graph
from ..api.estimation_result import EstimationResult
from ..api.wcc_endpoints import WccEndpoints, WccMutateResult, WccStatsResult, WccWriteResult
from ..utils.config_converter import ConfigConverter

WCC_ENDPOINT = "v2/community.wcc"


class WccArrowEndpoints(WccEndpoints):
def __init__(self, arrow_client: AuthenticatedArrowClient, write_back_client: Optional[WriteBackClient] = None):
self._arrow_client = arrow_client
self._write_back_client = write_back_client

def mutate(
self,
G: Graph,
mutate_property: str,
threshold: Optional[float] = None,
relationship_types: Optional[List[str]] = None,
node_labels: Optional[List[str]] = None,
sudo: Optional[bool] = None,
log_progress: Optional[bool] = None,
username: Optional[str] = None,
concurrency: Optional[int] = None,
job_id: Optional[str] = None,
seed_property: Optional[str] = None,
consecutive_ids: Optional[bool] = None,
relationship_weight_property: Optional[str] = None,
) -> WccMutateResult:
config = ConfigConverter.convert_to_gds_config(
graph_name=G.name(),
concurrency=concurrency,
consecutive_ids=consecutive_ids,
job_id=job_id,
log_progress=log_progress,
node_labels=node_labels,
relationship_types=relationship_types,
relationship_weight_property=relationship_weight_property,
seed_property=seed_property,
sudo=sudo,
threshold=threshold,
)

job_id = JobClient.run_job_and_wait(self._arrow_client, WCC_ENDPOINT, config)

mutate_result = MutationClient.mutate_node_property(self._arrow_client, job_id, mutate_property)
computation_result = JobClient.get_summary(self._arrow_client, job_id)

computation_result["nodePropertiesWritten"] = mutate_result.node_properties_written
computation_result["mutateMillis"] = 0

return WccMutateResult(**computation_result)

def stats(
self,
G: Graph,
threshold: Optional[float] = None,
relationship_types: Optional[List[str]] = None,
node_labels: Optional[List[str]] = None,
sudo: Optional[bool] = None,
log_progress: Optional[bool] = None,
username: Optional[str] = None,
concurrency: Optional[int] = None,
job_id: Optional[str] = None,
seed_property: Optional[str] = None,
consecutive_ids: Optional[bool] = None,
relationship_weight_property: Optional[str] = None,
) -> WccStatsResult:
config = ConfigConverter.convert_to_gds_config(
graph_name=G.name(),
concurrency=concurrency,
consecutive_ids=consecutive_ids,
job_id=job_id,
log_progress=log_progress,
node_labels=node_labels,
relationship_types=relationship_types,
relationship_weight_property=relationship_weight_property,
seed_property=seed_property,
sudo=sudo,
threshold=threshold,
)

job_id = JobClient.run_job_and_wait(self._arrow_client, WCC_ENDPOINT, config)
computation_result = JobClient.get_summary(self._arrow_client, job_id)

return WccStatsResult(**computation_result)

def stream(
self,
G: Graph,
min_component_size: Optional[int] = None,
threshold: Optional[float] = None,
relationship_types: Optional[List[str]] = None,
node_labels: Optional[List[str]] = None,
sudo: Optional[bool] = None,
log_progress: Optional[bool] = None,
username: Optional[str] = None,
concurrency: Optional[int] = None,
job_id: Optional[str] = None,
seed_property: Optional[str] = None,
consecutive_ids: Optional[bool] = None,
relationship_weight_property: Optional[str] = None,
) -> DataFrame:
config = ConfigConverter.convert_to_gds_config(
graph_name=G.name(),
concurrency=concurrency,
consecutive_ids=consecutive_ids,
job_id=job_id,
log_progress=log_progress,
min_component_size=min_component_size,
node_labels=node_labels,
relationship_types=relationship_types,
relationship_weight_property=relationship_weight_property,
seed_property=seed_property,
sudo=sudo,
threshold=threshold,
)

job_id = JobClient.run_job_and_wait(self._arrow_client, WCC_ENDPOINT, config)
return JobClient.stream_results(self._arrow_client, G.name(), job_id)

def write(
self,
G: Graph,
write_property: str,
min_component_size: Optional[int] = None,
threshold: Optional[float] = None,
relationship_types: Optional[List[str]] = None,
node_labels: Optional[List[str]] = None,
sudo: Optional[bool] = None,
log_progress: Optional[bool] = None,
username: Optional[str] = None,
concurrency: Optional[int] = None,
job_id: Optional[str] = None,
seed_property: Optional[str] = None,
consecutive_ids: Optional[bool] = None,
relationship_weight_property: Optional[str] = None,
write_concurrency: Optional[int] = None,
) -> WccWriteResult:
config = ConfigConverter.convert_to_gds_config(
graph_name=G.name(),
concurrency=concurrency,
consecutive_ids=consecutive_ids,
job_id=job_id,
log_progress=log_progress,
min_component_size=min_component_size,
node_labels=node_labels,
relationship_types=relationship_types,
relationship_weight_property=relationship_weight_property,
seed_property=seed_property,
sudo=sudo,
threshold=threshold,
)

job_id = JobClient.run_job_and_wait(self._arrow_client, WCC_ENDPOINT, config)
computation_result = JobClient.get_summary(self._arrow_client, job_id)

if self._write_back_client is None:
raise Exception("Write back client is not initialized")

write_millis = self._write_back_client.write(
G.name(), job_id, write_concurrency if write_concurrency is not None else concurrency
)

computation_result["writeMillis"] = write_millis

return WccWriteResult(**computation_result)

def estimate(
self, G: Optional[Graph] = None, projection_config: Optional[dict[str, Any]] = None
) -> EstimationResult:
if G is not None:
payload = {"graphName": G.name()}
elif projection_config is not None:
payload = projection_config
else:
raise ValueError("Either graph_name or projection_config must be provided.")

res = self._arrow_client.do_action_with_retry("v2/community.wcc.estimate", json.dumps(payload).encode("utf-8"))

return EstimationResult(**deserialize_single(res))
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import OrderedDict
from typing import Any, List, Optional, Union

from pandas import DataFrame
Expand Down Expand Up @@ -177,18 +178,20 @@ def write(
return WccWriteResult(**result.to_dict())

def estimate(
self, graph_name: Optional[str] = None, projection_config: Optional[dict[str, Any]] = None
self, G: Optional[Graph] = None, projection_config: Optional[dict[str, Any]] = None
) -> EstimationResult:
config: Union[str, dict[str, Any]] = {}
config: Union[dict[str, Any]] = OrderedDict()

if graph_name is not None:
config = graph_name
if G is not None:
config["graphNameOrConfiguration"] = G.name()
elif projection_config is not None:
config = projection_config
config["graphNameOrConfiguration"] = projection_config
else:
raise ValueError("Either graph_name or projection_config must be provided.")

params = CallParameters(config=config)
config["algoConfig"] = {}

params = CallParameters(**config)

result = self._query_runner.call_procedure(endpoint="gds.wcc.stats.estimate", params=params).squeeze()

Expand Down
1 change: 1 addition & 0 deletions graphdatascience/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ def pytest_addoption(parser: Any) -> None:
parser.addoption(
"--include-cloud-architecture", action="store_true", help="include tests resuiring a cloud architecture setup"
)
parser.addoption("--include-integration-v2", action="store_true", help="include integration tests for v2")
Empty file.
10 changes: 10 additions & 0 deletions graphdatascience/tests/integrationV2/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from typing import Any

import pytest


def pytest_collection_modifyitems(config: Any, items: Any) -> None:
if not config.getoption("--include-integration-v2"):
skip_v2 = pytest.mark.skip(reason="need --include-integration-v2 option to run")
for item in items:
item.add_marker(skip_v2)
Empty file.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import os
import tempfile
from typing import Generator

import pytest
from testcontainers.core.container import DockerContainer
from testcontainers.core.waiting_utils import wait_for_logs

from graphdatascience.arrow_client.arrow_authentication import UsernamePasswordAuthentication
from graphdatascience.arrow_client.arrow_info import ArrowInfo
from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient


@pytest.fixture(scope="session")
def password_file() -> Generator[str, None, None]:
"""Create a temporary file and return its path."""
temp_dir = tempfile.mkdtemp()
temp_file_path = os.path.join(temp_dir, "password")

with open(temp_file_path, "w") as f:
f.write("password")

yield temp_dir

# Clean up the file and directory
os.unlink(temp_file_path)
os.rmdir(temp_dir)


@pytest.fixture(scope="session")
def session_container(password_file: str) -> Generator[DockerContainer, None, None]:
session_image = os.getenv("GDS_SESSION_IMAGE")

if session_image is None:
raise ValueError("GDS_SESSION_IMAGE environment variable is not set")

session_container = (
DockerContainer(
image=session_image,
)
.with_env("ALLOW_LIST", "DEFAULT")
.with_env("DNS_NAME", "gds-session")
.with_env("PAGE_CACHE_SIZE", "100M")
.with_exposed_ports(8491)
.with_network_aliases(["gds-session"])
.with_volume_mapping(password_file, "/passwords")
)

with session_container as session_container:
wait_for_logs(session_container, "Running GDS tasks: 0")
yield session_container
stdout, stderr = session_container.get_logs()
print(stdout)


@pytest.fixture
def arrow_client(session_container: DockerContainer) -> AuthenticatedArrowClient:
"""Create an authenticated Arrow client connected to the session container."""
host = session_container.get_container_host_ip()
port = session_container.get_exposed_port(8491)

return AuthenticatedArrowClient.create(
arrow_info=ArrowInfo(f"{host}:{port}", True, True, ["v1", "v2"]),
auth=UsernamePasswordAuthentication("neo4j", "password"),
encrypted=False,
)
Loading