Skip to content

Commit

Permalink
fix: ensure we are fully backwards compatible
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinvalk committed Feb 21, 2024
1 parent 4a7f0df commit 3e51313
Show file tree
Hide file tree
Showing 13 changed files with 136 additions and 108 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
bin
*.wasm

# Devenv
.envrc
Expand Down
2 changes: 1 addition & 1 deletion internal/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ package python
type Config struct {
EmitAsync bool `json:"emit_async"` // Emits async code instead of sync
EmitExactTableNames bool `json:"emit_exact_table_names"`
EmitGenerators bool `json:"emit_generators"` // Will we use generators or lists, defaults to true
EmitGenerators bool `json:"emit_generators"` // Will we use generators or lists, defaults to false
EmitModule bool `json:"emit_module"` // If true emits functions in module, else wraps in a class.
EmitPydanticModels bool `json:"emit_pydantic_models"`
EmitSyncQuerier bool `json:"emit_sync_querier"` // DEPRECATED ALIAS FOR: emit_type = 'class', emit_generators = True
Expand Down
3 changes: 2 additions & 1 deletion internal/endtoend/endtoend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,9 @@ func TestGenerate(t *testing.T) {
cmd := exec.Command(sqlc, "diff")
cmd.Dir = dir
got, err := cmd.CombinedOutput()
// TODO: We are diffing patches! Does this make sense and what should we provide to the end user?
if diff := cmp.Diff(string(want), string(got)); diff != "" {
t.Errorf("sqlc diff mismatch (-want +got):\n%s", diff)
t.Errorf("sqlc diff mismatch (-want +got):\n%s", string(got))
}
if len(want) == 0 && err != nil {
t.Error(err)
Expand Down
2 changes: 1 addition & 1 deletion internal/endtoend/testdata/emit_pydantic_models/sqlc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ plugins:
- name: py
wasm:
url: file://../../../../bin/sqlc-gen-python.wasm
sha256: "a6c5d174c407007c3717eea36ff0882744346e6ba991f92f71d6ab2895204c0e"
sha256: "c97fad53818679a948c68f3ffe84530d7ca4999f636d3f3d89202c6c08ee224d"
sql:
- schema: schema.sql
queries: query.sql
Expand Down
2 changes: 1 addition & 1 deletion internal/endtoend/testdata/exec_result/sqlc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ plugins:
- name: py
wasm:
url: file://../../../../bin/sqlc-gen-python.wasm
sha256: "a6c5d174c407007c3717eea36ff0882744346e6ba991f92f71d6ab2895204c0e"
sha256: "c97fad53818679a948c68f3ffe84530d7ca4999f636d3f3d89202c6c08ee224d"
sql:
- schema: schema.sql
queries: query.sql
Expand Down
2 changes: 1 addition & 1 deletion internal/endtoend/testdata/exec_rows/sqlc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ plugins:
- name: py
wasm:
url: file://../../../../bin/sqlc-gen-python.wasm
sha256: "a6c5d174c407007c3717eea36ff0882744346e6ba991f92f71d6ab2895204c0e"
sha256: "c97fad53818679a948c68f3ffe84530d7ca4999f636d3f3d89202c6c08ee224d"
sql:
- schema: schema.sql
queries: query.sql
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ plugins:
- name: py
wasm:
url: file://../../../../bin/sqlc-gen-python.wasm
sha256: "a6c5d174c407007c3717eea36ff0882744346e6ba991f92f71d6ab2895204c0e"
sha256: "c97fad53818679a948c68f3ffe84530d7ca4999f636d3f3d89202c6c08ee224d"
sql:
- schema: schema.sql
queries: query.sql
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ plugins:
- name: py
wasm:
url: file://../../../../bin/sqlc-gen-python.wasm
sha256: "a6c5d174c407007c3717eea36ff0882744346e6ba991f92f71d6ab2895204c0e"
sha256: "c97fad53818679a948c68f3ffe84530d7ca4999f636d3f3d89202c6c08ee224d"
sql:
- schema: schema.sql
queries: query.sql
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ plugins:
- name: py
wasm:
url: file://../../../../bin/sqlc-gen-python.wasm
sha256: "a6c5d174c407007c3717eea36ff0882744346e6ba991f92f71d6ab2895204c0e"
sha256: "c97fad53818679a948c68f3ffe84530d7ca4999f636d3f3d89202c6c08ee224d"
sql:
- schema: schema.sql
queries: query.sql
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ plugins:
- name: py
wasm:
url: file://../../../../bin/sqlc-gen-python.wasm
sha256: "a6c5d174c407007c3717eea36ff0882744346e6ba991f92f71d6ab2895204c0e"
sha256: "c97fad53818679a948c68f3ffe84530d7ca4999f636d3f3d89202c6c08ee224d"
sql:
- schema: schema.sql
queries: query.sql
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ plugins:
- name: py
wasm:
url: file://../../../../bin/sqlc-gen-python.wasm
sha256: "a6c5d174c407007c3717eea36ff0882744346e6ba991f92f71d6ab2895204c0e"
sha256: "c97fad53818679a948c68f3ffe84530d7ca4999f636d3f3d89202c6c08ee224d"
sql:
- schema: schema.sql
queries: query.sql
Expand Down
212 changes: 119 additions & 93 deletions internal/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -845,7 +845,7 @@ func querierClassDef(name string, connectionAnnotation *pyast.Node) *pyast.Class
Arg: "self",
},
{
Arg: "connection",
Arg: "conn",
Annotation: connectionAnnotation,
},
},
Expand All @@ -855,9 +855,9 @@ func querierClassDef(name string, connectionAnnotation *pyast.Node) *pyast.Class
Node: &pyast.Node_Assign{
Assign: &pyast.Assign{
Targets: []*pyast.Node{
poet.Attribute(poet.Name("self"), "_connection"),
poet.Attribute(poet.Name("self"), "_conn"),
},
Value: poet.Name("connection"),
Value: poet.Name("conn"),
},
},
},
Expand All @@ -869,80 +869,12 @@ func querierClassDef(name string, connectionAnnotation *pyast.Node) *pyast.Class
}
}

func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node {
mod := moduleNode(ctx.SqlcVersion, source)
std, pkg := i.queryImportSpecs(source)
mod.Body = append(mod.Body, buildImportGroup(std), buildImportGroup(pkg))
mod.Body = append(mod.Body, &pyast.Node{
Node: &pyast.Node_ImportGroup{
ImportGroup: &pyast.ImportGroup{
Imports: []*pyast.Node{
{
Node: &pyast.Node_ImportFrom{
ImportFrom: &pyast.ImportFrom{
Module: ctx.C.Package,
Names: []*pyast.Node{
poet.Alias("models"),
},
},
},
},
},
},
},
})

for _, q := range ctx.Queries {
if !ctx.OutputQuery(q.SourceName) {
continue
}
queryText := fmt.Sprintf("-- name: %s \\\\%s\n%s\n", q.MethodName, q.Cmd, q.SQL)
mod.Body = append(mod.Body, assignNode(q.ConstantName, poet.Constant(queryText)))

// Generate params structures
for _, arg := range q.Args {
if arg.EmitStruct() {
var def *pyast.ClassDef
if ctx.C.EmitPydanticModels {
def = pydanticNode(arg.Struct.Name)
} else {
def = dataclassNode(arg.Struct.Name)
}

// We need a copy as we want to make sure that nullable params are at the end of the dataclass
fields := make([]Field, len(arg.Struct.Fields))
copy(fields, arg.Struct.Fields)

// Place all nullable fields at the end and try to keep the original order as much as possible
sort.SliceStable(fields, func(i int, j int) bool {
return (fields[j].Type.IsNull && fields[i].Type.IsNull != fields[j].Type.IsNull) || i < j
})

for _, f := range fields {
def.Body = append(def.Body, fieldNode(f, true))
}
mod.Body = append(mod.Body, poet.Node(def))
}
}
if q.Ret.EmitStruct() {
var def *pyast.ClassDef
if ctx.C.EmitPydanticModels {
def = pydanticNode(q.Ret.Struct.Name)
} else {
def = dataclassNode(q.Ret.Struct.Name)
}
for _, f := range q.Ret.Struct.Fields {
def.Body = append(def.Body, fieldNode(f, false))
}
mod.Body = append(mod.Body, poet.Node(def))
}
}

func buildQuerierClass(ctx *pyTmplCtx, isAsync bool) []*pyast.Node {
functions := make([]*pyast.Node, 0, 10)

// Define some reused types based on async or sync code
var connectionAnnotation *pyast.Node
if ctx.C.EmitAsync {
if isAsync {
connectionAnnotation = typeRefNode("sqlalchemy", "ext", "asyncio", "AsyncConnection")
} else {
connectionAnnotation = typeRefNode("sqlalchemy", "engine", "Connection")
Expand All @@ -951,9 +883,9 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node {
// We need to figure out how to access the SQLAlchemy connectionVar object
var connectionVar *pyast.Node
if ctx.C.EmitModule {
connectionVar = poet.Name("connection")
connectionVar = poet.Name("conn")
} else {
connectionVar = poet.Attribute(poet.Name("self"), "_connection")
connectionVar = poet.Attribute(poet.Name("self"), "_conn")
}

// We loop through all queries and build our query functions
Expand All @@ -968,7 +900,7 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node {

if ctx.C.EmitModule {
f.Args.Args = append(f.Args.Args, &pyast.Arg{
Arg: "connection",
Arg: "conn",
Annotation: connectionAnnotation,
})
} else {
Expand All @@ -980,7 +912,7 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node {
q.AddArgs(f.Args)

exec := poet.Expr(connMethodNode(poet.Attribute(connectionVar, "execute"), q.ConstantName, q.ArgDictNode()))
if ctx.C.EmitAsync {
if isAsync {
exec = poet.Await(exec)
}

Expand Down Expand Up @@ -1017,7 +949,7 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node {
f.Returns = subscriptNode("Optional", q.Ret.Annotation())
case ":many":
if ctx.C.EmitGenerators {
if ctx.C.EmitAsync {
if isAsync {
// If we are using generators and async, we are switching to stream implementation
exec = poet.Await(connMethodNode(poet.Attribute(connectionVar, "stream"), q.ConstantName, q.ArgDictNode()))

Expand Down Expand Up @@ -1094,8 +1026,8 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node {
panic("unknown cmd " + q.Cmd)
}

// If we are emitting async code, we have to swap our sync func for an async one and fix the connection annotation.
if ctx.C.EmitAsync {
// If we are emitting async code, we have to swap our sync func for an async one and fix the conn annotation.
if isAsync {
functions = append(functions, poet.Node(&pyast.AsyncFunctionDef{
Name: f.Name,
Args: f.Args,
Expand All @@ -1107,13 +1039,115 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node {
}
}

// Lets see how to add all functions
return functions
}

func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node {
mod := moduleNode(ctx.SqlcVersion, source)
std, pkg := i.queryImportSpecs(source)
mod.Body = append(mod.Body, buildImportGroup(std), buildImportGroup(pkg))
mod.Body = append(mod.Body, &pyast.Node{
Node: &pyast.Node_ImportGroup{
ImportGroup: &pyast.ImportGroup{
Imports: []*pyast.Node{
{
Node: &pyast.Node_ImportFrom{
ImportFrom: &pyast.ImportFrom{
Module: ctx.C.Package,
Names: []*pyast.Node{
poet.Alias("models"),
},
},
},
},
},
},
},
})

for _, q := range ctx.Queries {
if !ctx.OutputQuery(q.SourceName) {
continue
}
queryText := fmt.Sprintf("-- name: %s \\\\%s\n%s\n", q.MethodName, q.Cmd, q.SQL)
mod.Body = append(mod.Body, assignNode(q.ConstantName, poet.Constant(queryText)))

// Generate params structures
for _, arg := range q.Args {
if arg.EmitStruct() {
var def *pyast.ClassDef
if ctx.C.EmitPydanticModels {
def = pydanticNode(arg.Struct.Name)
} else {
def = dataclassNode(arg.Struct.Name)
}

// We need a copy as we want to make sure that nullable params are at the end of the dataclass
fields := make([]Field, len(arg.Struct.Fields))
copy(fields, arg.Struct.Fields)

// Place all nullable fields at the end and try to keep the original order as much as possible
sort.SliceStable(fields, func(i int, j int) bool {
return (fields[j].Type.IsNull && fields[i].Type.IsNull != fields[j].Type.IsNull) || i < j
})

for _, f := range fields {
def.Body = append(def.Body, fieldNode(f, true))
}
mod.Body = append(mod.Body, poet.Node(def))
}
}
if q.Ret.EmitStruct() {
var def *pyast.ClassDef
if ctx.C.EmitPydanticModels {
def = pydanticNode(q.Ret.Struct.Name)
} else {
def = dataclassNode(q.Ret.Struct.Name)
}
for _, f := range q.Ret.Struct.Fields {
def.Body = append(def.Body, fieldNode(f, false))
}
mod.Body = append(mod.Body, poet.Node(def))
}
}

// Lets see how to add all functions, we can either add them to the module directly or from within a class.
if ctx.C.EmitModule {
mod.Body = append(mod.Body, functions...)
mod.Body = append(mod.Body, buildQuerierClass(ctx, ctx.C.EmitAsync)...)
} else {
cls := querierClassDef("Querier", connectionAnnotation)
cls.Body = append(cls.Body, functions...)
mod.Body = append(mod.Body, poet.Node(cls))
asyncConnectionAnnotation := typeRefNode("sqlalchemy", "ext", "asyncio", "AsyncConnection")
syncConnectionAnnotation := typeRefNode("sqlalchemy", "engine", "Connection")

// NOTE: For backwards compatibility we support generating multiple classes, but this is definitely suboptimal.
// It is much better to use the `emit_async: bool` config to select what type to emit
if ctx.C.EmitAsyncQuerier || ctx.C.EmitSyncQuerier {

// When using these backwards compatible settings we force behavior!
ctx.C.EmitModule = false
ctx.C.EmitGenerators = true

if ctx.C.EmitSyncQuerier {
cls := querierClassDef("Querier", syncConnectionAnnotation)
cls.Body = append(cls.Body, buildQuerierClass(ctx, false)...)
mod.Body = append(mod.Body, poet.Node(cls))
}
if ctx.C.EmitAsyncQuerier {
cls := querierClassDef("AsyncQuerier", asyncConnectionAnnotation)
cls.Body = append(cls.Body, buildQuerierClass(ctx, true)...)
mod.Body = append(mod.Body, poet.Node(cls))
}
} else {
var connectionAnnotation *pyast.Node
if ctx.C.EmitAsync {
connectionAnnotation = asyncConnectionAnnotation
} else {
connectionAnnotation = syncConnectionAnnotation
}

cls := querierClassDef("Querier", connectionAnnotation)
cls.Body = append(cls.Body, buildQuerierClass(ctx, ctx.C.EmitAsync)...)
mod.Body = append(mod.Body, poet.Node(cls))
}
}

return poet.Node(mod)
Expand Down Expand Up @@ -1150,14 +1184,6 @@ func Generate(_ context.Context, req *plugin.GenerateRequest) (*plugin.GenerateR
}
}

// TODO: Remove when when we drop support for deprecated EmitSyncQuerier and EmitAsyncQuerier options
if conf.EmitAsyncQuerier || conf.EmitSyncQuerier {
conf.EmitModule = false
conf.EmitGenerators = true
conf.EmitAsync = conf.EmitAsyncQuerier
// TODO/NOTE: We now have a breaking change because we emit only one flavor. What do we want to do?
}

enums := buildEnums(req)
models := buildModels(conf, req)
queries, err := buildQueries(conf, req, models)
Expand Down
Loading

0 comments on commit 3e51313

Please sign in to comment.