Skip to content

Commit

Permalink
Merge pull request #2823 from dolthub/aaron/session-command-beginend
Browse files Browse the repository at this point in the history
[no-release-notes] sql/session.go: Add the ability for integrators to receive some lifecycle callbacks.
  • Loading branch information
reltuk authored Jan 22, 2025
2 parents c5d0e52 + 63e1070 commit 0e32f1d
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,6 @@ test-server

# OSX Files
.DS_Store

.dir-locals.el
*~
14 changes: 14 additions & 0 deletions server/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ func (s *SessionManager) NewSession(ctx context.Context, conn *mysql.Conn) error

session.SetConnectionId(conn.ConnectionID)

if cur, ok := s.sessions[conn.ConnectionID]; ok {
sql.SessionEnd(cur)
}

s.sessions[conn.ConnectionID] = session

logger := session.GetLogger()
Expand All @@ -127,6 +131,13 @@ func (s *SessionManager) SetDB(conn *mysql.Conn, dbName string) error {
return err
}

err = sql.SessionCommandBegin(sess)
if err != nil {
sql.SessionEnd(sess)
return err
}
defer sql.SessionCommandEnd(sess)

ctx := sql.NewContext(context.Background(), sql.WithSession(sess))
var db sql.Database
if dbName != "" {
Expand Down Expand Up @@ -257,6 +268,9 @@ func (s *SessionManager) KillConnection(connID uint32) error {
func (s *SessionManager) RemoveConn(conn *mysql.Conn) {
s.mu.Lock()
defer s.mu.Unlock()
if cur, ok := s.sessions[conn.ConnectionID]; ok {
sql.SessionEnd(cur)
}
delete(s.sessions, conn.ConnectionID)
delete(s.connections, conn.ConnectionID)
s.processlist.RemoveConnection(conn.ConnectionID)
Expand Down
22 changes: 22 additions & 0 deletions server/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,12 @@ func (h *Handler) ComPrepare(ctx context.Context, c *mysql.Conn, query string, p
if err != nil {
return nil, err
}
err = sql.SessionCommandBegin(sqlCtx.Session)
if err != nil {
return nil, err
}
defer sql.SessionCommandEnd(sqlCtx.Session)

var analyzed sql.Node
if analyzer.PreparedStmtDisabled {
analyzed, err = h.e.AnalyzeQuery(sqlCtx, query)
Expand Down Expand Up @@ -161,6 +167,12 @@ func (h *Handler) ComPrepareParsed(ctx context.Context, c *mysql.Conn, query str
return nil, nil, err
}

err = sql.SessionCommandBegin(sqlCtx.Session)
if err != nil {
return nil, nil, err
}
defer sql.SessionCommandEnd(sqlCtx.Session)

analyzed, err := h.e.PrepareParsedQuery(sqlCtx, query, query, parsed)
if err != nil {
logrus.WithField("query", query).Errorf("unable to prepare query: %s", err.Error())
Expand Down Expand Up @@ -189,6 +201,11 @@ func (h *Handler) ComBind(ctx context.Context, c *mysql.Conn, query string, pars
if err != nil {
return nil, nil, err
}
err = sql.SessionCommandBegin(sqlCtx.Session)
if err != nil {
return nil, nil, err
}
defer sql.SessionCommandEnd(sqlCtx.Session)

stmt, ok := parsedQuery.(sqlparser.Statement)
if !ok {
Expand Down Expand Up @@ -378,6 +395,11 @@ func (h *Handler) doQuery(
if err != nil {
return "", err
}
err = sql.SessionCommandBegin(sqlCtx.Session)
if err != nil {
return "", err
}
defer sql.SessionCommandEnd(sqlCtx.Session)

start := time.Now()

Expand Down
33 changes: 33 additions & 0 deletions sql/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,17 @@ type TransactionSession interface {
ReleaseSavepoint(ctx *Context, transaction Transaction, name string) error
}

// A LifecycleAwareSession is a a sql.Session that gets lifecycle callbacks
// from the handler when it begins and ends a command and when it itself ends.
//
// This is an optional interface which integrators can choose to implement
// if they want those callbacks.
type LifecycleAwareSession interface {
CommandBegin() error
CommandEnd()
SessionEnd()
}

type (
// TypedValue is a value along with its type.
TypedValue struct {
Expand Down Expand Up @@ -704,3 +715,25 @@ const (
VersionStable
VersionExperimental
)

// Helper function to call CommandBegin on a LifecycleAwareSession, or do nothing.
func SessionCommandBegin(s Session) error {
if cur, ok := s.(LifecycleAwareSession); ok {
return cur.CommandBegin()
}
return nil
}

// Helper function to call CommandEnd on a LifecycleAwareSession, or do nothing.
func SessionCommandEnd(s Session) {
if cur, ok := s.(LifecycleAwareSession); ok {
cur.CommandEnd()
}
}

// Helper function to call SessionEnd on a LifecycleAwareSession, or do nothing.
func SessionEnd(s Session) {
if cur, ok := s.(LifecycleAwareSession); ok {
cur.SessionEnd()
}
}

0 comments on commit 0e32f1d

Please sign in to comment.