From 0d64ccb44a939db8ba0ae52280e9861ed06c9351 Mon Sep 17 00:00:00 2001 From: loliconce <76913252+loliconce@users.noreply.github.com> Date: Wed, 12 Apr 2023 19:50:28 +0800 Subject: [PATCH] Add test function Add test function to generate the pic of val_data --- ML/Pytorch/GANs/CycleGAN/test.py | 67 ++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 ML/Pytorch/GANs/CycleGAN/test.py diff --git a/ML/Pytorch/GANs/CycleGAN/test.py b/ML/Pytorch/GANs/CycleGAN/test.py new file mode 100644 index 00000000..93c28a01 --- /dev/null +++ b/ML/Pytorch/GANs/CycleGAN/test.py @@ -0,0 +1,67 @@ +import torch +import config +from tqdm import tqdm +import torch.optim as optim +from torch.utils.data import DataLoader +from torchvision.utils import save_image +from dataset import HorseZebraDataset +from generator_model import Generator +from utils import load_checkpoint + + + +def test_fn(gen_Z, gen_H, loader): + + loop = tqdm(loader, leave=True) + + for idx, (zebra, horse) in enumerate(loop): + zebra = zebra.to(config.DEVICE) + horse = horse.to(config.DEVICE) + + with torch.cuda.amp.autocast(): + fake_horse = gen_H(zebra) + fake_zebra = gen_Z(horse) + + save_image(fake_horse * 0.5 + 0.5, f"saved_images/horse_{idx}.png") + save_image(fake_zebra * 0.5 + 0.5, f"saved_images/zebra_{idx}.png") + +def main(): + gen_Z = Generator(img_channels=3, num_residuals=9).to(config.DEVICE) + gen_H = Generator(img_channels=3, num_residuals=9).to(config.DEVICE) + + + opt_gen = optim.Adam( + list(gen_Z.parameters()) + list(gen_H.parameters()), + lr=config.LEARNING_RATE, + betas=(0.5, 0.999), + ) + load_checkpoint( + config.CHECKPOINT_GEN_H, + gen_H, + opt_gen, + config.LEARNING_RATE, + ) + load_checkpoint( + config.CHECKPOINT_GEN_Z, + gen_Z, + opt_gen, + config.LEARNING_RATE, + ) + + val_dataset = HorseZebraDataset( + root_horse=config.VAL_DIR + "/testA", + root_zebra=config.VAL_DIR + "/testB", + transform=config.transforms, + ) + + loader = DataLoader( + val_dataset, + batch_size=config.BATCH_SIZE, + shuffle=False, + num_workers=config.NUM_WORKERS, + pin_memory=True, + ) + test_fn(gen_Z, gen_H, loader) + +if __name__ == "__main__": + main() \ No newline at end of file