Skip to content

Commit 8e0e11b

Browse files
committed
Add pythonic projection to Cypher mapper DSL
1 parent 5ec9f83 commit 8e0e11b

File tree

3 files changed

+469
-0
lines changed

3 files changed

+469
-0
lines changed
Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
from collections import namedtuple
2+
from typing import Any, NamedTuple, Optional, Tuple
3+
4+
from pandas import Series
5+
6+
from ..error.illegal_attr_checker import IllegalAttrChecker
7+
from ..query_runner.query_runner import QueryRunner
8+
from ..server_version.server_version import ServerVersion
9+
from .graph_object import Graph
10+
from .graph_type_check import from_graph_type_check
11+
12+
13+
class NodeProperty(NamedTuple):
14+
name: str
15+
property_key: str
16+
default_value: Optional[Any] = None
17+
18+
19+
class NodeProjection(NamedTuple):
20+
name: str
21+
source_label: str
22+
properties: Optional[list[NodeProperty]] = None
23+
24+
25+
class RelationshipProperty(NamedTuple):
26+
name: str
27+
property_key: str
28+
default_value: Optional[Any] = None
29+
30+
31+
class RelationshipProjection(NamedTuple):
32+
name: str
33+
source_type: str
34+
properties: Optional[list[RelationshipProperty]] = None
35+
36+
37+
class MatchPart(NamedTuple):
38+
match: str = ""
39+
source_where: str = ""
40+
optional_match: str = ""
41+
optional_where: str = ""
42+
43+
def __str__(self) -> str:
44+
return "\n".join(
45+
part
46+
for part in [
47+
self.match,
48+
self.source_where,
49+
self.optional_match,
50+
self.optional_where,
51+
]
52+
if part
53+
)
54+
55+
56+
class MatchPattern(NamedTuple):
57+
label_filter: str = ""
58+
left_arrow: str = ""
59+
type_filter: str = ""
60+
right_arrow: str = ""
61+
62+
def __str__(self) -> str:
63+
return f"{self.left_arrow}{self.type_filter}{self.right_arrow}(target{self.label_filter})"
64+
65+
66+
class GraphCypherRunner(IllegalAttrChecker):
67+
def __init__(self, query_runner: QueryRunner, namespace: str, server_version: ServerVersion) -> None:
68+
if server_version < ServerVersion(2, 4, 0):
69+
raise ValueError("The new Cypher projection is only supported since GDS 2.4.0.")
70+
super().__init__(query_runner, namespace, server_version)
71+
72+
def project(
73+
self,
74+
graph_name: str,
75+
*,
76+
nodes: Any = None,
77+
relationships: Any = None,
78+
where: Optional[str] = None,
79+
allow_disconnected_nodes: bool = False,
80+
inverse: bool = False,
81+
combine_labels_with: str = "OR",
82+
**config: Any,
83+
) -> Tuple[Graph, "Series[Any]"]:
84+
"""
85+
Project a graph using Cypher projection.
86+
87+
Parameters
88+
----------
89+
graph_name : str
90+
The name of the graph to project.
91+
nodes : Any
92+
The nodes to project. If not specified, all nodes are projected.
93+
relationships : Any
94+
The relationships to project. If not specified, all relationships
95+
are projected.
96+
where : Optional[str]
97+
A Cypher WHERE clause to filter the nodes and relationships to
98+
project.
99+
allow_disconnected_nodes : bool
100+
Whether to allow disconnected nodes in the projected graph.
101+
inverse : bool
102+
Whether to project inverse relationships. The projected graph will
103+
be configured as NATURAL.
104+
combine_labels_with : str
105+
Whether to combine node labels with AND or OR. The default is AND.
106+
Allowed values are 'AND' and 'OR'.
107+
**config : Any
108+
Additional configuration for the projection.
109+
"""
110+
111+
query_params = {"graph_name": graph_name}
112+
113+
data_config = {}
114+
115+
nodes = self._node_projections_spec(nodes)
116+
rels = self._rel_projections_spec(relationships)
117+
118+
match_part = MatchPart()
119+
match_pattern = MatchPattern(
120+
left_arrow="<-" if inverse else "-",
121+
right_arrow="-" if inverse else "->",
122+
)
123+
124+
if nodes:
125+
if len(nodes) == 1 or combine_labels_with == "AND":
126+
match_pattern = match_pattern._replace(label_filter=f":{':'.join(spec.source_label for spec in nodes)}")
127+
128+
projected_labels = [spec.name for spec in nodes]
129+
data_config["sourceNodeLabels"] = projected_labels
130+
data_config["targetNodeLabels"] = projected_labels
131+
132+
elif combine_labels_with == "OR":
133+
source_labels_filter = " OR ".join(f"source:{spec.source_label}" for spec in nodes)
134+
target_labels_filter = " OR ".join(f"target:{spec.source_label}" for spec in nodes)
135+
if allow_disconnected_nodes:
136+
match_part = match_part._replace(
137+
source_where=f"WHERE {source_labels_filter}", optional_where=f"WHERE {target_labels_filter}"
138+
)
139+
else:
140+
match_part = match_part._replace(
141+
source_where=f"WHERE ({source_labels_filter}) AND ({target_labels_filter})"
142+
)
143+
144+
data_config["sourceNodeLabels"] = "labels(source)"
145+
data_config["targetNodeLabels"] = "labels(target)"
146+
else:
147+
raise ValueError(f"Invalid value for combine_labels_with: {combine_labels_with}")
148+
149+
if rels:
150+
if len(rels) == 1:
151+
rel_var = ""
152+
data_config["relationshipType"] = rels[0].source_type
153+
else:
154+
rel_var = "rel"
155+
data_config["relationshipTypes"] = "type(rel)"
156+
match_pattern = match_pattern._replace(
157+
type_filter=f"[{rel_var}:{'|'.join(spec.source_type for spec in rels)}]"
158+
)
159+
160+
source = f"(source{match_pattern.label_filter})"
161+
if allow_disconnected_nodes:
162+
match_part = match_part._replace(
163+
match=f"MATCH {source}", optional_match=f"OPTIONAL MATCH (source){match_pattern}"
164+
)
165+
else:
166+
match_part = match_part._replace(match=f"MATCH {source}{match_pattern}")
167+
168+
match_part = str(match_part)
169+
170+
args = ["$graph_name", "source", "target"]
171+
172+
if data_config:
173+
query_params["data_config"] = data_config
174+
args += ["$data_config"]
175+
176+
if config:
177+
query_params["config"] = config
178+
args += ["$config"]
179+
180+
return_part = f"RETURN {self._namespace}({', '.join(args)})"
181+
182+
query = "\n".join(part for part in [match_part, return_part] if part)
183+
184+
print(query)
185+
186+
result = self._query_runner.run_query_with_logging(
187+
query,
188+
query_params,
189+
).squeeze()
190+
191+
return Graph(graph_name, self._query_runner, self._server_version), result
192+
193+
def _node_projections_spec(self, spec: Any) -> list[NodeProjection]:
194+
if spec is None or spec is False:
195+
return []
196+
197+
if isinstance(spec, str):
198+
spec = [spec]
199+
200+
if isinstance(spec, list):
201+
return [self._node_projection_spec(node) for node in spec]
202+
203+
if isinstance(spec, dict):
204+
return [self._node_projection_spec(node, name) for name, node in spec.items()]
205+
206+
raise TypeError(f"Invalid node projection specification: {spec}")
207+
208+
def _node_projection_spec(self, spec: Any, name: Optional[str] = None) -> NodeProjection:
209+
if isinstance(spec, str):
210+
return NodeProjection(name=name or spec, source_label=spec)
211+
212+
raise TypeError(f"Invalid node projection specification: {spec}")
213+
214+
def _node_properties_spec(self, properties: dict[str, Any]) -> list[NodeProperty]:
215+
raise TypeError(f"Invalid node projection specification: {properties}")
216+
217+
def _rel_projections_spec(self, spec: Any) -> list[RelationshipProjection]:
218+
if spec is None or spec is False:
219+
return []
220+
221+
if isinstance(spec, str):
222+
spec = [spec]
223+
224+
if isinstance(spec, list):
225+
return [self._rel_projection_spec(node) for node in spec]
226+
227+
if isinstance(spec, dict):
228+
return [self._rel_projection_spec(node, name) for name, node in spec.items()]
229+
230+
raise TypeError(f"Invalid relationship projection specification: {spec}")
231+
232+
def _rel_projection_spec(self, spec: Any, name: Optional[str] = None) -> RelationshipProjection:
233+
if isinstance(spec, str):
234+
return RelationshipProjection(name=name or spec, source_type=spec)
235+
236+
raise TypeError(f"Invalid relationship projection specification: {spec}")
237+
238+
def _rel_properties_spec(self, properties: dict[str, Any]) -> list[RelationshipProperty]:
239+
raise TypeError(f"Invalid relationship projection specification: {properties}")

graphdatascience/graph/graph_proc_runner.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from .graph_sample_runner import GraphSampleRunner
2828
from .graph_type_check import graph_type_check, graph_type_check_optional
2929
from .ogb_loader import OGBLLoader, OGBNLoader
30+
from graphdatascience.graph.graph_cypher_runner import GraphCypherRunner
3031

3132
Strings = Union[str, List[str]]
3233

@@ -165,6 +166,11 @@ def project(self) -> GraphProjectRunner:
165166
self._namespace += ".project"
166167
return GraphProjectRunner(self._query_runner, self._namespace, self._server_version)
167168

169+
@property
170+
def cypher(self) -> GraphCypherRunner:
171+
self._namespace += ".project"
172+
return GraphCypherRunner(self._query_runner, self._namespace, self._server_version)
173+
168174
@property
169175
def export(self) -> GraphExportRunner:
170176
self._namespace += ".export"

0 commit comments

Comments
 (0)