Skip to content

Commit

Permalink
Merge pull request #4 from linuxfoundation/ems/refactor
Browse files Browse the repository at this point in the history
Refactor code and pipeline for best practices
  • Loading branch information
emsearcy authored Apr 9, 2024
2 parents 4c25ab9 + 3d428be commit 7f3c7d4
Show file tree
Hide file tree
Showing 8 changed files with 80 additions and 73 deletions.
3 changes: 3 additions & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright The Linux Foundation and each contributor.
# SPDX-License-Identifier: MIT

.git
.env
/bin/*
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/publish.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ jobs:
name: Publish
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v4
with:
go-version: '1.22.x'
- uses: actions/checkout@v4
- uses: ko-build/[email protected]
- run: |
ko build --bare --platform linux/amd64,linux/arm64 -t latest -t ${{ github.sha }} --sbom spdx .
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright The Linux Foundation and its contributors.
# SPDX-License-Identifier: MIT

FROM --platform=$BUILDPLATFORM cgr.dev/chainguard/go AS builder
FROM --platform=$BUILDPLATFORM cgr.dev/chainguard/go:latest AS builder

# Set necessary environment variables needed for our image. Allow building to
# other architectures via cross-compliation build-arg.
Expand Down
9 changes: 5 additions & 4 deletions auth0_clients.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"io"
"net/http"
"net/url"
"strconv"
Expand Down Expand Up @@ -91,7 +91,7 @@ func getAuth0Clients(ctx context.Context) ([]auth0ClientStub, error) {

for {
uri := fmt.Sprintf("https://%s.auth0.com/api/v2/clients?%s", cfg.Auth0Tenant, v.Encode())
req, err := http.NewRequestWithContext(ctx, "GET", uri, nil)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, uri, nil)
if err != nil {
return nil, err
}
Expand All @@ -101,7 +101,7 @@ func getAuth0Clients(ctx context.Context) ([]auth0ClientStub, error) {
}
defer resp.Body.Close()

bodyBytes, err := ioutil.ReadAll(resp.Body)
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -156,7 +156,8 @@ func getAuth0ClientByService(ctx context.Context, serviceURL string) (*auth0Clie
// Compare against cached globs.
if item, exists := auth0Cache.Get("cas-service-globs"); exists {
for glob, client := range item.(map[string]auth0ClientStub) {
match, err := doublestar.Match(glob, serviceURL)
var match bool
match, err = doublestar.Match(glob, serviceURL)
if err != nil {
appLogger(ctx).WithFields(logrus.Fields{
"pattern": glob,
Expand Down
58 changes: 29 additions & 29 deletions cas.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"encoding/base64"
"encoding/json"
"fmt"
"io/ioutil"
"io"
"net/http"
"net/url"
"path/filepath"
Expand Down Expand Up @@ -49,7 +49,7 @@ func casLogin(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Expires", "Sun, 19 Nov 1978 05:00:00 GMT")
w.Header().Set("X-Content-Type-Options", "nosniff")

if r.Method != "GET" && r.Method != "POST" {
if r.Method != http.MethodGet && r.Method != http.MethodPost {
http.NotFound(w, r)
return
}
Expand Down Expand Up @@ -136,7 +136,7 @@ func casLogout(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Expires", "Sun, 19 Nov 1978 05:00:00 GMT")
w.Header().Set("X-Content-Type-Options", "nosniff")

if r.Method != "GET" && r.Method != "POST" {
if r.Method != http.MethodGet && r.Method != http.MethodPost {
http.NotFound(w, r)
return
}
Expand Down Expand Up @@ -167,7 +167,7 @@ func casServiceValidate(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Expires", "Sun, 19 Nov 1978 05:00:00 GMT")
w.Header().Set("X-Content-Type-Options", "nosniff")

if r.Method != "GET" && r.Method != "POST" {
if r.Method != http.MethodGet && r.Method != http.MethodPost {
http.NotFound(w, r)
return
}
Expand All @@ -180,54 +180,54 @@ func casServiceValidate(w http.ResponseWriter, r *http.Request) {
case formatParam == "JSON":
useJSON = true
case formatParam != "" && formatParam != "XML":
outputFailure(r.Context(), w, r, nil, "INVALID_REQUEST", "invalid format", useJSON)
outputFailure(r.Context(), w, nil, "INVALID_REQUEST", "invalid format", useJSON)
return
}

service := params.Get("service")
if service == "" {
outputFailure(r.Context(), w, r, nil, "INVALID_REQUEST", "service parameter is required", useJSON)
outputFailure(r.Context(), w, nil, "INVALID_REQUEST", "service parameter is required", useJSON)
return
}

ticket := params.Get("ticket")
if ticket == "" {
outputFailure(r.Context(), w, r, nil, "INVALID_REQUEST", "ticket parameter is required", useJSON)
outputFailure(r.Context(), w, nil, "INVALID_REQUEST", "ticket parameter is required", useJSON)
return
}

pgtURL := params.Get("pgtUrl")
if pgtURL != "" {
outputFailure(r.Context(), w, r, nil, "INTERNAL_ERROR", "proxy callbacks not implemented", useJSON)
outputFailure(r.Context(), w, nil, "INTERNAL_ERROR", "proxy callbacks not implemented", useJSON)
return
}

if strings.HasPrefix(ticket, "PT-") {
// We don't issue proxy tickets (/proxy always returns
// UNAUTHORIZED_SERVICE), so any proxy ticket is not recognized.
outputFailure(r.Context(), w, r, nil, "INVALID_TICKET", "ticket not recognized", useJSON)
outputFailure(r.Context(), w, nil, "INVALID_TICKET", "ticket not recognized", useJSON)
return
}

if !strings.HasPrefix(ticket, "ST-") {
outputFailure(r.Context(), w, r, nil, "INVALID_TICKET_SPEC", "invalid ticket spec", useJSON)
outputFailure(r.Context(), w, nil, "INVALID_TICKET_SPEC", "invalid ticket spec", useJSON)
return
}

authCode := strings.TrimPrefix(ticket, "ST-A")
if authCode == ticket {
// Not having an ST-A prefix means the ticket is unknown; see oauth2Callback.
outputFailure(r.Context(), w, r, nil, "INVALID_TICKET", "foreign ticket not recognized", useJSON)
outputFailure(r.Context(), w, nil, "INVALID_TICKET", "foreign ticket not recognized", useJSON)
return
}

casClient, err := getAuth0ClientByService(r.Context(), service)
if err != nil {
outputFailure(r.Context(), w, r, err, "INTERNAL_ERROR", "error looking up service", useJSON)
outputFailure(r.Context(), w, err, "INTERNAL_ERROR", "error looking up service", useJSON)
return
}
if casClient == nil {
outputFailure(r.Context(), w, r, nil, "INVALID_SERVICE", "unknown service", useJSON)
outputFailure(r.Context(), w, nil, "INVALID_SERVICE", "unknown service", useJSON)
return
}

Expand All @@ -247,45 +247,45 @@ func casServiceValidate(w http.ResponseWriter, r *http.Request) {
// Rather than decoding the JSON payload, we can assume a 403 means the
// auth code (as provided as a CAS service ticket) was invalid.
appLogger(r.Context()).WithError(err).Debug("auth code exchange 403 response")
outputFailure(r.Context(), w, r, nil, "INVALID_TICKET", "invalid ticket", useJSON)
outputFailure(r.Context(), w, nil, "INVALID_TICKET", "invalid ticket", useJSON)
return
}
}
// Handle any other error (non-403 responses or HTTP errors).
outputFailure(r.Context(), w, r, err, "INTERNAL_ERROR", "error validating ticket", useJSON)
outputFailure(r.Context(), w, err, "INTERNAL_ERROR", "error validating ticket", useJSON)
return
}

uri := fmt.Sprintf("https://%s/userinfo", cfg.Auth0Domain)
req, err := http.NewRequestWithContext(r.Context(), "GET", uri, nil)
req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, uri, nil)
if err != nil {
outputFailure(r.Context(), w, r, err, "INTERNAL_ERROR", "error creating user profile request", useJSON)
outputFailure(r.Context(), w, err, "INTERNAL_ERROR", "error creating user profile request", useJSON)
return
}
token.SetAuthHeader(req)
resp, err := httpClient.Do(req)
if err != nil {
outputFailure(r.Context(), w, r, err, "INTERNAL_ERROR", "error fetching user profile", useJSON)
outputFailure(r.Context(), w, err, "INTERNAL_ERROR", "error fetching user profile", useJSON)
return
}
defer resp.Body.Close()

bodyBytes, err := ioutil.ReadAll(resp.Body)
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
outputFailure(r.Context(), w, r, err, "INTERNAL_ERROR", "error reading user profile response", useJSON)
outputFailure(r.Context(), w, err, "INTERNAL_ERROR", "error reading user profile response", useJSON)
return
}

if resp.StatusCode != http.StatusOK {
err := fmt.Errorf("userinfo returned %v: %s", resp.StatusCode, string(bodyBytes))
outputFailure(r.Context(), w, r, err, "INTERNAL_ERROR", "user profile response error", useJSON)
outputFailure(r.Context(), w, err, "INTERNAL_ERROR", "user profile response error", useJSON)
return
}

user := new(userAttributes)
err = json.Unmarshal(bodyBytes, user)
if err != nil {
outputFailure(r.Context(), w, r, err, "INTERNAL_ERROR", "user profile parse error", useJSON)
outputFailure(r.Context(), w, err, "INTERNAL_ERROR", "user profile parse error", useJSON)
return
}

Expand Down Expand Up @@ -325,7 +325,7 @@ func casProxy(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Expires", "Sun, 19 Nov 1978 05:00:00 GMT")
w.Header().Set("X-Content-Type-Options", "nosniff")

if r.Method != "GET" && r.Method != "POST" {
if r.Method != http.MethodGet && r.Method != http.MethodPost {
http.NotFound(w, r)
return
}
Expand All @@ -338,24 +338,24 @@ func casProxy(w http.ResponseWriter, r *http.Request) {
case formatParam == "JSON":
useJSON = true
case formatParam != "" && formatParam != "XML":
outputFailure(r.Context(), w, r, nil, "INVALID_REQUEST", "invalid format", useJSON)
outputFailure(r.Context(), w, nil, "INVALID_REQUEST", "invalid format", useJSON)
return
}

pgt := params.Get("pgt")
if pgt == "" {
outputFailure(r.Context(), w, r, nil, "INVALID_REQUEST", "pgt parameter is required", useJSON)
outputFailure(r.Context(), w, nil, "INVALID_REQUEST", "pgt parameter is required", useJSON)
return
}

targetService := params.Get("targetService")
if targetService == "" {
outputFailure(r.Context(), w, r, nil, "INVALID_REQUEST", "targetService parameter is required", useJSON)
outputFailure(r.Context(), w, nil, "INVALID_REQUEST", "targetService parameter is required", useJSON)
return
}

// Deny all proxy-grant-ticket requests.
outputFailure(r.Context(), w, r, nil, "UNAUTHORIZED_SERVICE", "not authorized for proxy requests", useJSON)
outputFailure(r.Context(), w, nil, "UNAUTHORIZED_SERVICE", "not authorized for proxy requests", useJSON)
}

func oauth2Callback(w http.ResponseWriter, r *http.Request) {
Expand All @@ -364,7 +364,7 @@ func oauth2Callback(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Expires", "Sun, 19 Nov 1978 05:00:00 GMT")
w.Header().Set("X-Content-Type-Options", "nosniff")

if r.Method != "GET" {
if r.Method != http.MethodGet {
http.NotFound(w, r)
return
}
Expand Down Expand Up @@ -515,7 +515,7 @@ func getLogoutParams(ctx context.Context, returnTo string) *url.Values {
// error. This logs the issue, and formats and outputs the response (default
// 200 status code). If the response cannot be formatted, an additional error
// is logged and a plain-text message and 500 response is output.
func outputFailure(ctx context.Context, w http.ResponseWriter, r *http.Request, err error, code, description string, useJSON bool) {
func outputFailure(ctx context.Context, w http.ResponseWriter, err error, code, description string, useJSON bool) {
switch {
case err != nil:
appLogger(ctx).WithError(err).Error(description)
Expand Down
68 changes: 33 additions & 35 deletions check-headers.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,29 @@
# Exits with a 0 if all source files have license headers
# Exits with a 1 if one or more source files are missing a license header

# These are the file patterns we should exclude - these are typically transient files not checked into source control
exclude_pattern='node_modules|.venv|.pytest_cache|.idea'
# Exclude code coming from a third-party. Typically these won't be checked into
# source control, but occassionally "vendored" code is committed.
exclude_pattern='^(.*/)?(node_modules|vendor)/'

# Include build definitions.
filetypes=(Makefile Dockerfile .gitignore .dockerignore)
# Include Golang files.
filetypes+=("*.go" go.mod)
# Include Python files.
filetypes+=("*.py")
# Include HTML, CSS, JS, TS, SCSS.
filetypes+=("*.html" "*.htm" "*.css" "*.ts" "*.js" "*.scss")
# Include shell scripts.
filetypes+=("*.sh" "*.bash" "*.ksh" "*.csh" "*.tcsh" "*.fsh")
# Include text.
filetypes+=("*.txt")
# Include YAML and TOML files.
filetypes+=("*.yaml" "*.yml" "*.toml")
# Include SQL scripts and definitions.
filetypes+=("*.sql")

files=()
echo "Scanning source code..."
# Adjust this filters based on the source files you are interested in checking
# Loads all the filenames into an array
# We need optimize this, possibly use: -name '*.go' -o -name '*.txt' - not working as expected on mac
echo 'Searching *.go|go.mod files...'
files+=($(find . -type f \( -name '*.go' -o -name 'go.mod' \) -print | egrep -v ${exclude_pattern}))
echo "Searching python files..."
files+=($(find . -type f -name '*.py' -print | egrep -v ${exclude_pattern}))
echo "Searching html|css|ts|js files..."
files+=($(find . -type f \( -name '*.html' -o -name '*.css' -o -name '*.ts' -o -name '*.js' -o -name '*.scss' \) -print | egrep -v ${exclude_pattern})) # NOTE There must be a space between the parens and its contents or it won't work.
echo "Searching shell files..."
files+=($(find . -type f \( -name '*.sh' -o -name '*.bash' -o -name '*.ksh' -o -name '*.csh' -o -name '*.tcsh' -o -name '*.fsh' \) -print | egrep -v ${exclude_pattern})) # NOTE There must be a space between the parens and its contents or it won't work.
echo "Searching make files..."
files+=($(find . -type f -name 'Makefile' -print | egrep -v ${exclude_pattern}))
echo "Searching txt files..."
files+=($(find . -type f -name '*.txt' -print | egrep -v ${exclude_pattern}))
echo "Searching yaml|yml files..."
files+=($(find . -type f \( -name '*.yaml' -o -name '*.yml' \) -print | egrep -v ${exclude_pattern})) # NOTE There must be a space between the parens and its contents or it won't work.
files+=($(find . -type f -name '.gitignore' -print | egrep -v ${exclude_pattern}))
echo "Searching SQL files..."
files+=($(find . -type f -name '*.sql' -print | egrep -v ${exclude_pattern}))
while IFS='' read -r line; do files+=("$line"); done < <(git ls-files -c "${filetypes[@]}" | grep -E -v "${exclude_pattern}")

# This is the copyright line to look for - adjust as necessary
copyright_line="Copyright The Linux Foundation"
Expand All @@ -42,24 +40,24 @@ missing_license_header=0
# For each file...
echo "Checking ${#files[@]} source code files for the license header..."
for file in "${files[@]}"; do
# echo "Processing file ${file}..."
# echo "Processing file ${file}..."

# Header is typically one of the first few lines in the file...
head -4 "${file}" | grep -q "${copyright_line}"
# Find it? exit code value of 0 indicates the grep found a match
exit_code=$?
if [[ ${exit_code} -ne 0 ]]; then
echo "${file} is missing the license header"
# update our flag - we'll fail the test
missing_license_header=1
fi
# Header is typically one of the first few lines in the file...
head -4 "${file}" | grep -q "${copyright_line}"
# Find it? exit code value of 0 indicates the grep found a match
exit_code=$?
if [[ ${exit_code} -ne 0 ]]; then
echo "${file} is missing the license header"
# update our flag - we'll fail the test
missing_license_header=1
fi
done

# Summary
if [[ ${missing_license_header} -eq 1 ]]; then
echo "One or more source files is missing the license header."
echo "One or more source files is missing the license header."
else
echo "License check passed."
echo "License check passed."
fi

# Exit with status code 0 = success, 1 = failed
Expand Down
10 changes: 8 additions & 2 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"net/http"
"os"
"strconv"
"time"

"github.com/evalphobia/logrus_fluent"
"github.com/sirupsen/logrus"
Expand Down Expand Up @@ -113,7 +114,7 @@ func main() {
}

// Support GET/POST monitoring "ping".
http.HandleFunc("/ping", func(w http.ResponseWriter, r *http.Request) {
http.HandleFunc("/ping", func(w http.ResponseWriter, _ *http.Request) {
fmt.Fprintf(w, "OK\n")
})

Expand Down Expand Up @@ -156,7 +157,12 @@ func main() {
} else {
addr = *bind + ":" + *port
}
err := http.ListenAndServe(addr, mux)
server := &http.Server{
Addr: addr,
Handler: mux,
ReadHeaderTimeout: 3 * time.Second,
}
err := server.ListenAndServe()
if err != nil {
logrus.WithError(err).Fatal("http listener error")
}
Expand Down
Loading

0 comments on commit 7f3c7d4

Please sign in to comment.