Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions crates/core_simd/examples/dot_product.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,17 @@ pub fn dot_prod_simd_5(a: &[f32], b: &[f32]) -> f32 {
.reduce_sum()
}

// Using the dot() API - clearer and more expressive than manual multiply + reduce_sum.
pub fn dot_prod_with_api(a: &[f32], b: &[f32]) -> f32 {
a.as_chunks::<4>()
.0
.iter()
.map(|&a| f32x4::from_array(a))
.zip(b.as_chunks::<4>().0.iter().map(|&b| f32x4::from_array(b)))
.map(|(a, b)| a.dot(b))
.sum()
}

fn main() {
// Empty main to make cargo happy
}
Expand All @@ -169,6 +180,7 @@ mod tests {
assert_eq!(0.0, dot_prod_simd_3(&a, &b));
assert_eq!(0.0, dot_prod_simd_4(&a, &b));
assert_eq!(0.0, dot_prod_simd_5(&a, &b));
assert_eq!(0.0, dot_prod_with_api(&a, &b));

// We can handle vectors that are non-multiples of 4
assert_eq!(1003.0, dot_prod_simd_3(&x, &y));
Expand Down
4 changes: 0 additions & 4 deletions crates/core_simd/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,6 @@
any(target_arch = "powerpc", target_arch = "powerpc64"),
feature(stdarch_powerpc)
)]
#![cfg_attr(
all(target_arch = "x86_64", target_feature = "avx512f"),
feature(stdarch_x86_avx512)
)]
#![warn(missing_docs, clippy::missing_inline_in_public_items)] // basically all items, really
#![deny(
unsafe_op_in_unsafe_fn,
Expand Down
83 changes: 83 additions & 0 deletions crates/core_simd/src/simd/num/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,70 @@ pub trait SimdFloat: Copy + Sealed {
/// assert!(v.reduce_min().is_nan());
/// ```
fn reduce_min(self) -> Self::Scalar;

/// Computes the dot product of two vectors by multiplying corresponding elements
/// and summing the results.
///
/// This is equivalent to `(self * rhs).reduce_sum()`, but expresses the intent
/// more clearly and may use platform-specific optimizations.
///
/// # Examples
///
/// ```
/// # #![feature(portable_simd)]
/// # #[cfg(feature = "as_crate")] use core_simd::simd;
/// # #[cfg(not(feature = "as_crate"))] use core::simd;
/// # use simd::prelude::*;
/// let a = f32x4::from_array([1.0, 2.0, 3.0, 4.0]);
/// let b = f32x4::from_array([5.0, 6.0, 7.0, 8.0]);
/// assert_eq!(a.dot(b), 70.0); // 1*5 + 2*6 + 3*7 + 4*8
/// ```
#[must_use = "method returns the dot product and does not mutate the original value"]
fn dot(self, rhs: Self) -> Self::Scalar;

/// Computes the dot product of the first 3 elements, ignoring the rest.
///
/// Computes `self[0]*rhs[0] + self[1]*rhs[1] + self[2]*rhs[2]`.
///
/// The 4th element (w component) and any elements beyond index 2 are ignored.
/// This is useful for 3D vector operations where vectors are stored in
/// 4-element SIMD registers for alignment.
///
/// # Examples
///
/// ```
/// # #![feature(portable_simd)]
/// # #[cfg(feature = "as_crate")] use core_simd::simd;
/// # #[cfg(not(feature = "as_crate"))] use core::simd;
/// # use simd::prelude::*;
/// let a = f32x4::from_array([1.0, 2.0, 3.0, 999.0]);
/// let b = f32x4::from_array([4.0, 5.0, 6.0, 888.0]);
/// assert_eq!(a.dot3(b), 32.0); // 1*4 + 2*5 + 3*6 = 32
/// // Note: w component (999.0, 888.0) is ignored
/// ```
#[must_use = "method returns the dot product and does not mutate the original value"]
fn dot3(self, rhs: Self) -> Self::Scalar;

/// Computes the dot product of the first 4 elements, ignoring the rest.
///
/// Computes `self[0]*rhs[0] + self[1]*rhs[1] + self[2]*rhs[2] + self[3]*rhs[3]`.
///
/// Any elements beyond index 3 are ignored. For `Simd<T, 4>` types,
/// this is equivalent to [`dot`](Self::dot).
///
/// # Examples
///
/// ```
/// # #![feature(portable_simd)]
/// # #[cfg(feature = "as_crate")] use core_simd::simd;
/// # #[cfg(not(feature = "as_crate"))] use core::simd;
/// # use simd::prelude::*;
/// let a = f32x4::from_array([1.0, 2.0, 3.0, 4.0]);
/// let b = f32x4::from_array([5.0, 6.0, 7.0, 8.0]);
/// assert_eq!(a.dot4(b), 70.0); // 1*5 + 2*6 + 3*7 + 4*8
/// ```
#[must_use = "method returns the dot product and does not mutate the original value"]
fn dot4(self, rhs: Self) -> Self::Scalar;
}

macro_rules! impl_trait {
Expand Down Expand Up @@ -439,6 +503,25 @@ macro_rules! impl_trait {
// Safety: `self` is a float vector
unsafe { core::intrinsics::simd::simd_reduce_min(self) }
}

#[inline]
fn dot(self, rhs: Self) -> Self::Scalar {
(self * rhs).reduce_sum()
}

#[inline]
fn dot3(self, rhs: Self) -> Self::Scalar {
const { assert!(N >= 3, "dot3 requires at least 3 elements") };
let product = self * rhs;
product[0] + product[1] + product[2]
}

#[inline]
fn dot4(self, rhs: Self) -> Self::Scalar {
const { assert!(N >= 4, "dot4 requires at least 4 elements") };
let product = self * rhs;
product[0] + product[1] + product[2] + product[3]
}
}
)*
}
Expand Down
Loading