|
1 | 1 | import datetime |
2 | 2 | import json |
3 | 3 | from itertools import chain, product |
4 | | -from typing import Generator, Literal, Optional, TypeAlias, Union, cast |
| 4 | +from typing import Generator, Literal, Optional, TypeAlias, Union |
5 | 5 |
|
6 | 6 | from pydantic import ( |
7 | 7 | AliasChoices, |
|
15 | 15 | ) |
16 | 16 | from pydantic_core import to_jsonable_python |
17 | 17 |
|
18 | | -from invokeai.app.invocations.baseinvocation import BaseInvocation |
19 | 18 | from invokeai.app.invocations.fields import ImageField |
20 | 19 | from invokeai.app.services.shared.graph import Graph, GraphExecutionState, NodeNotFoundError |
21 | 20 | from invokeai.app.services.workflow_records.workflow_records_common import ( |
@@ -137,20 +136,18 @@ def validate_unique_field_mappings(cls, v: Optional[BatchDataCollection]): |
137 | 136 | return v |
138 | 137 |
|
139 | 138 | @model_validator(mode="after") |
140 | | - def validate_batch_nodes_and_edges(cls, values): |
141 | | - batch_data_collection = cast(Optional[BatchDataCollection], values.data) |
142 | | - if batch_data_collection is None: |
143 | | - return values |
144 | | - graph = cast(Graph, values.graph) |
145 | | - for batch_data_list in batch_data_collection: |
| 139 | + def validate_batch_nodes_and_edges(self): |
| 140 | + if self.data is None: |
| 141 | + return self |
| 142 | + for batch_data_list in self.data: |
146 | 143 | for batch_data in batch_data_list: |
147 | 144 | try: |
148 | | - node = cast(BaseInvocation, graph.get_node(batch_data.node_path)) |
| 145 | + node = self.graph.get_node(batch_data.node_path) |
149 | 146 | except NodeNotFoundError: |
150 | 147 | raise NodeNotFoundError(f"Node {batch_data.node_path} not found in graph") |
151 | 148 | if batch_data.field_name not in type(node).model_fields: |
152 | 149 | raise NodeNotFoundError(f"Field {batch_data.field_name} not found in node {batch_data.node_path}") |
153 | | - return values |
| 150 | + return self |
154 | 151 |
|
155 | 152 | @field_validator("graph") |
156 | 153 | def validate_graph(cls, v: Graph): |
|
0 commit comments