Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from strawberry.types import Info

from datajunction_server.errors import DJNodeNotFound
from datajunction_server.sql.dag import get_downstream_nodes, get_upstream_nodes
from datajunction_server.api.graphql.scalars.node import NodeName
from datajunction_server.api.graphql.scalars.sql import CubeDefinition
from datajunction_server.api.graphql.utils import dedupe_append, extract_fields
Expand Down Expand Up @@ -187,3 +188,36 @@ async def resolve_metrics_and_dimensions(

metrics = list(OrderedDict.fromkeys(metrics))
return metrics, dimensions


async def resolve_node_upstreams(info: Info, root: DBNode) -> List[DBNode]:
"""
Resolves the upstream nodes for a given node. This function extracts the requested
fields from the query and only loads those to optimize the database query.
"""
fields = extract_fields(info)
options = load_node_options(fields)
return await get_upstream_nodes(
session=info.context["session"],
node_name=root.name,
node_output_options=options,
)


async def resolve_node_downstreams(
info: Info,
root: DBNode,
depth: int | None = -1,
) -> List[DBNode]:
"""
Resolves the downstream nodes for a given node. This function extracts the requested
fields from the query and only loads those to optimize the database query.
"""
fields = extract_fields(info)
options = load_node_options(fields)
return await get_downstream_nodes(
session=info.context["session"],
node_name=root.name,
depth=depth, # type: ignore
node_output_options=options,
)
86 changes: 59 additions & 27 deletions datajunction-server/datajunction_server/api/graphql/scalars/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,65 @@ async def dimension_node(self, info: Info) -> "Node":
)


@strawberry.type
class Node:
"""
A DJ node
"""

id: BigInt
name: str
type: NodeType # type: ignore
current_version: str
created_at: datetime.datetime
deactivated_at: Optional[datetime.datetime]

current: "NodeRevision"
revisions: List["NodeRevision"]

tags: List["TagBase"]
created_by: User

@strawberry.field
def edited_by(self, root: "DBNode") -> List[str]:
"""
The users who edited this node
"""
return root.edited_by

@strawberry.field
async def upstreams(
self,
root: "DBNodeRevision",
info: Info,
depth: int | None = 1,
) -> list["Node"]:
"""
The upstream nodes of this node
"""
from datajunction_server.api.graphql.resolvers.nodes import (
resolve_node_upstreams,
)

return await resolve_node_upstreams(info, root) # type: ignore

@strawberry.field
async def downstreams(
self,
root: "DBNodeRevision",
info: Info,
depth: int | None = 1,
) -> list["Node"]:
"""
The downstream nodes of this node
"""
from datajunction_server.api.graphql.resolvers.nodes import (
resolve_node_downstreams,
)

return await resolve_node_downstreams(info, root, depth) # type: ignore


@strawberry.type
class NodeRevision:
"""
Expand Down Expand Up @@ -289,30 +348,3 @@ class TagBase:
description: str | None
display_name: str | None
tag_metadata: JSON | None = strawberry.field(default_factory=dict)


@strawberry.type
class Node:
"""
A DJ node
"""

id: BigInt
name: str
type: NodeType # type: ignore
current_version: str
created_at: datetime.datetime
deactivated_at: Optional[datetime.datetime]

current: NodeRevision
revisions: List[NodeRevision]

tags: List[TagBase]
created_by: User

@strawberry.field
def edited_by(self, root: "DBNode") -> List[str]:
"""
The users who edited this node
"""
return root.edited_by
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,8 @@ type Node {
tags: [TagBase!]!
createdBy: User!
editedBy: [String!]!
upstreams(depth: Int = 1): [Node!]!
downstreams(depth: Int = 1): [Node!]!
}

type NodeConnection {
Expand Down
4 changes: 3 additions & 1 deletion datajunction-server/datajunction_server/sql/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ async def get_downstream_nodes(
include_deactivated: bool = True,
include_cubes: bool = True,
depth: int = -1,
node_output_options: Optional[List] = None,
) -> List[Node]:
"""
Gets all downstream children of the given node, filterable by node type.
Expand Down Expand Up @@ -147,6 +148,7 @@ async def get_upstream_nodes(
node_name: str,
node_type: NodeType = None,
include_deactivated: bool = True,
node_output_options: Optional[List] = None,
) -> List[Node]:
"""
Gets all upstreams of the given node, filterable by node type.
Expand Down Expand Up @@ -216,7 +218,7 @@ async def get_upstream_nodes(
(Node.current_version == NodeRevision.version)
& (Node.id == NodeRevision.node_id),
)
.options(*_node_output_options())
.options(*(node_output_options or _node_output_options()))
)

results = (await session.execute(statement)).unique().scalars().all()
Expand Down
Loading