@@ -329,29 +329,45 @@ class CLI {
329
329
}
330
330
331
331
void LoadModel (std::string const & path, Learner* learner) const {
332
- if (common::FileExtension (path) == " json" ) {
333
- auto buffer = common::LoadSequentialFile (path);
334
- CHECK_GT (buffer.size (), 2 );
335
- CHECK_EQ (buffer[0 ], ' {' );
336
- Json in{Json::Load ({buffer.data (), buffer.size ()})};
332
+ auto ext = common::FileExtension (path);
333
+ auto read_file = [&]() {
334
+ auto str = common::LoadSequentialFile (path);
335
+ CHECK_GE (str.size (), 3 ); // "{}\0"
336
+ CHECK_EQ (str[0 ], ' {' );
337
+ return str;
338
+ };
339
+
340
+ if (ext == " json" ) {
341
+ auto buffer = read_file ();
342
+ Json in{Json::Load (StringView{buffer.data (), buffer.size ()})};
343
+ learner->LoadModel (in);
344
+ } else if (ext == " ubj" ) {
345
+ auto buffer = read_file ();
346
+ Json in = Json::Load (StringView{buffer.data (), buffer.size ()}, std::ios::binary);
337
347
learner->LoadModel (in);
338
348
} else {
339
- std::unique_ptr<dmlc::Stream> fi (dmlc::Stream::Create (path.c_str (), " r" ));
340
- learner->LoadModel (fi.get ());
349
+ LOG (FATAL) << " Unknown model format:" << path << " , expecting either json or ubj." ;
341
350
}
342
351
}
343
352
344
353
void SaveModel (std::string const & path, Learner* learner) const {
345
354
learner->Configure ();
346
355
std::unique_ptr<dmlc::Stream> fo (dmlc::Stream::Create (path.c_str (), " w" ));
347
- if (common::FileExtension (path) == " json" ) {
356
+ auto ext = common::FileExtension (path);
357
+ auto save_json = [&](std::ios::openmode mode) {
348
358
Json out{Object ()};
349
359
learner->SaveModel (&out);
350
- std::string str;
351
- Json::Dump (out, &str);
352
- fo->Write (str.c_str (), str.size ());
360
+ std::vector<char > str;
361
+ Json::Dump (out, &str, mode);
362
+ fo->Write (str.data (), str.size ());
363
+ };
364
+
365
+ if (ext == " json" ) {
366
+ save_json (std::ios::out);
367
+ } else if (ext == " ubj" ) {
368
+ save_json (std::ios::binary);
353
369
} else {
354
- learner-> SaveModel (fo. get ()) ;
370
+ LOG (FATAL) << " Unknown model format: " << path << " , expecting either json or ubj. " ;
355
371
}
356
372
}
357
373
0 commit comments