diff --git a/algorithms/linfa-preprocessing/src/countgrams/hyperparams.rs b/algorithms/linfa-preprocessing/src/countgrams/hyperparams.rs index e051e1722..665f6b2e7 100644 --- a/algorithms/linfa-preprocessing/src/countgrams/hyperparams.rs +++ b/algorithms/linfa-preprocessing/src/countgrams/hyperparams.rs @@ -1,8 +1,8 @@ use crate::PreprocessingError; use linfa::ParamGuard; use regex::Regex; -use std::cell::{Ref, RefCell}; use std::collections::HashSet; +use std::sync::OnceLock; #[cfg(feature = "serde")] use serde_crate::{Deserialize, Serialize}; @@ -35,8 +35,7 @@ impl SerdeRegex { } fn as_re(&self) -> &Regex { - use std::ops::Deref; - &self.0.deref() + &self.0 } } @@ -68,7 +67,8 @@ impl SerdeRegex { pub struct CountVectorizerValidParams { convert_to_lowercase: bool, split_regex_expr: String, - split_regex: RefCell>, + #[cfg_attr(feature = "serde", serde(skip, default = "OnceLock::new"))] + split_regex: OnceLock, n_gram_range: (usize, usize), normalize: bool, document_frequency: (f32, f32), @@ -92,8 +92,11 @@ impl CountVectorizerValidParams { self.convert_to_lowercase } - pub fn split_regex(&self) -> Ref<'_, Regex> { - Ref::map(self.split_regex.borrow(), |x| x.as_ref().unwrap().as_re()) + pub fn split_regex(&self) -> &Regex { + self.split_regex + .get() + .expect("Regex not initialized; call `check_ref()` first") + .as_re() } pub fn n_gram_range(&self) -> (usize, usize) { @@ -121,12 +124,12 @@ impl CountVectorizerValidParams { #[derive(Clone, Debug)] pub struct CountVectorizerParams(CountVectorizerValidParams); -impl std::default::Default for CountVectorizerParams { +impl Default for CountVectorizerParams { fn default() -> Self { Self(CountVectorizerValidParams { convert_to_lowercase: true, split_regex_expr: r"\b\w\w+\b".to_string(), - split_regex: RefCell::new(None), + split_regex: OnceLock::new(), n_gram_range: (1, 1), normalize: true, document_frequency: (0., 1.), @@ -224,8 +227,7 @@ impl ParamGuard for CountVectorizerParams { min_freq, max_freq, )) } else { - *self.0.split_regex.borrow_mut() = Some(SerdeRegex::new(&self.0.split_regex_expr)?); - + let _ = self.0.split_regex.set(SerdeRegex::new(&self.0.split_regex_expr)?); Ok(&self.0) } } diff --git a/algorithms/linfa-preprocessing/src/countgrams/mod.rs b/algorithms/linfa-preprocessing/src/countgrams/mod.rs index 57f8df27c..ad78eb48a 100644 --- a/algorithms/linfa-preprocessing/src/countgrams/mod.rs +++ b/algorithms/linfa-preprocessing/src/countgrams/mod.rs @@ -45,7 +45,7 @@ impl CountVectorizerValidParams { // word, (integer mapping for word, document frequency for word) let mut vocabulary: HashMap = HashMap::new(); for string in x.iter().map(|s| transform_string(s.to_string(), self)) { - self.read_document_into_vocabulary(string, &self.split_regex(), &mut vocabulary); + self.read_document_into_vocabulary(string, self.split_regex(), &mut vocabulary); } let mut vocabulary = self.filter_vocabulary(vocabulary, x.len()); @@ -92,7 +92,7 @@ impl CountVectorizerValidParams { } // safe unwrap now that error has been handled let document = transform_string(document.unwrap(), self); - self.read_document_into_vocabulary(document, &self.split_regex(), &mut vocabulary); + self.read_document_into_vocabulary(document, self.split_regex(), &mut vocabulary); } let mut vocabulary = self.filter_vocabulary(vocabulary, documents_count); @@ -340,7 +340,7 @@ impl CountVectorizer { sprs_vectorized.reserve_outer_dim_exact(x.len()); let regex = self.properties.split_regex(); for string in x.into_iter().map(|s| s.to_string()) { - let row = self.analyze_document(string, ®ex, document_frequencies.view_mut()); + let row = self.analyze_document(string, regex, document_frequencies.view_mut()); sprs_vectorized = sprs_vectorized.append_outer_csvec(row.view()); } (sprs_vectorized, document_frequencies) @@ -364,7 +364,7 @@ impl CountVectorizer { file.read_to_end(&mut document_bytes).unwrap(); let document = encoding::decode(&document_bytes, trap, encoding).0.unwrap(); sprs_vectorized = sprs_vectorized.append_outer_csvec( - self.analyze_document(document, ®ex, document_frequencies.view_mut()) + self.analyze_document(document, regex, document_frequencies.view_mut()) .view(), ); }