From 9d39024c6c9490fe7775ca39e7f857e390fc91ab Mon Sep 17 00:00:00 2001 From: Gabriel Corado Date: Wed, 18 Dec 2024 16:35:03 -0300 Subject: [PATCH 1/5] refactor: determine session type on stream based on watcher --- api/types/session_tracker.go | 1 + lib/auth/auth_with_roles.go | 106 +++++++++++++++++++---------------- lib/events/auditlog.go | 56 ++++++++++++++++++ 3 files changed, 116 insertions(+), 47 deletions(-) diff --git a/api/types/session_tracker.go b/api/types/session_tracker.go index db07ea2578db5..2892db170085c 100644 --- a/api/types/session_tracker.go +++ b/api/types/session_tracker.go @@ -39,6 +39,7 @@ const ( DatabaseSessionKind SessionKind = "db" AppSessionKind SessionKind = "app" WindowsDesktopSessionKind SessionKind = "desktop" + UnknownSessionKind SessionKind = "" ) // SessionParticipantMode is the mode that determines what you can do when you join a session. diff --git a/lib/auth/auth_with_roles.go b/lib/auth/auth_with_roles.go index e5fdaa9dab8ce..c0574d54c016a 100644 --- a/lib/auth/auth_with_roles.go +++ b/lib/auth/auth_with_roles.go @@ -196,30 +196,14 @@ func (a *ServerWithRoles) actionWithExtendedContext(kind, verb string, extendCon // actionForKindSession is a special checker that grants access to session // recordings. It can allow access to a specific recording based on the // `where` section of the user's access rule for kind `session`. -func (a *ServerWithRoles) actionForKindSession(ctx context.Context, sid session.ID) (types.SessionKind, error) { - sessionEnd, err := a.findSessionEndEvent(ctx, sid) - - extendContext := func(ctx *services.Context) error { - ctx.Session = sessionEnd +func (a *ServerWithRoles) actionForKindSession(ctx context.Context, sid session.ID) error { + extendContext := func(servicesCtx *services.Context) error { + sessionEnd, err := a.findSessionEndEvent(ctx, sid) + servicesCtx.Session = sessionEnd return trace.Wrap(err) } - var sessionKind types.SessionKind - switch e := sessionEnd.(type) { - case *apievents.SessionEnd: - sessionKind = types.SSHSessionKind - if e.KubernetesCluster != "" { - sessionKind = types.KubernetesSessionKind - } - case *apievents.DatabaseSessionEnd: - sessionKind = types.DatabaseSessionKind - case *apievents.AppSessionEnd: - sessionKind = types.AppSessionKind - case *apievents.WindowsDesktopSessionEnd: - sessionKind = types.WindowsDesktopSessionKind - } - - return sessionKind, trace.Wrap(a.actionWithExtendedContext(types.KindSession, types.VerbRead, extendContext)) + return trace.Wrap(a.actionWithExtendedContext(types.KindSession, types.VerbRead, extendContext)) } // localServerAction returns an access denied error if the role is not one of the builtin server roles. @@ -6034,44 +6018,72 @@ func (a *ServerWithRoles) ReplaceRemoteLocks(ctx context.Context, clusterName st // channel if one is encountered. Otherwise the event channel is closed when the stream ends. // The event channel is not closed on error to prevent race conditions in downstream select statements. func (a *ServerWithRoles) StreamSessionEvents(ctx context.Context, sessionID session.ID, startIndex int64) (chan apievents.AuditEvent, chan error) { - createErrorChannel := func(err error) (chan apievents.AuditEvent, chan error) { - e := make(chan error, 1) - e <- trace.Wrap(err) - return nil, e - } - err := a.localServerAction() isTeleportServer := err == nil - var sessionType types.SessionKind if !isTeleportServer { - var err error - sessionType, err = a.actionForKindSession(ctx, sessionID) - if err != nil { + if err := a.actionForKindSession(ctx, sessionID); err != nil { c, e := make(chan apievents.AuditEvent), make(chan error, 1) e <- trace.Wrap(err) return c, e } } - // StreamSessionEvents can be called internally, and when that happens we don't want to emit an event. - shouldEmitAuditEvent := !isTeleportServer - if shouldEmitAuditEvent { - if err := a.authServer.emitter.EmitAuditEvent(a.authServer.closeCtx, &apievents.SessionRecordingAccess{ - Metadata: apievents.Metadata{ - Type: events.SessionRecordingAccessEvent, - Code: events.SessionRecordingAccessCode, - }, - SessionID: sessionID.String(), - UserMetadata: a.context.Identity.GetIdentity().GetUserMetadata(), - SessionType: string(sessionType), - Format: metadata.SessionRecordingFormatFromContext(ctx), - }); err != nil { - return createErrorChannel(err) - } + // We can only determine the session type after the streaming started. For + // this reason, we delay the emit audit event until the first event or if + // the streaming returns an error. + watcher := &eventsWatcher{ + onSessionStart: func(evt apievents.AuditEvent, _ error) { + // StreamSessionEvents can be called internally, and when that + // happens we don't want to emit an event. + if isTeleportServer { + return + } + + if err := a.authServer.emitter.EmitAuditEvent(a.authServer.closeCtx, &apievents.SessionRecordingAccess{ + Metadata: apievents.Metadata{ + Type: events.SessionRecordingAccessEvent, + Code: events.SessionRecordingAccessCode, + }, + SessionID: sessionID.String(), + UserMetadata: a.context.Identity.GetIdentity().GetUserMetadata(), + SessionType: string(sessionTypeFromStartEvent(evt)), + Format: metadata.SessionRecordingFormatFromContext(ctx), + }); err != nil { + log.WithError(err).Errorf("Failed to emit stream session event audit event") + } + }, } - return a.alog.StreamSessionEvents(ctx, sessionID, startIndex) + return a.alog.StreamSessionEvents(events.ContextWithEventWatcher(ctx, watcher), sessionID, startIndex) +} + +type eventsWatcher struct { + onSessionStart func(apievents.AuditEvent, error) +} + +func (e *eventsWatcher) OnSessionStart(evt apievents.AuditEvent, err error) { + e.onSessionStart(evt, err) +} + +// sessionTypeFromStartEvent determines the session type given the session start +// event. +func sessionTypeFromStartEvent(sessionStart apievents.AuditEvent) types.SessionKind { + switch e := sessionStart.(type) { + case *apievents.SessionStart: + if e.KubernetesCluster != "" { + return types.KubernetesSessionKind + } + return types.SSHSessionKind + case *apievents.DatabaseSessionStart: + return types.DatabaseSessionKind + case *apievents.AppSessionStart: + return types.AppSessionKind + case *apievents.WindowsDesktopSessionStart: + return types.WindowsDesktopSessionKind + default: + return types.UnknownSessionKind + } } // CreateApp creates a new application resource. diff --git a/lib/events/auditlog.go b/lib/events/auditlog.go index 51180746cbe7f..3539a6cd19519 100644 --- a/lib/events/auditlog.go +++ b/lib/events/auditlog.go @@ -545,8 +545,25 @@ func (l *AuditLog) StreamSessionEvents(ctx context.Context, sessionID session.ID "session_id", string(sessionID), ) + sessionStartCh := make(chan apievents.AuditEvent, 1) + if watcher, err := EventsWatcherFromContext(ctx); err == nil { + go func() { + select { + case evt, ok := <-sessionStartCh: + if !ok { + watcher.OnSessionStart(nil, trace.NotFound("session start event not found")) + return + } + + watcher.OnSessionStart(evt, nil) + } + }() + } + go func() { defer rawSession.Close() + defer close(sessionStartCh) + // this shouldn't be necessary as the position should be already 0 (Download // takes an io.WriterAt), but it's better to be safe than sorry if _, err := rawSession.Seek(0, io.SeekStart); err != nil { @@ -556,6 +573,7 @@ func (l *AuditLog) StreamSessionEvents(ctx context.Context, sessionID session.ID protoReader := NewProtoReader(rawSession) + firstEvent := true for { if ctx.Err() != nil { e <- trace.Wrap(ctx.Err()) @@ -572,6 +590,11 @@ func (l *AuditLog) StreamSessionEvents(ctx context.Context, sessionID session.ID return } + if firstEvent { + sessionStartCh <- event + firstEvent = false + } + if event.GetIndex() >= startIndex { select { case c <- event: @@ -666,3 +689,36 @@ func (l *AuditLog) periodicSpaceMonitor() { } } } + +// TODO +type streamSessionEventsContextKey string + +const ( + // TODO + eventWatcherContextKey streamSessionEventsContextKey = "watcher" +) + +// TODO +type EventsWatcher interface { + // TODO + OnSessionStart(evt apievents.AuditEvent, err error) +} + +// TODO +func ContextWithEventWatcher(ctx context.Context, watcher EventsWatcher) context.Context { + return context.WithValue(ctx, eventWatcherContextKey, watcher) +} + +// TODO +func EventsWatcherFromContext(ctx context.Context) (EventsWatcher, error) { + if ctx == nil { + return nil, trace.BadParameter("context is nil") + } + + watcher, ok := ctx.Value(eventWatcherContextKey).(EventsWatcher) + if !ok { + return nil, trace.BadParameter("events watcher was not found in the context") + } + + return watcher, nil +} From f3315a1bdca345af5840c886521ba702ac0e0438 Mon Sep 17 00:00:00 2001 From: Gabriel Corado Date: Mon, 23 Dec 2024 10:32:37 -0300 Subject: [PATCH 2/5] refactor(auth): early return when is teleport server --- lib/auth/auth_with_roles.go | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/lib/auth/auth_with_roles.go b/lib/auth/auth_with_roles.go index c0574d54c016a..b6d7731c00d7f 100644 --- a/lib/auth/auth_with_roles.go +++ b/lib/auth/auth_with_roles.go @@ -6021,12 +6021,16 @@ func (a *ServerWithRoles) StreamSessionEvents(ctx context.Context, sessionID ses err := a.localServerAction() isTeleportServer := err == nil - if !isTeleportServer { - if err := a.actionForKindSession(ctx, sessionID); err != nil { - c, e := make(chan apievents.AuditEvent), make(chan error, 1) - e <- trace.Wrap(err) - return c, e - } + // StreamSessionEvents can be called internally, and when that + // happens we don't want to emit an event or check for permissions. + if isTeleportServer { + return a.alog.StreamSessionEvents(ctx, sessionID, startIndex) + } + + if err := a.actionForKindSession(ctx, sessionID); err != nil { + c, e := make(chan apievents.AuditEvent), make(chan error, 1) + e <- trace.Wrap(err) + return c, e } // We can only determine the session type after the streaming started. For @@ -6034,12 +6038,6 @@ func (a *ServerWithRoles) StreamSessionEvents(ctx context.Context, sessionID ses // the streaming returns an error. watcher := &eventsWatcher{ onSessionStart: func(evt apievents.AuditEvent, _ error) { - // StreamSessionEvents can be called internally, and when that - // happens we don't want to emit an event. - if isTeleportServer { - return - } - if err := a.authServer.emitter.EmitAuditEvent(a.authServer.closeCtx, &apievents.SessionRecordingAccess{ Metadata: apievents.Metadata{ Type: events.SessionRecordingAccessEvent, From cd3b3bcd2f5169a8284dc8f6e2c467219502dddb Mon Sep 17 00:00:00 2001 From: Gabriel Corado Date: Mon, 23 Dec 2024 11:58:21 -0300 Subject: [PATCH 3/5] refactor: move to function and add tests --- lib/auth/auth_with_roles.go | 38 ++++----- lib/auth/auth_with_roles_test.go | 47 ++++++++--- lib/events/auditlog.go | 68 ++++++++-------- lib/events/auditlog_test.go | 131 +++++++++++++++++++++++++++++++ 4 files changed, 220 insertions(+), 64 deletions(-) diff --git a/lib/auth/auth_with_roles.go b/lib/auth/auth_with_roles.go index b6d7731c00d7f..c78626ae7ac1b 100644 --- a/lib/auth/auth_with_roles.go +++ b/lib/auth/auth_with_roles.go @@ -6036,32 +6036,22 @@ func (a *ServerWithRoles) StreamSessionEvents(ctx context.Context, sessionID ses // We can only determine the session type after the streaming started. For // this reason, we delay the emit audit event until the first event or if // the streaming returns an error. - watcher := &eventsWatcher{ - onSessionStart: func(evt apievents.AuditEvent, _ error) { - if err := a.authServer.emitter.EmitAuditEvent(a.authServer.closeCtx, &apievents.SessionRecordingAccess{ - Metadata: apievents.Metadata{ - Type: events.SessionRecordingAccessEvent, - Code: events.SessionRecordingAccessCode, - }, - SessionID: sessionID.String(), - UserMetadata: a.context.Identity.GetIdentity().GetUserMetadata(), - SessionType: string(sessionTypeFromStartEvent(evt)), - Format: metadata.SessionRecordingFormatFromContext(ctx), - }); err != nil { - log.WithError(err).Errorf("Failed to emit stream session event audit event") - } - }, + cb := func(evt apievents.AuditEvent, _ error) { + if err := a.authServer.emitter.EmitAuditEvent(a.authServer.closeCtx, &apievents.SessionRecordingAccess{ + Metadata: apievents.Metadata{ + Type: events.SessionRecordingAccessEvent, + Code: events.SessionRecordingAccessCode, + }, + SessionID: sessionID.String(), + UserMetadata: a.context.Identity.GetIdentity().GetUserMetadata(), + SessionType: string(sessionTypeFromStartEvent(evt)), + Format: metadata.SessionRecordingFormatFromContext(ctx), + }); err != nil { + log.WithError(err).Errorf("Failed to emit stream session event audit event") + } } - return a.alog.StreamSessionEvents(events.ContextWithEventWatcher(ctx, watcher), sessionID, startIndex) -} - -type eventsWatcher struct { - onSessionStart func(apievents.AuditEvent, error) -} - -func (e *eventsWatcher) OnSessionStart(evt apievents.AuditEvent, err error) { - e.onSessionStart(evt, err) + return a.alog.StreamSessionEvents(events.ContextWithSessionStartCallback(ctx, cb), sessionID, startIndex) } // sessionTypeFromStartEvent determines the session type given the session start diff --git a/lib/auth/auth_with_roles_test.go b/lib/auth/auth_with_roles_test.go index 1576986648191..68dd311cad2c4 100644 --- a/lib/auth/auth_with_roles_test.go +++ b/lib/auth/auth_with_roles_test.go @@ -2267,7 +2267,29 @@ func TestStreamSessionEvents(t *testing.T) { func TestStreamSessionEvents_SessionType(t *testing.T) { t.Parallel() - srv := newTestTLSServer(t) + authServerConfig := TestAuthServerConfig{ + Dir: t.TempDir(), + Clock: clockwork.NewFakeClockAt(time.Now().Round(time.Second).UTC()), + } + require.NoError(t, authServerConfig.CheckAndSetDefaults()) + + uploader := eventstest.NewMemoryUploader() + localLog, err := events.NewAuditLog(events.AuditLogConfig{ + DataDir: authServerConfig.Dir, + ServerID: authServerConfig.ClusterName, + Clock: authServerConfig.Clock, + UploadHandler: uploader, + }) + require.NoError(t, err) + authServerConfig.AuditLog = localLog + + as, err := NewTestAuthServer(authServerConfig) + require.NoError(t, err) + + srv, err := as.NewTestTLSServer() + require.NoError(t, err) + t.Cleanup(func() { srv.Close() }) + ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) @@ -2278,19 +2300,26 @@ func TestStreamSessionEvents_SessionType(t *testing.T) { identity := TestUser(user.GetName()) clt, err := srv.NewClient(identity) require.NoError(t, err) - sessionID := "44c6cea8-362f-11ea-83aa-125400432324" + sessionID := session.NewID() - // Emitting a session end event will cause the listing to correctly locate - // the recording (even if there might not be a recording file to stream). - require.NoError(t, srv.Auth().EmitAuditEvent(ctx, &apievents.DatabaseSessionEnd{ + streamer, err := events.NewProtoStreamer(events.ProtoStreamerConfig{ + Uploader: uploader, + }) + require.NoError(t, err) + stream, err := streamer.CreateAuditStream(ctx, sessionID) + require.NoError(t, err) + // The event is not required to pass through the auth server, we only need + // the upload to be present. + require.NoError(t, stream.RecordEvent(ctx, eventstest.PrepareEvent(&apievents.DatabaseSessionStart{ Metadata: apievents.Metadata{ - Type: events.DatabaseSessionEndEvent, - Code: events.DatabaseSessionEndCode, + Type: events.DatabaseSessionStartEvent, + Code: events.DatabaseSessionStartCode, }, SessionMetadata: apievents.SessionMetadata{ - SessionID: sessionID, + SessionID: sessionID.String(), }, - })) + }))) + require.NoError(t, stream.Complete(ctx)) accessedFormat := teleport.PTY clt.StreamSessionEvents(metadata.WithSessionRecordingFormatContext(ctx, accessedFormat), session.ID(sessionID), 0) diff --git a/lib/events/auditlog.go b/lib/events/auditlog.go index 3539a6cd19519..7ab26d6541a80 100644 --- a/lib/events/auditlog.go +++ b/lib/events/auditlog.go @@ -509,9 +509,25 @@ func (l *AuditLog) StreamSessionEvents(ctx context.Context, sessionID session.ID e := make(chan error, 1) c := make(chan apievents.AuditEvent) + sessionStartCh := make(chan apievents.AuditEvent, 1) + if startCb, err := SessionStartCallbackFromContext(ctx); err == nil { + go func() { + select { + case evt, ok := <-sessionStartCh: + if !ok { + startCb(nil, trace.NotFound("session start event not found")) + return + } + + startCb(evt, nil) + } + }() + } + rawSession, err := os.CreateTemp(l.playbackDir, string(sessionID)+".stream.tar.*") if err != nil { e <- trace.Wrap(trace.ConvertSystemError(err), "creating temporary stream file") + close(sessionStartCh) return c, e } // The file is still perfectly usable after unlinking it, and the space it's @@ -528,6 +544,7 @@ func (l *AuditLog) StreamSessionEvents(ctx context.Context, sessionID session.ID if err := os.Remove(rawSession.Name()); err != nil { _ = rawSession.Close() e <- trace.Wrap(trace.ConvertSystemError(err), "removing temporary stream file") + close(sessionStartCh) return c, e } @@ -538,6 +555,7 @@ func (l *AuditLog) StreamSessionEvents(ctx context.Context, sessionID session.ID err = trace.NotFound("a recording for session %v was not found", sessionID) } e <- trace.Wrap(err) + close(sessionStartCh) return c, e } l.log.DebugContext(ctx, "Downloaded session to a temporary file for streaming.", @@ -545,21 +563,6 @@ func (l *AuditLog) StreamSessionEvents(ctx context.Context, sessionID session.ID "session_id", string(sessionID), ) - sessionStartCh := make(chan apievents.AuditEvent, 1) - if watcher, err := EventsWatcherFromContext(ctx); err == nil { - go func() { - select { - case evt, ok := <-sessionStartCh: - if !ok { - watcher.OnSessionStart(nil, trace.NotFound("session start event not found")) - return - } - - watcher.OnSessionStart(evt, nil) - } - }() - } - go func() { defer rawSession.Close() defer close(sessionStartCh) @@ -690,35 +693,38 @@ func (l *AuditLog) periodicSpaceMonitor() { } } -// TODO +// streamSessionEventsContextKey represent context keys used by +// StreamSessionEvents function. type streamSessionEventsContextKey string const ( - // TODO - eventWatcherContextKey streamSessionEventsContextKey = "watcher" + // sessionStartCallbackContextKey is the context key used to store the + // session start callback function. + sessionStartCallbackContextKey streamSessionEventsContextKey = "session-start" ) -// TODO -type EventsWatcher interface { - // TODO - OnSessionStart(evt apievents.AuditEvent, err error) -} +// SessionStartCallback is the function used when streaming reaches the start +// event. If any error, such as session not found, the event will be nil, and +// the error will be set. +type SessionStartCallback func(startEvent apievents.AuditEvent, err error) -// TODO -func ContextWithEventWatcher(ctx context.Context, watcher EventsWatcher) context.Context { - return context.WithValue(ctx, eventWatcherContextKey, watcher) +// ContextWithSessionStartCallback returns a context.Context containing a +// session start event callback. +func ContextWithSessionStartCallback(ctx context.Context, cb SessionStartCallback) context.Context { + return context.WithValue(ctx, sessionStartCallbackContextKey, cb) } -// TODO -func EventsWatcherFromContext(ctx context.Context) (EventsWatcher, error) { +// SessionStartCallbackFromContext returns the session start callback from +// context.Context. +func SessionStartCallbackFromContext(ctx context.Context) (SessionStartCallback, error) { if ctx == nil { return nil, trace.BadParameter("context is nil") } - watcher, ok := ctx.Value(eventWatcherContextKey).(EventsWatcher) + cb, ok := ctx.Value(sessionStartCallbackContextKey).(SessionStartCallback) if !ok { - return nil, trace.BadParameter("events watcher was not found in the context") + return nil, trace.BadParameter("session start callback function was not found in the context") } - return watcher, nil + return cb, nil } diff --git a/lib/events/auditlog_test.go b/lib/events/auditlog_test.go index b76d27a0ee36a..416373e3e6951 100644 --- a/lib/events/auditlog_test.go +++ b/lib/events/auditlog_test.go @@ -154,6 +154,137 @@ func TestConcurrentStreaming(t *testing.T) { } } +func TestStreamSessionEvents(t *testing.T) { + uploader := eventstest.NewMemoryUploader() + alog, err := events.NewAuditLog(events.AuditLogConfig{ + DataDir: t.TempDir(), + Clock: clockwork.NewFakeClock(), + ServerID: "remote", + UploadHandler: uploader, + }) + require.NoError(t, err) + t.Cleanup(func() { alog.Close() }) + + ctx := context.Background() + sid := session.NewID() + sessionEvents := []apievents.AuditEvent{ + &apievents.DatabaseSessionStart{ + Metadata: apievents.Metadata{ + Type: events.DatabaseSessionStartEvent, + Code: events.DatabaseSessionStartCode, + Index: 0, + }, + SessionMetadata: apievents.SessionMetadata{ + SessionID: sid.String(), + }, + }, + &apievents.DatabaseSessionEnd{ + Metadata: apievents.Metadata{ + Type: events.DatabaseSessionEndEvent, + Code: events.DatabaseSessionEndCode, + Index: 1, + }, + SessionMetadata: apievents.SessionMetadata{ + SessionID: sid.String(), + }, + }, + } + + streamer, err := events.NewProtoStreamer(events.ProtoStreamerConfig{ + Uploader: uploader, + }) + require.NoError(t, err) + stream, err := streamer.CreateAuditStream(ctx, sid) + require.NoError(t, err) + for _, event := range sessionEvents { + require.NoError(t, stream.RecordEvent(ctx, eventstest.PrepareEvent(event))) + } + require.NoError(t, stream.Complete(ctx)) + + type callbackResult struct { + event apievents.AuditEvent + err error + } + + t.Run("Success", func(t *testing.T) { + for name, withCallback := range map[string]bool{ + "WithCallback": true, + "WithoutCallback": false, + } { + t.Run(name, func(t *testing.T) { + streamCtx, cancel := context.WithCancel(ctx) + defer cancel() + + callbackCh := make(chan callbackResult, 1) + if withCallback { + streamCtx = events.ContextWithSessionStartCallback(streamCtx, func(ae apievents.AuditEvent, err error) { + callbackCh <- callbackResult{ae, err} + }) + } + + ch, _ := alog.StreamSessionEvents(streamCtx, sid, 0) + for _, event := range sessionEvents { + select { + case receivedEvent := <-ch: + require.NotNil(t, receivedEvent) + require.Equal(t, event.GetCode(), receivedEvent.GetCode()) + require.Equal(t, event.GetType(), receivedEvent.GetType()) + case <-time.After(10 * time.Second): + require.Fail(t, "expected to receive session event %q but got nothing", event.GetType()) + } + } + + if withCallback { + select { + case res := <-callbackCh: + require.NoError(t, res.err) + require.Equal(t, sessionEvents[0].GetCode(), res.event.GetCode()) + require.Equal(t, sessionEvents[0].GetType(), res.event.GetType()) + case <-time.After(10 * time.Second): + require.Fail(t, "expected to receive callback result but got nothing") + } + } + }) + } + }) + + t.Run("Error", func(t *testing.T) { + for name, withCallback := range map[string]bool{ + "WithCallback": true, + "WithoutCallback": false, + } { + t.Run(name, func(t *testing.T) { + streamCtx, cancel := context.WithCancel(ctx) + defer cancel() + + callbackCh := make(chan callbackResult, 1) + if withCallback { + streamCtx = events.ContextWithSessionStartCallback(streamCtx, func(ae apievents.AuditEvent, err error) { + callbackCh <- callbackResult{ae, err} + }) + } + + _, errCh := alog.StreamSessionEvents(streamCtx, session.ID("random"), 0) + select { + case err := <-errCh: + require.Error(t, err) + case <-time.After(10 * time.Second): + require.Fail(t, "expected to get error while stream but got nothing") + } + + if withCallback { + select { + case res := <-callbackCh: + require.Error(t, res.err) + case <-time.After(10 * time.Second): + require.Fail(t, "expected to receive callback result but got nothing") + } + } + }) + } + }) +} + func TestExternalLog(t *testing.T) { m := &eventstest.MockAuditLog{ Emitter: &eventstest.MockRecorderEmitter{}, From bc76a1b5390556d3f897d5c826dc9130950bf403 Mon Sep 17 00:00:00 2001 From: Gabriel Corado Date: Mon, 23 Dec 2024 16:28:38 -0300 Subject: [PATCH 4/5] chore(lib): fix lint --- lib/auth/auth_with_roles_test.go | 2 +- lib/events/auditlog.go | 14 ++++++-------- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/lib/auth/auth_with_roles_test.go b/lib/auth/auth_with_roles_test.go index 68dd311cad2c4..f6ad5315cd6fb 100644 --- a/lib/auth/auth_with_roles_test.go +++ b/lib/auth/auth_with_roles_test.go @@ -2322,7 +2322,7 @@ func TestStreamSessionEvents_SessionType(t *testing.T) { require.NoError(t, stream.Complete(ctx)) accessedFormat := teleport.PTY - clt.StreamSessionEvents(metadata.WithSessionRecordingFormatContext(ctx, accessedFormat), session.ID(sessionID), 0) + clt.StreamSessionEvents(metadata.WithSessionRecordingFormatContext(ctx, accessedFormat), sessionID, 0) // Perform the listing an eventually loop to ensure the event is emitted. var searchEvents []apievents.AuditEvent diff --git a/lib/events/auditlog.go b/lib/events/auditlog.go index 4303db80a296d..220b8f316cbe1 100644 --- a/lib/events/auditlog.go +++ b/lib/events/auditlog.go @@ -512,15 +512,13 @@ func (l *AuditLog) StreamSessionEvents(ctx context.Context, sessionID session.ID sessionStartCh := make(chan apievents.AuditEvent, 1) if startCb, err := SessionStartCallbackFromContext(ctx); err == nil { go func() { - select { - case evt, ok := <-sessionStartCh: - if !ok { - startCb(nil, trace.NotFound("session start event not found")) - return - } - - startCb(evt, nil) + evt, ok := <-sessionStartCh + if !ok { + startCb(nil, trace.NotFound("session start event not found")) + return } + + startCb(evt, nil) }() } From 016d1e675efbbc0e50d1ad650c865a90a6eaf364 Mon Sep 17 00:00:00 2001 From: Gabriel Corado Date: Mon, 23 Dec 2024 22:24:53 -0300 Subject: [PATCH 5/5] refactor(events): make private function --- lib/events/auditlog.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/events/auditlog.go b/lib/events/auditlog.go index 220b8f316cbe1..274c3c65c56a6 100644 --- a/lib/events/auditlog.go +++ b/lib/events/auditlog.go @@ -510,7 +510,7 @@ func (l *AuditLog) StreamSessionEvents(ctx context.Context, sessionID session.ID c := make(chan apievents.AuditEvent) sessionStartCh := make(chan apievents.AuditEvent, 1) - if startCb, err := SessionStartCallbackFromContext(ctx); err == nil { + if startCb, err := sessionStartCallbackFromContext(ctx); err == nil { go func() { evt, ok := <-sessionStartCh if !ok { @@ -713,9 +713,9 @@ func ContextWithSessionStartCallback(ctx context.Context, cb SessionStartCallbac return context.WithValue(ctx, sessionStartCallbackContextKey, cb) } -// SessionStartCallbackFromContext returns the session start callback from +// sessionStartCallbackFromContext returns the session start callback from // context.Context. -func SessionStartCallbackFromContext(ctx context.Context) (SessionStartCallback, error) { +func sessionStartCallbackFromContext(ctx context.Context) (SessionStartCallback, error) { if ctx == nil { return nil, trace.BadParameter("context is nil") }