diff --git a/integration_tests/2pc_test.go b/integration_tests/2pc_test.go index 9369bfb666..1628b6a055 100644 --- a/integration_tests/2pc_test.go +++ b/integration_tests/2pc_test.go @@ -2512,3 +2512,139 @@ func (s *testCommitterSuite) TestKillSignal() { err = txn.Commit(context.Background()) s.ErrorContains(err, "query interrupted") } +<<<<<<< HEAD +======= + +func (s *testCommitterSuite) TestUninterruptibleAction() { + s.Run("Cleanup", func() { + var killed uint32 = 0 + txn := s.begin() + txn.SetVars(kv.NewVariables(&killed)) + err := txn.Set([]byte("k1"), []byte("v1")) + s.NoError(err) + committer, err := txn.NewCommitter(0) + s.NoError(err) + err = committer.PrewriteAllMutations(context.Background()) + s.NoError(err) + atomic.StoreUint32(&killed, 2) + s.NoError(committer.CleanupMutations(context.Background())) + }) + s.Run("PessimisticRollback", func() { + var killed uint32 = 0 + txn := s.begin() + txn.SetVars(kv.NewVariables(&killed)) + txn.SetPessimistic(true) + err := txn.LockKeys(context.Background(), kv.NewLockCtx(txn.StartTS(), kv.LockNoWait, time.Now()), []byte("k2")) + s.NoError(err) + atomic.StoreUint32(&killed, 2) + committer, err := txn.NewCommitter(0) + s.NoError(err) + s.NoError(committer.PessimisticRollbackMutations(context.Background(), committer.GetMutations())) + }) + s.Run("Commit", func() { + var killed uint32 = 0 + txn := s.begin() + txn.SetVars(kv.NewVariables(&killed)) + err := txn.Set([]byte("k1"), []byte("v1")) + s.NoError(err) + committer, err := txn.NewCommitter(0) + s.NoError(err) + err = committer.PrewriteAllMutations(context.Background()) + s.NoError(err) + atomic.StoreUint32(&killed, 2) + commitTS, err := s.store.GetOracle().GetTimestamp(context.Background(), &oracle.Option{}) + s.NoError(err) + committer.SetCommitTS(commitTS) + s.NoError(committer.CommitMutations(context.Background())) + }) +} + +func (s *testCommitterSuite) Test2PCLifecycleHooks() { + reachedPre := atomic.Bool{} + reachedPost := atomic.Bool{} + + var wg sync.WaitGroup + + t1 := s.begin() + t1.SetBackgroundGoroutineLifecycleHooks(transaction.LifecycleHooks{ + Pre: func() { + wg.Add(1) + + reachedPre.Store(true) + }, + Post: func() { + s.Equal(reachedPre.Load(), true) + reachedPost.Store(true) + + wg.Done() + }, + }) + t1.Set([]byte("a"), []byte("a")) + t1.Set([]byte("z"), []byte("z")) + s.Nil(t1.Commit(context.Background())) + + s.Equal(reachedPre.Load(), true) + s.Equal(reachedPost.Load(), false) + wg.Wait() + s.Equal(reachedPost.Load(), true) +} + +func (s *testCommitterSuite) Test2PCCleanupLifecycleHooks() { + reachedPre := atomic.Bool{} + reachedPost := atomic.Bool{} + + var wg sync.WaitGroup + + t1 := s.begin() + t1.SetBackgroundGoroutineLifecycleHooks(transaction.LifecycleHooks{ + Pre: func() { + wg.Add(1) + + reachedPre.Store(true) + }, + Post: func() { + s.Equal(reachedPre.Load(), true) + reachedPost.Store(true) + + wg.Done() + }, + }) + t1.Set([]byte("a"), []byte("a")) + t1.Set([]byte("z"), []byte("z")) + committer, err := t1.NewCommitter(0) + s.Nil(err) + + committer.CleanupWithoutWait(context.Background()) + + s.Equal(reachedPre.Load(), true) + s.Equal(reachedPost.Load(), false) + wg.Wait() + s.Equal(reachedPost.Load(), true) +} + +func (s *testCommitterSuite) TestFailWithUndeterminedResult() { + txn := s.begin() + s.Nil(txn.Set([]byte("key"), []byte("value"))) + // prewrite fail for an undetermined result in commit should retry + s.Nil(failpoint.Enable( + "tikvclient/rpcPrewriteResult", + // prewrite fail, but retry success + `1*return("undeterminedResult")->return("")`, + )) + err := txn.Commit(context.Background()) + s.Nil(err) + + // commit primary fail for an undetermined result should return undetermined error + txn = s.begin() + s.Nil(txn.Set([]byte("key"), []byte("value"))) + // prewrite fail for an undetermined result in commit should retry + s.Nil(failpoint.Enable( + "tikvclient/rpcCommitResult", + // prewrite success, but the first commit fail + `1*return("undeterminedResult")->return("")`, + )) + err = txn.Commit(context.Background()) + s.NotNil(err) + s.True(tikverr.IsErrorUndetermined(err)) +} +>>>>>>> b7e019d3 (txnkv: prevent some actions from being interrupted by kill (#1665)) diff --git a/internal/locate/region_request.go b/internal/locate/region_request.go index 1bd97faa14..a053ea4d6b 100644 --- a/internal/locate/region_request.go +++ b/internal/locate/region_request.go @@ -1502,6 +1502,516 @@ func IsFakeRegionError(err *errorpb.Error) bool { const slowLogSendReqTime = 100 * time.Millisecond +<<<<<<< HEAD +======= +// sendReqArgs defines the input arguments of the send request. +type sendReqArgs struct { + bo *retry.Backoffer + req *tikvrpc.Request + regionID RegionVerID + timeout time.Duration + et tikvrpc.EndpointType + opts []StoreSelectorOption +} + +// sendReqState represents the state of sending request with retry, which allows us to construct a state and start to +// retry from that state. +type sendReqState struct { + *RegionRequestSender + + // args holds the input arguments of the send request. + args sendReqArgs + + // vars maintains the local variables used in the retry loop. + vars struct { + rpcCtx *RPCContext + resp *tikvrpc.Response + regionErr *errorpb.Error + err error + msg string + sendTimes int + } + + invariants reqInvariants +} + +// reqInvariants holds the input state of the request. +// If the tikvrpc.Request is changed during the retries or other operations. +// the reqInvariants can tell the initial state. +type reqInvariants struct { + staleRead bool +} + +// next encapsulates one iteration of the retry loop. calling `next` will handle send error (s.vars.err) or region error +// (s.vars.regionErr) if one of them exists. When the error is retriable, `next` then constructs a new RPCContext and +// sends the request again. `next` returns true if the retry loop should stop, either because the request is done or +// exhausted (cannot complete by retrying). +func (s *sendReqState) next() (done bool) { + bo, req := s.args.bo, s.args.req + + // check whether the session/query is killed during the Next() + if req.IsInterruptible() { + if err := bo.CheckKilled(); err != nil { + s.vars.resp, s.vars.err = nil, err + return true + } + } + + // handle send error + if s.vars.err != nil { + if e := s.onSendFail(bo, s.vars.rpcCtx, req, s.vars.err); e != nil { + s.vars.rpcCtx, s.vars.resp = nil, nil + s.vars.msg = fmt.Sprintf("failed to handle send error: %v", s.vars.err) + return true + } + s.vars.err = nil + } + + // handle region error + if s.vars.regionErr != nil { + retry, err := s.onRegionError(bo, s.vars.rpcCtx, req, s.vars.regionErr) + if err != nil { + s.vars.rpcCtx, s.vars.resp = nil, nil + s.vars.err = err + s.vars.msg = fmt.Sprintf("failed to handle region error: %v", err) + return true + } + if !retry { + s.vars.msg = fmt.Sprintf("met unretriable region error: %T", s.vars.regionErr) + return true + } + s.vars.regionErr = nil + } + + s.vars.rpcCtx, s.vars.resp = nil, nil + if !req.IsRetryRequest && s.vars.sendTimes > 0 { + req.IsRetryRequest = true + } + + s.vars.rpcCtx, s.vars.err = s.getRPCContext(bo, req, s.args.regionID, s.args.et, s.args.opts...) + if s.vars.err != nil { + return true + } + + if _, err := util.EvalFailpoint("invalidCacheAndRetry"); err == nil { + // cooperate with tikvclient/setGcResolveMaxBackoff + if c := bo.GetCtx().Value("injectedBackoff"); c != nil { + s.vars.regionErr = &errorpb.Error{EpochNotMatch: &errorpb.EpochNotMatch{}} + s.vars.resp, s.vars.err = tikvrpc.GenRegionErrorResp(req, s.vars.regionErr) + return true + } + } + + if s.vars.rpcCtx == nil { + // TODO(youjiali1995): remove it when using the replica selector for all requests. + // If the region is not found in cache, it must be out + // of date and already be cleaned up. We can skip the + // RPC by returning RegionError directly. + + // TODO: Change the returned error to something like "region missing in cache", + // and handle this error like EpochNotMatch, which means to re-split the request and retry. + if s.replicaSelector != nil { + if s.vars.err = s.replicaSelector.backoffOnNoCandidate(bo); s.vars.err != nil { + return true + } + } + s.vars.regionErr = &errorpb.Error{EpochNotMatch: &errorpb.EpochNotMatch{}} + s.vars.resp, s.vars.err = tikvrpc.GenRegionErrorResp(req, s.vars.regionErr) + s.vars.msg = "throwing pseudo region error due to no replica available" + return true + } + + // should reset the access location after shifting to the next store. + s.setReqAccessLocation(req) + + logutil.Eventf(bo.GetCtx(), "send %s request to region %d at %s", req.Type, s.args.regionID.id, s.vars.rpcCtx.Addr) + s.storeAddr = s.vars.rpcCtx.Addr + + req.Context.ClusterId = s.vars.rpcCtx.ClusterID + if req.InputRequestSource != "" && s.replicaSelector != nil { + patchRequestSource(req, s.replicaSelector.replicaType()) + } + // RPCClient.SendRequest will attach `req.Context` thus skip attaching here to reduce overhead. + if s.vars.err = tikvrpc.SetContextNoAttach(req, s.vars.rpcCtx.Meta, s.vars.rpcCtx.Peer); s.vars.err != nil { + return true + } + if s.replicaSelector != nil { + if s.vars.err = s.replicaSelector.backoffOnRetry(s.vars.rpcCtx.Store, bo); s.vars.err != nil { + return true + } + } + + if _, err := util.EvalFailpoint("beforeSendReqToRegion"); err == nil { + if hook := bo.GetCtx().Value("sendReqToRegionHook"); hook != nil { + h := hook.(func(*tikvrpc.Request)) + h(req) + } + } + + // judge the store limit switch. + if limit := kv.StoreLimit.Load(); limit > 0 { + if s.vars.err = s.getStoreToken(s.vars.rpcCtx.Store, limit); s.vars.err != nil { + return true + } + defer s.releaseStoreToken(s.vars.rpcCtx.Store) + } + + canceled := s.send() + s.vars.sendTimes++ + + if s.vars.err != nil { + // Because in rpc logic, context.Cancel() will be transferred to rpcContext.Cancel error. For rpcContext cancel, + // we need to retry the request. But for context cancel active, for example, limitExec gets the required rows, + // we shouldn't retry the request, it will go to backoff and hang in retry logic. + if canceled { + return true + } + if val, e := util.EvalFailpoint("noRetryOnRpcError"); e == nil && val.(bool) { + return true + } + // need to handle send error + return false + } + + if val, err := util.EvalFailpoint("mockRetrySendReqToRegion"); err == nil && val.(bool) { + // force retry + return false + } + + s.vars.regionErr, s.vars.err = s.vars.resp.GetRegionError() + if s.vars.err != nil { + s.vars.rpcCtx, s.vars.resp = nil, nil + return true + } else if s.vars.regionErr != nil { + // need to handle region error + return false + } + + if s.replicaSelector != nil { + s.replicaSelector.onSendSuccess(req) + } + + return true +} + +func (s *sendReqState) send() (canceled bool) { + bo, req := s.args.bo, s.args.req + rpcCtx := s.vars.rpcCtx + ctx := bo.GetCtx() + if rawHook := ctx.Value(RPCCancellerCtxKey{}); rawHook != nil { + var cancel context.CancelFunc + ctx, cancel = rawHook.(*RPCCanceller).WithCancel(ctx) + defer cancel() + } + + // sendToAddr is the first target address that will receive the request. If proxy is used, sendToAddr will point to + // the proxy that will forward the request to the final target. + sendToAddr := rpcCtx.Addr + if rpcCtx.ProxyStore == nil { + req.ForwardedHost = "" + } else { + req.ForwardedHost = rpcCtx.Addr + sendToAddr = rpcCtx.ProxyAddr + } + + // Count the replica number as the RU cost factor. + req.ReplicaNumber = 1 + if rpcCtx.Meta != nil && len(rpcCtx.Meta.GetPeers()) > 0 { + req.ReplicaNumber = 0 + for _, peer := range rpcCtx.Meta.GetPeers() { + role := peer.GetRole() + if role == metapb.PeerRole_Voter || role == metapb.PeerRole_Learner { + req.ReplicaNumber++ + } + } + } + + var sessionID uint64 + if v := bo.GetCtx().Value(util.SessionID); v != nil { + sessionID = v.(uint64) + } + + injectFailOnSend := false + if val, e := util.EvalFailpoint("rpcFailOnSend"); e == nil { + inject := true + // Optional filters + if s, ok := val.(string); ok { + if s == "greengc" && !req.IsGreenGCRequest() { + inject = false + } else if s == "write" && !req.IsTxnWriteRequest() { + inject = false + } + } else if sessionID == 0 { + inject = false + } + + if inject { + logutil.Logger(ctx).Info( + "[failpoint] injected RPC error on send", zap.Stringer("type", req.Type), + zap.Stringer("req", req.Req.(fmt.Stringer)), zap.Stringer("ctx", &req.Context), + ) + injectFailOnSend = true + s.vars.err = errors.New("injected RPC error on send") + } + } + + if !injectFailOnSend { + start := time.Now() + s.vars.resp, s.vars.err = s.client.SendRequest(ctx, sendToAddr, req, s.args.timeout) + rpcDuration := time.Since(start) + if s.replicaSelector != nil { + recordAttemptedTime(s.replicaSelector, rpcDuration) + } + + var execDetails *util.ExecDetails + if stmtExec := ctx.Value(util.ExecDetailsKey); stmtExec != nil { + execDetails := stmtExec.(*util.ExecDetails) + atomic.AddInt64(&execDetails.WaitKVRespDuration, int64(rpcDuration)) + } + collector := networkCollector{ + staleRead: s.invariants.staleRead, + } + collector.onReq(req, execDetails) + collector.onResp(req, s.vars.resp, execDetails) + + // Record timecost of external requests on related Store when `ReplicaReadMode == "PreferLeader"`. + if rpcCtx.Store != nil && req.ReplicaReadType == kv.ReplicaReadPreferLeader && !util.IsInternalRequest(req.RequestSource) { + rpcCtx.Store.healthStatus.recordClientSideSlowScoreStat(rpcDuration) + } + if s.Stats != nil { + s.Stats.RecordRPCRuntimeStats(req.Type, rpcDuration) + if val, fpErr := util.EvalFailpoint("tikvStoreRespResult"); fpErr == nil { + if val.(bool) { + if req.Type == tikvrpc.CmdCop && bo.GetTotalSleep() == 0 { + s.vars.resp, s.vars.err = &tikvrpc.Response{ + Resp: &coprocessor.Response{RegionError: &errorpb.Error{EpochNotMatch: &errorpb.EpochNotMatch{}}}, + }, nil + return + } + } + } + } + + if val, e := util.EvalFailpoint("rpcFailOnRecv"); e == nil { + inject := true + // Optional filters + if s, ok := val.(string); ok { + if s == "greengc" && !req.IsGreenGCRequest() { + inject = false + } else if s == "write" && !req.IsTxnWriteRequest() { + inject = false + } + } else if sessionID == 0 { + inject = false + } + + if inject { + logutil.Logger(ctx).Info( + "[failpoint] injected RPC error on recv", zap.Stringer("type", req.Type), + zap.Stringer("req", req.Req.(fmt.Stringer)), zap.Stringer("ctx", &req.Context), + zap.Error(s.vars.err), zap.String("extra response info", fetchRespInfo(s.vars.resp)), + ) + s.vars.resp, s.vars.err = nil, errors.New("injected RPC error on recv") + } + } + + if val, e := util.EvalFailpoint("rpcContextCancelErr"); e == nil { + if val.(bool) { + ctx1, cancel := context.WithCancel(context.Background()) + cancel() + <-ctx1.Done() + ctx = ctx1 + s.vars.resp, s.vars.err = nil, ctx.Err() + } + } + + if _, e := util.EvalFailpoint("onRPCFinishedHook"); e == nil { + if hook := bo.GetCtx().Value("onRPCFinishedHook"); hook != nil { + h := hook.(func(*tikvrpc.Request, *tikvrpc.Response, error) (*tikvrpc.Response, error)) + s.vars.resp, s.vars.err = h(req, s.vars.resp, s.vars.err) + } + } + } + + if rpcCtx.ProxyStore != nil { + fromStore := strconv.FormatUint(rpcCtx.ProxyStore.storeID, 10) + toStore := strconv.FormatUint(rpcCtx.Store.storeID, 10) + result := "ok" + if s.vars.err != nil { + result = "fail" + } + metrics.TiKVForwardRequestCounter.WithLabelValues(fromStore, toStore, req.Type.String(), result).Inc() + } + + if err := s.vars.err; err != nil { + if isRPCError(err) { + s.rpcError = err + } + if s.Stats != nil { + errStr := getErrMsg(err) + s.Stats.RecordRPCErrorStats(errStr) + s.recordRPCAccessInfo(req, s.vars.rpcCtx, errStr) + } + if canceled = ctx.Err() != nil && errors.Cause(ctx.Err()) == context.Canceled; canceled { + metrics.TiKVRPCErrorCounter.WithLabelValues("context-canceled", storeIDLabel(s.vars.rpcCtx)).Inc() + } + } + return +} + +// initForAsyncRequest initializes the state for an async request. It should be called once before the first `next`. +func (s *sendReqState) initForAsyncRequest() (ok bool) { + bo, req := s.args.bo, s.args.req + + s.vars.rpcCtx, s.vars.err = s.getRPCContext(bo, req, s.args.regionID, s.args.et, s.args.opts...) + if s.vars.err != nil { + return false + } + if s.vars.rpcCtx == nil { + s.vars.regionErr = &errorpb.Error{EpochNotMatch: &errorpb.EpochNotMatch{}} + s.vars.resp, s.vars.err = tikvrpc.GenRegionErrorResp(req, s.vars.regionErr) + s.vars.msg = "throwing pseudo region error due to no replica available" + return false + } + + s.storeAddr = s.vars.rpcCtx.Addr + + // set access location based on source and target "zone" label. + s.setReqAccessLocation(req) + req.Context.ClusterId = s.vars.rpcCtx.ClusterID + if req.InputRequestSource != "" && s.replicaSelector != nil { + patchRequestSource(req, s.replicaSelector.replicaType()) + } + if s.vars.err = tikvrpc.SetContextNoAttach(req, s.vars.rpcCtx.Meta, s.vars.rpcCtx.Peer); s.vars.err != nil { + return false + } + + // Count the replica number as the RU cost factor. + req.ReplicaNumber = 1 + if s.vars.rpcCtx.Meta != nil && len(s.vars.rpcCtx.Meta.GetPeers()) > 0 { + req.ReplicaNumber = 0 + for _, peer := range s.vars.rpcCtx.Meta.GetPeers() { + role := peer.GetRole() + if role == metapb.PeerRole_Voter || role == metapb.PeerRole_Learner { + req.ReplicaNumber++ + } + } + } + + return true +} + +// setReqAccessLocation set the AccessLocation value of kv request based on +// target store "zone" label. +func (s *sendReqState) setReqAccessLocation(req *tikvrpc.Request) { + // set access location based on source and target "zone" label. + if s.replicaSelector != nil && s.replicaSelector.target != nil { + selfZoneLabel := config.GetGlobalConfig().ZoneLabel + targetZoneLabel, _ := s.replicaSelector.target.store.GetLabelValue("zone") + // if either "zone" label is "", we actually don't known if it involves cross AZ traffic. + if selfZoneLabel == "" || targetZoneLabel == "" { + req.AccessLocation = kv.AccessUnknown + } else if selfZoneLabel == targetZoneLabel { + req.AccessLocation = kv.AccessLocalZone + } else { + req.AccessLocation = kv.AccessCrossZone + } + } +} + +// handleAsyncResponse handles the response of an async request. +func (s *sendReqState) handleAsyncResponse(start time.Time, canceled bool, resp *tikvrpc.Response, err error, execDetails *util.ExecDetails, cancels ...context.CancelFunc) (done bool) { + if len(cancels) > 0 { + defer func() { + for i := len(cancels) - 1; i >= 0; i-- { + cancels[i]() + } + }() + } + s.vars.resp, s.vars.err = resp, err + req := s.args.req + rpcDuration := time.Since(start) + if s.replicaSelector != nil { + recordAttemptedTime(s.replicaSelector, rpcDuration) + } + if s.Stats != nil { + s.Stats.RecordRPCRuntimeStats(req.Type, rpcDuration) + } + if execDetails != nil { + atomic.AddInt64(&execDetails.WaitKVRespDuration, int64(rpcDuration)) + } + collector := networkCollector{ + staleRead: s.invariants.staleRead, + } + collector.onReq(req, execDetails) + collector.onResp(req, resp, execDetails) + + if s.vars.rpcCtx.Store != nil && req.ReplicaReadType == kv.ReplicaReadPreferLeader && !util.IsInternalRequest(req.RequestSource) { + s.vars.rpcCtx.Store.healthStatus.recordClientSideSlowScoreStat(rpcDuration) + } + if s.vars.rpcCtx.ProxyStore != nil { + fromStore := strconv.FormatUint(s.vars.rpcCtx.ProxyStore.storeID, 10) + toStore := strconv.FormatUint(s.vars.rpcCtx.Store.storeID, 10) + result := "ok" + if s.vars.err != nil { + result = "fail" + } + metrics.TiKVForwardRequestCounter.WithLabelValues(fromStore, toStore, req.Type.String(), result).Inc() + } + + if err := s.vars.err; err != nil { + if isRPCError(err) { + s.rpcError = err + metrics.AsyncSendReqCounterWithRPCError.Inc() + } else { + metrics.AsyncSendReqCounterWithSendError.Inc() + } + if s.Stats != nil { + errStr := getErrMsg(err) + s.Stats.RecordRPCErrorStats(errStr) + s.recordRPCAccessInfo(req, s.vars.rpcCtx, errStr) + } + if canceled { + metrics.TiKVRPCErrorCounter.WithLabelValues("context-canceled", storeIDLabel(s.vars.rpcCtx)).Inc() + } + return canceled + } + + s.vars.regionErr, s.vars.err = s.vars.resp.GetRegionError() + if s.vars.err != nil { + s.vars.rpcCtx, s.vars.resp = nil, nil + metrics.AsyncSendReqCounterWithOtherError.Inc() + return true + } else if s.vars.regionErr != nil { + // need to handle region error + metrics.AsyncSendReqCounterWithRegionError.Inc() + return false + } + + if s.replicaSelector != nil { + s.replicaSelector.onSendSuccess(req) + } + + metrics.AsyncSendReqCounterWithOK.Inc() + return true +} + +// toResponseExt converts the state to a ResponseExt . +func (s *sendReqState) toResponseExt() (*tikvrpc.ResponseExt, error) { + if s.vars.err != nil { + return nil, s.vars.err + } + if s.vars.resp == nil { + return nil, errors.New("invalid state: response is nil") + } + resp := &tikvrpc.ResponseExt{Response: *s.vars.resp} + if s.vars.rpcCtx != nil { + resp.Addr = s.vars.rpcCtx.Addr + } + return resp, nil +} + +>>>>>>> b7e019d3 (txnkv: prevent some actions from being interrupted by kill (#1665)) // SendReqCtx sends a request to tikv server and return response and RPCCtx of this RPC. func (s *RegionRequestSender) SendReqCtx( bo *retry.Backoffer, diff --git a/tikvrpc/tikvrpc.go b/tikvrpc/tikvrpc.go index 9d8cab64b3..881b329330 100644 --- a/tikvrpc/tikvrpc.go +++ b/tikvrpc/tikvrpc.go @@ -326,6 +326,16 @@ func (req *Request) IsDebugReq() bool { return false } +// IsInterruptible checks if the request can be interrupted when the query is killed. +func (req *Request) IsInterruptible() bool { + switch req.Type { + case CmdPessimisticRollback, CmdBatchRollback, CmdCommit: + return false + default: + return true + } +} + // Get returns GetRequest in request. func (req *Request) Get() *kvrpcpb.GetRequest { return req.Req.(*kvrpcpb.GetRequest) diff --git a/txnkv/transaction/2pc.go b/txnkv/transaction/2pc.go index b7471f87a2..04fd6297c7 100644 --- a/txnkv/transaction/2pc.go +++ b/txnkv/transaction/2pc.go @@ -75,6 +75,7 @@ const slowRequestThreshold = time.Minute type twoPhaseCommitAction interface { handleSingleBatch(*twoPhaseCommitter, *retry.Backoffer, batchMutations) error tiKVTxnRegionsNumHistogram() prometheus.Observer + isInterruptible() bool String() string } @@ -1055,9 +1056,10 @@ func (c *twoPhaseCommitter) doActionOnBatches( // killSignal should never be nil for TiDB if c.txn != nil && c.txn.vars != nil && c.txn.vars.Killed != nil { // Do not reset the killed flag here. Let the upper layer reset the flag. - // Before it resets, any request is considered valid to be killed. + // Before it resets, any request is considered valid to be killed if the + // corresponding action is interruptible. status := atomic.LoadUint32(c.txn.vars.Killed) - if status != 0 { + if status != 0 && action.isInterruptible() { logutil.BgLogger().Info( "query is killed", zap.Uint32( "signal", diff --git a/txnkv/transaction/cleanup.go b/txnkv/transaction/cleanup.go index b080319906..92e87f75cc 100644 --- a/txnkv/transaction/cleanup.go +++ b/txnkv/transaction/cleanup.go @@ -103,6 +103,10 @@ func (action actionCleanup) handleSingleBatch(c *twoPhaseCommitter, bo *retry.Ba return nil } +func (actionCleanup) isInterruptible() bool { + return false +} + func (c *twoPhaseCommitter) cleanupMutations(bo *retry.Backoffer, mutations CommitterMutations) error { return c.doActionOnMutations(bo, actionCleanup{isInternal: c.txn.isInternal()}, mutations) } diff --git a/txnkv/transaction/commit.go b/txnkv/transaction/commit.go index 3818e39783..46a0e974e6 100644 --- a/txnkv/transaction/commit.go +++ b/txnkv/transaction/commit.go @@ -221,6 +221,10 @@ func (action actionCommit) handleSingleBatch(c *twoPhaseCommitter, bo *retry.Bac return nil } +func (actionCommit) isInterruptible() bool { + return false +} + func (c *twoPhaseCommitter) commitMutations(bo *retry.Backoffer, mutations CommitterMutations) error { if span := opentracing.SpanFromContext(bo.GetCtx()); span != nil && span.Tracer() != nil { span1 := span.Tracer().StartSpan("twoPhaseCommitter.commitMutations", opentracing.ChildOf(span.Context())) diff --git a/txnkv/transaction/pessimistic.go b/txnkv/transaction/pessimistic.go index 5855a08df8..c621b83c4c 100644 --- a/txnkv/transaction/pessimistic.go +++ b/txnkv/transaction/pessimistic.go @@ -535,6 +535,10 @@ func (action actionPessimisticLock) handlePessimisticLockResponseForceLockMode( return true, nil } +func (actionPessimisticLock) isInterruptible() bool { + return true +} + func (actionPessimisticRollback) handleSingleBatch( c *twoPhaseCommitter, bo *retry.Backoffer, batch batchMutations, ) error { @@ -570,6 +574,10 @@ func (actionPessimisticRollback) handleSingleBatch( return nil } +func (actionPessimisticRollback) isInterruptible() bool { + return false +} + func (c *twoPhaseCommitter) pessimisticLockMutations( bo *retry.Backoffer, lockCtx *kv.LockCtx, lockWaitMode kvrpcpb.PessimisticLockWakeUpMode, mutations CommitterMutations, diff --git a/txnkv/transaction/pipelined_flush.go b/txnkv/transaction/pipelined_flush.go new file mode 100644 index 0000000000..e7234d5f4a --- /dev/null +++ b/txnkv/transaction/pipelined_flush.go @@ -0,0 +1,537 @@ +// Copyright 2024 TiKV Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package transaction + +import ( + "bytes" + "context" + "fmt" + "strconv" + "sync/atomic" + "time" + + "github.com/docker/go-units" + "github.com/golang/protobuf/proto" //nolint:staticcheck + "github.com/opentracing/opentracing-go" + "github.com/pingcap/errors" + "github.com/pingcap/kvproto/pkg/kvrpcpb" + "github.com/prometheus/client_golang/prometheus" + "github.com/tikv/client-go/v2/config/retry" + tikverr "github.com/tikv/client-go/v2/error" + "github.com/tikv/client-go/v2/internal/client" + "github.com/tikv/client-go/v2/internal/locate" + "github.com/tikv/client-go/v2/internal/logutil" + "github.com/tikv/client-go/v2/kv" + "github.com/tikv/client-go/v2/tikvrpc" + "github.com/tikv/client-go/v2/txnkv/rangetask" + "github.com/tikv/client-go/v2/txnkv/txnlock" + "github.com/tikv/client-go/v2/util" + "go.uber.org/zap" +) + +// PipelinedRequestSource is the source of the Flush & ResolveLock requests in a txn with pipelined memdb. +// txn.GetRequestSource may cause data race because the upper layer may edit the source while the flush requests are built in background. +// So we use the fixed source from the upper layer to avoid the data race. +// This also distinguishes the resource usage between p-DML(pipelined DML) and other small DMLs. +const PipelinedRequestSource = "external_pdml" + +type actionPipelinedFlush struct { + generation uint64 +} + +var _ twoPhaseCommitAction = actionPipelinedFlush{} + +func (action actionPipelinedFlush) String() string { + return "pipelined_flush" +} + +func (action actionPipelinedFlush) tiKVTxnRegionsNumHistogram() prometheus.Observer { + return nil +} + +func (c *twoPhaseCommitter) buildPipelinedFlushRequest(batch batchMutations, generation uint64) *tikvrpc.Request { + m := batch.mutations + mutations := make([]*kvrpcpb.Mutation, m.Len()) + + for i := 0; i < m.Len(); i++ { + assertion := kvrpcpb.Assertion_None + if m.IsAssertExists(i) { + assertion = kvrpcpb.Assertion_Exist + } + if m.IsAssertNotExist(i) { + assertion = kvrpcpb.Assertion_NotExist + } + mutations[i] = &kvrpcpb.Mutation{ + Op: m.GetOp(i), + Key: m.GetKey(i), + Value: m.GetValue(i), + Assertion: assertion, + } + } + + minCommitTS := c.startTS + 1 + + req := &kvrpcpb.FlushRequest{ + Mutations: mutations, + PrimaryKey: c.primary(), + StartTs: c.startTS, + MinCommitTs: minCommitTS, + Generation: generation, + LockTtl: max(defaultLockTTL, ManagedLockTTL), + AssertionLevel: c.txn.assertionLevel, + } + + r := tikvrpc.NewRequest( + tikvrpc.CmdFlush, req, kvrpcpb.Context{ + Priority: c.priority, + SyncLog: c.syncLog, + ResourceGroupTag: c.resourceGroupTag, + DiskFullOpt: c.txn.diskFullOpt, + TxnSource: c.txn.txnSource, + MaxExecutionDurationMs: uint64(client.MaxWriteExecutionTime.Milliseconds()), + RequestSource: PipelinedRequestSource, + ResourceControlContext: &kvrpcpb.ResourceControlContext{ + ResourceGroupName: c.resourceGroupName, + }, + }, + ) + if c.resourceGroupTag == nil && c.resourceGroupTagger != nil { + c.resourceGroupTagger(r) + } + return r +} + +func (action actionPipelinedFlush) handleSingleBatch( + c *twoPhaseCommitter, bo *retry.Backoffer, batch batchMutations, +) (err error) { + if len(c.primaryKey) == 0 { + logutil.Logger(bo.GetCtx()).Error( + "[pipelined dml] primary key should be set before pipelined flush", + zap.Uint64("startTS", c.startTS), + zap.Uint64("generation", action.generation), + ) + return errors.New("[pipelined dml] primary key should be set before pipelined flush") + } + + tBegin := time.Now() + attempts := 0 + + req := c.buildPipelinedFlushRequest(batch, action.generation) + sender := locate.NewRegionRequestSender(c.store.GetRegionCache(), c.store.GetTiKVClient(), c.store.GetOracle()) + var resolvingRecordToken *int + + for { + attempts++ + reqBegin := time.Now() + if reqBegin.Sub(tBegin) > slowRequestThreshold { + logutil.Logger(bo.GetCtx()).Warn( + "[pipelined dml] slow pipelined flush request", + zap.Uint64("startTS", c.startTS), + zap.Uint64("generation", action.generation), + zap.Stringer("region", &batch.region), + zap.Int("attempts", attempts), + ) + tBegin = time.Now() + } + resp, _, err := sender.SendReq(bo, req, batch.region, client.ReadTimeoutShort) + // Unexpected error occurs, return it + if err != nil { + return err + } + regionErr, err := resp.GetRegionError() + if err != nil { + return err + } + if regionErr != nil { + if err = retry.MayBackoffForRegionError(regionErr, bo); err != nil { + return err + } + if regionErr.GetDiskFull() != nil { + storeIds := regionErr.GetDiskFull().GetStoreId() + desc := " " + for _, i := range storeIds { + desc += strconv.FormatUint(i, 10) + " " + } + + logutil.Logger(bo.GetCtx()).Error( + "Request failed cause of TiKV disk full", + zap.String("store_id", desc), + zap.String("reason", regionErr.GetDiskFull().GetReason()), + ) + + return errors.New(regionErr.String()) + } + same, err := batch.relocate(bo, c.store.GetRegionCache()) + if err != nil { + return err + } + if same { + continue + } + err = c.doActionOnMutations(bo, actionPipelinedFlush{generation: action.generation}, batch.mutations) + return err + } + if resp.Resp == nil { + return errors.WithStack(tikverr.ErrBodyMissing) + } + flushResp := resp.Resp.(*kvrpcpb.FlushResponse) + keyErrs := flushResp.GetErrors() + if len(keyErrs) == 0 { + // Clear the RPC Error since the request is evaluated successfully. + sender.SetRPCError(nil) + + // Update CommitDetails + reqDuration := time.Since(reqBegin) + c.getDetail().MergeFlushReqDetails( + reqDuration, + batch.region.GetID(), + sender.GetStoreAddr(), + flushResp.ExecDetailsV2, + ) + + if batch.isPrimary { + // start keepalive after primary key is written. + c.run(c, nil, true) + } + return nil + } + locks := make([]*txnlock.Lock, 0, len(keyErrs)) + + logged := make(map[uint64]struct{}, 1) + for _, keyErr := range keyErrs { + // Check already exists error + if alreadyExist := keyErr.GetAlreadyExist(); alreadyExist != nil { + e := &tikverr.ErrKeyExist{AlreadyExist: alreadyExist} + return c.extractKeyExistsErr(e) + } + + // Extract lock from key error + lock, err1 := txnlock.ExtractLockFromKeyErr(keyErr) + if err1 != nil { + return err1 + } + if _, ok := logged[lock.TxnID]; !ok { + logutil.Logger(bo.GetCtx()).Info( + "[pipelined dml] flush encounters lock. "+ + "More locks belonging to the same transaction may be omitted", + zap.Uint64("txnID", c.startTS), + zap.Uint64("generation", action.generation), + zap.Stringer("lock", lock), + ) + logged[lock.TxnID] = struct{}{} + } + // If an optimistic transaction encounters a lock with larger TS, this transaction will certainly + // fail due to a WriteConflict error. So we can construct and return an error here early. + // Pessimistic transactions don't need such an optimization. If this key needs a pessimistic lock, + // TiKV will return a PessimisticLockNotFound error directly if it encounters a different lock. Otherwise, + // TiKV returns lock.TTL = 0, and we still need to resolve the lock. + if lock.TxnID > c.startTS && !c.isPessimistic { + return tikverr.NewErrWriteConflictWithArgs( + c.startTS, + lock.TxnID, + 0, + lock.Key, + kvrpcpb.WriteConflict_Optimistic, + ) + } + locks = append(locks, lock) + } + if resolvingRecordToken == nil { + token := c.store.GetLockResolver().RecordResolvingLocks(locks, c.startTS) + resolvingRecordToken = &token + defer c.store.GetLockResolver().ResolveLocksDone(c.startTS, *resolvingRecordToken) + } else { + c.store.GetLockResolver().UpdateResolvingLocks(locks, c.startTS, *resolvingRecordToken) + } + resolveLockOpts := txnlock.ResolveLocksOptions{ + CallerStartTS: c.startTS, + Locks: locks, + Detail: &c.getDetail().ResolveLock, + } + resolveLockRes, err := c.store.GetLockResolver().ResolveLocksWithOpts(bo, resolveLockOpts) + if err != nil { + return err + } + msBeforeExpired := resolveLockRes.TTL + if msBeforeExpired > 0 { + err = bo.BackoffWithCfgAndMaxSleep( + retry.BoTxnLock, + int(msBeforeExpired), + errors.Errorf("[pipelined dml] flush lockedKeys: %d", len(locks)), + ) + if err != nil { + logutil.Logger(bo.GetCtx()).Warn( + "[pipelined dml] backoff failed during flush", + zap.Error(err), + zap.Uint64("startTS", c.startTS), + zap.Uint64("generation", action.generation), + ) + return err + } + } + } +} + +func (actionPipelinedFlush) isInterruptible() bool { + return true +} + +func (c *twoPhaseCommitter) pipelinedFlushMutations(bo *retry.Backoffer, mutations CommitterMutations, generation uint64) error { + if span := opentracing.SpanFromContext(bo.GetCtx()); span != nil && span.Tracer() != nil { + span1 := span.Tracer().StartSpan("twoPhaseCommitter.pipelinedFlushMutations", opentracing.ChildOf(span.Context())) + defer span1.Finish() + bo.SetCtx(opentracing.ContextWithSpan(bo.GetCtx(), span1)) + } + + return c.doActionOnMutations(bo, actionPipelinedFlush{generation}, mutations) +} + +func (c *twoPhaseCommitter) commitFlushedMutations(bo *retry.Backoffer) error { + logutil.Logger(bo.GetCtx()).Info( + "[pipelined dml] start to commit transaction", + zap.Int("keys", c.txn.GetMemBuffer().Len()), + zap.Duration("flush_wait_duration", c.txn.GetMemBuffer().GetMetrics().WaitDuration), + zap.Duration("total_duration", c.txn.GetMemBuffer().GetMetrics().TotalDuration), + zap.Uint64("memdb traversal cache hit", c.txn.GetMemBuffer().GetMetrics().MemDBHitCount), + zap.Uint64("memdb traversal cache miss", c.txn.GetMemBuffer().GetMetrics().MemDBMissCount), + zap.String("size", units.HumanSize(float64(c.txn.GetMemBuffer().Size()))), + zap.Uint64("startTS", c.startTS), + ) + commitTS, err := c.store.GetTimestampWithRetry(bo, c.txn.GetScope()) + if err != nil { + logutil.Logger(bo.GetCtx()).Warn("[pipelined dml] commit transaction get commitTS failed", + zap.Error(err), + zap.Uint64("txnStartTS", c.startTS), + ) + return err + } + atomic.StoreUint64(&c.commitTS, commitTS) + + if _, err := util.EvalFailpoint("pipelinedCommitFail"); err == nil { + return errors.New("pipelined DML commit failed") + } + + primaryMutation := NewPlainMutations(1) + primaryMutation.Push(c.pipelinedCommitInfo.primaryOp, c.primaryKey, nil, false, false, false, false) + if err = c.commitMutations(bo, &primaryMutation); err != nil { + return errors.Trace(err) + } + c.mu.Lock() + c.mu.committed = true + c.mu.Unlock() + logutil.Logger(bo.GetCtx()).Info( + "[pipelined dml] transaction is committed", + zap.Uint64("startTS", c.startTS), + zap.Uint64("commitTS", commitTS), + ) + broadcastToAllStores( + c.txn, + c.store, + retry.NewBackofferWithVars( + bo.GetCtx(), + broadcastMaxBackoff, + c.txn.vars, + ), + &kvrpcpb.TxnStatus{ + StartTs: c.startTS, + MinCommitTs: c.minCommitTSMgr.get(), + CommitTs: commitTS, + RolledBack: false, + IsCompleted: false, + }, + c.resourceGroupName, + c.resourceGroupTag, + ) + + if _, err := util.EvalFailpoint("pipelinedSkipResolveLock"); err == nil { + return nil + } + + // async resolve the rest locks. + commitBo := retry.NewBackofferWithVars(c.store.Ctx(), CommitSecondaryMaxBackoff, c.txn.vars) + c.resolveFlushedLocks(commitBo, c.pipelinedCommitInfo.pipelinedStart, c.pipelinedCommitInfo.pipelinedEnd, true) + return nil +} + +// buildPipelinedResolveHandler returns a function which resolves all locks for the given region. +// If the region cache is stale, it reloads the region info and resolve the rest ranges. +// The function also count resolved regions. +func (c *twoPhaseCommitter) buildPipelinedResolveHandler(commit bool, resolved *atomic.Uint64) (rangetask.TaskHandler, error) { + commitVersion := uint64(0) + if commit { + commitVersion = atomic.LoadUint64(&c.commitTS) + if commitVersion == 0 { + return nil, errors.New("commitTS is 0") + } + } + maxBackOff := cleanupMaxBackoff + if commit { + maxBackOff = CommitSecondaryMaxBackoff + } + regionCache := c.store.GetRegionCache() + // the handler function runs in a different goroutine, should copy the required values before it to avoid race. + kvContext := &kvrpcpb.Context{ + Priority: c.priority, + SyncLog: c.syncLog, + ResourceGroupTag: c.resourceGroupTag, + DiskFullOpt: c.diskFullOpt, + TxnSource: c.txnSource, + RequestSource: PipelinedRequestSource, + ResourceControlContext: &kvrpcpb.ResourceControlContext{ + ResourceGroupName: c.resourceGroupName, + }, + } + return func(ctx context.Context, r kv.KeyRange) (rangetask.TaskStat, error) { + start := r.StartKey + res := rangetask.TaskStat{} + for { + lreq := &kvrpcpb.ResolveLockRequest{ + StartVersion: c.startTS, + CommitVersion: commitVersion, + } + req := tikvrpc.NewRequest(tikvrpc.CmdResolveLock, lreq, *proto.Clone(kvContext).(*kvrpcpb.Context)) + bo := retry.NewBackoffer(ctx, maxBackOff) + loc, err := regionCache.LocateKey(bo, start) + if err != nil { + return res, err + } + resp, err := c.store.SendReq(bo, req, loc.Region, client.MaxWriteExecutionTime) + if err != nil { + err = bo.Backoff(retry.BoRegionMiss, err) + if err != nil { + logutil.Logger(bo.GetCtx()).Error("send resolve lock request error", zap.Error(err)) + return res, err + } + continue + } + regionErr, err := resp.GetRegionError() + if err != nil { + logutil.Logger(bo.GetCtx()).Error("get region error failed", zap.Error(err)) + return res, err + } + if regionErr != nil { + err = bo.Backoff(retry.BoRegionMiss, errors.New(regionErr.String())) + if err != nil { + logutil.Logger(bo.GetCtx()).Error("send resolve lock get region error", zap.Error(err)) + return res, err + } + continue + } + if resp.Resp == nil { + logutil.Logger(bo.GetCtx()).Error("send resolve lock response body missing", zap.Error(errors.WithStack(tikverr.ErrBodyMissing))) + return res, err + } + cmdResp := resp.Resp.(*kvrpcpb.ResolveLockResponse) + if keyErr := cmdResp.GetError(); keyErr != nil { + err = errors.Errorf("unexpected resolve err: %s", keyErr) + logutil.BgLogger().Error( + "resolveLock error", + zap.Error(err), + zap.Uint64("startVer", lreq.StartVersion), + zap.Uint64("commitVer", lreq.CommitVersion), + zap.String("debugInfo", tikverr.ExtractDebugInfoStrFromKeyErr(keyErr)), + ) + return res, err + } + resolved.Add(1) + res.CompletedRegions++ + if loc.EndKey == nil || bytes.Compare(loc.EndKey, r.EndKey) >= 0 { + return res, nil + } + start = loc.EndKey + } + }, nil +} + +// resolveFlushedLocks resolves all locks in the given range [start, end) with the given status. +// The resolve process is running in another goroutine so this function won't block. +func (c *twoPhaseCommitter) resolveFlushedLocks(bo *retry.Backoffer, start, end []byte, commit bool) { + var resolved atomic.Uint64 + handler, err := c.buildPipelinedResolveHandler(commit, &resolved) + commitTs := uint64(0) + if commit { + commitTs = atomic.LoadUint64(&c.commitTS) + } + if err != nil { + logutil.Logger(bo.GetCtx()).Error( + "[pipelined dml] build buildPipelinedResolveHandler error", + zap.Error(err), + zap.Uint64("resolved regions", resolved.Load()), + zap.Uint64("startTS", c.startTS), + zap.Uint64("commitTS", commitTs), + ) + return + } + + status := "rollback" + if commit { + status = "commit" + } + + runner := rangetask.NewRangeTaskRunnerWithID( + fmt.Sprintf("pipelined-dml-%s", status), + fmt.Sprintf("pipelined-dml-%s-%d", status, c.startTS), + c.store, + c.txn.pipelinedResolveLockConcurrency, + handler, + ) + runner.SetStatLogInterval(30 * time.Second) + runner.SetRegionsPerTask(1) + + c.txn.spawnWithStorePool(func() { + if err = runner.RunOnRange(bo.GetCtx(), start, end); err != nil { + logutil.Logger(bo.GetCtx()).Error("[pipelined dml] resolve flushed locks failed", + zap.String("txn-status", status), + zap.Uint64("resolved regions", resolved.Load()), + zap.Uint64("startTS", c.startTS), + zap.Uint64("commitTS", commitTs), + zap.Uint64("session", c.sessionID), + zap.Error(err), + ) + } else { + logutil.Logger(bo.GetCtx()).Info("[pipelined dml] resolve flushed locks done", + zap.String("txn-status", status), + zap.Uint64("resolved regions", resolved.Load()), + zap.Uint64("startTS", c.startTS), + zap.Uint64("commitTS", commitTs), + zap.Uint64("session", c.sessionID), + ) + + // wait a while before notifying txn_status_cache to evict the txn, + // which tolerates slow followers and avoids the situation that the + // txn is evicted before the follower catches up. + time.Sleep(broadcastGracePeriod) + + broadcastToAllStores( + c.txn, + c.store, + retry.NewBackofferWithVars( + bo.GetCtx(), + broadcastMaxBackoff, + c.txn.vars, + ), + &kvrpcpb.TxnStatus{ + StartTs: c.startTS, + MinCommitTs: 0, + CommitTs: commitTs, + RolledBack: !commit, + IsCompleted: true, + }, + c.resourceGroupName, + c.resourceGroupTag, + ) + } + }) +} diff --git a/txnkv/transaction/prewrite.go b/txnkv/transaction/prewrite.go index d74b5fe6d1..7c1742de52 100644 --- a/txnkv/transaction/prewrite.go +++ b/txnkv/transaction/prewrite.go @@ -493,6 +493,10 @@ func (action actionPrewrite) handleSingleBatch( } } +func (actionPrewrite) isInterruptible() bool { + return true +} + func (c *twoPhaseCommitter) prewriteMutations(bo *retry.Backoffer, mutations CommitterMutations) error { if span := opentracing.SpanFromContext(bo.GetCtx()); span != nil && span.Tracer() != nil { span1 := span.Tracer().StartSpan("twoPhaseCommitter.prewriteMutations", opentracing.ChildOf(span.Context()))