Skip to content

Commit

Permalink
Allow sending audio track along with screensharing (#32)
Browse files Browse the repository at this point in the history
  • Loading branch information
streamer45 authored Mar 21, 2022
1 parent 1659c8c commit e2ffece
Show file tree
Hide file tree
Showing 10 changed files with 107 additions and 39 deletions.
14 changes: 8 additions & 6 deletions server/channel_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@ type callStats struct {
}

type callState struct {
ID string `json:"id"`
StartAt int64 `json:"create_at"`
Users map[string]*userState `json:"users,omitempty"`
ThreadID string `json:"thread_id"`
ScreenSharingID string `json:"screen_sharing_id"`
Stats callStats `json:"stats"`
ID string `json:"id"`
StartAt int64 `json:"create_at"`
Users map[string]*userState `json:"users,omitempty"`
ThreadID string `json:"thread_id"`
ScreenSharingID string `json:"screen_sharing_id"`
ScreenTrackID string `json:"screen_track_id"`
ScreenAudioTrackID string `json:"screen_audio_track_id"`
Stats callStats `json:"stats"`
}

type channelState struct {
Expand Down
3 changes: 3 additions & 0 deletions server/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ type session struct {
outVoiceTrack *webrtc.TrackLocalStaticRTP
outVoiceTrackEnabled bool
outScreenTrack *webrtc.TrackLocalStaticRTP
outScreenAudioTrack *webrtc.TrackLocalStaticRTP
remoteScreenTrack *webrtc.TrackRemote
rtcConn *webrtc.PeerConnection
tracksCh chan *webrtc.TrackLocalStaticRTP
Expand Down Expand Up @@ -120,6 +121,8 @@ func (p *Plugin) removeUserSession(userID, channelID string) (channelState, chan

if state.Call.ScreenSharingID == userID {
state.Call.ScreenSharingID = ""
state.Call.ScreenTrackID = ""
state.Call.ScreenAudioTrackID = ""
if call := p.getCall(channelID); call != nil {
call.setScreenSession(nil)
}
Expand Down
65 changes: 47 additions & 18 deletions server/sfu.go
Original file line number Diff line number Diff line change
Expand Up @@ -354,26 +354,46 @@ func (p *Plugin) initRTCConn(userID string) {
p.LogDebug(fmt.Sprintf("%+v", remoteTrack.Codec().RTPCodecCapability))
p.LogDebug(fmt.Sprintf("Track has started, of type %d: %s", remoteTrack.PayloadType(), remoteTrack.Codec().MimeType))

trackID := remoteTrack.ID()
state, err := p.kvGetChannelState(userSession.channelID)
if err != nil {
p.LogError(err.Error())
return
}
if state.Call == nil {
p.LogError("call state should not be nil")
return
}

if remoteTrack.Codec().MimeType == rtpAudioCodec.MimeType {
outVoiceTrack, err := webrtc.NewTrackLocalStaticRTP(rtpAudioCodec, "voice", model.NewId())
trackType := "voice"
if trackID != "" && trackID == state.Call.ScreenAudioTrackID {
p.LogDebug("received screen sharing audio track")
trackType = "screen-audio"
}
outAudioTrack, err := webrtc.NewTrackLocalStaticRTP(rtpAudioCodec, trackType, model.NewId())
if err != nil {
p.LogError(err.Error())
return
}

userSession.mut.Lock()
userSession.outVoiceTrack = outVoiceTrack
userSession.outVoiceTrackEnabled = true
if trackType == "voice" {
userSession.outVoiceTrack = outAudioTrack
userSession.outVoiceTrackEnabled = true
} else {
userSession.outScreenAudioTrack = outAudioTrack
}
userSession.mut.Unlock()

p.iterSessions(userSession.channelID, func(s *session) {
if s.userID == userSession.userID {
return
}
select {
case s.tracksCh <- outVoiceTrack:
case s.tracksCh <- outAudioTrack:
default:
p.LogError("failed to send voice track, channel is full", "userID", userID, "trackUserID", s.userID)
p.LogError("failed to send audio track, channel is full", "userID", userID, "trackUserID", s.userID)
}
})

Expand All @@ -384,18 +404,19 @@ func (p *Plugin) initRTCConn(userID string) {
return
}

p.metrics.RTPPacketCounters.With(prometheus.Labels{"direction": "in", "type": "voice"}).Inc()
p.metrics.RTPPacketBytesCounters.With(prometheus.Labels{"direction": "in", "type": "voice"}).Add(float64(len(rtp.Payload)))

userSession.mut.RLock()
isEnabled := userSession.outVoiceTrackEnabled
userSession.mut.RUnlock()
p.metrics.RTPPacketCounters.With(prometheus.Labels{"direction": "in", "type": trackType}).Inc()
p.metrics.RTPPacketBytesCounters.With(prometheus.Labels{"direction": "in", "type": trackType}).Add(float64(len(rtp.Payload)))

if !isEnabled {
continue
if trackType == "voice" {
userSession.mut.RLock()
isEnabled := userSession.outVoiceTrackEnabled
userSession.mut.RUnlock()
if !isEnabled {
continue
}
}

if err := outVoiceTrack.WriteRTP(rtp); err != nil && !errors.Is(err, io.ErrClosedPipe) {
if err := outAudioTrack.WriteRTP(rtp); err != nil && !errors.Is(err, io.ErrClosedPipe) {
p.LogError(err.Error())
return
}
Expand All @@ -405,13 +426,18 @@ func (p *Plugin) initRTCConn(userID string) {
if s.userID == userSession.userID {
return
}
p.metrics.RTPPacketCounters.With(prometheus.Labels{"direction": "out", "type": "voice"}).Inc()
p.metrics.RTPPacketBytesCounters.With(prometheus.Labels{"direction": "out", "type": "voice"}).Add(float64(len(rtp.Payload)))
p.metrics.RTPPacketCounters.With(prometheus.Labels{"direction": "out", "type": trackType}).Inc()
p.metrics.RTPPacketBytesCounters.With(prometheus.Labels{"direction": "out", "type": trackType}).Add(float64(len(rtp.Payload)))
})

}
} else if remoteTrack.Codec().MimeType == rtpVideoCodecVP8.MimeType {
// TODO: actually check if the userID matches the expected publisher.
if trackID == "" || trackID != state.Call.ScreenTrackID {
p.LogError("received unexpected video track", "trackID", trackID)
return
}

p.LogDebug("received screen sharing track")
call := p.getCall(userSession.channelID)
if call == nil {
p.LogError("call should not be nil")
Expand All @@ -422,7 +448,6 @@ func (p *Plugin) initRTCConn(userID string) {
return
}
call.setScreenSession(userSession)

p.API.PublishWebSocketEvent(wsEventUserScreenOn, map[string]interface{}{
"userID": userID,
}, &model.WebsocketBroadcast{ChannelId: userSession.channelID})
Expand Down Expand Up @@ -498,13 +523,17 @@ func (p *Plugin) handleTracks(us *session) {
outVoiceTrack := s.outVoiceTrack
isEnabled := s.outVoiceTrackEnabled
outScreenTrack := s.outScreenTrack
outScreenAudioTrack := s.outScreenAudioTrack
s.mut.RUnlock()
if outVoiceTrack != nil {
p.addTrack(us, outVoiceTrack, isEnabled)
}
if outScreenTrack != nil {
p.addTrack(us, outScreenTrack, true)
}
if outScreenAudioTrack != nil {
p.addTrack(us, outScreenAudioTrack, true)
}
})

for {
Expand Down
5 changes: 4 additions & 1 deletion server/slash_command.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ func getAutocompleteData() *model.AutocompleteData {
data.AddCommand(model.NewAutocompleteData(joinCommandTrigger, "", "Joins or starts a call in the current channel"))
data.AddCommand(model.NewAutocompleteData(leaveCommandTrigger, "", "Leaves a call in the current channel"))
data.AddCommand(model.NewAutocompleteData(linkCommandTrigger, "", "Generates a link to join a call in the current channel"))
data.AddCommand(model.NewAutocompleteData(experimentalCommandTrigger, "", "Turns on/off experimental features"))

experimentalCmdData := model.NewAutocompleteData(experimentalCommandTrigger, "", "Turns on/off experimental features")
experimentalCmdData.AddTextArgument("Available options: on, off", "", "on|off")
data.AddCommand(experimentalCmdData)
return data
}

Expand Down
16 changes: 14 additions & 2 deletions server/websocket.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"encoding/json"
"fmt"
"strings"
"sync"
Expand Down Expand Up @@ -30,6 +31,13 @@ const (
)

func (p *Plugin) handleClientMessageTypeScreen(msg clientMessage, channelID, userID string) error {
data := map[string]string{}
if msg.Type == clientMessageTypeScreenOn {
if err := json.Unmarshal(msg.Data, &data); err != nil {
p.LogError(err.Error())
}
}

if err := p.kvSetAtomicChannelState(channelID, func(state *channelState) (*channelState, error) {
if state == nil {
return nil, fmt.Errorf("channel state is missing from store")
Expand All @@ -43,11 +51,15 @@ func (p *Plugin) handleClientMessageTypeScreen(msg clientMessage, channelID, use
return nil, fmt.Errorf("cannot start screen sharing, someone else is sharing already: %q", state.Call.ScreenSharingID)
}
state.Call.ScreenSharingID = userID
state.Call.ScreenTrackID = data["screenTrackID"]
state.Call.ScreenAudioTrackID = data["screenAudioTrackID"]
} else {
if state.Call.ScreenSharingID != userID {
return nil, fmt.Errorf("cannot stop screen sharing, someone else is sharing already: %q", state.Call.ScreenSharingID)
}
state.Call.ScreenSharingID = ""
state.Call.ScreenTrackID = ""
state.Call.ScreenAudioTrackID = ""
if call := p.getCall(channelID); call != nil {
call.setScreenSession(nil)
}
Expand Down Expand Up @@ -457,10 +469,10 @@ func (p *Plugin) WebSocketMessageHasBeenPosted(connID, userID string, req *model
return
}
msg.Data = data
case clientMessageTypeICE:
case clientMessageTypeICE, clientMessageTypeScreenOn:
msgData, ok := req.Data["data"].(string)
if !ok {
p.LogError("invalid or missing ice data")
p.LogError("invalid or missing data")
return
}
msg.Data = []byte(msgData)
Expand Down
25 changes: 21 additions & 4 deletions webapp/src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ export default class CallsClient extends EventEmitter {
if (remoteStream.getAudioTracks().length > 0) {
this.emit('remoteVoiceStream', remoteStream);
} else if (remoteStream.getVideoTracks().length > 0) {
console.log(remoteStream.getTracks());
this.emit('remoteScreenStream', remoteStream);
this.remoteScreenTrack = remoteStream.getVideoTracks()[0];
}
Expand Down Expand Up @@ -388,10 +389,21 @@ export default class CallsClient extends EventEmitter {

const screenTrack = screenStream.getVideoTracks()[0];
this.localScreenTrack = screenTrack;
screenStream = new MediaStream([screenTrack]);

const screenAudioTrack = screenStream.getAudioTracks()[0];
if (screenAudioTrack) {
screenStream = new MediaStream([screenTrack, screenAudioTrack]);
} else {
screenStream = new MediaStream([screenTrack]);
}

this.streams.push(screenStream);

screenTrack.onended = () => {
if (screenAudioTrack) {
screenAudioTrack.stop();
}

this.localScreenTrack = null;

if (!this.ws || !this.peer) {
Expand All @@ -405,15 +417,20 @@ export default class CallsClient extends EventEmitter {

this.peer.addStream(screenStream);

this.ws.send('screen_on');
this.ws.send('screen_on', {
data: JSON.stringify({
screenTrackID: screenTrack.id,
screenAudioTrackID: screenAudioTrack?.id,
}),
});
}

public async shareScreen(sourceID?: string) {
public async shareScreen(sourceID?: string, withAudio?: boolean) {
if (!this.ws || !this.peer) {
return null;
}

const screenStream = await getScreenStream(sourceID);
const screenStream = await getScreenStream(sourceID, withAudio);
if (screenStream === null) {
return null;
}
Expand Down
4 changes: 2 additions & 2 deletions webapp/src/components/call_widget/component.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import {IDMappedObjects} from 'mattermost-redux/types/utilities';
import {changeOpacity} from 'mattermost-redux/utils/theme_utils';

import {UserState} from 'src/types/types';
import {getUserDisplayName, isPublicChannel, isPrivateChannel, isDMChannel, isGMChannel} from 'src/utils';
import {getUserDisplayName, isPublicChannel, isPrivateChannel, isDMChannel, isGMChannel, hasExperimentalFlag} from 'src/utils';

import Avatar from '../avatar/avatar';
import {pluginId} from '../../manifest';
Expand Down Expand Up @@ -383,7 +383,7 @@ export default class CallWidget extends React.PureComponent<Props, State> {
if (window.desktop && compareSemVer(window.desktop.version, '5.1.0') >= 0) {
this.props.showScreenSourceModal();
} else {
const stream = await window.callsClient.shareScreen();
const stream = await window.callsClient.shareScreen('', hasExperimentalFlag());
state.screenStream = stream;
}
}
Expand Down
4 changes: 2 additions & 2 deletions webapp/src/components/expanded_view/component.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import {OverlayTrigger, Tooltip} from 'react-bootstrap';
import {UserProfile} from 'mattermost-redux/types/users';
import {Channel} from 'mattermost-redux/types/channels';

import {getUserDisplayName, getScreenStream, isDMChannel} from 'src/utils';
import {getUserDisplayName, getScreenStream, isDMChannel, hasExperimentalFlag} from 'src/utils';
import {UserState} from 'src/types/types';

import Avatar from '../avatar/avatar';
Expand Down Expand Up @@ -102,7 +102,7 @@ export default class ExpandedView extends React.PureComponent<Props, State> {
if (window.desktop && compareSemVer(window.desktop.version, '5.1.0') >= 0) {
this.props.showScreenSourceModal();
} else {
const stream = await getScreenStream();
const stream = await getScreenStream('', hasExperimentalFlag());
if (window.opener && stream) {
window.screenSharingTrackId = stream.getVideoTracks()[0].id;
}
Expand Down
4 changes: 3 additions & 1 deletion webapp/src/components/screen_source_modal/component.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import {Channel} from 'mattermost-redux/types/channels';

import {changeOpacity} from 'mattermost-redux/utils/theme_utils';

import {hasExperimentalFlag} from '../../utils';

import CompassIcon from '../../components/icons/compassIcon';

import './component.scss';
Expand Down Expand Up @@ -152,7 +154,7 @@ export default class ScreenSourceModal extends React.PureComponent<Props, State>
}

private shareScreen = () => {
window.callsClient.shareScreen(this.state.selected);
window.callsClient.shareScreen(this.state.selected, hasExperimentalFlag());
this.hide();
}

Expand Down
6 changes: 3 additions & 3 deletions webapp/src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ export function stateSortProfiles(profiles: UserProfile[], statuses: {[key: stri
};
}

export async function getScreenStream(sourceID?: string): Promise<MediaStream|null> {
export async function getScreenStream(sourceID?: string, withAudio?: boolean): Promise<MediaStream|null> {
let screenStream: MediaStream|null = null;

if (window.desktop) {
Expand All @@ -203,10 +203,10 @@ export async function getScreenStream(sourceID?: string): Promise<MediaStream|nu
options.chromeMediaSourceId = sourceID;
}
screenStream = await navigator.mediaDevices.getUserMedia({
audio: false,
video: {
mandatory: options,
} as any,
audio: withAudio ? {mandatory: options} as any : false,
});
} catch (err) {
console.log(err);
Expand All @@ -217,7 +217,7 @@ export async function getScreenStream(sourceID?: string): Promise<MediaStream|nu
try {
screenStream = await navigator.mediaDevices.getDisplayMedia({
video: true,
audio: false,
audio: Boolean(withAudio),
});
} catch (err) {
console.log(err);
Expand Down

0 comments on commit e2ffece

Please sign in to comment.