@@ -145,7 +145,12 @@ ComputeGraph::ComputeGraph(GraphConfig config)
145145 execute_descriptor_counts_.descriptor_combined_sampler_count = 0 ;
146146 execute_descriptor_counts_.descriptor_storage_image_count = 0 ;
147147
148- context_->set_cmd (/* reusable = */ true );
148+ // If certain graph config variables are not specified, then set them
149+ // automatically.
150+ if (config_.prepack_threshold_nbytes == 0 ) {
151+ config_.prepack_threshold_nbytes = 10 * MB;
152+ config_.prepack_initial_threshold_nbytes = 10 * MB;
153+ }
149154}
150155
151156ComputeGraph::~ComputeGraph () {
@@ -431,6 +436,7 @@ ValueRef ComputeGraph::add_tensorref(
431436 ValueRef idx (static_cast <int >(values_.size ()));
432437 check_no_active_value_ptrs ();
433438 values_.emplace_back (TensorRef (sizes, dtype, data));
439+ total_constant_nbytes_ += values_.back ().toConstTensorRef ().nbytes ();
434440 return idx;
435441}
436442
@@ -750,6 +756,19 @@ void ComputeGraph::prepare_pipelines() {
750756 vkapi::ComputePipelineCache::Hasher>();
751757}
752758
759+ void ComputeGraph::submit_current_cmd (const bool final_use) {
760+ context_->submit_cmd_to_gpu (VK_NULL_HANDLE, final_use);
761+ }
762+
763+ void ComputeGraph::submit_current_cmd_and_wait (const bool final_use) {
764+ vkapi::VulkanFence fence = context_->fences ().get_fence ();
765+ context_->submit_cmd_to_gpu (fence.get_submit_handle (), final_use);
766+ fence.wait ();
767+ context_->fences ().return_fence (fence);
768+
769+ context_->flush ();
770+ }
771+
753772void ComputeGraph::encode_prepack () {
754773 for (std::unique_ptr<PrepackNode>& node : prepack_nodes_) {
755774 node->encode (this );
@@ -766,6 +785,37 @@ void ComputeGraph::prepack() const {
766785 context_->flush ();
767786}
768787
788+ void ComputeGraph::run_prepack () {
789+ int i = 0 ;
790+ bool submitted = false ;
791+ const bool reduce_peak_memory = total_constant_nbytes_ > 500 * MB;
792+ // int count = 0;
793+ context_->set_cmd ();
794+ for (std::unique_ptr<PrepackNode>& node : prepack_nodes_) {
795+ // Do not trigger on the first or last prepack node.
796+ const bool not_terminal = i != 0 && i != (prepack_nodes_.size () - 1 );
797+ size_t threshold = submitted ? config_.prepack_threshold_nbytes
798+ : config_.prepack_initial_threshold_nbytes ;
799+ if (not_terminal && staging_nbytes_in_cmd_ > threshold) {
800+ // If reducing peak memory usage, wait for the current command buffer to
801+ // finish executing and flush to recycle the staging memory. This will
802+ // reduce peak memory usage, but will slightly increase load latency.
803+ // Otherwise, just submit the current command buffer for execution and
804+ // proceed. This results in lower load latency at the cost of higher peak
805+ // memory usage.
806+ reduce_peak_memory ? submit_current_cmd_and_wait () : submit_current_cmd ();
807+ staging_nbytes_in_cmd_ = 0 ;
808+ context_->set_cmd ();
809+ submitted = true ;
810+ }
811+
812+ node->encode (this );
813+ i++;
814+ }
815+ submit_current_cmd_and_wait (/* final_use=*/ true );
816+ staging_nbytes_in_cmd_ = 0 ;
817+ }
818+
769819void ComputeGraph::encode_execute () {
770820 context_->flush ();
771821 context_->set_cmd (/* reusable = */ true );
0 commit comments