diff --git a/main.go b/main.go index 65f6d65..e3c4fed 100644 --- a/main.go +++ b/main.go @@ -10,6 +10,8 @@ import ( "net/http/httputil" "net/url" "os" + "os/signal" + "syscall" "time" ) @@ -53,14 +55,13 @@ func lb(w http.ResponseWriter, r *http.Request) { http.Error(w, "Service not available", http.StatusServiceUnavailable) } -// main function with arguement for configfile func main() { // get file name from argument arg := os.Args if len(arg) != 2 { log.Fatal("usage go run main.go '") } - + // declare slice for backend server backendservers := []string{} @@ -86,7 +87,7 @@ func main() { if err != nil { log.Println("Error parsing URL") } - proxy := httputil.NewSingleHostReverseProxy(be) //this is one of other backend servers need to pust in server pool + proxy := httputil.NewSingleHostReverseProxy(be) proxy.Director = func(r *http.Request) { if r.Body != nil { bodyBytes, _ := io.ReadAll(r.Body) @@ -95,7 +96,6 @@ func main() { r.Header.Set("User-Agent", "Your-User-Agent") r.Header.Set("Accept", "application/json") r.Header.Set("X-Custom-Header", "CustomValue") - // Adjust URL and Host r.URL.Scheme = be.Scheme r.URL.Host = be.Host r.Host = be.Host @@ -105,25 +105,22 @@ func main() { log.Printf("[%s] Request Canceled: %v\n", be.Host, r.Context().Err() == context.Canceled) log.Printf("[%s] %s\n", be.Host, e.Error()) - retries := GetRetryFromContext(r) // by default the retry count is 0 + retries := GetRetryFromContext(r) log.Println("This is the retry count", retries, "of the server", serverPool.Current) if retries < 3 { time.Sleep(10 * time.Millisecond) - ctx := context.WithValue(r.Context(), Retry, retries+1) // increment the retry count and set it in context + ctx := context.WithValue(r.Context(), Retry, retries+1) log.Println("check") proxy.ServeHTTP(w, r.WithContext(ctx)) return } - // if the retry count is more than 3 then mark the server as down log.Printf("[%s] Marking server as down\n", be.Host) serverPool.MarkDownTheServer(be, false) - // attempts := GetAttemptsFromContext(r) - // ctx := context.WithValue(r.Context(), Attempts, attempts+1) - lb(w, r) // this function will find the next alive server and redirect the request + lb(w, r) } serverPool.Backends = append(serverPool.Backends, &lib.ServerNode{ @@ -132,7 +129,7 @@ func main() { ReverseProxy: proxy, }) } - // http.HandleFunc("/", testHandler) + server := &http.Server{ Addr: ":8000", WriteTimeout: 15 * time.Second, @@ -140,9 +137,30 @@ func main() { IdleTimeout: 60 * time.Second, Handler: http.HandlerFunc(lb), } - log.Println("Server is starting on port 8000") - if err := server.ListenAndServe(); err != nil { - log.Fatalf("Server failed: %v", err) + + // Channel to listen for interrupt or termination signals + shutdown := make(chan os.Signal, 1) + signal.Notify(shutdown, os.Interrupt, syscall.SIGTERM) + + // Start the server in a goroutine + go func() { + log.Println("Server is starting on port 8000") + if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + log.Fatalf("Server failed: %v", err) + } + }() + + // Wait for termination signal + <-shutdown + + // Graceful shutdown with timeout + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + log.Println("Shutting down gracefully...") + if err := server.Shutdown(ctx); err != nil { + log.Fatalf("Server forced to shutdown: %v", err) } + log.Println("Server exited properly") }