diff --git a/tutorial_dataset.py b/tutorial_dataset.py index 616d874..9894ab1 100644 --- a/tutorial_dataset.py +++ b/tutorial_dataset.py @@ -201,6 +201,7 @@ def __getitem__(self, idx): bbx_instance = torch.tensor(bbx_instance) shadowfree_img = cv2.cvtColor(shadowfree_img, cv2.COLOR_BGR2RGB) target = cv2.cvtColor(shadowfree_img, cv2.COLOR_BGR2RGB) + zt = target source = np.concatenate((shadowfree_img, object_mask[:, :, np.newaxis]), axis=-1) cls_input = np.concatenate((shadowfree_img, object_mask[:, :, np.newaxis]), axis=-1) cls_input = cls_input.astype(np.float32) / 255.0 @@ -211,4 +212,4 @@ def __getitem__(self, idx): mask_embeddings = torch.zeros((64, 2048), dtype=torch.float32) bbx_region = torch.zeros((512, 512), dtype=torch.float32) - return dict(jpg=target, cls=cls_input, fg=bbx_instance, bbx=bbx_region, embeddings=mask_embeddings, img_name=pic_name, txt=prompt, hint=source) + return dict(zt=zt, jpg=target, cls=cls_input, fg=bbx_instance, bbx=bbx_region, embeddings=mask_embeddings, img_name=pic_name, txt=prompt, hint=source)