Skip to content

Commit 9894e2a

Browse files
committed
Move modules.json logic to backends/src/lib.rs instead
1 parent 45c90a0 commit 9894e2a

File tree

2 files changed

+78
-99
lines changed

2 files changed

+78
-99
lines changed

backends/candle/src/lib.rs

Lines changed: 46 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -115,46 +115,6 @@ enum Config {
115115
XlmRoberta(BertConfig),
116116
}
117117

118-
#[derive(Debug, Clone, Deserialize, PartialEq)]
119-
enum ModuleType {
120-
#[serde(rename = "sentence_transformers.models.Dense")]
121-
Dense,
122-
#[serde(rename = "sentence_transformers.models.Normalize")]
123-
Normalize,
124-
#[serde(rename = "sentence_transformers.models.Pooling")]
125-
Pooling,
126-
#[serde(rename = "sentence_transformers.models.Transformer")]
127-
Transformer,
128-
}
129-
130-
#[derive(Debug, Clone, Deserialize)]
131-
struct ModuleConfig {
132-
#[allow(dead_code)]
133-
idx: usize,
134-
#[allow(dead_code)]
135-
name: String,
136-
path: String,
137-
#[serde(rename = "type")]
138-
module_type: ModuleType,
139-
}
140-
141-
fn parse_dense_paths_from_modules(model_path: &Path) -> Result<Vec<String>, std::io::Error> {
142-
let modules_path = model_path.join("modules.json");
143-
if !modules_path.exists() {
144-
return Ok(vec![]);
145-
}
146-
147-
let content = std::fs::read_to_string(&modules_path)?;
148-
let modules: Vec<ModuleConfig> = serde_json::from_str(&content)
149-
.map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidData, err))?;
150-
151-
Ok(modules
152-
.into_iter()
153-
.filter(|module| module.module_type == ModuleType::Dense)
154-
.map(|module| module.path)
155-
.collect::<Vec<String>>())
156-
}
157-
158118
pub struct CandleBackend {
159119
device: Device,
160120
model: Box<dyn Model + Send>,
@@ -551,66 +511,54 @@ impl CandleBackend {
551511
}
552512
};
553513

554-
// Load modules.json and read the Dense paths from there, unless `dense_paths` is provided
555-
// in such case simply use the `dense_paths`
556-
// 1. If `dense_paths` is None then try to read the `modules.json` file and parse the
557-
// content to read the paths of the default Dense paths, useful when the model directory
558-
// is provided as the `model-id` rather than the ID from the Hugging Face Hub
559-
// 2. If `dense_paths` is Some (even if empty), respect that explicit choice and do not
560-
// read from modules.json, this allows users to explicitly disable dense layers
561514
let mut dense_layers = Vec::new();
562-
563-
let paths_to_load = if let Some(dense_paths) = &dense_paths {
564-
// If dense_paths is explicitly provided (even if empty), respect that choice
565-
dense_paths.clone()
566-
} else {
567-
// Try to parse modules.json only if dense_paths is None
568-
parse_dense_paths_from_modules(model_path).unwrap_or_default()
569-
};
570-
571-
if !paths_to_load.is_empty() {
572-
tracing::info!("Loading Dense module/s from path/s: {paths_to_load:?}");
573-
574-
for dense_path in paths_to_load.iter() {
575-
let dense_safetensors = model_path.join(format!("{dense_path}/model.safetensors"));
576-
let dense_pytorch = model_path.join(format!("{dense_path}/pytorch_model.bin"));
577-
578-
if dense_safetensors.exists() || dense_pytorch.exists() {
579-
let dense_config_path = model_path.join(format!("{dense_path}/config.json"));
580-
581-
let dense_config_str =
582-
std::fs::read_to_string(&dense_config_path).map_err(|err| {
583-
BackendError::Start(format!(
584-
"Unable to read `{dense_path}/config.json` file: {err:?}",
585-
))
586-
})?;
587-
let dense_config: DenseConfig = serde_json::from_str(&dense_config_str)
588-
.map_err(|err| {
589-
BackendError::Start(format!(
590-
"Unable to parse `{dense_path}/config.json`: {err:?}",
591-
))
592-
})?;
593-
594-
let dense_vb = if dense_safetensors.exists() {
595-
unsafe {
596-
VarBuilder::from_mmaped_safetensors(
597-
&[dense_safetensors],
598-
dtype,
599-
&device,
600-
)
601-
}
602-
.s()?
515+
if let Some(dense_paths) = dense_paths {
516+
if !dense_paths.is_empty() {
517+
tracing::info!("Loading Dense module/s from path/s: {dense_paths:?}");
518+
519+
for dense_path in dense_paths.iter() {
520+
let dense_safetensors =
521+
model_path.join(format!("{dense_path}/model.safetensors"));
522+
let dense_pytorch = model_path.join(format!("{dense_path}/pytorch_model.bin"));
523+
524+
if dense_safetensors.exists() || dense_pytorch.exists() {
525+
let dense_config_path =
526+
model_path.join(format!("{dense_path}/config.json"));
527+
528+
let dense_config_str = std::fs::read_to_string(&dense_config_path)
529+
.map_err(|err| {
530+
BackendError::Start(format!(
531+
"Unable to read `{dense_path}/config.json` file: {err:?}",
532+
))
533+
})?;
534+
let dense_config: DenseConfig = serde_json::from_str(&dense_config_str)
535+
.map_err(|err| {
536+
BackendError::Start(format!(
537+
"Unable to parse `{dense_path}/config.json`: {err:?}",
538+
))
539+
})?;
540+
541+
let dense_vb = if dense_safetensors.exists() {
542+
unsafe {
543+
VarBuilder::from_mmaped_safetensors(
544+
&[dense_safetensors],
545+
dtype,
546+
&device,
547+
)
548+
}
549+
.s()?
550+
} else {
551+
VarBuilder::from_pth(&dense_pytorch, dtype, &device).s()?
552+
};
553+
554+
let dense_layer = Box::new(Dense::load(dense_vb, &dense_config).s()?)
555+
as Box<dyn DenseLayer + Send>;
556+
dense_layers.push(dense_layer);
557+
558+
tracing::info!("Loaded Dense module from path: {dense_path}");
603559
} else {
604-
VarBuilder::from_pth(&dense_pytorch, dtype, &device).s()?
605-
};
606-
607-
let dense_layer = Box::new(Dense::load(dense_vb, &dense_config).s()?)
608-
as Box<dyn DenseLayer + Send>;
609-
dense_layers.push(dense_layer);
610-
611-
tracing::info!("Loaded Dense module from path: {dense_path}");
612-
} else {
613-
tracing::warn!("Dense module files not found for path: {dense_path}",);
560+
tracing::warn!("Dense module files not found for path: {dense_path}",);
561+
}
614562
}
615563
}
616564
}

backends/src/lib.rs

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,38 @@ async fn init_backend(
433433
tracing::info!("Dense modules downloaded in {:?}", start.elapsed());
434434
Some(dense_paths)
435435
} else {
436-
None
436+
// For local models, try to parse modules.json and handle dense_path logic
437+
let modules_json_path = model_path.join("modules.json");
438+
if modules_json_path.exists() {
439+
match parse_dense_paths_from_modules(&modules_json_path).await {
440+
Ok(module_paths) => match module_paths.len() {
441+
0 => Some(vec![]),
442+
1 => {
443+
let path_to_use = if let Some(ref user_path) = dense_path {
444+
if user_path != &module_paths[0] {
445+
tracing::info!("`{}` found in `modules.json`, but using provided `--dense-path={user_path}` instead", module_paths[0]);
446+
}
447+
user_path.clone()
448+
} else {
449+
module_paths[0].clone()
450+
};
451+
Some(vec![path_to_use])
452+
}
453+
_ => {
454+
if dense_path.is_some() {
455+
tracing::warn!("A value for `--dense-path` was provided, but since there's more than one subsequent Dense module, then the provided value will be ignored.");
456+
}
457+
Some(module_paths)
458+
}
459+
},
460+
Err(err) => {
461+
tracing::warn!("Failed to parse local modules.json: {err}");
462+
None
463+
}
464+
}
465+
} else {
466+
None
467+
}
437468
};
438469

439470
let backend = CandleBackend::new(

0 commit comments

Comments
 (0)