|
64 | 64 | #include <iomanip> |
65 | 65 |
|
66 | 66 | namespace { |
| 67 | + |
| 68 | +using dims_map = std::unordered_map<std::string, std::vector<std::size_t>>; |
| 69 | + |
67 | 70 | std::vector<std::string> |
68 | 71 | get_unrecognized_migraphx_envs(const char* envp[], |
69 | 72 | const std::map<std::string, std::string>& used_env) |
@@ -213,7 +216,7 @@ struct loader |
213 | 216 |
|
214 | 217 | static auto parse_param_dims(const std::vector<std::string>& param_dims_info) |
215 | 218 | { |
216 | | - std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims; |
| 219 | + dims_map map_input_dims; |
217 | 220 | std::string name = ""; |
218 | 221 | for(auto&& x : param_dims_info) |
219 | 222 | { |
@@ -502,16 +505,24 @@ struct program_params |
502 | 505 | return map_load_args; |
503 | 506 | } |
504 | 507 |
|
505 | | - auto generate(const program& p, const target& t, bool offload, unsigned batch) |
| 508 | + auto generate(const program& p, |
| 509 | + const target& t, |
| 510 | + bool offload, |
| 511 | + unsigned batch, |
| 512 | + dims_map map_input_dims = {}) |
506 | 513 | { |
507 | 514 | parameter_map m; |
508 | 515 | auto param_shapes = p.get_parameter_shapes(); |
509 | 516 | std::unordered_map<std::string, shape> static_param_shapes; |
510 | | - std::transform( |
511 | | - param_shapes.cbegin(), |
512 | | - param_shapes.cend(), |
513 | | - std::inserter(static_param_shapes, static_param_shapes.end()), |
514 | | - [&](const auto& x) { return std::make_pair(x.first, x.second.to_static(batch)); }); |
| 517 | + for(auto&& param : param_shapes) |
| 518 | + { |
| 519 | + if(contains(map_input_dims, param.first)) |
| 520 | + static_param_shapes[param.first] = {param.second.type(), |
| 521 | + map_input_dims[param.first]}; |
| 522 | + else |
| 523 | + static_param_shapes[param.first] = param.second.to_static(batch); |
| 524 | + } |
| 525 | + |
515 | 526 | for(auto&& s : fill0) |
516 | 527 | m[s] = fill_argument(static_param_shapes.at(s), 0); |
517 | 528 | for(auto&& s : fill1) |
@@ -591,7 +602,8 @@ struct compiler |
591 | 602 |
|
592 | 603 | auto params(const program& p) |
593 | 604 | { |
594 | | - return parameters.generate(p, ct.get_target(), co.offload_copy, l.batch); |
| 605 | + return parameters.generate( |
| 606 | + p, ct.get_target(), co.offload_copy, l.batch, loader::parse_param_dims(l.param_dims)); |
595 | 607 | } |
596 | 608 |
|
597 | 609 | auto host_params(const program& p) |
@@ -730,7 +742,8 @@ struct verify : command<verify> |
730 | 742 | std::cout << p << std::endl; |
731 | 743 |
|
732 | 744 | auto t = c.ct.get_target(); |
733 | | - auto m = c.parameters.generate(p, t, true, c.l.batch); |
| 745 | + auto m = |
| 746 | + c.parameters.generate(p, t, true, c.l.batch, loader::parse_param_dims(c.l.param_dims)); |
734 | 747 |
|
735 | 748 | if(c.to_fp16) |
736 | 749 | { |
|
0 commit comments