diff --git a/pkg/spark/spark.go b/pkg/spark/spark.go index 2581e61..65e1e7d 100644 --- a/pkg/spark/spark.go +++ b/pkg/spark/spark.go @@ -21,6 +21,7 @@ import ( "path" "path/filepath" "strings" + "time" "go.uber.org/zap" "go.uber.org/zap/zapio" @@ -162,7 +163,31 @@ func (s *Spark) exec(args []string) { cmd.Stdout = writer defer writer.Close() } - if err := cmd.Run(); err != nil { - zap.L().Error("spark-submit failed", zap.Error(err)) + + if err := retry(10, 1*time.Second, 2, 5*time.Minute, func() error { + return cmd.Run() + }); err != nil { + zap.L().Error("spark submit failed with retries", zap.Error(err)) + } +} + +func retry(retries int, initialDelay time.Duration, mult int, maxWait time.Duration, fn func() error) error { + delay := initialDelay + for try := 0; try < retries; try++ { + if err := fn(); err == nil { + return nil + } else { + zap.L().Warn( + "retry failed", + zap.Int("try", try), + zap.String("waitDuration", delay.String()), + ) + } + time.Sleep(delay) + delay = delay * time.Duration(mult) + if delay >= maxWait { + delay = maxWait + } } + return fmt.Errorf("retries exceeded") } diff --git a/pkg/spark/spark_test.go b/pkg/spark/spark_test.go index ad33bc8..4731572 100644 --- a/pkg/spark/spark_test.go +++ b/pkg/spark/spark_test.go @@ -14,7 +14,9 @@ limitations under the License. package spark import ( + "fmt" "testing" + "time" "github.com/stretchr/testify/require" ) @@ -75,4 +77,17 @@ func TestSpark(t *testing.T) { "--status=namespace:name", }, args) }) + + t.Run("retry works", func(t *testing.T) { + try := 0 + fn := func() error { + if try >= 2 { + return nil + } + try++ + return fmt.Errorf("error in fn") + } + require.NoError(t, retry(3, 1*time.Nanosecond, 2, 1*time.Second, fn)) + require.Greater(t, try, 1) + }) }