Skip to content

feat: watch <-context.Done #2455

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 26 additions & 14 deletions cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -1279,28 +1279,34 @@ func (c *ClusterClient) cmdsAreReadOnly(ctx context.Context, cmds []Cmder) bool
func (c *ClusterClient) processPipelineNode(
ctx context.Context, node *clusterNode, cmds []Cmder, failedCmds *cmdsMap,
) {
_ = node.Client.withProcessPipelineHook(ctx, cmds, func(ctx context.Context, cmds []Cmder) error {
_ = node.Client.withProcessPipelineHook(ctx, cmds, func(ctx context.Context, cmds []Cmder) (retErr error) {
cn, err := node.Client.getConn(ctx)
if err != nil {
_ = c.mapCmdsByNode(ctx, failedCmds, cmds)
setCmdsErr(cmds, err)
return err
}

var processErr error
if retErr = cn.WatchCancel(c.context(ctx)); retErr != nil {
return retErr
}

defer func() {
node.Client.releaseConn(ctx, cn, processErr)
if err = cn.WatchFinish(); err != nil {
retErr = err
}
node.Client.releaseConn(ctx, cn, retErr)
}()
processErr = c.processPipelineNodeConn(ctx, node, cn, cmds, failedCmds)
retErr = c.processPipelineNodeConn(ctx, node, cn, cmds, failedCmds)

return processErr
return retErr
})
}

func (c *ClusterClient) processPipelineNodeConn(
ctx context.Context, node *clusterNode, cn *pool.Conn, cmds []Cmder, failedCmds *cmdsMap,
) error {
if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error {
if err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
return writeCmds(wr, cmds)
}); err != nil {
if shouldRetry(err, true) {
Expand All @@ -1310,7 +1316,7 @@ func (c *ClusterClient) processPipelineNodeConn(
return err
}

return cn.WithReader(c.context(ctx), c.opt.ReadTimeout, func(rd *proto.Reader) error {
return cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error {
return c.pipelineReadCmds(ctx, node, rd, cmds, failedCmds)
})
}
Expand Down Expand Up @@ -1460,28 +1466,34 @@ func (c *ClusterClient) processTxPipelineNode(
ctx context.Context, node *clusterNode, cmds []Cmder, failedCmds *cmdsMap,
) {
cmds = wrapMultiExec(ctx, cmds)
_ = node.Client.withProcessPipelineHook(ctx, cmds, func(ctx context.Context, cmds []Cmder) error {
_ = node.Client.withProcessPipelineHook(ctx, cmds, func(ctx context.Context, cmds []Cmder) (retErr error) {
cn, err := node.Client.getConn(ctx)
if err != nil {
_ = c.mapCmdsByNode(ctx, failedCmds, cmds)
setCmdsErr(cmds, err)
return err
}

var processErr error
if retErr = cn.WatchCancel(c.context(ctx)); retErr != nil {
return retErr
}

defer func() {
node.Client.releaseConn(ctx, cn, processErr)
if err = cn.WatchFinish(); err != nil {
retErr = err
}
node.Client.releaseConn(ctx, cn, retErr)
}()
processErr = c.processTxPipelineNodeConn(ctx, node, cn, cmds, failedCmds)
retErr = c.processTxPipelineNodeConn(ctx, node, cn, cmds, failedCmds)

return processErr
return retErr
})
}

func (c *ClusterClient) processTxPipelineNodeConn(
ctx context.Context, node *clusterNode, cn *pool.Conn, cmds []Cmder, failedCmds *cmdsMap,
) error {
if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error {
if err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
return writeCmds(wr, cmds)
}); err != nil {
if shouldRetry(err, true) {
Expand All @@ -1491,7 +1503,7 @@ func (c *ClusterClient) processTxPipelineNodeConn(
return err
}

return cn.WithReader(c.context(ctx), c.opt.ReadTimeout, func(rd *proto.Reader) error {
return cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error {
statusCmd := cmds[0].(*StatusCmd)
// Trim multi and exec.
trimmedCmds := cmds[1 : len(cmds)-1]
Expand Down
107 changes: 86 additions & 21 deletions internal/pool/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@ import (
"github.com/redis/go-redis/v9/internal/proto"
)

var noDeadline = time.Time{}
var (
// aLongTimeAgo is a non-zero time, used to immediately unblock the network.
aLongTimeAgo = time.Unix(1, 0)
noDeadline = time.Time{}
)

type Conn struct {
usedAt int64 // atomic
Expand All @@ -23,6 +27,13 @@ type Conn struct {
Inited bool
pooled bool
createdAt time.Time

_closed int32 // atomic
closeChan chan struct{}
watchChan chan context.Context
finishChan chan struct{}
interruptChan chan error
watching bool
}

func NewConn(netConn net.Conn) *Conn {
Expand All @@ -34,9 +45,72 @@ func NewConn(netConn net.Conn) *Conn {
cn.bw = bufio.NewWriter(netConn)
cn.wr = proto.NewWriter(cn.bw)
cn.SetUsedAt(time.Now())

cn.closeChan = make(chan struct{})
cn.interruptChan = make(chan error)
cn.finishChan = make(chan struct{})
cn.watchChan = make(chan context.Context, 1)

go cn.loopWatcher()

return cn
}

func (cn *Conn) loopWatcher() {
var ctx context.Context
for {
select {
case ctx = <-cn.watchChan:
case <-cn.closeChan:
return
}

select {
case <-ctx.Done():
_ = cn.netConn.SetDeadline(aLongTimeAgo)
cn.interruptChan <- ctx.Err()
case <-cn.finishChan:
case <-cn.closeChan:
return
}
}
}

func (cn *Conn) WatchFinish() error {
if !cn.watching {
return nil
}

var err error
select {
case cn.finishChan <- struct{}{}:
cn.watching = false
case err = <-cn.interruptChan:
cn.watching = false
case <-cn.closeChan:
}
return err
}

func (cn *Conn) WatchCancel(ctx context.Context) error {
if cn.watching {
panic("repeat watchCancel")
}
if err := ctx.Err(); err != nil {
return err
}
if ctx.Done() == nil {
return nil
}
if cn.closed() {
return nil
}

cn.watching = true
cn.watchChan <- ctx
return nil
}

func (cn *Conn) UsedAt() time.Time {
unix := atomic.LoadInt64(&cn.usedAt)
return time.Unix(unix, 0)
Expand Down Expand Up @@ -94,33 +168,24 @@ func (cn *Conn) WithWriter(
return cn.bw.Flush()
}

func (cn *Conn) closed() bool {
return atomic.LoadInt32(&cn._closed) == 1
}

func (cn *Conn) Close() error {
return cn.netConn.Close()
if atomic.CompareAndSwapInt32(&cn._closed, 0, 1) {
close(cn.closeChan)
return cn.netConn.Close()
}
return nil
}

func (cn *Conn) deadline(ctx context.Context, timeout time.Duration) time.Time {
func (cn *Conn) deadline(_ context.Context, timeout time.Duration) time.Time {
tm := time.Now()
cn.SetUsedAt(tm)

if timeout > 0 {
tm = tm.Add(timeout)
}

if ctx != nil {
deadline, ok := ctx.Deadline()
if ok {
if timeout == 0 {
return deadline
}
if deadline.Before(tm) {
return deadline
}
return tm
}
}

if timeout > 0 {
return tm
return tm.Add(timeout)
}

return noDeadline
Expand Down
6 changes: 5 additions & 1 deletion internal/pool/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ func (p *ConnPool) Put(ctx context.Context, cn *Conn) {
return
}

if !cn.pooled {
if !cn.pooled || cn.watching {
p.Remove(ctx, cn, nil)
return
}
Expand Down Expand Up @@ -486,6 +486,10 @@ func (p *ConnPool) Close() error {
func (p *ConnPool) isHealthyConn(cn *Conn) bool {
now := time.Now()

if cn.watching {
return false
}

if p.cfg.ConnMaxLifetime > 0 && now.Sub(cn.createdAt) >= p.cfg.ConnMaxLifetime {
return false
}
Expand Down
28 changes: 17 additions & 11 deletions redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -336,20 +336,26 @@ func (c *baseClient) releaseConn(ctx context.Context, cn *pool.Conn, err error)

func (c *baseClient) withConn(
ctx context.Context, fn func(context.Context, *pool.Conn) error,
) error {
) (retErr error) {
cn, err := c.getConn(ctx)
if err != nil {
return err
}

var fnErr error
if retErr = cn.WatchCancel(c.context(ctx)); retErr != nil {
return retErr
}

defer func() {
c.releaseConn(ctx, cn, fnErr)
if err = cn.WatchFinish(); err != nil {
retErr = err
}
c.releaseConn(ctx, cn, retErr)
}()

fnErr = fn(ctx, cn)
retErr = fn(ctx, cn)

return fnErr
return retErr
}

func (c *baseClient) dial(ctx context.Context, network, addr string) (net.Conn, error) {
Expand Down Expand Up @@ -380,14 +386,14 @@ func (c *baseClient) _process(ctx context.Context, cmd Cmder, attempt int) (bool

retryTimeout := uint32(0)
if err := c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error {
if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error {
if err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
return writeCmd(wr, cmd)
}); err != nil {
atomic.StoreUint32(&retryTimeout, 1)
return err
}

if err := cn.WithReader(c.context(ctx), c.cmdTimeout(cmd), cmd.readReply); err != nil {
if err := cn.WithReader(ctx, c.cmdTimeout(cmd), cmd.readReply); err != nil {
if cmd.readTimeout() == nil {
atomic.StoreUint32(&retryTimeout, 1)
} else {
Expand Down Expand Up @@ -486,14 +492,14 @@ func (c *baseClient) generalProcessPipeline(
func (c *baseClient) pipelineProcessCmds(
ctx context.Context, cn *pool.Conn, cmds []Cmder,
) (bool, error) {
if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error {
if err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
return writeCmds(wr, cmds)
}); err != nil {
setCmdsErr(cmds, err)
return true, err
}

if err := cn.WithReader(c.context(ctx), c.opt.ReadTimeout, func(rd *proto.Reader) error {
if err := cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error {
return pipelineReadCmds(rd, cmds)
}); err != nil {
return true, err
Expand All @@ -518,14 +524,14 @@ func pipelineReadCmds(rd *proto.Reader, cmds []Cmder) error {
func (c *baseClient) txPipelineProcessCmds(
ctx context.Context, cn *pool.Conn, cmds []Cmder,
) (bool, error) {
if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error {
if err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
return writeCmds(wr, cmds)
}); err != nil {
setCmdsErr(cmds, err)
return true, err
}

if err := cn.WithReader(c.context(ctx), c.opt.ReadTimeout, func(rd *proto.Reader) error {
if err := cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error {
statusCmd := cmds[0].(*StatusCmd)
// Trim multi and exec.
trimmedCmds := cmds[1 : len(cmds)-1]
Expand Down
31 changes: 31 additions & 0 deletions redis_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -531,3 +531,34 @@ var _ = Describe("Hook", func() {
}))
})
})

var _ = Describe("Watch Context.Done", func() {
var client *redis.Client

BeforeEach(func() {
opt := redisOptions()
opt.ReadTimeout = 10 * time.Second
opt.WriteTimeout = 10 * time.Second
opt.PoolTimeout = 10 * time.Second
opt.ContextTimeoutEnabled = true

client = redis.NewClient(opt)
Expect(client.FlushDB(ctx).Err()).NotTo(HaveOccurred())
})

AfterEach(func() {
err := client.Close()
Expect(err).NotTo(HaveOccurred())
})

It("should cancel", func() {
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(1 * time.Second)
cancel()
}()

err := client.BLPop(ctx, 10*time.Second, "key1").Err()
Expect(errors.Is(err, context.Canceled)).To(BeTrue())
})
})