diff --git a/pkg/common/config/config.go b/pkg/common/config/config.go index 261e3d8c08..5309f9913d 100644 --- a/pkg/common/config/config.go +++ b/pkg/common/config/config.go @@ -78,10 +78,11 @@ type configStruct struct { } `yaml:"mongo"` Redis struct { - ClusterMode bool `yaml:"clusterMode"` - Address []string `yaml:"address"` - Username string `yaml:"username"` - Password string `yaml:"password"` + ClusterMode bool `yaml:"clusterMode"` + Address []string `yaml:"address"` + Username string `yaml:"username"` + Password string `yaml:"password"` + EnablePipeline bool `yaml:"enablePipeline"` } `yaml:"redis"` Kafka struct { diff --git a/pkg/common/db/cache/msg.go b/pkg/common/db/cache/msg.go index 50fb617aa9..c8346a1d48 100644 --- a/pkg/common/db/cache/msg.go +++ b/pkg/common/db/cache/msg.go @@ -21,6 +21,7 @@ import ( "time" "github.com/dtm-labs/rockscache" + "golang.org/x/sync/errgroup" unrelationtb "github.com/openimsdk/open-im-server/v3/pkg/common/db/table/unrelation" @@ -62,6 +63,8 @@ const ( uidPidToken = "UID_PID_TOKEN_STATUS:" ) +var concurrentLimit = 3 + type SeqCache interface { SetMaxSeq(ctx context.Context, conversationID string, maxSeq int64) error GetMaxSeqs(ctx context.Context, conversationIDs []string) (map[string]int64, error) @@ -345,85 +348,165 @@ func (c *msgCache) allMessageCacheKey(conversationID string) string { } func (c *msgCache) GetMessagesBySeq(ctx context.Context, conversationID string, seqs []int64) (seqMsgs []*sdkws.MsgData, failedSeqs []int64, err error) { + if config.Config.Redis.EnablePipeline { + return c.PipeGetMessagesBySeq(ctx, conversationID, seqs) + } + + return c.ParallelGetMessagesBySeq(ctx, conversationID, seqs) +} + +func (c *msgCache) PipeGetMessagesBySeq(ctx context.Context, conversationID string, seqs []int64) (seqMsgs []*sdkws.MsgData, failedSeqs []int64, err error) { + pipe := c.rdb.Pipeline() + + results := []*redis.StringCmd{} for _, seq := range seqs { - res, err := c.rdb.Get(ctx, c.getMessageCacheKey(conversationID, seq)).Result() - if err != nil { - log.ZError(ctx, "GetMessagesBySeq failed", err, "conversationID", conversationID, "seq", seq) + results = append(results, pipe.Get(ctx, c.getMessageCacheKey(conversationID, seq))) + } + + _, err = pipe.Exec(ctx) + if err != nil && err != redis.Nil { + return seqMsgs, failedSeqs, errs.Wrap(err, "pipe.get") + } + + for idx, res := range results { + seq := seqs[idx] + if res.Err() != nil { + log.ZError(ctx, "GetMessagesBySeq failed", err, "conversationID", conversationID, "seq", seq, "err", res.Err()) failedSeqs = append(failedSeqs, seq) continue } + msg := sdkws.MsgData{} - if err = msgprocessor.String2Pb(res, &msg); err != nil { + if err = msgprocessor.String2Pb(res.Val(), &msg); err != nil { log.ZError(ctx, "GetMessagesBySeq Unmarshal failed", err, "res", res, "conversationID", conversationID, "seq", seq) failedSeqs = append(failedSeqs, seq) continue } + if msg.Status == constant.MsgDeleted { failedSeqs = append(failedSeqs, seq) continue } + seqMsgs = append(seqMsgs, &msg) } return - //pipe := c.rdb.Pipeline() - //for _, v := range seqs { - // // MESSAGE_CACHE:169.254.225.224_reliability1653387820_0_1 - // key := c.getMessageCacheKey(conversationID, v) - // if err := pipe.Get(ctx, key).Err(); err != nil && err != redis.Nil { - // return nil, nil, err - // } - //} - //result, err := pipe.Exec(ctx) - //for i, v := range result { - // cmd := v.(*redis.StringCmd) - // if cmd.Err() != nil { - // failedSeqs = append(failedSeqs, seqs[i]) - // } else { - // msg := sdkws.MsgData{} - // err = msgprocessor.String2Pb(cmd.Val(), &msg) - // if err == nil { - // if msg.Status != constant.MsgDeleted { - // seqMsgs = append(seqMsgs, &msg) - // continue - // } - // } else { - // log.ZWarn(ctx, "UnmarshalString failed", err, "conversationID", conversationID, "seq", seqs[i], "msg", cmd.Val()) - // } - // failedSeqs = append(failedSeqs, seqs[i]) - // } - //} - //return seqMsgs, failedSeqs, err +} + +func (c *msgCache) ParallelGetMessagesBySeq(ctx context.Context, conversationID string, seqs []int64) (seqMsgs []*sdkws.MsgData, failedSeqs []int64, err error) { + type entry struct { + err error + msg *sdkws.MsgData + } + + wg := errgroup.Group{} + wg.SetLimit(concurrentLimit) + + results := make([]entry, len(seqs)) // set slice len/cap to length of seqs. + for idx, seq := range seqs { + // closure safe var + idx := idx + seq := seq + + wg.Go(func() error { + res, err := c.rdb.Get(ctx, c.getMessageCacheKey(conversationID, seq)).Result() + if err != nil { + log.ZError(ctx, "GetMessagesBySeq failed", err, "conversationID", conversationID, "seq", seq) + results[idx] = entry{err: err} + return nil + } + + msg := sdkws.MsgData{} + if err = msgprocessor.String2Pb(res, &msg); err != nil { + log.ZError(ctx, "GetMessagesBySeq Unmarshal failed", err, "res", res, "conversationID", conversationID, "seq", seq) + results[idx] = entry{err: err} + return nil + } + + if msg.Status == constant.MsgDeleted { + results[idx] = entry{err: err} + return nil + } + + results[idx] = entry{msg: &msg} + return nil + }) + } + + _ = wg.Wait() + + for idx, res := range results { + if res.err != nil { + failedSeqs = append(failedSeqs, seqs[idx]) + continue + } + + seqMsgs = append(seqMsgs, res.msg) + } + + return } func (c *msgCache) SetMessageToCache(ctx context.Context, conversationID string, msgs []*sdkws.MsgData) (int, error) { + if config.Config.Redis.EnablePipeline { + return c.PipeSetMessageToCache(ctx, conversationID, msgs) + } + return c.ParallelSetMessageToCache(ctx, conversationID, msgs) +} + +func (c *msgCache) PipeSetMessageToCache(ctx context.Context, conversationID string, msgs []*sdkws.MsgData) (int, error) { + pipe := c.rdb.Pipeline() for _, msg := range msgs { s, err := msgprocessor.Pb2String(msg) if err != nil { - return 0, errs.Wrap(err) + return 0, errs.Wrap(err, "pb.marshal") } + key := c.getMessageCacheKey(conversationID, msg.Seq) - if err := c.rdb.Set(ctx, key, s, time.Duration(config.Config.MsgCacheTimeout)*time.Second).Err(); err != nil { + _ = pipe.Set(ctx, key, s, time.Duration(config.Config.MsgCacheTimeout)*time.Second) + } + + results, err := pipe.Exec(ctx) + if err != nil { + return 0, errs.Wrap(err, "pipe.set") + } + + for _, res := range results { + if res.Err() != nil { return 0, errs.Wrap(err) } } + + return len(msgs), nil +} + +func (c *msgCache) ParallelSetMessageToCache(ctx context.Context, conversationID string, msgs []*sdkws.MsgData) (int, error) { + wg := errgroup.Group{} + wg.SetLimit(concurrentLimit) + + for _, msg := range msgs { + msg := msg // closure safe var + wg.Go(func() error { + s, err := msgprocessor.Pb2String(msg) + if err != nil { + return errs.Wrap(err) + } + + key := c.getMessageCacheKey(conversationID, msg.Seq) + if err := c.rdb.Set(ctx, key, s, time.Duration(config.Config.MsgCacheTimeout)*time.Second).Err(); err != nil { + return errs.Wrap(err) + } + return nil + }) + } + + err := wg.Wait() + if err != nil { + return 0, err + } + return len(msgs), nil - //pipe := c.rdb.Pipeline() - //var failedMsgs []*sdkws.MsgData - //for _, msg := range msgs { - // key := c.getMessageCacheKey(conversationID, msg.Seq) - // s, err := msgprocessor.Pb2String(msg) - // if err != nil { - // return 0, errs.Wrap(err) - // } - // err = pipe.Set(ctx, key, s, time.Duration(config.Config.MsgCacheTimeout)*time.Second).Err() - // if err != nil { - // failedMsgs = append(failedMsgs, msg) - // log.ZWarn(ctx, "set msg 2 cache failed", err, "msg", failedMsgs) - // } - //} - //_, err := pipe.Exec(ctx) - //return len(failedMsgs), err } func (c *msgCache) getMessageDelUserListKey(conversationID string, seq int64) string { diff --git a/pkg/common/db/cache/msg_test.go b/pkg/common/db/cache/msg_test.go new file mode 100644 index 0000000000..c5a4fb870c --- /dev/null +++ b/pkg/common/db/cache/msg_test.go @@ -0,0 +1,251 @@ +package cache + +import ( + "context" + "fmt" + "math/rand" + "testing" + + "github.com/OpenIMSDK/protocol/sdkws" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/assert" +) + +func TestParallelSetMessageToCache(t *testing.T) { + var ( + cid = fmt.Sprintf("cid-%v", rand.Int63()) + seqFirst = rand.Int63() + msgs = []*sdkws.MsgData{} + ) + + for i := 0; i < 100; i++ { + msgs = append(msgs, &sdkws.MsgData{ + Seq: seqFirst + int64(i), + }) + } + + testParallelSetMessageToCache(t, cid, msgs) +} + +func testParallelSetMessageToCache(t *testing.T, cid string, msgs []*sdkws.MsgData) { + rdb := redis.NewClient(&redis.Options{}) + defer rdb.Close() + + cacher := msgCache{rdb: rdb} + + ret, err := cacher.ParallelSetMessageToCache(context.Background(), cid, msgs) + assert.Nil(t, err) + assert.Equal(t, len(msgs), ret) + + // validate + for _, msg := range msgs { + key := cacher.getMessageCacheKey(cid, msg.Seq) + val, err := rdb.Exists(context.Background(), key).Result() + assert.Nil(t, err) + assert.EqualValues(t, 1, val) + } +} + +func TestPipeSetMessageToCache(t *testing.T) { + var ( + cid = fmt.Sprintf("cid-%v", rand.Int63()) + seqFirst = rand.Int63() + msgs = []*sdkws.MsgData{} + ) + + for i := 0; i < 100; i++ { + msgs = append(msgs, &sdkws.MsgData{ + Seq: seqFirst + int64(i), + }) + } + + testPipeSetMessageToCache(t, cid, msgs) +} + +func testPipeSetMessageToCache(t *testing.T, cid string, msgs []*sdkws.MsgData) { + rdb := redis.NewClient(&redis.Options{}) + defer rdb.Close() + + cacher := msgCache{rdb: rdb} + + ret, err := cacher.PipeSetMessageToCache(context.Background(), cid, msgs) + assert.Nil(t, err) + assert.Equal(t, len(msgs), ret) + + // validate + for _, msg := range msgs { + key := cacher.getMessageCacheKey(cid, msg.Seq) + val, err := rdb.Exists(context.Background(), key).Result() + assert.Nil(t, err) + assert.EqualValues(t, 1, val) + } +} + +func TestGetMessagesBySeq(t *testing.T) { + var ( + cid = fmt.Sprintf("cid-%v", rand.Int63()) + seqFirst = rand.Int63() + msgs = []*sdkws.MsgData{} + ) + + seqs := []int64{} + for i := 0; i < 100; i++ { + msgs = append(msgs, &sdkws.MsgData{ + Seq: seqFirst + int64(i), + SendID: fmt.Sprintf("fake-sendid-%v", i), + }) + seqs = append(seqs, seqFirst+int64(i)) + } + + // set data to cache + testPipeSetMessageToCache(t, cid, msgs) + + // get data from cache with parallet mode + testParallelGetMessagesBySeq(t, cid, seqs, msgs) + + // get data from cache with pipeline mode + testPipeGetMessagesBySeq(t, cid, seqs, msgs) +} + +func testParallelGetMessagesBySeq(t *testing.T, cid string, seqs []int64, inputMsgs []*sdkws.MsgData) { + rdb := redis.NewClient(&redis.Options{}) + defer rdb.Close() + + cacher := msgCache{rdb: rdb} + + respMsgs, failedSeqs, err := cacher.ParallelGetMessagesBySeq(context.Background(), cid, seqs) + assert.Nil(t, err) + assert.Equal(t, 0, len(failedSeqs)) + assert.Equal(t, len(respMsgs), len(seqs)) + + // validate + for idx, msg := range respMsgs { + assert.Equal(t, msg.Seq, inputMsgs[idx].Seq) + assert.Equal(t, msg.SendID, inputMsgs[idx].SendID) + } +} + +func testPipeGetMessagesBySeq(t *testing.T, cid string, seqs []int64, inputMsgs []*sdkws.MsgData) { + rdb := redis.NewClient(&redis.Options{}) + defer rdb.Close() + + cacher := msgCache{rdb: rdb} + + respMsgs, failedSeqs, err := cacher.PipeGetMessagesBySeq(context.Background(), cid, seqs) + assert.Nil(t, err) + assert.Equal(t, 0, len(failedSeqs)) + assert.Equal(t, len(respMsgs), len(seqs)) + + // validate + for idx, msg := range respMsgs { + assert.Equal(t, msg.Seq, inputMsgs[idx].Seq) + assert.Equal(t, msg.SendID, inputMsgs[idx].SendID) + } +} + +func TestGetMessagesBySeqWithEmptySeqs(t *testing.T) { + var ( + cid = fmt.Sprintf("cid-%v", rand.Int63()) + seqFirst int64 = 0 + msgs = []*sdkws.MsgData{} + ) + + seqs := []int64{} + for i := 0; i < 100; i++ { + msgs = append(msgs, &sdkws.MsgData{ + Seq: seqFirst + int64(i), + SendID: fmt.Sprintf("fake-sendid-%v", i), + }) + seqs = append(seqs, seqFirst+int64(i)) + } + + // don't set cache, only get data from cache. + + // get data from cache with parallet mode + testParallelGetMessagesBySeqWithEmptry(t, cid, seqs, msgs) + + // get data from cache with pipeline mode + testPipeGetMessagesBySeqWithEmptry(t, cid, seqs, msgs) +} + +func testParallelGetMessagesBySeqWithEmptry(t *testing.T, cid string, seqs []int64, inputMsgs []*sdkws.MsgData) { + rdb := redis.NewClient(&redis.Options{}) + defer rdb.Close() + + cacher := msgCache{rdb: rdb} + + respMsgs, failedSeqs, err := cacher.ParallelGetMessagesBySeq(context.Background(), cid, seqs) + assert.Nil(t, err) + assert.Equal(t, len(seqs), len(failedSeqs)) + assert.Equal(t, 0, len(respMsgs)) +} + +func testPipeGetMessagesBySeqWithEmptry(t *testing.T, cid string, seqs []int64, inputMsgs []*sdkws.MsgData) { + rdb := redis.NewClient(&redis.Options{}) + defer rdb.Close() + + cacher := msgCache{rdb: rdb} + + respMsgs, failedSeqs, err := cacher.PipeGetMessagesBySeq(context.Background(), cid, seqs) + assert.Equal(t, err, redis.Nil) + assert.Equal(t, len(seqs), len(failedSeqs)) + assert.Equal(t, 0, len(respMsgs)) +} + +func TestGetMessagesBySeqWithLostHalfSeqs(t *testing.T) { + var ( + cid = fmt.Sprintf("cid-%v", rand.Int63()) + seqFirst int64 = 0 + msgs = []*sdkws.MsgData{} + ) + + seqs := []int64{} + for i := 0; i < 100; i++ { + msgs = append(msgs, &sdkws.MsgData{ + Seq: seqFirst + int64(i), + SendID: fmt.Sprintf("fake-sendid-%v", i), + }) + seqs = append(seqs, seqFirst+int64(i)) + } + + // Only set half the number of messages. + testParallelSetMessageToCache(t, cid, msgs[:50]) + + // get data from cache with parallet mode + testParallelGetMessagesBySeqWithLostHalfSeqs(t, cid, seqs, msgs) + + // get data from cache with pipeline mode + testPipeGetMessagesBySeqWithLostHalfSeqs(t, cid, seqs, msgs) +} + +func testParallelGetMessagesBySeqWithLostHalfSeqs(t *testing.T, cid string, seqs []int64, inputMsgs []*sdkws.MsgData) { + rdb := redis.NewClient(&redis.Options{}) + defer rdb.Close() + + cacher := msgCache{rdb: rdb} + + respMsgs, failedSeqs, err := cacher.ParallelGetMessagesBySeq(context.Background(), cid, seqs) + assert.Nil(t, err) + assert.Equal(t, len(seqs)/2, len(failedSeqs)) + assert.Equal(t, len(seqs)/2, len(respMsgs)) + + for idx, msg := range respMsgs { + assert.Equal(t, msg.Seq, seqs[idx]) + } +} + +func testPipeGetMessagesBySeqWithLostHalfSeqs(t *testing.T, cid string, seqs []int64, inputMsgs []*sdkws.MsgData) { + rdb := redis.NewClient(&redis.Options{}) + defer rdb.Close() + + cacher := msgCache{rdb: rdb} + + respMsgs, failedSeqs, err := cacher.PipeGetMessagesBySeq(context.Background(), cid, seqs) + assert.Nil(t, err) + assert.Equal(t, len(seqs)/2, len(failedSeqs)) + assert.Equal(t, len(seqs)/2, len(respMsgs)) + + for idx, msg := range respMsgs { + assert.Equal(t, msg.Seq, seqs[idx]) + } +}