Skip to content

Commit

Permalink
feat: default runtimeclass webhook
Browse files Browse the repository at this point in the history
Signed-off-by: Dario Tranchitella <[email protected]>
  • Loading branch information
prometherion committed Aug 17, 2024
1 parent 2ed12d2 commit 4f3e20c
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 40 deletions.
2 changes: 1 addition & 1 deletion api/v1beta2/tenant_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ type TenantSpec struct {
// Specifies the allowed RuntimeClasses assigned to the Tenant.
// Capsule assures that all Pods resources created in the Tenant can use only one of the allowed RuntimeClasses.
// Optional.
RuntimeClasses *api.SelectorAllowedListSpec `json:"runtimeClasses,omitempty"`
RuntimeClasses *api.DefaultAllowedListSpec `json:"runtimeClasses,omitempty"`
// Specifies the allowed priorityClasses assigned to the Tenant.
// Capsule assures that all Pods resources created in the Tenant can use only one of the allowed PriorityClasses.
// A default value can be specified, and all the Pod resources created will inherit the declared class.
Expand Down
2 changes: 1 addition & 1 deletion api/v1beta2/zz_generated.deepcopy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 16 additions & 0 deletions pkg/webhook/defaults/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,19 @@ func NewPriorityClassError(class string, msg error) error {
func (e PriorityClassError) Error() string {
return fmt.Sprintf("Failed to resolve Priority Class %s: %s", e.priorityClass, e.msg)
}

type RuntimeClassError struct {
runtimeClass string
defaultClass string
}

func NewRuntimeClassError(defaultClass, usedClass string) error {
return &RuntimeClassError{
runtimeClass: usedClass,
defaultClass: defaultClass,
}
}

func (e RuntimeClassError) Error() string {
return fmt.Sprintf("The Runtime Class %s is not allowed, leave an empty value or specify the default one %s", e.runtimeClass, e.defaultClass)
}
105 changes: 74 additions & 31 deletions pkg/webhook/defaults/pods.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,79 +11,122 @@ import (
corev1 "k8s.io/api/core/v1"
schedulev1 "k8s.io/api/scheduling/v1"
"k8s.io/client-go/tools/record"
"k8s.io/utils/ptr"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/webhook/admission"

capsulev1beta2 "github.com/projectcapsule/capsule/api/v1beta2"
"github.com/projectcapsule/capsule/pkg/api"
"github.com/projectcapsule/capsule/pkg/webhook/utils"
)

func mutatePodDefaults(ctx context.Context, req admission.Request, c client.Client, decoder admission.Decoder, recorder record.EventRecorder, namespace string) *admission.Response {
var err error

pod := &corev1.Pod{}
if err = decoder.Decode(req, pod); err != nil {
var pod corev1.Pod
if err := decoder.Decode(req, &pod); err != nil {
return utils.ErroredResponse(err)
}

pod.SetNamespace(namespace)

var tnt *capsulev1beta2.Tenant
tnt, tErr := utils.TenantByStatusNamespace(ctx, c, pod.Namespace)
if tErr != nil {
return utils.ErroredResponse(tErr)
} else if tnt == nil {
return nil
}

tnt, err = utils.TenantByStatusNamespace(ctx, c, pod.Namespace)
if err != nil {
return utils.ErroredResponse(err)
var err error

pcMutated, pcErr := handlePriorityClassDefault(ctx, c, tnt.Spec.PriorityClasses, &pod)
if pcErr != nil {
return utils.ErroredResponse(pcErr)
} else if pcMutated {
defer func() {
if err == nil {
recorder.Eventf(tnt, corev1.EventTypeNormal, "TenantDefault", "Assigned Tenant default Priority Class %s to %s/%s", tnt.Spec.PriorityClasses.Default, pod.Namespace, pod.Name)
}
}()
}

rcMutated, rcErr := handleRuntimeClassDefault(tnt.Spec.RuntimeClasses, &pod)
if rcErr != nil {
return utils.ErroredResponse(rcErr)
} else if rcMutated {
defer func() {
if err == nil {
recorder.Eventf(tnt, corev1.EventTypeNormal, "TenantDefault", "Assigned Tenant default Runtime Class %s to %s/%s", tnt.Spec.RuntimeClasses.Default, pod.Namespace, pod.Name)
}
}()
}

if tnt == nil {
if !rcMutated && !pcMutated {
return nil
}

allowed := tnt.Spec.PriorityClasses
var marshaled []byte

if marshaled, err = json.Marshal(pod); err != nil {
return utils.ErroredResponse(err)
}

return ptr.To(admission.PatchResponseFromRaw(req.Object.Raw, marshaled))
}

func handleRuntimeClassDefault(allowed *api.DefaultAllowedListSpec, pod *corev1.Pod) (mutated bool, err error) {
if allowed == nil || allowed.Default == "" {
return nil
return false, nil
}

priorityClassPod := pod.Spec.PriorityClassName
runtimeClass := pod.Spec.RuntimeClassName

var mutate bool
if allowed.Default == "" && runtimeClass == nil {
return false, nil
}

if allowed.Default != "" && runtimeClass != nil && *runtimeClass == allowed.Default {
return false, nil
}

if allowed.Default != "" && runtimeClass != nil && *runtimeClass != allowed.Default {
// Should not happen, validation must be happened before
return false, NewRuntimeClassError(allowed.Default, *runtimeClass)
}

pod.Spec.RuntimeClassName = &allowed.Default

return true, nil
}

func handlePriorityClassDefault(ctx context.Context, c client.Client, allowed *api.DefaultAllowedListSpec, pod *corev1.Pod) (mutated bool, err error) {
if allowed == nil || allowed.Default == "" {
return false, nil
}

priorityClassPod := pod.Spec.PriorityClassName

var cpc *schedulev1.PriorityClass
// PriorityClass name is empty, if no GlobalDefault is set and no PriorityClass was given on pod
if len(priorityClassPod) > 0 && priorityClassPod != allowed.Default {
cpc, err = utils.GetPriorityClassByName(ctx, c, priorityClassPod)
// Should not happen, since API already checks if PC present
if err != nil {
response := admission.Denied(NewPriorityClassError(priorityClassPod, err).Error())

return &response
return false, NewPriorityClassError(priorityClassPod, err)
}
} else {
mutate = true
mutated = true
}

if mutate = mutate || (utils.IsDefaultPriorityClass(cpc) && cpc.GetName() != allowed.Default); !mutate {
return nil
if mutated = mutated || (utils.IsDefaultPriorityClass(cpc) && cpc.GetName() != allowed.Default); !mutated {
return false, nil
}

pc, err := utils.GetPriorityClassByName(ctx, c, allowed.Default)
if err != nil {
return utils.ErroredResponse(fmt.Errorf("failed to assign tenant default Priority Class: %w", err))
return false, fmt.Errorf("failed to assign tenant default Priority Class: %w", err)
}

pod.Spec.PreemptionPolicy = pc.PreemptionPolicy
pod.Spec.Priority = &pc.Value
pod.Spec.PriorityClassName = pc.Name
// Marshal Pod
marshaled, err := json.Marshal(pod)
if err != nil {
return utils.ErroredResponse(err)
}

recorder.Eventf(tnt, corev1.EventTypeNormal, "TenantDefault", "Assigned Tenant default Priority Class %s to %s/%s", allowed.Default, pod.Namespace, pod.Name)

response := admission.PatchResponseFromRaw(req.Object.Raw, marshaled)

return &response
return true, nil
}
6 changes: 3 additions & 3 deletions pkg/webhook/pod/runtimeclass_errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ import (

type podRuntimeClassForbiddenError struct {
runtimeClassName string
spec api.SelectorAllowedListSpec
spec api.DefaultAllowedListSpec
}

func NewPodRuntimeClassForbidden(runtimeClassName string, spec api.SelectorAllowedListSpec) error {
func NewPodRuntimeClassForbidden(runtimeClassName string, spec api.DefaultAllowedListSpec) error {
return &podRuntimeClassForbiddenError{
runtimeClassName: runtimeClassName,
spec: spec,
Expand All @@ -25,5 +25,5 @@ func NewPodRuntimeClassForbidden(runtimeClassName string, spec api.SelectorAllow
func (f podRuntimeClassForbiddenError) Error() (err string) {
err = fmt.Sprintf("Pod Runtime Class %s is forbidden for the current Tenant: ", f.runtimeClassName)

return utils.AllowedValuesErrorMessage(f.spec, err)
return utils.DefaultAllowedValuesErrorMessage(f.spec, err)
}
4 changes: 0 additions & 4 deletions pkg/webhook/utils/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,6 @@ func ErroredResponse(err error) *admission.Response {
}

func DefaultAllowedValuesErrorMessage(allowed api.DefaultAllowedListSpec, err string) string {
return AllowedValuesErrorMessage(allowed.SelectorAllowedListSpec, err)
}

func AllowedValuesErrorMessage(allowed api.SelectorAllowedListSpec, err string) string {
var extra []string
if len(allowed.Exact) > 0 {
extra = append(extra, fmt.Sprintf("use one from the following list (%s)", strings.Join(allowed.Exact, ", ")))
Expand Down

0 comments on commit 4f3e20c

Please sign in to comment.