diff --git a/tool/tctl/common/loadtest_command.go b/tool/tctl/common/loadtest_command.go index 891568e200599..6155afa705b79 100644 --- a/tool/tctl/common/loadtest_command.go +++ b/tool/tctl/common/loadtest_command.go @@ -20,6 +20,8 @@ package common import ( "context" + "crypto/tls" + "crypto/x509" "fmt" "log/slog" "os" @@ -27,6 +29,7 @@ import ( "runtime" "strconv" "strings" + "sync" "sync/atomic" "time" @@ -34,15 +37,25 @@ import ( "github.com/google/uuid" "github.com/gravitational/trace" log "github.com/sirupsen/logrus" + "golang.org/x/sync/errgroup" "github.com/gravitational/teleport" + apiclient "github.com/gravitational/teleport/api/client" + "github.com/gravitational/teleport/api/client/proto" auditlogpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/auditlog/v1" "github.com/gravitational/teleport/api/types" + apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/lib/auth/authclient" + "github.com/gravitational/teleport/lib/auth/keystore" "github.com/gravitational/teleport/lib/cache" + "github.com/gravitational/teleport/lib/cryptosuites" + "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/events/export" + "github.com/gravitational/teleport/lib/inventory" "github.com/gravitational/teleport/lib/service/servicecfg" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/srv" + "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" ) @@ -62,6 +75,8 @@ type LoadtestCommand struct { ttl time.Duration concurrency int + inventory bool + kind string ops string format string @@ -84,6 +99,7 @@ func (c *LoadtestCommand) Initialize(app *kingpin.Application, config *servicecf c.nodeHeartbeats.Flag("concurrency", "Max concurrent requests").Default( strconv.Itoa(runtime.NumCPU() * 16), ).IntVar(&c.concurrency) + c.nodeHeartbeats.Flag("inventory", "Use the inventory control stream to heartbeat").BoolVar(&c.inventory) c.watch = loadtest.Command("watch", "Monitor event stream").Hidden() c.watch.Flag("kind", "Resource kind(s) to watch, e.g. --kind=node,user,role").StringVar(&c.kind) @@ -99,7 +115,11 @@ func (c *LoadtestCommand) Initialize(app *kingpin.Application, config *servicecf func (c *LoadtestCommand) TryRun(ctx context.Context, cmd string, client *authclient.Client) (match bool, err error) { switch cmd { case c.nodeHeartbeats.FullCommand(): - err = c.NodeHeartbeats(ctx, client) + if c.inventory { + err = c.InventoryHeartbeats(ctx, client) + } else { + err = c.NodeHeartbeats(ctx, client) + } case c.watch.FullCommand(): err = c.Watch(ctx, client) case c.auditEvents.FullCommand(): @@ -110,6 +130,267 @@ func (c *LoadtestCommand) TryRun(ctx context.Context, cmd string, client *authcl return true, trace.Wrap(err) } +func (c *LoadtestCommand) InventoryHeartbeats(ctx context.Context, client *authclient.Client) error { + ctx, stop := signal.NotifyContext(ctx, os.Interrupt) + defer stop() + defer context.AfterFunc(ctx, stop)() + + slog.WarnContext(ctx, "setting up node heartbeats through the inventory", + "count", c.count, + "labels", c.labels, + ) + + typesClusterName, err := client.GetClusterName() + if err != nil { + return trace.Wrap(err) + } + clusterName := typesClusterName.GetClusterName() + + authPreference, err := client.GetAuthPreference(ctx) + if err != nil { + return trace.Wrap(err) + } + authPreferenceGetter := fixedAuthPreferenceGetter{authPreference: authPreference} + + keystoreManager, err := keystore.NewManager(ctx, &servicecfg.KeystoreConfig{}, &keystore.Options{ + HostUUID: "", + ClusterName: typesClusterName, + AuthPreferenceGetter: authPreferenceGetter, + }) + if err != nil { + return trace.Wrap(err) + } + + const loadSigningKeysTrue = true + ca, err := client.GetCertAuthority(ctx, types.CertAuthID{ + Type: types.HostCA, + DomainName: clusterName, + }, loadSigningKeysTrue) + if err != nil { + return trace.Wrap(err) + } + caCert, caSigner, err := keystoreManager.GetTLSCertAndSigner(ctx, ca) + if err != nil { + return trace.Wrap(err) + } + tlscaCA, err := tlsca.FromCertAndSigner(caCert, caSigner) + if err != nil { + return trace.Wrap(err) + } + caPool := x509.NewCertPool() + for _, keyPair := range ca.GetTrustedTLSKeyPairs() { + cert, err := tlsca.ParseCertificatePEM(keyPair.Cert) + if err != nil { + return trace.Wrap(err) + } + caPool.AddCert(cert) + } + + key, err := cryptosuites.GenerateKey(ctx, + cryptosuites.GetCurrentSuiteFromAuthPreference(authPreferenceGetter), + cryptosuites.HostIdentity, + ) + if err != nil { + return trace.Wrap(err) + } + + hostIDs := make([]string, 0, c.count) + for range c.count { + hostIDs = append(hostIDs, uuid.NewString()) + } + + clients := make([]*apiclient.Client, len(hostIDs)) + defer func() { + for _, clt := range clients { + if clt != nil { + _ = clt.Close() + } + } + }() + + eg, egCtx := errgroup.WithContext(ctx) + eg.SetLimit(c.concurrency) + + for i, hostID := range hostIDs { + ctx := egCtx + eg.Go(func() error { + identity := tlsca.Identity{ + Username: authclient.HostFQDN(hostID, clusterName), + Groups: []string{string(types.RoleInstance)}, + TeleportCluster: clusterName, + SystemRoles: []string{string(types.RoleNode)}, + } + subject, err := identity.Subject() + if err != nil { + return trace.Wrap(err) + } + hostCertPEM, err := tlscaCA.GenerateCertificate(tlsca.CertificateRequest{ + PublicKey: key.Public(), + Subject: subject, + NotAfter: time.Now().Add(defaults.CATTL), + }) + if err != nil { + return trace.Wrap(err) + } + hostCert, err := tlsca.ParseCertificatePEM(hostCertPEM) + if err != nil { + return trace.Wrap(err) + } + hostTLSCert := tls.Certificate{ + Certificate: [][]byte{hostCert.Raw}, + PrivateKey: key, + Leaf: hostCert, + } + + clt, err := apiclient.New(ctx, apiclient.Config{ + Context: ctx, + + Addrs: []string{c.config.AuthServerAddresses()[0].String()}, + Credentials: []apiclient.Credentials{apiclient.LoadTLS(&tls.Config{ + Certificates: []tls.Certificate{hostTLSCert}, + RootCAs: caPool, + // NextProtos: []string{"h2"}, + ServerName: apiutils.EncodeClusterName(clusterName), + + CipherSuites: c.config.CipherSuites, + + MinVersion: tls.VersionTLS12, + })}, + + ALPNSNIAuthDialClusterName: clusterName, + }) + if err != nil { + return trace.Wrap(err) + } + if _, err := clt.Ping(ctx); err != nil { + return trace.Wrap(err) + } + clients[i] = clt + return nil + }) + } + if err := eg.Wait(); err != nil { + return trace.Wrap(err) + } + + c.config.Logger.WarnContext(ctx, "generated host certs and opened connections", + "count", c.count, + "labels", c.labels, + ) + + var wg sync.WaitGroup + defer wg.Wait() + var failing atomic.Uint64 + + for i, hostID := range hostIDs { + wg.Add(1) + client := clients[i] + go func() { + defer wg.Done() + c.runInventorySimulator(ctx, client, hostID, &failing) + }() + } + + report := time.NewTicker(5 * time.Second) + defer report.Stop() + + for { + select { + case <-report.C: + slog.InfoContext(ctx, "heartbeat status", + "running", c.count, + "failing", failing.Load(), + ) + case <-ctx.Done(): + return nil + } + } +} + +func (c *LoadtestCommand) runInventorySimulator(ctx context.Context, client *apiclient.Client, hostID string, failing *atomic.Uint64) { + slog.InfoContext(ctx, "starting stream and heartbeat", + "host_id", hostID, + ) + handle := inventory.NewDownstreamHandle(client.InventoryControlStream, proto.UpstreamInventoryHello{ + ServerID: hostID, + Version: teleport.Version, + Services: []types.SystemRole{types.RoleNode}, + Hostname: "host-" + hostID, + }) + handle.RegisterPingHandler(func(sender inventory.DownstreamSender, ping proto.DownstreamInventoryPing) { + _ = sender.Send(handle.CloseContext(), proto.UpstreamInventoryPong{ + ID: ping.ID, + }) + }) + defer handle.Close() + + server := &types.ServerV2{ + Kind: types.KindNode, + SubKind: types.SubKindTeleportNode, + Metadata: types.Metadata{ + Name: hostID, + }, + Spec: types.ServerSpecV2{ + Hostname: "host-" + hostID, + UseTunnel: true, + Version: teleport.Version, + }, + } + if err := server.CheckAndSetDefaults(); err != nil { + panic(err) + } + + var failed atomic.Bool + hb, err := srv.NewSSHServerHeartbeat(srv.HeartbeatV2Config[*types.ServerV2]{ + InventoryHandle: handle, + AnnounceInterval: c.interval, + GetResource: func(context.Context) (*types.ServerV2, error) { + return server, nil + }, + OnHeartbeat: func(err error) { + fail := err != nil + if failed.Swap(fail) { + if !fail { + failing.Add(^uint64(0)) + } + } else { + if fail { + failing.Add(1) + } + } + }, + }) + if err != nil { + panic(err) + } + + var hbWg sync.WaitGroup + defer hbWg.Wait() + + hbWg.Add(1) + go func() { + defer hbWg.Done() + hb.Run() + }() + defer hb.Close() + + <-ctx.Done() + slog.InfoContext(ctx, "stopping heartbeat", + "host_id", hostID, + ) +} + +type fixedAuthPreferenceGetter struct { + authPreference types.AuthPreference +} + +var _ cryptosuites.AuthPreferenceGetter = fixedAuthPreferenceGetter{} + +// GetAuthPreference implements [cryptosuites.AuthPreferenceGetter]. +func (f fixedAuthPreferenceGetter) GetAuthPreference(ctx context.Context) (types.AuthPreference, error) { + return f.authPreference, nil +} + func (c *LoadtestCommand) NodeHeartbeats(ctx context.Context, client *authclient.Client) error { infof := func(format string, args ...any) { fmt.Fprintf(os.Stderr, "[i] "+format+"\n", args...)