@@ -1368,7 +1368,38 @@ XGB_DLL int XGBoosterPredictFromCUDAColumnar(BoosterHandle handle, char const *,
1368
1368
}
1369
1369
#endif // !defined(XGBOOST_USE_CUDA)
1370
1370
1371
- XGB_DLL int XGBoosterLoadModel (BoosterHandle handle, const char * fname) {
1371
+ namespace {
1372
+ template <typename Buffer, typename Iter = typename Buffer::const_iterator>
1373
+ Json DispatchModelType (Buffer const &buffer, StringView ext, bool warn) {
1374
+ auto first_non_space = [&](Iter beg, Iter end) {
1375
+ for (auto i = beg; i != end; ++i) {
1376
+ if (!std::isspace (*i)) {
1377
+ return i;
1378
+ }
1379
+ }
1380
+ return end;
1381
+ };
1382
+
1383
+ Json model;
1384
+ auto it = first_non_space (buffer.cbegin () + 1 , buffer.cend ());
1385
+ if (it != buffer.cend () && *it == ' "' ) {
1386
+ if (warn) {
1387
+ LOG (WARNING) << " Unknown file format: `" << ext << " `. Using JSON as a guess." ;
1388
+ }
1389
+ model = Json::Load (StringView{buffer.data (), buffer.size ()});
1390
+ } else if (it != buffer.cend () && std::isalpha (*it)) {
1391
+ if (warn) {
1392
+ LOG (WARNING) << " Unknown file format: `" << ext << " `. Using UBJ as a guess." ;
1393
+ }
1394
+ model = Json::Load (StringView{buffer.data (), buffer.size ()}, std::ios::binary);
1395
+ } else {
1396
+ LOG (FATAL) << " Invalid model format" ;
1397
+ }
1398
+ return model;
1399
+ }
1400
+ } // namespace
1401
+
1402
+ XGB_DLL int XGBoosterLoadModel (BoosterHandle handle, const char *fname) {
1372
1403
API_BEGIN ();
1373
1404
CHECK_HANDLE ();
1374
1405
xgboost_CHECK_C_ARG_PTR (fname);
@@ -1378,28 +1409,23 @@ XGB_DLL int XGBoosterLoadModel(BoosterHandle handle, const char* fname) {
1378
1409
CHECK_EQ (str[0 ], ' {' );
1379
1410
return str;
1380
1411
};
1381
- if (common::FileExtension (fname) == " json" ) {
1412
+ auto ext = common::FileExtension (fname);
1413
+ if (ext == " json" ) {
1382
1414
auto buffer = read_file ();
1383
1415
Json in{Json::Load (StringView{buffer.data (), buffer.size ()})};
1384
- static_cast <Learner*>(handle)->LoadModel (in);
1385
- } else if (common::FileExtension (fname) == " ubj" ) {
1416
+ static_cast <Learner *>(handle)->LoadModel (in);
1417
+ } else if (ext == " ubj" ) {
1386
1418
auto buffer = read_file ();
1387
1419
Json in = Json::Load (StringView{buffer.data (), buffer.size ()}, std::ios::binary);
1388
1420
static_cast <Learner *>(handle)->LoadModel (in);
1389
1421
} else {
1390
- std::unique_ptr<dmlc::Stream> fi (dmlc::Stream::Create (fname, " r" ));
1391
- static_cast <Learner*>(handle)->LoadModel (fi.get ());
1422
+ auto buffer = read_file ();
1423
+ auto in = DispatchModelType (buffer, ext, true );
1424
+ static_cast <Learner *>(handle)->LoadModel (in);
1392
1425
}
1393
1426
API_END ();
1394
1427
}
1395
1428
1396
- namespace {
1397
- void WarnOldModel () {
1398
- LOG (WARNING) << " Saving into deprecated binary model format, please consider using `json` or "
1399
- " `ubj`. Model format is default to UBJSON in XGBoost 2.1 if not specified." ;
1400
- }
1401
- } // anonymous namespace
1402
-
1403
1429
XGB_DLL int XGBoosterSaveModel (BoosterHandle handle, const char *fname) {
1404
1430
API_BEGIN ();
1405
1431
CHECK_HANDLE ();
@@ -1419,13 +1445,9 @@ XGB_DLL int XGBoosterSaveModel(BoosterHandle handle, const char *fname) {
1419
1445
save_json (std::ios::out);
1420
1446
} else if (common::FileExtension (fname) == " ubj" ) {
1421
1447
save_json (std::ios::binary);
1422
- } else if (common::FileExtension (fname) == " deprecated" ) {
1423
- WarnOldModel ();
1424
- auto *bst = static_cast <Learner *>(handle);
1425
- bst->SaveModel (fo.get ());
1426
1448
} else {
1427
1449
LOG (WARNING) << " Saving model in the UBJSON format as default. You can use file extension:"
1428
- " `json`, `ubj` or `deprecated ` to choose between formats." ;
1450
+ " `json` or `ubj ` to choose between formats." ;
1429
1451
save_json (std::ios::binary);
1430
1452
}
1431
1453
API_END ();
@@ -1436,9 +1458,11 @@ XGB_DLL int XGBoosterLoadModelFromBuffer(BoosterHandle handle, const void *buf,
1436
1458
API_BEGIN ();
1437
1459
CHECK_HANDLE ();
1438
1460
xgboost_CHECK_C_ARG_PTR (buf);
1439
-
1461
+ auto buffer = common::Span<char const >{static_cast <char const *>(buf), len};
1462
+ // Don't warn, we have to guess the format with buffer input.
1463
+ auto in = DispatchModelType (buffer, " " , false );
1440
1464
common::MemoryFixSizeBuffer fs ((void *)buf, len); // NOLINT(*)
1441
- static_cast <Learner *>(handle)->LoadModel (&fs );
1465
+ static_cast <Learner *>(handle)->LoadModel (in );
1442
1466
API_END ();
1443
1467
}
1444
1468
@@ -1471,15 +1495,6 @@ XGB_DLL int XGBoosterSaveModelToBuffer(BoosterHandle handle, char const *json_co
1471
1495
save_json (std::ios::out);
1472
1496
} else if (format == " ubj" ) {
1473
1497
save_json (std::ios::binary);
1474
- } else if (format == " deprecated" ) {
1475
- WarnOldModel ();
1476
- auto &raw_str = learner->GetThreadLocal ().ret_str ;
1477
- raw_str.clear ();
1478
- common::MemoryBufferStream fo (&raw_str);
1479
- learner->SaveModel (&fo);
1480
-
1481
- *out_dptr = dmlc::BeginPtr (raw_str);
1482
- *out_len = static_cast <xgboost::bst_ulong>(raw_str.size ());
1483
1498
} else {
1484
1499
LOG (FATAL) << " Unknown format: `" << format << " `" ;
1485
1500
}
0 commit comments