@@ -115,6 +115,46 @@ enum Config {
115
115
XlmRoberta ( BertConfig ) ,
116
116
}
117
117
118
+ #[ derive( Debug , Clone , Deserialize , PartialEq ) ]
119
+ enum ModuleType {
120
+ #[ serde( rename = "sentence_transformers.models.Dense" ) ]
121
+ Dense ,
122
+ #[ serde( rename = "sentence_transformers.models.Normalize" ) ]
123
+ Normalize ,
124
+ #[ serde( rename = "sentence_transformers.models.Pooling" ) ]
125
+ Pooling ,
126
+ #[ serde( rename = "sentence_transformers.models.Transformer" ) ]
127
+ Transformer ,
128
+ }
129
+
130
+ #[ derive( Debug , Clone , Deserialize ) ]
131
+ struct ModuleConfig {
132
+ #[ allow( dead_code) ]
133
+ idx : usize ,
134
+ #[ allow( dead_code) ]
135
+ name : String ,
136
+ path : String ,
137
+ #[ serde( rename = "type" ) ]
138
+ module_type : ModuleType ,
139
+ }
140
+
141
+ fn parse_dense_paths_from_modules ( model_path : & Path ) -> Result < Vec < String > , std:: io:: Error > {
142
+ let modules_path = model_path. join ( "modules.json" ) ;
143
+ if !modules_path. exists ( ) {
144
+ return Ok ( vec ! [ ] ) ;
145
+ }
146
+
147
+ let content = std:: fs:: read_to_string ( & modules_path) ?;
148
+ let modules: Vec < ModuleConfig > = serde_json:: from_str ( & content)
149
+ . map_err ( |err| std:: io:: Error :: new ( std:: io:: ErrorKind :: InvalidData , err) ) ?;
150
+
151
+ Ok ( modules
152
+ . into_iter ( )
153
+ . filter ( |module| module. module_type == ModuleType :: Dense )
154
+ . map ( |module| module. path )
155
+ . collect :: < Vec < String > > ( ) )
156
+ }
157
+
118
158
pub struct CandleBackend {
119
159
device : Device ,
120
160
model : Box < dyn Model + Send > ,
@@ -511,55 +551,66 @@ impl CandleBackend {
511
551
}
512
552
} ;
513
553
514
- // Load Dense layers from the provided Dense paths
554
+ // Load modules.json and read the Dense paths from there, unless `dense_paths` is provided
555
+ // in such case simply use the `dense_paths`
556
+ // 1. If `dense_paths` is None then try to read the `modules.json` file and parse the
557
+ // content to read the paths of the default Dense paths, useful when the model directory
558
+ // is provided as the `model-id` rather than the ID from the Hugging Face Hub
559
+ // 2. If `dense_paths` is Some (even if empty), respect that explicit choice and do not
560
+ // read from modules.json, this allows users to explicitly disable dense layers
515
561
let mut dense_layers = Vec :: new ( ) ;
516
- if let Some ( dense_paths) = & dense_paths {
517
- if !dense_paths. is_empty ( ) {
518
- tracing:: info!( "Loading Dense module/s from path/s: {dense_paths:?}" ) ;
519
-
520
- for dense_path in dense_paths. iter ( ) {
521
- let dense_safetensors =
522
- model_path. join ( format ! ( "{dense_path}/model.safetensors" ) ) ;
523
- let dense_pytorch = model_path. join ( format ! ( "{dense_path}/pytorch_model.bin" ) ) ;
524
-
525
- if dense_safetensors. exists ( ) || dense_pytorch. exists ( ) {
526
- let dense_config_path =
527
- model_path. join ( format ! ( "{dense_path}/config.json" ) ) ;
528
-
529
- let dense_config_str = std:: fs:: read_to_string ( & dense_config_path)
530
- . map_err ( |err| {
531
- BackendError :: Start ( format ! (
532
- "Unable to read `{dense_path}/config.json` file: {err:?}" ,
533
- ) )
534
- } ) ?;
535
- let dense_config: DenseConfig = serde_json:: from_str ( & dense_config_str)
536
- . map_err ( |err| {
537
- BackendError :: Start ( format ! (
538
- "Unable to parse `{dense_path}/config.json`: {err:?}" ,
539
- ) )
540
- } ) ?;
541
-
542
- let dense_vb = if dense_safetensors. exists ( ) {
543
- unsafe {
544
- VarBuilder :: from_mmaped_safetensors (
545
- & [ dense_safetensors] ,
546
- dtype,
547
- & device,
548
- )
549
- }
550
- . s ( ) ?
551
- } else {
552
- VarBuilder :: from_pth ( & dense_pytorch, dtype, & device) . s ( ) ?
553
- } ;
554
-
555
- let dense_layer = Box :: new ( Dense :: load ( dense_vb, & dense_config) . s ( ) ?)
556
- as Box < dyn DenseLayer + Send > ;
557
- dense_layers. push ( dense_layer) ;
558
-
559
- tracing:: info!( "Loaded Dense module from path: {dense_path}" ) ;
562
+
563
+ let paths_to_load = if let Some ( dense_paths) = & dense_paths {
564
+ // If dense_paths is explicitly provided (even if empty), respect that choice
565
+ dense_paths. clone ( )
566
+ } else {
567
+ // Try to parse modules.json only if dense_paths is None
568
+ parse_dense_paths_from_modules ( model_path) . unwrap_or_default ( )
569
+ } ;
570
+
571
+ if !paths_to_load. is_empty ( ) {
572
+ tracing:: info!( "Loading Dense module/s from path/s: {paths_to_load:?}" ) ;
573
+
574
+ for dense_path in paths_to_load. iter ( ) {
575
+ let dense_safetensors = model_path. join ( format ! ( "{dense_path}/model.safetensors" ) ) ;
576
+ let dense_pytorch = model_path. join ( format ! ( "{dense_path}/pytorch_model.bin" ) ) ;
577
+
578
+ if dense_safetensors. exists ( ) || dense_pytorch. exists ( ) {
579
+ let dense_config_path = model_path. join ( format ! ( "{dense_path}/config.json" ) ) ;
580
+
581
+ let dense_config_str =
582
+ std:: fs:: read_to_string ( & dense_config_path) . map_err ( |err| {
583
+ BackendError :: Start ( format ! (
584
+ "Unable to read `{dense_path}/config.json` file: {err:?}" ,
585
+ ) )
586
+ } ) ?;
587
+ let dense_config: DenseConfig = serde_json:: from_str ( & dense_config_str)
588
+ . map_err ( |err| {
589
+ BackendError :: Start ( format ! (
590
+ "Unable to parse `{dense_path}/config.json`: {err:?}" ,
591
+ ) )
592
+ } ) ?;
593
+
594
+ let dense_vb = if dense_safetensors. exists ( ) {
595
+ unsafe {
596
+ VarBuilder :: from_mmaped_safetensors (
597
+ & [ dense_safetensors] ,
598
+ dtype,
599
+ & device,
600
+ )
601
+ }
602
+ . s ( ) ?
560
603
} else {
561
- tracing:: warn!( "Dense module files not found for path: {dense_path}" , ) ;
562
- }
604
+ VarBuilder :: from_pth ( & dense_pytorch, dtype, & device) . s ( ) ?
605
+ } ;
606
+
607
+ let dense_layer = Box :: new ( Dense :: load ( dense_vb, & dense_config) . s ( ) ?)
608
+ as Box < dyn DenseLayer + Send > ;
609
+ dense_layers. push ( dense_layer) ;
610
+
611
+ tracing:: info!( "Loaded Dense module from path: {dense_path}" ) ;
612
+ } else {
613
+ tracing:: warn!( "Dense module files not found for path: {dense_path}" , ) ;
563
614
}
564
615
}
565
616
}
0 commit comments