Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unit tests for connection pooling #395

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 165 additions & 48 deletions client/gosqldriver/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package gosqldriver

import (
"context"
"database/sql/driver"
"errors"
"fmt"
Expand All @@ -32,15 +33,46 @@ import (

var corrIDUnsetCmd = netstring.NewNetstringFrom(common.CmdClientCalCorrelationID, []byte("CorrId=NotSet"))

type heraConnectionInterface interface {
Prepare(query string) (driver.Stmt, error)
Close() error
Begin() (driver.Tx, error)
exec(cmd int, payload []byte) error
execNs(ns *netstring.Netstring) error
getResponse() (*netstring.Netstring, error)
SetShardID(shard int) error
ResetShardID() error
GetNumShards() (int, error)
SetShardKeyPayload(payload string)
ResetShardKeyPayload()
SetCalCorrID(corrID string)
SetClientInfo(poolName string, host string) error
SetClientInfoWithPoolStack(poolName string, host string, poolStack string) error
getID() string
getCorrID() *netstring.Netstring
getShardKeyPayload() []byte
setCorrID(*netstring.Netstring)
startWatcher()
finish()
cancel(err error)
watchCancel(ctx context.Context) error
}

type heraConnection struct {
id string // used for logging
conn net.Conn
reader *netstring.Reader
// for the sharding extension
shardKeyPayload []byte
// correlation id
corrID *netstring.Netstring
corrID *netstring.Netstring
clientinfo *netstring.Netstring

// Context support
watching bool
watcher chan<- context.Context
finished chan<- struct{}
closech chan struct{}
}

// NewHeraConnection creates a structure implementing a driver.Con interface
Expand All @@ -49,9 +81,78 @@ func NewHeraConnection(conn net.Conn) driver.Conn {
if logger.GetLogger().V(logger.Info) {
logger.GetLogger().Log(logger.Info, hera.id, "create driver connection")
}

hera.startWatcher()

return hera
}

func (c *heraConnection) startWatcher() {
watcher := make(chan context.Context, 1)
c.watcher = watcher
finished := make(chan struct{})
c.finished = finished
go func() {
for {
var ctx context.Context
select {
case ctx = <-watcher:
case <-c.closech:
return
}

select {
case <-ctx.Done():
c.cancel(ctx.Err())
case <-finished:
case <-c.closech:
return
}
}
}()
}

func (c *heraConnection) finish() {
if !c.watching || c.finished == nil {
return
}
select {
case c.finished <- struct{}{}:
c.watching = false
case <-c.closech:
}
}

func (c *heraConnection) cancel(err error) {
if logger.GetLogger().V(logger.Debug) {
logger.GetLogger().Log(logger.Debug, c.id, "ctx error:", err)
}
c.Close()
}

func (c *heraConnection) watchCancel(ctx context.Context) error {
if c.watching {
// Reach here if canceled, the connection is already invalid
c.Close()
return nil
}
// When ctx is already cancelled, don't watch it.
if err := ctx.Err(); err != nil {
return err
}
// When ctx is not cancellable, don't watch it.
if ctx.Done() == nil {
return nil
}

if c.watcher == nil {
return nil
}

c.watching = true
c.watcher <- ctx
return nil
}

// Prepare returns a prepared statement, bound to this connection.
func (c *heraConnection) Prepare(query string) (driver.Stmt, error) {
Expand Down Expand Up @@ -177,66 +278,82 @@ func (c *heraConnection) SetCalCorrID(corrID string) {
}

// SetClientInfo actually sends it over to Hera server
func (c *heraConnection) SetClientInfo(poolName string, host string)(error){
func (c *heraConnection) SetClientInfo(poolName string, host string) error {
if len(poolName) <= 0 && len(host) <= 0 {
return nil
}

pid := os.Getpid()
data := fmt.Sprintf("PID: %d, HOST: %s, Poolname: %s, Command: SetClientInfo,", pid, host, poolName)
c.clientinfo = netstring.NewNetstringFrom(common.CmdClientInfo, []byte(string(data)))
if logger.GetLogger().V(logger.Verbose) {
logger.GetLogger().Log(logger.Verbose, "SetClientInfo", c.clientinfo.Serialized)
}

_, err := c.conn.Write(c.clientinfo.Serialized)
if err != nil {
if logger.GetLogger().V(logger.Warning) {
logger.GetLogger().Log(logger.Warning, "Failed to send client info")
}
return errors.New("Failed custom auth, failed to send client info")
}
ns, err := c.reader.ReadNext()
if err != nil {
if logger.GetLogger().V(logger.Warning) {
logger.GetLogger().Log(logger.Warning, "Failed to read server info")
}
return errors.New("Failed to read server info")
}
if logger.GetLogger().V(logger.Debug) {
logger.GetLogger().Log(logger.Debug, "Server info:", string(ns.Payload))
}
c.clientinfo = netstring.NewNetstringFrom(common.CmdClientInfo, []byte(string(data)))
if logger.GetLogger().V(logger.Verbose) {
logger.GetLogger().Log(logger.Verbose, "SetClientInfo", c.clientinfo.Serialized)
}

_, err := c.conn.Write(c.clientinfo.Serialized)
if err != nil {
if logger.GetLogger().V(logger.Warning) {
logger.GetLogger().Log(logger.Warning, "Failed to send client info")
}
return errors.New("Failed custom auth, failed to send client info")
}
ns, err := c.reader.ReadNext()
if err != nil {
if logger.GetLogger().V(logger.Warning) {
logger.GetLogger().Log(logger.Warning, "Failed to read server info")
}
return errors.New("Failed to read server info")
}
if logger.GetLogger().V(logger.Debug) {
logger.GetLogger().Log(logger.Debug, "Server info:", string(ns.Payload))
}
return nil
}

func (c *heraConnection) SetClientInfoWithPoolStack(poolName string, host string, poolStack string)(error){
func (c *heraConnection) SetClientInfoWithPoolStack(poolName string, host string, poolStack string) error {
if len(poolName) <= 0 && len(host) <= 0 && len(poolStack) <= 0 {
return nil
}

pid := os.Getpid()
data := fmt.Sprintf("PID: %d, HOST: %s, Poolname: %s, PoolStack: %s, Command: SetClientInfo,", pid, host, poolName, poolStack)
c.clientinfo = netstring.NewNetstringFrom(common.CmdClientInfo, []byte(string(data)))
if logger.GetLogger().V(logger.Verbose) {
logger.GetLogger().Log(logger.Verbose, "SetClientInfo", c.clientinfo.Serialized)
}

_, err := c.conn.Write(c.clientinfo.Serialized)
if err != nil {
if logger.GetLogger().V(logger.Warning) {
logger.GetLogger().Log(logger.Warning, "Failed to send client info")
}
return errors.New("Failed custom auth, failed to send client info")
}
ns, err := c.reader.ReadNext()
if err != nil {
if logger.GetLogger().V(logger.Warning) {
logger.GetLogger().Log(logger.Warning, "Failed to read server info")
}
return errors.New("Failed to read server info")
}
if logger.GetLogger().V(logger.Debug) {
logger.GetLogger().Log(logger.Debug, "Server info:", string(ns.Payload))
}
c.clientinfo = netstring.NewNetstringFrom(common.CmdClientInfo, []byte(string(data)))
if logger.GetLogger().V(logger.Verbose) {
logger.GetLogger().Log(logger.Verbose, "SetClientInfo", c.clientinfo.Serialized)
}

_, err := c.conn.Write(c.clientinfo.Serialized)
if err != nil {
if logger.GetLogger().V(logger.Warning) {
logger.GetLogger().Log(logger.Warning, "Failed to send client info")
}
return errors.New("Failed custom auth, failed to send client info")
}
ns, err := c.reader.ReadNext()
if err != nil {
if logger.GetLogger().V(logger.Warning) {
logger.GetLogger().Log(logger.Warning, "Failed to read server info")
}
return errors.New("Failed to read server info")
}
if logger.GetLogger().V(logger.Debug) {
logger.GetLogger().Log(logger.Debug, "Server info:", string(ns.Payload))
}
return nil
}
}

func (c *heraConnection) getID() string {
return c.id
}

func (c *heraConnection) getCorrID() *netstring.Netstring {
return c.corrID
}

func (c *heraConnection) getShardKeyPayload() []byte {
return c.shardKeyPayload
}

func (c *heraConnection) setCorrID(newCorrID *netstring.Netstring) {
c.corrID = newCorrID
}
Loading