Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions vggt/utils/load_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def load_and_preprocess_images_square(image_path_list, target_size=1024):
return images, original_coords


def load_and_preprocess_images(image_path_list, mode="crop"):
def load_and_preprocess_images(image_path_list, mode="crop", grayscale=False):
"""
A quick start function to load and preprocess images for model input.
This assumes the images should have the same shape for easier batching, but our model can also work well with different shapes.
Expand All @@ -105,6 +105,7 @@ def load_and_preprocess_images(image_path_list, mode="crop"):
- "crop" (default): Sets width to 518px and center crops height if needed.
- "pad": Preserves all pixels by making the largest dimension 518px
and padding the smaller dimension to reach a square shape.
grayscale (bool, optional): If True, convert images to grayscale. Default is False.

Returns:
torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, H, W)
Expand Down Expand Up @@ -140,14 +141,14 @@ def load_and_preprocess_images(image_path_list, mode="crop"):
img = Image.open(image_path)

# If there's an alpha channel, blend onto white background:
if img.mode == "RGBA":
if not grayscale and img.mode == "RGBA":
# Create white background
background = Image.new("RGBA", img.size, (255, 255, 255, 255))
# Alpha composite onto the white background
img = Image.alpha_composite(background, img)

# Now convert to "RGB" (this step assigns white for transparent areas)
img = img.convert("RGB")
img = img.convert("L" if grayscale else "RGB")

width, height = img.size

Expand Down