diff --git a/cluster.go b/cluster.go index 12417d1f9..2e9d0bc70 100644 --- a/cluster.go +++ b/cluster.go @@ -1291,7 +1291,19 @@ func (c *ClusterClient) processPipelineNode( defer func() { node.Client.releaseConn(ctx, cn, processErr) }() - processErr = c.processPipelineNodeConn(ctx, node, cn, cmds, failedCmds) + + errCh := make(chan error, 1) + + go func() { + errCh <- c.processPipelineNodeConn(ctx, node, cn, cmds, failedCmds) + }() + + select { + case processErr = <-errCh: + case <-ctx.Done(): + _ = cn.Close() + processErr = ctx.Err() + } return processErr }) @@ -1472,7 +1484,19 @@ func (c *ClusterClient) processTxPipelineNode( defer func() { node.Client.releaseConn(ctx, cn, processErr) }() - processErr = c.processTxPipelineNodeConn(ctx, node, cn, cmds, failedCmds) + + errCh := make(chan error, 1) + + go func() { + errCh <- c.processTxPipelineNodeConn(ctx, node, cn, cmds, failedCmds) + }() + + select { + case processErr = <-errCh: + case <-ctx.Done(): + _ = cn.Close() + processErr = ctx.Err() + } return processErr }) diff --git a/commands_test.go b/commands_test.go index fdc3452ca..821d335a3 100644 --- a/commands_test.go +++ b/commands_test.go @@ -4843,6 +4843,24 @@ var _ = Describe("Commands", func() { Expect(err).To(Equal(redis.Nil)) }) + Describe("canceled context", func() { + It("should unblock XRead", func() { + ctx2, cancel := context.WithCancel(ctx) + errCh := make(chan error, 1) + go func() { + errCh <- client.XRead(ctx2, &redis.XReadArgs{ + Streams: []string{"stream", "$"}, + }).Err() + }() + + var gotErr error + Consistently(errCh).ShouldNot(Receive(&gotErr), "Received %v", gotErr) + cancel() + Eventually(errCh).Should(Receive(&gotErr)) + Expect(gotErr).To(HaveOccurred()) + }) + }) + Describe("group", func() { BeforeEach(func() { err := client.XGroupCreate(ctx, "stream", "group", "0").Err() @@ -5023,6 +5041,26 @@ var _ = Describe("Commands", func() { Expect(err).NotTo(HaveOccurred()) Expect(n).To(Equal(int64(2))) }) + + Describe("canceled context", func() { + It("should unblock XReadGroup", func() { + ctx2, cancel := context.WithCancel(ctx) + errCh := make(chan error, 1) + go func() { + errCh <- client.XReadGroup(ctx2, &redis.XReadGroupArgs{ + Group: "group", + Consumer: "consumer", + Streams: []string{"stream", ">"}, + }).Err() + }() + + var gotErr error + Consistently(errCh).ShouldNot(Receive(&gotErr), "Received %v", gotErr) + cancel() + Eventually(errCh).Should(Receive(&gotErr)) + Expect(gotErr).To(HaveOccurred()) + }) + }) }) Describe("xinfo", func() { diff --git a/internal_test.go b/internal_test.go index a6317196a..494cb96ec 100644 --- a/internal_test.go +++ b/internal_test.go @@ -351,4 +351,21 @@ var _ = Describe("withConn", func() { Expect(newConn).NotTo(Equal(conn)) Expect(client.connPool.Len()).To(Equal(1)) }) + + It("should remove the connection from the pool if the context is canceled", func() { + var conn *pool.Conn + + ctx2, cancel := context.WithCancel(ctx) + cancel() + + client.withConn(ctx2, func(ctx context.Context, c *pool.Conn) error { + conn = c + return nil + }) + + newConn, err := client.connPool.Get(ctx) + Expect(err).To(BeNil()) + Expect(newConn).NotTo(Equal(conn)) + Expect(client.connPool.Len()).To(Equal(1)) + }) }) diff --git a/pubsub.go b/pubsub.go index 16c0f5672..9657ffe54 100644 --- a/pubsub.go +++ b/pubsub.go @@ -432,9 +432,19 @@ func (c *PubSub) ReceiveTimeout(ctx context.Context, timeout time.Duration) (int return nil, err } - err = cn.WithReader(context.Background(), timeout, func(rd *proto.Reader) error { - return c.cmd.readReply(rd) - }) + errCh := make(chan error, 1) + + go func() { + errCh <- cn.WithReader(context.Background(), timeout, func(rd *proto.Reader) error { + return c.cmd.readReply(rd) + }) + }() + + select { + case err = <-errCh: + case <-ctx.Done(): + err = ctx.Err() + } c.releaseConnWithLock(ctx, cn, err, timeout > 0) diff --git a/pubsub_test.go b/pubsub_test.go index a76100659..43b60f0ae 100644 --- a/pubsub_test.go +++ b/pubsub_test.go @@ -1,6 +1,7 @@ package redis_test import ( + "context" "io" "net" "sync" @@ -567,4 +568,24 @@ var _ = Describe("PubSub", func() { Expect(msg.Channel).To(Equal("mychannel")) Expect(msg.Payload).To(Equal(text)) }) + + Describe("canceled context", func() { + It("should unblock ReceiveMessage", func() { + pubsub := client.Subscribe(ctx, "mychannel") + defer pubsub.Close() + + ctx2, cancel := context.WithCancel(ctx) + errCh := make(chan error, 1) + go func() { + _, err := pubsub.ReceiveMessage(ctx2) + errCh <- err + }() + + var gotErr error + Consistently(errCh).ShouldNot(Receive(&gotErr), "Received %v", gotErr) + cancel() + Eventually(errCh).Should(Receive(&gotErr)) + Expect(gotErr).To(HaveOccurred()) + }) + }) }) diff --git a/redis.go b/redis.go index 6eed8424c..4dfbe9024 100644 --- a/redis.go +++ b/redis.go @@ -347,7 +347,18 @@ func (c *baseClient) withConn( c.releaseConn(ctx, cn, fnErr) }() - fnErr = fn(ctx, cn) + errCh := make(chan error, 1) + + go func() { + errCh <- fn(ctx, cn) + }() + + select { + case fnErr = <-errCh: + case <-ctx.Done(): + _ = c.connPool.CloseConn(cn) + fnErr = ctx.Err() + } return fnErr }