diff --git a/pkg/common/db/mgo/msg.go b/pkg/common/db/mgo/msg.go index 6fe24536bd..17e493d336 100644 --- a/pkg/common/db/mgo/msg.go +++ b/pkg/common/db/mgo/msg.go @@ -267,58 +267,80 @@ func (m *MsgMgo) MarkSingleChatMsgsAsRead(ctx context.Context, userID string, do } func (m *MsgMgo) SearchMessage(ctx context.Context, req *msg.SearchMessageReq) (int32, []*relation.MsgInfoModel, error) { - var pipe mongo.Pipeline - condition := bson.A{} - if req.SendTime != "" { - // Changed to keyed fields for bson.M to avoid govet errors - condition = append(condition, bson.M{"$eq": bson.A{bson.M{"$dateToString": bson.M{"format": "%Y-%m-%d", "date": bson.M{"$toDate": "$$item.msg.send_time"}}}, req.SendTime}}) + where := make(bson.A, 0, 6) + if req.RecvID != "" { + where = append(where, bson.M{"msgs.msg.recv_id": req.RecvID}) + } + if req.SendID != "" { + where = append(where, bson.M{"msgs.msg.send_id": req.SendID}) } if req.ContentType != 0 { - condition = append(condition, bson.M{"$eq": bson.A{"$$item.msg.content_type", req.ContentType}}) + where = append(where, bson.M{"msgs.msg.content_type": req.ContentType}) } if req.SessionType != 0 { - condition = append(condition, bson.M{"$eq": bson.A{"$$item.msg.session_type", req.SessionType}}) + where = append(where, bson.M{"msgs.msg.session_type": req.SessionType}) } - if req.RecvID != "" { - condition = append(condition, bson.M{"$regexFind": bson.M{"input": "$$item.msg.recv_id", "regex": req.RecvID}}) + if req.SendTime != "" { + sendTime, err := time.Parse(time.DateOnly, req.SendTime) + if err != nil { + return 0, nil, errs.ErrArgs.WrapMsg("invalid sendTime", "req", req.SendTime, "format", time.DateOnly, "cause", err.Error()) + } + where = append(where, + bson.M{ + "msgs.msg.send_time": bson.M{ + "$gte": sendTime.UnixMilli(), + }, + }, + bson.M{ + "msgs.msg.send_time": bson.M{ + "$lt": sendTime.Add(time.Hour * 24).UnixMilli(), + }, + }, + ) } - if req.SendID != "" { - condition = append(condition, bson.M{"$regexFind": bson.M{"input": "$$item.msg.send_id", "regex": req.SendID}}) + pipeline := bson.A{ + bson.M{ + "$unwind": "$msgs", + }, } - - or := bson.A{ - bson.M{"doc_id": bson.M{"$regex": "^si_", "$options": "i"}}, - bson.M{"doc_id": bson.M{"$regex": "^g_", "$options": "i"}}, - bson.M{"doc_id": bson.M{"$regex": "^sg_", "$options": "i"}}, + if len(where) > 0 { + pipeline = append(pipeline, bson.M{ + "$match": bson.M{"$and": where}, + }) } - - // Use bson.D with keyed fields to specify the order explicitly - pipe = mongo.Pipeline{ - {{"$match", bson.D{{Key: "$or", Value: or}}}}, - {{"$project", bson.D{ - {Key: "msgs", Value: bson.D{ - {Key: "$filter", Value: bson.D{ - {Key: "input", Value: "$msgs"}, - {Key: "as", Value: "item"}, - {Key: "cond", Value: bson.D{{Key: "$and", Value: condition}}}, - }}, - }}, - {Key: "doc_id", Value: 1}, - }}}, - {{"$unwind", bson.M{"path": "$msgs"}}}, - {{"$sort", bson.M{"msgs.msg.send_time": -1}}}, + pipeline = append(pipeline, + bson.M{ + "$project": bson.M{ + "_id": 0, + "msg": "$msgs.msg", + }, + }, + bson.M{ + "$count": "count", + }, + ) + count, err := mongoutil.Aggregate[int32](ctx, m.coll, pipeline) + if err != nil { + return 0, nil, err } - type docModel struct { - DocID string `bson:"doc_id"` - Msg *relation.MsgInfoModel `bson:"msgs"` + if len(count) == 0 || count[0] == 0 { + return 0, nil, nil } - msgsDocs, err := mongoutil.Aggregate[*docModel](ctx, m.coll, pipe) + pipeline = pipeline[:len(pipeline)-1] + pipeline = append(pipeline, + bson.M{ + "$skip": (req.Pagination.GetPageNumber() - 1) * req.Pagination.GetShowNumber(), + }, + bson.M{ + "$limit": req.Pagination.GetShowNumber(), + }, + ) + msgs, err := mongoutil.Aggregate[*relation.MsgInfoModel](ctx, m.coll, pipeline) if err != nil { return 0, nil, err } - msgs := make([]*relation.MsgInfoModel, 0) - for _, doc := range msgsDocs { - msgInfo := doc.Msg + for i := range msgs { + msgInfo := msgs[i] if msgInfo == nil || msgInfo.Msg == nil { continue } @@ -350,17 +372,17 @@ func (m *MsgMgo) SearchMessage(ctx context.Context, req *msg.SearchMessageReq) ( } msgs = append(msgs, msgInfo) } - start := (req.Pagination.PageNumber - 1) * req.Pagination.ShowNumber - n := int32(len(msgs)) - if start >= n { - return n, []*relation.MsgInfoModel{}, nil - } - if start+req.Pagination.ShowNumber < n { - msgs = msgs[start : start+req.Pagination.ShowNumber] - } else { - msgs = msgs[start:] - } - return n, msgs, nil + //start := (req.Pagination.PageNumber - 1) * req.Pagination.ShowNumber + //n := int32(len(msgs)) + //if start >= n { + // return n, []*relation.MsgInfoModel{}, nil + //} + //if start+req.Pagination.ShowNumber < n { + // msgs = msgs[start : start+req.Pagination.ShowNumber] + //} else { + // msgs = msgs[start:] + //} + return count[0], msgs, nil } func (m *MsgMgo) RangeUserSendCount(ctx context.Context, start time.Time, end time.Time, group bool, ase bool, pageNumber int32, showNumber int32) (msgCount int64, userCount int64, users []*relation.UserCount, dateCount map[string]int64, err error) {