@@ -105,6 +105,7 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
105105      int64_t  num_shards = 8 ,
106106      int64_t  num_threads = 32 ,
107107      int64_t  row_storage_bitwidth = 32 ,
108+       bool  backend_return_whole_row = false ,
108109      bool  enable_async_update = false ,
109110      std::optional<at::Tensor> table_dims = std::nullopt ,
110111      std::optional<at::Tensor> hash_size_cumsum = std::nullopt )
@@ -126,6 +127,7 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
126127            block_alignment_,
127128            /* blocks_per_chunk=*/ 8192 )),
128129        elem_size_(row_storage_bitwidth / 8 ),
130+         backend_return_whole_row_(backend_return_whole_row),
129131        feature_evict_config_(feature_evict_config) {
130132    executor_ = std::make_unique<folly::CPUThreadPoolExecutor>(std::max<size_t >(
131133        num_threads, facebook::Proc::getCpuInfo ().numCpuCores ));
@@ -608,11 +610,15 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
608610  void  set_range_to_storage (
609611      const  at::Tensor& weights,
610612      const  int64_t  start,
611-       const  int64_t  length) {
612-     const  auto  seq_indices =
613-         at::arange (start, start + length, at::TensorOptions ().dtype (at::kLong ));
614-     const  auto  count = at::tensor ({length}, at::ScalarType::Long);
615-     folly::coro::blockingWait (set_kv_db_async (seq_indices, weights, count));
613+       const  int64_t  length) override  {
614+     if  (backend_return_whole_row_) {
615+       set_kv_with_metaheader_to_storage (weights);
616+     } else  {
617+       const  auto  seq_indices = at::arange (
618+           start, start + length, at::TensorOptions ().dtype (at::kLong ));
619+       const  auto  count = at::tensor ({length}, at::ScalarType::Long);
620+       folly::coro::blockingWait (set_kv_db_async (seq_indices, weights, count));
621+     }
616622  }
617623
618624  void  get_range_from_snapshot (
@@ -625,10 +631,16 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
625631    CHECK (snapshot_handle == nullptr );
626632    const  auto  seq_indices =
627633        at::arange (start, start + length, at::TensorOptions ().dtype (at::kLong ));
628-     const  auto  count = at::tensor ({length}, at::ScalarType::Long);
629-     get_kv_db_async_impl (
630-         seq_indices, weights, count, width_offset, width_length)
631-         .wait ();
634+ 
635+     if  (backend_return_whole_row_) {
636+       get_kv_with_metaheader_from_storage (seq_indices, weights);
637+     } else  {
638+       const  auto  count = at::tensor ({length}, at::ScalarType::Long);
639+       get_kv_db_async_impl (
640+           seq_indices, weights, count, width_offset, width_length)
641+           .wait ();
642+     }
643+ 
632644    //  this is called by checkpoint mostly, and checkpoint should wait until
633645    //  eviction finishes so that we could reacha consistent state before/after
634646    //  state_dict() calls
@@ -642,8 +654,41 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
642654      int64_t  width_offset = 0 ,
643655      std::optional<int64_t > width_length = std::nullopt ) override  {
644656    CHECK (snapshot_handle == nullptr );
657+ 
658+     if  (backend_return_whole_row_) {
659+       get_kv_with_metaheader_from_storage (
660+           ids, weights, width_offset, width_length);
661+     } else  {
662+       const  auto  count = at::tensor ({ids.size (0 )}, at::ScalarType::Long);
663+       get_kv_db_async_impl (ids, weights, count, width_offset, width_length)
664+           .wait ();
665+     }
666+   }
667+ 
668+   //  used for ckpt, get kv with metaheader from storage
669+   void  get_kv_with_metaheader_from_storage (
670+       const  at::Tensor& ids,
671+       const  at::Tensor& weights_with_metaheader,
672+       int64_t  width_offset = 0 ,
673+       std::optional<int64_t > width_length = std::nullopt ) {
645674    const  auto  count = at::tensor ({ids.size (0 )}, at::ScalarType::Long);
646-     get_kv_db_async_impl (ids, weights, count, width_offset, width_length)
675+     get_kv_db_with_metaheader_async_impl (
676+         ids, weights_with_metaheader, count, width_offset, width_length)
677+         .wait ();
678+   }
679+ 
680+   void  set_kv_with_metaheader_to_storage (
681+       const  at::Tensor& weights_with_metaheader) {
682+     std::vector<int64_t > keys (weights_with_metaheader.size (0 ), 0 );
683+     for  (int64_t  i = 0 ; i < weights_with_metaheader.size (0 ); ++i) {
684+       keys[i] = FixedBlockPool::get_key (weights_with_metaheader[i].data_ptr ());
685+     }
686+     auto  indices =
687+         torch::from_blob (keys.data (), {int64_t (keys.size ())}, torch::kInt64 );
688+     const  auto  count =
689+         at::tensor ({weights_with_metaheader.size (0 )}, at::ScalarType::Long);
690+     set_kv_db_with_metaheader_async_impl (
691+         indices, weights_with_metaheader, count)
647692        .wait ();
648693    //  this is called by checkpoint mostly, and checkpoint should wait until
649694    //  eviction finishes so that we could reacha consistent state before/after
@@ -826,6 +871,16 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
826871
827872  void  flush_or_compact (const  int64_t  timestep) override  {}
828873
874+   bool  get_backend_return_whole_row () override  {
875+     return  backend_return_whole_row_;
876+   }
877+ 
878+   int64_t  get_metaheader_width_in_front () override  {
879+     return  backend_return_whole_row_
880+         ? FixedBlockPool::get_metaheader_dim<weight_type>()
881+         : 0 ;
882+   }
883+ 
829884  void  resume_ongoing_eviction () override  {
830885    if  (feature_evict_) {
831886      feature_evict_->resume ();
@@ -930,6 +985,192 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
930985    return  ret;
931986  }
932987
988+   // / Get embeddings and metaheader from kvstore.
989+   // /
990+   // / @param indices The 1D embedding index tensor, should skip on negative
991+   // / value
992+   // / @param weights_with_metaheader The 2D tensor that each row(embeddings) is
993+   // / paired up with relative element in <indices>. This tensor will be
994+   // / filled up with the returned embeddings from KVstore.
995+   // / @param count A single element tensor that contains the number of indices
996+   // / to be processed
997+   // /
998+   // / @return None
999+   folly::SemiFuture<std::vector<folly::Unit>>
1000+   get_kv_db_with_metaheader_async_impl (
1001+       const  at::Tensor& indices,
1002+       const  at::Tensor& weights_with_metaheader,
1003+       const  at::Tensor& count,
1004+       int64_t  width_offset = 0 ,
1005+       std::optional<int64_t > width_length = std::nullopt ) {
1006+     std::vector<folly::Future<folly::Unit>> futures;
1007+     auto  row_width = weights_with_metaheader.size (1 );
1008+     auto  copy_width = width_length.value_or (row_width);
1009+     CHECK_LE (row_width, block_size_);
1010+     CHECK_EQ (copy_width, row_width);
1011+     auto  shardid_to_indexes = shard_input (indices, count);
1012+ 
1013+     for  (auto  iter = shardid_to_indexes.begin ();
1014+          iter != shardid_to_indexes.end ();
1015+          iter++) {
1016+       const  auto  shard_id = iter->first ;
1017+       const  auto  indexes = iter->second ;
1018+       auto  f =
1019+           folly::via (executor_.get ())
1020+               .thenValue ([this ,
1021+                           shard_id,
1022+                           indexes,
1023+                           &indices,
1024+                           &weights_with_metaheader,
1025+                           width_offset,
1026+                           row_width](folly::Unit) {
1027+                 FBGEMM_DISPATCH_INTEGRAL_TYPES (
1028+                     indices.scalar_type (),
1029+                     " dram_kvstore_get_with_metaheader"  ,
1030+                     [this ,
1031+                      shard_id,
1032+                      indexes,
1033+                      &indices,
1034+                      &weights_with_metaheader,
1035+                      width_offset,
1036+                      row_width] {
1037+                       using  index_t  = scalar_t ;
1038+                       CHECK (indices.is_contiguous ());
1039+                       CHECK (weights_with_metaheader.is_contiguous ());
1040+                       CHECK_EQ (
1041+                           indices.size (0 ), weights_with_metaheader.size (0 ));
1042+                       auto  wlmap = kv_store_.by (shard_id).wlock ();
1043+                       auto  indices_data_ptr = indices.data_ptr <index_t >();
1044+                       auto  weights_data_ptr =
1045+                           weights_with_metaheader.data_ptr <weight_type>();
1046+                       {
1047+                         for  (auto  index_iter = indexes.begin ();
1048+                              index_iter != indexes.end ();
1049+                              index_iter++) {
1050+                           const  auto  weights_row_index = *index_iter;
1051+                           auto  weight_idx =
1052+                               int64_t (indices_data_ptr[weights_row_index]);
1053+                           const  auto  cached_iter = wlmap->find (weight_idx);
1054+                           //  Defensive programming
1055+                           //  it shouldn't occur under normal circumstances
1056+                           if  (cached_iter == wlmap->end ()) {
1057+                             std::memset (
1058+                                 &(weights_data_ptr
1059+                                       [weights_row_index * row_width]),
1060+                                 0 ,
1061+                                 row_width);
1062+                             continue ;
1063+                           }
1064+ 
1065+                           //  For weight KVT, offset=0 and it will read the whole
1066+                           //  row. For optimizer, offset=dim(metaheader) +
1067+                           //  emb_dim so it will only read the optimizer part
1068+                           const  auto * ptr_offset_from_front =
1069+                               FixedBlockPool::ptr_offset_from_front<
1070+                                   weight_type>(
1071+                                   cached_iter->second , width_offset);
1072+                           std::copy (
1073+                               ptr_offset_from_front,
1074+                               ptr_offset_from_front + row_width,
1075+                               &(weights_data_ptr
1076+                                     [weights_row_index * row_width]));
1077+                         }
1078+                       }
1079+                     });
1080+               });
1081+       futures.push_back (std::move (f));
1082+     }
1083+     return  folly::collect (futures);
1084+   }
1085+ 
1086+   // / insert embeddings and metaheader into kvstore.
1087+   // / current underlying memory management is done through F14FastMap
1088+   // / key value pair will be sharded into multiple shards to increase
1089+   // / parallelism.
1090+   // /
1091+   // / @param indices The 1D embedding index tensor, should skip on negative
1092+   // / value
1093+   // / @param weights_with_metaheader The 2D tensor that each row(embeddings with
1094+   // / metaheader) is paired up with relative element in <indices>
1095+   // / @param count A single element tensor that contains the number of indices
1096+   // / to be processed
1097+   // /
1098+   // / @return None
1099+   folly::SemiFuture<std::vector<folly::Unit>>
1100+   set_kv_db_with_metaheader_async_impl (
1101+       const  at::Tensor& indices,
1102+       const  at::Tensor& weights_with_metaheader,
1103+       const  at::Tensor& count) {
1104+     std::vector<folly::Future<folly::Unit>> futures;
1105+     auto  shardid_to_indexes = shard_input (indices, count);
1106+     for  (auto  iter = shardid_to_indexes.begin ();
1107+          iter != shardid_to_indexes.end ();
1108+          iter++) {
1109+       const  auto  shard_id = iter->first ;
1110+       const  auto  indexes = iter->second ;
1111+       auto  f =
1112+           folly::via (executor_.get ())
1113+               .thenValue (
1114+                   [this , shard_id, indexes, &indices, &weights_with_metaheader](
1115+                       folly::Unit) {
1116+                     FBGEMM_DISPATCH_INTEGRAL_TYPES (
1117+                         indices.scalar_type (),
1118+                         " dram_kv_set_with_metaheader"  ,
1119+                         [this ,
1120+                          shard_id,
1121+                          indexes,
1122+                          &indices,
1123+                          &weights_with_metaheader] {
1124+                           using  index_t  = scalar_t ;
1125+                           CHECK (indices.is_contiguous ());
1126+                           CHECK (weights_with_metaheader.is_contiguous ());
1127+                           CHECK_EQ (
1128+                               indices.size (0 ), weights_with_metaheader.size (0 ));
1129+                           {
1130+                             auto  wlmap = kv_store_.by (shard_id).wlock ();
1131+                             auto * pool = kv_store_.pool_by (shard_id);
1132+                             int64_t  stride = weights_with_metaheader.size (1 );
1133+                             auto  indices_data_ptr = indices.data_ptr <index_t >();
1134+                             auto  weights_data_ptr =
1135+                                 weights_with_metaheader.data_ptr <weight_type>();
1136+                             for  (auto  index_iter = indexes.begin ();
1137+                                  index_iter != indexes.end ();
1138+                                  index_iter++) {
1139+                               const  auto & id_index = *index_iter;
1140+                               auto  id = int64_t (indices_data_ptr[id_index]);
1141+                               //  Defensive programming
1142+                               //  it shouldn't occur under normal circumstances
1143+                               auto  used = FixedBlockPool::get_used (
1144+                                   weights_data_ptr + id_index * stride);
1145+                               if  (!used) {
1146+                                 continue ;
1147+                               }
1148+                               //  use mempool
1149+                               weight_type* block = nullptr ;
1150+                               //  First check if the key already exists
1151+                               auto  it = wlmap->find (id);
1152+                               if  (it != wlmap->end ()) {
1153+                                 block = it->second ;
1154+                               } else  {
1155+                                 //  Key doesn't exist, allocate new block and
1156+                                 //  insert.
1157+                                 block =
1158+                                     pool->template  allocate_t <weight_type>();
1159+                                 wlmap->insert ({id, block});
1160+                               }
1161+                               std::copy (
1162+                                   weights_data_ptr + id_index * stride,
1163+                                   weights_data_ptr + (id_index + 1 ) * stride,
1164+                                   block);
1165+                             }
1166+                           }
1167+                         });
1168+                   });
1169+       futures.push_back (std::move (f));
1170+     }
1171+     return  folly::collect (futures);
1172+   }
1173+ 
9331174  std::unique_ptr<folly::CPUThreadPoolExecutor> executor_;
9341175  //  background thread
9351176  folly::FunctionScheduler scheduler_;
@@ -942,6 +1183,7 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
9421183  std::atomic_bool is_eviction_ongoing_ = false ;
9431184  std::vector<std::unique_ptr<ssd::Initializer>> initializers_;
9441185  int64_t  elem_size_;
1186+   bool  backend_return_whole_row_;
9451187  std::vector<int64_t > sub_table_dims_;
9461188  std::vector<int64_t > sub_table_hash_cumsum_;
9471189  std::optional<c10::intrusive_ptr<FeatureEvictConfig>> feature_evict_config_;
0 commit comments