Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion diffsynth/trainers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,17 @@ def get_num_frames(self, reader):

def load_video(self, file_path):
reader = imageio.get_reader(file_path)
total_frames = int(reader.count_frames())
num_frames = self.get_num_frames(reader)
start_idx = 0
import random
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Imports should be at the top of the file, not inside a function. This follows PEP 8 style guidelines and makes dependencies clear. Please move import random to the top of diffsynth/trainers/utils.py.

if total_frames>num_frames:
# 计算随机起始位置(确保能截取到81帧)
max_start_idx = total_frames - num_frames
start_idx = random.randint(0, max_start_idx)

Comment on lines +269 to +273
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There are a few style improvements that can be made here to align with PEP 8:

  • The comment should be in English for consistency.
  • The blank line after the if block is unnecessary.
  • Comparison operators should have a single space on either side (e.g., total_frames > num_frames).
Suggested change
if total_frames>num_frames:
# 计算随机起始位置(确保能截取到81帧)
max_start_idx = total_frames - num_frames
start_idx = random.randint(0, max_start_idx)
if total_frames > num_frames:
# Calculate a random start index to ensure `num_frames` can be sampled.
max_start_idx = total_frames - num_frames
start_idx = random.randint(0, max_start_idx)

frames = []
for frame_id in range(num_frames):
for frame_id in range(start_idx,start_idx+num_frames):
frame = reader.get_data(frame_id)
frame = Image.fromarray(frame)
frame = self.crop_and_resize(frame, *self.get_height_width(frame))
Expand Down