diff --git a/common/client.go b/common/client.go index 2d33241..dc13b80 100644 --- a/common/client.go +++ b/common/client.go @@ -29,6 +29,7 @@ import ( "net/http" "net/url" "reflect" + "strings" "time" ) @@ -154,6 +155,82 @@ func (rc *RestClient) ListHosts() ([]Host, error) { return hostList, err } +// FindOne is a convenience function, which queries the appropriate service +// and retrieves one entity based on provided structure, and puts the results +// into the same structure. The provided argument, entity, should be a pointer +// to the desired structure, e.g., &common.Host{}. +func (rc *RestClient) FindOne(entity interface{}) error { + structType := reflect.TypeOf(entity).Elem() + entityName := structType.String() + entityDottedNames := strings.Split(entityName, ".") + if len(entityDottedNames) > 1 { + entityName = entityDottedNames[1] + } + entityName = strings.ToLower(entityName) + var serviceName string + switch entityName { + case "tenant": + serviceName = "tenant" + case "segment": + serviceName = "tenant" + case "host": + serviceName = "topology" + default: + return NewError("Do not know where to find entity '%s'", entityName) + } + + svcURL, err := rc.GetServiceUrl(serviceName) + if err != nil { + return err + } + if !strings.HasSuffix(svcURL, "/") { + svcURL += "/" + } + svcURL += "findOne/" + entityName + "s?" + queryString := "" + structValue := reflect.ValueOf(entity).Elem() + + for i := 0; i < structType.NumField(); i++ { + structField := structType.Field(i) + fieldTag := structField.Tag + fieldName := structField.Name + queryStringFieldName := strings.ToLower(fieldName) + omitEmpty := false + if fieldTag != "" { + jTag := fieldTag.Get("json") + if jTag != "" { + jTagElts := strings.Split(jTag, ",") + // This takes care of ",omitempty" + if len(jTagElts) > 1 { + queryStringFieldName = jTagElts[0] + for _, jTag2 := range jTagElts { + if jTag2 == "omitempty" { + omitEmpty = true + break + } // if jTag2 + } // for / jTagElts + } else { + queryStringFieldName = jTag + } + } // if jTag + } // if fieldTag + fieldValue := structValue.Field(i).Interface() + if omitEmpty && IsZeroValue(fieldValue) { + log.Printf("Skipping field %s: %v - empty", fieldName, fieldValue) + continue + } + + if queryString != "" { + queryString += "&" + } + + queryString += fmt.Sprintf("%s=%v", queryStringFieldName, fieldValue) + } + url := svcURL + queryString + log.Printf("Trying to find one %s at %s", entityName, url) + return rc.Get(url, entity) +} // func + // GetServiceUrl is a convenience function, which, given the root // service URL and name of desired service, returns the URL of that service. func (rc *RestClient) GetServiceUrl(name string) (string, error) { diff --git a/common/defs.go b/common/defs.go index d1cb622..4ee2358 100644 --- a/common/defs.go +++ b/common/defs.go @@ -28,7 +28,7 @@ import ( func String(i interface{}) string { j, e := json.Marshal(i) if e != nil { - return fmt.Sprintf("%#v", i) + return fmt.Sprintf("%+v", i) } return string(j) } @@ -92,8 +92,8 @@ type TokenMessage struct { // } // part of the response. type LinkResponse struct { - Href string - Rel string + Href string `json:"href,omitempty"` + Rel string `json:"rel,omitempty"` } // Type definitions @@ -105,11 +105,11 @@ type ServiceMessage string // about the host. type Host struct { ID uint64 `sql:"AUTO_INCREMENT" json:"id,omitempty"` - Name string `json:"name"` - Ip string `json:"ip"` - RomanaIp string `json:"romana_ip"` + Name string `json:"name,omitempty"` + Ip string `json:"ip,omitempty"` + RomanaIp string `json:"romana_ip,omitempty"` AgentPort uint64 `json:"agent_port,omitempty"` - Links Links `json:"links" sql:"-"` + Links Links `json:"links,omitempty" sql:"-"` } // Message to register with the root service the actual @@ -157,7 +157,7 @@ func (p PortRange) String() string { // 4. If Protocol specified is "icmp", Ports and PortRanges fields should be blank. // 5. If Protocol specified is not "icmp", Icmptype and IcmpCode should be unspecified. type Rule struct { - Protocol string `json:"protocol"` + Protocol string `json:"protocol,omitempty"` Ports []uint `json:"ports,omitempty"` PortRanges []PortRange `json:"port_ranges,omitempty"` // IcmpType only applies if Protocol value is ICMP and @@ -193,7 +193,7 @@ type Policy struct { Name string `json:"name"` // ID is Romana-generated unique (within Romana deployment) ID of this policy, // to be used in REST requests. It will be ignored when set by user. - ID uint64 `json:"id,omitempty",sql:"AUTO_INCREMENT"` + ID uint64 `json:"id,omitempty" sql:"AUTO_INCREMENT"` // ExternalID is an optional identifier of this policy in an external system working // with Romana in this deployment (e.g., Open Stack). ExternalID string `json:"external_id,omitempty",sql:"not null;unique"` @@ -211,7 +211,7 @@ func (p Policy) String() string { // isValidProto checks if the Protocol specified in Rule is valid. // The following protocols are recognized: -// - any -- wildcard +// - any -- see Wildcard // - tcp // - udp // - icmp @@ -419,14 +419,14 @@ type Datacenter struct { // We don't need to store this, but calculate and pass around Prefix uint64 `json:"prefix"` Cidr string `json:"cidr,omitempty"` - PrefixBits uint `json:"prefix_bits"` - PortBits uint `json:"port_bits"` - TenantBits uint `json:"tenant_bits"` - SegmentBits uint `json:"segment_bits"` + PrefixBits uint `json:"prefix_bits,omitempty"` + PortBits uint `json:"port_bits,omitempty"` + TenantBits uint `json:"tenant_bits,omitempty"` + SegmentBits uint `json:"segment_bits,omitempty"` // We don't need to store this, but calculate and pass around - EndpointBits uint `json:"endpoint_bits"` - EndpointSpaceBits uint `json:"endpoint_space_bits"` - Name string `json:"name"` + EndpointBits uint `json:"endpoint_bits,omitempty"` + EndpointSpaceBits uint `json:"endpoint_space_bits,omitempty"` + Name string `json:"name,omitempty"` } func (dc Datacenter) String() string { diff --git a/common/helpers.go b/common/helpers.go index 5916424..64b6194 100644 --- a/common/helpers.go +++ b/common/helpers.go @@ -21,6 +21,7 @@ import ( "fmt" "log" "os" + "reflect" "strings" "sync" ) @@ -46,6 +47,24 @@ func initEnviron() { } } +// IsZeroValue checks whether the provided value is equal to the +// zero value for the type. Zero values would be: +// - 0 for numeric types +// - "" for strings +// - uninitialized struct for a struct +// - zero-size for a slice or a map +func IsZeroValue(val interface{}) bool { + valType := reflect.TypeOf(val) + valKind := valType.Kind() + if valKind == reflect.Slice || valKind == reflect.Map { + valVal := reflect.ValueOf(val) + return valVal.Len() == 0 + } + zeroVal := reflect.Zero(valType).Interface() + log.Printf("Zero value of %+v (type %T, kind %s) is %+v", val, val, valKind, zeroVal) + return val == zeroVal +} + // CleanURL is similar to path.Clean() but to work on URLs func CleanURL(url string) (string, error) { elements := strings.Split(url, "/") diff --git a/common/middleware.go b/common/middleware.go index f521324..72a3f86 100644 --- a/common/middleware.go +++ b/common/middleware.go @@ -131,7 +131,7 @@ func write500(writer http.ResponseWriter, m Marshaller, err error) { httpErr := NewError500(err) // Should never error out - it's a struct we know. outData, _ := m.Marshal(httpErr) - log.Printf("Made\n\t%#v\n\tfrom\n\t%#v\n\t%s", httpErr, err, string(outData)) + log.Printf("Made\n\t%+v\n\tfrom\n\t%+v\n\t%s", httpErr, err, string(outData)) writer.Write(outData) } diff --git a/common/store.go b/common/store.go index a92d848..635f138 100644 --- a/common/store.go +++ b/common/store.go @@ -162,7 +162,7 @@ func (dbStore *DbStore) Find(query url.Values, entities interface{}, single bool log.Printf("For %s, query string field %s, struct field %s, DB field %s", t, queryStringField, fieldName, dbField) queryStringFieldToDbField[queryStringField] = dbField } - log.Printf("%#v", queryStringFieldToDbField) + log.Printf("%+v", queryStringFieldToDbField) whereMap := make(map[string]interface{}) for k, v := range query { @@ -177,7 +177,7 @@ func (dbStore *DbStore) Find(query url.Values, entities interface{}, single bool whereMap[dbFieldName] = v[0] } - log.Printf("Querying with %#v - %T", whereMap, entities) + log.Printf("Querying with %+v - %T", whereMap, entities) db := dbStore.Db.Where(whereMap).Find(entities) err := GetDbErrors(db) @@ -187,14 +187,14 @@ func (dbStore *DbStore) Find(query url.Values, entities interface{}, single bool rowCount := reflect.ValueOf(entities).Elem().Len() if rowCount == 0 { - return nil, NewError404(t.String(), fmt.Sprintf("%#v", whereMap)) + return nil, NewError404(t.String(), fmt.Sprintf("%+v", whereMap)) } if single { if rowCount == 1 { return reflect.ValueOf(entities).Elem().Index(0).Interface(), nil } else { - return nil, NewError500(fmt.Sprintf("Multiple results found for %#v", query)) + return nil, NewError500(fmt.Sprintf("Multiple results found for %+v", query)) } } diff --git a/ipam/ipam.go b/ipam/ipam.go index be3f83b..59fd67a 100644 --- a/ipam/ipam.go +++ b/ipam/ipam.go @@ -16,7 +16,6 @@ package ipam import ( - "errors" "fmt" "github.com/romana/core/common" "github.com/romana/core/tenant" @@ -77,35 +76,25 @@ func (ipam *IPAM) allocateIP(input interface{}, ctx common.RestContext) (interfa // 2. Kubernetes (CNI Plugin) // https://github.com/romana/kube/blob/master/CNI/romana#L134 // IP=$(curl -s "http://$ROMANA_MASTER_IP:9601/allocateIP?tenantName=${tenant}&segmentName=${segment}&hostName=${node}" | get_json_kv | get_ip) - - tenantParam := "" - tenantLookupField := "" - + ten := &tenant.Tenant{} if tenantID := ctx.QueryVariables.Get("tenantID"); tenantID != "" { // This is how IPAM plugin driver calls us. - tenantParam = tenantID - tenantLookupField = "external_id" + ten.ExternalID = tenantID } else if tenantName := ctx.QueryVariables.Get("tenantName"); tenantName != "" { - // This is how CNI plugin calls us. - tenantParam = tenantName - tenantLookupField = "name" + ten.Name = tenantName + } else { + return nil, common.NewError400("Either tenantID or tenantName must be specified.") } - // check for missing/empty required parameters - if tenantParam == "" { - err := errors.New("Missing or empty tenantName/tenantID parameter") - log.Printf("IPAM encountered an error: %v", err) - return nil, err - } segmentName := ctx.QueryVariables.Get("segmentName") if segmentName == "" { - err := errors.New("Missing or empty segmentName parameter") + err := common.NewError400("Missing or empty segmentName parameter") log.Printf("IPAM encountered an error: %v", err) return nil, err } hostName := ctx.QueryVariables.Get("hostName") if hostName == "" { - err := errors.New("Missing or empty hostName parameter") + err := common.NewError400("Missing or empty hostName parameter") log.Printf("IPAM encountered an error: %v", err) return nil, err } @@ -121,48 +110,31 @@ func (ipam *IPAM) allocateIP(input interface{}, ctx common.RestContext) (interfa log.Printf("IPAM encountered an error: %v", err) return nil, err } - // Get host info from topology service - topoUrl, err := client.GetServiceUrl("topology") - if err != nil { - log.Printf("IPAM encountered an error: %v", err) - return nil, err - } - hostsUrl := fmt.Sprintf("%s/findOne/hosts?name=%s", topoUrl, hostName) - host := common.Host{} - err = client.Get(hostsUrl, &host) + host := &common.Host{} + host.Name = hostName + err = client.FindOne(host) if err != nil { - log.Printf("IPAM encountered an error finding tenants: %v", err) + log.Printf("IPAM encountered an error finding host for name %s %v", hostName, err) return nil, err } endpoint.HostId = fmt.Sprintf("%d", host.ID) log.Printf("Host name %s has ID %s", hostName, endpoint.HostId) - tenantSvcUrl, err := client.GetServiceUrl("tenant") - if err != nil { - log.Printf("IPAM encountered an error: %v", err) - return nil, err - } - - // TODO follow links once tenant service supports it. For now... - tenantsUrl := fmt.Sprintf("%s/findOne/tenants?%s=%s", tenantSvcUrl, tenantLookupField, tenantParam) - ten := tenant.Tenant{} - err = client.Get(tenantsUrl, &ten) + err = client.FindOne(ten) if err != nil { - log.Printf("IPAM encountered an error finding tenants: %v", err) + log.Printf("IPAM encountered an error finding tenants %+v: %v", ten, err) return nil, err } endpoint.TenantID = fmt.Sprintf("%d", ten.ID) - log.Printf("IPAM: Tenant '%s' has ID %s, original %d", tenantParam, endpoint.TenantID, ten.ID) - - segmentsUrl := fmt.Sprintf("%s/findOne/segments?tenant_id=%s&name=%s", tenantSvcUrl, endpoint.TenantID, segmentName) - segment := tenant.Segment{} - err = client.Get(segmentsUrl, &segment) + seg := &tenant.Segment{Name: segmentName, TenantID: ten.ID} + err = client.FindOne(seg) if err != nil { - log.Printf("IPAM encountered an error finding segments: %v", err) + log.Printf("IPAM encountered an error finding segments: %+v: %v", seg, err) return nil, err } - endpoint.SegmentID = fmt.Sprintf("%d", segment.ID) + + endpoint.SegmentID = fmt.Sprintf("%d", seg.ID) log.Printf("Segment name %s has ID %s", segmentName, endpoint.SegmentID) return ipam.addEndpoint(&endpoint, ctx) } @@ -323,7 +295,7 @@ func (ipam *IPAM) Initialize() error { dcURL := index.Links.FindByRel("datacenter") dc := common.Datacenter{} - log.Printf("IPAM received datacenter information from topology service: %#v\n", dc) + log.Printf("IPAM received datacenter information from topology service: %+v\n", dc) err = client.Get(dcURL, &dc) if err != nil { return err diff --git a/policy/policy.go b/policy/policy.go index d004425..1e23500 100644 --- a/policy/policy.go +++ b/policy/policy.go @@ -142,7 +142,7 @@ func (policy *PolicySvc) augmentEndpoint(endpoint *common.Endpoint) error { } endpoint.SegmentNetworkID = &segment.Seq } else if endpoint.SegmentExternalID != "" || endpoint.SegmentName != "" { - segmentsUrl := fmt.Sprintf("%s/findOne/segments?tenant_id=%d&", tenantSvcUrl, *endpoint.TenantNetworkID) + segmentsUrl := fmt.Sprintf("%s/findOne/segments?tenant_id=%d&", tenantSvcUrl, ten.ID) if endpoint.SegmentExternalID != "" { segmentsUrl += "external_id=" + endpoint.TenantExternalID + "&" } @@ -223,7 +223,7 @@ func (policy *PolicySvc) distributePolicy(policyDoc *common.Policy) error { url := fmt.Sprintf("http://%s:%d/policies", host.Ip, host.AgentPort) log.Printf("Sending policy %s to agent at %s", policyDoc.Name, url) result := make(map[string]interface{}) - err = policy.client.Post(url, policyDoc, result) + err = policy.client.Post(url, policyDoc, &result) log.Printf("Agent at %s returned %v", host.Ip, result) if err != nil { errStr = append(errStr, fmt.Sprintf("Error applying policy %d to host %s: %v. ", policyDoc.ID, host.Ip, err)) diff --git a/romana/kubernetes/kubernetes_listener_test.go b/romana/kubernetes/kubernetes_listener_test.go index 3e35e71..6983565 100644 --- a/romana/kubernetes/kubernetes_listener_test.go +++ b/romana/kubernetes/kubernetes_listener_test.go @@ -116,7 +116,7 @@ func (s *mockSvc) Routes() common.Routes { case *common.Policy: return input, nil default: - return nil, common.NewError("Expected common.Policy, got %#v", input) + return nil, common.NewError("Expected common.Policy, got %+v", input) } }, MakeMessage: func() interface{} { return &common.Policy{} }, @@ -151,7 +151,7 @@ func (s *mockSvc) Routes() common.Routes { s.tenants[s.tenantCounter] = newTenant.ExternalID s.tenantsStr[newTenant.ExternalID] = s.tenantCounter newTenant.ID = s.tenantCounter - log.Printf("In tenantAddRoute\n\t%#v\n\t%#v", s.tenants, s.tenantsStr) + log.Printf("In tenantAddRoute\n\t%+v\n\t%+v", s.tenants, s.tenantsStr) return newTenant, nil }, @@ -162,7 +162,7 @@ func (s *mockSvc) Routes() common.Routes { Method: "GET", Pattern: "/tenants/{tenantID}", Handler: func(input interface{}, ctx common.RestContext) (interface{}, error) { - log.Printf("In tenantGetRoute\n\t%#v\n\t%#v", s.tenants, s.tenantsStr) + log.Printf("In tenantGetRoute\n\t%+v\n\t%+v", s.tenants, s.tenantsStr) idStr := ctx.PathVariables["tenantID"] id, err := strconv.ParseUint(idStr, 10, 64) if err != nil { @@ -197,7 +197,7 @@ func (s *mockSvc) Routes() common.Routes { s.segments[s.segmentCounter] = newSegment.ExternalID s.segmentsStr[newSegment.ExternalID] = s.segmentCounter newSegment.ID = s.segmentCounter - log.Printf("In segmentAddRoute\n\t%#v\n\t%#v", s.segments, s.segmentsStr) + log.Printf("In segmentAddRoute\n\t%+v\n\t%+v", s.segments, s.segmentsStr) return newSegment, nil }, MakeMessage: func() interface{} { return &tenant.Segment{} }, @@ -207,7 +207,7 @@ func (s *mockSvc) Routes() common.Routes { Method: "GET", Pattern: "/tenants/{tenantID}/segments/{segmentID}", Handler: func(input interface{}, ctx common.RestContext) (interface{}, error) { - log.Printf("In segmentGetRoute\n\t%#v\n\t%#v", s.segments, s.segmentsStr) + log.Printf("In segmentGetRoute\n\t%+v\n\t%+v", s.segments, s.segmentsStr) idStr := ctx.PathVariables["segmentID"] id, err := strconv.ParseUint(idStr, 10, 64) if err != nil { @@ -263,7 +263,7 @@ func (s *mockSvc) Routes() common.Routes { Method: "POST", Pattern: "/config/kubernetes-listener/port", Handler: func(input interface{}, ctx common.RestContext) (interface{}, error) { - log.Printf("Received %#v", input) + log.Printf("Received %+v", input) return "OK", nil }, } @@ -278,7 +278,7 @@ func (s *mockSvc) Routes() common.Routes { kubeListenerConfigRoute, registerPortRoute, } - log.Printf("mockService: Set up routes: %#v", routes) + log.Printf("mockService: Set up routes: %+v", routes) return routes } @@ -299,7 +299,7 @@ func (s *MySuite) getKubeListenerServiceConfig() *common.ServiceConfig { kubeListenerConfig["segment_label_name"] = "tier" svcConfig := common.ServiceConfig{Common: commonConfig, ServiceSpecific: kubeListenerConfig} - log.Printf("Test: Returning KubernetesListener config %#v", svcConfig.ServiceSpecific) + log.Printf("Test: Returning KubernetesListener config %+v", svcConfig.ServiceSpecific) return &svcConfig } @@ -340,7 +340,7 @@ func (s *MySuite) TestListener(c *check.C) { log.Printf("Test: Kubernetes listening on %s (%s)", s.kubeURL, svcInfo.Address) cfg := &common.ServiceConfig{Common: common.CommonConfig{Api: &common.Api{Port: 0, RestTimeoutMillis: 100}}} - log.Printf("Test: Mock service config:\n\t%#v\n\t%#v\n", cfg.Common.Api, cfg.ServiceSpecific) + log.Printf("Test: Mock service config:\n\t%+v\n\t%+v\n", cfg.Common.Api, cfg.ServiceSpecific) svc := &mockSvc{mySuite: s} svc.tenants = make(map[uint64]string) svc.tenantsStr = make(map[string]uint64) diff --git a/romana/kubernetes/listener.go b/romana/kubernetes/listener.go index 42cd8ee..455acc2 100644 --- a/romana/kubernetes/listener.go +++ b/romana/kubernetes/listener.go @@ -133,22 +133,36 @@ func Run(rootServiceURL string, cred *common.Credential) (*common.RestServiceInf // getOrAddSegment finds a segment (based on segment selector). // If not found, it adds one. -func (l *kubeListener) getOrAddSegment(tenantServiceURL string, namespace string, kubeSegmentID string) (*tenant.Segment, error) { - segment := &tenant.Segment{} - segmentsURL := fmt.Sprintf("%s/tenants/%s/segments", tenantServiceURL, namespace) - err := l.restClient.Get(fmt.Sprintf("%s/%s", segmentsURL, kubeSegmentID), segment) +func (l *kubeListener) getOrAddSegment(tenantServiceURL string, namespace string, kubeSegmentName string) (*tenant.Segment, error) { + ten := &tenant.Tenant{} + ten.Name = namespace + err := l.restClient.FindOne(ten) + if err != nil { + return nil, err + } + + seg := &tenant.Segment{} + seg.Name = kubeSegmentName + seg.TenantID = ten.ID + err = l.restClient.FindOne(ten) if err == nil { - return segment, nil + return seg, nil } + switch err := err.(type) { case common.HttpError: if err.StatusCode == http.StatusNotFound { // Not found, so let's create a segment. - segreq := tenant.Segment{Name: kubeSegmentID, ExternalID: kubeSegmentID} - err2 := l.restClient.Post(segmentsURL, segreq, segment) + segreq := tenant.Segment{Name: kubeSegmentName, TenantID: ten.ID} + segURL, err2 := l.restClient.GetServiceUrl("tenant") + if err2 != nil { + return nil, err2 + } + segURL = fmt.Sprintf("%s/tenants/%d/segments", segURL, ten.ID) + err2 = l.restClient.Post(segURL, segreq, seg) if err2 == nil { // Successful creation. - return segment, nil + return seg, nil } // Creation of non-existing segment gave an error. switch err2 := err2.(type) { @@ -207,7 +221,7 @@ func (l *kubeListener) translateNetworkPolicy(kubePolicy *KubeObject) (common.Po ns := kubePolicy.Metadata.Namespace // TODO actually look up tenant K8S ID. t, tenantURL, err := l.resolveTenantByName(ns) - log.Printf("translateNetworkPolicy(): For namespace %s got %#v / %#v", ns, t, err) + log.Printf("translateNetworkPolicy(): For namespace %s got %+v / %+v", ns, t, err) if err != nil { return *romanaPolicy, err } diff --git a/romana/kubernetes/resources.go b/romana/kubernetes/resources.go index b8042d7..cfe7897 100644 --- a/romana/kubernetes/resources.go +++ b/romana/kubernetes/resources.go @@ -135,18 +135,18 @@ func (e Event) handleNamespaceEvent(l *kubeListener) { log.Printf("Processing namespace event == %v and phase %v", e.Type, e.Object.Status) if e.Type == KubeEventAdded { - tenantReq := tenant.Tenant{Name: e.Object.Metadata.Name, ExternalID: e.Object.Metadata.Name} + tenantReq := tenant.Tenant{Name: e.Object.Metadata.Name, ExternalID: e.Object.Metadata.Uid} tenantResp := tenant.Tenant{} - log.Printf("processor: Posting to /tenants: %#v", tenantReq) + log.Printf("processor: Posting to /tenants: %+v", tenantReq) tenantUrl, err := l.restClient.GetServiceUrl("tenant") if err != nil { - log.Printf("Error adding tenant %s: %#v", tenantReq.Name, err) + log.Printf("Error adding tenant %s: %+v", tenantReq.Name, err) } else { err := l.restClient.Post(fmt.Sprintf("%s/tenants", tenantUrl), tenantReq, &tenantResp) if err != nil { - log.Printf("Error adding tenant %s: %#v", tenantReq.Name, err) + log.Printf("Error adding tenant %s: %+v", tenantReq.Name, err) } else { - log.Printf("Added tenant: %#v", tenantResp) + log.Printf("Added tenant: %+v", tenantResp) } } } else { @@ -155,9 +155,9 @@ func (e Event) handleNamespaceEvent(l *kubeListener) { // tenantResp := tenant.Tenant{} // err = client.Delete("/tenants", tenantReq, &tenantResp) // if err != nil { - // log.Printf("Error adding tenant %s: %#v", tenantReq.Name, err) + // log.Printf("Error adding tenant %s: %+v", tenantReq.Name, err) // } else { - // log.Printf("Added tenant: %#v", tenantResp) + // log.Printf("Added tenant: %+v", tenantResp) // } } @@ -197,8 +197,8 @@ func CreateDefaultPolicy(o KubeObject, l *kubeListener) { Direction: common.PolicyDirectionIngress, Name: policyName, AppliedTo: []common.Endpoint{{TenantNetworkID: &tenant.Seq}}, - Peers: []common.Endpoint{{Peer: "any"}}, - Rules: []common.Rule{{Protocol: "any"}}, + Peers: []common.Endpoint{{Peer: common.Wildcard}}, + Rules: []common.Rule{{Protocol: common.Wildcard}}, } log.Printf("In CreateDefaultPolicy with policy %v\n", romanaPolicy) diff --git a/tenant/store.go b/tenant/store.go index 3d75aaf..4e482ec 100644 --- a/tenant/store.go +++ b/tenant/store.go @@ -16,6 +16,7 @@ package tenant import ( + "fmt" "github.com/romana/core/common" "log" ) @@ -37,19 +38,19 @@ func (tenantStore *tenantStore) Entities() []interface{} { } type Tenant struct { - ID uint64 `sql:"AUTO_INCREMENT"` - ExternalID string `sql:"not null;unique" json:"external_id,omitempty" gorm:"COLUMN:external_id"` - Name string `json:"name"` - Segments []Segment - Seq uint64 + ID uint64 `sql:"AUTO_INCREMENT" json:"id,omitempty"` + ExternalID string `sql:"not null;unique" json:"external_id,omitempty" gorm:"COLUMN:external_id"` + Name string `json:"name,omitempty"` + Segments []Segment `json:"segments,omitempty"` + Seq uint64 `json:"seq,omitempty"` } type Segment struct { - ID uint64 `sql:"AUTO_INCREMENT"` + ID uint64 `sql:"AUTO_INCREMENT" json:"id,omitempty"` ExternalID string `sql:"not null;" json:"external_id,omitempty" gorm:"COLUMN:external_id"` - TenantID uint64 `gorm:"COLUMN:tenant_id" json:"tenant_id"` - Name string `json:"name"` - Seq uint64 + TenantID uint64 `gorm:"COLUMN:tenant_id" json:"tenant_id,omitempty"` + Name string `json:"name,omitempty"` + Seq uint64 `json:"seq,omitempty"` } func (tenantStore *tenantStore) listTenants() ([]Tenant, error) { @@ -69,7 +70,7 @@ func (tenantStore *tenantStore) listTenants() ([]Tenant, error) { func (tenantStore *tenantStore) listSegments(tenantId string) ([]Segment, error) { var segments []Segment db := tenantStore.DbStore.Db.Joins("JOIN tenants ON segments.tenant_id = tenants.id"). - Where("tenants.id = ? OR tenants.external_id = ?", tenantId, tenantId). + Where("tenants.id = ?", tenantId, tenantId). Find(&segments) err := common.MakeMultiError(db.GetErrors()) log.Printf("In listSegments(): %v, %v", segments, err) @@ -86,100 +87,77 @@ func (tenantStore *tenantStore) addTenant(tenant *Tenant) error { log.Println("In tenantStore addTenant().") var tenants []Tenant - tenantStore.DbStore.Db.Find(&tenants) + tx := tenantStore.DbStore.Db.Begin() + tx.Find(&tenants) tenant.Seq = uint64(len(tenants)) - db := tenantStore.DbStore.Db - tenantStore.DbStore.Db.Create(tenant) - if db.Error != nil { - return db.Error - } - tenantStore.DbStore.Db.NewRecord(*tenant) - err := common.MakeMultiError(db.GetErrors()) + + tx.Create(tenant) + err := common.GetDbErrors(tx) if err != nil { + tx.Rollback() return err } - if db.Error != nil { - return db.Error - } + tx.Commit() return nil } -func (tenantStore *tenantStore) findTenants(id string) ([]Tenant, error) { - var tenants []Tenant - log.Println("In findTenant()") - db := tenantStore.DbStore.Db.Where("id = ? OR external_id = ?", id, id).Find(&tenants) - err := common.MakeMultiError(db.GetErrors()) - if err != nil { - return nil, err - } - if db.Error != nil { - return nil, db.Error - } - return tenants, nil -} - -func (tenantStore *tenantStore) findTenantsByName(name string) ([]Tenant, error) { - var tenants []Tenant - log.Println("In findTenant()") - db := tenantStore.DbStore.Db.Find(&tenants).Where("name = ?", name) - err := common.MakeMultiError(db.GetErrors()) - if err != nil { - return nil, err - } - if db.Error != nil { - return nil, db.Error - } - if len(tenants) == 0 { - return nil, common.NewError404("tenant", name) - } - return tenants, nil -} - func (tenantStore *tenantStore) addSegment(tenantId uint64, segment *Segment) error { var err error - + tx := tenantStore.DbStore.Db.Begin() // TODO(gg): better way of getting sequence var segments []Segment - db := tenantStore.DbStore.Db.Where("tenant_id = ?", tenantId).Find(&segments) - err = common.MakeMultiError(db.GetErrors()) + db := tx.Where("tenant_id = ?", tenantId).Find(&segments) + err = common.GetDbErrors(tx) if err != nil { + tx.Rollback() return err } - if db.Error != nil { - return db.Error - } segment.Seq = uint64(len(segments)) + segment.TenantID = tenantId if segment.ExternalID == "" { segment.ExternalID = segment.Name } - - tenantStore.DbStore.Db.NewRecord(*segment) - err = common.MakeMultiError(db.GetErrors()) + tx = tx.Create(segment) + err = common.GetDbErrors(db) if err != nil { + tx.Rollback() return err } + tx.Commit() + return nil +} - if db.Error != nil { - return db.Error +func (tenantStore *tenantStore) getTenant(id string) (Tenant, error) { + ten := Tenant{} + var count int + log.Println("In getTenant()") + db := tenantStore.DbStore.Db.Where("id = ?", id).First(&ten).Count(&count) + err := common.GetDbErrors(db) + if err != nil { + return ten, err } - - segment.TenantID = tenantId - db = tenantStore.DbStore.Db.Create(segment) - if db.Error != nil { - return db.Error + if count == 0 { + return ten, common.NewError404("tenant", id) } + return ten, nil +} - err = common.MakeMultiError(db.GetErrors()) +func (tenantStore *tenantStore) getSegment(tenantId string, segmentId string) (Segment, error) { + seg := Segment{} + var count int + db := tenantStore.DbStore.Db.Where("tenant_id = ? AND id = ?", tenantId, segmentId). + First(&seg).Count(&count) + + err := common.GetDbErrors(db) if err != nil { - return err + return seg, err } - - if db.Error != nil { - return db.Error + if count == 0 { + return seg, common.NewError404("segment/tenant", fmt.Sprintf("%s/%s", tenantId, segmentId)) } - return nil + return seg, nil } // CreateSchemaPostProcess implements CreateSchemaPostProcess method of @@ -194,21 +172,3 @@ func (tenantStore *tenantStore) CreateSchemaPostProcess() error { } return nil } - -func (tenantStore *tenantStore) findSegments(tenantId string, segmentId string) ([]Segment, error) { - var segments []Segment - log.Println("In findSegment()") - // TODO should internal ID take precedence? - db := tenantStore.DbStore.Db.Joins("JOIN tenants ON segments.tenant_id = tenants.id"). - Where("(tenants.id = ? OR tenants.external_id = ?) AND (segments.id = ? OR segments.external_id = ?)", tenantId, tenantId, segmentId, segmentId). - Find(&segments) - err := common.MakeMultiError(db.GetErrors()) - if err != nil { - return nil, err - } - - if db.Error != nil { - return nil, db.Error - } - return segments, nil -} diff --git a/tenant/tenant.go b/tenant/tenant.go index c78aa5d..a53d32b 100644 --- a/tenant/tenant.go +++ b/tenant/tenant.go @@ -16,7 +16,6 @@ package tenant import ( - "fmt" "log" "strconv" @@ -121,32 +120,7 @@ func (tsvc *TenantSvc) listSegments(input interface{}, ctx common.RestContext) ( func (tsvc *TenantSvc) getTenant(input interface{}, ctx common.RestContext) (interface{}, error) { idStr := ctx.PathVariables["tenantId"] log.Printf("In findTenant(%s)\n", idStr) - tenants, err := tsvc.store.findTenants(idStr) - if err != nil { - return nil, err - } - if len(tenants) == 0 { - return nil, common.NewError404("tenant", idStr) - } - if len(tenants) > 1 { - return nil, common.NewError500(fmt.Sprintf("More than one tenant matches %s: %v", idStr, tenants)) - } - return tenants[0], nil -} - -func (tsvc *TenantSvc) findTenantsByName(input interface{}, ctx common.RestContext) (interface{}, error) { - nameArr := ctx.QueryVariables[tenantNameQueryVar] - if len(nameArr) != 1 { - return nil, common.NewError("Expected exactly one value in %s, got %v", tenantNameQueryVar, nameArr) - } - nameStr := nameArr[0] - log.Printf("In findTenant(%s)\n", nameStr) - - tenants, err := tsvc.store.findTenantsByName(nameStr) - if err != nil { - return nil, err - } - return tenants, nil + return tsvc.store.getTenant(idStr) } func (tsvc *TenantSvc) addSegment(input interface{}, ctx common.RestContext) (interface{}, error) { @@ -170,17 +144,7 @@ func (tsvc *TenantSvc) getSegment(input interface{}, ctx common.RestContext) (in tenantIdStr := ctx.PathVariables["tenantId"] segmentIdStr := ctx.PathVariables["segmentId"] - segments, err := tsvc.store.findSegments(tenantIdStr, segmentIdStr) - if err != nil { - return nil, err - } - if len(segments) == 0 { - return nil, common.NewError404("segment", segmentIdStr) - } - if len(segments) > 1 { - return nil, common.NewError500(fmt.Sprintf("More than one segment matches %s: %v", segmentIdStr, segments)) - } - return segments[0], nil + return tsvc.store.getSegment(tenantIdStr, segmentIdStr) } // SetConfig implements SetConfig function of the Service interface. diff --git a/topology/topology.go b/topology/topology.go index 7027b19..2b477ce 100644 --- a/topology/topology.go +++ b/topology/topology.go @@ -204,7 +204,7 @@ func (topology *TopologySvc) SetConfig(config common.ServiceConfig) error { // if err != nil { // return err // } - log.Printf("Datacenter information: was %s, decoded to %#v\n", dcMap, dc) + log.Printf("Datacenter information: was %s, decoded to %+v\n", dcMap, dc) topology.datacenter = &dc storeConfig := config.ServiceSpecific["store"].(map[string]interface{}) topology.store = topoStore{} diff --git a/topology/topology_test.go b/topology/topology_test.go index 0aef33f..12dda53 100644 --- a/topology/topology_test.go +++ b/topology/topology_test.go @@ -99,7 +99,7 @@ func (s *MySuite) TestHostMarshaling(c *check.C) { json, _ := m.Marshal(host) marshaledJSONStr := string(json) myLog(c, "Marshaled ", host, "to", marshaledJSONStr) - expectedJSONStr := "{\"id\":1,\"name\":\"host1\",\"ip\":\"10.1.1.1\",\"romana_ip\":\"192.168.0.1/16\",\"agent_port\":9999,\"links\":null}" + expectedJSONStr := "{\"id\":1,\"name\":\"host1\",\"ip\":\"10.1.1.1\",\"romana_ip\":\"192.168.0.1/16\",\"agent_port\":9999}" c.Assert(marshaledJSONStr, check.Equals, expectedJSONStr) host2 := common.Host{} err := m.Unmarshal([]byte(expectedJSONStr), &host2)