diff --git a/datajunction-server/datajunction_server/api/graphql/resolvers/nodes.py b/datajunction-server/datajunction_server/api/graphql/resolvers/nodes.py index faee76188..ef148adec 100644 --- a/datajunction-server/datajunction_server/api/graphql/resolvers/nodes.py +++ b/datajunction-server/datajunction_server/api/graphql/resolvers/nodes.py @@ -10,6 +10,7 @@ from strawberry.types import Info from datajunction_server.errors import DJNodeNotFound +from datajunction_server.sql.dag import get_downstream_nodes, get_upstream_nodes from datajunction_server.api.graphql.scalars.node import NodeName from datajunction_server.api.graphql.scalars.sql import CubeDefinition from datajunction_server.api.graphql.utils import dedupe_append, extract_fields @@ -187,3 +188,36 @@ async def resolve_metrics_and_dimensions( metrics = list(OrderedDict.fromkeys(metrics)) return metrics, dimensions + + +async def resolve_node_upstreams(info: Info, root: DBNode) -> List[DBNode]: + """ + Resolves the upstream nodes for a given node. This function extracts the requested + fields from the query and only loads those to optimize the database query. + """ + fields = extract_fields(info) + options = load_node_options(fields) + return await get_upstream_nodes( + session=info.context["session"], + node_name=root.name, + node_output_options=options, + ) + + +async def resolve_node_downstreams( + info: Info, + root: DBNode, + depth: int | None = -1, +) -> List[DBNode]: + """ + Resolves the downstream nodes for a given node. This function extracts the requested + fields from the query and only loads those to optimize the database query. + """ + fields = extract_fields(info) + options = load_node_options(fields) + return await get_downstream_nodes( + session=info.context["session"], + node_name=root.name, + depth=depth, # type: ignore + node_output_options=options, + ) diff --git a/datajunction-server/datajunction_server/api/graphql/scalars/node.py b/datajunction-server/datajunction_server/api/graphql/scalars/node.py index b566fadee..8b6af3b8b 100644 --- a/datajunction-server/datajunction_server/api/graphql/scalars/node.py +++ b/datajunction-server/datajunction_server/api/graphql/scalars/node.py @@ -99,6 +99,65 @@ async def dimension_node(self, info: Info) -> "Node": ) +@strawberry.type +class Node: + """ + A DJ node + """ + + id: BigInt + name: str + type: NodeType # type: ignore + current_version: str + created_at: datetime.datetime + deactivated_at: Optional[datetime.datetime] + + current: "NodeRevision" + revisions: List["NodeRevision"] + + tags: List["TagBase"] + created_by: User + + @strawberry.field + def edited_by(self, root: "DBNode") -> List[str]: + """ + The users who edited this node + """ + return root.edited_by + + @strawberry.field + async def upstreams( + self, + root: "DBNodeRevision", + info: Info, + depth: int | None = 1, + ) -> list["Node"]: + """ + The upstream nodes of this node + """ + from datajunction_server.api.graphql.resolvers.nodes import ( + resolve_node_upstreams, + ) + + return await resolve_node_upstreams(info, root) # type: ignore + + @strawberry.field + async def downstreams( + self, + root: "DBNodeRevision", + info: Info, + depth: int | None = 1, + ) -> list["Node"]: + """ + The downstream nodes of this node + """ + from datajunction_server.api.graphql.resolvers.nodes import ( + resolve_node_downstreams, + ) + + return await resolve_node_downstreams(info, root, depth) # type: ignore + + @strawberry.type class NodeRevision: """ @@ -289,30 +348,3 @@ class TagBase: description: str | None display_name: str | None tag_metadata: JSON | None = strawberry.field(default_factory=dict) - - -@strawberry.type -class Node: - """ - A DJ node - """ - - id: BigInt - name: str - type: NodeType # type: ignore - current_version: str - created_at: datetime.datetime - deactivated_at: Optional[datetime.datetime] - - current: NodeRevision - revisions: List[NodeRevision] - - tags: List[TagBase] - created_by: User - - @strawberry.field - def edited_by(self, root: "DBNode") -> List[str]: - """ - The users who edited this node - """ - return root.edited_by diff --git a/datajunction-server/datajunction_server/api/graphql/schema.graphql b/datajunction-server/datajunction_server/api/graphql/schema.graphql index e2f902ba7..eeb554615 100644 --- a/datajunction-server/datajunction_server/api/graphql/schema.graphql +++ b/datajunction-server/datajunction_server/api/graphql/schema.graphql @@ -231,6 +231,8 @@ type Node { tags: [TagBase!]! createdBy: User! editedBy: [String!]! + upstreams(depth: Int = 1): [Node!]! + downstreams(depth: Int = 1): [Node!]! } type NodeConnection { diff --git a/datajunction-server/datajunction_server/sql/dag.py b/datajunction-server/datajunction_server/sql/dag.py index f5c20a43e..11d9b7dc3 100644 --- a/datajunction-server/datajunction_server/sql/dag.py +++ b/datajunction-server/datajunction_server/sql/dag.py @@ -61,6 +61,7 @@ async def get_downstream_nodes( include_deactivated: bool = True, include_cubes: bool = True, depth: int = -1, + node_output_options: Optional[List] = None, ) -> List[Node]: """ Gets all downstream children of the given node, filterable by node type. @@ -147,6 +148,7 @@ async def get_upstream_nodes( node_name: str, node_type: NodeType = None, include_deactivated: bool = True, + node_output_options: Optional[List] = None, ) -> List[Node]: """ Gets all upstreams of the given node, filterable by node type. @@ -216,7 +218,7 @@ async def get_upstream_nodes( (Node.current_version == NodeRevision.version) & (Node.id == NodeRevision.node_id), ) - .options(*_node_output_options()) + .options(*(node_output_options or _node_output_options())) ) results = (await session.execute(statement)).unique().scalars().all()