-
Notifications
You must be signed in to change notification settings - Fork 300
Open
Description
HI, first a big thank you for publishing this work.
I am trying to use a trained model and query it with a new probe image.
It seems to me a very imprtant functionality , after all that is what you train the network for, right?
But I couldn't find it anywhere. I tried writing something, but I get poor results.
here is what I came up with:
any insights would be most appreciated.
thanks,
Omer
import os
import cv2
import numpy as np
from model import DCGAN
from utils import get_image, image_save, save_images
import tensorflow as tf
from scipy.misc import imresize
flags = tf.app.flags
flags.DEFINE_integer("batch_size", 64, "The size of batch images [64]")
flags.DEFINE_integer("image_size", 128, "The size of image to use")
flags.DEFINE_string("checkpoint_dir", "/home/omer/work/sub_pixel/models",
"Directory name to read the checkpoints [checkpoint]")
flags.DEFINE_string("test_image_dir", "/home/omer/work/sub_pixel/data/celebA/valid",
"Directory name of the images to evaluate")
flags.DEFINE_string("out_dir", "/home/omer/work/sub_pixel/out", "Directory name of to save results in")
FLAGS = flags.FLAGS
def doresize(x, shape):
x = np.copy((x + 1.) * 127.5).astype("uint8")
y = imresize(x, shape)
return y
def main():
with tf.Session() as sess:
dcgan = DCGAN(sess, image_size=FLAGS.image_size, image_shape=[FLAGS.image_size, FLAGS.image_size, 3],
batch_size=FLAGS.batch_size,
dataset_name='celebA', is_crop=False, checkpoint_dir=FLAGS.checkpoint_dir)
res = dcgan.load(FLAGS.checkpoint_dir)
if not res:
print ("failed loading model from path:" + FLAGS.checkpoint_dir)
return
i = 0
files = []
num_batches = len(os.listdir(FLAGS.test_image_dir)) / FLAGS.batch_size
completed_batches = 0
input_images = np.zeros(shape=(FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size, 3))
for f in os.listdir(FLAGS.test_image_dir):
try:
img_path = os.path.join(FLAGS.test_image_dir, f)
if os.path.isdir(img_path):
i += 1
continue
img = get_image(img_path, FLAGS.image_size, False)
files.append(f)
input_images[i] = img
if i == FLAGS.batch_size - 1 or i == len(os.listdir(FLAGS.test_image_dir)) - 1:
batch_ready(dcgan, input_images, sess, files)
i = 0
input_images = np.zeros(shape=(FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size, 3))
files = []
completed_batches += 1
print('done batch {0} out of {1}'.format(completed_batches, num_batches))
else:
i += 1
except Exception as e:
print("problem working on:" + f)
print (str(e))
i += 1
def batch_ready(dcgan, input_images, sess, files):
input_resized = [doresize(xx, (32, 32, 3)) for xx in input_images]
sample_input_resized = np.array(input_resized).astype(np.float32)
sample_input_images = np.array(input_images).astype(np.float32)
output_images = sess.run(fetches=[dcgan.G],
feed_dict={dcgan.inputs: sample_input_resized, dcgan.images: sample_input_images})
save_results(output_images, files)
def save_results(output_images, files):
for k in range(0, len(files)):
out_path = os.path.join(FLAGS.out_dir, files[k] + '_.png')
out_img = output_images[0][k]
# out_correct = ((out_img + 1) * 127.5).astype(np.uint8)
# out_correct = cv2.cvtColor(out_correct, cv2.COLOR_RGB2BGR)
# cv2.imshow('image', out_correct)
# cv2.waitKey(0)
image_save(out_img, out_path)
if __name__ == '__main__':
main()
Metadata
Metadata
Assignees
Labels
No labels