Skip to content

Commit 1e6313b

Browse files
committed
implement LtxvEditVideo API node
1 parent f17251b commit 1e6313b

File tree

3 files changed

+123
-25
lines changed

3 files changed

+123
-25
lines changed

comfy_api_nodes/nodes_ltxv.py

Lines changed: 85 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,17 @@
55
from pydantic import BaseModel, Field
66
from typing_extensions import override
77

8-
from comfy_api.input_impl import VideoFromFile
9-
from comfy_api.latest import IO, ComfyExtension
8+
from comfy_api.latest import IO, ComfyExtension, Input, InputImpl
109
from comfy_api_nodes.util import (
1110
ApiEndpoint,
1211
get_number_of_images,
1312
sync_op_raw,
1413
upload_images_to_comfyapi,
1514
validate_string,
15+
validate_video_duration,
16+
validate_video_dimensions,
17+
validate_video_frame_count,
18+
upload_video_to_comfyapi,
1619
)
1720

1821
MODELS_MAP = {
@@ -31,6 +34,14 @@ class ExecuteTaskRequest(BaseModel):
3134
image_uri: Optional[str] = Field(None)
3235

3336

37+
class VideoEditRequest(BaseModel):
38+
video_uri: str = Field(...)
39+
prompt: str = Field(...)
40+
start_time: int = Field(...)
41+
duration: int = Field(...)
42+
mode: str = Field(...)
43+
44+
3445
class TextToVideoNode(IO.ComfyNode):
3546
@classmethod
3647
def define_schema(cls):
@@ -103,7 +114,7 @@ async def execute(
103114
as_binary=True,
104115
max_retries=1,
105116
)
106-
return IO.NodeOutput(VideoFromFile(BytesIO(response)))
117+
return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(response)))
107118

108119

109120
class ImageToVideoNode(IO.ComfyNode):
@@ -183,7 +194,76 @@ async def execute(
183194
as_binary=True,
184195
max_retries=1,
185196
)
186-
return IO.NodeOutput(VideoFromFile(BytesIO(response)))
197+
return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(response)))
198+
199+
200+
class EditVideoNode(IO.ComfyNode):
201+
@classmethod
202+
def define_schema(cls):
203+
return IO.Schema(
204+
node_id="LtxvApiEditVideoNode",
205+
display_name="LTXV Video To Video",
206+
category="api node/video/LTXV",
207+
description="Edit a specific section of a video by replacing audio, video, or both using AI generation.",
208+
inputs=[
209+
IO.Video.Input("video"),
210+
IO.String.Input(
211+
"prompt",
212+
multiline=True,
213+
default="",
214+
),
215+
IO.Combo.Input("mode", options=["replace_video", "replace_audio", "replace_audio_and_video"]),
216+
IO.Float.Input("start_time", min=0.0, default=0.0),
217+
IO.Float.Input("duration", min=1.0, max=20.0, default=3),
218+
],
219+
outputs=[
220+
IO.Video.Output(),
221+
],
222+
hidden=[
223+
IO.Hidden.auth_token_comfy_org,
224+
IO.Hidden.api_key_comfy_org,
225+
IO.Hidden.unique_id,
226+
],
227+
is_api_node=True,
228+
)
229+
230+
@classmethod
231+
async def execute(
232+
cls,
233+
video: Input.Video,
234+
prompt: str,
235+
mode: str,
236+
start_time: float,
237+
duration: float,
238+
) -> IO.NodeOutput:
239+
validate_string(prompt, min_length=1, max_length=10000)
240+
validate_video_dimensions(video, max_width=3840, max_height=2160)
241+
validate_video_duration(video, max_duration=20)
242+
validate_video_frame_count(video, max_frame_count=505)
243+
video_duration = video.get_duration()
244+
if start_time >= video_duration:
245+
raise ValueError(
246+
f"Invalid start_time ({start_time}). Start time is greater than input video duration ({video_duration})"
247+
)
248+
response = await sync_op_raw(
249+
cls,
250+
# ApiEndpoint(
251+
# "https://api.ltx.video/v1/retake",
252+
# "POST",
253+
# headers={"Authorization": "Bearer PLACE_YOUR_API_KEY"},
254+
# ),
255+
ApiEndpoint("/proxy/ltx/v1/retake", "POST"),
256+
data=VideoEditRequest(
257+
video_uri=await upload_video_to_comfyapi(cls, video),
258+
prompt=prompt,
259+
mode=mode,
260+
start_time=int(start_time),
261+
duration=int(duration),
262+
),
263+
as_binary=True,
264+
max_retries=1,
265+
)
266+
return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(response)))
187267

188268

189269
class LtxvApiExtension(ComfyExtension):
@@ -192,6 +272,7 @@ async def get_node_list(self) -> list[type[IO.ComfyNode]]:
192272
return [
193273
TextToVideoNode,
194274
ImageToVideoNode,
275+
EditVideoNode,
195276
]
196277

197278

comfy_api_nodes/util/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
validate_string,
4848
validate_video_dimensions,
4949
validate_video_duration,
50+
validate_video_frame_count,
5051
)
5152

5253
__all__ = [
@@ -94,6 +95,7 @@
9495
"validate_string",
9596
"validate_video_dimensions",
9697
"validate_video_duration",
98+
"validate_video_frame_count",
9799
# Misc functions
98100
"get_fs_object_size",
99101
]

comfy_api_nodes/util/validation_utils.py

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
import logging
2-
from typing import Optional
32

43
import torch
54

6-
from comfy_api.input.video_types import VideoInput
75
from comfy_api.latest import Input
86

97

@@ -18,10 +16,10 @@ def get_image_dimensions(image: torch.Tensor) -> tuple[int, int]:
1816

1917
def validate_image_dimensions(
2018
image: torch.Tensor,
21-
min_width: Optional[int] = None,
22-
max_width: Optional[int] = None,
23-
min_height: Optional[int] = None,
24-
max_height: Optional[int] = None,
19+
min_width: int | None = None,
20+
max_width: int | None = None,
21+
min_height: int | None = None,
22+
max_height: int | None = None,
2523
):
2624
height, width = get_image_dimensions(image)
2725

@@ -37,8 +35,8 @@ def validate_image_dimensions(
3735

3836
def validate_image_aspect_ratio(
3937
image: torch.Tensor,
40-
min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4)
41-
max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1)
38+
min_ratio: tuple[float, float] | None = None, # e.g. (1, 4)
39+
max_ratio: tuple[float, float] | None = None, # e.g. (4, 1)
4240
*,
4341
strict: bool = True, # True -> (min, max); False -> [min, max]
4442
) -> float:
@@ -84,8 +82,8 @@ def validate_images_aspect_ratio_closeness(
8482

8583
def validate_aspect_ratio_string(
8684
aspect_ratio: str,
87-
min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4)
88-
max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1)
85+
min_ratio: tuple[float, float] | None = None, # e.g. (1, 4)
86+
max_ratio: tuple[float, float] | None = None, # e.g. (4, 1)
8987
*,
9088
strict: bool = False, # True -> (min, max); False -> [min, max]
9189
) -> float:
@@ -97,10 +95,10 @@ def validate_aspect_ratio_string(
9795

9896
def validate_video_dimensions(
9997
video: Input.Video,
100-
min_width: Optional[int] = None,
101-
max_width: Optional[int] = None,
102-
min_height: Optional[int] = None,
103-
max_height: Optional[int] = None,
98+
min_width: int | None = None,
99+
max_width: int | None = None,
100+
min_height: int | None = None,
101+
max_height: int | None = None,
104102
):
105103
try:
106104
width, height = video.get_dimensions()
@@ -120,8 +118,8 @@ def validate_video_dimensions(
120118

121119
def validate_video_duration(
122120
video: Input.Video,
123-
min_duration: Optional[float] = None,
124-
max_duration: Optional[float] = None,
121+
min_duration: float | None = None,
122+
max_duration: float | None = None,
125123
):
126124
try:
127125
duration = video.get_duration()
@@ -136,6 +134,23 @@ def validate_video_duration(
136134
raise ValueError(f"Video duration must be at most {max_duration}s, got {duration}s")
137135

138136

137+
def validate_video_frame_count(
138+
video: Input.Video,
139+
min_frame_count: int | None = None,
140+
max_frame_count: int | None = None,
141+
):
142+
try:
143+
frame_count = video.get_frame_count()
144+
except Exception as e:
145+
logging.error("Error getting frame count of video: %s", e)
146+
return
147+
148+
if min_frame_count is not None and min_frame_count > frame_count:
149+
raise ValueError(f"Video frame count must be at least {min_frame_count}, got {frame_count}")
150+
if max_frame_count is not None and frame_count > max_frame_count:
151+
raise ValueError(f"Video frame count must be at most {max_frame_count}, got {frame_count}")
152+
153+
139154
def get_number_of_images(images):
140155
if isinstance(images, torch.Tensor):
141156
return images.shape[0] if images.ndim >= 4 else 1
@@ -144,8 +159,8 @@ def get_number_of_images(images):
144159

145160
def validate_audio_duration(
146161
audio: Input.Audio,
147-
min_duration: Optional[float] = None,
148-
max_duration: Optional[float] = None,
162+
min_duration: float | None = None,
163+
max_duration: float | None = None,
149164
) -> None:
150165
sr = int(audio["sample_rate"])
151166
dur = int(audio["waveform"].shape[-1]) / sr
@@ -177,7 +192,7 @@ def validate_string(
177192
)
178193

179194

180-
def validate_container_format_is_mp4(video: VideoInput) -> None:
195+
def validate_container_format_is_mp4(video: Input.Video) -> None:
181196
"""Validates video container format is MP4."""
182197
container_format = video.get_container_format()
183198
if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]:
@@ -194,8 +209,8 @@ def _ratio_from_tuple(r: tuple[float, float]) -> float:
194209
def _assert_ratio_bounds(
195210
ar: float,
196211
*,
197-
min_ratio: Optional[tuple[float, float]] = None,
198-
max_ratio: Optional[tuple[float, float]] = None,
212+
min_ratio: tuple[float, float] | None = None,
213+
max_ratio: tuple[float, float] | None = None,
199214
strict: bool = True,
200215
) -> None:
201216
"""Validate a numeric aspect ratio against optional min/max ratio bounds."""

0 commit comments

Comments
 (0)