|
11 | 11 | #include <cstdlib>
|
12 | 12 |
|
13 | 13 | #include "gflags/gflags.h"
|
| 14 | +#include "cuda_fp16.h" |
14 | 15 |
|
15 | 16 | #include "nn-scaler.h"
|
16 | 17 | #include "logging.h"
|
|
20 | 21 | DEFINE_string(model_path, "models", "path to the folder to save model files");
|
21 | 22 |
|
22 | 23 | InferenceSession *session = nullptr;
|
23 |
| -pixel_importer_cpu *importer = nullptr; |
24 |
| -pixel_exporter_cpu_crop *exporter = nullptr; |
| 24 | + |
| 25 | +int using_io = 0; |
| 26 | + |
| 27 | +pixel_importer_cpu *importer_cpu = nullptr; |
| 28 | +pixel_exporter_cpu *exporter_cpu = nullptr; |
| 29 | + |
| 30 | +pixel_importer_gpu<float> *importer_gpu = nullptr; |
| 31 | +pixel_exporter_gpu<float> *exporter_gpu = nullptr; |
| 32 | + |
| 33 | +pixel_importer_gpu<half> *importer_gpu_fp16 = nullptr; |
| 34 | +pixel_exporter_gpu<half> *exporter_gpu_fp16 = nullptr; |
25 | 35 |
|
26 | 36 | static uint64_t total_processed = 0;
|
27 | 37 |
|
@@ -88,9 +98,10 @@ static std::string handle_folder(const std::filesystem::path &input, chan &works
|
88 | 98 |
|
89 | 99 | static Logger gLogger;
|
90 | 100 |
|
91 |
| -DEFINE_bool(fp16, false, "Use FP16 processing"); |
92 |
| -DEFINE_bool(external, false, "Use external algorithms from cuDNN and cuBLAS"); |
93 |
| -DEFINE_bool(low_mem, false, "Tweak configs to reduce memory consumption"); |
| 101 | +DEFINE_bool(fp16, false, "use FP16 processing"); |
| 102 | +DEFINE_bool(external, false, "use external algorithms from cuDNN and cuBLAS"); |
| 103 | +DEFINE_bool(low_mem, false, "tweak configs to reduce memory consumption"); |
| 104 | +DEFINE_string(reformatter, "auto", "reformatter used to import and export pixels: cpu, gpu, auto"); |
94 | 105 |
|
95 | 106 | DECLARE_string(alpha);
|
96 | 107 | DEFINE_int32(tile_width, 512, "tile width");
|
@@ -154,7 +165,7 @@ void custom_prefix(std::ostream &s, const google::LogMessageInfo &l, void *) {
|
154 | 165 |
|
155 | 166 | DECLARE_string(flagfile);
|
156 | 167 |
|
157 |
| -DEFINE_bool(cuda_lazy_load, true, "Enable CUDA lazying load."); |
| 168 | +DEFINE_bool(cuda_lazy_load, true, "enable CUDA lazying load."); |
158 | 169 |
|
159 | 170 | int32_t h_scale, w_scale;
|
160 | 171 |
|
@@ -286,16 +297,34 @@ int wmain(int argc, wchar_t **wargv) {
|
286 | 297 | LOG(FATAL) << "different width and height scale ratio unimplemented.";
|
287 | 298 | }
|
288 | 299 |
|
289 |
| - if (FLAGS_fp16) { |
290 |
| - LOG(FATAL) << "FP16 mode unimplemented."; |
291 |
| - } |
292 |
| - |
293 | 300 | // ------------------------------
|
294 | 301 | // Import & Export
|
295 | 302 | auto max_size = size_t(max_width) * max_height;
|
296 | 303 |
|
297 |
| - importer = new pixel_importer_cpu(max_size, FLAGS_alpha != "ignore"); |
298 |
| - exporter = new pixel_exporter_cpu_crop(h_scale * w_scale * max_size); |
| 304 | + if (FLAGS_reformatter == "auto") { |
| 305 | + FLAGS_reformatter = FLAGS_fp16 ? "gpu" : "cpu"; |
| 306 | + } |
| 307 | + if (FLAGS_fp16 && FLAGS_reformatter == "cpu") { |
| 308 | + LOG(FATAL) << "CPU reformatter can not handle FP16."; |
| 309 | + } |
| 310 | + |
| 311 | + if (FLAGS_reformatter == "cpu") { |
| 312 | + importer_cpu = new pixel_importer_cpu(max_size, FLAGS_alpha != "ignore"); |
| 313 | + exporter_cpu = new pixel_exporter_cpu(h_scale * w_scale * max_size, FLAGS_alpha != "ignore"); |
| 314 | + using_io = 0; |
| 315 | + } else if (FLAGS_reformatter == "gpu") { |
| 316 | + if (FLAGS_fp16) { |
| 317 | + importer_gpu_fp16 = new pixel_importer_gpu<half>(max_size, FLAGS_alpha != "ignore"); |
| 318 | + exporter_gpu_fp16 = new pixel_exporter_gpu<half>(h_scale * w_scale * max_size, FLAGS_alpha != "ignore"); |
| 319 | + using_io = 2; |
| 320 | + } else { |
| 321 | + importer_gpu = new pixel_importer_gpu<float>(max_size, FLAGS_alpha != "ignore"); |
| 322 | + exporter_gpu = new pixel_exporter_gpu<float>(h_scale * w_scale * max_size, FLAGS_alpha != "ignore"); |
| 323 | + using_io = 1; |
| 324 | + } |
| 325 | + } else { |
| 326 | + LOG(FATAL) << "Unknown reformatter."; |
| 327 | + } |
299 | 328 |
|
300 | 329 | chan works;
|
301 | 330 | std::thread pipeline(launch_pipeline, std::ref(works));
|
|
0 commit comments