diff --git a/Makefile b/Makefile index 9cffe29fd..ae96a8579 100644 --- a/Makefile +++ b/Makefile @@ -26,6 +26,6 @@ build_allrpc: go_init test: go_init go test \ - -timeout 120s \ + -timeout 240s \ -coverprofile=cover.out -covermode=atomic \ -v -race ${GO_TESTPKGS} diff --git a/go.mod b/go.mod index 7ed16e458..922d0d88a 100644 --- a/go.mod +++ b/go.mod @@ -8,19 +8,20 @@ require ( github.com/gammazero/deque v0.1.0 github.com/gammazero/workerpool v1.1.2 github.com/go-logr/logr v1.0.0 + github.com/google/uuid v1.3.0 // indirect github.com/gorilla/websocket v1.4.2 github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 github.com/improbable-eng/grpc-web v0.13.0 github.com/lucsky/cuid v1.0.2 github.com/pion/dtls/v2 v2.0.9 - github.com/pion/ice/v2 v2.1.8 + github.com/pion/ice/v2 v2.1.10 github.com/pion/logging v0.2.2 github.com/pion/rtcp v1.2.6 - github.com/pion/rtp/v2 v2.0.0 + github.com/pion/rtp v1.7.1 github.com/pion/sdp/v3 v3.0.4 github.com/pion/transport v0.12.3 github.com/pion/turn/v2 v2.0.5 - github.com/pion/webrtc/v3 v3.0.29 + github.com/pion/webrtc/v3 v3.1.0-beta.2.0.20210808020610-5253475ec730 github.com/prometheus/client_golang v1.9.0 github.com/rs/cors v1.7.0 // indirect github.com/rs/zerolog v1.23.0 @@ -28,7 +29,10 @@ require ( github.com/sourcegraph/jsonrpc2 v0.0.0-20210201082850-366fbb520750 github.com/spf13/viper v1.7.1 github.com/stretchr/testify v1.7.0 + golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 // indirect + golang.org/x/net v0.0.0-20210805182204-aaa1db679c0d // indirect golang.org/x/sync v0.0.0-20201207232520-09787c993a3a + golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069 // indirect google.golang.org/grpc v1.35.0 google.golang.org/grpc/examples v0.0.0-20201209011439-fd32f6a4fefe // indirect google.golang.org/protobuf v1.25.0 diff --git a/go.sum b/go.sum index 8449127e2..7be7fc750 100644 --- a/go.sum +++ b/go.sum @@ -148,8 +148,9 @@ github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OI github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/google/uuid v1.2.0 h1:qJYtXnJRWmpe7m/3XlyhrsLrEURqHRM2kxzoxXqyUDs= github.com/google/uuid v1.2.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 h1:EGx4pi6eqNxGaHF6qqu48+N2wcFQ5qg5FXgOdqsJ5d8= @@ -290,11 +291,12 @@ github.com/pion/datachannel v1.4.21 h1:3ZvhNyfmxsAqltQrApLPQMhSFNA+aT87RqyCq4OXm github.com/pion/datachannel v1.4.21/go.mod h1:oiNyP4gHx2DIwRzX/MFyH0Rz/Gz05OgBlayAI2hAWjg= github.com/pion/dtls/v2 v2.0.9 h1:7Ow+V++YSZQMYzggI0P9vLJz/hUFcffsfGMfT/Qy+u8= github.com/pion/dtls/v2 v2.0.9/go.mod h1:O0Wr7si/Zj5/EBFlDzDd6UtVxx25CE1r7XM7BQKYQho= -github.com/pion/ice/v2 v2.1.7/go.mod h1:kV4EODVD5ux2z8XncbLHIOtcXKtYXVgLVCeVqnpoeP0= -github.com/pion/ice/v2 v2.1.8 h1:3kV4XaB2C3z1gDUXZmwSB/B0PSdZ7GFFC3w4iUX9prs= -github.com/pion/ice/v2 v2.1.8/go.mod h1:kV4EODVD5ux2z8XncbLHIOtcXKtYXVgLVCeVqnpoeP0= -github.com/pion/interceptor v0.0.12 h1:eC1iVneBIAQJEfaNAfDqAncJWhMDAnaXPRCJsltdokE= -github.com/pion/interceptor v0.0.12/go.mod h1:qzeuWuD/ZXvPqOnxNcnhWfkCZ2e1kwwslicyyPnhoK4= +github.com/pion/ice/v2 v2.1.10 h1:Jt/BfUsaP+Dr6E5rbsy+w7w1JtHyFN0w2DkgfWq7Fko= +github.com/pion/ice/v2 v2.1.10/go.mod h1:kV4EODVD5ux2z8XncbLHIOtcXKtYXVgLVCeVqnpoeP0= +github.com/pion/interceptor v0.0.13 h1:fnV+b0p/KEzwwr/9z2nsSqA9IQRMsM4nF5HjrNSWwBo= +github.com/pion/interceptor v0.0.13/go.mod h1:svsW2QoLHLoGLUr4pDoSopGBEWk8FZwlfxId/OKRKzo= +github.com/pion/interceptor v0.0.15 h1:pQFkBUL8akUHiGoFr+pM94Q/15x7sLFh0K3Nj+DCC6s= +github.com/pion/interceptor v0.0.15/go.mod h1:pg3J253eGi5bqyKzA74+ej5Y19ez2jkWANVnF+Z9Dfk= github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY= github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms= github.com/pion/mdns v0.0.5 h1:Q2oj/JB3NqfzY9xGZ1fPzZzK7sDSD8rZPOvcIQ10BCw= @@ -306,6 +308,9 @@ github.com/pion/rtcp v1.2.6/go.mod h1:52rMNPWFsjr39z9B9MhnkqhPLoeHTv1aN63o/42bWE github.com/pion/rtp v1.6.2/go.mod h1:bDb5n+BFZxXx0Ea7E5qe+klMuqiBrP+w8XSjiWtCUko= github.com/pion/rtp v1.6.5 h1:o2cZf8OascA5HF/b0PAbTxRKvOWxTQxWYt7SlToxFGI= github.com/pion/rtp v1.6.5/go.mod h1:bDb5n+BFZxXx0Ea7E5qe+klMuqiBrP+w8XSjiWtCUko= +github.com/pion/rtp v1.7.0/go.mod h1:bDb5n+BFZxXx0Ea7E5qe+klMuqiBrP+w8XSjiWtCUko= +github.com/pion/rtp v1.7.1 h1:hCaxfVgPGt13eF/Tu9RhVn04c+dAcRZmhdDWqUE13oY= +github.com/pion/rtp v1.7.1/go.mod h1:bDb5n+BFZxXx0Ea7E5qe+klMuqiBrP+w8XSjiWtCUko= github.com/pion/sctp v1.7.10/go.mod h1:EhpTUQu1/lcK3xI+eriS6/96fWetHGCvBi9MSsnaBN0= github.com/pion/sctp v1.7.12 h1:GsatLufywVruXbZZT1CKg+Jr8ZTkwiPnmUC/oO9+uuY= github.com/pion/sctp v1.7.12/go.mod h1:xFe9cLMZ5Vj6eOzpyiKjT9SwGM4KpK/8Jbw5//jc+0s= @@ -313,6 +318,8 @@ github.com/pion/sdp/v3 v3.0.4 h1:2Kf+dgrzJflNCSw3TV5v2VLeI0s/qkzy2r5jlR0wzf8= github.com/pion/sdp/v3 v3.0.4/go.mod h1:bNiSknmJE0HYBprTHXKPQ3+JjacTv5uap92ueJZKsRk= github.com/pion/srtp/v2 v2.0.2 h1:664iGzVmaY7KYS5M0gleY0DscRo9ReDfTxQrq4UgGoU= github.com/pion/srtp/v2 v2.0.2/go.mod h1:VEyLv4CuxrwGY8cxM+Ng3bmVy8ckz/1t6A0q/msKOw0= +github.com/pion/srtp/v2 v2.0.5 h1:ks3wcTvIUE/GHndO3FAvROQ9opy0uLELpwHJaQ1yqhQ= +github.com/pion/srtp/v2 v2.0.5/go.mod h1:8k6AJlal740mrZ6WYxc4Dg6qDqqhxoRG2GSjlUhDF0A= github.com/pion/stun v0.3.5 h1:uLUCBCkQby4S1cf6CGuR9QrVOKcvUwFeemaC865QHDg= github.com/pion/stun v0.3.5/go.mod h1:gDMim+47EeEtfWogA37n6qXZS88L5V6LqFcf+DZA2UA= github.com/pion/transport v0.10.1/go.mod h1:PBis1stIILMiis0PewDw91WJeLJkyIMcEk+DwKOzf4A= @@ -323,8 +330,10 @@ github.com/pion/turn/v2 v2.0.5 h1:iwMHqDfPEDEOFzwWKT56eFmh6DYC6o/+xnLAEzgISbA= github.com/pion/turn/v2 v2.0.5/go.mod h1:APg43CFyt/14Uy7heYUOGWdkem/Wu4PhCO/bjyrTqMw= github.com/pion/udp v0.1.1 h1:8UAPvyqmsxK8oOjloDk4wUt63TzFe9WEJkg5lChlj7o= github.com/pion/udp v0.1.1/go.mod h1:6AFo+CMdKQm7UiA0eUPA8/eVCTx8jBIITLZHc9DWX5M= -github.com/pion/webrtc/v3 v3.0.29 h1:pVs6mYjbbYvC8pMsztayEz35DnUEFLPswsicGXaQjxo= -github.com/pion/webrtc/v3 v3.0.29/go.mod h1:XFQeLYBf++bWWA0sJqh6zF1ouWluosxwTOMOoTZGaD0= +github.com/pion/webrtc/v3 v3.0.33-0.20210728210013-6d7756b73271 h1:bH4Z2m7IUvUD9ot6H+eL3D157pqFGKG43Od04q7Mghc= +github.com/pion/webrtc/v3 v3.0.33-0.20210728210013-6d7756b73271/go.mod h1:wX3V5dQQUGCifhT1mYftC2kCrDQX6ZJ3B7Yad0R9JK0= +github.com/pion/webrtc/v3 v3.1.0-beta.2.0.20210808020610-5253475ec730 h1:7RJ7Auu3JmyACMCS+pQk6WOVuwpqoifPPMUdkyBKsbw= +github.com/pion/webrtc/v3 v3.1.0-beta.2.0.20210808020610-5253475ec730/go.mod h1:I4O6v2pkiXdVmcn7sUhCNwHUAepGU19PVEyR204s1qc= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= @@ -371,8 +380,6 @@ github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFR github.com/rs/cors v1.7.0 h1:+88SsELBHx5r+hZ8TCkggzSstaWNbDvThkVK8H6f9ik= github.com/rs/cors v1.7.0/go.mod h1:gFx+x8UowdsKA9AchylcLynDq+nNFfI8FkUZdN/jGCU= github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= -github.com/rs/zerolog v1.20.0 h1:38k9hgtUBdxFwE34yS8rTHmHBa4eN16E4DJlv177LNs= -github.com/rs/zerolog v1.20.0/go.mod h1:IzD0RJ65iWH0w97OQQebJEvTZYvsCUm9WVLWBQrJRjo= github.com/rs/zerolog v1.23.0 h1:UskrK+saS9P9Y789yNNulYKdARjPZuS35B8gJF2x60g= github.com/rs/zerolog v1.23.0/go.mod h1:6c7hFfxPOy7TacJc4Fcdi24/J0NKYGzjG8FWRI916Qo= github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= @@ -450,8 +457,9 @@ golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2 h1:It14KIkyBFYkHkwZ7k45minvA9aorojkyjGk9KJ5B/w= golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= +golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 h1:/UOmuWzQfxxo9UtlXMwuQU8CMgg1eZXqTRwkSQJWKOI= +golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -498,8 +506,11 @@ golang.org/x/net v0.0.0-20201202161906-c7110b5ffcbb/go.mod h1:sp8m0HH+o8qH0wwXwY golang.org/x/net v0.0.0-20210119194325-5f4716e94777/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210331212208-0fccb6fa2b5c/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= -golang.org/x/net v0.0.0-20210420210106-798c2154c571 h1:Q6Bg8xzKzpFPU4Oi1sBnBTHBwlMsLeEXpu4hYBY8rAg= -golang.org/x/net v0.0.0-20210420210106-798c2154c571/go.mod h1:72T/g9IO56b78aLF+1Kcs5dz7/ng1VjMUvfKvpfy+jM= +golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20210726213435-c6fcb2dbf985 h1:4CSI6oo7cOjJKajidEljs9h+uP0rRZBPPPhcCbj5mw8= +golang.org/x/net v0.0.0-20210726213435-c6fcb2dbf985/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20210805182204-aaa1db679c0d h1:20cMwl2fHAzkJMEA+8J4JgqBQcQGzbisXo31MIeenXI= +golang.org/x/net v0.0.0-20210805182204-aaa1db679c0d/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -545,8 +556,12 @@ golang.org/x/sys v0.0.0-20201214210602-f9fddec55a1e/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210112080510-489259a85091/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210420072515-93ed5bcd2bfe h1:WdX7u8s3yOigWAhHEaDl8r9G+4XwFQEQFtBMYyN+kXQ= -golang.org/x/sys v0.0.0-20210420072515-93ed5bcd2bfe/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c h1:F1jZWGFhYfh0Ci55sIpILtKKK8p3i2/krTr0H1rg74I= +golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069 h1:siQdpVirKtzPhKl3lZWozZraCFObP8S1v6PRp0bLrtU= +golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -574,7 +589,6 @@ golang.org/x/tools v0.0.0-20190606124116-d0a3d012864b/go.mod h1:/rFqwRUd4F7ZHNgw golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= golang.org/x/tools v0.0.0-20190628153133-6cdbf07be9d0/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= golang.org/x/tools v0.0.0-20190816200558-6889da9d5479/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20190828213141-aed303cbaa74/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20190911174233-4f2ddba30aff/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191012152004-8de300cfc20a/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= diff --git a/pkg/buffer/bucket.go b/pkg/buffer/bucket.go index f58e7f829..83444865d 100644 --- a/pkg/buffer/bucket.go +++ b/pkg/buffer/bucket.go @@ -9,6 +9,7 @@ const maxPktSize = 1500 type Bucket struct { buf []byte + src *[]byte init bool step int @@ -16,10 +17,11 @@ type Bucket struct { maxSteps int } -func NewBucket(buf []byte) *Bucket { +func NewBucket(buf *[]byte) *Bucket { return &Bucket{ - buf: buf, - maxSteps: int(math.Floor(float64(len(buf))/float64(maxPktSize))) - 1, + src: buf, + buf: *buf, + maxSteps: int(math.Floor(float64(len(*buf))/float64(maxPktSize))) - 1, } } diff --git a/pkg/buffer/bucket_test.go b/pkg/buffer/bucket_test.go index efe40f946..706387f96 100644 --- a/pkg/buffer/bucket_test.go +++ b/pkg/buffer/bucket_test.go @@ -41,7 +41,8 @@ var TestPackets = []*rtp.Packet{ } func Test_queue(t *testing.T) { - q := NewBucket(make([]byte, 25000)) + b := make([]byte, 25000) + q := NewBucket(&b) for _, p := range TestPackets { p := p @@ -98,7 +99,8 @@ func Test_queue_edges(t *testing.T) { }, }, } - q := NewBucket(make([]byte, 25000)) + b := make([]byte, 25000) + q := NewBucket(&b) for _, p := range TestPackets { p := p assert.NotNil(t, p) diff --git a/pkg/buffer/buffer.go b/pkg/buffer/buffer.go index 7df2f5f38..1fd124ab8 100644 --- a/pkg/buffer/buffer.go +++ b/pkg/buffer/buffer.go @@ -68,7 +68,7 @@ type Buffer struct { minPacketProbe int lastPacketRead int - maxTemporalLayer int64 + maxTemporalLayer int32 bitrate uint64 bitrateHelper uint64 lastSRNTPTime uint64 @@ -134,10 +134,10 @@ func (b *Buffer) Bind(params webrtc.RTPParameters, o Options) { switch { case strings.HasPrefix(b.mime, "audio/"): b.codecType = webrtc.RTPCodecTypeAudio - b.bucket = NewBucket(b.audioPool.Get().([]byte)) + b.bucket = NewBucket(b.audioPool.Get().(*[]byte)) case strings.HasPrefix(b.mime, "video/"): b.codecType = webrtc.RTPCodecTypeVideo - b.bucket = NewBucket(b.videoPool.Get().([]byte)) + b.bucket = NewBucket(b.videoPool.Get().(*[]byte)) default: b.codecType = webrtc.RTPCodecType(0) } @@ -253,10 +253,10 @@ func (b *Buffer) Close() error { b.closeOnce.Do(func() { if b.bucket != nil && b.codecType == webrtc.RTPCodecTypeVideo { - b.videoPool.Put(b.bucket.buf) + b.videoPool.Put(b.bucket.src) } if b.bucket != nil && b.codecType == webrtc.RTPCodecTypeAudio { - b.audioPool.Put(b.bucket.buf) + b.audioPool.Put(b.bucket.src) } b.closed.set(true) b.onClose() @@ -345,9 +345,9 @@ func (b *Buffer) calc(pkt []byte, arrivalTime int64) { if b.mime == "video/vp8" { pld := ep.Payload.(VP8) - mtl := atomic.LoadInt64(&b.maxTemporalLayer) - if mtl < int64(pld.TID) { - atomic.StoreInt64(&b.maxTemporalLayer, int64(pld.TID)) + mtl := atomic.LoadInt32(&b.maxTemporalLayer) + if mtl < int32(pld.TID) { + atomic.StoreInt32(&b.maxTemporalLayer, int32(pld.TID)) } } @@ -525,8 +525,8 @@ func (b *Buffer) Bitrate() uint64 { return atomic.LoadUint64(&b.bitrate) } -func (b *Buffer) MaxTemporalLayer() int64 { - return atomic.LoadInt64(&b.maxTemporalLayer) +func (b *Buffer) MaxTemporalLayer() int32 { + return atomic.LoadInt32(&b.maxTemporalLayer) } func (b *Buffer) OnTransportWideCC(fn func(sn uint16, timeNS int64, marker bool)) { diff --git a/pkg/buffer/buffer_test.go b/pkg/buffer/buffer_test.go index 5947ee800..7327aacbf 100644 --- a/pkg/buffer/buffer_test.go +++ b/pkg/buffer/buffer_test.go @@ -16,7 +16,6 @@ func CreateTestPacket(pktStamp *SequenceNumberAndTimeStamp) *rtp.Packet { if pktStamp == nil { return &rtp.Packet{ Header: rtp.Header{}, - Raw: []byte{1, 2, 3}, Payload: []byte{1, 2, 3}, } } @@ -26,7 +25,6 @@ func CreateTestPacket(pktStamp *SequenceNumberAndTimeStamp) *rtp.Packet { SequenceNumber: pktStamp.SequenceNumber, Timestamp: pktStamp.Timestamp, }, - Raw: []byte{1, 2, 3}, Payload: []byte{1, 2, 3}, } } @@ -48,7 +46,8 @@ func CreateTestListPackets(snsAndTSs []SequenceNumberAndTimeStamp) (packetList [ func TestNack(t *testing.T) { pool := &sync.Pool{ New: func() interface{} { - return make([]byte, 1500) + b := make([]byte, 1500) + return &b }, } logger.SetGlobalOptions(logger.GlobalConfig{V: 1}) // 2 - TRACE @@ -149,7 +148,8 @@ func TestNewBuffer(t *testing.T) { } pool := &sync.Pool{ New: func() interface{} { - return make([]byte, 1500) + b := make([]byte, 1500) + return &b }, } logger.SetGlobalOptions(logger.GlobalConfig{V: 2}) // 2 - TRACE diff --git a/pkg/buffer/factory.go b/pkg/buffer/factory.go index 71d214b03..627c76072 100644 --- a/pkg/buffer/factory.go +++ b/pkg/buffer/factory.go @@ -30,12 +30,14 @@ func NewBufferFactory(trackingPackets int, logger logr.Logger) *Factory { return &Factory{ videoPool: &sync.Pool{ New: func() interface{} { - return make([]byte, trackingPackets*maxPktSize) + b := make([]byte, trackingPackets*maxPktSize) + return &b }, }, audioPool: &sync.Pool{ New: func() interface{} { - return make([]byte, maxPktSize*25) + b := make([]byte, maxPktSize*25) + return &b }, }, rtpBuffers: make(map[uint32]*Buffer), diff --git a/pkg/buffer/nack.go b/pkg/buffer/nack.go index 46d634413..0050ba7ba 100644 --- a/pkg/buffer/nack.go +++ b/pkg/buffer/nack.go @@ -39,15 +39,20 @@ func (n *nackQueue) push(extSN uint32) { if i < len(n.nacks) && n.nacks[i].sn == extSN { return } - n.nacks = append(n.nacks, nack{}) - copy(n.nacks[i+1:], n.nacks[i:]) - n.nacks[i] = nack{ + + nck := nack{ sn: extSN, nacked: 0, } + if i == len(n.nacks) { + n.nacks = append(n.nacks, nck) + } else { + n.nacks = append(n.nacks[:i+1], n.nacks[i:]...) + n.nacks[i] = nck + } - if len(n.nacks) > maxNackCache { - n.nacks = n.nacks[1:] + if len(n.nacks) >= maxNackCache { + copy(n.nacks, n.nacks[1:]) } } diff --git a/pkg/relay/README.md b/pkg/relay/README.md new file mode 100644 index 000000000..87f4ade34 --- /dev/null +++ b/pkg/relay/README.md @@ -0,0 +1,81 @@ +# Relay + +`ion-sfu` supports relaying tracks to other ion-SFUs or other services using the ORTC API. + +Using this api allows to quickly send the stream to other services by signaling a single request, after that all the following negotiations are handled internally. + +## API + +### Relay Peer + +The relay peer shares common methods with the Webrtc PeerConnection, so it should be straight forward to use. To create a new relay peer follow below example +: +```go + // Meta holds all the related information of the peer you want to relay. + meta := PeerMeta{ + PeerID : "super-villain-1", + SessionID : "world-domination", + } + // config will hold pion/webrtc related structs required for the connection. + // you should fill according your requirements or leave the defaults. + config := &PeerConfig{} + peer, err := NewPeer(meta, config) + handleErr(err) + + // Now before working with the peer you need to signal the peer to + // your remote sever, the signaling can be whatever method you want (gRPC, RESt, pubsub, etc..) + signalFunc= func (meta PeerMeta, signal []byte) ([]byte, error){ + if meta.session== "world-domination"{ + return RelayToLegionOfDoom(meta, signal) + } + return nil, errors.New("not supported") + } + + // The remote peer should create a new Relay Peer with the metadata and call Answer. + if err:= peer.Offer(signalFunc); err!=nil{ + handleErr(err) + } + + // If there are no errors, relay peer offer some convenience methods to communicate with + // Relayed peer. + + // Emit will fire and forget to the request event + peer.Emit("evil-plan-1", data) + // Request will wait for a remote answer, use a time cancelled + // context to not block forever if peer does not answer + ans,err:= peer.Request(ctx, "evil-plan-2", data) + // To listen to remote event just attach the callback to peer + peer.OnRequest( func (event string, msg Message){ + // to access to request data + msg.Paylod() + // to reply the request + msg.Reply(...) + }) + + // The Relay Peer also has some convenience callbacks to manage the peer lifespan. + + // Peer OnClose is called when the remote peer connection is closed, or the Close method is called + peer.OnClose(func()) + // Peer OnReady is called when the relay peer is ready to start negotiating tracks, data channels and request + // is highly recommended to attach all the initialization logic to this callback + peer.OnReady(func()) + + // To add or receive tracks or data channels the API is similar to webrtc Peer Connection, just listen + // to the required callbacks + peer.OnDataChannel(f func(channel *webrtc.DataChannel)) + peer.OnTrack(f func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver)) + // Make sure to call below methods after the OnReady callback fired. + peer.CreateDataChannel(label string) + peer.AddTrack(receiver *webrtc.RTPReceiver, remoteTrack *webrtc.TrackRemote, +localTrack webrtc.TrackLocal) (*webrtc.RTPSender, error) +``` + +### ION-SFU integration + +ION-SFU offers some convenience methods for relaying peers in a very simple way. + +To relay a peer just call `Peer.Publisher().Relay(...)` then signal the data to the remote SFU and ingest the data using: + +`session.AddRelayPeer(peerID string, signalData []byte) ([]byte, error)` + +set the []byte response from the method as the response of the signaling. And is ready, everytime a peer joins to the new SFU will negotiate the relayed stream. diff --git a/pkg/relay/relay.go b/pkg/relay/relay.go index 52ad73408..eafc15d7d 100644 --- a/pkg/relay/relay.go +++ b/pkg/relay/relay.go @@ -8,6 +8,7 @@ import ( "math/rand" "strings" "sync" + "sync/atomic" "time" "github.com/go-logr/logr" @@ -59,6 +60,15 @@ type PeerMeta struct { SessionID string `json:"sessionId"` } +type Options struct { + // RelayMiddlewareDC if set to true middleware data channels will be created and forwarded + // to the relayed peer + RelayMiddlewareDC bool + // RelaySessionDC if set to true fanout data channels will be created and forwarded to the + // relayed peer + RelaySessionDC bool +} + type Peer struct { mu sync.Mutex rmu sync.Mutex @@ -80,10 +90,11 @@ type Peer struct { gatherer *webrtc.ICEGatherer dcIndex uint16 - onReady func() - onRequest func(event string, message Message) - onDataChannel func(channel *webrtc.DataChannel) - onTrack func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver, meta *TrackMeta) + onReady atomic.Value // func() + onClose atomic.Value // func() + onRequest atomic.Value // func(event string, message Message) + onDataChannel atomic.Value // func(channel *webrtc.DataChannel) + onTrack atomic.Value // func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver, meta *TrackMeta) } func NewPeer(meta PeerMeta, conf *PeerConfig) (*Peer, error) { @@ -123,20 +134,19 @@ func NewPeer(meta PeerMeta, conf *PeerConfig) (*Peer, error) { } sctp.OnDataChannel(func(channel *webrtc.DataChannel) { - p.mu.Lock() - defer p.mu.Unlock() if channel.Label() == signalerLabel { p.signalingDC = channel channel.OnMessage(p.handleRequest) - p.ready = true - if p.onReady != nil { - p.onReady() - } + channel.OnOpen(func() { + if f := p.onReady.Load(); f != nil { + f.(func())() + } + }) return } - if p.onDataChannel != nil { - p.onDataChannel(channel) + if f := p.onDataChannel.Load(); f != nil { + f.(func(dataChannel *webrtc.DataChannel))(channel) } }) @@ -151,6 +161,10 @@ func NewPeer(meta PeerMeta, conf *PeerConfig) (*Peer, error) { return p, nil } +func (p *Peer) ID() string { + return p.meta.PeerID +} + // Offer is used for establish the connection of the local relay Peer // with the remote relay Peer. // @@ -212,17 +226,19 @@ func (p *Peer) Offer(signalFn func(meta PeerMeta, signal []byte) ([]byte, error) } p.signalingDC.OnOpen(func() { - p.mu.Lock() - p.ready = true - p.mu.Unlock() - if p.onReady != nil { - p.onReady() + if f := p.onReady.Load(); f != nil { + f.(func())() } }) p.signalingDC.OnMessage(p.handleRequest) return nil } +// OnClose sets a callback that is called when relay Peer is closed. +func (p *Peer) OnClose(fn func()) { + p.onClose.Store(fn) +} + // Answer answers the remote Peer signal signalRequest func (p *Peer) Answer(request []byte) ([]byte, error) { if p.gatherer.State() != webrtc.ICEGathererStateNew { @@ -287,32 +303,24 @@ func (p *Peer) LocalTracks() []webrtc.TrackLocal { // OnReady calls the callback when relay Peer is ready to start sending/receiving and creating DC func (p *Peer) OnReady(f func()) { - p.mu.Lock() - p.onReady = f - p.mu.Unlock() + p.onReady.Store(f) } // OnRequest calls the callback when Peer gets a request message from remote Peer func (p *Peer) OnRequest(f func(event string, msg Message)) { - p.mu.Lock() - p.onRequest = f - p.mu.Unlock() + p.onRequest.Store(f) } // OnDataChannel sets an event handler which is invoked when a data // channel message arrives from a remote Peer. func (p *Peer) OnDataChannel(f func(channel *webrtc.DataChannel)) { - p.mu.Lock() - p.onDataChannel = f - p.mu.Unlock() + p.onDataChannel.Store(f) } // OnTrack sets an event handler which is called when remote track // arrives from a remote Peer func (p *Peer) OnTrack(f func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver, meta *TrackMeta)) { - p.mu.Lock() - p.onTrack = f - p.mu.Unlock() + p.onTrack.Store(f) } // Close ends the relay Peer @@ -327,6 +335,10 @@ func (p *Peer) Close() error { closeErrs = append(closeErrs, p.sctp.Stop(), p.dtls.Stop(), p.ice.Stop()) + if f := p.onClose.Load(); f != nil { + f.(func())() + } + return joinErrs(closeErrs...) } @@ -371,6 +383,7 @@ func (p *Peer) start(s *signal) error { return err } } + p.ready = true return nil } @@ -402,9 +415,18 @@ func (p *Peer) receive(s *signal) error { }, }, }}); err != nil { + return err } - if p.onTrack != nil { - p.onTrack(recv.Track(), recv, s.TrackMeta) + + recv.SetRTPParameters(webrtc.RTPParameters{ + HeaderExtensions: nil, + Codecs: []webrtc.RTPCodecParameters{*s.TrackMeta.CodecParameters}, + }) + + track := recv.Track() + + if f := p.onTrack.Load(); f != nil { + f.(func(remote *webrtc.TrackRemote, receiver *webrtc.RTPReceiver, meta *TrackMeta))(track, recv, s.TrackMeta) } p.receivers = append(p.receivers, recv) @@ -436,7 +458,7 @@ func (p *Peer) AddTrack(receiver *webrtc.RTPReceiver, remoteTrack *webrtc.TrackR } s.Encodings = &webrtc.RTPCodingParameters{ - SSRC: webrtc.SSRC(p.rand.Uint32()), + SSRC: sdr.GetParameters().Encodings[0].SSRC, PayloadType: remoteTrack.PayloadType(), } pld, err := json.Marshal(&s) @@ -484,10 +506,7 @@ func (p *Peer) Emit(event string, data []byte) error { return err } - if err = p.signalingDC.Send(msg); err != nil { - return err - } - return nil + return p.signalingDC.Send(msg) } func (p *Peer) Request(ctx context.Context, event string, data []byte) ([]byte, error) { @@ -565,16 +584,14 @@ func (p *Peer) handleRequest(msg webrtc.DataChannelMessage) { } if mr.Event != signalerRequestEvent { - p.mu.Lock() - if p.onRequest != nil { - p.onRequest(mr.Event, Message{ + if f := p.onRequest.Load(); f != nil { + f.(func(string, Message))(mr.Event, Message{ p: p, event: mr.Event, id: mr.ID, msg: mr.Payload, }) } - p.mu.Unlock() return } diff --git a/pkg/sfu/downtrack.go b/pkg/sfu/downtrack.go index a50217fcb..eaded16c3 100644 --- a/pkg/sfu/downtrack.go +++ b/pkg/sfu/downtrack.go @@ -13,7 +13,7 @@ import ( "github.com/pion/webrtc/v3" ) -// DownTrackType determines the type of a track +// DownTrackType determines the type of track type DownTrackType int const ( @@ -25,7 +25,6 @@ const ( // to SFU Subscriber, the track handle the packets for simple, simulcast // and SVC Publisher. type DownTrack struct { - mu sync.RWMutex id string peerID string bound atomicBool @@ -37,10 +36,10 @@ type DownTrack struct { sequencer *sequencer trackType DownTrackType bufferFactory *buffer.Factory - payload []byte + payload *[]byte - currentSpatialLayer int - targetSpatialLayer int + currentSpatialLayer int32 + targetSpatialLayer int32 temporalLayer int32 enabled atomicBool @@ -52,8 +51,8 @@ type DownTrack struct { lastTS uint32 simulcast simulcastTrackHelpers - maxSpatialLayer int64 - maxTemporalLayer int64 + maxSpatialLayer int32 + maxTemporalLayer int32 codec webrtc.RTPCodecCapability receiver Receiver @@ -64,10 +63,9 @@ type DownTrack struct { closeOnce sync.Once // Report helpers - octetCount uint32 - packetCount uint32 - maxPacketTs uint32 - lastPacketMs int64 + octetCount uint32 + packetCount uint32 + maxPacketTs uint32 } // NewDownTrack returns a DownTrack. @@ -116,7 +114,6 @@ func (d *DownTrack) Bind(t webrtc.TrackLocalContext) (webrtc.RTPCodecParameters, // because a track has been stopped. func (d *DownTrack) Unbind(_ webrtc.TrackLocalContext) error { d.bound.set(false) - d.receiver.DeleteDownTrack(d.CurrentSpatialLayer(), d.id) return nil } @@ -155,7 +152,7 @@ func (d *DownTrack) SetTransceiver(transceiver *webrtc.RTPTransceiver) { } // WriteRTP writes a RTP Packet to the DownTrack -func (d *DownTrack) WriteRTP(p *buffer.ExtPacket) error { +func (d *DownTrack) WriteRTP(p *buffer.ExtPacket, layer int) error { if !d.enabled.get() || !d.bound.get() { return nil } @@ -163,7 +160,7 @@ func (d *DownTrack) WriteRTP(p *buffer.ExtPacket) error { case SimpleDownTrack: return d.writeSimpleRTP(p) case SimulcastDownTrack: - return d.writeSimulcastRTP(p) + return d.writeSimulcastRTP(p, layer) } return nil } @@ -196,33 +193,27 @@ func (d *DownTrack) Close() { }) } -func (d *DownTrack) SetInitialLayers(spatialLayer, temporalLayer int64) { - d.mu.Lock() - defer d.mu.Unlock() - d.currentSpatialLayer = int(spatialLayer) - d.targetSpatialLayer = d.currentSpatialLayer - atomic.StoreInt32(&d.temporalLayer, int32(temporalLayer<<16)|int32(temporalLayer)) +func (d *DownTrack) SetInitialLayers(spatialLayer, temporalLayer int32) { + atomic.StoreInt32(&d.currentSpatialLayer, spatialLayer) + atomic.StoreInt32(&d.targetSpatialLayer, spatialLayer) + atomic.StoreInt32(&d.temporalLayer, temporalLayer<<16|temporalLayer) } func (d *DownTrack) CurrentSpatialLayer() int { - d.mu.RLock() - defer d.mu.RUnlock() - return d.currentSpatialLayer + return int(atomic.LoadInt32(&d.currentSpatialLayer)) } -func (d *DownTrack) SwitchSpatialLayer(targetLayer int64, setAsMax bool) error { +func (d *DownTrack) SwitchSpatialLayer(targetLayer int32, setAsMax bool) error { if d.trackType == SimulcastDownTrack { - d.mu.Lock() - defer d.mu.Unlock() // Don't switch until previous switch is done or canceled - if d.currentSpatialLayer != d.targetSpatialLayer || - d.currentSpatialLayer == int(targetLayer) { + csl := atomic.LoadInt32(&d.currentSpatialLayer) + if csl != atomic.LoadInt32(&d.targetSpatialLayer) || csl == targetLayer { return ErrSpatialLayerBusy } if err := d.receiver.SwitchDownTrack(d, int(targetLayer)); err == nil { - d.targetSpatialLayer = int(targetLayer) + atomic.StoreInt32(&d.targetSpatialLayer, targetLayer) if setAsMax { - atomic.StoreInt64(&d.maxSpatialLayer, targetLayer) + atomic.StoreInt32(&d.maxSpatialLayer, targetLayer) } } return nil @@ -230,18 +221,14 @@ func (d *DownTrack) SwitchSpatialLayer(targetLayer int64, setAsMax bool) error { return ErrSpatialNotSupported } -func (d *DownTrack) SwitchSpatialLayerDone() { - d.mu.Lock() - d.currentSpatialLayer = d.targetSpatialLayer - d.mu.Unlock() +func (d *DownTrack) SwitchSpatialLayerDone(layer int32) { + atomic.StoreInt32(&d.currentSpatialLayer, layer) } func (d *DownTrack) UptrackLayersChange(availableLayers []uint16) (int64, error) { if d.trackType == SimulcastDownTrack { - d.mu.RLock() currentLayer := uint16(d.currentSpatialLayer) - d.mu.RUnlock() - maxLayer := uint16(atomic.LoadInt64(&d.maxSpatialLayer)) + maxLayer := uint16(atomic.LoadInt32(&d.maxSpatialLayer)) var maxFound uint16 = 0 layerFound := false @@ -265,7 +252,7 @@ func (d *DownTrack) UptrackLayersChange(availableLayers []uint16) (int64, error) targetLayer = minFound } if currentLayer != targetLayer { - if err := d.SwitchSpatialLayer(int64(targetLayer), false); err != nil { + if err := d.SwitchSpatialLayer(int32(targetLayer), false); err != nil { return int64(targetLayer), err } } @@ -274,7 +261,7 @@ func (d *DownTrack) UptrackLayersChange(availableLayers []uint16) (int64, error) return -1, fmt.Errorf("downtrack %s does not support simulcast", d.id) } -func (d *DownTrack) SwitchTemporalLayer(targetLayer int64, setAsMax bool) { +func (d *DownTrack) SwitchTemporalLayer(targetLayer int32, setAsMax bool) { if d.trackType == SimulcastDownTrack { layer := atomic.LoadInt32(&d.temporalLayer) currentLayer := uint16(layer) @@ -284,9 +271,9 @@ func (d *DownTrack) SwitchTemporalLayer(targetLayer int64, setAsMax bool) { if currentLayer != currentTargetLayer { return } - atomic.StoreInt32(&d.temporalLayer, int32(targetLayer<<16)|int32(currentLayer)) + atomic.StoreInt32(&d.temporalLayer, targetLayer<<16|int32(currentLayer)) if setAsMax { - atomic.StoreInt64(&d.maxTemporalLayer, targetLayer) + atomic.StoreInt32(&d.maxTemporalLayer, targetLayer) } } } @@ -325,16 +312,24 @@ func (d *DownTrack) CreateSenderReport() *rtcp.SenderReport { if !d.bound.get() { return nil } - now := time.Now().UnixNano() - nowNTP := timeToNtp(now) - lastPktMs := atomic.LoadInt64(&d.lastPacketMs) - maxPktTs := atomic.LoadUint32(&d.lastTS) - diffTs := uint32((now/1e6)-lastPktMs) * d.codec.ClockRate / 1000 + srRTP, srNTP := d.receiver.GetSenderReportTime(int(atomic.LoadInt32(&d.currentSpatialLayer))) + if srRTP == 0 { + return nil + } + + now := time.Now() + nowNTP := toNtpTime(now) + + diff := (uint64(now.Sub(ntpTime(srNTP).Time())) * uint64(d.codec.ClockRate)) / uint64(time.Second) + if diff < 0 { + diff = 0 + } octets, packets := d.getSRStats() + return &rtcp.SenderReport{ SSRC: d.ssrc, - NTPTime: nowNTP, - RTPTime: maxPktTs + diffTs, + NTPTime: uint64(nowNTP), + RTPTime: srRTP + uint32(diff), PacketCount: packets, OctetCount: octets, } @@ -356,9 +351,10 @@ func (d *DownTrack) writeSimpleRTP(extPkt *buffer.ExtPacket) error { } } - d.snOffset = extPkt.Packet.SequenceNumber - d.lastSN - 1 - d.tsOffset = extPkt.Packet.Timestamp - d.lastTS - 1 - + if d.lastSN != 0 { + d.snOffset = extPkt.Packet.SequenceNumber - d.lastSN - 1 + d.tsOffset = extPkt.Packet.Timestamp - d.lastTS - 1 + } atomic.StoreUint32(&d.lastSSRC, extPkt.Packet.SSRC) d.reSync.set(false) } @@ -370,10 +366,9 @@ func (d *DownTrack) writeSimpleRTP(extPkt *buffer.ExtPacket) error { if d.sequencer != nil { d.sequencer.push(extPkt.Packet.SequenceNumber, newSN, newTS, 0, extPkt.Head) } - if (newSN-d.lastSN)&0x8000 == 0 || d.lastSN == 0 { + if extPkt.Head { d.lastSN = newSN - atomic.StoreInt64(&d.lastPacketMs, extPkt.Arrival/1e6) - atomic.StoreUint32(&d.lastTS, newTS) + d.lastTS = newTS } hdr := extPkt.Packet.Header hdr.PayloadType = d.payloadType @@ -382,16 +377,19 @@ func (d *DownTrack) writeSimpleRTP(extPkt *buffer.ExtPacket) error { hdr.SSRC = d.ssrc _, err := d.writeStream.WriteRTP(&hdr, extPkt.Packet.Payload) - if err != nil { - Logger.Error(err, "Write packet err") - } return err } -func (d *DownTrack) writeSimulcastRTP(extPkt *buffer.ExtPacket) error { +func (d *DownTrack) writeSimulcastRTP(extPkt *buffer.ExtPacket, layer int) error { // Check if packet SSRC is different from before // if true, the video source changed reSync := d.reSync.get() + csl := d.CurrentSpatialLayer() + + if csl != layer { + return nil + } + lastSSRC := atomic.LoadUint32(&d.lastSSRC) if lastSSRC != extPkt.Packet.SSRC || reSync { // Wait for a keyframe to sync new source @@ -449,20 +447,16 @@ func (d *DownTrack) writeSimulcastRTP(extPkt *buffer.ExtPacket) error { if d.simulcast.temporalSupported { if d.mime == "video/vp8" { drop := false - if picID, tlz0Idx, drop = setVP8TemporalLayer(extPkt, d); drop { + if payload, picID, tlz0Idx, drop = setVP8TemporalLayer(extPkt, d); drop { // Pkt not in temporal getLayer update sequence number offset to avoid gaps d.snOffset++ return nil } - payload = d.payload } } if d.sequencer != nil { - d.mu.RLock() - layer := d.currentSpatialLayer - d.mu.RUnlock() - if meta := d.sequencer.push(extPkt.Packet.SequenceNumber, newSN, newTS, uint8(layer), extPkt.Head); meta != nil && + if meta := d.sequencer.push(extPkt.Packet.SequenceNumber, newSN, newTS, uint8(csl), extPkt.Head); meta != nil && d.simulcast.temporalSupported && d.mime == "video/vp8" { meta.setVP8PayloadMeta(tlz0Idx, picID) } @@ -474,8 +468,6 @@ func (d *DownTrack) writeSimulcastRTP(extPkt *buffer.ExtPacket) error { if extPkt.Head { d.lastSN = newSN d.lastTS = newTS - atomic.StoreInt64(&d.lastPacketMs, time.Now().UnixNano()/1e6) - atomic.StoreUint32(&d.lastTS, newTS) } // Update base d.simulcast.lTSCalc = extPkt.Arrival @@ -487,10 +479,6 @@ func (d *DownTrack) writeSimulcastRTP(extPkt *buffer.ExtPacket) error { hdr.PayloadType = d.payloadType _, err := d.writeStream.WriteRTP(&hdr, payload) - if err != nil { - Logger.Error(err, "Write packet err") - } - return err } @@ -563,14 +551,12 @@ func (d *DownTrack) handleRTCP(bytes []byte) { } func (d *DownTrack) handleLayerChange(maxRatePacketLoss uint8, expectedMinBitrate uint64) { - d.mu.RLock() - currentSpatialLayer := int64(d.currentSpatialLayer) - targetSpatialLayer := int64(d.targetSpatialLayer) - d.mu.RUnlock() + currentSpatialLayer := atomic.LoadInt32(&d.currentSpatialLayer) + targetSpatialLayer := atomic.LoadInt32(&d.targetSpatialLayer) temporalLayer := atomic.LoadInt32(&d.temporalLayer) - currentTemporalLayer := int64(temporalLayer & 0x0f) - targetTemporalLayer := int64(temporalLayer >> 16) + currentTemporalLayer := temporalLayer & 0x0f + targetTemporalLayer := temporalLayer >> 16 if targetSpatialLayer == currentSpatialLayer && currentTemporalLayer == targetTemporalLayer { if time.Now().After(d.simulcast.switchDelay) { @@ -580,12 +566,12 @@ func (d *DownTrack) handleLayerChange(maxRatePacketLoss uint8, expectedMinBitrat mctl := mtl[currentSpatialLayer] if maxRatePacketLoss <= 5 { - if currentTemporalLayer < mctl && currentTemporalLayer+1 <= atomic.LoadInt64(&d.maxTemporalLayer) && + if currentTemporalLayer < mctl && currentTemporalLayer+1 <= atomic.LoadInt32(&d.maxTemporalLayer) && expectedMinBitrate >= 3*cbr/4 { d.SwitchTemporalLayer(currentTemporalLayer+1, false) d.simulcast.switchDelay = time.Now().Add(3 * time.Second) } - if currentTemporalLayer >= mctl && expectedMinBitrate >= 3*cbr/2 && currentSpatialLayer+1 <= atomic.LoadInt64(&d.maxSpatialLayer) && + if currentTemporalLayer >= mctl && expectedMinBitrate >= 3*cbr/2 && currentSpatialLayer+1 <= atomic.LoadInt32(&d.maxSpatialLayer) && currentSpatialLayer+1 <= 2 { if err := d.SwitchSpatialLayer(currentSpatialLayer+1, false); err == nil { d.SwitchTemporalLayer(0, false) diff --git a/pkg/sfu/helpers.go b/pkg/sfu/helpers.go index 384c74785..40c1f0579 100644 --- a/pkg/sfu/helpers.go +++ b/pkg/sfu/helpers.go @@ -4,16 +4,18 @@ import ( "encoding/binary" "strings" "sync/atomic" + "time" "github.com/pion/ion-sfu/pkg/buffer" "github.com/pion/webrtc/v3" ) -const ( - ntpEpoch = 2208988800 +var ( + ntpEpoch = time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC) ) type atomicBool int32 +type ntpTime uint64 func (a *atomicBool) set(value bool) (swapped bool) { if value { @@ -29,10 +31,10 @@ func (a *atomicBool) get() bool { // setVp8TemporalLayer is a helper to detect and modify accordingly the vp8 payload to reflect // temporal changes in the SFU. // VP8 temporal layers implemented according https://tools.ietf.org/html/rfc7741 -func setVP8TemporalLayer(p *buffer.ExtPacket, d *DownTrack) (picID uint16, tlz0Idx uint8, drop bool) { +func setVP8TemporalLayer(p *buffer.ExtPacket, d *DownTrack) (buf []byte, picID uint16, tlz0Idx uint8, drop bool) { pkt, ok := p.Payload.(buffer.VP8) if !ok { - return 0, 0, false + return } layer := atomic.LoadInt32(&d.temporalLayer) @@ -48,8 +50,9 @@ func setVP8TemporalLayer(p *buffer.ExtPacket, d *DownTrack) (picID uint16, tlz0I return } - d.payload = d.payload[:len(p.Packet.Payload)] - copy(d.payload, p.Packet.Payload) + buf = *d.payload + buf = buf[:len(p.Packet.Payload)] + copy(buf, p.Packet.Payload) picID = pkt.PictureID - d.simulcast.refPicID + d.simulcast.pRefPicID + 1 tlz0Idx = pkt.TL0PICIDX - d.simulcast.refTlZIdx + d.simulcast.pRefTlZIdx + 1 @@ -59,7 +62,7 @@ func setVP8TemporalLayer(p *buffer.ExtPacket, d *DownTrack) (picID uint16, tlz0I d.simulcast.lTlZIdx = tlz0Idx } - modifyVP8TemporalPayload(d.payload, pkt.PicIDIdx, pkt.TlzIdx, picID, tlz0Idx, pkt.MBit) + modifyVP8TemporalPayload(buf, pkt.PicIDIdx, pkt.TlzIdx, picID, tlz0Idx, pkt.MBit) return } @@ -75,12 +78,6 @@ func modifyVP8TemporalPayload(payload []byte, picIDIdx, tlz0Idx int, picID uint1 payload[tlz0Idx] = tlz0ID } -func timeToNtp(ns int64) uint64 { - seconds := uint64(ns/1e9 + ntpEpoch) - fraction := uint64(((ns % 1e9) << 32) / 1e9) - return seconds<<32 | fraction -} - // Do a fuzzy find for a codec in the list of codecs // Used for lookup up a codec in an existing list to find a match func codecParametersFuzzySearch(needle webrtc.RTPCodecParameters, haystack []webrtc.RTPCodecParameters) (webrtc.RTPCodecParameters, error) { @@ -118,3 +115,28 @@ func fastForwardTimestampAmount(newestTimestamp uint32, referenceTimestamp uint3 } return newestTimestamp - referenceTimestamp } + +func (t ntpTime) Duration() time.Duration { + sec := (t >> 32) * 1e9 + frac := (t & 0xffffffff) * 1e9 + nsec := frac >> 32 + if uint32(frac) >= 0x80000000 { + nsec++ + } + return time.Duration(sec + nsec) +} + +func (t ntpTime) Time() time.Time { + return ntpEpoch.Add(t.Duration()) +} + +func toNtpTime(t time.Time) ntpTime { + nsec := uint64(t.Sub(ntpEpoch)) + sec := nsec / 1e9 + nsec = (nsec - sec*1e9) << 32 + frac := nsec / 1e9 + if nsec%1e9 >= 1e9/2 { + frac++ + } + return ntpTime(sec<<32 | frac) +} diff --git a/pkg/sfu/helpers_test.go b/pkg/sfu/helpers_test.go index 1a9d63795..f21bc84ac 100644 --- a/pkg/sfu/helpers_test.go +++ b/pkg/sfu/helpers_test.go @@ -7,7 +7,7 @@ import ( func Test_timeToNtp(t *testing.T) { type args struct { - ns int64 + ns time.Time } tests := []struct { name string @@ -17,15 +17,15 @@ func Test_timeToNtp(t *testing.T) { { name: "Must return correct NTP time", args: args{ - ns: time.Unix(1602391458, 1234).UnixNano(), + ns: time.Unix(1602391458, 1234), }, - wantNTP: 16369753560730047667, + wantNTP: 16369753560730047668, }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { - gotNTP := timeToNtp(tt.args.ns) + gotNTP := uint64(toNtpTime(tt.args.ns)) if gotNTP != tt.wantNTP { t.Errorf("timeToNtp() gotFraction = %v, want %v", gotNTP, tt.wantNTP) } diff --git a/pkg/sfu/publisher.go b/pkg/sfu/publisher.go index 020acecbd..e34783ae5 100644 --- a/pkg/sfu/publisher.go +++ b/pkg/sfu/publisher.go @@ -7,13 +7,16 @@ import ( "sync/atomic" "time" + "github.com/pion/ion-sfu/pkg/buffer" + "github.com/pion/transport/packetio" + "github.com/pion/ion-sfu/pkg/relay" "github.com/pion/rtcp" "github.com/pion/webrtc/v3" ) type Publisher struct { - mu sync.Mutex + mu sync.RWMutex id string pc *webrtc.PeerConnection cfg *WebRTCTransportConfig @@ -21,7 +24,8 @@ type Publisher struct { router Router session Session tracks []PublisherTrack - relayPeer []*relay.Peer + relayed atomicBool + relayPeers []*relayPeer candidates []webrtc.ICECandidateInit onICEConnectionStateChangeHandler atomic.Value // func(webrtc.ICEConnectionState) @@ -30,6 +34,13 @@ type Publisher struct { closeOnce sync.Once } +type relayPeer struct { + peer *relay.Peer + dcs []*webrtc.DataChannel + withSRReports bool + relayFanOutDataChannels bool +} + type PublisherTrack struct { Track *webrtc.TrackRemote Receiver Receiver @@ -59,7 +70,7 @@ func NewPublisher(id string, session Session, cfg *WebRTCTransportConfig) (*Publ id: id, pc: pc, cfg: cfg, - router: newRouter(id, pc, session, cfg), + router: newRouter(id, session, cfg), session: session, } @@ -78,8 +89,8 @@ func NewPublisher(id string, session Session, cfg *WebRTCTransportConfig) (*Publ p.mu.Lock() publisherTrack := PublisherTrack{track, r, true} p.tracks = append(p.tracks, publisherTrack) - for _, rp := range p.relayPeer { - if err = p.createRelayTrack(track, r, rp); err != nil { + for _, rp := range p.relayPeers { + if err = p.createRelayTrack(track, r, rp.peer); err != nil { Logger.V(1).Error(err, "Creating relay track.", "peer_id", p.id) } } @@ -117,6 +128,8 @@ func NewPublisher(id string, session Session, cfg *WebRTCTransportConfig) (*Publ } }) + p.router.SetRTCPWriter(p.pc.WriteRTCP) + return p, nil } @@ -150,10 +163,10 @@ func (p *Publisher) GetRouter() Router { // Close peer func (p *Publisher) Close() { p.closeOnce.Do(func() { - if len(p.relayPeer) > 0 { + if len(p.relayPeers) > 0 { p.mu.Lock() - for _, rp := range p.relayPeer { - if err := rp.Close(); err != nil { + for _, rp := range p.relayPeers { + if err := rp.peer.Close(); err != nil { Logger.Error(err, "Closing relay peer transport.") } } @@ -187,23 +200,54 @@ func (p *Publisher) PeerConnection() *webrtc.PeerConnection { return p.pc } -func (p *Publisher) Relay(ice []webrtc.ICEServer) (*relay.Peer, error) { +// Relay will relay all current and future tracks from current Publisher +func (p *Publisher) Relay(signalFn func(meta relay.PeerMeta, signal []byte) ([]byte, error), + options ...func(r *relayPeer)) (*relay.Peer, error) { + lrp := &relayPeer{} + for _, o := range options { + o(lrp) + } + rp, err := relay.NewPeer(relay.PeerMeta{ PeerID: p.id, SessionID: p.session.ID(), }, &relay.PeerConfig{ SettingEngine: p.cfg.Setting, - ICEServers: ice, + ICEServers: p.cfg.Configuration.ICEServers, Logger: Logger, }) if err != nil { return nil, fmt.Errorf("relay: %w", err) } + lrp.peer = rp rp.OnReady(func() { - for _, lbl := range p.session.GetDataChannelLabels() { - if _, err := rp.CreateDataChannel(lbl); err != nil { - Logger.V(1).Error(err, "Creating data channels.", "peer_id", p.id) + peer := p.session.GetPeer(p.id) + + p.relayed.set(true) + if lrp.relayFanOutDataChannels { + for _, lbl := range p.session.GetFanOutDataChannelLabels() { + lbl := lbl + dc, err := rp.CreateDataChannel(lbl) + if err != nil { + Logger.V(1).Error(err, "Creating data channels.", "peer_id", p.id) + } + dc.OnMessage(func(msg webrtc.DataChannelMessage) { + if peer == nil || peer.Subscriber() == nil { + return + } + if sdc := peer.Subscriber().DataChannel(lbl); sdc != nil { + if msg.IsString { + if err = sdc.SendText(string(msg.Data)); err != nil { + Logger.Error(err, "Sending dc message err") + } + } else { + if err = sdc.Send(msg.Data); err != nil { + Logger.Error(err, "Sending dc message err") + } + } + } + }) } } @@ -217,13 +261,26 @@ func (p *Publisher) Relay(ice []webrtc.ICEServer) (*relay.Peer, error) { Logger.V(1).Error(err, "Creating relay track.", "peer_id", p.id) } } - p.relayPeer = append(p.relayPeer, rp) + p.relayPeers = append(p.relayPeers, lrp) + p.mu.Unlock() + + if lrp.withSRReports { + go p.relayReports(rp) + } + }) + + rp.OnDataChannel(func(channel *webrtc.DataChannel) { + if !lrp.relayFanOutDataChannels { + return + } + p.mu.Lock() + lrp.dcs = append(lrp.dcs, channel) p.mu.Unlock() - go p.relayReports(rp) + p.session.AddDatachannel("", channel) }) - if err = rp.Offer(p.cfg.Relay); err != nil { + if err = rp.Offer(signalFn); err != nil { return nil, fmt.Errorf("relay: %w", err) } @@ -241,9 +298,54 @@ func (p *Publisher) PublisherTracks() []PublisherTrack { return tracks } +// AddRelayFanOutDataChannel adds fan out data channel to relayed peers +func (p *Publisher) AddRelayFanOutDataChannel(label string) { + p.mu.RLock() + defer p.mu.RUnlock() + + for _, rp := range p.relayPeers { + for _, dc := range rp.dcs { + if dc.Label() == label { + continue + } + } + + dc, err := rp.peer.CreateDataChannel(label) + if err != nil { + Logger.V(1).Error(err, "Creating data channels.", "peer_id", p.id) + } + dc.OnMessage(func(msg webrtc.DataChannelMessage) { + p.session.FanOutMessage("", label, msg) + }) + } +} + +// GetRelayedDataChannels Returns a slice of data channels that belongs to relayed +// peers +func (p *Publisher) GetRelayedDataChannels(label string) []*webrtc.DataChannel { + p.mu.RLock() + defer p.mu.RUnlock() + + dcs := make([]*webrtc.DataChannel, 0, len(p.relayPeers)) + for _, rp := range p.relayPeers { + for _, dc := range rp.dcs { + if dc.Label() == label { + dcs = append(dcs, dc) + break + } + } + } + return dcs +} + +// Relayed returns true if the publisher has been relayed at least once +func (p *Publisher) Relayed() bool { + return p.relayed.get() +} + func (p *Publisher) Tracks() []*webrtc.TrackRemote { - p.mu.Lock() - defer p.mu.Unlock() + p.mu.RLock() + defer p.mu.RUnlock() tracks := make([]*webrtc.TrackRemote, len(p.tracks)) for idx, track := range p.tracks { @@ -281,6 +383,32 @@ func (p *Publisher) createRelayTrack(track *webrtc.TrackRemote, receiver Receive return fmt.Errorf("relay: %w", err) } + p.cfg.BufferFactory.GetOrNew(packetio.RTCPBufferPacket, + uint32(sdr.GetParameters().Encodings[0].SSRC)).(*buffer.RTCPReader).OnPacket(func(bytes []byte) { + pkts, err := rtcp.Unmarshal(bytes) + if err != nil { + Logger.V(1).Error(err, "Unmarshal rtcp reports", "peer_id", p.id) + return + } + var rpkts []rtcp.Packet + for _, pkt := range pkts { + switch pk := pkt.(type) { + case *rtcp.PictureLossIndication: + rpkts = append(rpkts, &rtcp.PictureLossIndication{ + SenderSSRC: pk.MediaSSRC, + MediaSSRC: uint32(track.SSRC()), + }) + } + } + + if len(rpkts) > 0 { + if err := p.pc.WriteRTCP(rpkts); err != nil { + Logger.V(1).Error(err, "Sending rtcp relay reports", "peer_id", p.id) + } + } + + }) + downTrack.OnCloseHandler(func() { if err = sdr.Stop(); err != nil { Logger.V(1).Error(err, "Stopping relay sender.", "peer_id", p.id) @@ -301,7 +429,9 @@ func (p *Publisher) relayReports(rp *relay.Peer) { if !dt.bound.get() { continue } - r = append(r, dt.CreateSenderReport()) + if sr := dt.CreateSenderReport(); sr != nil { + r = append(r, sr) + } } } diff --git a/pkg/sfu/receiver.go b/pkg/sfu/receiver.go index d9db6f53e..078d6096d 100644 --- a/pkg/sfu/receiver.go +++ b/pkg/sfu/receiver.go @@ -4,16 +4,16 @@ import ( "io" "math/rand" "sync" + "sync/atomic" "time" - "github.com/rs/zerolog/log" - "github.com/gammazero/workerpool" "github.com/pion/ion-sfu/pkg/buffer" "github.com/pion/ion-sfu/pkg/stats" "github.com/pion/rtcp" "github.com/pion/rtp" "github.com/pion/webrtc/v3" + "github.com/rs/zerolog/log" ) // Receiver defines a interface for a track receivers @@ -23,22 +23,23 @@ type Receiver interface { Codec() webrtc.RTPCodecParameters Kind() webrtc.RTPCodecType SSRC(layer int) uint32 + SetTrackMeta(trackID, streamID string) AddUpTrack(track *webrtc.TrackRemote, buffer *buffer.Buffer, bestQualityFirst bool) AddDownTrack(track *DownTrack, bestQualityFirst bool) SwitchDownTrack(track *DownTrack, layer int) error GetBitrate() [3]uint64 - GetMaxTemporalLayer() [3]int64 + GetMaxTemporalLayer() [3]int32 RetransmitPackets(track *DownTrack, packets []packetMeta) error DeleteDownTrack(layer int, id string) OnCloseHandler(fn func()) SendRTCP(p []rtcp.Packet) SetRTCPCh(ch chan []rtcp.Packet) + GetSenderReportTime(layer int) (rtpTS uint32, ntpTS uint64) } // WebRTCReceiver receives a video track type WebRTCReceiver struct { sync.Mutex - rtcpMu sync.Mutex closeOnce sync.Once peerID string @@ -52,11 +53,12 @@ type WebRTCReceiver struct { receiver *webrtc.RTPReceiver codec webrtc.RTPCodecParameters rtcpCh chan []rtcp.Packet - locks [3]sync.Mutex buffers [3]*buffer.Buffer upTracks [3]*webrtc.TrackRemote stats [3]*stats.Stream - downTracks [3][]*DownTrack + available [3]atomicBool + downTracks [3]atomic.Value // []*DownTrack + pending [3]atomicBool pendingTracks [3][]*DownTrack nackWorker *workerpool.WorkerPool isSimulcast bool @@ -77,6 +79,11 @@ func NewWebRTCReceiver(receiver *webrtc.RTPReceiver, track *webrtc.TrackRemote, } } +func (w *WebRTCReceiver) SetTrackMeta(trackID, streamID string) { + w.streamID = streamID + w.trackID = trackID +} + func (w *WebRTCReceiver) StreamID() string { return w.streamID } @@ -118,56 +125,43 @@ func (w *WebRTCReceiver) AddUpTrack(track *webrtc.TrackRemote, buff *buffer.Buff w.Lock() w.upTracks[layer] = track w.buffers[layer] = buff - w.downTracks[layer] = make([]*DownTrack, 0, 10) + w.available[layer].set(true) + w.downTracks[layer].Store(make([]*DownTrack, 0, 10)) + w.pendingTracks[layer] = make([]*DownTrack, 0, 10) w.Unlock() subBestQuality := func(targetLayer int) { for l := 0; l < targetLayer; l++ { - w.locks[l].Lock() - for _, dt := range w.downTracks[l] { - dt.SwitchSpatialLayer(int64(targetLayer), false) + dts := w.downTracks[l].Load() + if dts == nil { + continue + } + for _, dt := range dts.([]*DownTrack) { + _ = dt.SwitchSpatialLayer(int32(targetLayer), false) } - w.locks[l].Unlock() } } subLowestQuality := func(targetLayer int) { for l := 2; l != targetLayer; l-- { - w.locks[l].Lock() - for _, dt := range w.downTracks[l] { - dt.SwitchSpatialLayer(int64(targetLayer), false) + dts := w.downTracks[l].Load() + if dts == nil { + continue + } + for _, dt := range dts.([]*DownTrack) { + _ = dt.SwitchSpatialLayer(int32(targetLayer), false) } - w.locks[l].Unlock() } } if w.isSimulcast { - if bestQualityFirst { - if layer < 2 { - w.locks[layer+1].Lock() - t := w.downTracks[layer+1] - w.locks[layer+1].Unlock() - if t == nil { - subBestQuality(layer) - } - } else { - subBestQuality(layer) - } - } else { - if layer > 0 { - w.locks[layer-1].Lock() - t := w.downTracks[layer-1] - w.locks[layer-1].Unlock() - if t == nil { - subLowestQuality(layer) - } - } else { - subLowestQuality(layer) - } + if bestQualityFirst && (!w.available[2].get() || layer == 2) { + subBestQuality(layer) + } else if !bestQualityFirst && (!w.available[0].get() || layer == 0) { + subLowestQuality(layer) } } go w.writeRTP(layer) - } func (w *WebRTCReceiver) AddDownTrack(track *DownTrack, bestQualityFirst bool) { @@ -177,50 +171,44 @@ func (w *WebRTCReceiver) AddDownTrack(track *DownTrack, bestQualityFirst bool) { layer := 0 if w.isSimulcast { - w.Lock() - for i, t := range w.upTracks { - if t != nil { + for i, t := range w.available { + if t.get() { layer = i if !bestQualityFirst { break } } } - w.Unlock() - w.locks[layer].Lock() - if downTrackSubscribed(w.downTracks[layer], track) { - w.locks[layer].Unlock() + if w.downTrackSubscribed(layer, track) { return } - track.SetInitialLayers(int64(layer), 2) + track.SetInitialLayers(int32(layer), 2) track.maxSpatialLayer = 2 track.maxTemporalLayer = 2 track.lastSSRC = w.SSRC(layer) track.trackType = SimulcastDownTrack - track.payload = packetFactory.Get().([]byte) + track.payload = packetFactory.Get().(*[]byte) } else { - w.locks[layer].Lock() - if downTrackSubscribed(w.downTracks[layer], track) { - w.locks[layer].Unlock() + if w.downTrackSubscribed(layer, track) { return } track.SetInitialLayers(0, 0) track.trackType = SimpleDownTrack } - - w.downTracks[layer] = append(w.downTracks[layer], track) - w.locks[layer].Unlock() + w.Lock() + w.storeDownTrack(layer, track) + w.Unlock() } func (w *WebRTCReceiver) SwitchDownTrack(track *DownTrack, layer int) error { if w.closed.get() { return errNoReceiverFound } - - if buf := w.buffers[layer]; buf != nil { - w.locks[layer].Lock() + if w.available[layer].get() { + w.Lock() + w.pending[layer].set(true) w.pendingTracks[layer] = append(w.pendingTracks[layer], track) - w.locks[layer].Unlock() + w.Unlock() return nil } return errNoReceiverFound @@ -236,11 +224,11 @@ func (w *WebRTCReceiver) GetBitrate() [3]uint64 { return br } -func (w *WebRTCReceiver) GetMaxTemporalLayer() [3]int64 { - var tls [3]int64 - for i, buff := range w.buffers { - if buff != nil { - tls[i] = buff.MaxTemporalLayer() +func (w *WebRTCReceiver) GetMaxTemporalLayer() [3]int32 { + var tls [3]int32 + for i, a := range w.available { + if a.get() { + tls[i] = w.buffers[i].MaxTemporalLayer() } } return tls @@ -256,33 +244,28 @@ func (w *WebRTCReceiver) DeleteDownTrack(layer int, id string) { if w.closed.get() { return } + w.Lock() + w.deleteDownTrack(layer, id) + w.Unlock() +} - w.locks[layer].Lock() - idx := -1 - for i, dt := range w.downTracks[layer] { - if dt.peerID == id { - idx = i - break +func (w *WebRTCReceiver) deleteDownTrack(layer int, id string) { + dts := w.downTracks[layer].Load().([]*DownTrack) + ndts := make([]*DownTrack, 0, len(dts)) + for _, dt := range dts { + if dt.id != id { + ndts = append(ndts, dt) } } - if idx == -1 { - w.locks[layer].Unlock() - return - } - w.downTracks[layer][idx] = w.downTracks[layer][len(w.downTracks[layer])-1] - w.downTracks[layer][len(w.downTracks[layer])-1] = nil - w.downTracks[layer] = w.downTracks[layer][:len(w.downTracks[layer])-1] - w.locks[layer].Unlock() + w.downTracks[layer].Store(ndts) } func (w *WebRTCReceiver) SendRTCP(p []rtcp.Packet) { if _, ok := p[0].(*rtcp.PictureLossIndication); ok { - w.rtcpMu.Lock() - defer w.rtcpMu.Unlock() - if time.Now().UnixNano()-w.lastPli < 500e6 { + if time.Now().UnixNano()-atomic.LoadInt64(&w.lastPli) < 500e6 { return } - w.lastPli = time.Now().UnixNano() + atomic.StoreInt64(&w.lastPli, time.Now().UnixNano()) } w.rtcpCh <- p @@ -292,13 +275,19 @@ func (w *WebRTCReceiver) SetRTCPCh(ch chan []rtcp.Packet) { w.rtcpCh = ch } +func (w *WebRTCReceiver) GetSenderReportTime(layer int) (rtpTS uint32, ntpTS uint64) { + rtpTS, ntpTS, _ = w.buffers[layer].GetSenderReportData() + return +} + func (w *WebRTCReceiver) RetransmitPackets(track *DownTrack, packets []packetMeta) error { if w.nackWorker.Stopped() { return io.ErrClosedPipe } w.nackWorker.Submit(func() { + src := packetFactory.Get().(*[]byte) for _, meta := range packets { - pktBuff := packetFactory.Get().([]byte) + pktBuff := *src buff := w.buffers[meta.layer] if buff == nil { break @@ -335,9 +324,8 @@ func (w *WebRTCReceiver) RetransmitPackets(track *DownTrack, packets []packetMet } else { track.UpdateStats(uint32(i)) } - - packetFactory.Put(pktBuff) } + packetFactory.Put(src) }) return nil } @@ -360,40 +348,48 @@ func (w *WebRTCReceiver) writeRTP(layer int) { return } - w.locks[layer].Lock() - - if w.isSimulcast && len(w.pendingTracks[layer]) > 0 { - if pkt.KeyFrame { - for _, dt := range w.pendingTracks[layer] { - w.downTracks[layer] = append(w.downTracks[layer], dt) - w.DeleteDownTrack(dt.CurrentSpatialLayer(), dt.peerID) - dt.SwitchSpatialLayerDone() + if w.isSimulcast { + if w.pending[layer].get() { + if pkt.KeyFrame { + w.Lock() + for idx, dt := range w.pendingTracks[layer] { + w.deleteDownTrack(dt.CurrentSpatialLayer(), dt.peerID) + w.storeDownTrack(layer, dt) + dt.SwitchSpatialLayerDone(int32(layer)) + w.pendingTracks[layer][idx] = nil + } + w.pendingTracks[layer] = w.pendingTracks[layer][:0] + w.pending[layer].set(false) + w.Unlock() + } else { + w.SendRTCP(pli) } - w.pendingTracks[layer] = w.pendingTracks[layer][:0] - } else { - w.SendRTCP(pli) } } - for _, dt := range w.downTracks[layer] { - if err = dt.WriteRTP(pkt); err != nil { + for _, dt := range w.downTracks[layer].Load().([]*DownTrack) { + if err = dt.WriteRTP(pkt, layer); err != nil { + if err == io.EOF && err == io.ErrClosedPipe { + w.Lock() + w.deleteDownTrack(layer, dt.id) + w.Unlock() + } log.Error().Err(err).Str("id", dt.id).Msg("Error writing to down track") } } - w.locks[layer].Unlock() } } // closeTracks close all tracks from Receiver func (w *WebRTCReceiver) closeTracks() { - for idx, layer := range w.downTracks { - w.locks[idx].Lock() - for _, dt := range layer { + for idx, a := range w.available { + if !a.get() { + continue + } + for _, dt := range w.downTracks[idx].Load().([]*DownTrack) { dt.Close() } - w.downTracks[idx] = w.downTracks[idx][:0] - w.locks[idx].Unlock() } w.nackWorker.StopWait() if w.onCloseHandler != nil { @@ -401,7 +397,8 @@ func (w *WebRTCReceiver) closeTracks() { } } -func downTrackSubscribed(dts []*DownTrack, dt *DownTrack) bool { +func (w *WebRTCReceiver) downTrackSubscribed(layer int, dt *DownTrack) bool { + dts := w.downTracks[layer].Load().([]*DownTrack) for _, cdt := range dts { if cdt == dt { return true @@ -409,3 +406,11 @@ func downTrackSubscribed(dts []*DownTrack, dt *DownTrack) bool { } return false } + +func (w *WebRTCReceiver) storeDownTrack(layer int, dt *DownTrack) { + dts := w.downTracks[layer].Load().([]*DownTrack) + ndts := make([]*DownTrack, len(dts)+1) + copy(ndts, dts) + ndts[len(ndts)-1] = dt + w.downTracks[layer].Store(ndts) +} diff --git a/pkg/sfu/relay.go b/pkg/sfu/relay.go new file mode 100644 index 000000000..aa3f0133f --- /dev/null +++ b/pkg/sfu/relay.go @@ -0,0 +1,13 @@ +package sfu + +func RelayWithFanOutDataChannels() func(r *relayPeer) { + return func(r *relayPeer) { + r.relayFanOutDataChannels = true + } +} + +func RelayWithSenderReports() func(r *relayPeer) { + return func(r *relayPeer) { + r.withSRReports = true + } +} diff --git a/pkg/sfu/relaypeer.go b/pkg/sfu/relaypeer.go new file mode 100644 index 000000000..18502e31a --- /dev/null +++ b/pkg/sfu/relaypeer.go @@ -0,0 +1,205 @@ +package sfu + +import ( + "fmt" + "io" + "sync" + "time" + + "github.com/pion/ion-sfu/pkg/buffer" + "github.com/pion/ion-sfu/pkg/relay" + "github.com/pion/rtcp" + "github.com/pion/transport/packetio" + "github.com/pion/webrtc/v3" +) + +type RelayPeer struct { + mu sync.RWMutex + + peer *relay.Peer + session Session + router Router + config *WebRTCTransportConfig + tracks []PublisherTrack + relayPeers []*relay.Peer + dataChannels []*webrtc.DataChannel +} + +func NewRelayPeer(peer *relay.Peer, session Session, config *WebRTCTransportConfig) *RelayPeer { + r := newRouter(peer.ID(), session, config) + r.SetRTCPWriter(peer.WriteRTCP) + + rp := &RelayPeer{ + peer: peer, + router: r, + config: config, + session: session, + } + + peer.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver, meta *relay.TrackMeta) { + if recv, pub := r.AddReceiver(receiver, track); pub { + recv.SetTrackMeta(meta.TrackID, meta.StreamID) + session.Publish(r, recv) + rp.mu.Lock() + rp.tracks = append(rp.tracks, PublisherTrack{track, recv, true}) + for _, lrp := range rp.relayPeers { + if err := rp.createRelayTrack(track, recv, lrp); err != nil { + Logger.V(1).Error(err, "Creating relay track.", "peer_id", peer.ID()) + } + } + rp.mu.Unlock() + } else { + rp.mu.Lock() + rp.tracks = append(rp.tracks, PublisherTrack{track, recv, false}) + rp.mu.Unlock() + } + }) + + return rp +} + +func (r *RelayPeer) GetRouter() Router { + return r.router +} + +func (r *RelayPeer) ID() string { + return r.peer.ID() +} + +func (r *RelayPeer) Relay(signalFn func(meta relay.PeerMeta, signal []byte) ([]byte, error)) (*relay.Peer, error) { + rp, err := relay.NewPeer(relay.PeerMeta{ + PeerID: r.peer.ID(), + SessionID: r.session.ID(), + }, &relay.PeerConfig{ + SettingEngine: r.config.Setting, + ICEServers: r.config.Configuration.ICEServers, + Logger: Logger, + }) + if err != nil { + return nil, fmt.Errorf("relay: %w", err) + } + + rp.OnReady(func() { + r.mu.Lock() + for _, tp := range r.tracks { + if !tp.clientRelay { + // simulcast will just relay client track for now + continue + } + if err = r.createRelayTrack(tp.Track, tp.Receiver, rp); err != nil { + Logger.V(1).Error(err, "Creating relay track.", "peer_id", r.ID()) + } + } + r.relayPeers = append(r.relayPeers, rp) + r.mu.Unlock() + go r.relayReports(rp) + }) + + rp.OnDataChannel(func(channel *webrtc.DataChannel) { + r.mu.Lock() + r.dataChannels = append(r.dataChannels, channel) + r.mu.Unlock() + r.session.AddDatachannel("", channel) + }) + + if err = rp.Offer(signalFn); err != nil { + return nil, fmt.Errorf("relay: %w", err) + } + + return rp, nil +} + +func (r *RelayPeer) DataChannel(label string) *webrtc.DataChannel { + r.mu.RLock() + defer r.mu.RUnlock() + for _, dc := range r.dataChannels { + if dc.Label() == label { + return dc + } + } + return nil +} + +func (r *RelayPeer) createRelayTrack(track *webrtc.TrackRemote, receiver Receiver, rp *relay.Peer) error { + codec := track.Codec() + downTrack, err := NewDownTrack(webrtc.RTPCodecCapability{ + MimeType: codec.MimeType, + ClockRate: codec.ClockRate, + Channels: codec.Channels, + SDPFmtpLine: codec.SDPFmtpLine, + RTCPFeedback: []webrtc.RTCPFeedback{{"nack", ""}, {"nack", "pli"}}, + }, receiver, r.config.BufferFactory, r.ID(), r.config.Router.MaxPacketTrack) + if err != nil { + Logger.V(1).Error(err, "Create Relay downtrack err", "peer_id", r.ID()) + return err + } + + sdr, err := rp.AddTrack(receiver.(*WebRTCReceiver).receiver, track, downTrack) + if err != nil { + Logger.V(1).Error(err, "Relaying track.", "peer_id", r.ID()) + return fmt.Errorf("relay: %w", err) + } + + r.config.BufferFactory.GetOrNew(packetio.RTCPBufferPacket, + uint32(sdr.GetParameters().Encodings[0].SSRC)).(*buffer.RTCPReader).OnPacket(func(bytes []byte) { + pkts, err := rtcp.Unmarshal(bytes) + if err != nil { + Logger.V(1).Error(err, "Unmarshal rtcp reports", "peer_id", r.ID()) + return + } + var rpkts []rtcp.Packet + for _, pkt := range pkts { + switch pk := pkt.(type) { + case *rtcp.PictureLossIndication: + rpkts = append(rpkts, &rtcp.PictureLossIndication{ + SenderSSRC: pk.MediaSSRC, + MediaSSRC: uint32(track.SSRC()), + }) + } + } + + if len(rpkts) > 0 { + if err := r.peer.WriteRTCP(rpkts); err != nil { + Logger.V(1).Error(err, "Sending rtcp relay reports", "peer_id", r.ID()) + } + } + }) + + downTrack.OnCloseHandler(func() { + if err = sdr.Stop(); err != nil { + Logger.V(1).Error(err, "Stopping relay sender.", "peer_id", r.ID()) + } + }) + + receiver.AddDownTrack(downTrack, true) + return nil +} + +func (r *RelayPeer) relayReports(rp *relay.Peer) { + for { + time.Sleep(5 * time.Second) + + var packets []rtcp.Packet + for _, t := range rp.LocalTracks() { + if dt, ok := t.(*DownTrack); ok { + if !dt.bound.get() { + continue + } + if sr := dt.CreateSenderReport(); sr != nil { + packets = append(packets, sr) + } + } + } + + if len(packets) == 0 { + continue + } + + if err := rp.WriteRTCP(packets); err != nil { + if err == io.EOF || err == io.ErrClosedPipe { + return + } + Logger.Error(err, "Sending downtrack reports err") + } + } +} diff --git a/pkg/sfu/router.go b/pkg/sfu/router.go index e7c353911..7bdb28843 100644 --- a/pkg/sfu/router.go +++ b/pkg/sfu/router.go @@ -15,6 +15,7 @@ type Router interface { ID() string AddReceiver(receiver *webrtc.RTPReceiver, track *webrtc.TrackRemote) (Receiver, bool) AddDownTracks(s *Subscriber, r Receiver) error + SetRTCPWriter(func([]rtcp.Packet) error) AddDownTrack(s *Subscriber, r Receiver) (*DownTrack, error) Stop() } @@ -34,7 +35,6 @@ type router struct { sync.RWMutex id string twcc *twcc.Responder - peer *webrtc.PeerConnection stats map[uint32]*stats.Stream rtcpCh chan []rtcp.Packet stopCh chan struct{} @@ -42,14 +42,14 @@ type router struct { session Session receivers map[string]Receiver bufferFactory *buffer.Factory + writeRTCP func([]rtcp.Packet) error } // newRouter for routing rtp/rtcp packets -func newRouter(id string, peer *webrtc.PeerConnection, session Session, config *WebRTCTransportConfig) Router { +func newRouter(id string, session Session, config *WebRTCTransportConfig) Router { ch := make(chan []rtcp.Packet, 10) r := &router{ id: id, - peer: peer, rtcpCh: ch, stopCh: make(chan struct{}), config: config.Router, @@ -63,7 +63,6 @@ func newRouter(id string, peer *webrtc.PeerConnection, session Session, config * stats.Peers.Inc() } - go r.sendRTCP() return r } @@ -184,7 +183,6 @@ func (r *router) AddReceiver(receiver *webrtc.RTPReceiver, track *webrtc.TrackRe return recv, publish } -// AddWebRTCSender to Router func (r *router) AddDownTracks(s *Subscriber, recv Receiver) error { r.Lock() defer r.Unlock() @@ -213,6 +211,11 @@ func (r *router) AddDownTracks(s *Subscriber, recv Receiver) error { return nil } +func (r *router) SetRTCPWriter(fn func(packet []rtcp.Packet) error) { + r.writeRTCP = fn + go r.sendRTCP() +} + func (r *router) AddDownTrack(sub *Subscriber, recv Receiver) (*DownTrack, error) { for _, dt := range sub.GetDownTracks(recv.StreamID()) { if dt.ID() == recv.TrackID() { @@ -277,7 +280,7 @@ func (r *router) sendRTCP() { for { select { case pkts := <-r.rtcpCh: - if err := r.peer.WriteRTCP(pkts); err != nil { + if err := r.writeRTCP(pkts); err != nil { Logger.Error(err, "Write rtcp to peer err", "peer_id", r.id) } case <-r.stopCh: diff --git a/pkg/sfu/session.go b/pkg/sfu/session.go index 96340d6b3..22f6c97fa 100644 --- a/pkg/sfu/session.go +++ b/pkg/sfu/session.go @@ -1,11 +1,16 @@ package sfu import ( - "context" "encoding/json" "sync" "time" + "github.com/pion/ion-sfu/pkg/logger" + "github.com/pion/ion-sfu/pkg/relay" + "github.com/rs/zerolog/log" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "github.com/pion/webrtc/v3" ) @@ -16,19 +21,25 @@ type Session interface { Publish(router Router, r Receiver) Subscribe(peer Peer) AddPeer(peer Peer) + GetPeer(peerID string) Peer RemovePeer(peer Peer) + AddRelayPeer(peerID string, signalData []byte) ([]byte, error) AudioObserver() *AudioObserver AddDatachannel(owner string, dc *webrtc.DataChannel) GetDCMiddlewares() []*Datachannel - GetDataChannelLabels() []string - GetDataChannels(origin, label string) (dcs []*webrtc.DataChannel) + GetFanOutDataChannelLabels() []string + GetDataChannels(peerID, label string) (dcs []*webrtc.DataChannel) + FanOutMessage(origin, label string, msg webrtc.DataChannelMessage) Peers() []Peer + RelayPeers() []*RelayPeer } type SessionLocal struct { id string mu sync.RWMutex + config WebRTCTransportConfig peers map[string]Peer + relayPeers map[string]*RelayPeer closed atomicBool audioObs *AudioObserver fanOutDCs []string @@ -45,7 +56,9 @@ func NewSession(id string, dcs []*Datachannel, cfg WebRTCTransportConfig) Sessio s := &SessionLocal{ id: id, peers: make(map[string]Peer), + relayPeers: make(map[string]*RelayPeer), datachannels: dcs, + config: cfg, audioObs: NewAudioObserver(cfg.Router.AudioLevelThreshold, cfg.Router.AudioLevelInterval, cfg.Router.AudioLevelFilter), } go s.audioLevelObserver(cfg.Router.AudioLevelInterval) @@ -65,15 +78,12 @@ func (s *SessionLocal) GetDCMiddlewares() []*Datachannel { return s.datachannels } -func (s *SessionLocal) GetDataChannelLabels() []string { +func (s *SessionLocal) GetFanOutDataChannelLabels() []string { s.mu.RLock() defer s.mu.RUnlock() - res := make([]string, 0, len(s.datachannels)+len(s.fanOutDCs)) - copy(res, s.fanOutDCs) - for _, dc := range s.datachannels { - res = append(res, dc.Label) - } - return res + fanout := make([]string, len(s.fanOutDCs)) + copy(fanout, s.fanOutDCs) + return fanout } func (s *SessionLocal) AddPeer(peer Peer) { @@ -82,7 +92,55 @@ func (s *SessionLocal) AddPeer(peer Peer) { s.mu.Unlock() } -// RemovePeer removes a transport from the SessionLocal +func (s *SessionLocal) GetPeer(peerID string) Peer { + s.mu.RLock() + defer s.mu.RUnlock() + return s.peers[peerID] +} + +func (s *SessionLocal) AddRelayPeer(peerID string, signalData []byte) ([]byte, error) { + p, err := relay.NewPeer(relay.PeerMeta{ + PeerID: peerID, + SessionID: s.id, + }, &relay.PeerConfig{ + SettingEngine: s.config.Setting, + ICEServers: s.config.Configuration.ICEServers, + Logger: logger.New(), + }) + if err != nil { + log.Err(err).Msg("Creating relay peer") + return nil, status.Error(codes.Internal, err.Error()) + } + + resp, err := p.Answer(signalData) + if err != nil { + log.Err(err).Msg("Creating answer for relay") + return nil, err + } + + p.OnReady(func() { + rp := NewRelayPeer(p, s, &s.config) + s.mu.Lock() + s.relayPeers[peerID] = rp + s.mu.Unlock() + }) + + p.OnClose(func() { + s.mu.Lock() + delete(s.relayPeers, peerID) + s.mu.Unlock() + }) + + return resp, nil +} + +func (s *SessionLocal) GetRelayPeer(peerID string) *RelayPeer { + s.mu.RLock() + defer s.mu.RUnlock() + return s.relayPeers[peerID] +} + +// RemovePeer removes Peer from the SessionLocal func (s *SessionLocal) RemovePeer(p Peer) { pid := p.ID() Logger.V(0).Info("RemovePeer from SessionLocal", "peer_id", pid, "session_id", s.id) @@ -103,45 +161,64 @@ func (s *SessionLocal) AddDatachannel(owner string, dc *webrtc.DataChannel) { label := dc.Label() s.mu.Lock() - s.fanOutDCs = append(s.fanOutDCs, label) - peerOwner := s.peers[owner] - peers := make([]Peer, 0, len(s.peers)) - for _, p := range s.peers { - if p == peerOwner || p.Subscriber() == nil { - continue + for _, lbl := range s.fanOutDCs { + if label == lbl { + return } - peers = append(peers, p) } + s.fanOutDCs = append(s.fanOutDCs, label) + peerOwner := s.peers[owner] s.mu.Unlock() + peers := s.Peers() peerOwner.Subscriber().RegisterDatachannel(label, dc) dc.OnMessage(func(msg webrtc.DataChannelMessage) { - s.onMessage(owner, label, msg) + s.FanOutMessage(owner, label, msg) }) for _, p := range peers { - n, err := p.Subscriber().AddDataChannel(label) + peer := p + if peer.ID() == owner || peer.Subscriber() == nil { + continue + } + ndc, err := peer.Subscriber().AddDataChannel(label) if err != nil { Logger.Error(err, "error adding datachannel") continue } - pid := p.ID() - n.OnMessage(func(msg webrtc.DataChannelMessage) { - s.onMessage(pid, label, msg) + if peer.Publisher() != nil && peer.Publisher().Relayed() { + peer.Publisher().AddRelayFanOutDataChannel(label) + } + + pid := peer.ID() + ndc.OnMessage(func(msg webrtc.DataChannelMessage) { + s.FanOutMessage(pid, label, msg) + + if peer.Publisher().Relayed() { + for _, rdc := range peer.Publisher().GetRelayedDataChannels(label) { + if msg.IsString { + if err = rdc.SendText(string(msg.Data)); err != nil { + Logger.Error(err, "Sending dc message err") + } + } else { + if err = rdc.Send(msg.Data); err != nil { + Logger.Error(err, "Sending dc message err") + } + } + } + } }) - p.Subscriber().negotiate() + peer.Subscriber().negotiate() } } // Publish will add a Sender to all peers in current SessionLocal from given // Receiver func (s *SessionLocal) Publish(router Router, r Receiver) { - peers := s.Peers() - - for _, p := range peers { + for _, p := range s.Peers() { // Don't sub to self if router.ID() == p.ID() || p.Subscriber() == nil { continue @@ -170,16 +247,31 @@ func (s *SessionLocal) Subscribe(peer Peer) { } s.mu.RUnlock() - // Subscribe to fan out datachannels + // Subscribe to fan out data channels for _, label := range fdc { - n, err := peer.Subscriber().AddDataChannel(label) + dc, err := peer.Subscriber().AddDataChannel(label) if err != nil { Logger.Error(err, "error adding datachannel") continue } l := label - n.OnMessage(func(msg webrtc.DataChannelMessage) { - s.onMessage(peer.ID(), l, msg) + dc.OnMessage(func(msg webrtc.DataChannelMessage) { + s.FanOutMessage(peer.ID(), l, msg) + + if peer.Publisher().Relayed() { + for _, rdc := range peer.Publisher().GetRelayedDataChannels(l) { + if msg.IsString { + if err = rdc.SendText(string(msg.Data)); err != nil { + Logger.Error(err, "Sending dc message err") + } + } else { + if err = rdc.Send(msg.Data); err != nil { + Logger.Error(err, "Sending dc message err") + } + } + + } + } }) } @@ -192,6 +284,15 @@ func (s *SessionLocal) Subscribe(peer Peer) { } } + // Subscribe to relay streams + for _, p := range s.RelayPeers() { + err := p.GetRouter().AddDownTracks(peer.Subscriber(), nil) + if err != nil { + Logger.Error(err, "Subscribing to Router err") + continue + } + } + peer.Subscriber().negotiate() } @@ -206,6 +307,17 @@ func (s *SessionLocal) Peers() []Peer { return p } +// RelayPeers returns relay peers in this SessionLocal +func (s *SessionLocal) RelayPeers() []*RelayPeer { + s.mu.RLock() + defer s.mu.RUnlock() + p := make([]*RelayPeer, 0, len(s.peers)) + for _, peer := range s.relayPeers { + p = append(p, peer) + } + return p +} + // OnClose is called when the SessionLocal is closed func (s *SessionLocal) OnClose(f func()) { s.onCloseHandler = f @@ -220,34 +332,44 @@ func (s *SessionLocal) Close() { } } -func (s *SessionLocal) setRelayedDatachannel(peerID string, datachannel *webrtc.DataChannel) { - label := datachannel.Label() - for _, dc := range s.datachannels { - dc := dc - if dc.Label == label { - mws := newDCChain(dc.middlewares) - p := mws.Process(ProcessFunc(func(ctx context.Context, args ProcessArgs) { - if dc.onMessage != nil { - dc.onMessage(ctx, args) - } - })) - s.mu.RLock() - peer := s.peers[peerID] - s.mu.RUnlock() - datachannel.OnMessage(func(msg webrtc.DataChannelMessage) { - p.Process(context.Background(), ProcessArgs{ - Peer: peer, - Message: msg, - DataChannel: datachannel, - }) - }) +func (s *SessionLocal) FanOutMessage(origin, label string, msg webrtc.DataChannelMessage) { + dcs := s.GetDataChannels(origin, label) + for _, dc := range dcs { + if msg.IsString { + if err := dc.SendText(string(msg.Data)); err != nil { + Logger.Error(err, "Sending dc message err") + } + } else { + if err := dc.Send(msg.Data); err != nil { + Logger.Error(err, "Sending dc message err") + } } - return } +} - datachannel.OnMessage(func(msg webrtc.DataChannelMessage) { - s.onMessage(peerID, label, msg) - }) +func (s *SessionLocal) GetDataChannels(peerID, label string) []*webrtc.DataChannel { + s.mu.RLock() + defer s.mu.RUnlock() + dcs := make([]*webrtc.DataChannel, 0, len(s.peers)) + for pid, p := range s.peers { + if peerID == pid { + continue + } + + if p.Subscriber() != nil { + if dc := p.Subscriber().DataChannel(label); dc != nil && dc.ReadyState() == webrtc.DataChannelStateOpen { + dcs = append(dcs, dc) + } + } + + } + for _, rp := range s.relayPeers { + if dc := rp.DataChannel(label); dc != nil { + dcs = append(dcs, dc) + } + } + + return dcs } func (s *SessionLocal) audioLevelObserver(audioLevelInterval int) { @@ -289,34 +411,3 @@ func (s *SessionLocal) audioLevelObserver(audioLevelInterval int) { } } } - -func (s *SessionLocal) onMessage(origin, label string, msg webrtc.DataChannelMessage) { - dcs := s.GetDataChannels(origin, label) - for _, dc := range dcs { - if msg.IsString { - if err := dc.SendText(string(msg.Data)); err != nil { - Logger.Error(err, "Sending dc message err") - } - } else { - if err := dc.Send(msg.Data); err != nil { - Logger.Error(err, "Sending dc message err") - } - } - } -} - -func (s *SessionLocal) GetDataChannels(origin, label string) []*webrtc.DataChannel { - s.mu.RLock() - defer s.mu.RUnlock() - dcs := make([]*webrtc.DataChannel, 0, len(s.peers)) - for pid, p := range s.peers { - if origin == pid || p.Subscriber() == nil { - continue - } - - if dc := p.Subscriber().DataChannel(label); dc != nil && dc.ReadyState() == webrtc.DataChannelStateOpen { - dcs = append(dcs, dc) - } - } - return dcs -} diff --git a/pkg/sfu/sfu.go b/pkg/sfu/sfu.go index 7fce71701..8ba73ea63 100644 --- a/pkg/sfu/sfu.go +++ b/pkg/sfu/sfu.go @@ -8,8 +8,6 @@ import ( "sync" "time" - "github.com/pion/ion-sfu/pkg/relay" - "github.com/go-logr/logr" "github.com/pion/ice/v2" "github.com/pion/ion-sfu/pkg/buffer" @@ -38,7 +36,6 @@ type WebRTCTransportConfig struct { Configuration webrtc.Configuration Setting webrtc.SettingEngine Router RouterConfig - Relay func(meta relay.PeerMeta, signal []byte) ([]byte, error) BufferFactory *buffer.Factory } @@ -68,7 +65,6 @@ type Config struct { WebRTC WebRTCConfig `mapstructure:"webrtc"` Router RouterConfig `mapstructure:"Router"` Turn TurnConfig `mapstructure:"turn"` - Relay func(meta relay.PeerMeta, signal []byte) ([]byte, error) BufferFactory *buffer.Factory TurnAuth func(username string, realm string, srcAddr net.Addr) ([]byte, bool) } @@ -162,7 +158,6 @@ func NewWebRTCTransportConfig(c Config) WebRTCTransportConfig { }, Setting: se, Router: c.Router, - Relay: c.Relay, BufferFactory: c.BufferFactory, } @@ -186,7 +181,8 @@ func init() { // Init packet factory packetFactory = &sync.Pool{ New: func() interface{} { - return make([]byte, 1460) + b := make([]byte, 1460) + return &b }, } } diff --git a/pkg/sfu/sfu_test.go b/pkg/sfu/sfu_test.go index c982cadbe..243d6c473 100644 --- a/pkg/sfu/sfu_test.go +++ b/pkg/sfu/sfu_test.go @@ -293,10 +293,12 @@ func TestSFU_SessionScenarios(t *testing.T) { func() { switch action.kind { case "join": - me := webrtc.MediaEngine{} + me, _ := getPublisherMediaEngine() + se := webrtc.SettingEngine{} + se.DisableMediaEngineCopy(true) err := me.RegisterDefaultCodecs() assert.NoError(t, err) - api := webrtc.NewAPI(webrtc.WithMediaEngine(&me)) + api := webrtc.NewAPI(webrtc.WithMediaEngine(me), webrtc.WithSettingEngine(se)) pub, err := api.NewPeerConnection(webrtc.Configuration{}) assert.NoError(t, err) sub, err := api.NewPeerConnection(webrtc.Configuration{}) diff --git a/pkg/sfu/subscriber.go b/pkg/sfu/subscriber.go index 4b1995bf2..63ea9e0c2 100644 --- a/pkg/sfu/subscriber.go +++ b/pkg/sfu/subscriber.go @@ -260,7 +260,9 @@ func (s *Subscriber) downTracksReports() { if !dt.bound.get() { continue } - r = append(r, dt.CreateSenderReport()) + if sr := dt.CreateSenderReport(); sr != nil { + r = append(r, sr) + } sd = append(sd, dt.CreateSourceDescriptionChunks()...) } }