Skip to content

Commit 5bbefa9

Browse files
Implement FP16 mode
1 parent a2548d8 commit 5bbefa9

File tree

10 files changed

+722
-223
lines changed

10 files changed

+722
-223
lines changed

CMakeLists.txt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@ list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/Modules")
44
set(CMAKE_CXX_STANDARD 20)
55
set(CMAKE_CUDA_STANDARD 20)
66

7-
project(TRT-NNScaler LANGUAGES CXX)
7+
project(TRT-NNScaler LANGUAGES CXX CUDA)
88

99
if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
1010
if (CMAKE_SYSTEM_PROCESSOR STREQUAL aarch64)
11-
set(CMAKE_CUDA_ARCHITECTURES 53 62 72 87)
11+
set(CMAKE_CUDA_ARCHITECTURES 62 72 87)
1212
else ()
1313
set(CMAKE_CUDA_ARCHITECTURES 61 70 75 80 86 89 90)
1414
endif ()
@@ -50,10 +50,10 @@ endif ()
5050

5151
add_subdirectory(libyuv)
5252

53-
#add_library(reformat_cuda OBJECT reformat/reformat_cuda.h reformat/reformat.cu)
53+
add_library(reformat_cuda STATIC reformat/reformat_cuda.h reformat/reformat.cu)
5454

55-
add_library(reformat OBJECT reformat/reformat.h reformat/reformat_cuda.h reformat/reformat.cpp)
56-
target_link_libraries(reformat PUBLIC CUDA::cudart)
55+
add_library(reformat INTERFACE reformat/reformat.h reformat/reformat_cuda.h)
56+
target_link_libraries(reformat INTERFACE CUDA::cudart reformat_cuda)
5757

5858
set(SOURCE_FILES
5959
md_view.h

main.cpp

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <cstdlib>
1212

1313
#include "gflags/gflags.h"
14+
#include "cuda_fp16.h"
1415

1516
#include "nn-scaler.h"
1617
#include "logging.h"
@@ -20,8 +21,17 @@
2021
DEFINE_string(model_path, "models", "path to the folder to save model files");
2122

2223
InferenceSession *session = nullptr;
23-
pixel_importer_cpu *importer = nullptr;
24-
pixel_exporter_cpu_crop *exporter = nullptr;
24+
25+
int using_io = 0;
26+
27+
pixel_importer_cpu *importer_cpu = nullptr;
28+
pixel_exporter_cpu *exporter_cpu = nullptr;
29+
30+
pixel_importer_gpu<float> *importer_gpu = nullptr;
31+
pixel_exporter_gpu<float> *exporter_gpu = nullptr;
32+
33+
pixel_importer_gpu<half> *importer_gpu_fp16 = nullptr;
34+
pixel_exporter_gpu<half> *exporter_gpu_fp16 = nullptr;
2535

2636
static uint64_t total_processed = 0;
2737

@@ -88,9 +98,10 @@ static std::string handle_folder(const std::filesystem::path &input, chan &works
8898

8999
static Logger gLogger;
90100

91-
DEFINE_bool(fp16, false, "Use FP16 processing");
92-
DEFINE_bool(external, false, "Use external algorithms from cuDNN and cuBLAS");
93-
DEFINE_bool(low_mem, false, "Tweak configs to reduce memory consumption");
101+
DEFINE_bool(fp16, false, "use FP16 processing");
102+
DEFINE_bool(external, false, "use external algorithms from cuDNN and cuBLAS");
103+
DEFINE_bool(low_mem, false, "tweak configs to reduce memory consumption");
104+
DEFINE_string(reformatter, "auto", "reformatter used to import and export pixels: cpu, gpu, auto");
94105

95106
DECLARE_string(alpha);
96107
DEFINE_int32(tile_width, 512, "tile width");
@@ -154,7 +165,7 @@ void custom_prefix(std::ostream &s, const google::LogMessageInfo &l, void *) {
154165

155166
DECLARE_string(flagfile);
156167

157-
DEFINE_bool(cuda_lazy_load, true, "Enable CUDA lazying load.");
168+
DEFINE_bool(cuda_lazy_load, true, "enable CUDA lazying load.");
158169

159170
int32_t h_scale, w_scale;
160171

@@ -286,16 +297,34 @@ int wmain(int argc, wchar_t **wargv) {
286297
LOG(FATAL) << "different width and height scale ratio unimplemented.";
287298
}
288299

289-
if (FLAGS_fp16) {
290-
LOG(FATAL) << "FP16 mode unimplemented.";
291-
}
292-
293300
// ------------------------------
294301
// Import & Export
295302
auto max_size = size_t(max_width) * max_height;
296303

297-
importer = new pixel_importer_cpu(max_size, FLAGS_alpha != "ignore");
298-
exporter = new pixel_exporter_cpu_crop(h_scale * w_scale * max_size);
304+
if (FLAGS_reformatter == "auto") {
305+
FLAGS_reformatter = FLAGS_fp16 ? "gpu" : "cpu";
306+
}
307+
if (FLAGS_fp16 && FLAGS_reformatter == "cpu") {
308+
LOG(FATAL) << "CPU reformatter can not handle FP16.";
309+
}
310+
311+
if (FLAGS_reformatter == "cpu") {
312+
importer_cpu = new pixel_importer_cpu(max_size, FLAGS_alpha != "ignore");
313+
exporter_cpu = new pixel_exporter_cpu(h_scale * w_scale * max_size, FLAGS_alpha != "ignore");
314+
using_io = 0;
315+
} else if (FLAGS_reformatter == "gpu") {
316+
if (FLAGS_fp16) {
317+
importer_gpu_fp16 = new pixel_importer_gpu<half>(max_size, FLAGS_alpha != "ignore");
318+
exporter_gpu_fp16 = new pixel_exporter_gpu<half>(h_scale * w_scale * max_size, FLAGS_alpha != "ignore");
319+
using_io = 2;
320+
} else {
321+
importer_gpu = new pixel_importer_gpu<float>(max_size, FLAGS_alpha != "ignore");
322+
exporter_gpu = new pixel_exporter_gpu<float>(h_scale * w_scale * max_size, FLAGS_alpha != "ignore");
323+
using_io = 1;
324+
}
325+
} else {
326+
LOG(FATAL) << "Unknown reformatter.";
327+
}
299328

300329
chan works;
301330
std::thread pipeline(launch_pipeline, std::ref(works));

md_view.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,13 @@ void util_attrs copy(const md_view<T, DIMS> &dst, const md_view<T, DIMS> &src) {
442442
}
443443

444444
namespace detail {
445+
template<class T, class Memcpy>
446+
void util_attrs copy_impl(const md_uview<T, 1> &dst, const md_uview<const T, 1> &src, Memcpy cp) {
447+
for (int i = 0; i < dst.shape[0]; ++i) {
448+
cp(&dst.at(i), &src.at(i), sizeof(T));
449+
}
450+
}
451+
445452
template<class T, std::size_t DIMS, class Memcpy>
446453
void util_attrs copy_impl(const md_uview<T, DIMS> &dst, const md_uview<const T, DIMS> &src, Memcpy cp) {
447454
if (dst.at(0).is_contiguous() && src.at(0).is_contiguous()) {
@@ -455,13 +462,6 @@ void util_attrs copy_impl(const md_uview<T, DIMS> &dst, const md_uview<const T,
455462
}
456463
}
457464
}
458-
459-
template<class T, class Memcpy>
460-
void util_attrs copy_impl(const md_uview<T, 1> &dst, const md_uview<const T, 1> &src, Memcpy cp) {
461-
for (int i = 0; i < dst.shape[0]; ++i) {
462-
cp(dst.at(i).data, src.at(i).data, sizeof(T));
463-
}
464-
}
465465
}
466466

467467
template<class T, std::size_t DIMS, class Memcpy>

optimize.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,10 +118,11 @@ std::string OptimizationContext::optimize() {
118118

119119
network->getInput(0)->setName("input");
120120
network->getOutput(0)->setName("output");
121-
network->getInput(0)->setType(nvinfer1::DataType::kFLOAT);
122-
network->getOutput(0)->setType(nvinfer1::DataType::kFLOAT);
123121

124-
// auto ioDataType = config.use_fp16 ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT;
122+
auto ioDataType = config.use_fp16 ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT;
123+
network->getInput(0)->setType(ioDataType);
124+
network->getOutput(0)->setType(ioDataType);
125+
125126
auto height = config.input_height;
126127
auto width = config.input_width;
127128
auto batch = config.batch;

reformat/reformat.cpp

Lines changed: 0 additions & 54 deletions
This file was deleted.

0 commit comments

Comments
 (0)