From 9cb81ce7c7c17fbdb31da047a04b28025e9cae9e Mon Sep 17 00:00:00 2001 From: Rajesh S Date: Thu, 21 Mar 2024 17:28:48 +0530 Subject: [PATCH] adding context timeouts for management queries --- lib/racmaint.go | 27 ++++++++++++++---- lib/shardingcfg.go | 69 +++++++++++++++++++++++++--------------------- 2 files changed, 59 insertions(+), 37 deletions(-) diff --git a/lib/racmaint.go b/lib/racmaint.go index 27c3bbea..2aef6912 100644 --- a/lib/racmaint.go +++ b/lib/racmaint.go @@ -102,9 +102,20 @@ func racMaintMain(shard int, interval int, cmdLineModuleName string) { binds[0], err = os.Hostname() binds[0] = strings.ToUpper(binds[0]) binds[1] = strings.ToUpper(cmdLineModuleName) // */ + waitTime := time.Second * time.Duration(interval) + //First time data loading + racMaint(&ctx, shard, db, racSQL, cmdLineModuleName, prev, waitTime/2) + + timeTicker := time.NewTicker(waitTime) for { - racMaint(ctx, shard, db, racSQL, cmdLineModuleName, prev) - time.Sleep(time.Second * time.Duration(interval)) + select { + case <-ctx.Done(): + logger.GetLogger().Log(logger.Alert, "Application main context has been closed, so exiting from racmaint data reload.") + return + case <-timeTicker.C: + //Periodic data loading + racMaint(&ctx, shard, db, racSQL, cmdLineModuleName, prev, waitTime/2) + } } } @@ -112,14 +123,18 @@ func racMaintMain(shard int, interval int, cmdLineModuleName string) { racMaint is the main function for RAC maintenance processing, being called regularly. When maintenance is planned, it calls workerpool.RacMaint to start the actuall processing */ -func racMaint(ctx context.Context, shard int, db *sql.DB, racSQL string, cmdLineModuleName string, prev map[racCfgKey]racCfg) { +func racMaint(ctx *context.Context, shard int, db *sql.DB, racSQL string, cmdLineModuleName string, prev map[racCfgKey]racCfg, queryTimeout time.Duration) { // // print this log for unittesting // if logger.GetLogger().V(logger.Verbose) { logger.GetLogger().Log(logger.Verbose, "Rac maint check, shard =", shard) } - conn, err := db.Conn(ctx) + //create cancellable context + queryContext, cancel := context.WithTimeout(*ctx, queryTimeout) + defer cancel() // Always call cancel to release resources associated with the context + + conn, err := db.Conn(queryContext) if err != nil { if logger.GetLogger().V(logger.Info) { logger.GetLogger().Log(logger.Info, "Error (conn) rac maint for shard =", shard, ",err :", err) @@ -127,7 +142,7 @@ func racMaint(ctx context.Context, shard int, db *sql.DB, racSQL string, cmdLine return } defer conn.Close() - stmt, err := conn.PrepareContext(ctx, racSQL) + stmt, err := conn.PrepareContext(queryContext, racSQL) if err != nil { if logger.GetLogger().V(logger.Info) { logger.GetLogger().Log(logger.Info, "Error (stmt) rac maint for shard =", shard, ",err :", err) @@ -139,7 +154,7 @@ func racMaint(ctx context.Context, shard int, db *sql.DB, racSQL string, cmdLine hostname = strings.ToUpper(hostname) module := strings.ToUpper(cmdLineModuleName) module_taf := fmt.Sprintf("%s_TAF", module) - rows, err := stmt.QueryContext(ctx, hostname, module_taf, module) + rows, err := stmt.QueryContext(queryContext, hostname, module_taf, module) if err != nil { if logger.GetLogger().V(logger.Info) { logger.GetLogger().Log(logger.Info, "Error (query) rac maint for shard =", shard, ",err :", err) diff --git a/lib/shardingcfg.go b/lib/shardingcfg.go index c6dac50c..995f53e0 100644 --- a/lib/shardingcfg.go +++ b/lib/shardingcfg.go @@ -100,7 +100,7 @@ func getSQL() string { /* load the physical to logical maping */ -func loadMap(ctx context.Context, db *sql.DB) error { +func loadMap(ctx *context.Context, db *sql.DB, queryTimeoutInterval time.Duration) error { if logger.GetLogger().V(logger.Verbose) { logger.GetLogger().Log(logger.Verbose, "Begin loading shard map") } @@ -109,17 +109,18 @@ func loadMap(ctx context.Context, db *sql.DB) error { logger.GetLogger().Log(logger.Verbose, "Done loading shard map") }() } - - conn, err := db.Conn(ctx) + queryContext, cancel := context.WithTimeout(*ctx, queryTimeoutInterval) + defer cancel() + conn, err := db.Conn(queryContext) if err != nil { return fmt.Errorf("Error (conn) loading shard map: %s", err.Error()) } defer conn.Close() - stmt, err := conn.PrepareContext(ctx, getSQL()) + stmt, err := conn.PrepareContext(queryContext, getSQL()) if err != nil { return fmt.Errorf("Error (stmt) loading shard map: %s", err.Error()) } - rows, err := stmt.QueryContext(ctx) + rows, err := stmt.QueryContext(queryContext) if err != nil { return fmt.Errorf("Error (query) loading shard map: %s", err.Error()) } @@ -216,7 +217,7 @@ func getWLSQL() string { /* load the whitelist mapping */ -func loadWhitelist(ctx context.Context, db *sql.DB) { +func loadWhitelist(ctx *context.Context, db *sql.DB, timeout time.Duration) { if logger.GetLogger().V(logger.Verbose) { logger.GetLogger().Log(logger.Verbose, "Begin loading whitelist") } @@ -225,19 +226,20 @@ func loadWhitelist(ctx context.Context, db *sql.DB) { logger.GetLogger().Log(logger.Verbose, "Done loading whitelist") }() } - - conn, err := db.Conn(ctx) + queryContext, cancel := context.WithTimeout(*ctx, timeout) + defer cancel() + conn, err := db.Conn(queryContext) if err != nil { logger.GetLogger().Log(logger.Alert, "Error (conn) loading whitelist:", err) return } defer conn.Close() - stmt, err := conn.PrepareContext(ctx, getWLSQL()) + stmt, err := conn.PrepareContext(queryContext, getWLSQL()) if err != nil { logger.GetLogger().Log(logger.Alert, "Error (stmt) loading whitelist:", err) return } - rows, err := stmt.QueryContext(ctx) + rows, err := stmt.QueryContext(queryContext) if err != nil { logger.GetLogger().Log(logger.Alert, "Error (query) loading whitelist:", err) return @@ -291,7 +293,10 @@ func InitShardingCfg() error { ctx := context.Background() var db *sql.DB var err error - + reloadInterval := time.Second * time.Duration(GetConfig().ShardingCfgReloadInterval) + if reloadInterval < 100*time.Millisecond { + reloadInterval = 100 * time.Millisecond + } i := 0 for ; i < 60; i++ { for shard := 0; shard < GetConfig().NumOfShards; shard++ { @@ -300,7 +305,7 @@ func InitShardingCfg() error { } db, err = openDb(shard) if err == nil { - err = loadMap(ctx, db) + err = loadMap(&ctx, db, reloadInterval/2) if err == nil { break } @@ -319,32 +324,34 @@ func InitShardingCfg() error { return errors.New("Failed to load shard map, no more retry") } if GetConfig().EnableWhitelistTest { - loadWhitelist(ctx, db) + loadWhitelist(&ctx, db, reloadInterval/2) } go func() { + reloadTimer := time.NewTimer(reloadInterval) //Periodic reload timer for { - reloadInterval := time.Second * time.Duration(GetConfig().ShardingCfgReloadInterval) - if reloadInterval < 100 * time.Millisecond { - reloadInterval = 100 * time.Millisecond - } - time.Sleep(reloadInterval) - for shard := 0; shard < GetConfig().NumOfShards; shard++ { - if db != nil { - db.Close() - } - db, err = openDb(shard) - if err == nil { - err = loadMap(ctx, db) + select { + case <-ctx.Done(): + logger.GetLogger().Log(logger.Alert, "Application main context has been closed, so exiting from shard-config data reload.") + return + case <-reloadTimer.C: + for shard := 0; shard < GetConfig().NumOfShards; shard++ { + if db != nil { + db.Close() + } + db, err = openDb(shard) if err == nil { - if shard == 0 && GetConfig().EnableWhitelistTest { - loadWhitelist(ctx, db) + err = loadMap(&ctx, db, reloadInterval/2) + if err == nil { + if shard == 0 && GetConfig().EnableWhitelistTest { + loadWhitelist(&ctx, db, reloadInterval/2) + } + break } - break } + logger.GetLogger().Log(logger.Warning, "Error <", err, "> loading the shard map from shard", shard) + evt := cal.NewCalEvent(cal.EventTypeError, "no_shard_map", cal.TransOK, "Error loading shard map") + evt.Completed() } - logger.GetLogger().Log(logger.Warning, "Error <", err, "> loading the shard map from shard", shard) - evt := cal.NewCalEvent(cal.EventTypeError, "no_shard_map", cal.TransOK, "Error loading shard map") - evt.Completed() } } }()