Skip to content

Commit 45c90a0

Browse files
committed
Read modules.json when dense_paths is None
As that might imply that the user originally provided a local path rather than a Hugging Face Hub ID, meaning that the `dense_paths` variable won't be filled, meaning that we need to read those from `modules.json` Note that this is just a premature quick solution, ideally this should be handled within `backends/src/lib.rs` rather than directly within the `CandleBackend` as otherwise we end up duplicating a lot of unnecessary code
1 parent 9ef569d commit 45c90a0

File tree

1 file changed

+98
-47
lines changed

1 file changed

+98
-47
lines changed

backends/candle/src/lib.rs

Lines changed: 98 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,46 @@ enum Config {
115115
XlmRoberta(BertConfig),
116116
}
117117

118+
#[derive(Debug, Clone, Deserialize, PartialEq)]
119+
enum ModuleType {
120+
#[serde(rename = "sentence_transformers.models.Dense")]
121+
Dense,
122+
#[serde(rename = "sentence_transformers.models.Normalize")]
123+
Normalize,
124+
#[serde(rename = "sentence_transformers.models.Pooling")]
125+
Pooling,
126+
#[serde(rename = "sentence_transformers.models.Transformer")]
127+
Transformer,
128+
}
129+
130+
#[derive(Debug, Clone, Deserialize)]
131+
struct ModuleConfig {
132+
#[allow(dead_code)]
133+
idx: usize,
134+
#[allow(dead_code)]
135+
name: String,
136+
path: String,
137+
#[serde(rename = "type")]
138+
module_type: ModuleType,
139+
}
140+
141+
fn parse_dense_paths_from_modules(model_path: &Path) -> Result<Vec<String>, std::io::Error> {
142+
let modules_path = model_path.join("modules.json");
143+
if !modules_path.exists() {
144+
return Ok(vec![]);
145+
}
146+
147+
let content = std::fs::read_to_string(&modules_path)?;
148+
let modules: Vec<ModuleConfig> = serde_json::from_str(&content)
149+
.map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidData, err))?;
150+
151+
Ok(modules
152+
.into_iter()
153+
.filter(|module| module.module_type == ModuleType::Dense)
154+
.map(|module| module.path)
155+
.collect::<Vec<String>>())
156+
}
157+
118158
pub struct CandleBackend {
119159
device: Device,
120160
model: Box<dyn Model + Send>,
@@ -511,55 +551,66 @@ impl CandleBackend {
511551
}
512552
};
513553

514-
// Load Dense layers from the provided Dense paths
554+
// Load modules.json and read the Dense paths from there, unless `dense_paths` is provided
555+
// in such case simply use the `dense_paths`
556+
// 1. If `dense_paths` is None then try to read the `modules.json` file and parse the
557+
// content to read the paths of the default Dense paths, useful when the model directory
558+
// is provided as the `model-id` rather than the ID from the Hugging Face Hub
559+
// 2. If `dense_paths` is Some (even if empty), respect that explicit choice and do not
560+
// read from modules.json, this allows users to explicitly disable dense layers
515561
let mut dense_layers = Vec::new();
516-
if let Some(dense_paths) = &dense_paths {
517-
if !dense_paths.is_empty() {
518-
tracing::info!("Loading Dense module/s from path/s: {dense_paths:?}");
519-
520-
for dense_path in dense_paths.iter() {
521-
let dense_safetensors =
522-
model_path.join(format!("{dense_path}/model.safetensors"));
523-
let dense_pytorch = model_path.join(format!("{dense_path}/pytorch_model.bin"));
524-
525-
if dense_safetensors.exists() || dense_pytorch.exists() {
526-
let dense_config_path =
527-
model_path.join(format!("{dense_path}/config.json"));
528-
529-
let dense_config_str = std::fs::read_to_string(&dense_config_path)
530-
.map_err(|err| {
531-
BackendError::Start(format!(
532-
"Unable to read `{dense_path}/config.json` file: {err:?}",
533-
))
534-
})?;
535-
let dense_config: DenseConfig = serde_json::from_str(&dense_config_str)
536-
.map_err(|err| {
537-
BackendError::Start(format!(
538-
"Unable to parse `{dense_path}/config.json`: {err:?}",
539-
))
540-
})?;
541-
542-
let dense_vb = if dense_safetensors.exists() {
543-
unsafe {
544-
VarBuilder::from_mmaped_safetensors(
545-
&[dense_safetensors],
546-
dtype,
547-
&device,
548-
)
549-
}
550-
.s()?
551-
} else {
552-
VarBuilder::from_pth(&dense_pytorch, dtype, &device).s()?
553-
};
554-
555-
let dense_layer = Box::new(Dense::load(dense_vb, &dense_config).s()?)
556-
as Box<dyn DenseLayer + Send>;
557-
dense_layers.push(dense_layer);
558-
559-
tracing::info!("Loaded Dense module from path: {dense_path}");
562+
563+
let paths_to_load = if let Some(dense_paths) = &dense_paths {
564+
// If dense_paths is explicitly provided (even if empty), respect that choice
565+
dense_paths.clone()
566+
} else {
567+
// Try to parse modules.json only if dense_paths is None
568+
parse_dense_paths_from_modules(model_path).unwrap_or_default()
569+
};
570+
571+
if !paths_to_load.is_empty() {
572+
tracing::info!("Loading Dense module/s from path/s: {paths_to_load:?}");
573+
574+
for dense_path in paths_to_load.iter() {
575+
let dense_safetensors = model_path.join(format!("{dense_path}/model.safetensors"));
576+
let dense_pytorch = model_path.join(format!("{dense_path}/pytorch_model.bin"));
577+
578+
if dense_safetensors.exists() || dense_pytorch.exists() {
579+
let dense_config_path = model_path.join(format!("{dense_path}/config.json"));
580+
581+
let dense_config_str =
582+
std::fs::read_to_string(&dense_config_path).map_err(|err| {
583+
BackendError::Start(format!(
584+
"Unable to read `{dense_path}/config.json` file: {err:?}",
585+
))
586+
})?;
587+
let dense_config: DenseConfig = serde_json::from_str(&dense_config_str)
588+
.map_err(|err| {
589+
BackendError::Start(format!(
590+
"Unable to parse `{dense_path}/config.json`: {err:?}",
591+
))
592+
})?;
593+
594+
let dense_vb = if dense_safetensors.exists() {
595+
unsafe {
596+
VarBuilder::from_mmaped_safetensors(
597+
&[dense_safetensors],
598+
dtype,
599+
&device,
600+
)
601+
}
602+
.s()?
560603
} else {
561-
tracing::warn!("Dense module files not found for path: {dense_path}",);
562-
}
604+
VarBuilder::from_pth(&dense_pytorch, dtype, &device).s()?
605+
};
606+
607+
let dense_layer = Box::new(Dense::load(dense_vb, &dense_config).s()?)
608+
as Box<dyn DenseLayer + Send>;
609+
dense_layers.push(dense_layer);
610+
611+
tracing::info!("Loaded Dense module from path: {dense_path}");
612+
} else {
613+
tracing::warn!("Dense module files not found for path: {dense_path}",);
563614
}
564615
}
565616
}

0 commit comments

Comments
 (0)