diff --git a/treatment-service/models/storage.go b/treatment-service/models/storage.go index 52d0529..90ad674 100644 --- a/treatment-service/models/storage.go +++ b/treatment-service/models/storage.go @@ -254,21 +254,22 @@ func (i *ExperimentIndex) checkSegmentHasWeakMatch(segmentName string) bool { } func (s *LocalStorage) InsertProjectSettings(projectSettings *pubsub.ProjectSettings) error { - s.Lock() - defer s.Unlock() - // check that settings with the same Id doesn't exist - for _, existingSettings := range s.ProjectSettings { - if existingSettings.GetProjectId() == projectSettings.GetProjectId() { - return nil - } + existingProjectSettings := s.findProjectSettingsById(ProjectId(projectSettings.GetProjectId())) + if existingProjectSettings != nil { + return nil } - s.ProjectSettings = append(s.ProjectSettings, projectSettings) // Update project segmenters on creation - if err := s.initProjectSegmenters([]*pubsub.ProjectSettings{projectSettings}); err != nil { + newSegmenters, err := s.fetchProjectSegmenters([]*pubsub.ProjectSettings{projectSettings}) + if err != nil { return err } + + s.Lock() + defer s.Unlock() + s.ProjectSegmenters = newSegmenters + s.ProjectSettings = append(s.ProjectSettings, projectSettings) return nil } @@ -284,6 +285,21 @@ func (s *LocalStorage) UpdateProjectSettings(updatedProjectSettings *pubsub.Proj } func (s *LocalStorage) FindProjectSettingsWithId(projectId ProjectId) *pubsub.ProjectSettings { + projectSettings := s.findSubscribedProjectSettingsById(projectId) + if projectSettings != nil { + return projectSettings + } + + // In case new project was just created and we are subscribed to its ID + // we'll try to retrieve it from management service + projectSettings, err := s.fetchProjectSettingsWithId(projectId) + if err != nil { + return nil + } + return projectSettings +} + +func (s *LocalStorage) findSubscribedProjectSettingsById(projectId ProjectId) *pubsub.ProjectSettings { s.RLock() defer s.RUnlock() @@ -291,23 +307,33 @@ func (s *LocalStorage) FindProjectSettingsWithId(projectId ProjectId) *pubsub.Pr return nil } + return s.findProjectSettingsById(projectId) +} + +func (s *LocalStorage) findProjectSettingsById(projectId ProjectId) *pubsub.ProjectSettings { + s.RLock() + defer s.RUnlock() + for _, settings := range s.ProjectSettings { if ProjectId(settings.ProjectId) == projectId { return settings } } + return nil +} - // In case new project was just created and we are subscribed to its ID - // we'll try to retrieve it from management service +func (s *LocalStorage) fetchProjectSettingsWithId(projectId ProjectId) (*pubsub.ProjectSettings, error) { projectSettingsResponse, err := s.managementClient.GetProjectSettingsWithResponse( context.Background(), int64(projectId)) if err != nil { - return nil + return nil, err } project := OpenAPIProjectSettingsSpecToProtobuf(projectSettingsResponse.JSON200.Data) + s.Lock() + defer s.Unlock() s.ProjectSettings = append(s.ProjectSettings, project) - return project + return project, nil } func (s *LocalStorage) GetSegmentersTypeMapping(projectId ProjectId) (map[string]schema.SegmenterType, error) { @@ -322,10 +348,10 @@ func (s *LocalStorage) GetSegmentersTypeMapping(projectId ProjectId) (map[string } func (s *LocalStorage) FindExperiments(projectId ProjectId, filters []SegmentFilter) []*ExperimentMatch { - experiments := s.Experiments[projectId] s.RLock() defer s.RUnlock() + experiments := s.Experiments[projectId] var matched = make([]*ExperimentMatch, 0) for _, item := range experiments { @@ -356,10 +382,10 @@ func (s *LocalStorage) FindExperiments(projectId ProjectId, filters []SegmentFil } func (s *LocalStorage) FindExperimentWithId(projectId ProjectId, experimentId int64) *pubsub.Experiment { - currentExperiments, settingsExist := s.Experiments[projectId] - s.RLock() defer s.RUnlock() + + currentExperiments, settingsExist := s.Experiments[projectId] if !settingsExist { return nil } @@ -487,9 +513,6 @@ func (s *LocalStorage) DumpExperiments(filepath string) error { } func (s *LocalStorage) Init() error { - s.Lock() - defer s.Unlock() - var subscribedProjectSettings []*pubsub.ProjectSettings var err error if len(s.subscribedProjectIds) > 0 { @@ -504,18 +527,23 @@ func (s *LocalStorage) Init() error { if len(s.subscribedProjectIds) > 0 && len(subscribedProjectSettings) != len(s.subscribedProjectIds) { return errors.New("not all subscribed project ids are found") } - s.ProjectSettings = subscribedProjectSettings - err = s.initProjectSegmenters(subscribedProjectSettings) + newSegmenters, err := s.fetchProjectSegmenters(subscribedProjectSettings) if err != nil { return err } - err = s.initExperiments(subscribedProjectSettings) + newExperiments, err := s.fetchExperiments(subscribedProjectSettings, newSegmenters) if err != nil { return err } + s.Lock() + defer s.Unlock() + s.ProjectSegmenters = newSegmenters + s.Experiments = newExperiments + s.ProjectSettings = subscribedProjectSettings + return nil } @@ -597,7 +625,10 @@ func NewLocalStorage( return &s, err } -func (s *LocalStorage) initExperiments(subscribedProjectSettings []*pubsub.ProjectSettings) error { +func (s *LocalStorage) fetchExperiments( + subscribedProjectSettings []*pubsub.ProjectSettings, + projectSegmenters map[ProjectId]map[string]schema.SegmenterType, +) (map[ProjectId][]*ExperimentIndex, error) { log.Println("retrieving project experiments...") index := make(map[ProjectId][]*ExperimentIndex) for _, projectSettings := range subscribedProjectSettings { @@ -607,14 +638,14 @@ func (s *LocalStorage) initExperiments(subscribedProjectSettings []*pubsub.Proje endTime := time.Now().Add(855360 * time.Hour) activeStatus := schema.ExperimentStatusActive - segmentersType := s.ProjectSegmenters[projectId] + segmentersType := projectSegmenters[projectId] resp, err := s.managementClient.ListExperimentsWithResponse( context.TODO(), projectSettings.ProjectId, &managementClient.ListExperimentsParams{StartTime: &startTime, EndTime: &endTime, Status: &activeStatus}, ) if err != nil { - return err + return nil, err } if resp.StatusCode() == 200 { @@ -622,7 +653,7 @@ func (s *LocalStorage) initExperiments(subscribedProjectSettings []*pubsub.Proje index[projectId] = make([]*ExperimentIndex, 0) index, err = flattenProjectExperiments(projectId, index, projectExperiments, segmentersType) if err != nil { - return err + return nil, err } var pages int @@ -637,24 +668,24 @@ func (s *LocalStorage) initExperiments(subscribedProjectSettings []*pubsub.Proje &managementClient.ListExperimentsParams{Page: &page, StartTime: &startTime, EndTime: &endTime, Status: &activeStatus}, ) if err != nil { - return err + return nil, err } if resp.StatusCode() == 200 { projectExperiments := resp.JSON200.Data index, err = flattenProjectExperiments(projectId, index, projectExperiments, segmentersType) if err != nil { - return err + return nil, err } } } } } - s.Experiments = index - return nil + return index, nil } -func (s *LocalStorage) initProjectSegmenters(settings []*pubsub.ProjectSettings) error { +func (s *LocalStorage) fetchProjectSegmenters(settings []*pubsub.ProjectSettings) (map[ProjectId]map[string]schema.SegmenterType, error) { + projectSegmenters := make(map[uint32]map[string]schema.SegmenterType) for _, projectSettings := range settings { log.Printf("retrieving project segmenters for %d", projectSettings.ProjectId) segmentersResp, err := s.managementClient.ListSegmentersWithResponse( @@ -663,15 +694,16 @@ func (s *LocalStorage) initProjectSegmenters(settings []*pubsub.ProjectSettings) &managementClient.ListSegmentersParams{}, ) if err != nil { - return err + return nil, err } segmenters := map[string]schema.SegmenterType{} for _, v := range segmentersResp.JSON200.Data { segmenters[v.Name] = schema.SegmenterType(strings.ToLower(string(v.Type))) } - s.ProjectSegmenters[ProjectId(projectSettings.ProjectId)] = segmenters + projectSegmenters[ProjectId(projectSettings.ProjectId)] = segmenters } - return nil + + return projectSegmenters, nil } func (s *LocalStorage) UpdateProjectSegmenters(segmenter *_segmenters.SegmenterConfiguration, projectId int64) {