Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions algorithms/linfa-preprocessing/src/countgrams/hyperparams.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use crate::PreprocessingError;
use linfa::ParamGuard;
use regex::Regex;
use std::cell::{Ref, RefCell};
use std::collections::HashSet;
use std::sync::OnceLock;

#[cfg(feature = "serde")]
use serde_crate::{Deserialize, Serialize};
Expand Down Expand Up @@ -35,8 +35,7 @@ impl SerdeRegex {
}

fn as_re(&self) -> &Regex {
use std::ops::Deref;
&self.0.deref()
&self.0
}
}

Expand Down Expand Up @@ -68,7 +67,8 @@ impl SerdeRegex {
pub struct CountVectorizerValidParams {
convert_to_lowercase: bool,
split_regex_expr: String,
split_regex: RefCell<Option<SerdeRegex>>,
#[cfg_attr(feature = "serde", serde(skip, default = "OnceLock::new"))]
split_regex: OnceLock<SerdeRegex>,
n_gram_range: (usize, usize),
normalize: bool,
document_frequency: (f32, f32),
Expand All @@ -92,8 +92,11 @@ impl CountVectorizerValidParams {
self.convert_to_lowercase
}

pub fn split_regex(&self) -> Ref<'_, Regex> {
Ref::map(self.split_regex.borrow(), |x| x.as_ref().unwrap().as_re())
pub fn split_regex(&self) -> &Regex {
self.split_regex
.get()
.expect("Regex not initialized; call `check_ref()` first")
.as_re()
}

pub fn n_gram_range(&self) -> (usize, usize) {
Expand Down Expand Up @@ -121,12 +124,12 @@ impl CountVectorizerValidParams {
#[derive(Clone, Debug)]
pub struct CountVectorizerParams(CountVectorizerValidParams);

impl std::default::Default for CountVectorizerParams {
impl Default for CountVectorizerParams {
fn default() -> Self {
Self(CountVectorizerValidParams {
convert_to_lowercase: true,
split_regex_expr: r"\b\w\w+\b".to_string(),
split_regex: RefCell::new(None),
split_regex: OnceLock::new(),
n_gram_range: (1, 1),
normalize: true,
document_frequency: (0., 1.),
Expand Down Expand Up @@ -224,8 +227,7 @@ impl ParamGuard for CountVectorizerParams {
min_freq, max_freq,
))
} else {
*self.0.split_regex.borrow_mut() = Some(SerdeRegex::new(&self.0.split_regex_expr)?);

let _ = self.0.split_regex.set(SerdeRegex::new(&self.0.split_regex_expr)?);
Ok(&self.0)
}
}
Expand Down
8 changes: 4 additions & 4 deletions algorithms/linfa-preprocessing/src/countgrams/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ impl CountVectorizerValidParams {
// word, (integer mapping for word, document frequency for word)
let mut vocabulary: HashMap<String, (usize, usize)> = HashMap::new();
for string in x.iter().map(|s| transform_string(s.to_string(), self)) {
self.read_document_into_vocabulary(string, &self.split_regex(), &mut vocabulary);
self.read_document_into_vocabulary(string, self.split_regex(), &mut vocabulary);
}

let mut vocabulary = self.filter_vocabulary(vocabulary, x.len());
Expand Down Expand Up @@ -92,7 +92,7 @@ impl CountVectorizerValidParams {
}
// safe unwrap now that error has been handled
let document = transform_string(document.unwrap(), self);
self.read_document_into_vocabulary(document, &self.split_regex(), &mut vocabulary);
self.read_document_into_vocabulary(document, self.split_regex(), &mut vocabulary);
}

let mut vocabulary = self.filter_vocabulary(vocabulary, documents_count);
Expand Down Expand Up @@ -340,7 +340,7 @@ impl CountVectorizer {
sprs_vectorized.reserve_outer_dim_exact(x.len());
let regex = self.properties.split_regex();
for string in x.into_iter().map(|s| s.to_string()) {
let row = self.analyze_document(string, &regex, document_frequencies.view_mut());
let row = self.analyze_document(string, regex, document_frequencies.view_mut());
sprs_vectorized = sprs_vectorized.append_outer_csvec(row.view());
}
(sprs_vectorized, document_frequencies)
Expand All @@ -364,7 +364,7 @@ impl CountVectorizer {
file.read_to_end(&mut document_bytes).unwrap();
let document = encoding::decode(&document_bytes, trap, encoding).0.unwrap();
sprs_vectorized = sprs_vectorized.append_outer_csvec(
self.analyze_document(document, &regex, document_frequencies.view_mut())
self.analyze_document(document, regex, document_frequencies.view_mut())
.view(),
);
}
Expand Down
Loading