@@ -115,46 +115,6 @@ enum Config {
115
115
XlmRoberta ( BertConfig ) ,
116
116
}
117
117
118
- #[ derive( Debug , Clone , Deserialize , PartialEq ) ]
119
- enum ModuleType {
120
- #[ serde( rename = "sentence_transformers.models.Dense" ) ]
121
- Dense ,
122
- #[ serde( rename = "sentence_transformers.models.Normalize" ) ]
123
- Normalize ,
124
- #[ serde( rename = "sentence_transformers.models.Pooling" ) ]
125
- Pooling ,
126
- #[ serde( rename = "sentence_transformers.models.Transformer" ) ]
127
- Transformer ,
128
- }
129
-
130
- #[ derive( Debug , Clone , Deserialize ) ]
131
- struct ModuleConfig {
132
- #[ allow( dead_code) ]
133
- idx : usize ,
134
- #[ allow( dead_code) ]
135
- name : String ,
136
- path : String ,
137
- #[ serde( rename = "type" ) ]
138
- module_type : ModuleType ,
139
- }
140
-
141
- fn parse_dense_paths_from_modules ( model_path : & Path ) -> Result < Vec < String > , std:: io:: Error > {
142
- let modules_path = model_path. join ( "modules.json" ) ;
143
- if !modules_path. exists ( ) {
144
- return Ok ( vec ! [ ] ) ;
145
- }
146
-
147
- let content = std:: fs:: read_to_string ( & modules_path) ?;
148
- let modules: Vec < ModuleConfig > = serde_json:: from_str ( & content)
149
- . map_err ( |err| std:: io:: Error :: new ( std:: io:: ErrorKind :: InvalidData , err) ) ?;
150
-
151
- Ok ( modules
152
- . into_iter ( )
153
- . filter ( |module| module. module_type == ModuleType :: Dense )
154
- . map ( |module| module. path )
155
- . collect :: < Vec < String > > ( ) )
156
- }
157
-
158
118
pub struct CandleBackend {
159
119
device : Device ,
160
120
model : Box < dyn Model + Send > ,
@@ -551,66 +511,54 @@ impl CandleBackend {
551
511
}
552
512
} ;
553
513
554
- // Load modules.json and read the Dense paths from there, unless `dense_paths` is provided
555
- // in such case simply use the `dense_paths`
556
- // 1. If `dense_paths` is None then try to read the `modules.json` file and parse the
557
- // content to read the paths of the default Dense paths, useful when the model directory
558
- // is provided as the `model-id` rather than the ID from the Hugging Face Hub
559
- // 2. If `dense_paths` is Some (even if empty), respect that explicit choice and do not
560
- // read from modules.json, this allows users to explicitly disable dense layers
561
514
let mut dense_layers = Vec :: new ( ) ;
562
-
563
- let paths_to_load = if let Some ( dense_paths) = & dense_paths {
564
- // If dense_paths is explicitly provided (even if empty), respect that choice
565
- dense_paths. clone ( )
566
- } else {
567
- // Try to parse modules.json only if dense_paths is None
568
- parse_dense_paths_from_modules ( model_path) . unwrap_or_default ( )
569
- } ;
570
-
571
- if !paths_to_load. is_empty ( ) {
572
- tracing:: info!( "Loading Dense module/s from path/s: {paths_to_load:?}" ) ;
573
-
574
- for dense_path in paths_to_load. iter ( ) {
575
- let dense_safetensors = model_path. join ( format ! ( "{dense_path}/model.safetensors" ) ) ;
576
- let dense_pytorch = model_path. join ( format ! ( "{dense_path}/pytorch_model.bin" ) ) ;
577
-
578
- if dense_safetensors. exists ( ) || dense_pytorch. exists ( ) {
579
- let dense_config_path = model_path. join ( format ! ( "{dense_path}/config.json" ) ) ;
580
-
581
- let dense_config_str =
582
- std:: fs:: read_to_string ( & dense_config_path) . map_err ( |err| {
583
- BackendError :: Start ( format ! (
584
- "Unable to read `{dense_path}/config.json` file: {err:?}" ,
585
- ) )
586
- } ) ?;
587
- let dense_config: DenseConfig = serde_json:: from_str ( & dense_config_str)
588
- . map_err ( |err| {
589
- BackendError :: Start ( format ! (
590
- "Unable to parse `{dense_path}/config.json`: {err:?}" ,
591
- ) )
592
- } ) ?;
593
-
594
- let dense_vb = if dense_safetensors. exists ( ) {
595
- unsafe {
596
- VarBuilder :: from_mmaped_safetensors (
597
- & [ dense_safetensors] ,
598
- dtype,
599
- & device,
600
- )
601
- }
602
- . s ( ) ?
515
+ if let Some ( dense_paths) = dense_paths {
516
+ if !dense_paths. is_empty ( ) {
517
+ tracing:: info!( "Loading Dense module/s from path/s: {dense_paths:?}" ) ;
518
+
519
+ for dense_path in dense_paths. iter ( ) {
520
+ let dense_safetensors =
521
+ model_path. join ( format ! ( "{dense_path}/model.safetensors" ) ) ;
522
+ let dense_pytorch = model_path. join ( format ! ( "{dense_path}/pytorch_model.bin" ) ) ;
523
+
524
+ if dense_safetensors. exists ( ) || dense_pytorch. exists ( ) {
525
+ let dense_config_path =
526
+ model_path. join ( format ! ( "{dense_path}/config.json" ) ) ;
527
+
528
+ let dense_config_str = std:: fs:: read_to_string ( & dense_config_path)
529
+ . map_err ( |err| {
530
+ BackendError :: Start ( format ! (
531
+ "Unable to read `{dense_path}/config.json` file: {err:?}" ,
532
+ ) )
533
+ } ) ?;
534
+ let dense_config: DenseConfig = serde_json:: from_str ( & dense_config_str)
535
+ . map_err ( |err| {
536
+ BackendError :: Start ( format ! (
537
+ "Unable to parse `{dense_path}/config.json`: {err:?}" ,
538
+ ) )
539
+ } ) ?;
540
+
541
+ let dense_vb = if dense_safetensors. exists ( ) {
542
+ unsafe {
543
+ VarBuilder :: from_mmaped_safetensors (
544
+ & [ dense_safetensors] ,
545
+ dtype,
546
+ & device,
547
+ )
548
+ }
549
+ . s ( ) ?
550
+ } else {
551
+ VarBuilder :: from_pth ( & dense_pytorch, dtype, & device) . s ( ) ?
552
+ } ;
553
+
554
+ let dense_layer = Box :: new ( Dense :: load ( dense_vb, & dense_config) . s ( ) ?)
555
+ as Box < dyn DenseLayer + Send > ;
556
+ dense_layers. push ( dense_layer) ;
557
+
558
+ tracing:: info!( "Loaded Dense module from path: {dense_path}" ) ;
603
559
} else {
604
- VarBuilder :: from_pth ( & dense_pytorch, dtype, & device) . s ( ) ?
605
- } ;
606
-
607
- let dense_layer = Box :: new ( Dense :: load ( dense_vb, & dense_config) . s ( ) ?)
608
- as Box < dyn DenseLayer + Send > ;
609
- dense_layers. push ( dense_layer) ;
610
-
611
- tracing:: info!( "Loaded Dense module from path: {dense_path}" ) ;
612
- } else {
613
- tracing:: warn!( "Dense module files not found for path: {dense_path}" , ) ;
560
+ tracing:: warn!( "Dense module files not found for path: {dense_path}" , ) ;
561
+ }
614
562
}
615
563
}
616
564
}
0 commit comments