diff --git a/demo_gradio.py b/demo_gradio.py index 8dcef991..ae73f571 100644 --- a/demo_gradio.py +++ b/demo_gradio.py @@ -49,8 +49,9 @@ def run_model(target_dir, model) -> dict: # Device check device = "cuda" if torch.cuda.is_available() else "cpu" + if not torch.cuda.is_available(): - raise ValueError("CUDA is not available. Check your environment.") + print("CUDA is not available. Running on CPU.") # Move model to device model = model.to(device) @@ -68,27 +69,39 @@ def run_model(target_dir, model) -> dict: # Run inference print("Running inference...") - dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16 - + if device == "cuda": + dtype = ( + torch.bfloat16 + if torch.cuda.get_device_capability()[0] >= 8 + else torch.float16 + ) + else: + dtype = torch.float16 with torch.no_grad(): with torch.cuda.amp.autocast(dtype=dtype): predictions = model(images) # Convert pose encoding to extrinsic and intrinsic matrices print("Converting pose encoding to extrinsic and intrinsic matrices...") - extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], images.shape[-2:]) + extrinsic, intrinsic = pose_encoding_to_extri_intri( + predictions["pose_enc"], images.shape[-2:] + ) predictions["extrinsic"] = extrinsic predictions["intrinsic"] = intrinsic # Convert tensors to numpy for key in predictions.keys(): if isinstance(predictions[key], torch.Tensor): - predictions[key] = predictions[key].cpu().numpy().squeeze(0) # remove batch dimension + predictions[key] = ( + predictions[key].cpu().numpy().squeeze(0) + ) # remove batch dimension # Generate world points from depth map print("Computing world points from depth map...") depth_map = predictions["depth"] # (S, H, W, 1) - world_points = unproject_depth_map_to_point_map(depth_map, predictions["extrinsic"], predictions["intrinsic"]) + world_points = unproject_depth_map_to_point_map( + depth_map, predictions["extrinsic"], predictions["intrinsic"] + ) predictions["world_points_from_depth"] = world_points # Clean up @@ -99,7 +112,7 @@ def run_model(target_dir, model) -> dict: # ------------------------------------------------------------------------- # 2) Handle uploaded video/images --> produce target_dir + images # ------------------------------------------------------------------------- -def handle_uploads(input_video, input_images): +def handle_uploads(input_video, input_images, input_masks): """ Create a new 'target_dir' + 'images' subfolder, and place user-uploaded images or extracted frames from video into it. Return (target_dir, image_paths). @@ -132,6 +145,21 @@ def handle_uploads(input_video, input_images): shutil.copy(file_path, dst_path) image_paths.append(dst_path) + mask_paths = [] + + # --- Handle masks --- + if input_masks is not None: + target_dir_masks = os.path.join(target_dir, "masks") + os.makedirs(target_dir_masks) + for file_data in input_masks: + if isinstance(file_data, dict) and "name" in file_data: + file_path = file_data["name"] + else: + file_path = file_data + dst_path = os.path.join(target_dir_masks, os.path.basename(file_path)) + shutil.copy(file_path, dst_path) + mask_paths.append(dst_path) + # --- Handle video --- if input_video is not None: if isinstance(input_video, dict) and "name" in input_video: @@ -151,7 +179,9 @@ def handle_uploads(input_video, input_images): break count += 1 if count % frame_interval == 0: - image_path = os.path.join(target_dir_images, f"{video_frame_num:06}.png") + image_path = os.path.join( + target_dir_images, f"{video_frame_num:06}.png" + ) cv2.imwrite(image_path, frame) image_paths.append(image_path) video_frame_num += 1 @@ -160,14 +190,16 @@ def handle_uploads(input_video, input_images): image_paths = sorted(image_paths) end_time = time.time() - print(f"Files copied to {target_dir_images}; took {end_time - start_time:.3f} seconds") + print( + f"Files copied to {target_dir_images}; took {end_time - start_time:.3f} seconds" + ) return target_dir, image_paths # ------------------------------------------------------------------------- # 3) Update gallery on upload # ------------------------------------------------------------------------- -def update_gallery_on_upload(input_video, input_images): +def update_gallery_on_upload(input_video, input_images, input_masks): """ Whenever user uploads or changes files, immediately handle them and show in the gallery. Return (target_dir, image_paths). @@ -175,8 +207,13 @@ def update_gallery_on_upload(input_video, input_images): """ if not input_video and not input_images: return None, None, None, None - target_dir, image_paths = handle_uploads(input_video, input_images) - return None, target_dir, image_paths, "Upload complete. Click 'Reconstruct' to begin 3D processing." + target_dir, image_paths = handle_uploads(input_video, input_images, input_masks) + return ( + None, + target_dir, + image_paths, + "Upload complete. Click 'Reconstruct' to begin 3D processing.", + ) # ------------------------------------------------------------------------- @@ -204,7 +241,11 @@ def gradio_demo( # Prepare frame_filter dropdown target_dir_images = os.path.join(target_dir, "images") - all_files = sorted(os.listdir(target_dir_images)) if os.path.isdir(target_dir_images) else [] + all_files = ( + sorted(os.listdir(target_dir_images)) + if os.path.isdir(target_dir_images) + else [] + ) all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)] frame_filter_choices = ["All"] + all_files @@ -226,6 +267,17 @@ def gradio_demo( f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_maskb{mask_black_bg}_maskw{mask_white_bg}_cam{show_cam}_sky{mask_sky}_pred{prediction_mode.replace(' ', '_')}.glb", ) + # check for masks + target_dir_masks = os.path.join(target_dir, "masks") + if os.path.exists(target_dir_masks): + image_masks = ( + sorted(os.listdir(target_dir_masks)) + if os.path.isdir(target_dir_masks) + else [] + ) + else: + image_masks = None + # Convert predictions to GLB glbscene = predictions_to_glb( predictions, @@ -237,6 +289,7 @@ def gradio_demo( mask_sky=mask_sky, target_dir=target_dir, prediction_mode=prediction_mode, + image_masks=image_masks ) glbscene.export(file_obj=glbfile) @@ -247,9 +300,15 @@ def gradio_demo( end_time = time.time() print(f"Total time: {end_time - start_time:.2f} seconds (including IO)") - log_msg = f"Reconstruction Success ({len(all_files)} frames). Waiting for visualization." + log_msg = ( + f"Reconstruction Success ({len(all_files)} frames). Waiting for visualization." + ) - return glbfile, log_msg, gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True) + return ( + glbfile, + log_msg, + gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True), + ) # ------------------------------------------------------------------------- @@ -270,7 +329,15 @@ def update_log(): def update_visualization( - target_dir, conf_thres, frame_filter, mask_black_bg, mask_white_bg, show_cam, mask_sky, prediction_mode, is_example + target_dir, + conf_thres, + frame_filter, + mask_black_bg, + mask_white_bg, + show_cam, + mask_sky, + prediction_mode, + is_example, ): """ Reload saved predictions from npz, create (or reuse) the GLB for new parameters, @@ -279,14 +346,23 @@ def update_visualization( # If it's an example click, skip as requested if is_example == "True": - return None, "No reconstruction available. Please click the Reconstruct button first." + return ( + None, + "No reconstruction available. Please click the Reconstruct button first.", + ) if not target_dir or target_dir == "None" or not os.path.isdir(target_dir): - return None, "No reconstruction available. Please click the Reconstruct button first." + return ( + None, + "No reconstruction available. Please click the Reconstruct button first.", + ) predictions_path = os.path.join(target_dir, "predictions.npz") if not os.path.exists(predictions_path): - return None, f"No reconstruction available at {predictions_path}. Please run 'Reconstruct' first." + return ( + None, + f"No reconstruction available at {predictions_path}. Please run 'Reconstruct' first.", + ) loaded = np.load(predictions_path, allow_pickle=True) predictions = {key: loaded[key] for key in loaded.keys()} @@ -424,7 +500,9 @@ def update_visualization( with gr.Row(): with gr.Column(scale=2): input_video = gr.Video(label="Upload Video", interactive=True) - input_images = gr.File(file_count="multiple", label="Upload Images", interactive=True) + input_images = gr.File( + file_count="multiple", label="Upload Images", interactive=True + ) image_gallery = gr.Gallery( label="Preview", @@ -434,19 +512,33 @@ def update_visualization( object_fit="contain", preview=True, ) + input_masks = gr.File( + file_count="multiple", label="Upload Masks", interactive=True + ) with gr.Column(scale=4): with gr.Column(): gr.Markdown("**3D Reconstruction (Point Cloud and Camera Poses)**") log_output = gr.Markdown( - "Please upload a video or images, then click Reconstruct.", elem_classes=["custom-log"] + "Please upload a video or images, then click Reconstruct.", + elem_classes=["custom-log"], + ) + reconstruction_output = gr.Model3D( + height=520, zoom_speed=0.5, pan_speed=0.5 ) - reconstruction_output = gr.Model3D(height=520, zoom_speed=0.5, pan_speed=0.5) with gr.Row(): submit_btn = gr.Button("Reconstruct", scale=1, variant="primary") clear_btn = gr.ClearButton( - [input_video, input_images, reconstruction_output, log_output, target_dir_output, image_gallery], + [ + input_video, + input_images, + input_masks, + reconstruction_output, + log_output, + target_dir_output, + image_gallery, + ], scale=1, ) @@ -460,23 +552,112 @@ def update_visualization( ) with gr.Row(): - conf_thres = gr.Slider(minimum=0, maximum=100, value=50, step=0.1, label="Confidence Threshold (%)") - frame_filter = gr.Dropdown(choices=["All"], value="All", label="Show Points from Frame") + conf_thres = gr.Slider( + minimum=0, + maximum=100, + value=50, + step=0.1, + label="Confidence Threshold (%)", + ) + frame_filter = gr.Dropdown( + choices=["All"], value="All", label="Show Points from Frame" + ) with gr.Column(): show_cam = gr.Checkbox(label="Show Camera", value=True) mask_sky = gr.Checkbox(label="Filter Sky", value=False) - mask_black_bg = gr.Checkbox(label="Filter Black Background", value=False) - mask_white_bg = gr.Checkbox(label="Filter White Background", value=False) + mask_black_bg = gr.Checkbox( + label="Filter Black Background", value=False + ) + mask_white_bg = gr.Checkbox( + label="Filter White Background", value=False + ) # ---------------------- Examples section ---------------------- examples = [ - [colosseum_video, "22", None, 20.0, False, False, True, False, "Depthmap and Camera Branch", "True"], - [pyramid_video, "30", None, 35.0, False, False, True, False, "Depthmap and Camera Branch", "True"], - [single_cartoon_video, "1", None, 15.0, False, False, True, False, "Depthmap and Camera Branch", "True"], - [single_oil_painting_video, "1", None, 20.0, False, False, True, True, "Depthmap and Camera Branch", "True"], - [room_video, "8", None, 5.0, False, False, True, False, "Depthmap and Camera Branch", "True"], - [kitchen_video, "25", None, 50.0, False, False, True, False, "Depthmap and Camera Branch", "True"], - [fern_video, "20", None, 45.0, False, False, True, False, "Depthmap and Camera Branch", "True"], + [ + colosseum_video, + "22", + None, + 20.0, + False, + False, + True, + False, + "Depthmap and Camera Branch", + "True", + ], + [ + pyramid_video, + "30", + None, + 35.0, + False, + False, + True, + False, + "Depthmap and Camera Branch", + "True", + ], + [ + single_cartoon_video, + "1", + None, + 15.0, + False, + False, + True, + False, + "Depthmap and Camera Branch", + "True", + ], + [ + single_oil_painting_video, + "1", + None, + 20.0, + False, + False, + True, + True, + "Depthmap and Camera Branch", + "True", + ], + [ + room_video, + "8", + None, + 5.0, + False, + False, + True, + False, + "Depthmap and Camera Branch", + "True", + ], + [ + kitchen_video, + "25", + None, + 50.0, + False, + False, + True, + False, + "Depthmap and Camera Branch", + "True", + ], + [ + fern_video, + "20", + None, + 45.0, + False, + False, + True, + False, + "Depthmap and Camera Branch", + "True", + ], ] def example_pipeline( @@ -501,7 +682,14 @@ def example_pipeline( # Always use "All" for frame_filter in examples frame_filter = "All" glbfile, log_msg, dropdown = gradio_demo( - target_dir, conf_thres, frame_filter, mask_black_bg, mask_white_bg, show_cam, mask_sky, prediction_mode + target_dir, + conf_thres, + frame_filter, + mask_black_bg, + mask_white_bg, + show_cam, + mask_sky, + prediction_mode, ) return glbfile, log_msg, target_dir, dropdown, image_paths @@ -673,12 +861,17 @@ def example_pipeline( # ------------------------------------------------------------------------- input_video.change( fn=update_gallery_on_upload, - inputs=[input_video, input_images], + inputs=[input_video, input_images, input_masks], outputs=[reconstruction_output, target_dir_output, image_gallery, log_output], ) input_images.change( fn=update_gallery_on_upload, - inputs=[input_video, input_images], + inputs=[input_video, input_images, input_masks], + outputs=[reconstruction_output, target_dir_output, image_gallery, log_output], + ) + input_masks.change( + fn=update_gallery_on_upload, + inputs=[input_video, input_images, input_masks], outputs=[reconstruction_output, target_dir_output, image_gallery, log_output], ) diff --git a/requirements_demo.txt b/requirements_demo.txt index efd6f5e9..2b4b1efb 100644 --- a/requirements_demo.txt +++ b/requirements_demo.txt @@ -9,3 +9,4 @@ onnxruntime requests trimesh matplotlib +pydantic==2.10.6 \ No newline at end of file diff --git a/visual_util.py b/visual_util.py index 3c624c92..362aa469 100644 --- a/visual_util.py +++ b/visual_util.py @@ -25,6 +25,7 @@ def predictions_to_glb( mask_sky=False, target_dir=None, prediction_mode="Predicted Pointmap", + image_masks=None ) -> trimesh.Scene: """ Converts VGGT predictions to a 3D scene represented as a GLB file. @@ -148,6 +149,22 @@ def predictions_to_glb( colors_rgb = images colors_rgb = (colors_rgb.reshape(-1, 3) * 255).astype(np.uint8) + if image_masks is not None: + target_dir_masks = target_dir + "/masks" + mask_list = sorted(os.listdir(target_dir_masks)) + + if len(mask_list) != len(image_masks): + print("Number of masks does not match number of images -> using rgb colors") + image_masks = None + else: + image_mask_list = [] + for i, mask_name in enumerate(mask_list): + mask_filepath = os.path.join(target_dir_masks, mask_name) + mask = cv2.imread(mask_filepath) + mask = cv2.resize(mask, (W, H)) + image_mask_list.append(mask) + image_mask_array = np.array(image_mask_list) + mask_rgb = (image_mask_array.reshape(-1, 3)).astype(np.uint8) conf = pred_world_points_conf.reshape(-1) # Convert percentage threshold to actual confidence value if conf_thres == 0.0: @@ -170,6 +187,9 @@ def predictions_to_glb( vertices_3d = vertices_3d[conf_mask] colors_rgb = colors_rgb[conf_mask] + if image_masks is not None: + mask_rgb = mask_rgb[conf_mask] + if vertices_3d is None or np.asarray(vertices_3d).size == 0: vertices_3d = np.array([[1, 0, 0]]) colors_rgb = np.array([[255, 255, 255]]) @@ -187,8 +207,11 @@ def predictions_to_glb( # Initialize a 3D scene scene_3d = trimesh.Scene() - # Add point cloud data to the scene - point_cloud_data = trimesh.PointCloud(vertices=vertices_3d, colors=colors_rgb) + #visualize mask if given - otherwise colorize point cloud + if image_masks is not None: + point_cloud_data = trimesh.PointCloud(vertices=vertices_3d, colors=mask_rgb) + else: + point_cloud_data = trimesh.PointCloud(vertices=vertices_3d, colors=colors_rgb) scene_3d.add_geometry(point_cloud_data)