Skip to content

Commit 967fa23

Browse files
authored
Solari GI: Balance heuristic for spatial resampling (#20259)
Use the balance heuristic for spatial resampling. This greatly reduces artifacts from the jacobian, and matches the reference PT way closer. I also tried it for DI, but doesn't make much of a difference there (at least for static scenes). To test, edit line 82 of restir_gi.wgsl to go from `+=` to `=` so that you only see the GI. Then try comparing the code in this PR against main. Setting CONFIDENCE_WEIGHT_CAP to 300 (longer temporal history) can also make it more obvious.
1 parent dd57875 commit 967fa23

File tree

4 files changed

+141
-53
lines changed

4 files changed

+141
-53
lines changed

crates/bevy_solari/src/realtime/restir_di.wgsl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ fn spatial_and_shade(@builtin(global_invocation_id) global_id: vec3<u32>) {
9292
textureStore(view_output, global_id.xy, vec4(pixel_color, 1.0));
9393
}
9494

95-
fn generate_initial_reservoir(world_position: vec3<f32>, world_normal: vec3<f32>, diffuse_brdf: vec3<f32>, workgroup_id: vec2<u32>, rng: ptr<function, u32>) -> Reservoir{
95+
fn generate_initial_reservoir(world_position: vec3<f32>, world_normal: vec3<f32>, diffuse_brdf: vec3<f32>, workgroup_id: vec2<u32>, rng: ptr<function, u32>) -> Reservoir {
9696
var workgroup_rng = (workgroup_id.x * 5782582u) + workgroup_id.y;
9797
let light_tile_start = rand_range_u(128u, &workgroup_rng) * 1024u;
9898

@@ -266,7 +266,6 @@ fn merge_reservoirs(
266266
diffuse_brdf: vec3<f32>,
267267
rng: ptr<function, u32>,
268268
) -> ReservoirMergeResult {
269-
// TODO: Balance heuristic MIS weights
270269
let mis_weight_denominator = 1.0 / (canonical_reservoir.confidence_weight + other_reservoir.confidence_weight);
271270

272271
let canonical_mis_weight = canonical_reservoir.confidence_weight * mis_weight_denominator;

crates/bevy_solari/src/realtime/restir_gi.wgsl

Lines changed: 138 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@ fn initial_and_temporal(@builtin(global_invocation_id) global_id: vec3<u32>) {
4444

4545
let initial_reservoir = generate_initial_reservoir(world_position, world_normal, &rng);
4646
let temporal_reservoir = load_temporal_reservoir(global_id.xy, depth, world_position, world_normal);
47-
let merge_result = merge_reservoirs(initial_reservoir, temporal_reservoir, vec3(1.0), vec3(1.0), &rng);
47+
let combined_reservoir = merge_reservoirs(initial_reservoir, temporal_reservoir, &rng);
4848

49-
gi_reservoirs_b[pixel_index] = merge_result.merged_reservoir;
49+
gi_reservoirs_b[pixel_index] = combined_reservoir;
5050
}
5151

5252
@compute @workgroup_size(8, 8, 1)
@@ -68,12 +68,9 @@ fn spatial_and_shade(@builtin(global_invocation_id) global_id: vec3<u32>) {
6868
let diffuse_brdf = base_color / PI;
6969

7070
let input_reservoir = gi_reservoirs_b[pixel_index];
71-
let spatial_reservoir = load_spatial_reservoir(global_id.xy, depth, world_position, world_normal, &rng);
72-
73-
let input_factor = dot(normalize(input_reservoir.sample_point_world_position - world_position), world_normal) * diffuse_brdf;
74-
let spatial_factor = dot(normalize(spatial_reservoir.sample_point_world_position - world_position), world_normal) * diffuse_brdf;
75-
76-
let merge_result = merge_reservoirs(input_reservoir, spatial_reservoir, input_factor, spatial_factor, &rng);
71+
let spatial = load_spatial_reservoir(global_id.xy, depth, world_position, world_normal, &rng);
72+
let merge_result = merge_reservoirs_spatial(input_reservoir, world_position, world_normal, diffuse_brdf,
73+
spatial.reservoir, spatial.world_position, spatial.world_normal, spatial.diffuse_brdf, &rng);
7774
let combined_reservoir = merge_result.merged_reservoir;
7875

7976
gi_reservoirs_a[pixel_index] = combined_reservoir;
@@ -83,7 +80,7 @@ fn spatial_and_shade(@builtin(global_invocation_id) global_id: vec3<u32>) {
8380
textureStore(view_output, global_id.xy, pixel_color);
8481
}
8582

86-
fn generate_initial_reservoir(world_position: vec3<f32>, world_normal: vec3<f32>, rng: ptr<function, u32>) -> Reservoir{
83+
fn generate_initial_reservoir(world_position: vec3<f32>, world_normal: vec3<f32>, rng: ptr<function, u32>) -> Reservoir {
8784
var reservoir = empty_reservoir();
8885

8986
let ray_direction = sample_uniform_hemisphere(world_normal, rng);
@@ -141,34 +138,30 @@ fn load_temporal_reservoir(pixel_id: vec2<u32>, depth: f32, world_position: vec3
141138
return temporal_reservoir;
142139
}
143140

144-
fn load_spatial_reservoir(pixel_id: vec2<u32>, depth: f32, world_position: vec3<f32>, world_normal: vec3<f32>, rng: ptr<function, u32>) -> Reservoir {
141+
struct SpatialInfo {
142+
reservoir: Reservoir,
143+
world_position: vec3<f32>,
144+
world_normal: vec3<f32>,
145+
diffuse_brdf: vec3<f32>,
146+
}
147+
148+
fn load_spatial_reservoir(pixel_id: vec2<u32>, depth: f32, world_position: vec3<f32>, world_normal: vec3<f32>, rng: ptr<function, u32>) -> SpatialInfo {
145149
let spatial_pixel_id = get_neighbor_pixel_id(pixel_id, rng);
146150

147151
let spatial_depth = textureLoad(depth_buffer, spatial_pixel_id, 0);
148152
let spatial_gpixel = textureLoad(gbuffer, spatial_pixel_id, 0);
149153
let spatial_world_position = reconstruct_world_position(spatial_pixel_id, spatial_depth);
150154
let spatial_world_normal = octahedral_decode(unpack_24bit_normal(spatial_gpixel.a));
155+
let spatial_base_color = pow(unpack4x8unorm(spatial_gpixel.r).rgb, vec3(2.2));
156+
let spatial_diffuse_brdf = spatial_base_color / PI;
151157
if pixel_dissimilar(depth, world_position, spatial_world_position, world_normal, spatial_world_normal) {
152-
return empty_reservoir();
158+
return SpatialInfo(empty_reservoir(), spatial_world_position, spatial_world_normal, spatial_diffuse_brdf);
153159
}
154160

155161
let spatial_pixel_index = spatial_pixel_id.x + spatial_pixel_id.y * u32(view.viewport.z);
156-
var spatial_reservoir = gi_reservoirs_b[spatial_pixel_index];
162+
let spatial_reservoir = gi_reservoirs_b[spatial_pixel_index];
157163

158-
var jacobian = jacobian(
159-
world_position,
160-
spatial_world_position,
161-
spatial_reservoir.sample_point_world_position,
162-
spatial_reservoir.sample_point_world_normal
163-
);
164-
if jacobian > 10.0 || jacobian < 0.1 {
165-
return empty_reservoir();
166-
}
167-
spatial_reservoir.unbiased_contribution_weight *= jacobian;
168-
169-
spatial_reservoir.unbiased_contribution_weight *= trace_point_visibility(world_position, spatial_reservoir.sample_point_world_position);
170-
171-
return spatial_reservoir;
164+
return SpatialInfo(spatial_reservoir, spatial_world_position, spatial_world_normal, spatial_diffuse_brdf);
172165
}
173166

174167
fn get_neighbor_pixel_id(center_pixel_id: vec2<u32>, rng: ptr<function, u32>) -> vec2<u32> {
@@ -178,13 +171,13 @@ fn get_neighbor_pixel_id(center_pixel_id: vec2<u32>, rng: ptr<function, u32>) ->
178171
}
179172

180173
fn jacobian(
181-
world_position: vec3<f32>,
182-
spatial_world_position: vec3<f32>,
174+
new_world_position: vec3<f32>,
175+
original_world_position: vec3<f32>,
183176
sample_point_world_position: vec3<f32>,
184177
sample_point_world_normal: vec3<f32>,
185178
) -> f32 {
186-
let r = world_position - sample_point_world_position;
187-
let q = spatial_world_position - sample_point_world_position;
179+
let r = new_world_position - sample_point_world_position;
180+
let q = original_world_position - sample_point_world_position;
188181
let rl = length(r);
189182
let ql = length(q);
190183
let phi_r = saturate(dot(r / rl, sample_point_world_normal));
@@ -256,34 +249,22 @@ fn empty_reservoir() -> Reservoir {
256249
);
257250
}
258251

259-
struct ReservoirMergeResult {
260-
merged_reservoir: Reservoir,
261-
selected_sample_radiance: vec3<f32>,
262-
}
263-
264252
fn merge_reservoirs(
265253
canonical_reservoir: Reservoir,
266254
other_reservoir: Reservoir,
267-
canonical_factor: vec3<f32>,
268-
other_factor: vec3<f32>,
269255
rng: ptr<function, u32>,
270-
) -> ReservoirMergeResult {
256+
) -> Reservoir {
271257
var combined_reservoir = empty_reservoir();
272258
combined_reservoir.confidence_weight = canonical_reservoir.confidence_weight + other_reservoir.confidence_weight;
273259

274-
if combined_reservoir.confidence_weight == 0.0 { return ReservoirMergeResult(combined_reservoir, vec3(0.0)); }
275-
276-
// TODO: Balance heuristic MIS weights
277-
let mis_weight_denominator = 1.0 / combined_reservoir.confidence_weight;
260+
let mis_weight_denominator = select(0.0, 1.0 / combined_reservoir.confidence_weight, combined_reservoir.confidence_weight > 0.0);
278261

279262
let canonical_mis_weight = canonical_reservoir.confidence_weight * mis_weight_denominator;
280-
let canonical_radiance = canonical_reservoir.radiance * canonical_factor;
281-
let canonical_target_function = luminance(canonical_radiance);
263+
let canonical_target_function = luminance(canonical_reservoir.radiance);
282264
let canonical_resampling_weight = canonical_mis_weight * (canonical_target_function * canonical_reservoir.unbiased_contribution_weight);
283265

284266
let other_mis_weight = other_reservoir.confidence_weight * mis_weight_denominator;
285-
let other_radiance = other_reservoir.radiance * other_factor;
286-
let other_target_function = luminance(other_radiance);
267+
let other_target_function = luminance(other_reservoir.radiance);
287268
let other_resampling_weight = other_mis_weight * (other_target_function * other_reservoir.unbiased_contribution_weight);
288269

289270
combined_reservoir.weight_sum = canonical_resampling_weight + other_resampling_weight;
@@ -295,16 +276,124 @@ fn merge_reservoirs(
295276

296277
let inverse_target_function = select(0.0, 1.0 / other_target_function, other_target_function > 0.0);
297278
combined_reservoir.unbiased_contribution_weight = combined_reservoir.weight_sum * inverse_target_function;
298-
299-
return ReservoirMergeResult(combined_reservoir, other_radiance);
300279
} else {
301280
combined_reservoir.sample_point_world_position = canonical_reservoir.sample_point_world_position;
302281
combined_reservoir.sample_point_world_normal = canonical_reservoir.sample_point_world_normal;
303282
combined_reservoir.radiance = canonical_reservoir.radiance;
304283

305284
let inverse_target_function = select(0.0, 1.0 / canonical_target_function, canonical_target_function > 0.0);
306285
combined_reservoir.unbiased_contribution_weight = combined_reservoir.weight_sum * inverse_target_function;
286+
}
287+
288+
return combined_reservoir;
289+
}
290+
291+
struct ReservoirMergeResult {
292+
merged_reservoir: Reservoir,
293+
selected_sample_radiance: vec3<f32>,
294+
}
295+
296+
fn merge_reservoirs_spatial(
297+
canonical_reservoir: Reservoir,
298+
canonical_world_position: vec3<f32>,
299+
canonical_world_normal: vec3<f32>,
300+
canonical_diffuse_brdf: vec3<f32>,
301+
other_reservoir: Reservoir,
302+
other_world_position: vec3<f32>,
303+
other_world_normal: vec3<f32>,
304+
other_diffuse_brdf: vec3<f32>,
305+
rng: ptr<function, u32>,
306+
) -> ReservoirMergeResult {
307+
// Radiances for resampling
308+
let canonical_sample_radiance =
309+
canonical_reservoir.radiance *
310+
saturate(dot(normalize(canonical_reservoir.sample_point_world_position - canonical_world_position), canonical_world_normal)) *
311+
canonical_diffuse_brdf;
312+
let other_sample_radiance =
313+
other_reservoir.radiance *
314+
saturate(dot(normalize(other_reservoir.sample_point_world_position - canonical_world_position), canonical_world_normal)) *
315+
canonical_diffuse_brdf *
316+
trace_point_visibility(canonical_world_position, other_reservoir.sample_point_world_position);
317+
318+
// Target functions for resampling and MIS
319+
let canonical_target_function_canonical_sample = luminance(canonical_sample_radiance);
320+
let canonical_target_function_other_sample = luminance(other_sample_radiance);
321+
322+
// Extra target functions for MIS
323+
let other_target_function_canonical_sample = luminance(
324+
canonical_reservoir.radiance *
325+
saturate(dot(normalize(canonical_reservoir.sample_point_world_position - other_world_position), other_world_normal)) *
326+
other_diffuse_brdf
327+
);
328+
let other_target_function_other_sample = luminance(
329+
other_reservoir.radiance *
330+
saturate(dot(normalize(other_reservoir.sample_point_world_position - other_world_position), other_world_normal)) *
331+
other_diffuse_brdf
332+
);
333+
334+
// Jacobians for resampling and MIS
335+
let canonical_target_function_other_sample_jacobian = jacobian(
336+
canonical_world_position,
337+
other_world_position,
338+
other_reservoir.sample_point_world_position,
339+
other_reservoir.sample_point_world_normal
340+
);
341+
let other_target_function_canonical_sample_jacobian = jacobian(
342+
other_world_position,
343+
canonical_world_position,
344+
canonical_reservoir.sample_point_world_position,
345+
canonical_reservoir.sample_point_world_normal
346+
);
347+
348+
// Resampling weight for canonical sample
349+
let canonical_sample_mis_weight = balance_heuristic(
350+
canonical_reservoir.confidence_weight * canonical_target_function_canonical_sample,
351+
other_reservoir.confidence_weight * other_target_function_canonical_sample * other_target_function_canonical_sample_jacobian,
352+
);
353+
let canonical_sample_resampling_weight = canonical_sample_mis_weight *
354+
canonical_target_function_canonical_sample *
355+
canonical_reservoir.unbiased_contribution_weight;
356+
357+
// Resampling weight for other sample
358+
let other_sample_mis_weight = balance_heuristic(
359+
other_reservoir.confidence_weight * other_target_function_other_sample,
360+
canonical_reservoir.confidence_weight * canonical_target_function_other_sample * canonical_target_function_other_sample_jacobian,
361+
);
362+
let other_sample_resampling_weight = other_sample_mis_weight *
363+
canonical_target_function_other_sample *
364+
other_reservoir.unbiased_contribution_weight *
365+
canonical_target_function_other_sample_jacobian;
366+
367+
// Perform resampling
368+
var combined_reservoir = empty_reservoir();
369+
combined_reservoir.confidence_weight = canonical_reservoir.confidence_weight + other_reservoir.confidence_weight;
370+
combined_reservoir.weight_sum = canonical_sample_resampling_weight + other_sample_resampling_weight;
371+
372+
if rand_f(rng) < other_sample_resampling_weight / combined_reservoir.weight_sum {
373+
combined_reservoir.sample_point_world_position = other_reservoir.sample_point_world_position;
374+
combined_reservoir.sample_point_world_normal = other_reservoir.sample_point_world_normal;
375+
combined_reservoir.radiance = other_reservoir.radiance;
376+
377+
let inverse_target_function = select(0.0, 1.0 / canonical_target_function_other_sample, canonical_target_function_other_sample > 0.0);
378+
combined_reservoir.unbiased_contribution_weight = combined_reservoir.weight_sum * inverse_target_function;
379+
380+
return ReservoirMergeResult(combined_reservoir, other_sample_radiance);
381+
} else {
382+
combined_reservoir.sample_point_world_position = canonical_reservoir.sample_point_world_position;
383+
combined_reservoir.sample_point_world_normal = canonical_reservoir.sample_point_world_normal;
384+
combined_reservoir.radiance = canonical_reservoir.radiance;
385+
386+
let inverse_target_function = select(0.0, 1.0 / canonical_target_function_canonical_sample, canonical_target_function_canonical_sample > 0.0);
387+
combined_reservoir.unbiased_contribution_weight = combined_reservoir.weight_sum * inverse_target_function;
388+
389+
return ReservoirMergeResult(combined_reservoir, canonical_sample_radiance);
390+
}
391+
}
307392

308-
return ReservoirMergeResult(combined_reservoir, canonical_radiance);
393+
fn balance_heuristic(x: f32, y: f32) -> f32 {
394+
let sum = x + y;
395+
if sum == 0.0 {
396+
return 0.0;
309397
}
398+
return x / sum;
310399
}

crates/bevy_solari/src/scene/binder.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ struct GpuLightSource {
363363
impl GpuLightSource {
364364
fn new_emissive_mesh_light(instance_id: u32, triangle_count: u32) -> GpuLightSource {
365365
if triangle_count > u16::MAX as u32 {
366-
panic!("Too triangles in an emissive mesh, maximum is 65535.");
366+
panic!("Too many triangles ({triangle_count}) in an emissive mesh, maximum is 65535.");
367367
}
368368

369369
Self {

release-content/release-notes/bevy_solari.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
---
22
title: Initial raytraced lighting progress (bevy_solari)
33
authors: ["@JMS55"]
4-
pull_requests: [19058, 19620, 19790, 20020, 20113, 20213]
4+
pull_requests: [19058, 19620, 19790, 20020, 20113, 20213, 20259]
55
---
66

77
(TODO: Embed solari example screenshot here)

0 commit comments

Comments
 (0)