From 06297c9c48fc63b980eb17ec8d7214d2387c54f4 Mon Sep 17 00:00:00 2001 From: Amogh-Bharadwaj Date: Wed, 8 May 2024 01:06:55 +0530 Subject: [PATCH] support tstzrange --- flow/connectors/postgres/qvalue_convert.go | 44 ++++++++++++++++++++++ flow/model/qrecord_batch.go | 2 + flow/model/qvalue/kind.go | 1 + flow/model/qvalue/qvalue.go | 18 ++++++++- 4 files changed, 64 insertions(+), 1 deletion(-) diff --git a/flow/connectors/postgres/qvalue_convert.go b/flow/connectors/postgres/qvalue_convert.go index ef12fd61c4..05285297c1 100644 --- a/flow/connectors/postgres/qvalue_convert.go +++ b/flow/connectors/postgres/qvalue_convert.go @@ -108,6 +108,8 @@ func (c *PostgresConnector) postgresOIDToQValueKind(recvOID uint32) qvalue.QValu return qvalue.QValueKindArrayString case pgtype.IntervalOID: return qvalue.QValueKindInterval + case pgtype.TstzrangeOID: + return qvalue.QValueKindTSTZRange default: typeName, ok := pgtype.NewMap().TypeForOID(recvOID) if !ok { @@ -273,6 +275,31 @@ func parseFieldFromQValueKind(qvalueKind qvalue.QValueKind, value interface{}) ( } return qvalue.QValueString{Val: string(intervalJSON)}, nil + case qvalue.QValueKindTSTZRange: + tstzrangeObject := value.(pgtype.Range[interface{}]) + lowerBoundType := tstzrangeObject.LowerType + upperBoundType := tstzrangeObject.UpperType + lowerTime, err := ConvertTimeRangeBounds(tstzrangeObject.Lower) + if err != nil { + return nil, fmt.Errorf("[tstzrange]error for lower time bound: %v", err) + } + + upperTime, err := ConvertTimeRangeBounds(tstzrangeObject.Upper) + if err != nil { + return nil, fmt.Errorf("[tstzrange]error for upper time bound: %v", err) + } + + lowerBracket := "[" + if lowerBoundType == pgtype.Exclusive { + lowerBracket = "(" + } + upperBracket := "]" + if upperBoundType == pgtype.Exclusive { + upperBracket = ")" + } + tstzrangeStr := fmt.Sprintf("%s%v,%v%s", + lowerBracket, lowerTime, upperTime, upperBracket) + return qvalue.QValueTSTZRange{Val: tstzrangeStr}, nil case qvalue.QValueKindDate: date := value.(time.Time) return qvalue.QValueDate{Val: date}, nil @@ -481,3 +508,20 @@ func customTypeToQKind(typeName string) qvalue.QValueKind { return qvalue.QValueKindString } } + +func ConvertTimeRangeBounds(timeBound interface{}) (string, error) { + layout := "2006-01-02 15:04:05 -0700 MST" + postgresFormat := "2006-01-02 15:04:05" + var convertedTime string + if timeBound != nil { + lowerParsed, err := time.Parse(layout, fmt.Sprint(timeBound)) + if err != nil { + return "", fmt.Errorf("Unexpected lower bound value in tstzrange. Error: %v", err) + } + convertedTime = lowerParsed.Format(postgresFormat) + } else { + convertedTime = "" + } + + return convertedTime, nil +} diff --git a/flow/model/qrecord_batch.go b/flow/model/qrecord_batch.go index 1f787b5c1c..3b76c363d0 100644 --- a/flow/model/qrecord_batch.go +++ b/flow/model/qrecord_batch.go @@ -118,6 +118,8 @@ func (src *QRecordBatchCopyFromSource) Values() ([]interface{}, error) { values[i] = pgtype.Timestamp{Time: v.Val, Valid: true} case qvalue.QValueTimestampTZ: values[i] = pgtype.Timestamptz{Time: v.Val, Valid: true} + case qvalue.QValueTSTZRange: + values[i] = v.Val case qvalue.QValueUUID: values[i] = uuid.UUID(v.Val) case qvalue.QValueNumeric: diff --git a/flow/model/qvalue/kind.go b/flow/model/qvalue/kind.go index fa0a3c2235..14233338f5 100644 --- a/flow/model/qvalue/kind.go +++ b/flow/model/qvalue/kind.go @@ -26,6 +26,7 @@ const ( QValueKindTime QValueKind = "time" QValueKindTimeTZ QValueKind = "timetz" QValueKindInterval QValueKind = "interval" + QValueKindTSTZRange QValueKind = "tstzrange" QValueKindNumeric QValueKind = "numeric" QValueKindBytes QValueKind = "bytes" QValueKindUUID QValueKind = "uuid" diff --git a/flow/model/qvalue/qvalue.go b/flow/model/qvalue/qvalue.go index 91b9e3fe31..961028ceb9 100644 --- a/flow/model/qvalue/qvalue.go +++ b/flow/model/qvalue/qvalue.go @@ -6,7 +6,7 @@ import ( "github.com/google/uuid" "github.com/shopspring/decimal" - "github.com/yuin/gopher-lua" + lua "github.com/yuin/gopher-lua" "github.com/PeerDB-io/glua64" "github.com/PeerDB-io/peer-flow/shared" @@ -294,6 +294,22 @@ func (v QValueInterval) LValue(ls *lua.LState) lua.LValue { return lua.LString(v.Val) } +type QValueTSTZRange struct { + Val string +} + +func (QValueTSTZRange) Kind() QValueKind { + return QValueKindInterval +} + +func (v QValueTSTZRange) Value() any { + return v.Val +} + +func (v QValueTSTZRange) LValue(ls *lua.LState) lua.LValue { + return lua.LString(v.Val) +} + type QValueNumeric struct { Val decimal.Decimal }