Skip to content

Commit 3e4952f

Browse files
committed
This makes ChannelCount NonZero<u16> and channels not zero asserts
I ran into a lot of bugs while adding tests that had to do with channel being set to zero somewhere. While this change makes the API slightly less easy to use it prevents very hard to debug crashes/underflows etc. Performance might drop in decoders, the current implementation makes the bound check every time `channels` is called which is once per span. This could be cached to alleviate that.
1 parent f2cac6e commit 3e4952f

40 files changed

+244
-204
lines changed

benches/pipeline.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use std::time::Duration;
22

33
use divan::Bencher;
4+
use rodio::ChannelCount;
45
use rodio::{source::UniformSourceIterator, Source};
56

67
mod shared;
@@ -31,7 +32,8 @@ fn long(bencher: Bencher) {
3132
.buffered()
3233
.reverb(Duration::from_secs_f32(0.05), 0.3)
3334
.skippable();
34-
let resampled = UniformSourceIterator::new(effects_applied, 2, 40_000);
35+
let resampled =
36+
UniformSourceIterator::new(effects_applied, ChannelCount::new(2).unwrap(), 40_000);
3537
resampled.for_each(divan::black_box_drop)
3638
})
3739
}

benches/shared.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use rodio::{ChannelCount, Sample, SampleRate, Source};
66

77
pub struct TestSource {
88
samples: vec::IntoIter<Sample>,
9-
channels: u16,
9+
channels: ChannelCount,
1010
sample_rate: u32,
1111
total_duration: Duration,
1212
}

examples/mix_multiple_sources.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
use rodio::mixer;
22
use rodio::source::{SineWave, Source};
33
use std::error::Error;
4+
use std::num::NonZero;
45
use std::time::Duration;
56

67
fn main() -> Result<(), Box<dyn Error>> {
78
// Construct a dynamic controller and mixer, stream_handle, and sink.
8-
let (controller, mixer) = mixer::mixer(2, 44_100);
9+
let (controller, mixer) = mixer::mixer(NonZero::new(2).unwrap(), 44_100);
910
let stream_handle = rodio::OutputStreamBuilder::open_default_stream()?;
1011
let sink = rodio::Sink::connect_new(&stream_handle.mixer());
1112

src/buffer.rs

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
//!
77
//! ```
88
//! use rodio::buffer::SamplesBuffer;
9-
//! let _ = SamplesBuffer::new(1, 44100, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
9+
//! use rodio::ChannelCount;
10+
//! let _ = SamplesBuffer::new(ChannelCount::new(1).unwrap(), 44100, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1011
//! ```
1112
//!
1213
@@ -30,7 +31,6 @@ impl SamplesBuffer {
3031
///
3132
/// # Panic
3233
///
33-
/// - Panics if the number of channels is zero.
3434
/// - Panics if the samples rate is zero.
3535
/// - Panics if the length of the buffer is larger than approximately 16 billion elements.
3636
/// This is because the calculation of the duration would overflow.
@@ -39,13 +39,12 @@ impl SamplesBuffer {
3939
where
4040
D: Into<Vec<Sample>>,
4141
{
42-
assert!(channels >= 1);
4342
assert!(sample_rate >= 1);
4443

4544
let data = data.into();
4645
let duration_ns = 1_000_000_000u64.checked_mul(data.len() as u64).unwrap()
4746
/ sample_rate as u64
48-
/ channels as u64;
47+
/ channels.get() as u64;
4948
let duration = Duration::new(
5049
duration_ns / 1_000_000_000,
5150
(duration_ns % 1_000_000_000) as u32,
@@ -89,14 +88,14 @@ impl Source for SamplesBuffer {
8988
// and due to the constant sample_rate we can jump to the right
9089
// sample directly.
9190

92-
let curr_channel = self.pos % self.channels() as usize;
93-
let new_pos = pos.as_secs_f32() * self.sample_rate() as f32 * self.channels() as f32;
91+
let curr_channel = self.pos % self.channels().get() as usize;
92+
let new_pos = pos.as_secs_f32() * self.sample_rate() as f32 * self.channels().get() as f32;
9493
// saturate pos at the end of the source
9594
let new_pos = new_pos as usize;
9695
let new_pos = new_pos.min(self.data.len());
9796

9897
// make sure the next sample is for the right channel
99-
let new_pos = new_pos.next_multiple_of(self.channels() as usize);
98+
let new_pos = new_pos.next_multiple_of(self.channels().get() as usize);
10099
let new_pos = new_pos - curr_channel;
101100

102101
self.pos = new_pos;
@@ -123,36 +122,31 @@ impl Iterator for SamplesBuffer {
123122
#[cfg(test)]
124123
mod tests {
125124
use crate::buffer::SamplesBuffer;
125+
use crate::math::ch;
126126
use crate::source::Source;
127127

128128
#[test]
129129
fn basic() {
130-
let _ = SamplesBuffer::new(1, 44100, vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
131-
}
132-
133-
#[test]
134-
#[should_panic]
135-
fn panic_if_zero_channels() {
136-
SamplesBuffer::new(0, 44100, vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
130+
let _ = SamplesBuffer::new(ch!(1), 44100, vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
137131
}
138132

139133
#[test]
140134
#[should_panic]
141135
fn panic_if_zero_sample_rate() {
142-
SamplesBuffer::new(1, 0, vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
136+
SamplesBuffer::new(ch!(1), 0, vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
143137
}
144138

145139
#[test]
146140
fn duration_basic() {
147-
let buf = SamplesBuffer::new(2, 2, vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
141+
let buf = SamplesBuffer::new(ch!(2), 2, vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
148142
let dur = buf.total_duration().unwrap();
149143
assert_eq!(dur.as_secs(), 1);
150144
assert_eq!(dur.subsec_nanos(), 500_000_000);
151145
}
152146

153147
#[test]
154148
fn iteration() {
155-
let mut buf = SamplesBuffer::new(1, 44100, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
149+
let mut buf = SamplesBuffer::new(ch!(1), 44100, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
156150
assert_eq!(buf.next(), Some(1.0));
157151
assert_eq!(buf.next(), Some(2.0));
158152
assert_eq!(buf.next(), Some(3.0));
@@ -172,7 +166,7 @@ mod tests {
172166
#[test]
173167
fn channel_order_stays_correct() {
174168
const SAMPLE_RATE: SampleRate = 100;
175-
const CHANNELS: ChannelCount = 2;
169+
const CHANNELS: ChannelCount = ch!(2);
176170
let mut buf = SamplesBuffer::new(
177171
CHANNELS,
178172
SAMPLE_RATE,
@@ -182,7 +176,10 @@ mod tests {
182176
.collect::<Vec<_>>(),
183177
);
184178
buf.try_seek(Duration::from_secs(5)).unwrap();
185-
assert_eq!(buf.next(), Some(5.0 * SAMPLE_RATE as f32 * CHANNELS as f32));
179+
assert_eq!(
180+
buf.next(),
181+
Some(5.0 * SAMPLE_RATE as f32 * CHANNELS.get() as f32)
182+
);
186183

187184
assert!(buf.next().is_some_and(|s| s.trunc() as i32 % 2 == 1));
188185
assert!(buf.next().is_some_and(|s| s.trunc() as i32 % 2 == 0));

src/common.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
use std::num::NonZero;
2+
13
/// Stream sample rate (a frame rate or samples per second per channel).
24
pub type SampleRate = u32;
35

4-
/// Number of channels in a stream.
5-
pub type ChannelCount = u16;
6+
/// Number of channels in a stream. Can never be Zero
7+
pub type ChannelCount = NonZero<u16>;
68

79
/// Represents value of a single sample.
810
/// Silence corresponds to the value `0.0`. The expected amplitude range is -1.0...1.0.

src/conversions/channels.rs

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ where
1111
from: ChannelCount,
1212
to: ChannelCount,
1313
sample_repeat: Option<Sample>,
14-
next_output_sample_pos: ChannelCount,
14+
next_output_sample_pos: u16,
1515
}
1616

1717
impl<I> ChannelCountConverter<I>
@@ -26,9 +26,6 @@ where
2626
///
2727
#[inline]
2828
pub fn new(input: I, from: ChannelCount, to: ChannelCount) -> ChannelCountConverter<I> {
29-
assert!(from >= 1);
30-
assert!(to >= 1);
31-
3229
ChannelCountConverter {
3330
input,
3431
from,
@@ -65,7 +62,7 @@ where
6562
self.sample_repeat = value;
6663
value
6764
}
68-
x if x < self.from => self.input.next(),
65+
x if x < self.from.get() => self.input.next(),
6966
1 => self.sample_repeat,
7067
_ => Some(0.0),
7168
};
@@ -74,11 +71,11 @@ where
7471
self.next_output_sample_pos += 1;
7572
}
7673

77-
if self.next_output_sample_pos == self.to {
74+
if self.next_output_sample_pos == self.to.get() {
7875
self.next_output_sample_pos = 0;
7976

8077
if self.from > self.to {
81-
for _ in self.to..self.from {
78+
for _ in self.to.get()..self.from.get() {
8279
self.input.next(); // discarding extra input
8380
}
8481
}
@@ -91,13 +88,13 @@ where
9188
fn size_hint(&self) -> (usize, Option<usize>) {
9289
let (min, max) = self.input.size_hint();
9390

94-
let consumed = std::cmp::min(self.from, self.next_output_sample_pos) as usize;
91+
let consumed = std::cmp::min(self.from.get(), self.next_output_sample_pos) as usize;
9592

96-
let min = ((min + consumed) / self.from as usize * self.to as usize)
93+
let min = ((min + consumed) / self.from.get() as usize * self.to.get() as usize)
9794
.saturating_sub(self.next_output_sample_pos as usize);
9895

9996
let max = max.map(|max| {
100-
((max + consumed) / self.from as usize * self.to as usize)
97+
((max + consumed) / self.from.get() as usize * self.to.get() as usize)
10198
.saturating_sub(self.next_output_sample_pos as usize)
10299
});
103100

@@ -111,31 +108,37 @@ impl<I> ExactSizeIterator for ChannelCountConverter<I> where I: ExactSizeIterato
111108
mod test {
112109
use super::ChannelCountConverter;
113110
use crate::common::ChannelCount;
111+
use crate::math::ch;
114112
use crate::Sample;
115113

116114
#[test]
117115
fn remove_channels() {
118116
let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
119-
let output = ChannelCountConverter::new(input.into_iter(), 3, 2).collect::<Vec<_>>();
117+
let output =
118+
ChannelCountConverter::new(input.into_iter(), ch!(3), ch!(2)).collect::<Vec<_>>();
120119
assert_eq!(output, [1.0, 2.0, 4.0, 5.0]);
121120

122121
let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
123-
let output = ChannelCountConverter::new(input.into_iter(), 4, 1).collect::<Vec<_>>();
122+
let output =
123+
ChannelCountConverter::new(input.into_iter(), ch!(4), ch!(1)).collect::<Vec<_>>();
124124
assert_eq!(output, [1.0, 5.0]);
125125
}
126126

127127
#[test]
128128
fn add_channels() {
129129
let input = vec![1.0, 2.0, 3.0, 4.0];
130-
let output = ChannelCountConverter::new(input.into_iter(), 1, 2).collect::<Vec<_>>();
130+
let output =
131+
ChannelCountConverter::new(input.into_iter(), ch!(1), ch!(2)).collect::<Vec<_>>();
131132
assert_eq!(output, [1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0]);
132133

133134
let input = vec![1.0, 2.0];
134-
let output = ChannelCountConverter::new(input.into_iter(), 1, 4).collect::<Vec<_>>();
135+
let output =
136+
ChannelCountConverter::new(input.into_iter(), ch!(1), ch!(4)).collect::<Vec<_>>();
135137
assert_eq!(output, [1.0, 1.0, 0.0, 0.0, 2.0, 2.0, 0.0, 0.0]);
136138

137139
let input = vec![1.0, 2.0, 3.0, 4.0];
138-
let output = ChannelCountConverter::new(input.into_iter(), 2, 4).collect::<Vec<_>>();
140+
let output =
141+
ChannelCountConverter::new(input.into_iter(), ch!(2), ch!(4)).collect::<Vec<_>>();
139142
assert_eq!(output, [1.0, 2.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0]);
140143
}
141144

@@ -152,24 +155,24 @@ mod test {
152155
assert_eq!(converter.size_hint(), (0, Some(0)));
153156
}
154157

155-
test(&[1.0, 2.0, 3.0], 1, 2);
156-
test(&[1.0, 2.0, 3.0, 4.0], 2, 4);
157-
test(&[1.0, 2.0, 3.0, 4.0], 4, 2);
158-
test(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 3, 8);
159-
test(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], 4, 1);
158+
test(&[1.0, 2.0, 3.0], ch!(1), ch!(2));
159+
test(&[1.0, 2.0, 3.0, 4.0], ch!(2), ch!(4));
160+
test(&[1.0, 2.0, 3.0, 4.0], ch!(4), ch!(2));
161+
test(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], ch!(3), ch!(8));
162+
test(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], ch!(4), ch!(1));
160163
}
161164

162165
#[test]
163166
fn len_more() {
164167
let input = vec![1.0, 2.0, 3.0, 4.0];
165-
let output = ChannelCountConverter::new(input.into_iter(), 2, 3);
168+
let output = ChannelCountConverter::new(input.into_iter(), ch!(2), ch!(3));
166169
assert_eq!(output.len(), 6);
167170
}
168171

169172
#[test]
170173
fn len_less() {
171174
let input = vec![1.0, 2.0, 3.0, 4.0];
172-
let output = ChannelCountConverter::new(input.into_iter(), 2, 1);
175+
let output = ChannelCountConverter::new(input.into_iter(), ch!(2), ch!(1));
173176
assert_eq!(output.len(), 2);
174177
}
175178
}

0 commit comments

Comments
 (0)