From d3d14fcda38f52c428e822f41cf989706aa3a6df Mon Sep 17 00:00:00 2001 From: Vasily Zorin Date: Tue, 22 Apr 2025 01:52:36 +0700 Subject: [PATCH 1/4] linfa-preprocessing::countgrams: Multi-threading made possible --- .../src/countgrams/hyperparams.rs | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/algorithms/linfa-preprocessing/src/countgrams/hyperparams.rs b/algorithms/linfa-preprocessing/src/countgrams/hyperparams.rs index e051e1722..4f0401d8d 100644 --- a/algorithms/linfa-preprocessing/src/countgrams/hyperparams.rs +++ b/algorithms/linfa-preprocessing/src/countgrams/hyperparams.rs @@ -1,8 +1,8 @@ use crate::PreprocessingError; use linfa::ParamGuard; use regex::Regex; -use std::cell::{Ref, RefCell}; use std::collections::HashSet; +use std::sync::OnceLock; #[cfg(feature = "serde")] use serde_crate::{Deserialize, Serialize}; @@ -68,7 +68,7 @@ impl SerdeRegex { pub struct CountVectorizerValidParams { convert_to_lowercase: bool, split_regex_expr: String, - split_regex: RefCell>, + split_regex: OnceLock, n_gram_range: (usize, usize), normalize: bool, document_frequency: (f32, f32), @@ -92,8 +92,11 @@ impl CountVectorizerValidParams { self.convert_to_lowercase } - pub fn split_regex(&self) -> Ref<'_, Regex> { - Ref::map(self.split_regex.borrow(), |x| x.as_ref().unwrap().as_re()) + pub fn split_regex(&self) -> &Regex { + self.split_regex + .get() + .expect("Regex not initialized") + .as_re() } pub fn n_gram_range(&self) -> (usize, usize) { @@ -126,7 +129,7 @@ impl std::default::Default for CountVectorizerParams { Self(CountVectorizerValidParams { convert_to_lowercase: true, split_regex_expr: r"\b\w\w+\b".to_string(), - split_regex: RefCell::new(None), + split_regex: OnceLock::new(), n_gram_range: (1, 1), normalize: true, document_frequency: (0., 1.), @@ -224,7 +227,8 @@ impl ParamGuard for CountVectorizerParams { min_freq, max_freq, )) } else { - *self.0.split_regex.borrow_mut() = Some(SerdeRegex::new(&self.0.split_regex_expr)?); + let regex = SerdeRegex::new(&self.0.split_regex_expr)?; + let _ = self.0.split_regex.set(regex); Ok(&self.0) } From 776e31195f38582e355ab04530c8e99f32a2a69d Mon Sep 17 00:00:00 2001 From: Vasily Zorin Date: Wed, 11 Jun 2025 03:15:39 +0700 Subject: [PATCH 2/4] CountVectorizerValidParams: serde for split_regex --- .../linfa-preprocessing/src/countgrams/hyperparams.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/algorithms/linfa-preprocessing/src/countgrams/hyperparams.rs b/algorithms/linfa-preprocessing/src/countgrams/hyperparams.rs index 4f0401d8d..10c66eda6 100644 --- a/algorithms/linfa-preprocessing/src/countgrams/hyperparams.rs +++ b/algorithms/linfa-preprocessing/src/countgrams/hyperparams.rs @@ -35,8 +35,7 @@ impl SerdeRegex { } fn as_re(&self) -> &Regex { - use std::ops::Deref; - &self.0.deref() + &self.0 } } @@ -68,6 +67,7 @@ impl SerdeRegex { pub struct CountVectorizerValidParams { convert_to_lowercase: bool, split_regex_expr: String, + #[cfg_attr(feature = "serde", serde(skip, default = "OnceLock::new"))] split_regex: OnceLock, n_gram_range: (usize, usize), normalize: bool, @@ -95,7 +95,7 @@ impl CountVectorizerValidParams { pub fn split_regex(&self) -> &Regex { self.split_regex .get() - .expect("Regex not initialized") + .expect("Regex not initialized; call `check_ref()` first") .as_re() } @@ -124,7 +124,7 @@ impl CountVectorizerValidParams { #[derive(Clone, Debug)] pub struct CountVectorizerParams(CountVectorizerValidParams); -impl std::default::Default for CountVectorizerParams { +impl Default for CountVectorizerParams { fn default() -> Self { Self(CountVectorizerValidParams { convert_to_lowercase: true, From 0a23fc7f0d21cbd564ebe5c34b5268356dd012a0 Mon Sep 17 00:00:00 2001 From: Vasily Zorin Date: Wed, 11 Jun 2025 05:07:58 +0700 Subject: [PATCH 3/4] Code quality --- algorithms/linfa-preprocessing/src/countgrams/mod.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/algorithms/linfa-preprocessing/src/countgrams/mod.rs b/algorithms/linfa-preprocessing/src/countgrams/mod.rs index 57f8df27c..ad78eb48a 100644 --- a/algorithms/linfa-preprocessing/src/countgrams/mod.rs +++ b/algorithms/linfa-preprocessing/src/countgrams/mod.rs @@ -45,7 +45,7 @@ impl CountVectorizerValidParams { // word, (integer mapping for word, document frequency for word) let mut vocabulary: HashMap = HashMap::new(); for string in x.iter().map(|s| transform_string(s.to_string(), self)) { - self.read_document_into_vocabulary(string, &self.split_regex(), &mut vocabulary); + self.read_document_into_vocabulary(string, self.split_regex(), &mut vocabulary); } let mut vocabulary = self.filter_vocabulary(vocabulary, x.len()); @@ -92,7 +92,7 @@ impl CountVectorizerValidParams { } // safe unwrap now that error has been handled let document = transform_string(document.unwrap(), self); - self.read_document_into_vocabulary(document, &self.split_regex(), &mut vocabulary); + self.read_document_into_vocabulary(document, self.split_regex(), &mut vocabulary); } let mut vocabulary = self.filter_vocabulary(vocabulary, documents_count); @@ -340,7 +340,7 @@ impl CountVectorizer { sprs_vectorized.reserve_outer_dim_exact(x.len()); let regex = self.properties.split_regex(); for string in x.into_iter().map(|s| s.to_string()) { - let row = self.analyze_document(string, ®ex, document_frequencies.view_mut()); + let row = self.analyze_document(string, regex, document_frequencies.view_mut()); sprs_vectorized = sprs_vectorized.append_outer_csvec(row.view()); } (sprs_vectorized, document_frequencies) @@ -364,7 +364,7 @@ impl CountVectorizer { file.read_to_end(&mut document_bytes).unwrap(); let document = encoding::decode(&document_bytes, trap, encoding).0.unwrap(); sprs_vectorized = sprs_vectorized.append_outer_csvec( - self.analyze_document(document, ®ex, document_frequencies.view_mut()) + self.analyze_document(document, regex, document_frequencies.view_mut()) .view(), ); } From 2351e3319813b7a9e4bd376a790a9fc194c38bbc Mon Sep 17 00:00:00 2001 From: Vasily Zorin Date: Wed, 11 Jun 2025 11:16:20 +0700 Subject: [PATCH 4/4] Code quality --- algorithms/linfa-preprocessing/src/countgrams/hyperparams.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/algorithms/linfa-preprocessing/src/countgrams/hyperparams.rs b/algorithms/linfa-preprocessing/src/countgrams/hyperparams.rs index 10c66eda6..665f6b2e7 100644 --- a/algorithms/linfa-preprocessing/src/countgrams/hyperparams.rs +++ b/algorithms/linfa-preprocessing/src/countgrams/hyperparams.rs @@ -227,9 +227,7 @@ impl ParamGuard for CountVectorizerParams { min_freq, max_freq, )) } else { - let regex = SerdeRegex::new(&self.0.split_regex_expr)?; - let _ = self.0.split_regex.set(regex); - + let _ = self.0.split_regex.set(SerdeRegex::new(&self.0.split_regex_expr)?); Ok(&self.0) } }