Skip to content

Commit 4e3500c

Browse files
committed
support noisy preview via API
1 parent 31d36b2 commit 4e3500c

File tree

5 files changed

+36
-14
lines changed

5 files changed

+36
-14
lines changed

examples/cli/main.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1506,7 +1506,10 @@ bool load_images_from_dir(const std::string dir,
15061506
const char* preview_path;
15071507
float preview_fps;
15081508

1509-
void step_callback(int step, int frame_count, sd_image_t* image) {
1509+
void step_callback(int step, int frame_count, sd_image_t* image, bool is_noisy) {
1510+
(void)is_noisy;
1511+
// is_noisy is set to true if the preview corresponds to noisy latents, false if it's denoised latents
1512+
// unused in this app, it will either be always noisy or always denoised here
15101513
if (frame_count == 1) {
15111514
stbi_write_png(preview_path, image->width, image->height, image->channel, image->data, 0);
15121515
} else {
@@ -1541,7 +1544,7 @@ int main(int argc, const char* argv[]) {
15411544
params.high_noise_sample_params.guidance.slg.layer_count = params.high_noise_skip_layers.size();
15421545

15431546
sd_set_log_callback(sd_log_cb, (void*)&params);
1544-
sd_set_preview_callback((sd_preview_cb_t)step_callback, params.preview_method, params.preview_interval);
1547+
sd_set_preview_callback((sd_preview_cb_t)step_callback, params.preview_method, params.preview_interval, true, false);
15451548

15461549
if (params.verbose) {
15471550
print_params(params);

stable-diffusion.cpp

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1148,7 +1148,8 @@ class StableDiffusionGGML {
11481148
enum SDVersion version,
11491149
preview_t preview_mode,
11501150
ggml_tensor* result,
1151-
std::function<void(int, int, sd_image_t*)> step_callback) {
1151+
std::function<void(int, int, sd_image_t*, bool)> step_callback,
1152+
bool is_noisy) {
11521153
const uint32_t channel = 3;
11531154
uint32_t width = latents->ne[0];
11541155
uint32_t height = latents->ne[1];
@@ -1218,7 +1219,7 @@ class StableDiffusionGGML {
12181219
for (int i = 0; i < frames; i++) {
12191220
images[i] = {width, height, channel, data + i * width * height * channel};
12201221
}
1221-
step_callback(step, frames, images);
1222+
step_callback(step, frames, images, is_noisy);
12221223
free(data);
12231224
free(images);
12241225
} else {
@@ -1272,7 +1273,7 @@ class StableDiffusionGGML {
12721273
images[i].data = sd_tensor_to_image(result, i, ggml_n_dims(latents) == 4);
12731274
}
12741275

1275-
step_callback(step, frames, images);
1276+
step_callback(step, frames, images, is_noisy);
12761277

12771278
ggml_tensor_scale(result, 0);
12781279
for (int i = 0; i < frames; i++) {
@@ -1384,6 +1385,8 @@ class StableDiffusionGGML {
13841385
}
13851386

13861387
auto denoise = [&](ggml_tensor* input, float sigma, int step) -> ggml_tensor* {
1388+
auto sd_preview_cb = sd_get_preview_callback();
1389+
auto sd_preview_mode = sd_get_preview_mode();
13871390
if (step == 1 || step == -1) {
13881391
pretty_progress(0, (int)steps, 0);
13891392
}
@@ -1418,6 +1421,11 @@ class StableDiffusionGGML {
14181421
if (denoise_mask != nullptr && version == VERSION_WAN2_2_TI2V) {
14191422
apply_mask(noised_input, init_latent, denoise_mask);
14201423
}
1424+
if (sd_preview_cb != NULL && sd_should_preview_noisy()) {
1425+
if (step % sd_get_preview_interval() == 0) {
1426+
preview_image(work_ctx, step, noised_input, version, sd_preview_mode, preview_tensor, sd_preview_cb, true);
1427+
}
1428+
}
14211429

14221430
std::vector<struct ggml_tensor*> controls;
14231431

@@ -1542,14 +1550,13 @@ class StableDiffusionGGML {
15421550
if (denoise_mask != nullptr) {
15431551
apply_mask(denoised, init_latent, denoise_mask);
15441552
}
1545-
auto sd_preview_cb = sd_get_preview_callback();
1546-
auto sd_preview_mode = sd_get_preview_mode();
1547-
if (sd_preview_cb != NULL) {
1553+
1554+
if (sd_preview_cb != NULL && sd_should_preview_denoised()) {
15481555
if (step % sd_get_preview_interval() == 0) {
1549-
preview_image(work_ctx, step, denoised, version, sd_preview_mode, preview_tensor, sd_preview_cb);
1556+
preview_image(work_ctx, step, denoised, version, sd_preview_mode, preview_tensor, sd_preview_cb, false);
15501557
}
15511558
}
1552-
1559+
15531560
int64_t t1 = ggml_time_us();
15541561
if (step > 0 || step == -(int)steps) {
15551562
int showstep = std::abs(step);

stable-diffusion.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -263,11 +263,11 @@ typedef struct sd_ctx_t sd_ctx_t;
263263

264264
typedef void (*sd_log_cb_t)(enum sd_log_level_t level, const char* text, void* data);
265265
typedef void (*sd_progress_cb_t)(int step, int steps, float time, void* data);
266-
typedef void (*sd_preview_cb_t)(int, int, sd_image_t*);
266+
typedef void (*sd_preview_cb_t)(int step, int frame_count, sd_image_t* frames, bool is_noisy);
267267

268268
SD_API void sd_set_log_callback(sd_log_cb_t sd_log_cb, void* data);
269269
SD_API void sd_set_progress_callback(sd_progress_cb_t cb, void* data);
270-
SD_API void sd_set_preview_callback(sd_preview_cb_t cb, preview_t mode, int interval);
270+
SD_API void sd_set_preview_callback(sd_preview_cb_t cb, preview_t mode, int interval, bool denoised, bool noisy);
271271
SD_API int32_t get_num_physical_cores();
272272
SD_API const char* sd_get_system_info();
273273

util.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,8 +189,10 @@ static sd_progress_cb_t sd_progress_cb = nullptr;
189189
void* sd_progress_cb_data = nullptr;
190190

191191
static sd_preview_cb_t sd_preview_cb = NULL;
192-
preview_t sd_preview_mode = PREVIEW_NONE;
192+
preview_t sd_preview_mode = PREVIEW_NONE;
193193
int sd_preview_interval = 1;
194+
bool sd_preview_denoised = true;
195+
bool sd_preview_noisy = false;
194196

195197
std::u32string utf8_to_utf32(const std::string& utf8_str) {
196198
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter;
@@ -335,10 +337,12 @@ void sd_set_progress_callback(sd_progress_cb_t cb, void* data) {
335337
sd_progress_cb = cb;
336338
sd_progress_cb_data = data;
337339
}
338-
void sd_set_preview_callback(sd_preview_cb_t cb, preview_t mode = PREVIEW_PROJ, int interval = 1) {
340+
void sd_set_preview_callback(sd_preview_cb_t cb, preview_t mode = PREVIEW_PROJ, int interval = 1, bool denoised = true, bool noisy = false) {
339341
sd_preview_cb = cb;
340342
sd_preview_mode = mode;
341343
sd_preview_interval = interval;
344+
sd_preview_denoised = denoised;
345+
sd_preview_noisy = noisy;
342346
}
343347

344348
sd_preview_cb_t sd_get_preview_callback() {
@@ -351,6 +355,12 @@ preview_t sd_get_preview_mode() {
351355
int sd_get_preview_interval() {
352356
return sd_preview_interval;
353357
}
358+
bool sd_should_preview_denoised() {
359+
return sd_preview_denoised;
360+
}
361+
bool sd_should_preview_noisy() {
362+
return sd_preview_noisy;
363+
}
354364

355365
sd_progress_cb_t sd_get_progress_callback() {
356366
return sd_progress_cb;

util.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ void* sd_get_progress_callback_data();
6060
sd_preview_cb_t sd_get_preview_callback();
6161
preview_t sd_get_preview_mode();
6262
int sd_get_preview_interval();
63+
bool sd_should_preview_denoised();
64+
bool sd_should_preview_noisy();
6365

6466
#define LOG_DEBUG(format, ...) log_printf(SD_LOG_DEBUG, __FILE__, __LINE__, format, ##__VA_ARGS__)
6567
#define LOG_INFO(format, ...) log_printf(SD_LOG_INFO, __FILE__, __LINE__, format, ##__VA_ARGS__)

0 commit comments

Comments
 (0)