Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Authorization Callback Function #108

Merged
merged 11 commits into from
Oct 21, 2024
28 changes: 26 additions & 2 deletions client_test.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
package knox

import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"os/exec"
"path"
"reflect"
"runtime"
"sync/atomic"
"testing"
)
Expand Down Expand Up @@ -75,7 +78,7 @@ func buildServer(code int, body []byte, a func(r *http.Request)) *httptest.Serve
}))
}

func buildConcurrentServer(code int, t *testing.T, a func(r *http.Request) []byte) *httptest.Server {
func buildConcurrentServer(code int, a func(r *http.Request) []byte) *httptest.Server {
return httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := a(r)
w.WriteHeader(code)
Expand All @@ -84,6 +87,23 @@ func buildConcurrentServer(code int, t *testing.T, a func(r *http.Request) []byt
}))
}

func isKnoxDaemonRunning() bool {
if runtime.GOOS != "linux" {
return false
}

cmd := exec.Command("systemctl", "is-active", "--quiet", "knox")

var out bytes.Buffer
cmd.Stdout = &out
err := cmd.Run()
if err == nil {
return true
}

return false
}

func TestGetKey(t *testing.T) {
expected := Key{
ID: "testkey",
Expand Down Expand Up @@ -357,7 +377,7 @@ func TestPutAccess(t *testing.T) {

func TestConcurrentDeletes(t *testing.T) {
var ops uint64
srv := buildConcurrentServer(200, t, func(r *http.Request) []byte {
srv := buildConcurrentServer(200, func(r *http.Request) []byte {
if r.Method != "DELETE" {
t.Fatalf("%s is not DELETE", r.Method)
}
Expand Down Expand Up @@ -511,6 +531,10 @@ func TestGetInvalidKeys(t *testing.T) {
}

func TestNewFileClient(t *testing.T) {
if isKnoxDaemonRunning() {
t.Skip("Knox daemon is running, skipping the test.")
}

_, err := NewFileClient("ThisKeyDoesNotExistSoWeExpectAnError")
if (err.Error() != "error getting knox key ThisKeyDoesNotExistSoWeExpectAnError. error: exit status 1") && (err.Error() != "error getting knox key ThisKeyDoesNotExistSoWeExpectAnError. error: exec: \"knox\": executable file not found in $PATH") {
t.Fatal("Unexpected error", err.Error())
Expand Down
24 changes: 24 additions & 0 deletions knox.go
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,14 @@ type Principal interface {
CanAccess(ACL, AccessType) bool
GetID() string
Type() string
Raw() []RawPrincipal
}

// RawPrincipal is a serializable version of a principal for passing to
// access callbacks.
type RawPrincipal struct {
ID string `json:"id"`
Type string `json:"type"`
}

// PrincipalMux provides a Principal Interface over multiple Principals.
Expand Down Expand Up @@ -564,6 +572,15 @@ func (p PrincipalMux) Default() Principal {
return p.defaultPrincipal
}

// Raw returns the raw version of all the principals.
func (p PrincipalMux) Raw() []RawPrincipal {
raw := []RawPrincipal{}
for _, principal := range p.allPrincipals {
raw = append(raw, principal.Raw()...)
}
return raw
}

// NewPrincipalMux returns a Principal that represents many principals.
func NewPrincipalMux(defaultPrincipal Principal, allPrincipals map[string]Principal) Principal {
return PrincipalMux{
Expand Down Expand Up @@ -599,3 +616,10 @@ type Response struct {
Message string `json:"message"`
Data interface{} `json:"data"`
}

// AccessCallbackInput is the input to the access callback function.
type AccessCallbackInput struct {
Key Key `json:"key"`
Principals []RawPrincipal `json:"principals"`
AccessType AccessType `json:"access_type"`
}
7 changes: 7 additions & 0 deletions server/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,13 @@ func AddDefaultAccess(a *knox.Access) {
defaultAccess = append(defaultAccess, *a)
}

var accessCallback func(knox.AccessCallbackInput) (bool, error)
krockpot marked this conversation as resolved.
Show resolved Hide resolved

henryluo marked this conversation as resolved.
Show resolved Hide resolved
// SetAccessCallback adds a callback.
func SetAccessCallback(callback func(knox.AccessCallbackInput) (bool, error)) {
accessCallback = callback
}

// Extra validators to apply on principals submitted to Knox.
var extraPrincipalValidators []knox.PrincipalValidator

Expand Down
25 changes: 25 additions & 0 deletions server/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ func additionalMockHandler(m KeyManager, principal knox.Principal, parameters ma
return "The meaning of life is 42", nil
}

func mockAccessCallback(input knox.AccessCallbackInput) (bool, error) {
return true, nil
}

func mockRoute() Route {
return Route{
Method: "GET",
Expand Down Expand Up @@ -105,6 +109,27 @@ func TestAddDefaultAccess(t *testing.T) {

}

func TestSetAccessCallback(t *testing.T) {
defer SetAccessCallback(nil)

SetAccessCallback(mockAccessCallback)

input := knox.AccessCallbackInput{}

if accessCallback == nil {
t.Fatal("accessCallback should not be nil")
}

canAccess, err := accessCallback(input)
if err != nil {
t.Fatal("accessCallback should not return an error")
}

if !canAccess {
t.Fatal("accessCallback should return true")
}
}

func TestParseFormParameter(t *testing.T) {
p := PostParameter("key")

Expand Down
27 changes: 27 additions & 0 deletions server/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,15 @@ func (u user) GetID() string {
return u.ID
}

func (u user) Raw() []knox.RawPrincipal {
return []knox.RawPrincipal{
{
ID: u.GetID(),
Type: u.Type(),
},
}
}

// Type returns the underlying type of a principal, for logging/debugging purposes.
func (u user) Type() string {
return "user"
Expand Down Expand Up @@ -415,6 +424,15 @@ func (m machine) Type() string {
return "machine"
}

func (m machine) Raw() []knox.RawPrincipal {
return []knox.RawPrincipal{
{
ID: m.GetID(),
Type: m.Type(),
},
}
}

// CanAccess determines if a Machine can access an object represented by the ACL
// with a certain AccessType. It compares Machine hostname and hostname prefix.
func (m machine) CanAccess(acl knox.ACL, t knox.AccessType) bool {
Expand Down Expand Up @@ -450,6 +468,15 @@ func (s service) Type() string {
return "service"
}

func (s service) Raw() []knox.RawPrincipal {
return []knox.RawPrincipal{
{
ID: s.GetID(),
Type: s.Type(),
},
}
}

// CanAccess determines if a Service can access an object represented by the ACL
// with a certain AccessType. It compares Service id and id prefix.
func (s service) CanAccess(acl knox.ACL, t knox.AccessType) bool {
Expand Down
59 changes: 54 additions & 5 deletions server/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"strconv"

"github.com/pinterest/knox"
"github.com/pinterest/knox/log"
"github.com/pinterest/knox/server/auth"
)

Expand Down Expand Up @@ -211,9 +212,15 @@ func getKeyHandler(m KeyManager, principal knox.Principal, parameters map[string
}

// Authorize access to data
if !principal.CanAccess(key.ACL, knox.Read) {
authorized, authzErr := authorizeRequest(key, principal, knox.Read)
if authzErr != nil {
return nil, errF(knox.InternalServerErrorCode, authzErr.Error())
}

if !authorized {
return nil, errF(knox.UnauthorizedCode, fmt.Sprintf("Principal %s not authorized to read %s", principal.GetID(), keyID))
}

// Zero ACL for key response, in order to avoid caching unnecessarily
key.ACL = knox.ACL{}
return key, nil
Expand All @@ -234,7 +241,12 @@ func deleteKeyHandler(m KeyManager, principal knox.Principal, parameters map[str
}

// Authorize
if !principal.CanAccess(key.ACL, knox.Admin) {
authorized, authzErr := authorizeRequest(key, principal, knox.Admin)
if authzErr != nil {
return nil, errF(knox.InternalServerErrorCode, authzErr.Error())
}

if !authorized {
return nil, errF(knox.UnauthorizedCode, fmt.Sprintf("Principal %s not authorized to delete %s", principal.GetID(), keyID))
}

Expand Down Expand Up @@ -314,7 +326,12 @@ func putAccessHandler(m KeyManager, principal knox.Principal, parameters map[str
}

// Authorize
if !principal.CanAccess(key.ACL, knox.Admin) {
authorized, authzErr := authorizeRequest(key, principal, knox.Admin)
if authzErr != nil {
return nil, errF(knox.InternalServerErrorCode, authzErr.Error())
}

if !authorized {
return nil, errF(knox.UnauthorizedCode, fmt.Sprintf("Principal %s not authorized to update access for %s", principal.GetID(), keyID))
}

Expand Down Expand Up @@ -371,7 +388,12 @@ func postVersionHandler(m KeyManager, principal knox.Principal, parameters map[s
}

// Authorize
if !principal.CanAccess(key.ACL, knox.Write) {
authorized, authzErr := authorizeRequest(key, principal, knox.Write)
if authzErr != nil {
return nil, errF(knox.InternalServerErrorCode, authzErr.Error())
}

if !authorized {
return nil, errF(knox.UnauthorizedCode, fmt.Sprintf("Principal %s not authorized to write %s", principal.GetID(), keyID))
}

Expand Down Expand Up @@ -428,7 +450,12 @@ func putVersionsHandler(m KeyManager, principal knox.Principal, parameters map[s
}

// Authorize
if !principal.CanAccess(key.ACL, knox.Write) {
authorized, authzErr := authorizeRequest(key, principal, knox.Write)
if authzErr != nil {
return nil, errF(knox.InternalServerErrorCode, authzErr.Error())
}

if !authorized {
return nil, errF(knox.UnauthorizedCode, fmt.Sprintf("Principal %s not authorized to write %s", principal.GetID(), keyID))
}

Expand All @@ -445,3 +472,25 @@ func putVersionsHandler(m KeyManager, principal knox.Principal, parameters map[s
return nil, errF(knox.InternalServerErrorCode, err.Error())
}
}

func authorizeRequest(key *knox.Key, principal knox.Principal, access knox.AccessType) (allow bool, err error) {
defer func() {
if r := recover(); r != nil {
log.Printf("Recovered from panic in access callback: %v", r)

err = fmt.Errorf("Recovered from panic in access callback: %v", r)
}
}()

allow = principal.CanAccess(key.ACL, access)

if !allow && accessCallback != nil {
allow, err = accessCallback(knox.AccessCallbackInput{
Key: *key,
Principals: principal.Raw(),
AccessType: access,
})
}

return
}
Loading
Loading