Skip to content

Commit 1cafbbd

Browse files
Add alignment parameter: handling models that has input shape alignment requirement
1 parent 3dcfbd3 commit 1cafbbd

File tree

4 files changed

+92
-20
lines changed

4 files changed

+92
-20
lines changed

main.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ DEFINE_int32(tile_width, 512, "tile width");
9797
DEFINE_int32(tile_height, 512, "tile height");
9898
DEFINE_int32(tile_pad, 16, "tile pad border to reduce tile block discontinuity");
9999
DEFINE_int32(extend_grace, 0, "grace limit to not split another tile");
100+
DECLARE_int32(alignment);
100101

101102
void verify_flags() {
102103
if (!exists(std::filesystem::path(FLAGS_model_path))) {
@@ -116,6 +117,11 @@ void verify_flags() {
116117
LOG(FATAL) << "Invalid tile extend grace.";
117118
}
118119

120+
if (FLAGS_alignment < 1 || FLAGS_tile_width % FLAGS_alignment != 0 || FLAGS_tile_height % FLAGS_alignment != 0
121+
|| FLAGS_tile_pad % FLAGS_alignment != 0 || FLAGS_extend_grace % FLAGS_alignment != 0) {
122+
LOG(FATAL) << "Invalid tile alignment.";
123+
}
124+
119125
auto ext_count = std::count(FLAGS_extensions.begin(), FLAGS_extensions.end(), ',');
120126
exts.reserve(ext_count + 1);
121127
exts.emplace_back(FLAGS_extensions);

reformat/reformat.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ std::string pixel_exporter_cpu_crop::fetch_color(md_view<const float, 3> src, cu
3030
return std::string("CUDA error: ") + cudaGetErrorName(err);
3131
}
3232

33+
current_buffer_shape = src.shape;
3334
return "";
3435
}
3536

@@ -40,6 +41,10 @@ std::string pixel_exporter_cpu_crop::fetch_alpha(md_view<const float, 3> src, cu
4041
return "dimension too big";
4142
}
4243

44+
if (current_buffer_shape != src.shape) {
45+
return "incompatible color buffer shape";
46+
}
47+
4348
auto err = cudaMemcpyAsync(buffer_alpha.get(), src.data, h * w * 4, cudaMemcpyDeviceToHost, stream);
4449
if (err != cudaSuccess) {
4550
return std::string("CUDA error: ") + cudaGetErrorName(err);

reformat/reformat.h

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ struct pad_descriptor {
3737
class pixel_exporter_cpu_crop {
3838
std::unique_ptr<float[]> buffer{};
3939
std::unique_ptr<float[]> buffer_alpha{};
40+
shape_t<3> current_buffer_shape;
4041
size_t max_size;
4142

4243
public:
@@ -82,21 +83,23 @@ std::string pixel_importer_cpu::import_color(md_view<float, 3> dst,
8283
md_uview<const U, 3> src,
8384
cudaStream_t stream,
8485
float quant) {
85-
if (dst.shape.slice<1, 2>() != src.shape.template slice<0, 2>()) {
86-
return "dimension mismatch";
87-
}
88-
8986
auto [h, w, c] = src.shape;
87+
auto [dc, dh, dw] = dst.shape;
9088

91-
if (h * w > max_size) {
89+
if (dh * dw > max_size) {
9290
return "dimension too big";
9391
}
9492

93+
if (h > dh || w > dw) {
94+
return "incompatible dimension";
95+
}
96+
9597
if (quant == 0.0) {
9698
quant = 1.0 / float(std::numeric_limits<U>::max());
9799
}
98100

99101
md_view<float, 3> tmp{buffer.get(), dst.shape};
102+
md_view<float, 2> tmp_alpha{buffer_alpha.get(), {h, w}};
100103

101104
if (c == 3) {
102105
for (size_t y = 0; y < h; ++y) {
@@ -118,7 +121,6 @@ std::string pixel_importer_cpu::import_color(md_view<float, 3> dst,
118121
}
119122
}
120123
else {
121-
md_view<float, 2> tmp_alpha{buffer_alpha.get(), {h, w}};
122124
for (size_t y = 0; y < h; ++y) {
123125
for (size_t x = 0; x < w; ++x) {
124126
tmp.at(0, y, x) = static_cast<float>(src.at(y, x, 2)) * quant;
@@ -133,6 +135,39 @@ std::string pixel_importer_cpu::import_color(md_view<float, 3> dst,
133135
assert(false);
134136
}
135137

138+
for (size_t y = h; y < dh; ++y) {
139+
for (size_t x = 0; x < w; ++x) {
140+
tmp.at(0, y, x) = tmp.at(0, h - 1, x);
141+
tmp.at(1, y, x) = tmp.at(1, h - 1, x);
142+
tmp.at(2, y, x) = tmp.at(2, h - 1, x);
143+
if (c == 4 && buffer_alpha) {
144+
tmp_alpha.at(y, x) = tmp_alpha.at(h - 1, x);
145+
}
146+
}
147+
}
148+
149+
for (size_t y = 0; y < h; ++y) {
150+
for (size_t x = w; x < dw; ++x) {
151+
tmp.at(0, y, x) = tmp.at(0, y, w - 1);
152+
tmp.at(1, y, x) = tmp.at(1, y, w - 1);
153+
tmp.at(2, y, x) = tmp.at(2, y, w - 1);
154+
if (c == 4 && buffer_alpha) {
155+
tmp_alpha.at(y, x) = tmp_alpha.at(y, w - 1);
156+
}
157+
}
158+
}
159+
160+
for (size_t y = h; y < dh; ++y) {
161+
for (size_t x = w; x < dw; ++x) {
162+
tmp.at(0, y, x) = tmp.at(0, h - 1, w - 1);
163+
tmp.at(1, y, x) = tmp.at(1, h - 1, w - 1);
164+
tmp.at(2, y, x) = tmp.at(2, h - 1, w - 1);
165+
if (c == 4 && buffer_alpha) {
166+
tmp_alpha.at(y, x) = tmp_alpha.at(h - 1, w - 1);
167+
}
168+
}
169+
}
170+
136171
auto err = cudaMemcpyAsync(dst.data, tmp.data, dst.size() * 4, cudaMemcpyHostToDevice, stream);
137172
if (err != cudaSuccess) {
138173
return std::string("CUDA error: ") + cudaGetErrorName(err);
@@ -148,8 +183,13 @@ std::string pixel_exporter_cpu_crop::export_data(md_uview<U, 3> dst, pad_descrip
148183
}
149184

150185
auto [he, we, c] = dst.shape;
151-
md_uview<float, 3> tmp = md_view<float, 3>{buffer.get(), {c, he, we}};
152-
md_uview<float, 2> tmp_alpha = md_view<float, 2>{buffer_alpha.get(), {he, we}};
186+
auto [_, hs, ws] = current_buffer_shape;
187+
if (he > hs || we > ws) {
188+
return "incompatible dimension";
189+
}
190+
191+
md_uview<float, 3> tmp = md_view<float, 3>{buffer.get(), current_buffer_shape};
192+
md_uview<float, 2> tmp_alpha = md_view<float, 2>{buffer_alpha.get(), current_buffer_shape.slice<1, 2>()};
153193

154194
offset_t shrink = pad.pad / 2;
155195
offset_t hb = pad.top ? 0 : shrink;

workers.cpp

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
#include <array>
66
#include <memory>
77
#include <iostream>
8-
#include <cmath>
98

109
#include "nn-scaler.h"
1110
#include "reformat/reformat.h"
@@ -178,6 +177,12 @@ DECLARE_int32(tile_width);
178177
DECLARE_int32(tile_height);
179178
DECLARE_int32(tile_pad);
180179
DECLARE_int32(extend_grace);
180+
DEFINE_int32(alignment, 1, "model input alignment requirement");
181+
182+
static offset_t align(offset_t n, size_t alignment) {
183+
n += alignment - 1;
184+
return n - (n % alignment);
185+
}
181186

182187
static void pixel_import_worker(ichan &in, ichan &out) {
183188
bool nn_alpha = FLAGS_alpha == "nn";
@@ -192,19 +197,23 @@ static void pixel_import_worker(ichan &in, ichan &out) {
192197

193198
auto [h, w, c] = ctx.in_image.shape;
194199
auto process_alpha = nn_alpha && c == 4;
200+
offset_t h_split = align(h, FLAGS_alignment), w_split = align(w, FLAGS_alignment);
195201

196202
split_range<offset_t>(
197-
h, FLAGS_tile_height, FLAGS_tile_pad, FLAGS_extend_grace,
198-
[&, w = w](offset_t y, offset_t th, bool h_beg, bool h_end) {
203+
h_split, FLAGS_tile_height, FLAGS_tile_pad, FLAGS_extend_grace,
204+
[&, h = h, w = w](offset_t y, offset_t th, bool h_beg, bool h_end) {
199205
return split_range<offset_t>(
200-
w, FLAGS_tile_width, FLAGS_tile_pad, FLAGS_extend_grace,
206+
w_split, FLAGS_tile_width, FLAGS_tile_pad, FLAGS_extend_grace,
201207
[&](offset_t x, offset_t tw, bool w_beg, bool w_end) -> bool {
202208
auto tile_start = hr_clock::now();
203209

204210
md_view<float, 3> input_tensor = {reinterpret_cast<float *>(session->input), {3, th, tw}};
205-
importer->import_color(input_tensor,
206-
ctx.in_image.slice<0>(y, y + th).slice<1>(x, x + tw),
211+
auto ret = importer->import_color(input_tensor,
212+
ctx.in_image.slice<0>(y, std::min(y + th, h)).slice<1>(x, std::min(x + tw, w)),
207213
session->stream);
214+
if (!ret.empty()) {
215+
LOG(FATAL) << "Unexpected error importing pixel: " << ret;
216+
}
208217

209218
WorkContextInternal tile_ctx{
210219
.tile_start = tile_start,
@@ -244,7 +253,11 @@ static void pixel_import_worker(ichan &in, ichan &out) {
244253

245254
if (process_alpha) {
246255
auto alpha_start = hr_clock::now();
247-
importer->import_alpha(input_tensor, session->stream);
256+
ret = importer->import_alpha(input_tensor, session->stream);
257+
if (!ret.empty()) {
258+
LOG(FATAL) << "Unexpected error importing pixel: " << ret;
259+
}
260+
248261
tile_ctx = {
249262
.tile_start = alpha_start,
250263
.y = y, .x = x, .th = th, .tw = tw,
@@ -331,11 +344,15 @@ static void pixel_export_worker(ichan &in, ichan &out) {
331344
auto ctx = std::move(*i);
332345
md_view<float, 3> output_tensor =
333346
{reinterpret_cast<float *>(session->output), {3, ctx.th * h_scale, ctx.tw * w_scale}};
347+
std::string ret;
334348
if (ctx.is_alpha) {
335-
exporter->fetch_alpha(output_tensor, session->stream);
349+
ret = exporter->fetch_alpha(output_tensor, session->stream);
336350
}
337351
else {
338-
exporter->fetch_color(output_tensor, session->stream);
352+
ret = exporter->fetch_color(output_tensor, session->stream);
353+
}
354+
if (!ret.empty()) {
355+
LOG(FATAL) << "Unexpected error fetching result pixel: " << ret;
339356
}
340357

341358
auto err = cudaStreamSynchronize(session->stream);
@@ -358,11 +375,15 @@ static void pixel_export_worker(ichan &in, ichan &out) {
358375
}
359376

360377
pad_descriptor pad_desc{FLAGS_tile_pad * h_scale, ctx.h_beg, ctx.h_end, ctx.w_beg, ctx.w_end};
378+
auto [h, w, _] = ctx.out_image.shape;
361379
auto out_tile = ctx.out_image
362-
.slice<0>(h_scale * ctx.y, h_scale * (ctx.y + ctx.th))
363-
.slice<1>(w_scale * ctx.x, w_scale * (ctx.x + ctx.tw));
380+
.slice<0>(h_scale * ctx.y, std::min(h_scale * (ctx.y + ctx.th), h))
381+
.slice<1>(w_scale * ctx.x, std::min(w_scale * (ctx.x + ctx.tw), w));
364382
if (!ctx.has_alpha || ctx.is_alpha) {
365-
exporter->export_data(out_tile, pad_desc);
383+
ret = exporter->export_data(out_tile, pad_desc);
384+
if (!ret.empty()) {
385+
LOG(FATAL) << "Unexpected error exporting pixel: " << ret;
386+
}
366387
}
367388

368389
VLOG(1) << "Tile "

0 commit comments

Comments
 (0)