Skip to content

Commit

Permalink
support tstzrange
Browse files Browse the repository at this point in the history
  • Loading branch information
Amogh-Bharadwaj committed May 7, 2024
1 parent 9967be5 commit 06297c9
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 1 deletion.
44 changes: 44 additions & 0 deletions flow/connectors/postgres/qvalue_convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ func (c *PostgresConnector) postgresOIDToQValueKind(recvOID uint32) qvalue.QValu
return qvalue.QValueKindArrayString
case pgtype.IntervalOID:
return qvalue.QValueKindInterval
case pgtype.TstzrangeOID:
return qvalue.QValueKindTSTZRange
default:
typeName, ok := pgtype.NewMap().TypeForOID(recvOID)
if !ok {
Expand Down Expand Up @@ -273,6 +275,31 @@ func parseFieldFromQValueKind(qvalueKind qvalue.QValueKind, value interface{}) (
}

return qvalue.QValueString{Val: string(intervalJSON)}, nil
case qvalue.QValueKindTSTZRange:
tstzrangeObject := value.(pgtype.Range[interface{}])
lowerBoundType := tstzrangeObject.LowerType
upperBoundType := tstzrangeObject.UpperType
lowerTime, err := ConvertTimeRangeBounds(tstzrangeObject.Lower)
if err != nil {
return nil, fmt.Errorf("[tstzrange]error for lower time bound: %v", err)
}

upperTime, err := ConvertTimeRangeBounds(tstzrangeObject.Upper)
if err != nil {
return nil, fmt.Errorf("[tstzrange]error for upper time bound: %v", err)
}

lowerBracket := "["
if lowerBoundType == pgtype.Exclusive {
lowerBracket = "("
}
upperBracket := "]"
if upperBoundType == pgtype.Exclusive {
upperBracket = ")"
}
tstzrangeStr := fmt.Sprintf("%s%v,%v%s",
lowerBracket, lowerTime, upperTime, upperBracket)
return qvalue.QValueTSTZRange{Val: tstzrangeStr}, nil
case qvalue.QValueKindDate:
date := value.(time.Time)
return qvalue.QValueDate{Val: date}, nil
Expand Down Expand Up @@ -481,3 +508,20 @@ func customTypeToQKind(typeName string) qvalue.QValueKind {
return qvalue.QValueKindString
}
}

func ConvertTimeRangeBounds(timeBound interface{}) (string, error) {
layout := "2006-01-02 15:04:05 -0700 MST"
postgresFormat := "2006-01-02 15:04:05"
var convertedTime string
if timeBound != nil {
lowerParsed, err := time.Parse(layout, fmt.Sprint(timeBound))
if err != nil {
return "", fmt.Errorf("Unexpected lower bound value in tstzrange. Error: %v", err)
}
convertedTime = lowerParsed.Format(postgresFormat)
} else {
convertedTime = ""
}

return convertedTime, nil
}
2 changes: 2 additions & 0 deletions flow/model/qrecord_batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ func (src *QRecordBatchCopyFromSource) Values() ([]interface{}, error) {
values[i] = pgtype.Timestamp{Time: v.Val, Valid: true}
case qvalue.QValueTimestampTZ:
values[i] = pgtype.Timestamptz{Time: v.Val, Valid: true}
case qvalue.QValueTSTZRange:
values[i] = v.Val
case qvalue.QValueUUID:
values[i] = uuid.UUID(v.Val)
case qvalue.QValueNumeric:
Expand Down
1 change: 1 addition & 0 deletions flow/model/qvalue/kind.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ const (
QValueKindTime QValueKind = "time"
QValueKindTimeTZ QValueKind = "timetz"
QValueKindInterval QValueKind = "interval"
QValueKindTSTZRange QValueKind = "tstzrange"
QValueKindNumeric QValueKind = "numeric"
QValueKindBytes QValueKind = "bytes"
QValueKindUUID QValueKind = "uuid"
Expand Down
18 changes: 17 additions & 1 deletion flow/model/qvalue/qvalue.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (

"github.com/google/uuid"
"github.com/shopspring/decimal"
"github.com/yuin/gopher-lua"
lua "github.com/yuin/gopher-lua"

"github.com/PeerDB-io/glua64"
"github.com/PeerDB-io/peer-flow/shared"
Expand Down Expand Up @@ -294,6 +294,22 @@ func (v QValueInterval) LValue(ls *lua.LState) lua.LValue {
return lua.LString(v.Val)
}

type QValueTSTZRange struct {
Val string
}

func (QValueTSTZRange) Kind() QValueKind {
return QValueKindInterval
}

func (v QValueTSTZRange) Value() any {
return v.Val
}

func (v QValueTSTZRange) LValue(ls *lua.LState) lua.LValue {
return lua.LString(v.Val)
}

type QValueNumeric struct {
Val decimal.Decimal
}
Expand Down

0 comments on commit 06297c9

Please sign in to comment.