diff --git a/.github/hooks/pre-commit.sh b/.github/hooks/pre-commit.sh index cc318d7..d5a1ce5 100755 --- a/.github/hooks/pre-commit.sh +++ b/.github/hooks/pre-commit.sh @@ -10,3 +10,4 @@ exec 1>&2 .github/lint-disallowed-functions-in-library.sh +.github/lint-no-trailing-newline-in-log-messages.sh diff --git a/.github/lint-no-trailing-newline-in-log-messages.sh b/.github/lint-no-trailing-newline-in-log-messages.sh new file mode 100755 index 0000000..29cd4a2 --- /dev/null +++ b/.github/lint-no-trailing-newline-in-log-messages.sh @@ -0,0 +1,41 @@ +#!/usr/bin/env bash + +# +# DO NOT EDIT THIS FILE +# +# It is automatically copied from https://github.com/pion/.goassets repository. +# +# If you want to update the shared CI config, send a PR to +# https://github.com/pion/.goassets instead of this repository. +# + +set -e + +# Disallow usages of functions that cause the program to exit in the library code +SCRIPT_PATH=$( cd "$(dirname "${BASH_SOURCE[0]}")" ; pwd -P ) +if [ -f ${SCRIPT_PATH}/.ci.conf ] +then + . ${SCRIPT_PATH}/.ci.conf +fi + +files=$( + find "$SCRIPT_PATH/.." -name "*.go" \ + | while read file + do + excluded=false + for ex in $EXCLUDE_DIRECTORIES + do + if [[ $file == */$ex/* ]] + then + excluded=true + break + fi + done + $excluded || echo "$file" + done +) + +if grep -E '\.(Trace|Debug|Info|Warn|Error)f?\("[^"]*\\n"\)?' $files | grep -v -e 'nolint'; then + echo "Log format strings should have trailing new-line" + exit 1 +fi \ No newline at end of file diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml new file mode 100644 index 0000000..cec0d7c --- /dev/null +++ b/.github/workflows/codeql-analysis.yml @@ -0,0 +1,40 @@ +name: "CodeQL" + +on: + workflow_dispatch: + schedule: + - cron: '23 5 * * 0' + pull_request: + branches: + - master + paths: + - '**.go' + +jobs: + analyze: + name: Analyze + runs-on: ubuntu-latest + permissions: + actions: read + contents: read + security-events: write + + steps: + - name: Checkout repo + uses: actions/checkout@v3 + + # The code in examples/ might intentionally do things like log credentials + # in order to show how the library is used, aid in debugging etc. We + # should ignore those for CodeQL scanning, and only focus on the package + # itself. + - name: Remove example code + run: | + rm -rf examples/ + + - name: Initialize CodeQL + uses: github/codeql-action/init@v2 + with: + languages: 'go' + + - name: CodeQL Analysis + uses: github/codeql-action/analyze@v2 diff --git a/.github/workflows/generate-authors.yml b/.github/workflows/generate-authors.yml index 9a80a48..c7a8404 100644 --- a/.github/workflows/generate-authors.yml +++ b/.github/workflows/generate-authors.yml @@ -16,6 +16,8 @@ on: jobs: checksecret: + permissions: + contents: none runs-on: ubuntu-latest outputs: is_PIONBOT_PRIVATE_KEY_set: ${{ steps.checksecret_job.outputs.is_PIONBOT_PRIVATE_KEY_set }} @@ -28,6 +30,8 @@ jobs: echo "::set-output name=is_PIONBOT_PRIVATE_KEY_set::${{ env.PIONBOT_PRIVATE_KEY != '' }}" generate-authors: + permissions: + contents: write needs: [checksecret] if: needs.checksecret.outputs.is_PIONBOT_PRIVATE_KEY_set == 'true' runs-on: ubuntu-latest diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index 438443f..11b6336 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -16,6 +16,10 @@ on: - opened - edited - synchronize + +permissions: + contents: read + jobs: lint-commit-message: name: Metadata @@ -36,8 +40,14 @@ jobs: - name: Functions run: .github/lint-disallowed-functions-in-library.sh + - name: Logging messages should not have trailing newlines + run: .github/lint-no-trailing-newline-in-log-messages.sh + lint-go: name: Go + permissions: + contents: read + pull-requests: read runs-on: ubuntu-latest strategy: fail-fast: false @@ -47,5 +57,5 @@ jobs: - name: golangci-lint uses: golangci/golangci-lint-action@v3 with: - version: v1.31 + version: v1.45.2 args: $GOLANGCI_LINT_EXRA_ARGS diff --git a/.github/workflows/renovate-go-mod-fix.yaml b/.github/workflows/renovate-go-mod-fix.yaml index 5991822..0804642 100644 --- a/.github/workflows/renovate-go-mod-fix.yaml +++ b/.github/workflows/renovate-go-mod-fix.yaml @@ -15,6 +15,9 @@ on: branches: - renovate/* +permissions: + contents: write + jobs: go-mod-fix: runs-on: ubuntu-latest diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index cd788c9..300fac6 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -17,18 +17,22 @@ on: pull_request: branches: - master + +permissions: + contents: read + jobs: test: runs-on: ubuntu-latest strategy: matrix: - go: ["1.16", "1.17"] + go: ["1.17", "1.18"] fail-fast: false name: Go ${{ matrix.go }} steps: - uses: actions/checkout@v3 - - uses: actions/cache@v2 + - uses: actions/cache@v3 with: path: | ~/go/pkg/mod @@ -44,18 +48,31 @@ jobs: go-version: ${{ matrix.go }} - name: Setup go-acc - run: | - go get github.com/ory/go-acc - git checkout go.mod go.sum + run: go install github.com/ory/go-acc@latest + + - name: Set up gotestfmt + uses: haveyoudebuggedit/gotestfmt-action@v2 + with: + token: ${{ secrets.GITHUB_TOKEN }} # Avoid getting rate limited - name: Run test run: | TEST_BENCH_OPTION="-bench=." if [ -f .github/.ci.conf ]; then . .github/.ci.conf; fi + set -euo pipefail go-acc -o cover.out ./... -- \ ${TEST_BENCH_OPTION} \ - -v -race + -json \ + -v -race 2>&1 | grep -v '^go: downloading' | tee /tmp/gotest.log | gotestfmt + + - name: Upload test log + uses: actions/upload-artifact@v2 + if: always() + with: + name: test-log-${{ matrix.go }} + path: /tmp/gotest.log + if-no-files-found: error - name: Run TEST_HOOK run: | @@ -64,7 +81,6 @@ jobs: - uses: codecov/codecov-action@v2 with: - file: ./cover.out name: codecov-umbrella fail_ci_if_error: true flags: go @@ -73,13 +89,13 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - go: ["1.16", "1.17"] + go: ["1.17", "1.18"] fail-fast: false name: Go i386 ${{ matrix.go }} steps: - uses: actions/checkout@v3 - - uses: actions/cache@v2 + - uses: actions/cache@v3 with: path: | ~/go/pkg/mod @@ -117,7 +133,7 @@ jobs: with: node-version: '16.x' - - uses: actions/cache@v2 + - uses: actions/cache@v3 with: path: | ~/go/pkg/mod @@ -153,7 +169,6 @@ jobs: - uses: codecov/codecov-action@v2 with: - file: ./cover.out name: codecov-umbrella fail_ci_if_error: true flags: wasm diff --git a/.github/workflows/tidy-check.yaml b/.github/workflows/tidy-check.yaml index 3ab2c35..fa52ce9 100644 --- a/.github/workflows/tidy-check.yaml +++ b/.github/workflows/tidy-check.yaml @@ -18,6 +18,9 @@ on: branches: - master +permissions: + contents: read + jobs: Check: runs-on: ubuntu-latest diff --git a/.golangci.yml b/.golangci.yml index d6162c9..d7a88ec 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -15,14 +15,22 @@ linters-settings: linters: enable: - asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers + - bidichk # Checks for dangerous unicode character sequences - bodyclose # checks whether HTTP response body is closed successfully + - contextcheck # check the function whether use a non-inherited context - deadcode # Finds unused code + - decorder # check declaration order and count of types, constants, variables and functions - depguard # Go linter that checks if package imports are in a list of acceptable packages - dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f()) - dupl # Tool for code clone detection + - durationcheck # check for two durations multiplied together - errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases + - errchkjson # Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occations, where the check for the returned error can be omitted. + - errname # Checks that sentinel errors are prefixed with the `Err` and error types are suffixed with the `Error`. + - errorlint # errorlint is a linter for that can be used to find code that will cause problems with the error wrapping scheme introduced in Go 1.13. - exhaustive # check exhaustiveness of enum switch statements - exportloopref # checks for pointers to enclosing loop variables + - forcetypeassert # finds forced type assertions - gci # Gci control golang package import order and make it always deterministic. - gochecknoglobals # Checks that no globals are present in Go code - gochecknoinits # Checks that no init functions are present in Go code @@ -35,40 +43,62 @@ linters: - gofumpt # Gofumpt checks whether code was gofumpt-ed. - goheader # Checks is file header matches to pattern - goimports # Goimports does everything that gofmt does. Additionally it checks unused imports - - golint # Golint differs from gofmt. Gofmt reformats Go source code, whereas golint prints out style mistakes + - gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod. - gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations. - goprintffuncname # Checks that printf-like functions are named with `f` at the end - gosec # Inspects source code for security problems - gosimple # Linter for Go source code that specializes in simplifying a code - govet # Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string + - grouper # An analyzer to analyze expression groups. + - importas # Enforces consistent import aliases - ineffassign # Detects when assignments to existing variables are not used - misspell # Finds commonly misspelled English words in comments - nakedret # Finds naked returns in functions greater than a specified function length + - nilerr # Finds the code that returns nil even if it checks that the error is not nil. + - nilnil # Checks that there is no simultaneous return of `nil` error and an invalid value. - noctx # noctx finds sending http request without context.Context - - scopelint # Scopelint checks for unpinned variables in go programs + - predeclared # find code that shadows one of Go's predeclared identifiers + - revive # golint replacement, finds style mistakes - staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks - structcheck # Finds unused struct fields - stylecheck # Stylecheck is a replacement for golint + - tagliatelle # Checks the struct tags. + - tenv # tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17 + - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes - typecheck # Like the front-end of a Go compiler, parses and type-checks Go code - unconvert # Remove unnecessary type conversions - unparam # Reports unused function parameters - unused # Checks Go code for unused constants, variables, functions and types - varcheck # Finds unused global variables and constants + - wastedassign # wastedassign finds wasted assignment statements - whitespace # Tool for detection of leading and trailing whitespace disable: + - containedctx # containedctx is a linter that detects struct contained context.Context field + - cyclop # checks function and package cyclomatic complexity + - exhaustivestruct # Checks if all struct's fields are initialized + - forbidigo # Forbids identifiers - funlen # Tool for detection of long functions - gocyclo # Computes and checks the cyclomatic complexity of functions - godot # Check if comments end in a period - gomnd # An analyzer to detect magic numbers. + - ifshort # Checks that your code uses short syntax for if-statements whenever possible + - ireturn # Accept Interfaces, Return Concrete Types - lll # Reports long lines + - maintidx # maintidx measures the maintainability index of each function. + - makezero # Finds slice declarations with non-zero initial length - maligned # Tool to detect Go structs that would take less memory if their fields were sorted - nestif # Reports deeply nested if statements - nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity - nolintlint # Reports ill-formed or insufficient nolint directives + - paralleltest # paralleltest detects missing usage of t.Parallel() method in your Go test - prealloc # Finds slice declarations that could potentially be preallocated + - promlinter # Check Prometheus metrics naming via promlint - rowserrcheck # checks whether Err of rows is checked successfully - sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed. - testpackage # linter that makes you use a separate _test package + - thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers + - varnamelen # checks that the length of a variable's name matches its scope + - wrapcheck # Checks that errors returned from external packages are wrapped - wsl # Whitespace Linter - Forces you to use empty lines! issues: diff --git a/AUTHORS.txt b/AUTHORS.txt index 3606ac9..2b8d45c 100644 --- a/AUTHORS.txt +++ b/AUTHORS.txt @@ -4,8 +4,10 @@ # This file is auto generated, using git to list all individuals contributors. # see `.github/generate-authors.sh` for the scripting adamroach +Adrian Cable Agniva De Sarker Antoine Baché +Antoine Baché Atsushi Watanabe backkem chenkaiC4 diff --git a/context.go b/context.go index e0ad48c..bf871b2 100644 --- a/context.go +++ b/context.go @@ -86,9 +86,9 @@ func CreateContext(masterKey, masterSalt []byte, profile ProtectionProfile, opts switch profile { case ProtectionProfileAeadAes128Gcm: - c.cipher, err = newSrtpCipherAeadAesGcm(masterKey, masterSalt) - case ProtectionProfileAes128CmHmacSha1_80: - c.cipher, err = newSrtpCipherAesCmHmacSha1(masterKey, masterSalt) + c.cipher, err = newSrtpCipherAeadAesGcm(profile, masterKey, masterSalt) + case ProtectionProfileAes128CmHmacSha1_32, ProtectionProfileAes128CmHmacSha1_80: + c.cipher, err = newSrtpCipherAesCmHmacSha1(profile, masterKey, masterSalt) default: return nil, fmt.Errorf("%w: %#v", errNoSuchSRTPProfile, profile) } @@ -112,13 +112,13 @@ func CreateContext(masterKey, masterSalt []byte, profile ProtectionProfile, opts } // https://tools.ietf.org/html/rfc3550#appendix-A.1 -func (s *srtpSSRCState) nextRolloverCount(sequenceNumber uint16) (uint32, func()) { +func (s *srtpSSRCState) nextRolloverCount(sequenceNumber uint16) (uint32, int32) { seq := int32(sequenceNumber) localRoc := uint32(s.index >> 16) localSeq := int32(s.index & (seqNumMax - 1)) guessRoc := localRoc - var difference int32 = 0 + var difference int32 if s.rolloverHasProcessed { // When localROC is equal to 0, and entering seq-localSeq > seqNumMedian @@ -147,15 +147,17 @@ func (s *srtpSSRCState) nextRolloverCount(sequenceNumber uint16) (uint32, func() } } - return guessRoc, func() { - if !s.rolloverHasProcessed { - s.index |= uint64(sequenceNumber) - s.rolloverHasProcessed = true - return - } - if difference > 0 { - s.index += uint64(difference) - } + return guessRoc, difference +} + +func (s *srtpSSRCState) updateRolloverCount(sequenceNumber uint16, difference int32) { + if !s.rolloverHasProcessed { + s.index |= uint64(sequenceNumber) + s.rolloverHasProcessed = true + return + } + if difference > 0 { + s.index += uint64(difference) } } diff --git a/context_test.go b/context_test.go index e5b95f4..b3fecea 100644 --- a/context_test.go +++ b/context_test.go @@ -5,7 +5,7 @@ import ( ) func TestContextROC(t *testing.T) { - c, err := CreateContext(make([]byte, 16), make([]byte, 14), cipherContextAlgo) + c, err := CreateContext(make([]byte, 16), make([]byte, 14), profileCTR) if err != nil { t.Fatal(err) } @@ -24,7 +24,7 @@ func TestContextROC(t *testing.T) { } func TestContextIndex(t *testing.T) { - c, err := CreateContext(make([]byte, 16), make([]byte, 14), cipherContextAlgo) + c, err := CreateContext(make([]byte, 16), make([]byte, 14), profileCTR) if err != nil { t.Fatal(err) } diff --git a/crypto.go b/crypto.go new file mode 100644 index 0000000..e2eb9b2 --- /dev/null +++ b/crypto.go @@ -0,0 +1,58 @@ +package srtp + +import ( + "crypto/cipher" +) + +// xorBytes computes the exclusive-or of src1 and src2 and stores it in dst. +// It returns the number of bytes written. +func xorBytes(dst, src1, src2 []byte) int { + n := len(src1) + if len(src2) < n { + n = len(src2) + } + if len(dst) < n { + n = len(dst) + } + + for i := 0; i < n; i++ { + dst[i] = src1[i] ^ src2[i] + } + + return n +} + +// incrementCTR increments a big-endian integer of arbitrary size. +func incrementCTR(ctr []byte) { + for i := len(ctr) - 1; i >= 0; i-- { + ctr[i]++ + if ctr[i] != 0 { + break + } + } +} + +// xorBytesCTR performs CTR encryption and decryption. +// It is equivalent to cipher.NewCTR followed by XORKeyStream. +func xorBytesCTR(block cipher.Block, iv []byte, dst, src []byte) error { + if len(iv) != block.BlockSize() { + return errBadIVLength + } + + ctr := make([]byte, len(iv)) + copy(ctr, iv) + bs := block.BlockSize() + stream := make([]byte, bs) + + i := 0 + for i < len(src) { + block.Encrypt(stream, ctr) + incrementCTR(ctr) + n := xorBytes(dst[i:], src[i:], stream) + if n == 0 { + break + } + i += n + } + return nil +} diff --git a/crypto_test.go b/crypto_test.go new file mode 100644 index 0000000..4a5bf8f --- /dev/null +++ b/crypto_test.go @@ -0,0 +1,83 @@ +package srtp + +import ( + "crypto/aes" + "crypto/cipher" + "math/rand" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func xorBytesCTRReference(block cipher.Block, iv []byte, dst, src []byte) { + stream := cipher.NewCTR(block, iv) + stream.XORKeyStream(dst, src) +} + +func TestXorBytesCTR(t *testing.T) { + for keysize := 16; keysize < 64; keysize *= 2 { + key := make([]byte, keysize) + _, err := rand.Read(key) //nolint: gosec + require.NoError(t, err) + + block, err := aes.NewCipher(key) + require.NoError(t, err) + + iv := make([]byte, block.BlockSize()) + for i := 0; i < 1500; i++ { + src := make([]byte, i) + dst := make([]byte, i) + reference := make([]byte, i) + _, err = rand.Read(iv) //nolint: gosec + require.NoError(t, err) + + _, err = rand.Read(src) //nolint: gosec + require.NoError(t, err) + + assert.NoError(t, xorBytesCTR(block, iv, dst, src)) + xorBytesCTRReference(block, iv, reference, src) + require.Equal(t, dst, reference) + + // test overlap + assert.NoError(t, xorBytesCTR(block, iv, dst, dst)) + xorBytesCTRReference(block, iv, reference, reference) + require.Equal(t, dst, reference) + } + } +} + +func TestXorBytesCTRInvalidIvLength(t *testing.T) { + key := make([]byte, 16) + block, err := aes.NewCipher(key) + require.NoError(t, err) + + src := make([]byte, 1024) + dst := make([]byte, 1024) + + test := func(iv []byte) { + assert.Error(t, errBadIVLength, xorBytesCTR(block, iv, dst, src)) + } + + test(make([]byte, block.BlockSize()-1)) + test(make([]byte, block.BlockSize()+1)) +} + +func TestXorBytesBufferSize(t *testing.T) { + a := []byte{3} + b := []byte{5, 6} + dst := make([]byte, 3) + + xorBytes(dst, a, b) + require.Equal(t, dst, []byte{6, 0, 0}) + + xorBytes(dst, b, a) + require.Equal(t, dst, []byte{6, 0, 0}) + + a = []byte{1, 1, 1, 1} + b = []byte{2, 2, 2, 2} + dst = make([]byte, 3) + + xorBytes(dst, a, b) + require.Equal(t, dst, []byte{3, 3, 3}) +} diff --git a/errors.go b/errors.go index a702621..55a67bc 100644 --- a/errors.go +++ b/errors.go @@ -18,6 +18,7 @@ var ( errTooShortRTCP = errors.New("packet is too short to be rtcp packet") errPayloadDiffers = errors.New("payload differs") errStartedChannelUsedIncorrectly = errors.New("started channel used incorrectly, should only be closed") + errBadIVLength = errors.New("bad iv length in xorBytesCTR") errStreamNotInited = errors.New("stream has not been inited, unable to close") errStreamAlreadyClosed = errors.New("stream is already closed") @@ -25,16 +26,16 @@ var ( errFailedTypeAssertion = errors.New("failed to cast child") ) -type errorDuplicated struct { +type duplicatedError struct { Proto string // srtp or srtcp SSRC uint32 Index uint32 // sequence number or index } -func (e *errorDuplicated) Error() string { +func (e *duplicatedError) Error() string { return fmt.Sprintf("%s ssrc=%d index=%d: %v", e.Proto, e.SSRC, e.Index, errDuplicated) } -func (e *errorDuplicated) Unwrap() error { +func (e *duplicatedError) Unwrap() error { return errDuplicated } diff --git a/go.mod b/go.mod index 6af247e..9f0e9a6 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.14 require ( github.com/pion/logging v0.2.2 github.com/pion/rtcp v1.2.9 - github.com/pion/rtp v1.7.4 + github.com/pion/rtp v1.7.13 github.com/pion/transport v0.13.0 - github.com/stretchr/testify v1.7.0 + github.com/stretchr/testify v1.7.1 ) diff --git a/go.sum b/go.sum index 836d704..4b2d50e 100644 --- a/go.sum +++ b/go.sum @@ -6,15 +6,16 @@ github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA= github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8= github.com/pion/rtcp v1.2.9 h1:1ujStwg++IOLIEoOiIQ2s+qBuJ1VN81KW+9pMPsif+U= github.com/pion/rtcp v1.2.9/go.mod h1:qVPhiCzAm4D/rxb6XzKeyZiQK69yJpbUDJSF7TgrqNo= -github.com/pion/rtp v1.7.4 h1:4dMbjb1SuynU5OpA3kz1zHK+u+eOCQjW3MAeVHf1ODA= -github.com/pion/rtp v1.7.4/go.mod h1:bDb5n+BFZxXx0Ea7E5qe+klMuqiBrP+w8XSjiWtCUko= +github.com/pion/rtp v1.7.13 h1:qcHwlmtiI50t1XivvoawdCGTP4Uiypzfrsap+bijcoA= +github.com/pion/rtp v1.7.13/go.mod h1:bDb5n+BFZxXx0Ea7E5qe+klMuqiBrP+w8XSjiWtCUko= github.com/pion/transport v0.13.0 h1:KWTA5ZrQogizzYwPEciGtHPLwpAjE91FgXnyu+Hv2uY= github.com/pion/transport v0.13.0/go.mod h1:yxm9uXpK9bpBBWkITk13cLo1y5/ur5VQpG22ny6EP7g= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= golang.org/x/net v0.0.0-20211201190559-0a0e4e1bb54c h1:WtYZ93XtWSO5KlOMgPZu7hXY9WhMZpprvlm5VwvAl8c= golang.org/x/net v0.0.0-20211201190559-0a0e4e1bb54c/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/key_derivation.go b/key_derivation.go index 26386d4..09bc58b 100644 --- a/key_derivation.go +++ b/key_derivation.go @@ -48,16 +48,19 @@ func aesCmKeyDerivation(label byte, masterKey, masterSalt []byte, indexOverKdr i // - passing through 65,535 // i = 2^16 * ROC + SEQ // IV = (salt*2 ^ 16) | (ssrc*2 ^ 64) | (i*2 ^ 16) -func generateCounter(sequenceNumber uint16, rolloverCounter uint32, ssrc uint32, sessionSalt []byte) [16]byte { - var counter [16]byte +func generateCounter(sequenceNumber uint16, rolloverCounter uint32, ssrc uint32, sessionSalt []byte) (counter [16]byte) { + copy(counter[:], sessionSalt) - binary.BigEndian.PutUint32(counter[4:], ssrc) - binary.BigEndian.PutUint32(counter[8:], rolloverCounter) - binary.BigEndian.PutUint32(counter[12:], uint32(sequenceNumber)<<16) - - for i := range sessionSalt { - counter[i] ^= sessionSalt[i] - } + counter[4] ^= byte(ssrc >> 24) + counter[5] ^= byte(ssrc >> 16) + counter[6] ^= byte(ssrc >> 8) + counter[7] ^= byte(ssrc) + counter[8] ^= byte(rolloverCounter >> 24) + counter[9] ^= byte(rolloverCounter >> 16) + counter[10] ^= byte(rolloverCounter >> 8) + counter[11] ^= byte(rolloverCounter) + counter[12] ^= byte(sequenceNumber >> 8) + counter[13] ^= byte(sequenceNumber) return counter } diff --git a/key_derivation_test.go b/key_derivation_test.go index 42bf568..a2c7aa5 100644 --- a/key_derivation_test.go +++ b/key_derivation_test.go @@ -46,3 +46,19 @@ func TestIndexOverKDR(t *testing.T) { _, err := aesCmKeyDerivation(labelSRTPAuthenticationTag, []byte{}, []byte{}, 1, 0) assert.Error(t, err) } + +func BenchmarkGenerateCounter(b *testing.B) { + masterKey := []byte{0x0d, 0xcd, 0x21, 0x3e, 0x4c, 0xbc, 0xf2, 0x8f, 0x01, 0x7f, 0x69, 0x94, 0x40, 0x1e, 0x28, 0x89} + masterSalt := []byte{0x62, 0x77, 0x60, 0x38, 0xc0, 0x6d, 0xc9, 0x41, 0x9f, 0x6d, 0xd9, 0x43, 0x3e, 0x7c} + + s := &srtpSSRCState{ssrc: 4160032510} + + srtpSessionSalt, err := aesCmKeyDerivation(labelSRTPSalt, masterKey, masterSalt, 0, len(masterSalt)) + assert.NoError(b, err) + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + generateCounter(32846, uint32(s.index>>16), s.ssrc, srtpSessionSalt) + } +} diff --git a/option.go b/option.go index d6159f1..86ecd8e 100644 --- a/option.go +++ b/option.go @@ -8,7 +8,7 @@ import ( type ContextOption func(*Context) error // SRTPReplayProtection sets SRTP replay protection window size. -func SRTPReplayProtection(windowSize uint) ContextOption { // nolint:golint +func SRTPReplayProtection(windowSize uint) ContextOption { // nolint:revive return func(c *Context) error { c.newSRTPReplayDetector = func() replaydetector.ReplayDetector { return replaydetector.WithWrap(windowSize, maxSequenceNumber) @@ -28,7 +28,7 @@ func SRTCPReplayProtection(windowSize uint) ContextOption { } // SRTPNoReplayProtection disables SRTP replay protection. -func SRTPNoReplayProtection() ContextOption { // nolint:golint +func SRTPNoReplayProtection() ContextOption { // nolint:revive return func(c *Context) error { c.newSRTPReplayDetector = func() replaydetector.ReplayDetector { return &nopReplayDetector{} diff --git a/protection_profile.go b/protection_profile.go index 322eca0..71d9ac5 100644 --- a/protection_profile.go +++ b/protection_profile.go @@ -9,14 +9,13 @@ type ProtectionProfile uint16 // See https://www.iana.org/assignments/srtp-protection/srtp-protection.xhtml const ( ProtectionProfileAes128CmHmacSha1_80 ProtectionProfile = 0x0001 + ProtectionProfileAes128CmHmacSha1_32 ProtectionProfile = 0x0002 ProtectionProfileAeadAes128Gcm ProtectionProfile = 0x0007 ) func (p ProtectionProfile) keyLen() (int, error) { switch p { - case ProtectionProfileAes128CmHmacSha1_80: - fallthrough - case ProtectionProfileAeadAes128Gcm: + case ProtectionProfileAes128CmHmacSha1_32, ProtectionProfileAes128CmHmacSha1_80, ProtectionProfileAeadAes128Gcm: return 16, nil default: return 0, fmt.Errorf("%w: %#v", errNoSuchSRTPProfile, p) @@ -25,7 +24,7 @@ func (p ProtectionProfile) keyLen() (int, error) { func (p ProtectionProfile) saltLen() (int, error) { switch p { - case ProtectionProfileAes128CmHmacSha1_80: + case ProtectionProfileAes128CmHmacSha1_32, ProtectionProfileAes128CmHmacSha1_80: return 14, nil case ProtectionProfileAeadAes128Gcm: return 12, nil @@ -34,12 +33,25 @@ func (p ProtectionProfile) saltLen() (int, error) { } } -func (p ProtectionProfile) authTagLen() (int, error) { +func (p ProtectionProfile) rtpAuthTagLen() (int, error) { switch p { case ProtectionProfileAes128CmHmacSha1_80: - return (&srtpCipherAesCmHmacSha1{}).authTagLen(), nil + return 10, nil + case ProtectionProfileAes128CmHmacSha1_32: + return 4, nil + case ProtectionProfileAeadAes128Gcm: + return 0, nil + default: + return 0, fmt.Errorf("%w: %#v", errNoSuchSRTPProfile, p) + } +} + +func (p ProtectionProfile) rtcpAuthTagLen() (int, error) { + switch p { + case ProtectionProfileAes128CmHmacSha1_32, ProtectionProfileAes128CmHmacSha1_80: + return 10, nil case ProtectionProfileAeadAes128Gcm: - return (&srtpCipherAeadAesGcm{}).authTagLen(), nil + return 0, nil default: return 0, fmt.Errorf("%w: %#v", errNoSuchSRTPProfile, p) } @@ -47,10 +59,10 @@ func (p ProtectionProfile) authTagLen() (int, error) { func (p ProtectionProfile) aeadAuthTagLen() (int, error) { switch p { - case ProtectionProfileAes128CmHmacSha1_80: - return (&srtpCipherAesCmHmacSha1{}).aeadAuthTagLen(), nil + case ProtectionProfileAes128CmHmacSha1_32, ProtectionProfileAes128CmHmacSha1_80: + return 0, nil case ProtectionProfileAeadAes128Gcm: - return (&srtpCipherAeadAesGcm{}).aeadAuthTagLen(), nil + return 16, nil default: return 0, fmt.Errorf("%w: %#v", errNoSuchSRTPProfile, p) } @@ -58,7 +70,7 @@ func (p ProtectionProfile) aeadAuthTagLen() (int, error) { func (p ProtectionProfile) authKeyLen() (int, error) { switch p { - case ProtectionProfileAes128CmHmacSha1_80: + case ProtectionProfileAes128CmHmacSha1_32, ProtectionProfileAes128CmHmacSha1_80: return 20, nil case ProtectionProfileAeadAes128Gcm: return 0, nil diff --git a/session.go b/session.go index 770c203..7cd8c6f 100644 --- a/session.go +++ b/session.go @@ -1,6 +1,7 @@ package srtp import ( + "errors" "io" "net" "sync" @@ -136,7 +137,7 @@ func (s *session) start(localMasterKey, localMasterSalt, remoteMasterKey, remote var i int i, err = s.nextConn.Read(b) if err != nil { - if err != io.EOF { + if !errors.Is(err, io.EOF) { s.log.Error(err.Error()) } return diff --git a/session_srtcp.go b/session_srtcp.go index 2e6e168..8ad13f2 100644 --- a/session_srtcp.go +++ b/session_srtcp.go @@ -114,8 +114,11 @@ func (s *SessionSRTCP) write(buf []byte) (int, error) { return 0, errStartedChannelUsedIncorrectly } + ibuf := bufferpool.Get() + defer bufferpool.Put(ibuf) + s.session.localContextMutex.Lock() - encrypted, err := s.localContext.EncryptRTCP(nil, buf, nil) + encrypted, err := s.localContext.EncryptRTCP(ibuf.([]byte), buf, nil) s.session.localContextMutex.Unlock() if err != nil { diff --git a/session_srtcp_test.go b/session_srtcp_test.go index 1fb03f3..b2043d7 100644 --- a/session_srtcp_test.go +++ b/session_srtcp_test.go @@ -2,6 +2,7 @@ package srtp import ( "bytes" + "errors" "io" "net" "reflect" @@ -259,7 +260,7 @@ func TestSessionSRTCPReplayProtection(t *testing.T) { for { if ssrc, perr := getSenderSSRC(t, bReadStream); perr == nil { receivedSSRC = append(receivedSSRC, ssrc) - } else if perr == io.EOF { + } else if errors.Is(perr, io.EOF) { return } } @@ -301,7 +302,7 @@ func TestSessionSRTCPReplayProtection(t *testing.T) { } func getSenderSSRC(t *testing.T, stream *ReadStreamSRTCP) (ssrc uint32, err error) { - authTagSize, err := ProtectionProfileAes128CmHmacSha1_80.authTagLen() + authTagSize, err := ProtectionProfileAes128CmHmacSha1_80.rtcpAuthTagLen() if err != nil { return 0, err } @@ -309,7 +310,7 @@ func getSenderSSRC(t *testing.T, stream *ReadStreamSRTCP) (ssrc uint32, err erro const pliPacketSize = 8 readBuffer := make([]byte, pliPacketSize+authTagSize+srtcpIndexSize) n, _, err := stream.ReadRTCP(readBuffer) - if err == io.EOF { + if errors.Is(err, io.EOF) { return 0, err } if err != nil { diff --git a/session_srtp.go b/session_srtp.go index 2de27a9..66ac060 100644 --- a/session_srtp.go +++ b/session_srtp.go @@ -2,6 +2,7 @@ package srtp import ( "net" + "sync" "time" "github.com/pion/logging" @@ -111,21 +112,41 @@ func (s *SessionSRTP) Close() error { func (s *SessionSRTP) write(b []byte) (int, error) { packet := &rtp.Packet{} - err := packet.Unmarshal(b) - if err != nil { - return 0, nil + if err := packet.Unmarshal(b); err != nil { + return 0, err } return s.writeRTP(&packet.Header, packet.Payload) } +// bufferpool is a global pool of buffers used for encrypted packets in +// writeRTP below. Since it's global, buffers can be shared between +// different sessions, which amortizes the cost of allocating the pool. +// +// 1472 is the maximum Ethernet UDP payload. We give ourselves 20 bytes +// of slack for any authentication tags, which is more than enough for +// either CTR or GCM. If the buffer is too small, no harm, it will just +// get expanded by growBuffer. +var bufferpool = sync.Pool{ // nolint:gochecknoglobals + New: func() interface{} { + return make([]byte, 1492) + }, +} + func (s *SessionSRTP) writeRTP(header *rtp.Header, payload []byte) (int, error) { if _, ok := <-s.session.started; ok { return 0, errStartedChannelUsedIncorrectly } + // encryptRTP will either return our buffer, or, if it is too + // small, allocate a new buffer itself. In either case, it is + // safe to put the buffer back into the pool, but only after + // nextConn.Write has returned. + ibuf := bufferpool.Get() + defer bufferpool.Put(ibuf) + s.session.localContextMutex.Lock() - encrypted, err := s.localContext.encryptRTP(nil, header, payload) + encrypted, err := s.localContext.encryptRTP(ibuf.([]byte), header, payload) s.session.localContextMutex.Unlock() if err != nil { diff --git a/session_srtp_test.go b/session_srtp_test.go index 6649d27..a3798ba 100644 --- a/session_srtp_test.go +++ b/session_srtp_test.go @@ -2,6 +2,7 @@ package srtp import ( "bytes" + "errors" "io" "net" "reflect" @@ -313,7 +314,7 @@ func TestSessionSRTPReplayProtection(t *testing.T) { for { if seq, perr := assertPayloadSRTP(t, bReadStream, rtpHeaderSize, testPayload); perr == nil { receivedSequenceNumber = append(receivedSequenceNumber, seq) - } else if perr == io.EOF { + } else if errors.Is(perr, io.EOF) { return } } @@ -357,7 +358,7 @@ func TestSessionSRTPReplayProtection(t *testing.T) { func assertPayloadSRTP(t *testing.T, stream *ReadStreamSRTP, headerSize int, expectedPayload []byte) (seq uint16, err error) { readBuffer := make([]byte, headerSize+len(expectedPayload)) n, hdr, err := stream.ReadRTP(readBuffer) - if err == io.EOF { + if errors.Is(err, io.EOF) { return 0, err } if err != nil { diff --git a/srtcp.go b/srtcp.go index dbf5125..d3e387b 100644 --- a/srtcp.go +++ b/srtcp.go @@ -11,9 +11,18 @@ const maxSRTCPIndex = 0x7FFFFFFF func (c *Context) decryptRTCP(dst, encrypted []byte) ([]byte, error) { out := allocateIfMismatch(dst, encrypted) - tailOffset := len(encrypted) - (c.cipher.authTagLen() + srtcpIndexSize) - if tailOffset < 0 { + authTagLen, err := c.cipher.rtcpAuthTagLen() + if err != nil { + return nil, err + } + aeadAuthTagLen, err := c.cipher.aeadAuthTagLen() + if err != nil { + return nil, err + } + tailOffset := len(encrypted) - (authTagLen + srtcpIndexSize) + + if tailOffset < aeadAuthTagLen { return nil, fmt.Errorf("%w: %d", errTooShortRTCP, len(encrypted)) } else if isEncrypted := encrypted[tailOffset] >> 7; isEncrypted == 0 { return out, nil @@ -25,10 +34,10 @@ func (c *Context) decryptRTCP(dst, encrypted []byte) ([]byte, error) { s := c.getSRTCPSSRCState(ssrc) markAsValid, ok := s.replayDetector.Check(uint64(index)) if !ok { - return nil, &errorDuplicated{Proto: "srtcp", SSRC: ssrc, Index: index} + return nil, &duplicatedError{Proto: "srtcp", SSRC: ssrc, Index: index} } - out, err := c.cipher.decryptRTCP(out, encrypted, index, ssrc) + out, err = c.cipher.decryptRTCP(out, encrypted, index, ssrc) if err != nil { return nil, err } diff --git a/srtcp_test.go b/srtcp_test.go index ffd34bf..f2e870d 100644 --- a/srtcp_test.go +++ b/srtcp_test.go @@ -226,7 +226,7 @@ func TestRTCPLifecycleInPlace(t *testing.T) { testCase := testCase t.Run(caseName, func(t *testing.T) { assert := assert.New(t) - authTagLen, err := testCase.algo.authTagLen() + authTagLen, err := testCase.algo.rtcpAuthTagLen() assert.NoError(err) aeadAuthTagLen, err := testCase.algo.aeadAuthTagLen() @@ -288,11 +288,7 @@ func TestRTCPLifecyclePartialAllocation(t *testing.T) { for caseName, testCase := range rtcpTestCases() { testCase := testCase t.Run(caseName, func(t *testing.T) { - if testCase.algo == ProtectionProfileAeadAes128Gcm { - t.Skip("FIXME: DecryptRTCP(nil, input, nil) for ProtectionProfileAeadAes128Gcm changes input data") - } assert := assert.New(t) - encryptHeader := &rtcp.Header{} encryptContext, err := CreateContext(testCase.masterKey, testCase.masterSalt, testCase.algo) if err != nil { @@ -338,7 +334,7 @@ func TestRTCPInvalidAuthTag(t *testing.T) { testCase := testCase t.Run(caseName, func(t *testing.T) { assert := assert.New(t) - authTagLen, err := testCase.algo.authTagLen() + authTagLen, err := testCase.algo.rtcpAuthTagLen() assert.NoError(err) aeadAuthTagLen, err := testCase.algo.aeadAuthTagLen() @@ -420,7 +416,7 @@ func TestEncryptRTCPSeparation(t *testing.T) { encryptContext, err := CreateContext(testCase.masterKey, testCase.masterSalt, testCase.algo) assert.NoError(err) - authTagLen, err := testCase.algo.authTagLen() + authTagLen, err := testCase.algo.rtcpAuthTagLen() assert.NoError(err) decryptContext, err := CreateContext( @@ -464,3 +460,22 @@ func TestEncryptRTCPSeparation(t *testing.T) { }) } } + +func TestRTCPDecryptShortenedPacket(t *testing.T) { + for caseName, testCase := range rtcpTestCasesSingle() { + testCase := testCase + t.Run(caseName, func(t *testing.T) { + pkt := testCase.packets[0] + for i := 1; i < len(pkt.encrypted)-1; i++ { + packet := pkt.encrypted[:i] + decryptContext, err := CreateContext(testCase.masterKey, testCase.masterSalt, testCase.algo) + if err != nil { + t.Errorf("CreateContext failed: %v", err) + } + assert.NotPanics(t, func() { + _, _ = decryptContext.DecryptRTCP(nil, packet, nil) + }, "Panic on length %d/%d", i, len(pkt.encrypted)) + } + }) + } +} diff --git a/srtp.go b/srtp.go index 2628372..0feaf0f 100644 --- a/srtp.go +++ b/srtp.go @@ -10,21 +10,25 @@ func (c *Context) decryptRTP(dst, ciphertext []byte, header *rtp.Header, headerL markAsValid, ok := s.replayDetector.Check(uint64(header.SequenceNumber)) if !ok { - return nil, &errorDuplicated{ + return nil, &duplicatedError{ Proto: "srtp", SSRC: header.SSRC, Index: uint32(header.SequenceNumber), } } - dst = growBufferSize(dst, len(ciphertext)-c.cipher.authTagLen()) - roc, updateROC := s.nextRolloverCount(header.SequenceNumber) + authTagLen, err := c.cipher.rtpAuthTagLen() + if err != nil { + return nil, err + } + dst = growBufferSize(dst, len(ciphertext)-authTagLen) + roc, diff := s.nextRolloverCount(header.SequenceNumber) - dst, err := c.cipher.decryptRTP(dst, ciphertext, header, headerLen, roc) + dst, err = c.cipher.decryptRTP(dst, ciphertext, header, headerLen, roc) if err != nil { return nil, err } markAsValid() - updateROC() + s.updateRolloverCount(header.SequenceNumber, diff) return dst, nil } @@ -63,8 +67,8 @@ func (c *Context) EncryptRTP(dst []byte, plaintext []byte, header *rtp.Header) ( // Similar to above but faster because it can avoid unmarshaling the header and marshaling the payload. func (c *Context) encryptRTP(dst []byte, header *rtp.Header, payload []byte) (ciphertext []byte, err error) { s := c.getSRTPSSRCState(header.SSRC) - roc, updateROC := s.nextRolloverCount(header.SequenceNumber) - updateROC() + roc, diff := s.nextRolloverCount(header.SequenceNumber) + s.updateRolloverCount(header.SequenceNumber, diff) return c.cipher.encryptRTP(dst, header, payload, roc) } diff --git a/srtp_cipher.go b/srtp_cipher.go index 2cdf325..c272310 100644 --- a/srtp_cipher.go +++ b/srtp_cipher.go @@ -7,10 +7,11 @@ import "github.com/pion/rtp" type srtpCipher interface { // authTagLen returns auth key length of the cipher. // See the note below. - authTagLen() int + rtpAuthTagLen() (int, error) + rtcpAuthTagLen() (int, error) // aeadAuthTagLen returns AEAD auth key length of the cipher. // See the note below. - aeadAuthTagLen() int + aeadAuthTagLen() (int, error) getRTCPIndex([]byte) uint32 encryptRTP([]byte, *rtp.Header, []byte, uint32) ([]byte, error) diff --git a/srtp_cipher_aead_aes_gcm.go b/srtp_cipher_aead_aes_gcm.go index d76c404..110b80a 100644 --- a/srtp_cipher_aead_aes_gcm.go +++ b/srtp_cipher_aead_aes_gcm.go @@ -13,13 +13,15 @@ const ( ) type srtpCipherAeadAesGcm struct { + ProtectionProfile + srtpCipher, srtcpCipher cipher.AEAD srtpSessionSalt, srtcpSessionSalt []byte } -func newSrtpCipherAeadAesGcm(masterKey, masterSalt []byte) (*srtpCipherAeadAesGcm, error) { - s := &srtpCipherAeadAesGcm{} +func newSrtpCipherAeadAesGcm(profile ProtectionProfile, masterKey, masterSalt []byte) (*srtpCipherAeadAesGcm, error) { + s := &srtpCipherAeadAesGcm{ProtectionProfile: profile} srtpSessionKey, err := aesCmKeyDerivation(labelSRTPEncryption, masterKey, masterSalt, 0, len(masterKey)) if err != nil { @@ -60,33 +62,31 @@ func newSrtpCipherAeadAesGcm(masterKey, masterSalt []byte) (*srtpCipherAeadAesGc return s, nil } -func (s *srtpCipherAeadAesGcm) authTagLen() int { - return 0 -} - -func (s *srtpCipherAeadAesGcm) aeadAuthTagLen() int { - return 16 -} - func (s *srtpCipherAeadAesGcm) encryptRTP(dst []byte, header *rtp.Header, payload []byte, roc uint32) (ciphertext []byte, err error) { // Grow the given buffer to fit the output. - dst = growBufferSize(dst, header.MarshalSize()+len(payload)+s.aeadAuthTagLen()) + authTagLen, err := s.aeadAuthTagLen() + if err != nil { + return nil, err + } + dst = growBufferSize(dst, header.MarshalSize()+len(payload)+authTagLen) - hdr, err := header.Marshal() + n, err := header.MarshalTo(dst) if err != nil { return nil, err } iv := s.rtpInitializationVector(header, roc) - nHdr := len(hdr) - s.srtpCipher.Seal(dst[nHdr:nHdr], iv[:], payload, hdr) - copy(dst[:nHdr], hdr) + s.srtpCipher.Seal(dst[n:n], iv[:], payload, dst[:n]) return dst, nil } func (s *srtpCipherAeadAesGcm) decryptRTP(dst, ciphertext []byte, header *rtp.Header, headerLen int, roc uint32) ([]byte, error) { // Grow the given buffer to fit the output. - nDst := len(ciphertext) - s.aeadAuthTagLen() + authTagLen, err := s.aeadAuthTagLen() + if err != nil { + return nil, err + } + nDst := len(ciphertext) - authTagLen if nDst < 0 { // Size of ciphertext is shorter than AEAD auth tag len. return nil, errFailedToVerifyAuthTag @@ -106,7 +106,11 @@ func (s *srtpCipherAeadAesGcm) decryptRTP(dst, ciphertext []byte, header *rtp.He } func (s *srtpCipherAeadAesGcm) encryptRTCP(dst, decrypted []byte, srtcpIndex uint32, ssrc uint32) ([]byte, error) { - aadPos := len(decrypted) + s.aeadAuthTagLen() + authTagLen, err := s.aeadAuthTagLen() + if err != nil { + return nil, err + } + aadPos := len(decrypted) + authTagLen // Grow the given buffer to fit the output. dst = growBufferSize(dst, aadPos+srtcpIndexSize) @@ -123,7 +127,11 @@ func (s *srtpCipherAeadAesGcm) encryptRTCP(dst, decrypted []byte, srtcpIndex uin func (s *srtpCipherAeadAesGcm) decryptRTCP(dst, encrypted []byte, srtcpIndex, ssrc uint32) ([]byte, error) { aadPos := len(encrypted) - srtcpIndexSize // Grow the given buffer to fit the output. - nDst := aadPos - s.aeadAuthTagLen() + authTagLen, err := s.aeadAuthTagLen() + if err != nil { + return nil, err + } + nDst := aadPos - authTagLen if nDst < 0 { // Size of ciphertext is shorter than AEAD auth tag len. return nil, errFailedToVerifyAuthTag diff --git a/srtp_cipher_aes_cm_hmac_sha1.go b/srtp_cipher_aes_cm_hmac_sha1.go index 8ca7317..0e3af50 100644 --- a/srtp_cipher_aes_cm_hmac_sha1.go +++ b/srtp_cipher_aes_cm_hmac_sha1.go @@ -13,6 +13,8 @@ import ( //nolint:gci ) type srtpCipherAesCmHmacSha1 struct { + ProtectionProfile + srtpSessionSalt []byte srtpSessionAuth hash.Hash srtpBlock cipher.Block @@ -22,8 +24,8 @@ type srtpCipherAesCmHmacSha1 struct { srtcpBlock cipher.Block } -func newSrtpCipherAesCmHmacSha1(masterKey, masterSalt []byte) (*srtpCipherAesCmHmacSha1, error) { - s := &srtpCipherAesCmHmacSha1{} +func newSrtpCipherAesCmHmacSha1(profile ProtectionProfile, masterKey, masterSalt []byte) (*srtpCipherAesCmHmacSha1, error) { + s := &srtpCipherAesCmHmacSha1{ProtectionProfile: profile} srtpSessionKey, err := aesCmKeyDerivation(labelSRTPEncryption, masterKey, masterSalt, 0, len(masterKey)) if err != nil { return nil, err @@ -44,7 +46,7 @@ func newSrtpCipherAesCmHmacSha1(masterKey, masterSalt []byte) (*srtpCipherAesCmH return nil, err } - authKeyLen, err := ProtectionProfileAes128CmHmacSha1_80.authKeyLen() + authKeyLen, err := profile.authKeyLen() if err != nil { return nil, err } @@ -64,17 +66,13 @@ func newSrtpCipherAesCmHmacSha1(masterKey, masterSalt []byte) (*srtpCipherAesCmH return s, nil } -func (s *srtpCipherAesCmHmacSha1) authTagLen() int { - return 10 -} - -func (s *srtpCipherAesCmHmacSha1) aeadAuthTagLen() int { - return 0 -} - func (s *srtpCipherAesCmHmacSha1) encryptRTP(dst []byte, header *rtp.Header, payload []byte, roc uint32) (ciphertext []byte, err error) { // Grow the given buffer to fit the output. - dst = growBufferSize(dst, header.MarshalSize()+len(payload)+s.authTagLen()) + authTagLen, err := s.rtpAuthTagLen() + if err != nil { + return nil, err + } + dst = growBufferSize(dst, header.MarshalSize()+len(payload)+authTagLen) // Copy the header unencrypted. n, err := header.MarshalTo(dst) @@ -84,8 +82,9 @@ func (s *srtpCipherAesCmHmacSha1) encryptRTP(dst []byte, header *rtp.Header, pay // Encrypt the payload counter := generateCounter(header.SequenceNumber, roc, header.SSRC, s.srtpSessionSalt) - stream := cipher.NewCTR(s.srtpBlock, counter[:]) - stream.XORKeyStream(dst[n:], payload) + if err = xorBytesCTR(s.srtpBlock, counter[:], dst[n:], payload); err != nil { + return nil, err + } n += len(payload) // Generate the auth tag. @@ -102,8 +101,12 @@ func (s *srtpCipherAesCmHmacSha1) encryptRTP(dst []byte, header *rtp.Header, pay func (s *srtpCipherAesCmHmacSha1) decryptRTP(dst, ciphertext []byte, header *rtp.Header, headerLen int, roc uint32) ([]byte, error) { // Split the auth tag and the cipher text into two parts. - actualTag := ciphertext[len(ciphertext)-s.authTagLen():] - ciphertext = ciphertext[:len(ciphertext)-s.authTagLen()] + authTagLen, err := s.rtpAuthTagLen() + if err != nil { + return nil, err + } + actualTag := ciphertext[len(ciphertext)-authTagLen:] + ciphertext = ciphertext[:len(ciphertext)-authTagLen] // Generate the auth tag we expect to see from the ciphertext. expectedTag, err := s.generateSrtpAuthTag(ciphertext, roc) @@ -122,9 +125,10 @@ func (s *srtpCipherAesCmHmacSha1) decryptRTP(dst, ciphertext []byte, header *rtp // Decrypt the ciphertext for the payload. counter := generateCounter(header.SequenceNumber, roc, header.SSRC, s.srtpSessionSalt) - stream := cipher.NewCTR(s.srtpBlock, counter[:]) - stream.XORKeyStream(dst[headerLen:], ciphertext[headerLen:]) - return dst, nil + err = xorBytesCTR( + s.srtpBlock, counter[:], dst[headerLen:], ciphertext[headerLen:], + ) + return dst, err } func (s *srtpCipherAesCmHmacSha1) encryptRTCP(dst, decrypted []byte, srtcpIndex uint32, ssrc uint32) ([]byte, error) { @@ -132,8 +136,9 @@ func (s *srtpCipherAesCmHmacSha1) encryptRTCP(dst, decrypted []byte, srtcpIndex // Encrypt everything after header counter := generateCounter(uint16(srtcpIndex&0xffff), srtcpIndex>>16, ssrc, s.srtcpSessionSalt) - stream := cipher.NewCTR(s.srtcpBlock, counter[:]) - stream.XORKeyStream(dst[8:], dst[8:]) + if err := xorBytesCTR(s.srtcpBlock, counter[:], dst[8:], dst[8:]); err != nil { + return nil, err + } // Add SRTCP Index and set Encryption bit dst = append(dst, make([]byte, 4)...) @@ -148,24 +153,27 @@ func (s *srtpCipherAesCmHmacSha1) encryptRTCP(dst, decrypted []byte, srtcpIndex } func (s *srtpCipherAesCmHmacSha1) decryptRTCP(out, encrypted []byte, index, ssrc uint32) ([]byte, error) { - tailOffset := len(encrypted) - (s.authTagLen() + srtcpIndexSize) + authTagLen, err := s.rtcpAuthTagLen() + if err != nil { + return nil, err + } + tailOffset := len(encrypted) - (authTagLen + srtcpIndexSize) out = out[0:tailOffset] - expectedTag, err := s.generateSrtcpAuthTag(encrypted[:len(encrypted)-s.authTagLen()]) + expectedTag, err := s.generateSrtcpAuthTag(encrypted[:len(encrypted)-authTagLen]) if err != nil { return nil, err } - actualTag := encrypted[len(encrypted)-s.authTagLen():] + actualTag := encrypted[len(encrypted)-authTagLen:] if subtle.ConstantTimeCompare(actualTag, expectedTag) != 1 { return nil, errFailedToVerifyAuthTag } counter := generateCounter(uint16(index&0xffff), index>>16, ssrc, s.srtcpSessionSalt) - stream := cipher.NewCTR(s.srtcpBlock, counter[:]) - stream.XORKeyStream(out[8:], out[8:]) + err = xorBytesCTR(s.srtcpBlock, counter[:], out[8:], out[8:]) - return out, nil + return out, err } func (s *srtpCipherAesCmHmacSha1) generateSrtpAuthTag(buf []byte, roc uint32) ([]byte, error) { @@ -198,8 +206,12 @@ func (s *srtpCipherAesCmHmacSha1) generateSrtpAuthTag(buf []byte, roc uint32) ([ return nil, err } - // Truncate the hash to the first 10 bytes. - return s.srtpSessionAuth.Sum(nil)[0:s.authTagLen()], nil + // Truncate the hash to the size indicated by the profile + authTagLen, err := s.rtpAuthTagLen() + if err != nil { + return nil, err + } + return s.srtpSessionAuth.Sum(nil)[0:authTagLen], nil } func (s *srtpCipherAesCmHmacSha1) generateSrtcpAuthTag(buf []byte) ([]byte, error) { @@ -219,12 +231,17 @@ func (s *srtpCipherAesCmHmacSha1) generateSrtcpAuthTag(buf []byte) ([]byte, erro if _, err := s.srtcpSessionAuth.Write(buf); err != nil { return nil, err } + authTagLen, err := s.rtcpAuthTagLen() + if err != nil { + return nil, err + } - return s.srtcpSessionAuth.Sum(nil)[0:s.authTagLen()], nil + return s.srtcpSessionAuth.Sum(nil)[0:authTagLen], nil } func (s *srtpCipherAesCmHmacSha1) getRTCPIndex(in []byte) uint32 { - tailOffset := len(in) - (s.authTagLen() + srtcpIndexSize) + authTagLen, _ := s.rtcpAuthTagLen() + tailOffset := len(in) - (authTagLen + srtcpIndexSize) srtcpIndexBuffer := in[tailOffset : tailOffset+srtcpIndexSize] return binary.BigEndian.Uint32(srtcpIndexBuffer) &^ (1 << 31) } diff --git a/srtp_test.go b/srtp_test.go index 0f9236f..8333b48 100644 --- a/srtp_test.go +++ b/srtp_test.go @@ -10,35 +10,53 @@ import ( ) const ( - cipherContextAlgo = ProtectionProfileAes128CmHmacSha1_80 - defaultSsrc = 0 + profileCTR = ProtectionProfileAes128CmHmacSha1_80 + profileGCM = ProtectionProfileAeadAes128Gcm + defaultSsrc = 0 ) type rtpTestCase struct { sequenceNumber uint16 - encrypted []byte + encryptedCTR []byte + encryptedGCM []byte } -func TestKeyLen(t *testing.T) { - keyLen, err := cipherContextAlgo.keyLen() +func (tc rtpTestCase) encrypted(profile ProtectionProfile) []byte { + switch profile { + case profileCTR: + return tc.encryptedCTR + case profileGCM: + return tc.encryptedGCM + default: + panic("unknown profile") + } +} + +func testKeyLen(t *testing.T, profile ProtectionProfile) { + keyLen, err := profile.keyLen() assert.NoError(t, err) - saltLen, err := cipherContextAlgo.saltLen() + saltLen, err := profile.saltLen() assert.NoError(t, err) - if _, err := CreateContext([]byte{}, make([]byte, saltLen), cipherContextAlgo); err == nil { + if _, err := CreateContext([]byte{}, make([]byte, saltLen), profile); err == nil { t.Errorf("CreateContext accepted a 0 length key") } - if _, err := CreateContext(make([]byte, keyLen), []byte{}, cipherContextAlgo); err == nil { + if _, err := CreateContext(make([]byte, keyLen), []byte{}, profile); err == nil { t.Errorf("CreateContext accepted a 0 length salt") } - if _, err := CreateContext(make([]byte, keyLen), make([]byte, saltLen), cipherContextAlgo); err != nil { + if _, err := CreateContext(make([]byte, keyLen), make([]byte, saltLen), profile); err != nil { t.Errorf("CreateContext failed with a valid length key and salt: %v", err) } } +func TestKeyLen(t *testing.T) { + t.Run("CTR", func(t *testing.T) { testKeyLen(t, profileCTR) }) + t.Run("GCM", func(t *testing.T) { testKeyLen(t, profileGCM) }) +} + func TestValidPacketCounter(t *testing.T) { masterKey := []byte{0x0d, 0xcd, 0x21, 0x3e, 0x4c, 0xbc, 0xf2, 0x8f, 0x01, 0x7f, 0x69, 0x94, 0x40, 0x1e, 0x28, 0x89} masterSalt := []byte{0x62, 0x77, 0x60, 0x38, 0xc0, 0x6d, 0xc9, 0x41, 0x9f, 0x6d, 0xd9, 0x43, 0x3e, 0x7c} @@ -58,11 +76,11 @@ func TestRolloverCount(t *testing.T) { s := &srtpSSRCState{ssrc: defaultSsrc} // Set initial seqnum - roc, update := s.nextRolloverCount(65530) + roc, diff := s.nextRolloverCount(65530) if roc != 0 { t.Errorf("Initial rolloverCounter must be 0") } - update() + s.updateRolloverCount(65530, diff) // Invalid packets never update ROC _, _ = s.nextRolloverCount(0) @@ -72,73 +90,84 @@ func TestRolloverCount(t *testing.T) { _, _ = s.nextRolloverCount(0) // We rolled over to 0 - roc, update = s.nextRolloverCount(0) + roc, diff = s.nextRolloverCount(0) if roc != 1 { t.Errorf("rolloverCounter was not updated after it crossed 0") } - update() + s.updateRolloverCount(0, diff) - roc, update = s.nextRolloverCount(65530) + roc, diff = s.nextRolloverCount(65530) if roc != 0 { t.Errorf("rolloverCounter was not updated when it rolled back, failed to handle out of order") } - update() + s.updateRolloverCount(65530, diff) - roc, update = s.nextRolloverCount(5) + roc, diff = s.nextRolloverCount(5) if roc != 1 { t.Errorf("rolloverCounter was not updated when it rolled over initial, to handle out of order") } - update() + s.updateRolloverCount(5, diff) - _, update = s.nextRolloverCount(6) - update() - _, update = s.nextRolloverCount(7) - update() - roc, update = s.nextRolloverCount(8) + _, diff = s.nextRolloverCount(6) + s.updateRolloverCount(6, diff) + _, diff = s.nextRolloverCount(7) + s.updateRolloverCount(7, diff) + roc, diff = s.nextRolloverCount(8) if roc != 1 { t.Errorf("rolloverCounter was improperly updated for non-significant packets") } - update() + s.updateRolloverCount(8, diff) // valid packets never update ROC - roc, update = s.nextRolloverCount(0x4000) + roc, diff = s.nextRolloverCount(0x4000) if roc != 1 { t.Errorf("rolloverCounter was improperly updated for non-significant packets") } - update() - roc, update = s.nextRolloverCount(0x8000) + s.updateRolloverCount(0x4000, diff) + roc, diff = s.nextRolloverCount(0x8000) if roc != 1 { t.Errorf("rolloverCounter was improperly updated for non-significant packets") } - update() - roc, update = s.nextRolloverCount(0xFFFF) + s.updateRolloverCount(0x8000, diff) + roc, diff = s.nextRolloverCount(0xFFFF) if roc != 1 { t.Errorf("rolloverCounter was improperly updated for non-significant packets") } - update() + s.updateRolloverCount(0xFFFF, diff) roc, _ = s.nextRolloverCount(0) if roc != 2 { t.Errorf("rolloverCounter must be incremented after wrapping, got %d", roc) } } -func buildTestContext(opts ...ContextOption) (*Context, error) { +func buildTestContext(profile ProtectionProfile, opts ...ContextOption) (*Context, error) { + keyLen, err := profile.keyLen() + if err != nil { + return nil, err + } + saltLen, err := profile.saltLen() + if err != nil { + return nil, err + } + masterKey := []byte{0x0d, 0xcd, 0x21, 0x3e, 0x4c, 0xbc, 0xf2, 0x8f, 0x01, 0x7f, 0x69, 0x94, 0x40, 0x1e, 0x28, 0x89} + masterKey = masterKey[:keyLen] masterSalt := []byte{0x62, 0x77, 0x60, 0x38, 0xc0, 0x6d, 0xc9, 0x41, 0x9f, 0x6d, 0xd9, 0x43, 0x3e, 0x7c} + masterSalt = masterSalt[:saltLen] - return CreateContext(masterKey, masterSalt, cipherContextAlgo, opts...) + return CreateContext(masterKey, masterSalt, profile, opts...) } func TestRTPInvalidAuth(t *testing.T) { masterKey := []byte{0x0d, 0xcd, 0x21, 0x3e, 0x4c, 0xbc, 0xf2, 0x8f, 0x01, 0x7f, 0x69, 0x94, 0x40, 0x1e, 0x28, 0x89} invalidSalt := []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00} - encryptContext, err := buildTestContext() + encryptContext, err := buildTestContext(profileCTR) if err != nil { t.Fatal(err) } - invalidContext, err := CreateContext(masterKey, invalidSalt, cipherContextAlgo) + invalidContext, err := CreateContext(masterKey, invalidSalt, profileCTR) if err != nil { t.Errorf("CreateContext failed: %v", err) } @@ -167,44 +196,50 @@ func rtpTestCases() []rtpTestCase { return []rtpTestCase{ { sequenceNumber: 5000, - encrypted: []byte{0x6d, 0xd3, 0x7e, 0xd5, 0x99, 0xb7, 0x2d, 0x28, 0xb1, 0xf3, 0xa1, 0xf0, 0xc, 0xfb, 0xfd, 0x8}, + encryptedCTR: []byte{0x6d, 0xd3, 0x7e, 0xd5, 0x99, 0xb7, 0x2d, 0x28, 0xb1, 0xf3, 0xa1, 0xf0, 0xc, 0xfb, 0xfd, 0x8}, + encryptedGCM: []byte{0x05, 0x39, 0x62, 0xbb, 0x50, 0x2a, 0x08, 0x19, 0xc7, 0xcc, 0xc9, 0x24, 0xb8, 0xd9, 0x7a, 0xe5, 0xad, 0x99, 0x06, 0xc7, 0x3b, 0}, }, { sequenceNumber: 5001, - encrypted: []byte{0xda, 0x47, 0xb, 0x2a, 0x74, 0x53, 0x65, 0xbd, 0x2f, 0xeb, 0xdc, 0x4b, 0x6d, 0x23, 0xf3, 0xde}, + encryptedCTR: []byte{0xda, 0x47, 0xb, 0x2a, 0x74, 0x53, 0x65, 0xbd, 0x2f, 0xeb, 0xdc, 0x4b, 0x6d, 0x23, 0xf3, 0xde}, + encryptedGCM: []byte{0xb0, 0xbc, 0xfc, 0xb0, 0x15, 0x2c, 0xa0, 0x15, 0xb5, 0xa8, 0xcd, 0x0d, 0x65, 0xfa, 0x98, 0xb3, 0x09, 0xb1, 0xf8, 0x4b, 0x1c, 0xfa}, }, { sequenceNumber: 5002, - encrypted: []byte{0x6e, 0xa7, 0x69, 0x8d, 0x24, 0x6d, 0xdc, 0xbf, 0xec, 0x2, 0x1c, 0xd1, 0x60, 0x76, 0xc1, 0xe}, + encryptedCTR: []byte{0x6e, 0xa7, 0x69, 0x8d, 0x24, 0x6d, 0xdc, 0xbf, 0xec, 0x2, 0x1c, 0xd1, 0x60, 0x76, 0xc1, 0xe}, + encryptedGCM: []byte{0x5e, 0x20, 0x6a, 0xbf, 0x58, 0x7e, 0x24, 0xc0, 0x15, 0x94, 0x7a, 0xe2, 0x49, 0x25, 0xd4, 0xd4, 0x08, 0xe2, 0xf1, 0x47, 0x7a, 0x33}, }, { sequenceNumber: 5003, - encrypted: []byte{0x24, 0x7e, 0x96, 0xc8, 0x7d, 0x33, 0xa2, 0x92, 0x8d, 0x13, 0x8d, 0xe0, 0x76, 0x9f, 0x8, 0xdc}, + encryptedCTR: []byte{0x24, 0x7e, 0x96, 0xc8, 0x7d, 0x33, 0xa2, 0x92, 0x8d, 0x13, 0x8d, 0xe0, 0x76, 0x9f, 0x8, 0xdc}, + encryptedGCM: []byte{0xb0, 0x63, 0x14, 0xe7, 0xd2, 0x29, 0xca, 0x92, 0x8c, 0x97, 0x25, 0xd2, 0x50, 0x69, 0x6e, 0x1b, 0x04, 0xb9, 0x37, 0xa5, 0xa1, 0xc5}, }, { sequenceNumber: 5004, - encrypted: []byte{0x75, 0x43, 0x28, 0xe4, 0x3a, 0x77, 0x59, 0x9b, 0x2e, 0xdf, 0x7b, 0x12, 0x68, 0xb, 0x57, 0x49}, + encryptedCTR: []byte{0x75, 0x43, 0x28, 0xe4, 0x3a, 0x77, 0x59, 0x9b, 0x2e, 0xdf, 0x7b, 0x12, 0x68, 0xb, 0x57, 0x49}, + encryptedGCM: []byte{0xb2, 0x4f, 0x19, 0x53, 0x79, 0x8a, 0x9b, 0x9e, 0xe5, 0x22, 0x93, 0x14, 0x50, 0x8a, 0x8c, 0xd5, 0xfc, 0x61, 0xbf, 0x95, 0xd1, 0xfb}, }, { sequenceNumber: 65535, // upper boundary - encrypted: []byte{0xaf, 0xf7, 0xc2, 0x70, 0x37, 0x20, 0x83, 0x9c, 0x2c, 0x63, 0x85, 0x15, 0xe, 0x44, 0xca, 0x36}, + encryptedCTR: []byte{0xaf, 0xf7, 0xc2, 0x70, 0x37, 0x20, 0x83, 0x9c, 0x2c, 0x63, 0x85, 0x15, 0xe, 0x44, 0xca, 0x36}, + encryptedGCM: []byte{0x40, 0x44, 0x6c, 0xd1, 0x33, 0x5f, 0xca, 0x9b, 0x2e, 0xa3, 0xe5, 0x03, 0xd7, 0x82, 0x36, 0xd8, 0xb7, 0xe8, 0x97, 0x3c, 0xe6, 0xb6}, }, } } -func TestRTPLifecyleNewAlloc(t *testing.T) { +func testRTPLifecyleNewAlloc(t *testing.T, profile ProtectionProfile) { assert := assert.New(t) - authTagLen, err := ProtectionProfileAes128CmHmacSha1_80.authTagLen() + authTagLen, err := profile.rtpAuthTagLen() assert.NoError(err) for _, testCase := range rtpTestCases() { - encryptContext, err := buildTestContext() + encryptContext, err := buildTestContext(profile) if err != nil { t.Fatal(err) } - decryptContext, err := buildTestContext() + decryptContext, err := buildTestContext(profile) if err != nil { t.Fatal(err) } @@ -215,7 +250,7 @@ func TestRTPLifecyleNewAlloc(t *testing.T) { t.Fatal(err) } - encryptedPkt := &rtp.Packet{Payload: testCase.encrypted, Header: rtp.Header{SequenceNumber: testCase.sequenceNumber}} + encryptedPkt := &rtp.Packet{Payload: testCase.encrypted(profile), Header: rtp.Header{SequenceNumber: testCase.sequenceNumber}} encryptedRaw, err := encryptedPkt.Marshal() if err != nil { t.Fatal(err) @@ -238,16 +273,21 @@ func TestRTPLifecyleNewAlloc(t *testing.T) { } } -func TestRTPLifecyleInPlace(t *testing.T) { +func TestRTPLifecycleNewAlloc(t *testing.T) { + t.Run("CTR", func(t *testing.T) { testRTPLifecyleNewAlloc(t, profileCTR) }) + t.Run("GCM", func(t *testing.T) { testRTPLifecyleNewAlloc(t, profileGCM) }) +} + +func testRTPLifecyleInPlace(t *testing.T, profile ProtectionProfile) { assert := assert.New(t) for _, testCase := range rtpTestCases() { - encryptContext, err := buildTestContext() + encryptContext, err := buildTestContext(profile) if err != nil { t.Fatal(err) } - decryptContext, err := buildTestContext() + decryptContext, err := buildTestContext(profile) if err != nil { t.Fatal(err) } @@ -260,14 +300,18 @@ func TestRTPLifecyleInPlace(t *testing.T) { } encryptHeader := &rtp.Header{} - encryptedPkt := &rtp.Packet{Payload: testCase.encrypted, Header: rtp.Header{SequenceNumber: testCase.sequenceNumber}} + encryptedPkt := &rtp.Packet{Payload: testCase.encrypted(profile), Header: rtp.Header{SequenceNumber: testCase.sequenceNumber}} encryptedRaw, err := encryptedPkt.Marshal() if err != nil { t.Fatal(err) } // Copy packet, asserts that everything was done in place - encryptInput := make([]byte, len(decryptedRaw), len(decryptedRaw)+10) + slack := 10 + if profile == profileGCM { + slack = 16 + } + encryptInput := make([]byte, len(decryptedRaw), len(decryptedRaw)+slack) copy(encryptInput, decryptedRaw) actualEncrypted, err := encryptContext.EncryptRTP(encryptInput, encryptInput, encryptHeader) @@ -275,9 +319,9 @@ func TestRTPLifecyleInPlace(t *testing.T) { case err != nil: t.Fatal(err) case &encryptInput[0] != &actualEncrypted[0]: - t.Fatal("EncryptRTP failed to encrypt in place") + t.Errorf("EncryptRTP failed to encrypt in place") case encryptHeader.SequenceNumber != testCase.sequenceNumber: - t.Fatal("EncryptRTP failed to populate input rtp.Header") + t.Errorf("EncryptRTP failed to populate input rtp.Header") } assert.Equalf(actualEncrypted, encryptedRaw, "RTP packet with SeqNum invalid encryption: %d", testCase.sequenceNumber) @@ -290,24 +334,31 @@ func TestRTPLifecyleInPlace(t *testing.T) { case err != nil: t.Fatal(err) case &decryptInput[0] != &actualDecrypted[0]: - t.Fatal("DecryptRTP failed to decrypt in place") + t.Errorf("DecryptRTP failed to decrypt in place") case decryptHeader.SequenceNumber != testCase.sequenceNumber: - t.Fatal("DecryptRTP failed to populate input rtp.Header") + t.Errorf("DecryptRTP failed to populate input rtp.Header") } assert.Equalf(actualDecrypted, decryptedRaw, "RTP packet with SeqNum invalid decryption: %d", testCase.sequenceNumber) } } -func TestRTPReplayProtection(t *testing.T) { +func TestRTPLifecycleInPlace(t *testing.T) { + t.Run("CTR", func(t *testing.T) { testRTPLifecyleInPlace(t, profileCTR) }) + t.Run("GCM", func(t *testing.T) { testRTPLifecyleInPlace(t, profileGCM) }) +} + +func testRTPReplayProtection(t *testing.T, profile ProtectionProfile) { assert := assert.New(t) for _, testCase := range rtpTestCases() { - encryptContext, err := buildTestContext() + encryptContext, err := buildTestContext(profile) if err != nil { t.Fatal(err) } - decryptContext, err := buildTestContext(SRTPReplayProtection(64)) + decryptContext, err := buildTestContext( + profile, SRTPReplayProtection(64), + ) if err != nil { t.Fatal(err) } @@ -320,14 +371,18 @@ func TestRTPReplayProtection(t *testing.T) { } encryptHeader := &rtp.Header{} - encryptedPkt := &rtp.Packet{Payload: testCase.encrypted, Header: rtp.Header{SequenceNumber: testCase.sequenceNumber}} + encryptedPkt := &rtp.Packet{Payload: testCase.encrypted(profile), Header: rtp.Header{SequenceNumber: testCase.sequenceNumber}} encryptedRaw, err := encryptedPkt.Marshal() if err != nil { t.Fatal(err) } // Copy packet, asserts that everything was done in place - encryptInput := make([]byte, len(decryptedRaw), len(decryptedRaw)+10) + slack := 10 + if profile == profileGCM { + slack = 16 + } + encryptInput := make([]byte, len(decryptedRaw), len(decryptedRaw)+slack) copy(encryptInput, decryptedRaw) actualEncrypted, err := encryptContext.EncryptRTP(encryptInput, encryptInput, encryptHeader) @@ -335,7 +390,7 @@ func TestRTPReplayProtection(t *testing.T) { case err != nil: t.Fatal(err) case &encryptInput[0] != &actualEncrypted[0]: - t.Fatal("EncryptRTP failed to encrypt in place") + t.Errorf("EncryptRTP failed to encrypt in place") case encryptHeader.SequenceNumber != testCase.sequenceNumber: t.Fatal("EncryptRTP failed to populate input rtp.Header") } @@ -350,9 +405,9 @@ func TestRTPReplayProtection(t *testing.T) { case err != nil: t.Fatal(err) case &decryptInput[0] != &actualDecrypted[0]: - t.Fatal("DecryptRTP failed to decrypt in place") + t.Errorf("DecryptRTP failed to decrypt in place") case decryptHeader.SequenceNumber != testCase.sequenceNumber: - t.Fatal("DecryptRTP failed to populate input rtp.Header") + t.Errorf("DecryptRTP failed to populate input rtp.Header") } assert.Equalf(actualDecrypted, decryptedRaw, "RTP packet with SeqNum invalid decryption: %d", testCase.sequenceNumber) @@ -363,18 +418,24 @@ func TestRTPReplayProtection(t *testing.T) { } } -func BenchmarkEncryptRTP(b *testing.B) { - encryptContext, err := buildTestContext() +func TestRTPReplayProtection(t *testing.T) { + t.Run("CTR", func(t *testing.T) { testRTPReplayProtection(t, profileCTR) }) + t.Run("GCM", func(t *testing.T) { testRTPReplayProtection(t, profileGCM) }) +} + +func benchmarkEncryptRTP(b *testing.B, profile ProtectionProfile, size int) { + encryptContext, err := buildTestContext(profile) if err != nil { b.Fatal(err) } - pkt := &rtp.Packet{Payload: make([]byte, 100)} + pkt := &rtp.Packet{Payload: make([]byte, size)} pktRaw, err := pkt.Marshal() if err != nil { b.Fatal(err) } + b.SetBytes(int64(len(pktRaw))) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -385,13 +446,28 @@ func BenchmarkEncryptRTP(b *testing.B) { } } -func BenchmarkEncryptRTPInPlace(b *testing.B) { - encryptContext, err := buildTestContext() +func BenchmarkEncryptRTP(b *testing.B) { + b.Run("CTR-100", func(b *testing.B) { + benchmarkEncryptRTP(b, profileCTR, 100) + }) + b.Run("CTR-1000", func(b *testing.B) { + benchmarkEncryptRTP(b, profileCTR, 1000) + }) + b.Run("GCM-100", func(b *testing.B) { + benchmarkEncryptRTP(b, profileGCM, 100) + }) + b.Run("GCM-1000", func(b *testing.B) { + benchmarkEncryptRTP(b, profileGCM, 1000) + }) +} + +func benchmarkEncryptRTPInPlace(b *testing.B, profile ProtectionProfile, size int) { + encryptContext, err := buildTestContext(profile) if err != nil { b.Fatal(err) } - pkt := &rtp.Packet{Payload: make([]byte, 100)} + pkt := &rtp.Packet{Payload: make([]byte, size)} pktRaw, err := pkt.Marshal() if err != nil { b.Fatal(err) @@ -399,6 +475,7 @@ func BenchmarkEncryptRTPInPlace(b *testing.B) { buf := make([]byte, 0, len(pktRaw)+10) + b.SetBytes(int64(len(pktRaw))) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -409,9 +486,24 @@ func BenchmarkEncryptRTPInPlace(b *testing.B) { } } -func BenchmarkDecryptRTP(b *testing.B) { +func BenchmarkEncryptRTPInPlace(b *testing.B) { + b.Run("CTR-100", func(b *testing.B) { + benchmarkEncryptRTPInPlace(b, profileCTR, 100) + }) + b.Run("CTR-1000", func(b *testing.B) { + benchmarkEncryptRTPInPlace(b, profileCTR, 1000) + }) + b.Run("GCM-100", func(b *testing.B) { + benchmarkEncryptRTPInPlace(b, profileGCM, 100) + }) + b.Run("GCM-1000", func(b *testing.B) { + benchmarkEncryptRTPInPlace(b, profileGCM, 1000) + }) +} + +func benchmarkDecryptRTP(b *testing.B, profile ProtectionProfile) { sequenceNumber := uint16(5000) - encrypted := []byte{0x6d, 0xd3, 0x7e, 0xd5, 0x99, 0xb7, 0x2d, 0x28, 0xb1, 0xf3, 0xa1, 0xf0, 0xc, 0xfb, 0xfd, 0x8} + encrypted := rtpTestCases()[0].encrypted(profile) encryptedPkt := &rtp.Packet{ Payload: encrypted, @@ -425,11 +517,12 @@ func BenchmarkDecryptRTP(b *testing.B) { b.Fatal(err) } - context, err := buildTestContext() + context, err := buildTestContext(profile) if err != nil { b.Fatal(err) } + b.SetBytes(int64(len(encryptedRaw))) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -440,59 +533,130 @@ func BenchmarkDecryptRTP(b *testing.B) { } } +func BenchmarkDecryptRTP(b *testing.B) { + b.Run("CTR", func(b *testing.B) { benchmarkDecryptRTP(b, profileCTR) }) + b.Run("GCM", func(b *testing.B) { benchmarkDecryptRTP(b, profileGCM) }) +} + func TestRolloverCount2(t *testing.T) { s := &srtpSSRCState{ssrc: defaultSsrc} - roc, update := s.nextRolloverCount(30123) + roc, diff := s.nextRolloverCount(30123) if roc != 0 { t.Errorf("Initial rolloverCounter must be 0") } - update() + s.updateRolloverCount(30123, diff) - roc, update = s.nextRolloverCount(62892) // 30123 + (1 << 15) + 1 + roc, diff = s.nextRolloverCount(62892) // 30123 + (1 << 15) + 1 if roc != 0 { t.Errorf("Initial rolloverCounter must be 0") } - update() - roc, update = s.nextRolloverCount(204) + s.updateRolloverCount(62892, diff) + roc, diff = s.nextRolloverCount(204) if roc != 1 { t.Errorf("rolloverCounter was not updated after it crossed 0") } - update() - roc, update = s.nextRolloverCount(64535) + s.updateRolloverCount(62892, diff) + roc, diff = s.nextRolloverCount(64535) if roc != 0 { t.Errorf("rolloverCounter was not updated when it rolled back, failed to handle out of order") } - update() - roc, update = s.nextRolloverCount(205) + s.updateRolloverCount(64535, diff) + roc, diff = s.nextRolloverCount(205) if roc != 1 { t.Errorf("rolloverCounter was improperly updated for non-significant packets") } - update() - roc, update = s.nextRolloverCount(1) + s.updateRolloverCount(205, diff) + roc, diff = s.nextRolloverCount(1) if roc != 1 { t.Errorf("rolloverCounter was improperly updated for non-significant packets") } - update() + s.updateRolloverCount(1, diff) - roc, update = s.nextRolloverCount(64532) + roc, diff = s.nextRolloverCount(64532) if roc != 0 { t.Errorf("rolloverCounter was improperly updated for non-significant packets") } - update() - roc, update = s.nextRolloverCount(65534) + s.updateRolloverCount(64532, diff) + roc, diff = s.nextRolloverCount(65534) if roc != 0 { t.Errorf("index was improperly updated for non-significant packets") } - update() - roc, update = s.nextRolloverCount(64532) + s.updateRolloverCount(65534, diff) + roc, diff = s.nextRolloverCount(64532) if roc != 0 { t.Errorf("index was improperly updated for non-significant packets") } - update() - roc, update = s.nextRolloverCount(205) + s.updateRolloverCount(65532, diff) + roc, diff = s.nextRolloverCount(205) if roc != 1 { t.Errorf("index was not updated after it crossed 0") } - update() + s.updateRolloverCount(65532, diff) +} + +func TestProtectionProfileAes128CmHmacSha1_32(t *testing.T) { + masterKey := []byte{0x0d, 0xcd, 0x21, 0x3e, 0x4c, 0xbc, 0xf2, 0x8f, 0x01, 0x7f, 0x69, 0x94, 0x40, 0x1e, 0x28, 0x89} + masterSalt := []byte{0x62, 0x77, 0x60, 0x38, 0xc0, 0x6d, 0xc9, 0x41, 0x9f, 0x6d, 0xd9, 0x43, 0x3e, 0x7c} + + encryptContext, err := CreateContext(masterKey, masterSalt, ProtectionProfileAes128CmHmacSha1_32) + if err != nil { + t.Fatal(err) + } + + decryptContext, err := CreateContext(masterKey, masterSalt, ProtectionProfileAes128CmHmacSha1_32) + if err != nil { + t.Fatal(err) + } + + pkt := &rtp.Packet{Payload: rtpTestCaseDecrypted(), Header: rtp.Header{SequenceNumber: 5000}} + pktRaw, err := pkt.Marshal() + if err != nil { + t.Fatal(err) + } + + out, err := encryptContext.EncryptRTP(nil, pktRaw, nil) + if err != nil { + t.Fatal(err) + } + + decrypted, err := decryptContext.DecryptRTP(nil, out, nil) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(decrypted, pktRaw) { + t.Errorf("Decrypted % 02x does not match original % 02x", decrypted, pktRaw) + } +} + +func TestRTPDecryptShotenedPacket(t *testing.T) { + profiles := map[string]ProtectionProfile{ + "CTR": profileCTR, + "GCM": profileGCM, + } + for name, profile := range profiles { + profile := profile + t.Run(name, func(t *testing.T) { + for _, testCase := range rtpTestCases() { + decryptContext, err := buildTestContext(profile) + if err != nil { + t.Fatal(err) + } + + encryptedPkt := &rtp.Packet{Payload: testCase.encrypted(profile), Header: rtp.Header{SequenceNumber: testCase.sequenceNumber}} + encryptedRaw, err := encryptedPkt.Marshal() + if err != nil { + t.Fatal(err) + } + + for i := 1; i < len(encryptedRaw)-1; i++ { + packet := encryptedRaw[:i] + assert.NotPanics(t, func() { + _, _ = decryptContext.DecryptRTP(nil, packet, nil) + }, "Panic on length %d/%d", i, len(encryptedRaw)) + } + } + }) + } } diff --git a/stream_srtp_test.go b/stream_srtp_test.go index 1fbe760..6907e92 100644 --- a/stream_srtp_test.go +++ b/stream_srtp_test.go @@ -61,17 +61,26 @@ func TestBufferFactory(t *testing.T) { wg.Wait() } -func BenchmarkWrite(b *testing.B) { +func benchmarkWrite(b *testing.B, profile ProtectionProfile, size int) { conn := newNoopConn() + keyLen, err := profile.keyLen() + if err != nil { + b.Fatal(err) + } + saltLen, err := profile.saltLen() + if err != nil { + b.Fatal(err) + } + config := &Config{ Keys: SessionKeys{ - LocalMasterKey: make([]byte, 16), - LocalMasterSalt: make([]byte, 14), - RemoteMasterKey: make([]byte, 16), - RemoteMasterSalt: make([]byte, 14), + LocalMasterKey: make([]byte, keyLen), + LocalMasterSalt: make([]byte, saltLen), + RemoteMasterKey: make([]byte, keyLen), + RemoteMasterSalt: make([]byte, saltLen), }, - Profile: ProtectionProfileAes128CmHmacSha1_80, + Profile: profile, } session, err := NewSessionSRTP(conn, config) @@ -89,7 +98,7 @@ func BenchmarkWrite(b *testing.B) { Version: 2, SSRC: 322, }, - Payload: make([]byte, 100), + Payload: make([]byte, size), } packetRaw, err := packet.Marshal() @@ -97,6 +106,7 @@ func BenchmarkWrite(b *testing.B) { b.Fatal(err) } + b.SetBytes(int64(len(packetRaw))) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -114,19 +124,43 @@ func BenchmarkWrite(b *testing.B) { } } -func BenchmarkWriteRTP(b *testing.B) { +func BenchmarkWrite(b *testing.B) { + b.Run("CTR-100", func(b *testing.B) { + benchmarkWrite(b, profileCTR, 100) + }) + b.Run("CTR-1000", func(b *testing.B) { + benchmarkWrite(b, profileCTR, 1000) + }) + b.Run("GCM-100", func(b *testing.B) { + benchmarkWrite(b, profileGCM, 100) + }) + b.Run("GCM-1000", func(b *testing.B) { + benchmarkWrite(b, profileGCM, 1000) + }) +} + +func benchmarkWriteRTP(b *testing.B, profile ProtectionProfile, size int) { conn := &noopConn{ closed: make(chan struct{}), } + keyLen, err := profile.keyLen() + if err != nil { + b.Fatal(err) + } + saltLen, err := profile.saltLen() + if err != nil { + b.Fatal(err) + } + config := &Config{ Keys: SessionKeys{ - LocalMasterKey: make([]byte, 16), - LocalMasterSalt: make([]byte, 14), - RemoteMasterKey: make([]byte, 16), - RemoteMasterSalt: make([]byte, 14), + LocalMasterKey: make([]byte, keyLen), + LocalMasterSalt: make([]byte, saltLen), + RemoteMasterKey: make([]byte, keyLen), + RemoteMasterSalt: make([]byte, saltLen), }, - Profile: ProtectionProfileAes128CmHmacSha1_80, + Profile: profile, } session, err := NewSessionSRTP(conn, config) @@ -144,8 +178,9 @@ func BenchmarkWriteRTP(b *testing.B) { SSRC: 322, } - payload := make([]byte, 100) + payload := make([]byte, size) + b.SetBytes(int64(header.MarshalSize() + len(payload))) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -162,3 +197,18 @@ func BenchmarkWriteRTP(b *testing.B) { b.Fatal(err) } } + +func BenchmarkWriteRTP(b *testing.B) { + b.Run("CTR-100", func(b *testing.B) { + benchmarkWriteRTP(b, profileCTR, 100) + }) + b.Run("CTR-1000", func(b *testing.B) { + benchmarkWriteRTP(b, profileCTR, 1000) + }) + b.Run("GCM-100", func(b *testing.B) { + benchmarkWriteRTP(b, profileGCM, 100) + }) + b.Run("GCM-1000", func(b *testing.B) { + benchmarkWriteRTP(b, profileGCM, 1000) + }) +}