Skip to content

Commit 5b3da6c

Browse files
committed
cli.
1 parent 68d4017 commit 5b3da6c

File tree

1 file changed

+28
-12
lines changed

1 file changed

+28
-12
lines changed

src/cli_main.cc

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -329,29 +329,45 @@ class CLI {
329329
}
330330

331331
void LoadModel(std::string const& path, Learner* learner) const {
332-
if (common::FileExtension(path) == "json") {
333-
auto buffer = common::LoadSequentialFile(path);
334-
CHECK_GT(buffer.size(), 2);
335-
CHECK_EQ(buffer[0], '{');
336-
Json in{Json::Load({buffer.data(), buffer.size()})};
332+
auto ext = common::FileExtension(path);
333+
auto read_file = [&]() {
334+
auto str = common::LoadSequentialFile(path);
335+
CHECK_GE(str.size(), 3); // "{}\0"
336+
CHECK_EQ(str[0], '{');
337+
return str;
338+
};
339+
340+
if (ext == "json") {
341+
auto buffer = read_file();
342+
Json in{Json::Load(StringView{buffer.data(), buffer.size()})};
343+
learner->LoadModel(in);
344+
} else if (ext == "ubj") {
345+
auto buffer = read_file();
346+
Json in = Json::Load(StringView{buffer.data(), buffer.size()}, std::ios::binary);
337347
learner->LoadModel(in);
338348
} else {
339-
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(path.c_str(), "r"));
340-
learner->LoadModel(fi.get());
349+
LOG(FATAL) << "Unknown model format:" << path << ", expecting either json or ubj.";
341350
}
342351
}
343352

344353
void SaveModel(std::string const& path, Learner* learner) const {
345354
learner->Configure();
346355
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(path.c_str(), "w"));
347-
if (common::FileExtension(path) == "json") {
356+
auto ext = common::FileExtension(path);
357+
auto save_json = [&](std::ios::openmode mode) {
348358
Json out{Object()};
349359
learner->SaveModel(&out);
350-
std::string str;
351-
Json::Dump(out, &str);
352-
fo->Write(str.c_str(), str.size());
360+
std::vector<char> str;
361+
Json::Dump(out, &str, mode);
362+
fo->Write(str.data(), str.size());
363+
};
364+
365+
if (ext == "json") {
366+
save_json(std::ios::out);
367+
} else if (ext == "ubj") {
368+
save_json(std::ios::binary);
353369
} else {
354-
learner->SaveModel(fo.get());
370+
LOG(FATAL) << "Unknown model format:" << path << ", expecting either json or ubj.";
355371
}
356372
}
357373

0 commit comments

Comments
 (0)