Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Determine session type on stream based on stream watcher events #50395

Merged
1 change: 1 addition & 0 deletions api/types/session_tracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ const (
DatabaseSessionKind SessionKind = "db"
AppSessionKind SessionKind = "app"
WindowsDesktopSessionKind SessionKind = "desktop"
UnknownSessionKind SessionKind = ""
)

// SessionParticipantMode is the mode that determines what you can do when you join a session.
Expand Down
84 changes: 42 additions & 42 deletions lib/auth/auth_with_roles.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,30 +196,14 @@ func (a *ServerWithRoles) actionWithExtendedContext(kind, verb string, extendCon
// actionForKindSession is a special checker that grants access to session
// recordings. It can allow access to a specific recording based on the
// `where` section of the user's access rule for kind `session`.
func (a *ServerWithRoles) actionForKindSession(ctx context.Context, sid session.ID) (types.SessionKind, error) {
sessionEnd, err := a.findSessionEndEvent(ctx, sid)

extendContext := func(ctx *services.Context) error {
ctx.Session = sessionEnd
func (a *ServerWithRoles) actionForKindSession(ctx context.Context, sid session.ID) error {
gabrielcorado marked this conversation as resolved.
Show resolved Hide resolved
extendContext := func(servicesCtx *services.Context) error {
sessionEnd, err := a.findSessionEndEvent(ctx, sid)
servicesCtx.Session = sessionEnd
return trace.Wrap(err)
}

var sessionKind types.SessionKind
switch e := sessionEnd.(type) {
case *apievents.SessionEnd:
sessionKind = types.SSHSessionKind
if e.KubernetesCluster != "" {
sessionKind = types.KubernetesSessionKind
}
case *apievents.DatabaseSessionEnd:
sessionKind = types.DatabaseSessionKind
case *apievents.AppSessionEnd:
sessionKind = types.AppSessionKind
case *apievents.WindowsDesktopSessionEnd:
sessionKind = types.WindowsDesktopSessionKind
}

return sessionKind, trace.Wrap(a.actionWithExtendedContext(types.KindSession, types.VerbRead, extendContext))
return trace.Wrap(a.actionWithExtendedContext(types.KindSession, types.VerbRead, extendContext))
}

// localServerAction returns an access denied error if the role is not one of the builtin server roles.
Expand Down Expand Up @@ -6121,44 +6105,60 @@ func (a *ServerWithRoles) ReplaceRemoteLocks(ctx context.Context, clusterName st
// channel if one is encountered. Otherwise the event channel is closed when the stream ends.
// The event channel is not closed on error to prevent race conditions in downstream select statements.
func (a *ServerWithRoles) StreamSessionEvents(ctx context.Context, sessionID session.ID, startIndex int64) (chan apievents.AuditEvent, chan error) {
createErrorChannel := func(err error) (chan apievents.AuditEvent, chan error) {
e := make(chan error, 1)
e <- trace.Wrap(err)
return nil, e
}

err := a.localServerAction()
isTeleportServer := err == nil

var sessionType types.SessionKind
if !isTeleportServer {
var err error
sessionType, err = a.actionForKindSession(ctx, sessionID)
if err != nil {
c, e := make(chan apievents.AuditEvent), make(chan error, 1)
e <- trace.Wrap(err)
return c, e
}
// StreamSessionEvents can be called internally, and when that
// happens we don't want to emit an event or check for permissions.
if isTeleportServer {
return a.alog.StreamSessionEvents(ctx, sessionID, startIndex)
}

// StreamSessionEvents can be called internally, and when that happens we don't want to emit an event.
shouldEmitAuditEvent := !isTeleportServer
if shouldEmitAuditEvent {
if err := a.actionForKindSession(ctx, sessionID); err != nil {
c, e := make(chan apievents.AuditEvent), make(chan error, 1)
e <- trace.Wrap(err)
return c, e
}

// We can only determine the session type after the streaming started. For
// this reason, we delay the emit audit event until the first event or if
// the streaming returns an error.
cb := func(evt apievents.AuditEvent, _ error) {
if err := a.authServer.emitter.EmitAuditEvent(a.authServer.closeCtx, &apievents.SessionRecordingAccess{
Metadata: apievents.Metadata{
Type: events.SessionRecordingAccessEvent,
Code: events.SessionRecordingAccessCode,
},
SessionID: sessionID.String(),
UserMetadata: a.context.Identity.GetIdentity().GetUserMetadata(),
SessionType: string(sessionType),
SessionType: string(sessionTypeFromStartEvent(evt)),
Format: metadata.SessionRecordingFormatFromContext(ctx),
}); err != nil {
return createErrorChannel(err)
log.WithError(err).Errorf("Failed to emit stream session event audit event")
}
}

return a.alog.StreamSessionEvents(ctx, sessionID, startIndex)
return a.alog.StreamSessionEvents(events.ContextWithSessionStartCallback(ctx, cb), sessionID, startIndex)
}

// sessionTypeFromStartEvent determines the session type given the session start
// event.
func sessionTypeFromStartEvent(sessionStart apievents.AuditEvent) types.SessionKind {
switch e := sessionStart.(type) {
case *apievents.SessionStart:
if e.KubernetesCluster != "" {
return types.KubernetesSessionKind
}
return types.SSHSessionKind
case *apievents.DatabaseSessionStart:
return types.DatabaseSessionKind
case *apievents.AppSessionStart:
return types.AppSessionKind
case *apievents.WindowsDesktopSessionStart:
return types.WindowsDesktopSessionKind
default:
return types.UnknownSessionKind
}
}

// CreateApp creates a new application resource.
Expand Down
49 changes: 39 additions & 10 deletions lib/auth/auth_with_roles_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2267,7 +2267,29 @@ func TestStreamSessionEvents(t *testing.T) {
func TestStreamSessionEvents_SessionType(t *testing.T) {
t.Parallel()

srv := newTestTLSServer(t)
authServerConfig := TestAuthServerConfig{
Dir: t.TempDir(),
Clock: clockwork.NewFakeClockAt(time.Now().Round(time.Second).UTC()),
}
require.NoError(t, authServerConfig.CheckAndSetDefaults())

uploader := eventstest.NewMemoryUploader()
localLog, err := events.NewAuditLog(events.AuditLogConfig{
DataDir: authServerConfig.Dir,
ServerID: authServerConfig.ClusterName,
Clock: authServerConfig.Clock,
UploadHandler: uploader,
})
require.NoError(t, err)
authServerConfig.AuditLog = localLog

as, err := NewTestAuthServer(authServerConfig)
require.NoError(t, err)

srv, err := as.NewTestTLSServer()
require.NoError(t, err)
t.Cleanup(func() { srv.Close() })

ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)

Expand All @@ -2278,22 +2300,29 @@ func TestStreamSessionEvents_SessionType(t *testing.T) {
identity := TestUser(user.GetName())
clt, err := srv.NewClient(identity)
require.NoError(t, err)
sessionID := "44c6cea8-362f-11ea-83aa-125400432324"
sessionID := session.NewID()

// Emitting a session end event will cause the listing to correctly locate
// the recording (even if there might not be a recording file to stream).
require.NoError(t, srv.Auth().EmitAuditEvent(ctx, &apievents.DatabaseSessionEnd{
streamer, err := events.NewProtoStreamer(events.ProtoStreamerConfig{
Uploader: uploader,
})
require.NoError(t, err)
stream, err := streamer.CreateAuditStream(ctx, sessionID)
require.NoError(t, err)
// The event is not required to pass through the auth server, we only need
// the upload to be present.
require.NoError(t, stream.RecordEvent(ctx, eventstest.PrepareEvent(&apievents.DatabaseSessionStart{
Metadata: apievents.Metadata{
Type: events.DatabaseSessionEndEvent,
Code: events.DatabaseSessionEndCode,
Type: events.DatabaseSessionStartEvent,
Code: events.DatabaseSessionStartCode,
},
SessionMetadata: apievents.SessionMetadata{
SessionID: sessionID,
SessionID: sessionID.String(),
},
}))
})))
require.NoError(t, stream.Complete(ctx))

accessedFormat := teleport.PTY
clt.StreamSessionEvents(metadata.WithSessionRecordingFormatContext(ctx, accessedFormat), session.ID(sessionID), 0)
clt.StreamSessionEvents(metadata.WithSessionRecordingFormatContext(ctx, accessedFormat), sessionID, 0)

// Perform the listing an eventually loop to ensure the event is emitted.
var searchEvents []apievents.AuditEvent
Expand Down
60 changes: 60 additions & 0 deletions lib/events/auditlog.go
Original file line number Diff line number Diff line change
Expand Up @@ -509,9 +509,23 @@ func (l *AuditLog) StreamSessionEvents(ctx context.Context, sessionID session.ID
e := make(chan error, 1)
c := make(chan apievents.AuditEvent)

sessionStartCh := make(chan apievents.AuditEvent, 1)
if startCb, err := sessionStartCallbackFromContext(ctx); err == nil {
go func() {
evt, ok := <-sessionStartCh
if !ok {
startCb(nil, trace.NotFound("session start event not found"))
return
}

startCb(evt, nil)
}()
}

rawSession, err := os.CreateTemp(l.playbackDir, string(sessionID)+".stream.tar.*")
if err != nil {
e <- trace.Wrap(trace.ConvertSystemError(err), "creating temporary stream file")
close(sessionStartCh)
return c, e
}
// The file is still perfectly usable after unlinking it, and the space it's
Expand All @@ -528,6 +542,7 @@ func (l *AuditLog) StreamSessionEvents(ctx context.Context, sessionID session.ID
if err := os.Remove(rawSession.Name()); err != nil {
_ = rawSession.Close()
e <- trace.Wrap(trace.ConvertSystemError(err), "removing temporary stream file")
close(sessionStartCh)
return c, e
}

Expand All @@ -538,6 +553,7 @@ func (l *AuditLog) StreamSessionEvents(ctx context.Context, sessionID session.ID
err = trace.NotFound("a recording for session %v was not found", sessionID)
}
e <- trace.Wrap(err)
close(sessionStartCh)
return c, e
}
l.log.DebugContext(ctx, "Downloaded session to a temporary file for streaming.",
Expand All @@ -547,6 +563,8 @@ func (l *AuditLog) StreamSessionEvents(ctx context.Context, sessionID session.ID

go func() {
defer rawSession.Close()
defer close(sessionStartCh)

// this shouldn't be necessary as the position should be already 0 (Download
// takes an io.WriterAt), but it's better to be safe than sorry
if _, err := rawSession.Seek(0, io.SeekStart); err != nil {
Expand All @@ -557,6 +575,7 @@ func (l *AuditLog) StreamSessionEvents(ctx context.Context, sessionID session.ID
protoReader := NewProtoReader(rawSession)
defer protoReader.Close()

firstEvent := true
for {
if ctx.Err() != nil {
e <- trace.Wrap(ctx.Err())
Expand All @@ -573,6 +592,11 @@ func (l *AuditLog) StreamSessionEvents(ctx context.Context, sessionID session.ID
return
}

if firstEvent {
sessionStartCh <- event
firstEvent = false
}

if event.GetIndex() >= startIndex {
select {
case c <- event:
Expand Down Expand Up @@ -667,3 +691,39 @@ func (l *AuditLog) periodicSpaceMonitor() {
}
}
}

// streamSessionEventsContextKey represent context keys used by
// StreamSessionEvents function.
type streamSessionEventsContextKey string

const (
// sessionStartCallbackContextKey is the context key used to store the
// session start callback function.
sessionStartCallbackContextKey streamSessionEventsContextKey = "session-start"
)

// SessionStartCallback is the function used when streaming reaches the start
// event. If any error, such as session not found, the event will be nil, and
// the error will be set.
type SessionStartCallback func(startEvent apievents.AuditEvent, err error)

// ContextWithSessionStartCallback returns a context.Context containing a
// session start event callback.
func ContextWithSessionStartCallback(ctx context.Context, cb SessionStartCallback) context.Context {
return context.WithValue(ctx, sessionStartCallbackContextKey, cb)
}

// sessionStartCallbackFromContext returns the session start callback from
// context.Context.
func sessionStartCallbackFromContext(ctx context.Context) (SessionStartCallback, error) {
if ctx == nil {
return nil, trace.BadParameter("context is nil")
}

cb, ok := ctx.Value(sessionStartCallbackContextKey).(SessionStartCallback)
if !ok {
return nil, trace.BadParameter("session start callback function was not found in the context")
}

return cb, nil
}
Loading
Loading