Skip to content

Commit 5cddaf8

Browse files
committed
This makes ChannelCount NonZero<u16> and channels not zero asserts
I ran into a lot of bugs while adding tests that had to do with channel being set to zero somewhere. While this change makes the API slightly less easy to use it prevents very hard to debug crashes/underflows etc. Performance might drop in decoders, the current implementation makes the bound check every time `channels` is called which is once per span. This could be cached to alleviate that.
1 parent f3e1c9b commit 5cddaf8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+298
-235
lines changed

benches/pipeline.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use std::time::Duration;
22

33
use divan::Bencher;
4+
use rodio::ChannelCount;
45
use rodio::{source::UniformSourceIterator, Source};
56

67
mod shared;
@@ -30,7 +31,8 @@ fn long(bencher: Bencher) {
3031
.buffered()
3132
.reverb(Duration::from_secs_f32(0.05), 0.3)
3233
.skippable();
33-
let resampled = UniformSourceIterator::new(effects_applied, 2, 40_000);
34+
let resampled =
35+
UniformSourceIterator::new(effects_applied, ChannelCount::new(2).unwrap(), 40_000);
3436
resampled.for_each(divan::black_box_drop)
3537
})
3638
}

benches/shared.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use rodio::{ChannelCount, Sample, SampleRate, Source};
66

77
pub struct TestSource {
88
samples: vec::IntoIter<Sample>,
9-
channels: u16,
9+
channels: ChannelCount,
1010
sample_rate: u32,
1111
total_duration: Duration,
1212
}

examples/mix_multiple_sources.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
use rodio::mixer;
22
use rodio::source::{SineWave, Source};
33
use std::error::Error;
4+
use std::num::NonZero;
45
use std::time::Duration;
56

67
fn main() -> Result<(), Box<dyn Error>> {
78
// Construct a dynamic controller and mixer, stream_handle, and sink.
8-
let (controller, mixer) = mixer::mixer(2, 44_100);
9+
let (controller, mixer) = mixer::mixer(NonZero::new(2).unwrap(), 44_100);
910
let stream_handle = rodio::OutputStreamBuilder::open_default_stream()?;
1011
let sink = rodio::Sink::connect_new(&stream_handle.mixer());
1112

src/buffer.rs

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ impl SamplesBuffer {
3030
///
3131
/// # Panic
3232
///
33-
/// - Panics if the number of channels is zero.
3433
/// - Panics if the samples rate is zero.
3534
/// - Panics if the length of the buffer is larger than approximately 16 billion elements.
3635
/// This is because the calculation of the duration would overflow.
@@ -39,13 +38,12 @@ impl SamplesBuffer {
3938
where
4039
D: Into<Vec<Sample>>,
4140
{
42-
assert!(channels >= 1);
4341
assert!(sample_rate >= 1);
4442

4543
let data = data.into();
4644
let duration_ns = 1_000_000_000u64.checked_mul(data.len() as u64).unwrap()
4745
/ sample_rate as u64
48-
/ channels as u64;
46+
/ channels.get() as u64;
4947
let duration = Duration::new(
5048
duration_ns / 1_000_000_000,
5149
(duration_ns % 1_000_000_000) as u32,
@@ -89,14 +87,14 @@ impl Source for SamplesBuffer {
8987
// and due to the constant sample_rate we can jump to the right
9088
// sample directly.
9189

92-
let curr_channel = self.pos % self.channels() as usize;
93-
let new_pos = pos.as_secs_f32() * self.sample_rate() as f32 * self.channels() as f32;
90+
let curr_channel = self.pos % self.channels().get() as usize;
91+
let new_pos = pos.as_secs_f32() * self.sample_rate() as f32 * self.channels().get() as f32;
9492
// saturate pos at the end of the source
9593
let new_pos = new_pos as usize;
9694
let new_pos = new_pos.min(self.data.len());
9795

9896
// make sure the next sample is for the right channel
99-
let new_pos = new_pos.next_multiple_of(self.channels() as usize);
97+
let new_pos = new_pos.next_multiple_of(self.channels().get() as usize);
10098
let new_pos = new_pos - curr_channel;
10199

102100
self.pos = new_pos;
@@ -123,36 +121,31 @@ impl Iterator for SamplesBuffer {
123121
#[cfg(test)]
124122
mod tests {
125123
use crate::buffer::SamplesBuffer;
124+
use crate::math::ch;
126125
use crate::source::Source;
127126

128127
#[test]
129128
fn basic() {
130-
let _ = SamplesBuffer::new(1, 44100, vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
131-
}
132-
133-
#[test]
134-
#[should_panic]
135-
fn panic_if_zero_channels() {
136-
SamplesBuffer::new(0, 44100, vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
129+
let _ = SamplesBuffer::new(ch!(1), 44100, vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
137130
}
138131

139132
#[test]
140133
#[should_panic]
141134
fn panic_if_zero_sample_rate() {
142-
SamplesBuffer::new(1, 0, vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
135+
SamplesBuffer::new(ch!(1), 0, vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
143136
}
144137

145138
#[test]
146139
fn duration_basic() {
147-
let buf = SamplesBuffer::new(2, 2, vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
140+
let buf = SamplesBuffer::new(ch!(2), 2, vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
148141
let dur = buf.total_duration().unwrap();
149142
assert_eq!(dur.as_secs(), 1);
150143
assert_eq!(dur.subsec_nanos(), 500_000_000);
151144
}
152145

153146
#[test]
154147
fn iteration() {
155-
let mut buf = SamplesBuffer::new(1, 44100, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
148+
let mut buf = SamplesBuffer::new(ch!(1), 44100, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
156149
assert_eq!(buf.next(), Some(1.0));
157150
assert_eq!(buf.next(), Some(2.0));
158151
assert_eq!(buf.next(), Some(3.0));
@@ -172,7 +165,7 @@ mod tests {
172165
#[test]
173166
fn channel_order_stays_correct() {
174167
const SAMPLE_RATE: SampleRate = 100;
175-
const CHANNELS: ChannelCount = 2;
168+
const CHANNELS: ChannelCount = ch!(2);
176169
let mut buf = SamplesBuffer::new(
177170
CHANNELS,
178171
SAMPLE_RATE,
@@ -182,7 +175,10 @@ mod tests {
182175
.collect::<Vec<_>>(),
183176
);
184177
buf.try_seek(Duration::from_secs(5)).unwrap();
185-
assert_eq!(buf.next(), Some(5.0 * SAMPLE_RATE as f32 * CHANNELS as f32));
178+
assert_eq!(
179+
buf.next(),
180+
Some(5.0 * SAMPLE_RATE as f32 * CHANNELS.get() as f32)
181+
);
186182

187183
assert!(buf.next().is_some_and(|s| s.trunc() as i32 % 2 == 1));
188184
assert!(buf.next().is_some_and(|s| s.trunc() as i32 % 2 == 0));

src/common.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
use std::num::NonZero;
2+
13
/// Stream sample rate (a frame rate or samples per second per channel).
24
pub type SampleRate = u32;
35

4-
/// Number of channels in a stream.
5-
pub type ChannelCount = u16;
6+
/// Number of channels in a stream. Can never be Zero
7+
pub type ChannelCount = NonZero<u16>;
68

79
/// Represents value of a single sample.
810
/// Silence corresponds to the value `0.0`. The expected amplitude range is -1.0...1.0.

src/conversions/channels.rs

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ where
1111
from: ChannelCount,
1212
to: ChannelCount,
1313
sample_repeat: Option<Sample>,
14-
next_output_sample_pos: ChannelCount,
14+
next_output_sample_pos: u16,
1515
}
1616

1717
impl<I> ChannelCountConverter<I>
@@ -26,9 +26,6 @@ where
2626
///
2727
#[inline]
2828
pub fn new(input: I, from: ChannelCount, to: ChannelCount) -> ChannelCountConverter<I> {
29-
assert!(from >= 1);
30-
assert!(to >= 1);
31-
3229
ChannelCountConverter {
3330
input,
3431
from,
@@ -65,7 +62,7 @@ where
6562
self.sample_repeat = value;
6663
value
6764
}
68-
x if x < self.from => self.input.next(),
65+
x if x < self.from.get() => self.input.next(),
6966
1 => self.sample_repeat,
7067
_ => Some(0.0),
7168
};
@@ -74,11 +71,11 @@ where
7471
self.next_output_sample_pos += 1;
7572
}
7673

77-
if self.next_output_sample_pos == self.to {
74+
if self.next_output_sample_pos == self.to.get() {
7875
self.next_output_sample_pos = 0;
7976

8077
if self.from > self.to {
81-
for _ in self.to..self.from {
78+
for _ in self.to.get()..self.from.get() {
8279
self.input.next(); // discarding extra input
8380
}
8481
}
@@ -91,9 +88,9 @@ where
9188
fn size_hint(&self) -> (usize, Option<usize>) {
9289
let (min, max) = self.input.size_hint();
9390

94-
let consumed = std::cmp::min(self.from, self.next_output_sample_pos) as usize;
91+
let consumed = std::cmp::min(self.from.get(), self.next_output_sample_pos) as usize;
9592
let calculate = |size| {
96-
(size + consumed) / self.from as usize * self.to as usize
93+
(size + consumed) / self.from.get() as usize * self.to.get() as usize
9794
- self.next_output_sample_pos as usize
9895
};
9996

@@ -110,38 +107,45 @@ impl<I> ExactSizeIterator for ChannelCountConverter<I> where I: ExactSizeIterato
110107
mod test {
111108
use super::ChannelCountConverter;
112109
use crate::common::ChannelCount;
110+
use crate::math::ch;
113111
use crate::Sample;
114112

115113
#[test]
116114
fn remove_channels() {
117115
let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
118-
let output = ChannelCountConverter::new(input.into_iter(), 3, 2).collect::<Vec<_>>();
116+
let output =
117+
ChannelCountConverter::new(input.into_iter(), ch!(3), ch!(2)).collect::<Vec<_>>();
119118
assert_eq!(output, [1.0, 2.0, 4.0, 5.0]);
120119

121120
let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
122-
let output = ChannelCountConverter::new(input.into_iter(), 4, 1).collect::<Vec<_>>();
121+
let output =
122+
ChannelCountConverter::new(input.into_iter(), ch!(4), ch!(1)).collect::<Vec<_>>();
123123
assert_eq!(output, [1.0, 5.0]);
124124
}
125125

126126
#[test]
127127
fn add_channels() {
128128
let input = vec![1.0, 2.0, 3.0, 4.0];
129-
let output = ChannelCountConverter::new(input.into_iter(), 1, 2).collect::<Vec<_>>();
129+
let output =
130+
ChannelCountConverter::new(input.into_iter(), ch!(1), ch!(2)).collect::<Vec<_>>();
130131
assert_eq!(output, [1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0]);
131132

132133
let input = vec![1.0, 2.0];
133-
let output = ChannelCountConverter::new(input.into_iter(), 1, 4).collect::<Vec<_>>();
134+
let output =
135+
ChannelCountConverter::new(input.into_iter(), ch!(1), ch!(4)).collect::<Vec<_>>();
134136
assert_eq!(output, [1.0, 1.0, 0.0, 0.0, 2.0, 2.0, 0.0, 0.0]);
135137

136138
let input = vec![1.0, 2.0, 3.0, 4.0];
137-
let output = ChannelCountConverter::new(input.into_iter(), 2, 4).collect::<Vec<_>>();
139+
let output =
140+
ChannelCountConverter::new(input.into_iter(), ch!(2), ch!(4)).collect::<Vec<_>>();
138141
assert_eq!(output, [1.0, 2.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0]);
139142
}
140143

141144
#[test]
142145
fn size_hint() {
143146
fn test(input: &[Sample], from: ChannelCount, to: ChannelCount) {
144-
let mut converter = ChannelCountConverter::new(input.iter().copied(), from, to);
147+
let mut converter =
148+
ChannelCountConverter::new(input.iter().copied(), from, to);
145149
let count = converter.clone().count();
146150
for left_in_iter in (0..=count).rev() {
147151
println!("left_in_iter = {left_in_iter}");
@@ -151,24 +155,24 @@ mod test {
151155
assert_eq!(converter.size_hint(), (0, Some(0)));
152156
}
153157

154-
test(&[1.0, 2.0, 3.0], 1, 2);
155-
test(&[1.0, 2.0, 3.0, 4.0], 2, 4);
156-
test(&[1.0, 2.0, 3.0, 4.0], 4, 2);
157-
test(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 8);
158-
test(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], 4, 1);
158+
test(&[1.0, 2.0, 3.0], ch!(1), ch!(2));
159+
test(&[1.0, 2.0, 3.0, 4.0], ch!(2), ch!(4));
160+
test(&[1.0, 2.0, 3.0, 4.0], ch!(4), ch!(2));
161+
test(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], ch!(3), ch!(8));
162+
test(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], ch!(4), ch!(1));
159163
}
160164

161165
#[test]
162166
fn len_more() {
163167
let input = vec![1.0, 2.0, 3.0, 4.0];
164-
let output = ChannelCountConverter::new(input.into_iter(), 2, 3);
168+
let output = ChannelCountConverter::new(input.into_iter(), ch!(2), ch!(3));
165169
assert_eq!(output.len(), 6);
166170
}
167171

168172
#[test]
169173
fn len_less() {
170174
let input = vec![1.0, 2.0, 3.0, 4.0];
171-
let output = ChannelCountConverter::new(input.into_iter(), 2, 1);
175+
let output = ChannelCountConverter::new(input.into_iter(), ch!(2), ch!(1));
172176
assert_eq!(output.len(), 2);
173177
}
174178
}

0 commit comments

Comments
 (0)