Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Determine session type on stream based on stream watcher events #50395

Merged
1 change: 1 addition & 0 deletions api/types/session_tracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ const (
DatabaseSessionKind SessionKind = "db"
AppSessionKind SessionKind = "app"
WindowsDesktopSessionKind SessionKind = "desktop"
UnknownSessionKind SessionKind = ""
)

// SessionParticipantMode is the mode that determines what you can do when you join a session.
Expand Down
106 changes: 59 additions & 47 deletions lib/auth/auth_with_roles.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,30 +196,14 @@ func (a *ServerWithRoles) actionWithExtendedContext(kind, verb string, extendCon
// actionForKindSession is a special checker that grants access to session
// recordings. It can allow access to a specific recording based on the
// `where` section of the user's access rule for kind `session`.
func (a *ServerWithRoles) actionForKindSession(ctx context.Context, sid session.ID) (types.SessionKind, error) {
sessionEnd, err := a.findSessionEndEvent(ctx, sid)

extendContext := func(ctx *services.Context) error {
ctx.Session = sessionEnd
func (a *ServerWithRoles) actionForKindSession(ctx context.Context, sid session.ID) error {
gabrielcorado marked this conversation as resolved.
Show resolved Hide resolved
extendContext := func(servicesCtx *services.Context) error {
sessionEnd, err := a.findSessionEndEvent(ctx, sid)
servicesCtx.Session = sessionEnd
return trace.Wrap(err)
}

var sessionKind types.SessionKind
switch e := sessionEnd.(type) {
case *apievents.SessionEnd:
sessionKind = types.SSHSessionKind
if e.KubernetesCluster != "" {
sessionKind = types.KubernetesSessionKind
}
case *apievents.DatabaseSessionEnd:
sessionKind = types.DatabaseSessionKind
case *apievents.AppSessionEnd:
sessionKind = types.AppSessionKind
case *apievents.WindowsDesktopSessionEnd:
sessionKind = types.WindowsDesktopSessionKind
}

return sessionKind, trace.Wrap(a.actionWithExtendedContext(types.KindSession, types.VerbRead, extendContext))
return trace.Wrap(a.actionWithExtendedContext(types.KindSession, types.VerbRead, extendContext))
}

// localServerAction returns an access denied error if the role is not one of the builtin server roles.
Expand Down Expand Up @@ -6034,44 +6018,72 @@ func (a *ServerWithRoles) ReplaceRemoteLocks(ctx context.Context, clusterName st
// channel if one is encountered. Otherwise the event channel is closed when the stream ends.
// The event channel is not closed on error to prevent race conditions in downstream select statements.
func (a *ServerWithRoles) StreamSessionEvents(ctx context.Context, sessionID session.ID, startIndex int64) (chan apievents.AuditEvent, chan error) {
createErrorChannel := func(err error) (chan apievents.AuditEvent, chan error) {
e := make(chan error, 1)
e <- trace.Wrap(err)
return nil, e
}

err := a.localServerAction()
isTeleportServer := err == nil

var sessionType types.SessionKind
if !isTeleportServer {
var err error
sessionType, err = a.actionForKindSession(ctx, sessionID)
if err != nil {
if err := a.actionForKindSession(ctx, sessionID); err != nil {
c, e := make(chan apievents.AuditEvent), make(chan error, 1)
e <- trace.Wrap(err)
return c, e
}
}

// StreamSessionEvents can be called internally, and when that happens we don't want to emit an event.
shouldEmitAuditEvent := !isTeleportServer
if shouldEmitAuditEvent {
if err := a.authServer.emitter.EmitAuditEvent(a.authServer.closeCtx, &apievents.SessionRecordingAccess{
Metadata: apievents.Metadata{
Type: events.SessionRecordingAccessEvent,
Code: events.SessionRecordingAccessCode,
},
SessionID: sessionID.String(),
UserMetadata: a.context.Identity.GetIdentity().GetUserMetadata(),
SessionType: string(sessionType),
Format: metadata.SessionRecordingFormatFromContext(ctx),
}); err != nil {
return createErrorChannel(err)
}
// We can only determine the session type after the streaming started. For
// this reason, we delay the emit audit event until the first event or if
// the streaming returns an error.
watcher := &eventsWatcher{
onSessionStart: func(evt apievents.AuditEvent, _ error) {
// StreamSessionEvents can be called internally, and when that
// happens we don't want to emit an event.
if isTeleportServer {
gabrielcorado marked this conversation as resolved.
Show resolved Hide resolved
return
}

if err := a.authServer.emitter.EmitAuditEvent(a.authServer.closeCtx, &apievents.SessionRecordingAccess{
Metadata: apievents.Metadata{
Type: events.SessionRecordingAccessEvent,
Code: events.SessionRecordingAccessCode,
},
SessionID: sessionID.String(),
UserMetadata: a.context.Identity.GetIdentity().GetUserMetadata(),
SessionType: string(sessionTypeFromStartEvent(evt)),
Format: metadata.SessionRecordingFormatFromContext(ctx),
}); err != nil {
log.WithError(err).Errorf("Failed to emit stream session event audit event")
}
},
}

return a.alog.StreamSessionEvents(ctx, sessionID, startIndex)
return a.alog.StreamSessionEvents(events.ContextWithEventWatcher(ctx, watcher), sessionID, startIndex)
}

type eventsWatcher struct {
onSessionStart func(apievents.AuditEvent, error)
}
gabrielcorado marked this conversation as resolved.
Show resolved Hide resolved

func (e *eventsWatcher) OnSessionStart(evt apievents.AuditEvent, err error) {
e.onSessionStart(evt, err)
}

// sessionTypeFromStartEvent determines the session type given the session start
// event.
func sessionTypeFromStartEvent(sessionStart apievents.AuditEvent) types.SessionKind {
switch e := sessionStart.(type) {
case *apievents.SessionStart:
if e.KubernetesCluster != "" {
return types.KubernetesSessionKind
}
return types.SSHSessionKind
case *apievents.DatabaseSessionStart:
return types.DatabaseSessionKind
case *apievents.AppSessionStart:
return types.AppSessionKind
case *apievents.WindowsDesktopSessionStart:
return types.WindowsDesktopSessionKind
default:
return types.UnknownSessionKind
}
}

// CreateApp creates a new application resource.
Expand Down
56 changes: 56 additions & 0 deletions lib/events/auditlog.go
Original file line number Diff line number Diff line change
Expand Up @@ -545,8 +545,25 @@ func (l *AuditLog) StreamSessionEvents(ctx context.Context, sessionID session.ID
"session_id", string(sessionID),
)

sessionStartCh := make(chan apievents.AuditEvent, 1)
if watcher, err := EventsWatcherFromContext(ctx); err == nil {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

error cases above are missing the callback?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated it. Now the error cases closes the channel (causing the callback to called if any).

go func() {
select {
case evt, ok := <-sessionStartCh:
if !ok {
watcher.OnSessionStart(nil, trace.NotFound("session start event not found"))
return
}

watcher.OnSessionStart(evt, nil)
}
}()
}

go func() {
defer rawSession.Close()
defer close(sessionStartCh)

// this shouldn't be necessary as the position should be already 0 (Download
// takes an io.WriterAt), but it's better to be safe than sorry
if _, err := rawSession.Seek(0, io.SeekStart); err != nil {
Expand All @@ -556,6 +573,7 @@ func (l *AuditLog) StreamSessionEvents(ctx context.Context, sessionID session.ID

protoReader := NewProtoReader(rawSession)

firstEvent := true
for {
if ctx.Err() != nil {
e <- trace.Wrap(ctx.Err())
Expand All @@ -572,6 +590,11 @@ func (l *AuditLog) StreamSessionEvents(ctx context.Context, sessionID session.ID
return
}

if firstEvent {
sessionStartCh <- event
firstEvent = false
}

if event.GetIndex() >= startIndex {
select {
case c <- event:
Expand Down Expand Up @@ -666,3 +689,36 @@ func (l *AuditLog) periodicSpaceMonitor() {
}
}
}

// TODO
type streamSessionEventsContextKey string

const (
// TODO
eventWatcherContextKey streamSessionEventsContextKey = "watcher"
)

// TODO
type EventsWatcher interface {
// TODO
OnSessionStart(evt apievents.AuditEvent, err error)
}

// TODO
func ContextWithEventWatcher(ctx context.Context, watcher EventsWatcher) context.Context {
return context.WithValue(ctx, eventWatcherContextKey, watcher)
}

// TODO
func EventsWatcherFromContext(ctx context.Context) (EventsWatcher, error) {
if ctx == nil {
return nil, trace.BadParameter("context is nil")
}

watcher, ok := ctx.Value(eventWatcherContextKey).(EventsWatcher)
if !ok {
return nil, trace.BadParameter("events watcher was not found in the context")
}

return watcher, nil
}
Loading