From 63e1070a717054a706c9182c3854cbd00a5948d1 Mon Sep 17 00:00:00 2001 From: Aaron Son Date: Tue, 21 Jan 2025 16:29:00 -0800 Subject: [PATCH] sql/session.go: Add the ability for integrators to receive some lifecycle callbacks. --- .gitignore | 3 +++ server/context.go | 14 ++++++++++++++ server/handler.go | 22 ++++++++++++++++++++++ sql/session.go | 33 +++++++++++++++++++++++++++++++++ 4 files changed, 72 insertions(+) diff --git a/.gitignore b/.gitignore index 3b34d04130..4dc1365a0c 100644 --- a/.gitignore +++ b/.gitignore @@ -40,3 +40,6 @@ test-server # OSX Files .DS_Store + +.dir-locals.el +*~ \ No newline at end of file diff --git a/server/context.go b/server/context.go index f56e205a47..13931f519f 100644 --- a/server/context.go +++ b/server/context.go @@ -103,6 +103,10 @@ func (s *SessionManager) NewSession(ctx context.Context, conn *mysql.Conn) error session.SetConnectionId(conn.ConnectionID) + if cur, ok := s.sessions[conn.ConnectionID]; ok { + sql.SessionEnd(cur) + } + s.sessions[conn.ConnectionID] = session logger := session.GetLogger() @@ -127,6 +131,13 @@ func (s *SessionManager) SetDB(conn *mysql.Conn, dbName string) error { return err } + err = sql.SessionCommandBegin(sess) + if err != nil { + sql.SessionEnd(sess) + return err + } + defer sql.SessionCommandEnd(sess) + ctx := sql.NewContext(context.Background(), sql.WithSession(sess)) var db sql.Database if dbName != "" { @@ -257,6 +268,9 @@ func (s *SessionManager) KillConnection(connID uint32) error { func (s *SessionManager) RemoveConn(conn *mysql.Conn) { s.mu.Lock() defer s.mu.Unlock() + if cur, ok := s.sessions[conn.ConnectionID]; ok { + sql.SessionEnd(cur) + } delete(s.sessions, conn.ConnectionID) delete(s.connections, conn.ConnectionID) s.processlist.RemoveConnection(conn.ConnectionID) diff --git a/server/handler.go b/server/handler.go index 6dc7625f4f..a09ea2772f 100644 --- a/server/handler.go +++ b/server/handler.go @@ -121,6 +121,12 @@ func (h *Handler) ComPrepare(ctx context.Context, c *mysql.Conn, query string, p if err != nil { return nil, err } + err = sql.SessionCommandBegin(sqlCtx.Session) + if err != nil { + return nil, err + } + defer sql.SessionCommandEnd(sqlCtx.Session) + var analyzed sql.Node if analyzer.PreparedStmtDisabled { analyzed, err = h.e.AnalyzeQuery(sqlCtx, query) @@ -161,6 +167,12 @@ func (h *Handler) ComPrepareParsed(ctx context.Context, c *mysql.Conn, query str return nil, nil, err } + err = sql.SessionCommandBegin(sqlCtx.Session) + if err != nil { + return nil, nil, err + } + defer sql.SessionCommandEnd(sqlCtx.Session) + analyzed, err := h.e.PrepareParsedQuery(sqlCtx, query, query, parsed) if err != nil { logrus.WithField("query", query).Errorf("unable to prepare query: %s", err.Error()) @@ -189,6 +201,11 @@ func (h *Handler) ComBind(ctx context.Context, c *mysql.Conn, query string, pars if err != nil { return nil, nil, err } + err = sql.SessionCommandBegin(sqlCtx.Session) + if err != nil { + return nil, nil, err + } + defer sql.SessionCommandEnd(sqlCtx.Session) stmt, ok := parsedQuery.(sqlparser.Statement) if !ok { @@ -378,6 +395,11 @@ func (h *Handler) doQuery( if err != nil { return "", err } + err = sql.SessionCommandBegin(sqlCtx.Session) + if err != nil { + return "", err + } + defer sql.SessionCommandEnd(sqlCtx.Session) start := time.Now() diff --git a/sql/session.go b/sql/session.go index ec73bea4c7..9a73c91930 100644 --- a/sql/session.go +++ b/sql/session.go @@ -207,6 +207,17 @@ type TransactionSession interface { ReleaseSavepoint(ctx *Context, transaction Transaction, name string) error } +// A LifecycleAwareSession is a a sql.Session that gets lifecycle callbacks +// from the handler when it begins and ends a command and when it itself ends. +// +// This is an optional interface which integrators can choose to implement +// if they want those callbacks. +type LifecycleAwareSession interface { + CommandBegin() error + CommandEnd() + SessionEnd() +} + type ( // TypedValue is a value along with its type. TypedValue struct { @@ -704,3 +715,25 @@ const ( VersionStable VersionExperimental ) + +// Helper function to call CommandBegin on a LifecycleAwareSession, or do nothing. +func SessionCommandBegin(s Session) error { + if cur, ok := s.(LifecycleAwareSession); ok { + return cur.CommandBegin() + } + return nil +} + +// Helper function to call CommandEnd on a LifecycleAwareSession, or do nothing. +func SessionCommandEnd(s Session) { + if cur, ok := s.(LifecycleAwareSession); ok { + cur.CommandEnd() + } +} + +// Helper function to call SessionEnd on a LifecycleAwareSession, or do nothing. +func SessionEnd(s Session) { + if cur, ok := s.(LifecycleAwareSession); ok { + cur.SessionEnd() + } +}