diff --git a/src/raw/requests.rs b/src/raw/requests.rs index f6f32f3d..e7825b36 100644 --- a/src/raw/requests.rs +++ b/src/raw/requests.rs @@ -10,7 +10,6 @@ use crate::proto::tikvpb::tikv_client::TikvClient; use crate::range_request; use crate::region::RegionWithLeader; use crate::request::plan::ResponseWithShard; -use crate::request::Collect; use crate::request::CollectSingle; use crate::request::DefaultProcessor; use crate::request::KvRequest; @@ -19,6 +18,7 @@ use crate::request::Process; use crate::request::RangeRequest; use crate::request::Shardable; use crate::request::SingleKey; +use crate::request::{Batchable, Collect}; use crate::shardable_key; use crate::shardable_keys; use crate::shardable_range; @@ -35,12 +35,15 @@ use crate::Result; use crate::Value; use async_trait::async_trait; use futures::stream::BoxStream; +use futures::{stream, StreamExt}; use std::any::Any; use std::ops::Range; use std::sync::Arc; use std::time::Duration; use tonic::transport::Channel; +const RAW_KV_REQUEST_BATCH_SIZE: u64 = 16 * 1024; // 16 KB + pub fn new_raw_get_request(key: Vec, cf: Option) -> kvrpcpb::RawGetRequest { let mut req = kvrpcpb::RawGetRequest::default(); req.key = key; @@ -188,6 +191,14 @@ impl KvRequest for kvrpcpb::RawBatchPutRequest { type Response = kvrpcpb::RawBatchPutResponse; } +impl Batchable for kvrpcpb::RawBatchPutRequest { + type Item = (kvrpcpb::KvPair, u64); + + fn item_size(item: &Self::Item) -> u64 { + (item.0.key.len() + item.0.value.len()) as u64 + } +} + impl Shardable for kvrpcpb::RawBatchPutRequest { type Shard = Vec<(kvrpcpb::KvPair, u64)>; @@ -204,6 +215,16 @@ impl Shardable for kvrpcpb::RawBatchPutRequest { .collect(); kv_ttl.sort_by(|a, b| a.0.key.cmp(&b.0.key)); region_stream_for_keys(kv_ttl.into_iter(), pd_client.clone()) + .flat_map(|result| match result { + Ok((keys, region)) => stream::iter(kvrpcpb::RawBatchPutRequest::batches( + keys, + RAW_KV_REQUEST_BATCH_SIZE, + )) + .map(move |batch| Ok((batch, region.clone()))) + .boxed(), + Err(e) => stream::iter(Err(e)).boxed(), + }) + .boxed() } fn apply_shard(&mut self, shard: Self::Shard) { @@ -212,6 +233,18 @@ impl Shardable for kvrpcpb::RawBatchPutRequest { self.ttls = ttls; } + fn clone_then_apply_shard(&self, shard: Self::Shard) -> Self + where + Self: Sized + Clone, + { + let mut cloned = Self::default(); + cloned.context = self.context.clone(); + cloned.cf = self.cf.clone(); + cloned.for_cas = self.for_cas; + cloned.apply_shard(shard); + cloned + } + fn apply_store(&mut self, store: &RegionStore) -> Result<()> { self.set_leader(&store.region_with_leader) } @@ -257,7 +290,56 @@ impl KvRequest for kvrpcpb::RawBatchDeleteRequest { type Response = kvrpcpb::RawBatchDeleteResponse; } -shardable_keys!(kvrpcpb::RawBatchDeleteRequest); +impl Batchable for kvrpcpb::RawBatchDeleteRequest { + type Item = Vec; + + fn item_size(item: &Self::Item) -> u64 { + item.len() as u64 + } +} + +impl Shardable for kvrpcpb::RawBatchDeleteRequest { + type Shard = Vec>; + + fn shards( + &self, + pd_client: &Arc, + ) -> BoxStream<'static, Result<(Self::Shard, RegionWithLeader)>> { + let mut keys = self.keys.clone(); + keys.sort(); + region_stream_for_keys(keys.into_iter(), pd_client.clone()) + .flat_map(|result| match result { + Ok((keys, region)) => stream::iter(kvrpcpb::RawBatchDeleteRequest::batches( + keys, + RAW_KV_REQUEST_BATCH_SIZE, + )) + .map(move |batch| Ok((batch, region.clone()))) + .boxed(), + Err(e) => stream::iter(Err(e)).boxed(), + }) + .boxed() + } + + fn apply_shard(&mut self, shard: Self::Shard) { + self.keys = shard; + } + + fn clone_then_apply_shard(&self, shard: Self::Shard) -> Self + where + Self: Sized + Clone, + { + let mut cloned = Self::default(); + cloned.context = self.context.clone(); + cloned.cf = self.cf.clone(); + cloned.for_cas = self.for_cas; + cloned.apply_shard(shard); + cloned + } + + fn apply_store(&mut self, store: &RegionStore) -> Result<()> { + self.set_leader(&store.region_with_leader) + } +} pub fn new_raw_delete_range_request( start_key: Vec, diff --git a/src/request/plan.rs b/src/request/plan.rs index be915f77..c4a1a368 100644 --- a/src/request/plan.rs +++ b/src/request/plan.rs @@ -117,11 +117,10 @@ where ) -> Result<::Result> { let shards = current_plan.shards(&pd_client).collect::>().await; debug!("single_plan_handler, shards: {}", shards.len()); - let mut handles = Vec::new(); + let mut handles = Vec::with_capacity(shards.len()); for shard in shards { let (shard, region) = shard?; - let mut clone = current_plan.clone(); - clone.apply_shard(shard); + let clone = current_plan.clone_then_apply_shard(shard); let handle = tokio::spawn(Self::single_shard_handler( pd_client.clone(), clone, diff --git a/src/request/shard.rs b/src/request/shard.rs index a68f446a..1bac69e4 100644 --- a/src/request/shard.rs +++ b/src/request/shard.rs @@ -48,6 +48,16 @@ pub trait Shardable { fn apply_shard(&mut self, shard: Self::Shard); + /// Implementation can skip unnecessary fields clone if fields will be overwritten by `apply_shard`. + fn clone_then_apply_shard(&self, shard: Self::Shard) -> Self + where + Self: Sized + Clone, + { + let mut cloned = self.clone(); + cloned.apply_shard(shard); + cloned + } + fn apply_store(&mut self, store: &RegionStore) -> Result<()>; } @@ -103,6 +113,16 @@ impl Shardable for Dispatch { self.request.apply_shard(shard); } + fn clone_then_apply_shard(&self, shard: Self::Shard) -> Self + where + Self: Sized + Clone, + { + Dispatch { + request: self.request.clone_then_apply_shard(shard), + kv_client: self.kv_client.clone(), + } + } + fn apply_store(&mut self, store: &RegionStore) -> Result<()> { self.kv_client = Some(store.client.clone()); self.request.apply_store(store) diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs index 73b43459..411e8810 100644 --- a/tests/integration_tests.rs +++ b/tests/integration_tests.rs @@ -871,6 +871,53 @@ async fn raw_write_million() -> Result<()> { Ok(()) } +/// Tests raw batch put has a large payload. +#[tokio::test] +#[serial] +async fn raw_large_batch_put() -> Result<()> { + const TARGET_SIZE_MB: usize = 100; + const KEY_SIZE: usize = 32; + const VALUE_SIZE: usize = 1024; + + let pair_size = KEY_SIZE + VALUE_SIZE; + let target_size_bytes = TARGET_SIZE_MB * 1024 * 1024; + let num_pairs = target_size_bytes / pair_size; + let mut pairs = Vec::with_capacity(num_pairs); + for i in 0..num_pairs { + // Generate key: "bench_key_" + zero-padded number + let key = format!("bench_key_{:010}", i); + + // Generate value: repeat pattern to reach VALUE_SIZE + let pattern = format!("value_{}", i % 1000); + let repeat_count = VALUE_SIZE.div_ceil(pattern.len()); + let value = pattern.repeat(repeat_count); + + pairs.push(KvPair::from((key, value))); + } + + init().await?; + let client = + RawClient::new_with_config(pd_addrs(), Config::default().with_default_keyspace()).await?; + + client.batch_put(pairs.clone()).await?; + + let keys = pairs.iter().map(|pair| pair.0.clone()).collect::>(); + // split into multiple batch_get to avoid response too large error + const BATCH_SIZE: usize = 1000; + let mut got = Vec::with_capacity(num_pairs); + for chunk in keys.chunks(BATCH_SIZE) { + let mut partial = client.batch_get(chunk.to_vec()).await?; + got.append(&mut partial); + } + assert_eq!(got, pairs); + + client.batch_delete(keys.clone()).await?; + let res = client.batch_get(keys).await?; + assert!(res.is_empty()); + + Ok(()) +} + /// Tests raw ttl API. #[tokio::test] #[serial]