Skip to content

Commit 70d4c0a

Browse files
authored
Merge pull request #53 from neo4j/streamlit-example
Add streamlit example
2 parents e28a23e + 981e514 commit 70d4c0a

File tree

10 files changed

+148
-7
lines changed

10 files changed

+148
-7
lines changed
91.1 KB
Binary file not shown.
25.7 KB
Binary file not shown.

examples/streamlit-example.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import streamlit as st
2+
from IPython.display import HTML
3+
import streamlit.components.v1 as components
4+
from pandas import read_parquet
5+
import pathlib
6+
7+
from neo4j_viz.pandas import from_dfs
8+
from neo4j_viz import VisualizationGraph
9+
10+
# Path to this file
11+
script_path = pathlib.Path(__file__).resolve()
12+
script_dir_path = pathlib.Path(__file__).parent.resolve()
13+
14+
15+
@st.cache_data
16+
def create_visualization_graph() -> VisualizationGraph:
17+
cora_nodes_path = f"{script_dir_path}/datasets/cora/cora_nodes.parquet.gzip"
18+
cora_rels_path = f"{script_dir_path}/datasets/cora/cora_rels.parquet.gzip"
19+
20+
nodes_df = read_parquet(cora_nodes_path)
21+
rels_df = read_parquet(cora_rels_path)
22+
23+
# Drop the features column since it's not needed for visualization
24+
# Also numpy arrays are not supported by the visualization library
25+
nodes_df.drop(columns="features", inplace=True)
26+
27+
VG = from_dfs(nodes_df, rels_df)
28+
VG.color_nodes("subject")
29+
30+
return VG
31+
32+
33+
@st.cache_data
34+
def render_graph(
35+
_VG: VisualizationGraph, height: int, initial_zoom: float = 0.1
36+
) -> HTML:
37+
return VG.render(initial_zoom=initial_zoom, height=f"{height}px")
38+
39+
40+
VG = create_visualization_graph()
41+
42+
st.title("Neo4j Viz Streamlit Example")
43+
st.text(
44+
"This is an example of how to use Streamlit with the Graph "
45+
"Visualization for Python library by Neo4j."
46+
)
47+
48+
with st.sidebar:
49+
height = st.slider("Height in pixels", 100, 2000, 600, 50)
50+
show_code = st.checkbox("Show code")
51+
52+
st.header("Visualization")
53+
st.text(
54+
"A visualization of the famous Cora citation network. Each of its "
55+
"seven scientific subjects is represented by a different color."
56+
)
57+
58+
components.html(
59+
render_graph(VG, height=height).data,
60+
height=height,
61+
)
62+
63+
if show_code:
64+
st.header("Code")
65+
st.code(script_path.read_text())

python-wrapper/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ dev = [
4747
"palettable==3.3.3",
4848
"pytest-mock==3.14.0",
4949
"nbconvert==7.16.5",
50+
"streamlit==1.41.1",
5051
]
5152
pandas = ["pandas>=2, <3", "pandas-stubs>=2, <3"]
5253
gds = ["graphdatascience>=1, <2"] # not compatible yet with Python 3.13

python-wrapper/src/neo4j_viz/node.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from typing import Any, Optional, Union
44

5-
from pydantic import BaseModel, Field, field_serializer, field_validator
5+
from pydantic import AliasChoices, BaseModel, Field, field_serializer, field_validator
66
from pydantic_extra_types.color import Color, ColorType
77

88
from .options import CaptionAlignment
@@ -14,7 +14,9 @@ class Node(BaseModel, extra="allow"):
1414
All options available in the NVL library (see https://neo4j.com/docs/nvl/current/base-library/#_nodes)
1515
"""
1616

17-
id: Union[str, int] = Field(description="Unique identifier for the node")
17+
id: Union[str, int] = Field(
18+
validation_alias=AliasChoices("id", "nodeId", "node_id"), description="Unique identifier for the node"
19+
)
1820
caption: Optional[str] = Field(None, description="The caption of the node")
1921
caption_align: Optional[CaptionAlignment] = Field(
2022
None, serialization_alias="captionAlign", description="The alignment of the caption text"

python-wrapper/src/neo4j_viz/nvl.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@ def __init__(self) -> None:
2323
with js_path.open("r", encoding="utf-8") as file:
2424
self.library_code = file.read()
2525

26+
def unsupported_field_type_error(self, e: TypeError, entity: str) -> Exception:
27+
if "not JSON serializable" in str(e):
28+
return ValueError(f"A field of a {entity} object is not supported: {str(e)}")
29+
return e
30+
2631
def render(
2732
self,
2833
nodes: list[Node],
@@ -31,8 +36,14 @@ def render(
3136
width: str,
3237
height: str,
3338
) -> HTML:
34-
nodes_json = json.dumps([node.to_dict() for node in nodes])
35-
rels_json = json.dumps([rel.to_dict() for rel in relationships])
39+
try:
40+
nodes_json = json.dumps([node.to_dict() for node in nodes])
41+
except TypeError as e:
42+
raise self.unsupported_field_type_error(e, "node")
43+
try:
44+
rels_json = json.dumps([rel.to_dict() for rel in relationships])
45+
except TypeError as e:
46+
raise self.unsupported_field_type_error(e, "relationship")
3647

3748
render_options_json = json.dumps(render_options.to_dict())
3849

python-wrapper/src/neo4j_viz/relationship.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Any, Optional, Union
44
from uuid import uuid4
55

6-
from pydantic import BaseModel, Field, field_serializer, field_validator
6+
from pydantic import AliasChoices, BaseModel, Field, field_serializer, field_validator
77
from pydantic_extra_types.color import Color, ColorType
88

99
from .options import CaptionAlignment
@@ -19,9 +19,15 @@ class Relationship(BaseModel, extra="allow"):
1919
default_factory=lambda: uuid4().hex, description="Unique identifier for the relationship"
2020
)
2121
source: Union[str, int] = Field(
22-
serialization_alias="from", description="Node ID where the relationship points from"
22+
serialization_alias="from",
23+
validation_alias=AliasChoices("source", "sourceNodeId", "source_node_id", "from"),
24+
description="Node ID where the relationship points from",
25+
)
26+
target: Union[str, int] = Field(
27+
serialization_alias="to",
28+
validation_alias=AliasChoices("target", "targetNodeId", "target_node_id", "to"),
29+
description="Node ID where the relationship points to",
2330
)
24-
target: Union[str, int] = Field(serialization_alias="to", description="Node ID where the relationship points to")
2531
caption: Optional[str] = Field(None, description="The caption of the relationship")
2632
caption_align: Optional[CaptionAlignment] = Field(
2733
None, serialization_alias="captionAlign", description="The alignment of the caption text"

python-wrapper/tests/test_node.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import pytest
2+
13
from neo4j_viz import CaptionAlignment, Node
24

35

@@ -53,3 +55,12 @@ def test_node_with_additional_fields() -> None:
5355
"id": "1",
5456
"componentId": 2,
5557
}
58+
59+
60+
@pytest.mark.parametrize("alias", ["id", "nodeId", "node_id"])
61+
def test_id_aliases(alias: str) -> None:
62+
node = Node(**{alias: 1})
63+
64+
assert node.to_dict() == {
65+
"id": "1",
66+
}

python-wrapper/tests/test_relationship.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import pytest
2+
13
from neo4j_viz import CaptionAlignment
24
from neo4j_viz.relationship import Relationship
35

@@ -47,3 +49,20 @@ def test_rels_additional_fields() -> None:
4749
rel_dict = rel.to_dict()
4850
assert {"id", "from", "to", "componentId"} == set(rel_dict.keys())
4951
assert rel.componentId == 2 # type: ignore[attr-defined]
52+
53+
54+
@pytest.mark.parametrize("src_alias", ["source", "sourceNodeId", "source_node_id", "from"])
55+
@pytest.mark.parametrize("trg_alias", ["target", "targetNodeId", "target_node_id", "to"])
56+
def test_aliases(src_alias: str, trg_alias: str) -> None:
57+
rel = Relationship(
58+
**{
59+
src_alias: "1",
60+
trg_alias: "2",
61+
}
62+
)
63+
64+
rel_dict = rel.to_dict()
65+
66+
assert {"id", "from", "to"} == set(rel_dict.keys())
67+
assert rel_dict["from"] == "1"
68+
assert rel_dict["to"] == "2"

python-wrapper/tests/test_render.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,29 @@ def test_basic_render(render_option: dict[str, Any], tmp_path: Path) -> None:
6060
severe_logs = [log for log in logs if log["level"] == "SEVERE"]
6161

6262
assert not severe_logs, f"Severe logs found: {severe_logs}, all logs: {logs}"
63+
64+
65+
def test_unsupported_field_type() -> None:
66+
with pytest.raises(
67+
ValueError, match="A field of a node object is not supported: Object of type set is not JSON serializable"
68+
):
69+
nodes = [
70+
Node(id="4:d09f48a4-5fca-421d-921d-a30a896c604d:0", caption="Person", unsupported=set([1, 2, 3])),
71+
]
72+
VG = VisualizationGraph(nodes=nodes, relationships=[])
73+
VG.render()
74+
75+
with pytest.raises(
76+
ValueError,
77+
match="A field of a relationship object is not supported: Object of type set is not JSON serializable",
78+
):
79+
relationships = [
80+
Relationship(
81+
source="4:d09f48a4-5fca-421d-921d-a30a896c604d:0",
82+
target="4:d09f48a4-5fca-421d-921d-a30a896c604d:6",
83+
caption="BUYS",
84+
unsupported=set([1, 2, 3]),
85+
),
86+
]
87+
VG = VisualizationGraph(nodes=[], relationships=relationships)
88+
VG.render()

0 commit comments

Comments
 (0)