Skip to content

Commit

Permalink
refactor(generator): use annotations for the model
Browse files Browse the repository at this point in the history
  • Loading branch information
coryan committed Jan 31, 2025
1 parent 3f5f0a1 commit 853c632
Show file tree
Hide file tree
Showing 24 changed files with 195 additions and 182 deletions.
11 changes: 11 additions & 0 deletions generator/internal/api/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ type API struct {
Messages []*Message
// Enums
Enums []*Enum
// Language specific annotations
Codec any

// State contains helpful information that can be used when generating
// clients.
Expand Down Expand Up @@ -92,6 +94,9 @@ type Service struct {
DefaultHost string
// The Protobuf package this service belongs to.
Package string
// The model this service belongs to, mustache templates use this field to
// navigate the data structure.
Model *API
// Language specific annotations
Codec any
}
Expand Down Expand Up @@ -121,6 +126,12 @@ type Method struct {
ServerSideStreaming bool
// For methods returning long-running operations
OperationInfo *OperationInfo
// The model this method belongs to, mustache templates use this field to
// navigate the data structure.
Model *API
// The service this method belongs to, mustache templates use this field to
// navigate the data structure.
Service *Service
// Language specific annotations
Codec any
}
Expand Down
7 changes: 7 additions & 0 deletions generator/internal/api/xref.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,5 +50,12 @@ func CrossReference(model *API) error {
m.OperationInfo.Method = m
}
}
for _, s := range model.State.ServiceByID {
s.Model = model
for _, m := range s.Methods {
m.Model = model
m.Service = s
}
}
return nil
}
23 changes: 23 additions & 0 deletions generator/internal/api/xref_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,26 @@ func TestCrossReferenceMethod(t *testing.T) {
t.Errorf("mismatched output type, got=%v, want=%v", method.OutputType, response)
}
}

func TestCrossReferenceService(t *testing.T) {
service := &Service{
Name: "Service",
ID: ".test.Service",
}
mixin := &Service{
Name: "Mixin",
ID: ".external.Mixin",
}

model := NewTestAPI([]*Message{}, []*Enum{}, []*Service{service})
model.State.ServiceByID[mixin.ID] = mixin
if err := CrossReference(model); err != nil {
t.Fatal(err)
}
if service.Model != model {
t.Errorf("mismatched model, got=%v, want=%v", service.Model, model)
}
if mixin.Model != model {
t.Errorf("mismatched model, got=%v, want=%v", mixin.Model, model)
}
}
4 changes: 2 additions & 2 deletions generator/internal/golang/golang.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ type goImport struct {
}

func Generate(model *api.API, outdir string, options map[string]string) error {
data, err := newTemplateData(model, options)
_, err := annotateModel(model, options)
if err != nil {
return err
}
provider := templatesProvider()
return language.GenerateFromRoot(outdir, data, provider, generatedFiles())
return language.GenerateFromRoot(outdir, model, provider, generatedFiles())
}

func generatedFiles() []language.GeneratedFile {
Expand Down
30 changes: 6 additions & 24 deletions generator/internal/golang/gotemplate.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,14 @@ import (
"github.com/iancoleman/strcase"
)

type templateData struct {
Name string
Title string
Description string
type modelAnnotations struct {
PackageName string
SourcePackageName string
HasServices bool
CopyrightYear string
BoilerPlate []string
Imports []string
DefaultHost string
Services []*api.Service
Messages []*api.Message
Enums []*api.Enum
GoPackage string
}

Expand Down Expand Up @@ -105,11 +99,11 @@ type enumValueAnnotation struct {
EnumType string
}

// newTemplateData creates a struct used as input for Mustache templates.
// annotateModel creates a struct used as input for Mustache templates.
// Fields and methods defined in this struct directly correspond to Mustache
// tags. For example, the Mustache tag {{#Services}} uses the
// [Template.Services] field.
func newTemplateData(model *api.API, options map[string]string) (*templateData, error) {
func annotateModel(model *api.API, options map[string]string) (*modelAnnotations, error) {
var (
sourceSpecificationPackageName string
packageNameOverride string
Expand Down Expand Up @@ -153,10 +147,7 @@ func newTemplateData(model *api.API, options map[string]string) (*templateData,
for _, s := range model.Services {
annotateService(s, model.State)
}
data := &templateData{
Name: model.Name,
Title: model.Title,
Description: model.Description,
ann := &modelAnnotations{
PackageName: modelPackageName(model, packageNameOverride),
SourcePackageName: sourceSpecificationPackageName,
HasServices: len(model.Services) > 0,
Expand All @@ -171,20 +162,11 @@ func newTemplateData(model *api.API, options map[string]string) (*templateData,
}
return ""
}(),
Services: model.Services,
Messages: model.Messages,
Enums: model.Enums,
GoPackage: packageName,
}

for _, s := range data.Services {
for _, method := range s.Methods {
if m, ok := model.State.MessageByID[method.InputTypeID]; ok {
method.InputType = m
}
}
}
return data, nil
model.Codec = ann
return ann, nil
}

func annotateService(s *api.Service, state *api.APIState) {
Expand Down
2 changes: 1 addition & 1 deletion generator/internal/golang/gotemplate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func Test_GoEnumAnnotations(t *testing.T) {

model := api.NewTestAPI(
[]*api.Message{}, []*api.Enum{enum}, []*api.Service{})
_, err := newTemplateData(model, map[string]string{})
_, err := annotateModel(model, map[string]string{})
if err != nil {
t.Fatal(err)
}
Expand Down
18 changes: 9 additions & 9 deletions generator/internal/golang/templates/client.go.mustache
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
}}
// Copyright {{CopyrightYear}} Google LLC
{{#BoilerPlate}}
// Copyright {{Codec.CopyrightYear}} Google LLC
{{#Codec.BoilerPlate}}
//{{{.}}}
{{/BoilerPlate}}
{{/Codec.BoilerPlate}}

package {{GoPackage}}
package {{Codec.GoPackage}}

import (
"bytes"
Expand All @@ -31,13 +31,13 @@ import (
"time"

"cloud.google.com/go/auth"
{{#Imports}}
{{#Codec.Imports}}
{{{.}}}
{{/Imports}}
{{/Codec.Imports}}
)

{{#HasServices}}
const defaultHost = "https://{{DefaultHost}}"
{{#Codec.HasServices}}
const defaultHost = "https://{{Codec.DefaultHost}}"

type Options struct {
Credentials *auth.Credentials
Expand Down Expand Up @@ -145,7 +145,7 @@ func doRequest(client *http.Client, req *http.Request) ([]byte, error){
}
return b, nil
}
{{/HasServices}}
{{/Codec.HasServices}}
{{#Messages}}
{{> message}}
{{/Messages}}
2 changes: 1 addition & 1 deletion generator/internal/golang/templates/go.mod.mustache
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
}}
module {{PackageName}}
module {{Codec.PackageName}}

go 1.23.2
11 changes: 5 additions & 6 deletions generator/internal/rust/rust.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,13 @@ func Generate(model *api.API, outdir string, options map[string]string) error {
if err != nil {
return err
}
data, err := newTemplateData(model, codec, outdir)
if err != nil {
if _, err := annotateModel(model, codec, outdir); err != nil {
return err
}
provider := templatesProvider()
hasServices := len(model.State.ServiceByID) > 0
generatedFiles := generatedFiles(codec.generateModule, hasServices)
return language.GenerateFromRoot(outdir, data, provider, generatedFiles)
return language.GenerateFromRoot(outdir, model, provider, generatedFiles)
}

func newCodec(options map[string]string) (*codec, error) {
Expand Down Expand Up @@ -1552,7 +1551,7 @@ func hasStreamingRPC(model *api.API) bool {
return false
}

func addStreamingFeature(data *templateData, api *api.API, extraPackages []*packagez) {
func addStreamingFeature(ann *modelAnnotations, api *api.API, extraPackages []*packagez) {
hasStreamingRPC := hasStreamingRPC(api)
if !hasStreamingRPC {
return
Expand All @@ -1575,8 +1574,8 @@ func addStreamingFeature(data *templateData, api *api.API, extraPackages []*pack
}
sort.Strings(deps)
features := fmt.Sprintf("unstable-stream = [%s]", strings.Join(deps, ", "))
data.HasFeatures = true
data.Features = append(data.Features, features)
ann.HasFeatures = true
ann.Features = append(ann.Features, features)
}

func generateMethod(m *api.Method) bool {
Expand Down
4 changes: 2 additions & 2 deletions generator/internal/rust/rust_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ func TestRust_NoStreamingFeature(t *testing.T) {
{Name: "CreateResource", IsPageableResponse: false},
}, []*api.Enum{}, []*api.Service{})
loadWellKnownTypes(model.State)
data := &templateData{}
data := &modelAnnotations{}
addStreamingFeature(data, model, c.extraPackages)
if data.HasFeatures {
t.Errorf("mismatch in data.HasFeatures, expected `HasFeatures: false`, got=%v", data)
Expand Down Expand Up @@ -471,7 +471,7 @@ func checkRustContext(t *testing.T, codec *codec, wantFeatures string) {
{Name: "ListResources", IsPageableResponse: true},
}, []*api.Enum{}, []*api.Service{})
loadWellKnownTypes(model.State)
data := &templateData{}
data := &modelAnnotations{}
addStreamingFeature(data, model, codec.extraPackages)
want := []string{wantFeatures}
if !data.HasFeatures {
Expand Down
Loading

0 comments on commit 853c632

Please sign in to comment.