Skip to content

Commit

Permalink
fix modulitos' findings (part 1)
Browse files Browse the repository at this point in the history
  • Loading branch information
roehrijn committed Aug 21, 2024
1 parent e278706 commit e3fe4ce
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 54 deletions.
94 changes: 47 additions & 47 deletions pkg/cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ type Response struct {
type ServiceAccountCache interface {
Start(stop chan struct{})
Get(name, namespace string) Response
GetOrNotify(name, namespace string, handler chan any) Response
GetOrNotify(name, namespace string, handler chan struct{}) Response
GetCommonConfigurations(name, namespace string) (useRegionalSTS bool, tokenExpiration int64)
// ToJSON returns cache contents as JSON string
ToJSON() string
Expand All @@ -70,7 +70,7 @@ type serviceAccountCache struct {
composeRoleArn ComposeRoleArn
defaultTokenExpiration int64
webhookUsage prometheus.Gauge
notificationHandlers map[string]chan any // type of channel doesn't matter. It's just for being notified
notificationHandlers map[string]chan struct{}
}

type ComposeRoleArn struct {
Expand Down Expand Up @@ -106,32 +106,32 @@ func (c *serviceAccountCache) Get(name, namespace string) Response {
// It will first look at the set of ServiceAccounts configured using annotations. If none is found, it will register
// handler to be notified as soon as a ServiceAccount with given key is populated to the cache. Afterwards it will check
// for a ServiceAccount configured through the pod-identity-webhook ConfigMap.
func (c *serviceAccountCache) GetOrNotify(name, namespace string, handler chan any) Response {
func (c *serviceAccountCache) GetOrNotify(name, namespace string, handler chan struct{}) Response {
result := Response{
TokenExpiration: pkg.DefaultTokenExpiration,
}
klog.V(5).Infof("Fetching sa %s/%s from cache", namespace, name)
{
resp := c.getSAorNotify(name, namespace, handler)
if resp != nil {
entry := c.getSAorNotify(name, namespace, handler)
if entry != nil {
result.FoundInSACache = true
}
if resp != nil && resp.RoleARN != "" {
result.RoleARN = resp.RoleARN
result.Audience = resp.Audience
result.UseRegionalSTS = resp.UseRegionalSTS
result.TokenExpiration = resp.TokenExpiration
if entry != nil && entry.RoleARN != "" {
result.RoleARN = entry.RoleARN
result.Audience = entry.Audience
result.UseRegionalSTS = entry.UseRegionalSTS
result.TokenExpiration = entry.TokenExpiration
return result
}
}
{
resp := c.getCM(name, namespace)
if resp != nil {
entry := c.getCM(name, namespace)
if entry != nil {
result.FoundInCMCache = true
result.RoleARN = resp.RoleARN
result.Audience = resp.Audience
result.UseRegionalSTS = resp.UseRegionalSTS
result.TokenExpiration = resp.TokenExpiration
result.RoleARN = entry.RoleARN
result.Audience = entry.Audience
result.UseRegionalSTS = entry.UseRegionalSTS
result.TokenExpiration = entry.TokenExpiration
return result
}
}
Expand All @@ -143,34 +143,34 @@ func (c *serviceAccountCache) GetOrNotify(name, namespace string, handler chan a
// The config file for the container credentials does not contain "TokenExpiration" or "UseRegionalSTS". For backward compatibility,
// Use these fields if they are set in the sa annotations or config map.
func (c *serviceAccountCache) GetCommonConfigurations(name, namespace string) (useRegionalSTS bool, tokenExpiration int64) {
if resp := c.getSAorNotify(name, namespace, nil); resp != nil {
return resp.UseRegionalSTS, resp.TokenExpiration
} else if resp := c.getCM(name, namespace); resp != nil {
return resp.UseRegionalSTS, resp.TokenExpiration
if entry := c.getSAorNotify(name, namespace, nil); entry != nil {
return entry.UseRegionalSTS, entry.TokenExpiration
} else if entry := c.getCM(name, namespace); entry != nil {
return entry.UseRegionalSTS, entry.TokenExpiration
}
return false, pkg.DefaultTokenExpiration
}

func (c *serviceAccountCache) getSAorNotify(name, namespace string, handler chan any) *Entry {
func (c *serviceAccountCache) getSAorNotify(name, namespace string, handler chan struct{}) *Entry {
c.mu.RLock()
defer c.mu.RUnlock()
resp, ok := c.saCache[namespace+"/"+name]
entry, ok := c.saCache[namespace+"/"+name]
if !ok && handler != nil {
klog.V(5).Infof("Service Account %s/%s not found in cache, adding notification handler", namespace, name)
c.notificationHandlers[namespace+"/"+name] = handler
return nil
}
return resp
return entry
}

func (c *serviceAccountCache) getCM(name, namespace string) *Entry {
c.mu.RLock()
defer c.mu.RUnlock()
resp, ok := c.cmCache[namespace+"/"+name]
entry, ok := c.cmCache[namespace+"/"+name]
if !ok {
return nil
}
return resp
return entry
}

func (c *serviceAccountCache) popSA(name, namespace string) {
Expand Down Expand Up @@ -200,7 +200,7 @@ func (c *serviceAccountCache) ToJSON() string {
}

func (c *serviceAccountCache) addSA(sa *v1.ServiceAccount) {
resp := &Entry{}
entry := &Entry{}

arn, ok := sa.Annotations[c.annotationPrefix+"/"+pkg.RoleARNAnnotation]
if ok {
Expand All @@ -214,57 +214,57 @@ func (c *serviceAccountCache) addSA(sa *v1.ServiceAccount) {
} else if !matched {
klog.Warningf("arn is invalid: %s", arn)
}
resp.RoleARN = arn
entry.RoleARN = arn
}

resp.Audience = c.defaultAudience
entry.Audience = c.defaultAudience
if audience, ok := sa.Annotations[c.annotationPrefix+"/"+pkg.AudienceAnnotation]; ok {
resp.Audience = audience
entry.Audience = audience
}

resp.UseRegionalSTS = c.defaultRegionalSTS
entry.UseRegionalSTS = c.defaultRegionalSTS
if useRegionalStr, ok := sa.Annotations[c.annotationPrefix+"/"+pkg.UseRegionalSTSAnnotation]; ok {
useRegional, err := strconv.ParseBool(useRegionalStr)
if err != nil {
klog.V(4).Infof("Ignoring service account %s/%s invalid value for disable-regional-sts annotation", sa.Namespace, sa.Name)
} else {
resp.UseRegionalSTS = useRegional
entry.UseRegionalSTS = useRegional
}
}

resp.TokenExpiration = c.defaultTokenExpiration
entry.TokenExpiration = c.defaultTokenExpiration
if tokenExpirationStr, ok := sa.Annotations[c.annotationPrefix+"/"+pkg.TokenExpirationAnnotation]; ok {
if tokenExpiration, err := strconv.ParseInt(tokenExpirationStr, 10, 64); err != nil {
klog.V(4).Infof("Found invalid value for token expiration, using %d seconds as default: %v", resp.TokenExpiration, err)
klog.V(4).Infof("Found invalid value for token expiration, using %d seconds as default: %v", entry.TokenExpiration, err)
} else {
resp.TokenExpiration = pkg.ValidateMinTokenExpiration(tokenExpiration)
entry.TokenExpiration = pkg.ValidateMinTokenExpiration(tokenExpiration)
}
}
c.webhookUsage.Set(1)

c.setSA(sa.Name, sa.Namespace, resp)
c.setSA(sa.Name, sa.Namespace, entry)
}

func (c *serviceAccountCache) setSA(name, namespace string, resp *Entry) {
func (c *serviceAccountCache) setSA(name, namespace string, entry *Entry) {
c.mu.Lock()
defer c.mu.Unlock()

key := namespace + "/" + name
klog.V(5).Infof("Adding SA %q to SA cache: %+v", key, resp)
c.saCache[namespace+"/"+name] = resp
klog.V(5).Infof("Adding SA %q to SA cache: %+v", key, entry)
c.saCache[key] = entry

if handler, found := c.notificationHandlers[key]; found {
klog.V(5).Infof("Notifying handler for %q", key)
handler <- 1
handler <- struct{}{}
delete(c.notificationHandlers, key)
}
}

func (c *serviceAccountCache) setCM(name, namespace string, resp *Entry) {
func (c *serviceAccountCache) setCM(name, namespace string, entry *Entry) {
c.mu.Lock()
defer c.mu.Unlock()
klog.V(5).Infof("Adding SA %s/%s to CM cache: %+v", namespace, name, resp)
c.cmCache[namespace+"/"+name] = resp
klog.V(5).Infof("Adding SA %s/%s to CM cache: %+v", namespace, name, entry)
c.cmCache[namespace+"/"+name] = entry
}

func New(defaultAudience, prefix string, defaultRegionalSTS bool, defaultTokenExpiration int64, saInformer coreinformers.ServiceAccountInformer, cmInformer coreinformers.ConfigMapInformer, composeRoleArn ComposeRoleArn) ServiceAccountCache {
Expand All @@ -286,7 +286,7 @@ func New(defaultAudience, prefix string, defaultRegionalSTS bool, defaultTokenEx
defaultTokenExpiration: defaultTokenExpiration,
hasSynced: hasSynced,
webhookUsage: webhookUsage,
notificationHandlers: map[string]chan any{},
notificationHandlers: map[string]chan struct{}{},
}

saInformer.Informer().AddEventHandler(
Expand Down Expand Up @@ -348,12 +348,12 @@ func (c *serviceAccountCache) populateCacheFromCM(oldCM, newCM *v1.ConfigMap) er
if err != nil {
return fmt.Errorf("failed to unmarshal new config %q: %v", newConfig, err)
}
for key, resp := range sas {
for key, entry := range sas {
parts := strings.Split(key, "/")
if resp.TokenExpiration == 0 {
resp.TokenExpiration = c.defaultTokenExpiration
if entry.TokenExpiration == 0 {
entry.TokenExpiration = c.defaultTokenExpiration
}
c.setCM(parts[1], parts[0], resp)
c.setCM(parts[1], parts[0], entry)
}

if oldCM != nil {
Expand Down
1 change: 0 additions & 1 deletion pkg/cache/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,6 @@ func TestNonRegionalSTS(t *testing.T) {
t.Fatalf("cache never called addSA: %v", err)
}

//gotRoleArn, gotAudience, useRegionalSTS, gotTokenExpiration, found := cache.Get("default", "default")
resp := cache.Get("default", "default")
assert.True(t, resp.FoundInSACache, "Expected cache entry to be found")
if resp.RoleARN != roleArn {
Expand Down
2 changes: 1 addition & 1 deletion pkg/cache/fake.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func (f *FakeServiceAccountCache) Get(name, namespace string) Response {
}

// GetOrNotify gets a service account from the cache
func (f *FakeServiceAccountCache) GetOrNotify(name, namespace string, handler chan any) Response {
func (f *FakeServiceAccountCache) GetOrNotify(name, namespace string, handler chan struct{}) Response {
f.mu.RLock()
defer f.mu.RUnlock()
resp, ok := f.cache[namespace+"/"+name]
Expand Down
2 changes: 1 addition & 1 deletion pkg/handler/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ func (m *Modifier) buildPodPatchConfig(pod *corev1.Pod) *podPatchConfig {
}

// Use the STS WebIdentity method if set
handler := make(chan any, 1)
handler := make(chan struct{}, 1)
response := m.Cache.GetOrNotify(pod.Spec.ServiceAccountName, pod.Namespace, handler)
key := pod.Namespace + "/" + pod.Spec.ServiceAccountName
if !response.FoundInSACache && !response.FoundInCMCache && m.saLookupGraceTime > 0 {
Expand Down
8 changes: 4 additions & 4 deletions pkg/handler/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,22 +83,22 @@ func TestMutatePod(t *testing.T) {
want, _ := json.MarshalIndent(c.response, "", " ")
t.Errorf("Unexpected response. Got \n%s\n wanted \n%s", string(got), string(want))
}
var expectedPatchOps, actualPatchOps []byte
if len(response.Patch) > 0 {
patchOps := make([]patchOperation, 0)
if err := json.Unmarshal(response.Patch, &patchOps); err != nil {
t.Errorf("Failed to unmarshal patch: %v", err)
}
indentedPatchOps, _ := json.MarshalIndent(patchOps, "", " ")
t.Logf("got patch operations: %s", string(indentedPatchOps))
actualPatchOps, _ = json.MarshalIndent(patchOps, "", " ")
}
if len(c.response.Patch) > 0 {
patchOps := make([]patchOperation, 0)
if err := json.Unmarshal(c.response.Patch, &patchOps); err != nil {
t.Errorf("Failed to unmarshal patch: %v", err)
}
indentedPatchOps, _ := json.MarshalIndent(patchOps, "", " ")
t.Logf("wanted patch operations: %s", string(indentedPatchOps))
expectedPatchOps, _ = json.MarshalIndent(patchOps, "", " ")
}
assert.Equal(t, string(expectedPatchOps), string(actualPatchOps))
})
}
}
Expand Down

0 comments on commit e3fe4ce

Please sign in to comment.