diff --git a/backend/controller/controller.go b/backend/controller/controller.go index 95d9898cb..5470ba26b 100644 --- a/backend/controller/controller.go +++ b/backend/controller/controller.go @@ -769,21 +769,10 @@ func (s *Service) GetModuleContext(ctx context.Context, req *connect.Request[ftl continue } dbTypes[db.Name] = dbType - // TODO: Move the DSN resolution to the runtime - if db.Runtime != nil { - var dsn string - switch dbType { - case modulecontext.DBTypePostgres: - // TODO: Get the port from config - dsn = "postgres://127.0.0.1:5678/" + db.Name - case modulecontext.DBTypeMySQL: - // TODO: Route MySQL through a proxy as well - dsn = db.Runtime.DSN - default: - return connect.NewError(connect.CodeInternal, fmt.Errorf("unknown DB type: %s", db.Type)) - } + // TODO: Move the DSN resolution to the runtime once MySQL proxy is working + if db.Runtime != nil && dbType == modulecontext.DBTypeMySQL { databases[db.Name] = modulecontext.Database{ - DSN: dsn, + DSN: db.Runtime.DSN, DBType: dbType, } } diff --git a/backend/runner/runner.go b/backend/runner/runner.go index eb4026d8e..576cef682 100644 --- a/backend/runner/runner.go +++ b/backend/runner/runner.go @@ -63,8 +63,6 @@ type Config struct { Registry artefacts.RegistryConfig `embed:"" prefix:"oci-"` ObservabilityConfig ftlobservability.Config `embed:"" prefix:"o11y-"` DevEndpoint optional.Option[url.URL] `help:"An existing endpoint to connect to in development mode" env:"FTL_DEV_ENDPOINT"` - - PgProxyConfig pgproxy.Config `embed:"" prefix:"pgproxy-"` } func Start(ctx context.Context, config Config) error { @@ -141,7 +139,7 @@ func Start(ctx context.Context, config Config) error { return connect.NewError(connect.CodeInvalidArgument, fmt.Errorf("invalid deployment key: %w", err)) } - module, err := svc.GetModule(ctx, deploymentKey) + module, err := svc.getModule(ctx, deploymentKey) if err != nil { return fmt.Errorf("failed to get module: %w", err) } @@ -160,12 +158,17 @@ func Start(ctx context.Context, config Config) error { go rpc.RetryStreamingClientStream(ctx, backoff.Backoff{}, controllerClient.StreamDeploymentLogs, svc.streamLogsLoop) }() + pgProxyStarted := make(chan pgproxy.Started) + g, ctx := errgroup.WithContext(ctx) g.Go(func() error { - return svc.startPgProxy(ctx, module) + return svc.startPgProxy(ctx, module, pgProxyStarted) }) g.Go(func() error { - // TODO: Make sure pgproxy is ready before starting the runner + pgProxy := <-pgProxyStarted + os.Setenv("PG_PROXY_ADDRESS", fmt.Sprintf("127.0.0.1:%d", pgProxy.Address.Port)) + logger.Debugf("PG_PROXY_ADDRESS: %s", os.Getenv("PG_PROXY_ADDRESS")) + return rpc.Serve(ctx, config.Bind, rpc.GRPC(ftlv1connect.NewVerbServiceHandler, svc), rpc.HTTP("/", svc), @@ -319,7 +322,7 @@ func (s *Service) Ping(ctx context.Context, req *connect.Request[ftlv1.PingReque return connect.NewResponse(&ftlv1.PingResponse{}), nil } -func (s *Service) GetModule(ctx context.Context, key model.DeploymentKey) (*schema.Module, error) { +func (s *Service) getModule(ctx context.Context, key model.DeploymentKey) (*schema.Module, error) { gdResp, err := s.controllerClient.GetDeployment(ctx, connect.NewRequest(&ftlv1.GetDeploymentRequest{DeploymentKey: s.config.Deployment})) if err != nil { observability.Deployment.Failure(ctx, optional.Some(key.String())) @@ -593,7 +596,9 @@ func (s *Service) healthCheck(writer http.ResponseWriter, request *http.Request) writer.WriteHeader(http.StatusServiceUnavailable) } -func (s *Service) startPgProxy(ctx context.Context, module *schema.Module) error { +func (s *Service) startPgProxy(ctx context.Context, module *schema.Module, started chan<- pgproxy.Started) error { + logger := log.FromContext(ctx) + databases := map[string]*schema.Database{} for _, decl := range module.Decls { if db, ok := decl.(*schema.Database); ok { @@ -601,13 +606,15 @@ func (s *Service) startPgProxy(ctx context.Context, module *schema.Module) error } } - if err := pgproxy.New(s.config.PgProxyConfig, func(ctx context.Context, params map[string]string) (string, error) { + if err := pgproxy.New(":0", func(ctx context.Context, params map[string]string) (string, error) { db, ok := databases[params["database"]] if !ok { return "", fmt.Errorf("database %s not found", params["database"]) } + logger.Debugf("Resolved DSN (%s): %s", params["database"], db.Runtime.DSN) + return db.Runtime.DSN, nil - }).Start(ctx); err != nil { + }).Start(ctx, started); err != nil { return fmt.Errorf("failed to start pgproxy: %w", err) } return nil diff --git a/cmd/ftl-proxy-pg/main.go b/cmd/ftl-proxy-pg/main.go index 781fda75d..1bf84fbd5 100644 --- a/cmd/ftl-proxy-pg/main.go +++ b/cmd/ftl-proxy-pg/main.go @@ -38,10 +38,10 @@ func main() { err = observability.Init(ctx, false, "", "ftl-provisioner", ftl.Version, cli.ObservabilityConfig) kctx.FatalIfErrorf(err, "failed to initialize observability") - proxy := pgproxy.New(cli.Config, func(ctx context.Context, params map[string]string) (string, error) { + proxy := pgproxy.New(cli.Config.Listen, func(ctx context.Context, params map[string]string) (string, error) { return "postgres://localhost:5432/postgres?user=" + params["user"], nil }) - if err := proxy.Start(ctx); err != nil { + if err := proxy.Start(ctx, nil); err != nil { kctx.FatalIfErrorf(err, "failed to start proxy") } } diff --git a/internal/modulecontext/module_context.go b/internal/modulecontext/module_context.go index 05d063185..ef1e91144 100644 --- a/internal/modulecontext/module_context.go +++ b/internal/modulecontext/module_context.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "os" "strings" "sync" "time" @@ -159,6 +160,10 @@ func (m ModuleContext) GetSecret(name string, value any) error { func (m ModuleContext) GetDatabase(name string, dbType DBType) (string, bool, error) { db, ok := m.databases[name] if !ok { + if dbType == DBTypePostgres { + proxyAddress := os.Getenv("PG_PROXY_ADDRESS") + return "postgres://" + proxyAddress + "/" + name, false, nil + } return "", false, fmt.Errorf("missing DSN for database %s", name) } if db.DBType != dbType { diff --git a/internal/pgproxy/pgproxy.go b/internal/pgproxy/pgproxy.go index 916bdcaec..d16bdc529 100644 --- a/internal/pgproxy/pgproxy.go +++ b/internal/pgproxy/pgproxy.go @@ -30,15 +30,19 @@ type DSNConstructor func(ctx context.Context, params map[string]string) (string, // // address is the address to listen on for incoming connections. // connectionFn is a function that constructs a new connection string from parameters of the incoming connection. -func New(config Config, connectionFn DSNConstructor) *PgProxy { +func New(listenAddress string, connectionFn DSNConstructor) *PgProxy { return &PgProxy{ - listenAddress: config.Listen, + listenAddress: listenAddress, connectionStringFn: connectionFn, } } -// Start the proxy. -func (p *PgProxy) Start(ctx context.Context) error { +type Started struct { + Address *net.TCPAddr +} + +// Start the proxy +func (p *PgProxy) Start(ctx context.Context, started chan<- Started) error { logger := log.FromContext(ctx) listener, err := net.Listen("tcp", p.listenAddress) @@ -47,6 +51,14 @@ func (p *PgProxy) Start(ctx context.Context) error { } defer listener.Close() + if started != nil { + addr, ok := listener.Addr().(*net.TCPAddr) + if !ok { + panic("failed to get TCP address") + } + started <- Started{Address: addr} + } + for { conn, err := listener.Accept() if err != nil {