diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..b8556db --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,23 @@ +name: test + +on: + push: + branches: + - master + - main + pull_request: + +jobs: + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: erlef/setup-beam@v1 + with: + otp-version: "26.0.2" + gleam-version: "1.4.0-rc1" + rebar3-version: "3" + # elixir-version: "1.15.4" + - run: gleam deps download + - run: gleam test + - run: gleam format --check src test diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..599be4e --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +*.beam +*.ez +/build +erl_crash.dump diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..442b532 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,37 @@ +# Contributing + +If you're reading this, thank you so much for trying to contribute to +`squirrel`! +I tried to do my best and comment the code as much as possible and make it easy +to understand. Each module also starts with a small comment to explain what it +does, so it should be easier to dive in the codebase. + +> ๐Ÿ’ก If you feel like some pieces of code are not commented enough or are too +> obscure, than that's a bug! Please do reach out, I'd love to hear your +> feedback and make `squirrel` easier to contribute to! + +## Running the tests + +Most of the tests are snapshot tests that directly call the `postgres.main` +function to let it type the queries. In order to do that `squirrel` will have to +connect to a postgres server that must be running during the tests. + +- In CI this is taken care of automatically +- Locally you'll need a little bit of setup: + - There must be a user called `squirrel_test` + - It must be able to read and write to a database called `squirrel_test` + - It will use the empty password to connect at localhost's port 5432 + +## Writing tests + +`squirrel` uses a lot of snapshot tests, to add new tests for the code +generation bits you can have a look and copy the existing ones. +There's no hard requirements but I have some suggestion to write good snapshot +tests: + +- Have at most one snapshot per test function +- Try to keep the snapshots as small as possible. + Ideally one snapshot should assert a single property of the generated code so + that it is easier to focus on a specific aspect of the code when reviewing it +- Use a long descriptive title for the snapshots: a title should describe what + one expects to see in the produced snapshot to guide the review process diff --git a/README.md b/README.md new file mode 100644 index 0000000..5502c03 --- /dev/null +++ b/README.md @@ -0,0 +1,186 @@ +# ๐Ÿฟ๏ธ squirrel - type safe SQL in Gleam + +[![Package Version](https://img.shields.io/hexpm/v/squirrel)](https://hex.pm/packages/squirrel) +[![Hex Docs](https://img.shields.io/badge/hex-docs-ffaff3)](https://hexdocs.pm/squirrel/) + +## What's Squirrel? + +If you need to talk with a database in Gleam you'll have to write something like +this: + +```gleam +import gleam/pgo +import decode + +pub type FindSquirrelRow { + FindSquirrelRow(name: String, owned_acorns: Int) +} + +/// Find a squirrel and its owned acorns given its name. +/// +pub fn find_squirrel(db: pgo.Connection, name: String) { + let squirrel_row_decoder = + decode.into({ + use name <- decode.parameter + use owned_acorns <- decode.parameter + FindSquirrelRow(name: name, owned_acorns: owned_acorns) + }) + |> decode.field(0, decode.string) + |> decode.field(1, decode.int) + + "select name, owned_acorns + from squirrel + where name = $1" + |> pgo.execute(db, [pgo.text(name)], squirrel_row_decoder) +} +``` + +This is probably fine if you have a few small queries but it can become quite +the burden when you have a lot of queries: + +- The SQL query you write is just a plain string, you do not get syntax + highlighting, auto formatting, suggestions... all the little niceties you + would otherwise get if you where writing a plain `*.sql` file. +- This also means you loose the ability to run these queries on their own with + other external tools, inspect them and so on. +- You have to manually keep in sync the decoder with the query's output. + +One might be tempted to hide all of this by reaching for something like an ORM. +Squirrel proposes a different approach: instead of trying to hide the SQL it +_embraces it and leaves you in control._ +You write the SQL queries in plain old `*.sql` files and Squirrel will take care +of generating all the corresponding functions. + +A code snippet is worth a thousand words, so let's have a look at an example. +Instead of the hand written example shown earlier you can instead just write the +following query: + +```sql +-- we're in file `src/squirrels/sql/find_squirrel.sql` +-- Find a squirrel and its owned acorns given its name. +select + name, + owned_acorns +from + squirrel +where + name = $1 +``` + +And run `gleam run -m squirrel`. Just like magic you'll now have a type-safe +function `find_squirrel` you can use just as you'd expect: + +```gleam +import squirrels/sql + +pub fn main() { + let db = todo as "the pgo connection" + // And it just works as you'd expect: + let assert Ok(#(_rows_count, rows)) = sql.find_squirrel("sandy") + let assert [FindSquirrelRow(name: "sandy", owned_acorns: 11_111)] = rows +} +``` + +Behind the scenes Squirrel generates the decoders and functions you need; and +it's pretty-printed, standard Gleam code (actually it's exactly like the hand +written example I showed you earlier)! +So now you get the best of both worlds: + +- You don't have to take care of keeping encoders and decoders in sync, Squirrel + does that for you. +- And you're not compromising on type safety either: Squirrel is able to + understand the types of your query and produce a correct decoder. +- You can stick to writing plain SQL in `*.sql` files. You'll have better + editor support, syntax highlighting and completions. +- You can run each query on its own: need to `explain` a query? + No big deal, it's just a plain old `*.sql` file. + +## Usage + +First you'll need to add Squirrel to your project as a dev dependency: + +```sh +gleam add squirrel --dev + +# Remember to add these packages if you haven't yet, they are needed by the +# generated code to run and decode the read rows! +gleam add gleam_pgo +gleam add decode +``` + +Then you can ask it to generate code running the `squirrel` module: + +```sh +gleam run -m squirrel +``` + +And that's it! As long as you follow a couple of conventions Squirrel will just +work: + +- Squirrel will look for all `*.sql` files in any `sql` directory under your + project's `src` directory. +- Each `sql` directory will be turned into a single Gleam module containing a + function for each `*.sql` file inside it. The generated Gleam module is going + to be located in the same directory as the corresponding `sql` directory and + it's name is `sql.gleam`. +- Each `*.sql` file _must contain a single SQL query._ And the name of the file + is going to be the name of the corresponding Gleam function to run that query. + +> Let's make an example. Imagine you have a Gleam project that looks like this +> +> ```txt +> โ”œโ”€โ”€ src +> โ”‚ย ย  โ”œโ”€โ”€ squirrels +> โ”‚ โ”‚ โ””โ”€โ”€ sql +> โ”‚ โ”‚ โ”œโ”€โ”€ find_squirrel.sql +> โ”‚ โ”‚ โ””โ”€โ”€ list_squirrels.sql +> โ”‚ย ย  โ””โ”€โ”€ squirrels.gleam +> โ””โ”€โ”€ test +> โ””โ”€โ”€ squirrels_test.gleam +> ``` +> +> Running `gleam run -m squirrel` will create a `src/squirrels/sql.gleam` file +> defining two functions `find_squirrel` and `list_squirrels` you can then +> import and use in your code. + +### Talking to the database + +In order to understand the type of your queries, Squirrel needs to connect to +the Postgres server where the database is defined. To connect, it will read the +[Postgres env variables](https://www.postgresql.org/docs/current/libpq-envars.html) +and use the following defaults if one is not set: + +- `PGHOST`: `"localhost"` +- `PGPORT`: `5432` +- `PGUSER`: `"root"` +- `PGDATABASE`: `"database"` +- `PGPASSWORD`: `""` + +## FAQ + +### What flavour of SQL does squirrel support? + +Squirrel only has support for Postgres. + +### Why isn't squirrel configurable in any way? + +By going the "convention over configuration" route, Squirrel enforces that all +projects adopting it will always have the same structure. +If you need to contribute to a project using Squirrel you'll immediately know +which directories and modules to look for. + +This makes it easier to get started with a new project and cuts down on all the +bike shedding: _"Where should I put my queries?",_ +_"How many queries should go in on file?",_ ... + +## References + +This package draws a lot of inspiration from the amazing +[yesql](https://github.com/krisajenkins/yesql) and +[sqlx](https://github.com/launchbadge/sqlx). + +## Contributing + +If you think thereโ€™s any way to improve this package, or if you spot a bug donโ€™t +be afraid to open PRs, issues or requests of any kind! Any contribution is +welcome ๐Ÿ’œ diff --git a/birdie_snapshots/array_decoding.accepted b/birdie_snapshots/array_decoding.accepted new file mode 100644 index 0000000..43fd62d --- /dev/null +++ b/birdie_snapshots/array_decoding.accepted @@ -0,0 +1,33 @@ +--- +version: 1.1.8 +title: array decoding +file: ./test/squirrel_test.gleam +test_name: array_decoding_test +--- +/// A row you get from running the `query` query +/// defined in `query.sql`. +/// +/// > ๐Ÿฟ๏ธ This type definition was generated automatically using v-test of the +/// > [squirrel package](https://github.com/giacomocavalieri/squirrel). +/// +pub type QueryRow { + QueryRow(res: List(Int)) +} + +/// Runs the `query` query +/// defined in `query.sql`. +/// +/// > ๐Ÿฟ๏ธ This function was generated automatically using v-test of +/// > the [squirrel package](https://github.com/giacomocavalieri/squirrel). +/// +pub fn query(db) { + let decoder = + decode.into({ + use res <- decode.parameter + QueryRow(res: res) + }) + |> decode.field(0, decode.list(decode.int)) + + "select array[1, 2, 3] as res" + |> pgo.execute(db, [], decode.from(decoder, _)) +} diff --git a/birdie_snapshots/array_encoding.accepted b/birdie_snapshots/array_encoding.accepted new file mode 100644 index 0000000..fce70b9 --- /dev/null +++ b/birdie_snapshots/array_encoding.accepted @@ -0,0 +1,37 @@ +--- +version: 1.1.8 +title: array encoding +file: ./test/squirrel_test.gleam +test_name: array_encoding_test +--- +/// A row you get from running the `query` query +/// defined in `query.sql`. +/// +/// > ๐Ÿฟ๏ธ This type definition was generated automatically using v-test of the +/// > [squirrel package](https://github.com/giacomocavalieri/squirrel). +/// +pub type QueryRow { + QueryRow(res: Bool) +} + +/// Runs the `query` query +/// defined in `query.sql`. +/// +/// > ๐Ÿฟ๏ธ This function was generated automatically using v-test of +/// > the [squirrel package](https://github.com/giacomocavalieri/squirrel). +/// +pub fn query(db, arg_1) { + let decoder = + decode.into({ + use res <- decode.parameter + QueryRow(res: res) + }) + |> decode.field(0, decode.bool) + + "select true as res where $1 = array[1, 2, 3]" + |> pgo.execute( + db, + [pgo.array(list.map(arg_1, fn(a) {pgo.int(a)}))], + decode.from(decoder, _), + ) +} diff --git a/birdie_snapshots/bool_decoding.accepted b/birdie_snapshots/bool_decoding.accepted new file mode 100644 index 0000000..c1ed746 --- /dev/null +++ b/birdie_snapshots/bool_decoding.accepted @@ -0,0 +1,33 @@ +--- +version: 1.1.8 +title: bool decoding +file: ./test/squirrel_test.gleam +test_name: bool_decoding_test +--- +/// A row you get from running the `query` query +/// defined in `query.sql`. +/// +/// > ๐Ÿฟ๏ธ This type definition was generated automatically using v-test of the +/// > [squirrel package](https://github.com/giacomocavalieri/squirrel). +/// +pub type QueryRow { + QueryRow(res: Bool) +} + +/// Runs the `query` query +/// defined in `query.sql`. +/// +/// > ๐Ÿฟ๏ธ This function was generated automatically using v-test of +/// > the [squirrel package](https://github.com/giacomocavalieri/squirrel). +/// +pub fn query(db) { + let decoder = + decode.into({ + use res <- decode.parameter + QueryRow(res: res) + }) + |> decode.field(0, decode.bool) + + "select true as res" + |> pgo.execute(db, [], decode.from(decoder, _)) +} diff --git a/birdie_snapshots/bool_encoding.accepted b/birdie_snapshots/bool_encoding.accepted new file mode 100644 index 0000000..debef8e --- /dev/null +++ b/birdie_snapshots/bool_encoding.accepted @@ -0,0 +1,33 @@ +--- +version: 1.1.8 +title: bool encoding +file: ./test/squirrel_test.gleam +test_name: bool_encoding_test +--- +/// A row you get from running the `query` query +/// defined in `query.sql`. +/// +/// > ๐Ÿฟ๏ธ This type definition was generated automatically using v-test of the +/// > [squirrel package](https://github.com/giacomocavalieri/squirrel). +/// +pub type QueryRow { + QueryRow(res: Bool) +} + +/// Runs the `query` query +/// defined in `query.sql`. +/// +/// > ๐Ÿฟ๏ธ This function was generated automatically using v-test of +/// > the [squirrel package](https://github.com/giacomocavalieri/squirrel). +/// +pub fn query(db, arg_1) { + let decoder = + decode.into({ + use res <- decode.parameter + QueryRow(res: res) + }) + |> decode.field(0, decode.bool) + + "select true as res where $1 = true" + |> pgo.execute(db, [pgo.bool(arg_1)], decode.from(decoder, _)) +} diff --git a/birdie_snapshots/float_decoding.accepted b/birdie_snapshots/float_decoding.accepted new file mode 100644 index 0000000..66c6e9f --- /dev/null +++ b/birdie_snapshots/float_decoding.accepted @@ -0,0 +1,33 @@ +--- +version: 1.1.8 +title: float decoding +file: ./test/squirrel_test.gleam +test_name: float_decoding_test +--- +/// A row you get from running the `query` query +/// defined in `query.sql`. +/// +/// > ๐Ÿฟ๏ธ This type definition was generated automatically using v-test of the +/// > [squirrel package](https://github.com/giacomocavalieri/squirrel). +/// +pub type QueryRow { + QueryRow(res: Float) +} + +/// Runs the `query` query +/// defined in `query.sql`. +/// +/// > ๐Ÿฟ๏ธ This function was generated automatically using v-test of +/// > the [squirrel package](https://github.com/giacomocavalieri/squirrel). +/// +pub fn query(db) { + let decoder = + decode.into({ + use res <- decode.parameter + QueryRow(res: res) + }) + |> decode.field(0, decode.float) + + "select 1.1 as res" + |> pgo.execute(db, [], decode.from(decoder, _)) +} diff --git a/birdie_snapshots/float_encoding.accepted b/birdie_snapshots/float_encoding.accepted new file mode 100644 index 0000000..7df1d7f --- /dev/null +++ b/birdie_snapshots/float_encoding.accepted @@ -0,0 +1,33 @@ +--- +version: 1.1.8 +title: float encoding +file: ./test/squirrel_test.gleam +test_name: float_encoding_test +--- +/// A row you get from running the `query` query +/// defined in `query.sql`. +/// +/// > ๐Ÿฟ๏ธ This type definition was generated automatically using v-test of the +/// > [squirrel package](https://github.com/giacomocavalieri/squirrel). +/// +pub type QueryRow { + QueryRow(res: Bool) +} + +/// Runs the `query` query +/// defined in `query.sql`. +/// +/// > ๐Ÿฟ๏ธ This function was generated automatically using v-test of +/// > the [squirrel package](https://github.com/giacomocavalieri/squirrel). +/// +pub fn query(db, arg_1) { + let decoder = + decode.into({ + use res <- decode.parameter + QueryRow(res: res) + }) + |> decode.field(0, decode.bool) + + "select true as res where $1 = 1.1" + |> pgo.execute(db, [pgo.float(arg_1)], decode.from(decoder, _)) +} diff --git a/birdie_snapshots/generated_type_fields_are_labelled_with_their_name_in_the_select_list.accepted b/birdie_snapshots/generated_type_fields_are_labelled_with_their_name_in_the_select_list.accepted new file mode 100644 index 0000000..bf567d7 --- /dev/null +++ b/birdie_snapshots/generated_type_fields_are_labelled_with_their_name_in_the_select_list.accepted @@ -0,0 +1,41 @@ +--- +version: 1.1.8 +title: generated type fields are labelled with their name in the select list +file: ./test/squirrel_test.gleam +test_name: generated_type_fields_are_labelled_with_their_name_in_the_select_list_test +--- +/// A row you get from running the `query` query +/// defined in `query.sql`. +/// +/// > ๐Ÿฟ๏ธ This type definition was generated automatically using v-test of the +/// > [squirrel package](https://github.com/giacomocavalieri/squirrel). +/// +pub type QueryRow { + QueryRow(acorns: Option(Int), squirrel_name: String) +} + +/// Runs the `query` query +/// defined in `query.sql`. +/// +/// > ๐Ÿฟ๏ธ This function was generated automatically using v-test of +/// > the [squirrel package](https://github.com/giacomocavalieri/squirrel). +/// +pub fn query(db) { + let decoder = + decode.into({ + use acorns <- decode.parameter + use squirrel_name <- decode.parameter + QueryRow(acorns: acorns, squirrel_name: squirrel_name) + }) + |> decode.field(0, decode.optional(decode.int)) + |> decode.field(1, decode.string) + + " +select + acorns, + name as squirrel_name +from + squirrel +" + |> pgo.execute(db, [], decode.from(decoder, _)) +} diff --git a/birdie_snapshots/generated_type_has_the_same_name_as_the_function_but_in_pascal_case.accepted b/birdie_snapshots/generated_type_has_the_same_name_as_the_function_but_in_pascal_case.accepted new file mode 100644 index 0000000..c17fc2b --- /dev/null +++ b/birdie_snapshots/generated_type_has_the_same_name_as_the_function_but_in_pascal_case.accepted @@ -0,0 +1,35 @@ +--- +version: 1.1.8 +title: generated type has the same name as the function but in pascal case +file: ./test/squirrel_test.gleam +test_name: generated_type_has_the_same_name_as_the_function_but_in_pascal_case_test +--- +/// A row you get from running the `query` query +/// defined in `query.sql`. +/// +/// > ๐Ÿฟ๏ธ This type definition was generated automatically using v-test of the +/// > [squirrel package](https://github.com/giacomocavalieri/squirrel). +/// +pub type QueryRow { + QueryRow(res: Bool) +} + +/// Runs the `query` query +/// defined in `query.sql`. +/// +/// > ๐Ÿฟ๏ธ This function was generated automatically using v-test of +/// > the [squirrel package](https://github.com/giacomocavalieri/squirrel). +/// +pub fn query(db) { + let decoder = + decode.into({ + use res <- decode.parameter + QueryRow(res: res) + }) + |> decode.field(0, decode.bool) + + " +select true as res +" + |> pgo.execute(db, [], decode.from(decoder, _)) +} diff --git a/birdie_snapshots/int_decoding.accepted b/birdie_snapshots/int_decoding.accepted new file mode 100644 index 0000000..f239d50 --- /dev/null +++ b/birdie_snapshots/int_decoding.accepted @@ -0,0 +1,33 @@ +--- +version: 1.1.8 +title: int decoding +file: ./test/squirrel_test.gleam +test_name: int_decoding_test +--- +/// A row you get from running the `query` query +/// defined in `query.sql`. +/// +/// > ๐Ÿฟ๏ธ This type definition was generated automatically using v-test of the +/// > [squirrel package](https://github.com/giacomocavalieri/squirrel). +/// +pub type QueryRow { + QueryRow(res: Int) +} + +/// Runs the `query` query +/// defined in `query.sql`. +/// +/// > ๐Ÿฟ๏ธ This function was generated automatically using v-test of +/// > the [squirrel package](https://github.com/giacomocavalieri/squirrel). +/// +pub fn query(db) { + let decoder = + decode.into({ + use res <- decode.parameter + QueryRow(res: res) + }) + |> decode.field(0, decode.int) + + "select 11 as res" + |> pgo.execute(db, [], decode.from(decoder, _)) +} diff --git a/birdie_snapshots/int_encoding.accepted b/birdie_snapshots/int_encoding.accepted new file mode 100644 index 0000000..c6c1ed4 --- /dev/null +++ b/birdie_snapshots/int_encoding.accepted @@ -0,0 +1,33 @@ +--- +version: 1.1.8 +title: int encoding +file: ./test/squirrel_test.gleam +test_name: int_encoding_test +--- +/// A row you get from running the `query` query +/// defined in `query.sql`. +/// +/// > ๐Ÿฟ๏ธ This type definition was generated automatically using v-test of the +/// > [squirrel package](https://github.com/giacomocavalieri/squirrel). +/// +pub type QueryRow { + QueryRow(res: Bool) +} + +/// Runs the `query` query +/// defined in `query.sql`. +/// +/// > ๐Ÿฟ๏ธ This function was generated automatically using v-test of +/// > the [squirrel package](https://github.com/giacomocavalieri/squirrel). +/// +pub fn query(db, arg_1) { + let decoder = + decode.into({ + use res <- decode.parameter + QueryRow(res: res) + }) + |> decode.field(0, decode.bool) + + "select true as res where $1 = 11" + |> pgo.execute(db, [pgo.int(arg_1)], decode.from(decoder, _)) +} diff --git a/birdie_snapshots/optional_decoding.accepted b/birdie_snapshots/optional_decoding.accepted new file mode 100644 index 0000000..bf70d51 --- /dev/null +++ b/birdie_snapshots/optional_decoding.accepted @@ -0,0 +1,33 @@ +--- +version: 1.1.8 +title: optional decoding +file: ./test/squirrel_test.gleam +test_name: optional_decoding_test +--- +/// A row you get from running the `query` query +/// defined in `query.sql`. +/// +/// > ๐Ÿฟ๏ธ This type definition was generated automatically using v-test of the +/// > [squirrel package](https://github.com/giacomocavalieri/squirrel). +/// +pub type QueryRow { + QueryRow(acorns: Option(Int)) +} + +/// Runs the `query` query +/// defined in `query.sql`. +/// +/// > ๐Ÿฟ๏ธ This function was generated automatically using v-test of +/// > the [squirrel package](https://github.com/giacomocavalieri/squirrel). +/// +pub fn query(db) { + let decoder = + decode.into({ + use acorns <- decode.parameter + QueryRow(acorns: acorns) + }) + |> decode.field(0, decode.optional(decode.int)) + + "select acorns from squirrel" + |> pgo.execute(db, [], decode.from(decoder, _)) +} diff --git a/birdie_snapshots/query_with_comment.accepted b/birdie_snapshots/query_with_comment.accepted new file mode 100644 index 0000000..5f241da --- /dev/null +++ b/birdie_snapshots/query_with_comment.accepted @@ -0,0 +1,35 @@ +--- +version: 1.1.8 +title: query with comment +file: ./test/squirrel_test.gleam +test_name: query_with_comment_test +--- +/// A row you get from running the `query` query +/// defined in `query.sql`. +/// +/// > ๐Ÿฟ๏ธ This type definition was generated automatically using v-test of the +/// > [squirrel package](https://github.com/giacomocavalieri/squirrel). +/// +pub type QueryRow { + QueryRow(res: Bool) +} + +/// This is a comment +/// +/// > ๐Ÿฟ๏ธ This function was generated automatically using v-test of +/// > the [squirrel package](https://github.com/giacomocavalieri/squirrel). +/// +pub fn query(db) { + let decoder = + decode.into({ + use res <- decode.parameter + QueryRow(res: res) + }) + |> decode.field(0, decode.bool) + + " +-- This is a comment +select true as res +" + |> pgo.execute(db, [], decode.from(decoder, _)) +} diff --git a/birdie_snapshots/query_with_multiline_comment.accepted b/birdie_snapshots/query_with_multiline_comment.accepted new file mode 100644 index 0000000..edd1e49 --- /dev/null +++ b/birdie_snapshots/query_with_multiline_comment.accepted @@ -0,0 +1,37 @@ +--- +version: 1.1.8 +title: query with multiline comment +file: ./test/squirrel_test.gleam +test_name: query_with_multiline_comment_test +--- +/// A row you get from running the `query` query +/// defined in `query.sql`. +/// +/// > ๐Ÿฟ๏ธ This type definition was generated automatically using v-test of the +/// > [squirrel package](https://github.com/giacomocavalieri/squirrel). +/// +pub type QueryRow { + QueryRow(res: Bool) +} + +/// This is a comment +/// that goes over multiple lines! +/// +/// > ๐Ÿฟ๏ธ This function was generated automatically using v-test of +/// > the [squirrel package](https://github.com/giacomocavalieri/squirrel). +/// +pub fn query(db) { + let decoder = + decode.into({ + use res <- decode.parameter + QueryRow(res: res) + }) + |> decode.field(0, decode.bool) + + " +-- This is a comment +-- that goes over multiple lines! +select true as res +" + |> pgo.execute(db, [], decode.from(decoder, _)) +} diff --git a/birdie_snapshots/string_decoding.accepted b/birdie_snapshots/string_decoding.accepted new file mode 100644 index 0000000..1dc6810 --- /dev/null +++ b/birdie_snapshots/string_decoding.accepted @@ -0,0 +1,33 @@ +--- +version: 1.1.8 +title: string decoding +file: ./test/squirrel_test.gleam +test_name: string_decoding_test +--- +/// A row you get from running the `query` query +/// defined in `query.sql`. +/// +/// > ๐Ÿฟ๏ธ This type definition was generated automatically using v-test of the +/// > [squirrel package](https://github.com/giacomocavalieri/squirrel). +/// +pub type QueryRow { + QueryRow(res: String) +} + +/// Runs the `query` query +/// defined in `query.sql`. +/// +/// > ๐Ÿฟ๏ธ This function was generated automatically using v-test of +/// > the [squirrel package](https://github.com/giacomocavalieri/squirrel). +/// +pub fn query(db) { + let decoder = + decode.into({ + use res <- decode.parameter + QueryRow(res: res) + }) + |> decode.field(0, decode.string) + + "select 'wibble' as res" + |> pgo.execute(db, [], decode.from(decoder, _)) +} diff --git a/birdie_snapshots/string_encoding.accepted b/birdie_snapshots/string_encoding.accepted new file mode 100644 index 0000000..fd45a41 --- /dev/null +++ b/birdie_snapshots/string_encoding.accepted @@ -0,0 +1,33 @@ +--- +version: 1.1.8 +title: string encoding +file: ./test/squirrel_test.gleam +test_name: string_encoding_test +--- +/// A row you get from running the `query` query +/// defined in `query.sql`. +/// +/// > ๐Ÿฟ๏ธ This type definition was generated automatically using v-test of the +/// > [squirrel package](https://github.com/giacomocavalieri/squirrel). +/// +pub type QueryRow { + QueryRow(res: Bool) +} + +/// Runs the `query` query +/// defined in `query.sql`. +/// +/// > ๐Ÿฟ๏ธ This function was generated automatically using v-test of +/// > the [squirrel package](https://github.com/giacomocavalieri/squirrel). +/// +pub fn query(db, arg_1) { + let decoder = + decode.into({ + use res <- decode.parameter + QueryRow(res: res) + }) + |> decode.field(0, decode.bool) + + "select true as res where $1 = 'wibble'" + |> pgo.execute(db, [pgo.text(arg_1)], decode.from(decoder, _)) +} diff --git a/gleam.toml b/gleam.toml new file mode 100644 index 0000000..e17e7a9 --- /dev/null +++ b/gleam.toml @@ -0,0 +1,25 @@ +name = "squirrel" +version = "0.1.0" +description = "๐Ÿฟ๏ธ Type safe SQL in Gleam" +licences = ["Apache-2.0"] +repository = { type = "github", user = "giacomocavalieri", repo = "squirrel" } + +[dependencies] +gleam_stdlib = ">= 0.34.0 and < 2.0.0" +simplifile = ">= 2.0.1 and < 3.0.0" +eval = ">= 1.0.0 and < 2.0.0" +gleam_json = ">= 1.0.0 and < 2.0.0" +mug = ">= 1.1.0 and < 2.0.0" +glam = ">= 2.0.1 and < 3.0.0" +justin = ">= 1.0.1 and < 2.0.0" +filepath = ">= 1.0.0 and < 2.0.0" +gleam_community_ansi = ">= 1.4.0 and < 2.0.0" +term_size = ">= 1.0.1 and < 2.0.0" +argv = ">= 1.0.2 and < 2.0.0" +envoy = ">= 1.0.1 and < 2.0.0" + +[dev-dependencies] +gleeunit = ">= 1.0.0 and < 2.0.0" +birdie = ">= 1.1.8 and < 2.0.0" +temporary = ">= 1.0.0 and < 2.0.0" +gleam_pgo = ">= 0.13.0 and < 1.0.0" diff --git a/manifest.toml b/manifest.toml new file mode 100644 index 0000000..71b9f12 --- /dev/null +++ b/manifest.toml @@ -0,0 +1,53 @@ +# This file was generated by Gleam +# You typically do not need to edit this file + +packages = [ + { name = "argv", version = "1.0.2", build_tools = ["gleam"], requirements = [], otp_app = "argv", source = "hex", outer_checksum = "BA1FF0929525DEBA1CE67256E5ADF77A7CDDFE729E3E3F57A5BDCAA031DED09D" }, + { name = "backoff", version = "1.1.6", build_tools = ["rebar3"], requirements = [], otp_app = "backoff", source = "hex", outer_checksum = "CF0CFFF8995FB20562F822E5CC47D8CCF664C5ECDC26A684CBE85C225F9D7C39" }, + { name = "birdie", version = "1.1.8", build_tools = ["gleam"], requirements = ["argv", "filepath", "glance", "gleam_community_ansi", "gleam_erlang", "gleam_stdlib", "justin", "rank", "simplifile", "trie_again"], otp_app = "birdie", source = "hex", outer_checksum = "D225C0A3035FCD73A88402925A903AAD3567A1515C9EAE8364F11C17AD1805BB" }, + { name = "envoy", version = "1.0.1", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "envoy", source = "hex", outer_checksum = "CFAACCCFC47654F7E8B75E614746ED924C65BD08B1DE21101548AC314A8B6A41" }, + { name = "eval", version = "1.0.0", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "eval", source = "hex", outer_checksum = "264DAF4B49DF807F303CA4A4E4EBC012070429E40BE384C58FE094C4958F9BDA" }, + { name = "exception", version = "2.0.0", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "exception", source = "hex", outer_checksum = "F5580D584F16A20B7FCDCABF9E9BE9A2C1F6AC4F9176FA6DD0B63E3B20D450AA" }, + { name = "filepath", version = "1.0.0", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "filepath", source = "hex", outer_checksum = "EFB6FF65C98B2A16378ABC3EE2B14124168C0CE5201553DE652E2644DCFDB594" }, + { name = "glam", version = "2.0.1", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "glam", source = "hex", outer_checksum = "66EC3BCD632E51EED029678F8DF419659C1E57B1A93D874C5131FE220DFAD2B2" }, + { name = "glance", version = "0.11.0", build_tools = ["gleam"], requirements = ["gleam_stdlib", "glexer"], otp_app = "glance", source = "hex", outer_checksum = "8F3314D27773B7C3B9FB58D8C02C634290422CE531988C0394FA0DF8676B964D" }, + { name = "gleam_community_ansi", version = "1.4.0", build_tools = ["gleam"], requirements = ["gleam_community_colour", "gleam_stdlib"], otp_app = "gleam_community_ansi", source = "hex", outer_checksum = "FE79E08BF97009729259B6357EC058315B6FBB916FAD1C2FF9355115FEB0D3A4" }, + { name = "gleam_community_colour", version = "1.4.0", build_tools = ["gleam"], requirements = ["gleam_json", "gleam_stdlib"], otp_app = "gleam_community_colour", source = "hex", outer_checksum = "795964217EBEDB3DA656F5EB8F67D7AD22872EB95182042D3E7AFEF32D3FD2FE" }, + { name = "gleam_crypto", version = "1.3.0", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "gleam_crypto", source = "hex", outer_checksum = "ADD058DEDE8F0341F1ADE3AAC492A224F15700829D9A3A3F9ADF370F875C51B7" }, + { name = "gleam_erlang", version = "0.25.0", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "gleam_erlang", source = "hex", outer_checksum = "054D571A7092D2A9727B3E5D183B7507DAB0DA41556EC9133606F09C15497373" }, + { name = "gleam_json", version = "1.0.1", build_tools = ["gleam"], requirements = ["gleam_stdlib", "thoas"], otp_app = "gleam_json", source = "hex", outer_checksum = "9063D14D25406326C0255BDA0021541E797D8A7A12573D849462CAFED459F6EB" }, + { name = "gleam_pgo", version = "0.13.0", build_tools = ["gleam"], requirements = ["gleam_stdlib", "pgo"], otp_app = "gleam_pgo", source = "hex", outer_checksum = "6A1E7F3E717C077788254871E4EF4A8DFF58FEC07D7FA7C7702C2CCF66095AC8" }, + { name = "gleam_stdlib", version = "0.39.0", build_tools = ["gleam"], requirements = [], otp_app = "gleam_stdlib", source = "hex", outer_checksum = "2D7DE885A6EA7F1D5015D1698920C9BAF7241102836CE0C3837A4F160128A9C4" }, + { name = "gleeunit", version = "1.2.0", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "gleeunit", source = "hex", outer_checksum = "F7A7228925D3EE7D0813C922E062BFD6D7E9310F0BEE585D3A42F3307E3CFD13" }, + { name = "glexer", version = "1.0.1", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "glexer", source = "hex", outer_checksum = "BD477AD657C2B637FEF75F2405FAEFFA533F277A74EF1A5E17B55B1178C228FB" }, + { name = "justin", version = "1.0.1", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "justin", source = "hex", outer_checksum = "7FA0C6DB78640C6DC5FBFD59BF3456009F3F8B485BF6825E97E1EB44E9A1E2CD" }, + { name = "mug", version = "1.1.0", build_tools = ["gleam"], requirements = ["gleam_erlang", "gleam_stdlib"], otp_app = "mug", source = "hex", outer_checksum = "85A61E67A7A8C25F4460D9CBEF1C09C68FC06ABBC6FF893B0A1F42AE01CBB546" }, + { name = "opentelemetry_api", version = "1.3.0", build_tools = ["rebar3", "mix"], requirements = ["opentelemetry_semantic_conventions"], otp_app = "opentelemetry_api", source = "hex", outer_checksum = "B9E5FF775FD064FA098DBA3C398490B77649A352B40B0B730A6B7DC0BDD68858" }, + { name = "opentelemetry_semantic_conventions", version = "0.2.0", build_tools = ["rebar3", "mix"], requirements = [], otp_app = "opentelemetry_semantic_conventions", source = "hex", outer_checksum = "D61FA1F5639EE8668D74B527E6806E0503EFC55A42DB7B5F39939D84C07D6895" }, + { name = "pg_types", version = "0.4.0", build_tools = ["rebar3"], requirements = [], otp_app = "pg_types", source = "hex", outer_checksum = "B02EFA785CAECECF9702C681C80A9CA12A39F9161A846CE17B01FB20AEEED7EB" }, + { name = "pgo", version = "0.14.0", build_tools = ["rebar3"], requirements = ["backoff", "opentelemetry_api", "pg_types"], otp_app = "pgo", source = "hex", outer_checksum = "71016C22599936E042DC0012EE4589D24C71427D266292F775EBF201D97DF9C9" }, + { name = "rank", version = "1.0.0", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "rank", source = "hex", outer_checksum = "5660E361F0E49CBB714CC57CC4C89C63415D8986F05B2DA0C719D5642FAD91C9" }, + { name = "simplifile", version = "2.0.1", build_tools = ["gleam"], requirements = ["filepath", "gleam_stdlib"], otp_app = "simplifile", source = "hex", outer_checksum = "5FFEBD0CAB39BDD343C3E1CCA6438B2848847DC170BA2386DF9D7064F34DF000" }, + { name = "temporary", version = "1.0.0", build_tools = ["gleam"], requirements = ["envoy", "exception", "filepath", "gleam_crypto", "gleam_stdlib", "simplifile"], otp_app = "temporary", source = "hex", outer_checksum = "51C0FEF4D72CE7CA507BD188B21C1F00695B3D5B09D7DFE38240BFD3A8E1E9B3" }, + { name = "term_size", version = "1.0.1", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "term_size", source = "hex", outer_checksum = "D00BD2BC8FB3EBB7E6AE076F3F1FF2AC9D5ED1805F004D0896C784D06C6645F1" }, + { name = "thoas", version = "1.2.1", build_tools = ["rebar3"], requirements = [], otp_app = "thoas", source = "hex", outer_checksum = "E38697EDFFD6E91BD12CEA41B155115282630075C2A727E7A6B2947F5408B86A" }, + { name = "trie_again", version = "1.1.2", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "trie_again", source = "hex", outer_checksum = "5B19176F52B1BD98831B57FDC97BD1F88C8A403D6D8C63471407E78598E27184" }, +] + +[requirements] +argv = { version = ">= 1.0.2 and < 2.0.0" } +birdie = { version = ">= 1.1.8 and < 2.0.0" } +envoy = { version = ">= 1.0.1 and < 2.0.0" } +eval = { version = ">= 1.0.0 and < 2.0.0" } +filepath = { version = ">= 1.0.0 and < 2.0.0" } +glam = { version = ">= 2.0.1 and < 3.0.0" } +gleam_community_ansi = { version = ">= 1.4.0 and < 2.0.0" } +gleam_json = { version = ">= 1.0.0 and < 2.0.0" } +gleam_pgo = { version = ">= 0.13.0 and < 1.0.0" } +gleam_stdlib = { version = ">= 0.34.0 and < 2.0.0" } +gleeunit = { version = ">= 1.0.0 and < 2.0.0" } +justin = { version = ">= 1.0.1 and < 2.0.0" } +mug = { version = ">= 1.1.0 and < 2.0.0" } +simplifile = { version = ">= 2.0.1 and < 3.0.0" } +temporary = { version = ">= 1.0.0 and < 2.0.0" } +term_size = { version = ">= 1.0.1 and < 2.0.0" } diff --git a/prova.gleam b/prova.gleam new file mode 100644 index 0000000..6229a26 --- /dev/null +++ b/prova.gleam @@ -0,0 +1,36 @@ +import gleam/pgo +import decode + +/// A row you get from running the `prova` query +/// defined in `prova.sql`. +/// +/// > ๐Ÿฟ๏ธ This type definition was generated automatically using v1.0.0 of the +/// > [squirrel package](https://github.com/giacomocavalieri/squirrel). +/// +pub type ProvaRow { + ProvaRow(user_id: String, username: String, password: String) +} + +/// Runs the `prova` query +/// defined in `prova.sql`. +/// +/// > ๐Ÿฟ๏ธ This function was generated automatically using v1.0.0 of the +/// > [squirrel package](https://github.com/giacomocavalieri/squirrel). +/// +pub fn prova(db) { + let decoder = + decode.into({ + use user_id <- decode.parameter + use username <- decode.parameter + use password <- decode.parameter + ProvaRow(user_id: user_id, username: username, password: password) + }) + |> decode.field(0, decode.string) + |> decode.field(1, decode.string) + |> decode.field(2, decode.string) + + "-- name: prova +select * from kira_user; +" + |> pgo.execute(db, [], decode.from(decoder, _)) +} diff --git a/prova.sql b/prova.sql new file mode 100644 index 0000000..c2ca628 --- /dev/null +++ b/prova.sql @@ -0,0 +1,2 @@ +-- name: prova +select * from kira_user; diff --git a/src/squirrel.gleam b/src/squirrel.gleam new file mode 100644 index 0000000..a88bbfa --- /dev/null +++ b/src/squirrel.gleam @@ -0,0 +1,250 @@ +import envoy +import filepath +import glam/doc.{type Document} +import gleam/bool +import gleam/dict.{type Dict} +import gleam/int +import gleam/io +import gleam/list +import gleam/result +import gleam/string +import gleam_community/ansi +import simplifile +import squirrel/internal/database/postgres +import squirrel/internal/error.{type Error, CannotWriteToFile} +import squirrel/internal/query.{type TypedQuery} +import term_size + +const squirrel_version = "v1.0.0" + +/// ๐Ÿฟ๏ธ Performs code generation for your Gleam project. +/// +/// `squirrel` is not configurable and will discover the queries to generate +/// code for by relying on a conventional project's structure: +/// - `squirrel` first looks for all directories called `sql` under the `src` +/// directory of your Gleam project, and reads all the `*.sql` files in there +/// (in glob terms `src/**/sql/*.sql`). +/// - Each `*.sql` file _must contain a single query_ as it is turned into a +/// Gleam function with the same name. +/// - All functions coming from the same `sql` directory will be grouped under +/// a Gleam file called `sql.gleam` at the same level: given a `src/$PATH/sql` +/// directory, you'll end up with a generated `src/$PATH/sql.gleam` file. +/// +/// > โš ๏ธ In order to generate type safe code, `squirrel` has to connect +/// > to your Postgres database. To know what host, user, etc. values to use +/// > when connecting, it will read your +/// > [Postgres env variables.](https://www.postgresql.org/docs/current/libpq-envars.html) +/// > +/// > If a variable is not set it will go with the following defaults: +/// > - `PGHOST`: `"localhost"` +/// > - `PGPORT`: `5432` +/// > - `PGUSER`: `"root"` +/// > - `PGDATABASE`: `"database"` +/// > - `PGPASSWORD`: `""` +/// +/// > โš ๏ธ The generated code relies on the +/// > [`gleam_pgo`](https://hexdocs.pm/gleam_pgo/) and +/// > [`decode`](https://hexdocs.pm/decode/) packages to work, so make sure to +/// > add those as dependencies to your project. +/// +pub fn main() { + walk("src") + |> run(read_connection_options()) + |> pretty_report + |> io.println +} + +fn read_connection_options() -> postgres.ConnectionOptions { + let host = envoy.get("PGHOST") |> result.unwrap("localhost") + let user = envoy.get("PGUSER") |> result.unwrap("root") + let database = envoy.get("PGDATABASE") |> result.unwrap("database") + let password = envoy.get("PGPASSWORD") |> result.unwrap("") + let port = + envoy.get("PGPORT") + |> result.then(int.parse) + |> result.unwrap(5432) + + postgres.ConnectionOptions( + host: host, + port: port, + user: user, + password: password, + database: database, + timeout: 1000, + ) +} + +/// Finds all `from/**/sql` directories and lists the full paths of the `*.sql` +/// files inside each one. +/// +fn walk(from: String) -> Dict(String, List(String)) { + case filepath.base_name(from) { + "sql" -> { + let assert Ok(files) = simplifile.read_directory(from) + let files = { + use file <- list.filter_map(files) + use extension <- result.try(filepath.extension(file)) + use <- bool.guard(when: extension != "sql", return: Error(Nil)) + let file_name = filepath.join(from, file) + case simplifile.is_file(file_name) { + Ok(True) -> Ok(file_name) + Ok(False) | Error(_) -> Error(Nil) + } + } + dict.from_list([#(from, files)]) + } + + _ -> { + let assert Ok(files) = simplifile.read_directory(from) + let directories = { + use file <- list.filter_map(files) + let file_name = filepath.join(from, file) + case simplifile.is_directory(file_name) { + Ok(True) -> Ok(file_name) + Ok(False) | Error(_) -> Error(Nil) + } + } + + list.map(directories, walk) + |> list.fold(from: dict.new(), with: dict.merge) + } + } +} + +/// Given a dict of directories and their `*.sql` files, performs code +/// generation for each one, bundling all `*.sql` files under the same directory +/// into a single Gleam module. +/// +fn run( + directories: Dict(String, List(String)), + connection: postgres.ConnectionOptions, +) -> Dict(String, #(Int, List(Error))) { + use directory, files <- dict.map_values(directories) + + let #(queries, errors) = + list.map(files, query.from_file) + |> result.partition + + let #(queries, errors) = case postgres.main(queries, connection) { + Error(error) -> #([], [error, ..errors]) + Ok(#(queries, type_errors)) -> #(queries, list.append(errors, type_errors)) + } + + let output_file = + filepath.directory_name(directory) + |> filepath.join("sql.gleam") + + case write_queries(queries, to: output_file) { + Ok(n) -> #(n, errors) + Error(error) -> #(list.length(queries), [error, ..errors]) + } +} + +fn write_queries( + queries: List(TypedQuery), + to file: String, +) -> Result(Int, Error) { + use <- bool.guard(when: queries == [], return: Ok(0)) + + let directory = filepath.directory_name(file) + let _ = simplifile.create_directory_all(directory) + + // We need the top level imports. + let imports = "import gleam/pgo\nimport decode\n" + let #(count, code) = { + use #(count, code), query <- list.fold(queries, #(0, imports)) + #(count + 1, code <> "\n" <> query.generate_code(squirrel_version, query)) + } + + let try_write = + simplifile.write(code, to: file) + |> result.map_error(CannotWriteToFile(file, _)) + + use _ <- result.try(try_write) + Ok(count) +} + +// --- PRETTY REPORT PRINTING -------------------------------------------------- + +fn pretty_report(dirs: Dict(String, #(Int, List(Error)))) -> String { + let width = term_size.columns() |> result.unwrap(80) + let #(ok, errors) = { + use #(all_ok, all_errors), _, result <- dict.fold(dirs, #(0, [])) + let #(ok, errors) = result + #(all_ok + ok, errors |> list.append(all_errors)) + } + let errors_doc = + list.map(errors, error.to_doc) + |> doc.join(with: doc.lines(2)) + + case ok, errors { + 0, [_, ..] -> doc.to_string(errors_doc, width) + 0, [] -> + text_with_header( + "๐Ÿฟ๏ธ ", + "I couldn't find any `*.sql` file to generate queries from", + ) + |> doc.to_string(width) + |> ansi.yellow + + n, [] -> + text_with_header( + "๐Ÿฟ๏ธ ", + "Generated " + <> int.to_string(n) + <> " " + <> pluralise(n, "query", "queries"), + ) + |> doc.to_string(width) + |> ansi.green + |> string.append("\n") + |> string.append( + text_with_header( + "๐Ÿฅœ ", + "Don't forget to run `gleam add decode gleam_pgo` if you haven't yet!", + ) + |> doc.to_string(width) + |> ansi.cyan, + ) + + n, [_, ..] -> + [ + errors_doc, + doc.lines(2), + text_with_header( + "๐Ÿฅœ ", + "I could still generate " + <> int.to_string(n) + <> " " + <> pluralise(n, "query", "queries"), + ), + ] + |> doc.concat + |> doc.to_string(width) + } +} + +fn text_with_header(header: String, text: String) { + [ + doc.from_string(header), + flexible_string(text) + |> doc.nest(by: string.length(header)), + ] + |> doc.concat + |> doc.group +} + +fn pluralise(count: Int, singular: String, plural: String) -> String { + case count { + 1 -> singular + _ -> plural + } +} + +fn flexible_string(string: String) -> Document { + string.split(string, on: "\n") + |> list.flat_map(string.split(_, on: " ")) + |> list.map(doc.from_string) + |> doc.join(with: doc.flex_space) + |> doc.group +} diff --git a/src/squirrel/internal/database/postgres.gleam b/src/squirrel/internal/database/postgres.gleam new file mode 100644 index 0000000..54d9e98 --- /dev/null +++ b/src/squirrel/internal/database/postgres.gleam @@ -0,0 +1,899 @@ +//// In this module lies the core of `squirrel`. +//// It exposes a single public function called `main` that is used to turn a +//// list of untyped queries into typed ones. +//// +//// To do so, `squirrel` will try to connect to a database, have it parse the +//// queries and reply with the types it could infer. +//// Then it's as simple as (not that simple in practice ๐Ÿ˜†) converting the +//// Postgres types into Gleam types. +//// +//// > ๐Ÿ’ก I tried to do my best to comment everything as much as possible and +//// > make things easy to read. +//// > If you feel something is poorly commented or hard to understand, then +//// > that is a bug! Please do reach out, I'd love to hear your feedback. +//// + +import eval +import gleam/bit_array +import gleam/bool +import gleam/dict.{type Dict} +import gleam/dynamic.{type DecodeErrors, type Dynamic} as d +import gleam/int +import gleam/json +import gleam/list +import gleam/option.{type Option, None, Some} +import gleam/result +import gleam/set.{type Set} +import gleam/string +import squirrel/internal/database/postgres_protocol as pg +import squirrel/internal/error.{ + type Error, type Pointer, type ValueIdentifierError, ByteIndex, + CannotParseQuery, PgCannotAuthenticate, PgCannotDecodeReceivedMessage, + PgCannotDescribeQuery, PgCannotReceiveMessage, PgCannotSendMessage, Pointer, + QueryHasInvalidColumn, QueryHasUnsupportedType, +} +import squirrel/internal/eval_extra +import squirrel/internal/gleam +import squirrel/internal/query.{ + type TypedQuery, type UntypedQuery, TypedQuery, UntypedQuery, +} + +const find_postgres_type_query = " +select + -- The name of the type or, if the type is an array, the name of its + -- elements' type. + case + when elem.typname is null then type.typname + else elem.typname + end as type, + -- Tells us how to interpret the firs column: if this is true then the first + -- column is the type of the elements of the array type. + -- Otherwise it means we've found a base type. + case + when elem.typname is null then false + else true + end as is_array +from + pg_type as type + left join pg_type as elem on type.typelem = elem.oid +where + type.oid = $1 +" + +const find_column_nullability_query = " +select + -- Whether the column has a not-null constraint. + attnotnull +from + pg_attribute +where + -- The oid of the table the column comes from. + attrelid = $1 + -- The index of the column we're looking for. + and attnum = $2 +" + +// --- TYPES ------------------------------------------------------------------- + +/// A Postgres type. +/// +/// > โš ๏ธPostgres has loads of types and this might not cover the more exotic +/// > ones but for now it feels more than enough. +/// +type PgType { + /// A base type, like `integer`, `text`, `char`, ... + /// + PBase(name: String) + + /// An array type like `int[]`, `text[]`, ... + /// + PArray(inner: PgType) + + /// A type that could also be `NULL`, this is particularly common for columns + /// that do not have a `not null` constraint; or for those coming from partial + /// joins. + /// + POption(inner: PgType) +} + +/// The context in which all database-related actions will take place. +/// +type Context { + Context( + /// A connection to the database. Squirrel does nothing fancy and just uses + /// a single connection to run all the queries. + /// + db: pg.Connection, + /// A cache from `oid` to corresponding Gleam type. + /// We use this to avoid having to reach to the database every time we need + /// to infer a type. + /// + /// > ๐Ÿ’ก An oid is an integer identifier that is used by Postgres to + /// > uniquely identify types (and a lot of other various objects, see + /// > [this documentation page](https://www.postgresql.org/docs/current/datatype-oid.html)). + /// + gleam_types: Dict(Int, gleam.Type), + /// A cache from table `oid` and column index to its nullability. + /// We use this to avoid having to reach to the database every time we need + /// to type a column. + /// + column_nullability: Dict(#(Int, Int), Nullability), + ) +} + +/// Information about a column's nullability. +/// If a column has a `not null` constraint then it will be `NotNullable`, in +/// all other cases it will be `Nullable`. +/// +/// > โš ๏ธ A column with a `not null` constraint might still be considered +/// > nullable if it comes from a left/right join! +/// +type Nullability { + Nullable + NotNullable +} + +/// A query plan produced by Postgres when we ask it to `explain` a query. +/// +type Plan { + Plan( + join_type: Option(JoinType), + parent_relation: Option(ParentRelation), + output: Option(List(String)), + plans: Option(List(Plan)), + ) +} + +type JoinType { + Full + Left + Right + Other +} + +type ParentRelation { + Inner + NotInner +} + +/// This is the type of a database-related action. +/// In order to be carried out it needs to have access to the database `Context` +/// and could fail with an `Error`. +/// +type Db(a) = + eval.Eval(a, Error, Context) + +/// The options used to establish a connection to the Postgres database. +/// +pub type ConnectionOptions { + ConnectionOptions( + host: String, + port: Int, + user: String, + password: String, + database: String, + timeout: Int, + ) +} + +// --- POSTGRES TO GLEAM TYPES CONVERSIONS ------------------------------------- + +/// This function turns a Postgres type into a Gleam one, returning an error +/// with the type name if it is not currently supported. +/// +fn pg_to_gleam_type(type_: PgType) -> Result(gleam.Type, String) { + case type_ { + PArray(inner: inner) -> + pg_to_gleam_type(inner) + |> result.map(gleam.List) + |> result.map_error(fn(inner) { inner <> "[]" }) + + POption(inner: inner) -> + pg_to_gleam_type(inner) + |> result.map(gleam.Option) + |> result.map_error(fn(inner) { inner <> "?" }) + + PBase(name: name) -> + case name { + "bool" -> Ok(gleam.Bool) + "text" | "char" -> Ok(gleam.String) + "float4" | "float8" | "numeric" -> Ok(gleam.Float) + "int2" | "int4" | "int8" -> Ok(gleam.Int) + _ -> Error(name) + } + } +} + +// --- CLI ENTRY POINT --------------------------------------------------------- + +/// Connects to a Postgres database (using the given options) and types a list +/// of queries. +/// +/// This might fail with an `Error` if a database connection cannot be +/// established, making it impossible to type any of the queries. +/// Otherwise, it will try typing all the queries, retuning a list of typed ones +/// and a list of possible errors for the ones it couldn't type. +/// +pub fn main( + queries: List(UntypedQuery), + connection: ConnectionOptions, +) -> Result(#(List(TypedQuery), List(Error)), Error) { + let context = + Context( + db: pg.connect(connection.host, connection.port, connection.timeout), + gleam_types: dict.new(), + column_nullability: dict.new(), + ) + + // Once the server has confirmed that it is ready to accept query requests we + // can start gathering information about all the different queries. + // After each one we need to make sure the server is ready to go on with the + // next one. + // + // https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-EXT-QUERY + // + let #(context, connection) = eval.step(authenticate(connection), context) + case connection { + Error(error) -> Error(error) + // After successfully authenticating we can try and type all the queries. + Ok(_) -> + list.map(queries, infer_types) + |> eval_extra.run_all(context) + |> result.partition + |> Ok + } +} + +fn authenticate(connection: ConnectionOptions) -> Db(Nil) { + let params = [#("user", connection.user), #("database", connection.database)] + use _ <- eval.try(send(pg.FeStartupMessage(params))) + + use msg <- eval.try(receive()) + use _ <- eval.try(case msg { + pg.BeAuthenticationOk -> eval.return(Nil) + _ -> unexpected_message(PgCannotAuthenticate, "AuthenticationOk", msg) + }) + + use _ <- eval.try(wait_until_ready()) + eval.return(Nil) +} + +/// Returns type information about a query. +/// +fn infer_types(query: UntypedQuery) -> Db(TypedQuery) { + // Postgres doesn't give us 100% accurate data regardin a query's type. + // We'll need to perform a couple of database interrogations and do some + // guessing. + // + // The big picture idea is the following: + // - We ask the server to prepare the query. + // - Postgres will reply with type information about the returned rows and the + // query's parameters. + let action = parameters_and_returns(query) + use #(parameters, returns) <- eval.try(action) + // - The parameters' types are just OIDs so we need to interrogate the + // database to learn the actual corresponding Gleam type. + use parameters <- eval.try(resolve_parameters(query, parameters)) + // - The returns' types are just OIDs so we have to do the same for those. + // - Here comes the tricky part: we can't know if a column is nullable just + // from the server's previous answer. + // - For columns coming from a database table we'll look it up and see if + // the column is nullable or not + // - But this is not enough! If a returned column comes from a left/right + // join it will be nullable even if it is not in the original table. + // To work around this we'll have to inspect the query plan. + use plan <- eval.try(query_plan(query, list.length(parameters))) + let nullables = nullables_from_plan(plan) + use returns <- eval.try(resolve_returns(query, returns, nullables)) + + query + |> query.add_types(parameters, returns) + |> eval.return +} + +fn parameters_and_returns(query: UntypedQuery) -> Db(_) { + // We need to send three messages: + // - `Parse` with the query to parse + // - `Describe` to ask the server to reply with a description of the query's + // return types and parameter types. This is what we need to understand the + // SQL inferred types and generate the corresponding Gleam types. + // - `Sync` to ask the server to immediately reply with the results of parsing + // + use _ <- eval.try( + send_all([ + pg.FeParse("", query.content, []), + pg.FeDescribe(pg.PreparedStatement, ""), + pg.FeSync, + ]), + ) + + // Error builder used in the following steps in case the message sequence + // doesn't go as planned. + let cannot_describe = fn(expected, got) { + PgCannotDescribeQuery( + file: query.file, + query_name: gleam.identifier_to_string(query.name), + expected: expected, + got: got, + ) + } + + use msg <- eval.try(receive()) + case msg { + pg.BeErrorResponse(errors) -> + eval.throw(error_fields_to_parse_error(query, errors)) + pg.BeParseComplete -> { + use msg <- eval.try(receive()) + use parameters <- eval.try(case msg { + pg.BeParameterDescription(parameters) -> eval.return(parameters) + _ -> unexpected_message(cannot_describe, "ParameterDescription", msg) + }) + + use msg <- eval.try(receive()) + use rows <- eval.try(case msg { + pg.BeRowDescriptions(rows) -> eval.return(rows) + _ -> unexpected_message(cannot_describe, "RowDescriptions", msg) + }) + + use msg <- eval.try(receive()) + use _ <- eval.try(case msg { + pg.BeReadyForQuery(_) -> eval.return(Nil) + _ -> unexpected_message(cannot_describe, "ReadyForQuery", msg) + }) + + eval.return(#(parameters, rows)) + } + _ -> + unexpected_message(cannot_describe, "ParseComplete or ErrorResponse", msg) + } +} + +/// Given an untyped query and the error fields we got back from the database in +/// case it couldn't be parsed, produces an appropriate `Error`. +/// +fn error_fields_to_parse_error( + query: UntypedQuery, + errors: Set(pg.ErrorOrNoticeField), +) -> Error { + // We first look for the relevant errors in the set of errors the database + // returned. This way we can attach additional information explaining why the + // query failed. + let #(error_code, message, position, hint) = { + use acc, error_field <- set.fold(errors, from: #(None, None, None, None)) + let #(code, message, position, hint) = acc + case error_field { + pg.Code(code) -> #(Some(code), message, position, hint) + pg.Message(message) -> #(code, Some(message), position, hint) + pg.Hint(hint) -> #(code, message, position, Some(hint)) + pg.Position(position) -> + case int.parse(position) { + Ok(position) -> #(code, message, Some(position), hint) + Error(_) -> acc + } + _ -> acc + } + } + + // If we found both a `Message` and `Position` error messages then we can turn + // those into a pointer that will be shown in the error message. + let pointer = case message, position { + Some(message), Some(position) -> + Some(Pointer(point_to: ByteIndex(position), message: message)) + _, _ -> None + } + + cannot_parse_error(query, error_code, hint, pointer) +} + +fn resolve_parameters( + query: UntypedQuery, + parameters: List(Int), +) -> Db(List(gleam.Type)) { + use oid <- eval_extra.try_map(parameters) + find_gleam_type(query, oid) +} + +/// Looks up a type with the given id in the Postgres registry. +/// +/// > โš ๏ธ This function assumes that the oid is present in the database and +/// > will crash otherwise. This should only be called with oids coming from +/// > a database interrogation. +/// +fn find_gleam_type(query: UntypedQuery, oid: Int) -> Db(gleam.Type) { + // We first look for the Gleam type corresponding to this id in the cache to + // avoid hammering the db with needless queries. + use <- with_cached_gleam_type(oid) + + // The only parameter to this query is the oid of the type to lookup: + // that's a 32bit integer (its oid needed to prepare the query is 23). + let params = [pg.Parameter(<>)] + let run_query = find_postgres_type_query |> run_query(params, [23]) + use res <- eval.try(run_query) + + // We know the output must only contain two values: the name and a boolean to + // check wether it is an array or not. + // It's safe to assert because this query is hard coded in our code and the + // output shape cannot change without us changing that query. + let assert [name, is_array] = res + + // We then decode the bitarrays we got as a result: + // - `name` is just a string + // - `is_array` is a pg boolean + let assert Ok(name) = bit_array.to_string(name) + let type_ = case bit_array_to_bool(is_array) { + True -> PArray(PBase(name)) + False -> PBase(name) + } + + pg_to_gleam_type(type_) + |> result.map_error(unsupported_type_error(query, _)) + |> eval.from_result +} + +/// Returns the query plan for a given query. +/// `parameters` is the number of parameter placeholders in the query. +/// +fn query_plan(query: UntypedQuery, parameters: Int) -> Db(Plan) { + // We ask postgres to give us the query plan. To do that we need to fill in + // all the possible holes in the user supplied query with null values; + // otherwise, the server would complain that it has arguments that are not + // bound. + let query = "explain (format json, verbose) " <> query.content + let params = list.repeat(pg.Null, parameters) + let run_query = run_query(query, params, []) + use res <- eval.try(run_query) + + // We know the output will only contain a single row that is the json string + // containing the query plan. + let assert [plan] = res + let assert Ok([plan, ..]) = json.decode_bits(plan, json_plans_decoder) + eval.return(plan) +} + +/// Given a query plan, returns a set with the indices of the output columns +/// that can contain null values. +/// +fn nullables_from_plan(plan: Plan) -> Set(Int) { + let outputs = case plan.output { + Some(outputs) -> list.index_fold(outputs, dict.new(), dict.insert) + None -> dict.new() + } + + do_nullables_from_plan(plan, outputs, set.new()) +} + +fn do_nullables_from_plan( + plan: Plan, + // A dict from "column name" to its position in the query output. + query_outputs: Dict(String, Int), + nullables: Set(Int), +) -> Set(Int) { + let nullables = case plan.output, plan.join_type, plan.parent_relation { + // - All the outputs of a full join must be marked as nullable + // - All the outputs of an inner half join must be marked as nullable + Some(outputs), Some(Full), _ | Some(outputs), _, Some(Inner) -> { + use nullables, output <- list.fold(outputs, from: nullables) + case dict.get(query_outputs, output) { + Ok(i) -> set.insert(nullables, i) + Error(_) -> nullables + } + } + _, _, _ -> nullables + } + + case plan.plans, plan.join_type { + // If this is an inner half join we keep inspecting the children to mark + // their outputs as nullable. + Some(plans), Some(Left) | Some(plans), Some(Right) -> { + use nullables, plan <- list.fold(plans, from: nullables) + do_nullables_from_plan(plan, query_outputs, nullables) + } + _, _ -> nullables + } +} + +/// Given a list of `RowDescriptionFields` it turns those into Gleam fields with +/// a name and a type. +/// +/// This also uses nullability info coming from the query plan to figure out if +/// a column can be nullable or not: +/// - If the column name ends with `!` it will be forced to be not nullable +/// - If the column name ends with `?` it will be forced to be nullable +/// - If the column appears in the `nullables` set then it will be nullable +/// - Othwerwise we look for its metadata in the database and if it has a +/// not-null constraint it will be not nullable; otherwise it will be nullable +/// +fn resolve_returns( + query: UntypedQuery, + returns: List(pg.RowDescriptionField), + nullables: Set(Int), +) -> Db(List(gleam.Field)) { + use column, i <- eval_extra.try_index_map(returns) + let pg.RowDescriptionField( + data_type_oid: type_oid, + attr_number: column, + table_oid: table, + name: name, + .., + ) = column + + use type_ <- eval.try(find_gleam_type(query, type_oid)) + + let ends_with_exclamation_mark = string.ends_with(name, "!") + let ends_with_question_mark = string.ends_with(name, "?") + use nullability <- eval.try(case ends_with_exclamation_mark { + True -> eval.return(NotNullable) + False -> + case ends_with_question_mark { + True -> eval.return(Nullable) + False -> + case set.contains(nullables, i) { + True -> eval.return(Nullable) + False -> column_nullability(table: table, column: column) + } + } + }) + + let type_ = case nullability { + Nullable -> gleam.Option(type_) + NotNullable -> type_ + } + + let try_convert_name = + // If the name ends with a `?` or `!` we don't want that to be included in + // the gleam name or it would be invalid! + case ends_with_exclamation_mark || ends_with_question_mark { + True -> string.drop_right(name, 1) + False -> name + } + |> gleam.identifier + |> result.map_error(invalid_column_error(query, name, _)) + + use name <- eval.try(eval.from_result(try_convert_name)) + + let field = gleam.Field(label: name, type_: type_) + eval.return(field) +} + +fn column_nullability(table table: Int, column column: Int) -> Db(Nullability) { + // We first check if the table+column is cached to avoid making redundant + // queries to the database. + use <- with_cached_column(table: table, column: column) + + // If the table oid is 0 that means the column doesn't come from any table so + // we just assume it's not nullable. + use <- bool.guard(when: table == 0, return: eval.return(NotNullable)) + + // This query has 2 parameters: + // - the oid of the table (a 32bit integer, oid is 23) + // - the index of the column (a 32 bit integer, oid is 23) + let params = [pg.Parameter(<>), pg.Parameter(<>)] + let run_query = find_column_nullability_query |> run_query(params, [23, 23]) + use res <- eval.try(run_query) + + // We know the output will only have only one column, that is the boolean + // telling us if the column has a not-null constraint. + let assert [has_non_null_constraint] = res + case bit_array_to_bool(has_non_null_constraint) { + True -> eval.return(NotNullable) + False -> eval.return(Nullable) + } +} + +// --- DB ACTION HELPERS ------------------------------------------------------- + +/// Runs a query against the database. +/// - `parameters` are the parameter values that need to be supplied in place of +/// the query placeholders +/// - `parameters_object_ids` are the oids describing the type of each +/// parameter. +/// +/// > โš ๏ธ The `parameters_objects_ids` should have the same length of +/// > `parameters` and correctly describe each parameter's type. This function +/// > makes no attempt whatsoever to verify this assumption is correct so be +/// > careful! +/// +/// > โš ๏ธ This function makes the assumption that the query will only return one +/// > single row. This is totally fine here because we only use this to run +/// > specific hard coded queries that are guaranteed to return a single row. +/// +fn run_query( + query: String, + parameters: List(pg.ParameterValue), + parameters_object_ids: List(Int), +) -> Db(List(BitArray)) { + // The message exchange to run a query works as follow: + // - `Parse` we ask the server to parse the query, we do not give it a name + // - `Bind` we bind the query to the unnamed portal so that it is ready to + // be executed. + // - `Execute` we ask the server to run the unnamed portal and return all + // the rows (0 means return all rows). + // - `Close` we ask to close the unnamed query and the unnamed portal to free + // their resources. + // - `Sync` this acts as a synchronization point that needs to go at the end + // of the sequence before the next one. + // + // As you can see in the receiving part, each message we send corresponds to a + // specific answer from the server: + // - `ParseComplete` the query was parsed correctly + // - `BindComplete` the query was bound to a portal + // - `MessageDataRow` the result(s) coming from the query execution + // - `CommandComplete` when the result coming from the query is over + // - `CloseComplete` the portal/statement was closed + // - `ReadyForQuery` final reply to the sync message signalling we can go on + // making new requests + use _ <- eval.try( + send_all([ + pg.FeParse("", query, parameters_object_ids), + pg.FeBind( + portal: "", + statement_name: "", + parameter_format: pg.FormatAll(pg.Binary), + parameters:, + result_format: pg.FormatAll(pg.Binary), + ), + pg.FeExecute("", 0), + pg.FeClose(pg.PreparedStatement, ""), + pg.FeClose(pg.Portal, ""), + pg.FeSync, + ]), + ) + + use msg <- eval.try(receive()) + let assert pg.BeParseComplete = msg + use msg <- eval.try(receive()) + let assert pg.BeBindComplete = msg + use msg <- eval.try(receive()) + let assert pg.BeMessageDataRow(res) = msg + use msg <- eval.try(receive()) + let assert pg.BeCommandComplete(_, _) = msg + use msg <- eval.try(receive()) + let assert pg.BeCloseComplete = msg + use msg <- eval.try(receive()) + let assert pg.BeCloseComplete = msg + use msg <- eval.try(receive()) + let assert pg.BeReadyForQuery(_) = msg + eval.return(res) +} + +/// Receive a single message from the database. +/// +fn receive() -> Db(pg.BackendMessage) { + use Context(db: db, ..) as context <- eval.from + case pg.receive(db) { + Ok(#(db, msg)) -> #(Context(..context, db: db), Ok(msg)) + Error(pg.ReadDecodeError(error)) -> #( + context, + Error(PgCannotDecodeReceivedMessage(string.inspect(error))), + ) + Error(pg.SocketError(error)) -> #( + context, + Error(PgCannotReceiveMessage(string.inspect(error))), + ) + } +} + +/// Send a single message to the database. +/// +fn send(message message: pg.FrontendMessage) -> Db(Nil) { + use Context(db: db, ..) as context <- eval.from + + let result = + message + |> pg.encode_frontend_message + |> pg.send(db, _) + + let #(db, result) = case result { + Ok(db) -> #(db, Ok(Nil)) + Error(error) -> #(db, Error(PgCannotSendMessage(string.inspect(error)))) + } + + #(Context(..context, db: db), result) +} + +/// Send many messages, one after the other. +/// +fn send_all(messages messages: List(pg.FrontendMessage)) -> Db(Nil) { + use acc, msg <- eval_extra.try_fold(messages, from: Nil) + use _ <- eval.try(send(msg)) + eval.return(acc) +} + +/// Start receiving and discarding messages until a `ReadyForQuery` message is +/// received. +/// +fn wait_until_ready() -> Db(Nil) { + use _ <- eval.try(send(pg.FeFlush)) + do_wait_until_ready() +} + +fn do_wait_until_ready() -> Db(Nil) { + use msg <- eval.try(receive()) + case msg { + pg.BeReadyForQuery(_) -> eval.return(Nil) + _ -> do_wait_until_ready() + } +} + +/// Throws an error built from an expected message and an unexpected message +/// that is turned into a string. +/// +fn unexpected_message( + builder: fn(String, String) -> Error, + expected expected: String, + got got: pg.BackendMessage, +) { + builder(expected, string.inspect(got)) |> eval.throw +} + +/// Looks up for a type with the given id in a global cache. +/// If the type is present it immediately returns it. +/// Otherwise it runs the database action to fetch it and then caches it to be +/// reused later. +/// +fn with_cached_gleam_type( + lookup oid: Int, + otherwise do: fn() -> Db(gleam.Type), +) -> Db(gleam.Type) { + use context: Context <- eval.from + case dict.get(context.gleam_types, oid) { + Ok(type_) -> #(context, Ok(type_)) + Error(_) -> + case eval.step(do(), context) { + #(_, Error(_)) as result -> result + #(Context(gleam_types: gleam_types, ..) as context, Ok(type_)) -> { + let gleam_types = dict.insert(gleam_types, oid, type_) + let new_context = Context(..context, gleam_types: gleam_types) + #(new_context, Ok(type_)) + } + } + } +} + +/// Looks up for the nullability of table's column. +/// If the nullability is cached it is immediately returns it. +/// Otherwise it runs the database action to fetch it and then caches it to be +/// reused later. +/// +fn with_cached_column( + table table_oid: Int, + column column: Int, + otherwise do: fn() -> Db(Nullability), +) -> Db(Nullability) { + use context: Context <- eval.from + let key = #(table_oid, column) + case dict.get(context.column_nullability, key) { + Ok(type_) -> #(context, Ok(type_)) + Error(_) -> + case eval.step(do(), context) { + #(_, Error(_)) as result -> result + #( + Context( + column_nullability: column_nullability, + .., + ) as context, + Ok(type_), + ) -> { + let column_nullability = dict.insert(column_nullability, key, type_) + let new_context = + Context(..context, column_nullability: column_nullability) + #(new_context, Ok(type_)) + } + } + } +} + +// --- HELPERS TO BUILD ERRORS ------------------------------------------------- + +fn unsupported_type_error(query: UntypedQuery, type_: String) -> Error { + let UntypedQuery( + content: content, + file: file, + name: name, + starting_line: starting_line, + comment: _, + ) = query + QueryHasUnsupportedType( + file: file, + name: gleam.identifier_to_string(name), + content: content, + type_: type_, + starting_line: starting_line, + ) +} + +fn cannot_parse_error( + query: UntypedQuery, + error_code: Option(String), + hint: Option(String), + pointer: Option(Pointer), +) -> Error { + let UntypedQuery( + content: content, + file: file, + name: name, + starting_line: starting_line, + comment: _, + ) = query + CannotParseQuery( + content: content, + file: file, + name: gleam.identifier_to_string(name), + error_code: error_code, + hint: hint, + pointer: pointer, + starting_line: starting_line, + ) +} + +fn invalid_column_error( + query: UntypedQuery, + column_name: String, + reason: ValueIdentifierError, +) -> Error { + let UntypedQuery( + name: _, + file: file, + content: content, + starting_line: starting_line, + comment: _, + ) = query + QueryHasInvalidColumn( + file: file, + column_name: column_name, + suggested_name: gleam.similar_identifier_string(column_name) + |> option.from_result, + content: content, + reason: reason, + starting_line: starting_line, + ) +} + +// --- DECODERS ---------------------------------------------------------------- + +fn json_plans_decoder(data: Dynamic) -> Result(List(Plan), DecodeErrors) { + d.list(d.field("Plan", plan_decoder))(data) +} + +fn plan_decoder(data: Dynamic) -> Result(Plan, DecodeErrors) { + d.decode4( + Plan, + d.optional_field("Join Type", join_type_decoder), + d.optional_field("Parent Relationship", parent_relation_decoder), + d.optional_field("Output", d.list(d.string)), + d.optional_field("Plans", d.list(plan_decoder)), + )(data) +} + +fn join_type_decoder(data: Dynamic) -> Result(JoinType, DecodeErrors) { + use data <- result.map(d.string(data)) + case data { + "Full" -> Full + "Left" -> Left + "Right" -> Right + _ -> Other + } +} + +fn parent_relation_decoder( + data: Dynamic, +) -> Result(ParentRelation, DecodeErrors) { + use data <- result.map(d.string(data)) + case data { + "Inner" -> Inner + _ -> NotInner + } +} + +// --- UTILS ------------------------------------------------------------------- + +/// Turns a bit array into a boolean value. +/// Returns `False` if the bit array is all `0`s or empty, `True` otherwise. +/// +fn bit_array_to_bool(bit_array: BitArray) -> Bool { + case bit_array { + <<0, rest:bits>> -> bit_array_to_bool(rest) + <<>> -> False + _ -> True + } +} diff --git a/src/squirrel/internal/database/postgres_protocol.gleam b/src/squirrel/internal/database/postgres_protocol.gleam new file mode 100644 index 0000000..9556bb1 --- /dev/null +++ b/src/squirrel/internal/database/postgres_protocol.gleam @@ -0,0 +1,1391 @@ +//// Vendored version of https://hex.pm/packages/postgresql_protocol. +//// Ideally this should not be touched unless some really special needs come +//// up. +//// +//// I had to make a little change to the existing library: +//// - do not fail with an unexpected Command if the outer message is ok +//// +//// This library parses and generates packages for the PostgreSQL Binary Protocol +//// It also provides a basic connection abstraction, but this hasn't been used +//// outside of tests. + +import gleam/bit_array +import gleam/bool +import gleam/dict +import gleam/int +import gleam/list +import gleam/result.{try} +import gleam/set +import mug + +pub type Connection { + Connection(socket: mug.Socket, buffer: BitArray, timeout: Int) +} + +pub fn connect(host, port, timeout) { + let assert Ok(socket) = + mug.connect(mug.ConnectionOptions(host: host, port: port, timeout: timeout)) + + Connection(socket: socket, buffer: <<>>, timeout: timeout) +} + +pub type StateInitial { + StateInitial(parameters: dict.Dict(String, String)) +} + +pub type State { + State( + process_id: Int, + secret_key: Int, + parameters: dict.Dict(String, String), + oids: dict.Dict(Int, fn(BitArray) -> Result(Int, Nil)), + ) +} + +fn default_oids() { + dict.new() + |> dict.insert(23, fn(col: BitArray) { + use converted <- try(bit_array.to_string(col)) + int.parse(converted) + }) +} + +pub fn start(conn, params) { + let assert Ok(conn) = + conn + |> send(encode_startup_message(params)) + + let assert Ok(#(conn, state)) = + conn + |> receive_startup_rec(StateInitial(dict.new())) + + #(conn, state) +} + +fn receive_startup_rec(conn: Connection, state: StateInitial) { + case receive(conn) { + Ok(#(conn, BeAuthenticationOk)) -> receive_startup_rec(conn, state) + Ok(#(conn, BeParameterStatus(name: name, value: value))) -> + receive_startup_rec( + conn, + StateInitial(parameters: dict.insert(state.parameters, name, value)), + ) + Ok(#(conn, BeBackendKeyData(secret_key: secret_key, process_id: process_id))) -> + receive_startup_1( + conn, + State( + parameters: state.parameters, + secret_key: secret_key, + process_id: process_id, + oids: default_oids(), + ), + ) + Ok(#(_conn, msg)) -> Error(StartupFailedWithUnexpectedMessage(msg)) + Error(err) -> Error(StartupFailedWithError(err)) + } +} + +fn receive_startup_1(conn: Connection, state: State) { + case receive(conn) { + Ok(#(conn, BeParameterStatus(name: name, value: value))) -> + receive_startup_1( + conn, + State(..state, parameters: dict.insert(state.parameters, name, value)), + ) + Ok(#(conn, BeReadyForQuery(_))) -> Ok(#(conn, state)) + Ok(#(_conn, msg)) -> Error(StartupFailedWithUnexpectedMessage(msg)) + Error(err) -> Error(StartupFailedWithError(err)) + } +} + +pub type StartupFailed { + StartupFailedWithUnexpectedMessage(BackendMessage) + StartupFailedWithError(ReadError) +} + +pub type ReadError { + SocketError(mug.Error) + ReadDecodeError(MessageDecodingError) +} + +pub type MessageDecodingError { + MessageDecodingError(String) + MessageIncomplete(BitArray) + UnknownMessage(data: BitArray) +} + +fn dec_err(desc: String, data: BitArray) -> Result(a, MessageDecodingError) { + Error(MessageDecodingError(desc <> "; data: " <> bit_array.inspect(data))) +} + +fn msg_dec_err(desc: String, data: BitArray) -> MessageDecodingError { + MessageDecodingError(desc <> "; data: " <> bit_array.inspect(data)) +} + +pub type Command { + Insert + Delete + Update + Merge + Select + Move + Fetch + Copy +} + +/// Messages originating from the PostgreSQL backend (server) +pub type BackendMessage { + BeBindComplete + BeCloseComplete + BeCommandComplete(Command, Int) + BeCopyData(data: BitArray) + BeCopyDone + BeAuthenticationOk + BeAuthenticationKerberosV5 + BeAuthenticationCleartextPassword + BeAuthenticationMD5Password(salt: BitArray) + BeAuthenticationGSS + BeAuthenticationGSSContinue(auth_data: BitArray) + BeAuthenticationSSPI + BeAuthenticationSASL(mechanisms: List(String)) + BeAuthenticationSASLContinue(data: BitArray) + BeAuthenticationSASLFinal(data: BitArray) + BeReadyForQuery(TransactionStatus) + BeRowDescriptions(List(RowDescriptionField)) + BeMessageDataRow(List(BitArray)) + BeBackendKeyData(process_id: Int, secret_key: Int) + BeParameterStatus(name: String, value: String) + BeCopyResponse( + direction: CopyDirection, + overall_format: Format, + codes: List(Format), + ) + BeNegotiateProtocolVersion( + newest_minor: Int, + unrecognized_options: List(String), + ) + BeNoData + BeNoticeResponse(set.Set(ErrorOrNoticeField)) + BeNotificationResponse(process_id: Int, channel: String, payload: String) + BeParameterDescription(List(Int)) + BeParseComplete + BePortalSuspended + BeErrorResponse(set.Set(ErrorOrNoticeField)) +} + +/// Direction of a BeCopyResponse +pub type CopyDirection { + In + Out + Both +} + +/// Indicates encoding of column values +pub type Format { + Text + Binary +} + +/// Lists of parameters can have different encodings. +pub type FormatValue { + FormatAllText + FormatAll(Format) + Formats(List(Format)) +} + +fn encode_format_value(format) -> BitArray { + case format { + FormatAllText -> <> + FormatAll(Text) -> <<1:16, encode_format(Text):16>> + FormatAll(Binary) -> <<1:16, encode_format(Binary):16>> + Formats(formats) -> { + let size = list.length(formats) + list.fold(formats, <>, fn(sum, fmt) { + <> + }) + } + } +} + +pub type ParameterValues = + List(ParameterValue) + +pub type ParameterValue { + Parameter(BitArray) + Null +} + +fn parameters_to_bytes(parameters: ParameterValues) { + let mapped = + parameters + |> list.map(fn(parameter) { + case parameter { + Parameter(value) -> <> + Null -> <<-1:32>> + } + }) + + <> +} + +pub type FrontendMessage { + FeBind( + portal: String, + statement_name: String, + parameter_format: FormatValue, + parameters: ParameterValues, + result_format: FormatValue, + ) + FeCancelRequest(process_id: Int, secret_key: Int) + FeClose(what: What, name: String) + FeCopyData(data: BitArray) + FeCopyDone + FeCopyFail(error: String) + FeDescribe(what: What, name: String) + FeExecute(portal: String, return_row_count: Int) + FeFlush + FeFunctionCall( + object_id: Int, + argument_format: FormatValue, + arguments: ParameterValues, + result_format: Format, + ) + FeGssEncRequest + FeParse(name: String, query: String, parameter_object_ids: List(Int)) + FeQuery(query: String) + FeStartupMessage(params: List(#(String, String))) + FeSslRequest + FeTerminate + FeSync + FeAmbigous(FeAmbigous) +} + +// These all share the same message type +pub type FeAmbigous { + FeGssResponse(data: BitArray) + FeSaslInitialResponse(name: String, data: BitArray) + FeSaslResponse(data: BitArray) + FePasswordMessage(password: String) +} + +pub type What { + PreparedStatement + Portal +} + +fn wire_what(what) { + case what { + Portal -> <<"P":utf8>> + PreparedStatement -> <<"S":utf8>> + } +} + +fn decode_what(binary) { + case binary { + <<"P":utf8, rest:bytes>> -> Ok(#(Portal, rest)) + <<"S":utf8, rest:bytes>> -> Ok(#(PreparedStatement, rest)) + _ -> dec_err("only portal and prepared statement are allowed", binary) + } +} + +fn encode_message_data_row(columns) { + <> + |> list.fold( + columns, + _, + fn(sum, column) { + let len = bit_array.byte_size(column) + <> + }, + ) + |> encode("D", _) +} + +fn encode_error_response(fields: set.Set(ErrorOrNoticeField)) -> BitArray { + fields + |> set.fold(<<>>, fn(sum, field) { <> }) + |> encode("E", _) +} + +fn encode_field(field) { + case field { + Severity(value) -> <<"S":utf8, value:utf8, 0>> + SeverityLocalized(value) -> <<"V":utf8, value:utf8, 0>> + Code(value) -> <<"C":utf8, value:utf8, 0>> + Message(value) -> <<"M":utf8, value:utf8, 0>> + Detail(value) -> <<"D":utf8, value:utf8, 0>> + Hint(value) -> <<"H":utf8, value:utf8, 0>> + Position(value) -> <<"P":utf8, value:utf8, 0>> + InternalPosition(value) -> <<"p":utf8, value:utf8, 0>> + InternalQuery(value) -> <<"q":utf8, value:utf8, 0>> + Where(value) -> <<"W":utf8, value:utf8, 0>> + Schema(value) -> <<"s":utf8, value:utf8, 0>> + Table(value) -> <<"t":utf8, value:utf8, 0>> + Column(value) -> <<"c":utf8, value:utf8, 0>> + DataType(value) -> <<"d":utf8, value:utf8, 0>> + Constraint(value) -> <<"n":utf8, value:utf8, 0>> + File(value) -> <<"F":utf8, value:utf8, 0>> + Line(value) -> <<"L":utf8, value:utf8, 0>> + Routine(value) -> <<"R":utf8, value:utf8, 0>> + Unknown(key, value) -> <> + } +} + +fn encode_authentication_sasl(mechanisms) -> BitArray { + list.fold(mechanisms, <<10:32>>, fn(sum, mechanism) { + <> + }) + |> encode("R", _) +} + +fn encode_command_complete(command, rows_num) { + let rows = int.to_string(rows_num) + + let data = + case command { + Insert -> "INSERT 0 " <> rows + Delete -> "DELETE " <> rows + Update -> "UPDATE " <> rows + Merge -> "MERGE " <> rows + Select -> "SELECT " <> rows + Move -> "MOVE " <> rows + Fetch -> "FETCH " <> rows + Copy -> "COPY " <> rows + } + |> encode_string() + + encode("C", data) +} + +fn encode_copy_response( + direction: CopyDirection, + overall_format: Format, + codes: List(Format), +) { + let data = encode_copy_response_rec(overall_format, codes) + case direction { + In -> encode("G", data) + Out -> encode("H", data) + Both -> encode("W", data) + } +} + +// The format codes to be used for each column. Each must presently be zero +// (text) or one (binary). All must be zero if the overall copy format is +// textual. +fn encode_copy_response_rec(overall_format, codes) { + case overall_format { + Text -> + codes + |> list.fold( + <>, + fn(sum, _code) { <> }, + ) + Binary -> + codes + |> list.fold( + <>, + fn(sum, code) { <> }, + ) + } +} + +fn encode_parameter_status(name: String, value: String) -> BitArray { + encode("S", <>) +} + +fn encode_negotiate_protocol_version( + newest_minor: Int, + unrecognized_options: List(String), +) { + list.fold( + unrecognized_options, + <>, + fn(sum, option) { <> }, + ) + |> encode("v", _) +} + +fn encode_notice_response(fields: set.Set(ErrorOrNoticeField)) { + fields + |> set.fold(<<>>, fn(sum, field) { <> }) + |> encode("N", _) +} + +fn encode_notification_response( + process_id: Int, + channel: String, + payload: String, +) { + encode("A", <>) +} + +fn encode_parameter_description(descriptions: List(Int)) { + descriptions + |> list.fold(<>, fn(sum, description) { + <> + }) + |> encode("t", _) +} + +fn encode_row_descriptions(fields: List(RowDescriptionField)) { + fields + |> list.fold(<>, fn(sum, field) { + << + sum:bits, + encode_string(field.name):bits, + field.table_oid:32, + field.attr_number:16, + field.data_type_oid:32, + field.data_type_size:16, + field.type_modifier:32, + field.format_code:16, + >> + }) + |> encode("T", _) +} + +pub fn encode_backend_message(message: BackendMessage) -> BitArray { + case message { + BeMessageDataRow(columns) -> encode_message_data_row(columns) + BeErrorResponse(fields) -> encode_error_response(fields) + BeAuthenticationOk -> encode("R", <<0:32>>) + BeAuthenticationKerberosV5 -> encode("R", <<2:32>>) + BeAuthenticationCleartextPassword -> encode("R", <<3:32>>) + BeAuthenticationMD5Password(salt: salt) -> encode("R", <<5:32, salt:bits>>) + BeAuthenticationGSS -> encode("R", <<7:32>>) + BeAuthenticationGSSContinue(data) -> encode("R", <<8:32, data:bits>>) + BeAuthenticationSSPI -> encode("R", <<9:32>>) + BeAuthenticationSASL(a) -> encode_authentication_sasl(a) + BeAuthenticationSASLContinue(data) -> encode("R", <<11:32, data:bits>>) + BeAuthenticationSASLFinal(data: data) -> encode("R", <<12:32, data:bits>>) + BeBackendKeyData(pid, sk) -> encode("K", <>) + BeBindComplete -> encode("2", <<>>) + BeCloseComplete -> encode("3", <<>>) + BeCommandComplete(a, b) -> encode_command_complete(a, b) + BeCopyData(data) -> encode("d", data) + BeCopyDone -> encode("c", <<>>) + BeCopyResponse(a, b, c) -> encode_copy_response(a, b, c) + BeParameterStatus(name, value) -> encode_parameter_status(name, value) + BeNegotiateProtocolVersion(a, b) -> encode_negotiate_protocol_version(a, b) + BeNoData -> encode("n", <<>>) + BeNoticeResponse(data) -> encode_notice_response(data) + BeNotificationResponse(a, b, c) -> encode_notification_response(a, b, c) + BeParameterDescription(a) -> encode_parameter_description(a) + BeParseComplete -> encode("1", <<>>) + BePortalSuspended -> encode("s", <<>>) + BeRowDescriptions(a) -> encode_row_descriptions(a) + BeReadyForQuery(TransactionStatusIdle) -> encode("Z", <<"I":utf8>>) + BeReadyForQuery(TransactionStatusInTransaction) -> encode("Z", <<"T":utf8>>) + BeReadyForQuery(TransactionStatusFailed) -> encode("Z", <<"E":utf8>>) + } +} + +pub fn encode_frontend_message(message: FrontendMessage) { + case message { + FeBind(a, b, c, d, e) -> encode_bind(a, b, c, d, e) + FeCancelRequest(process_id: process_id, secret_key: secret_key) -> << + 16:32, + 1234:16, + 5678:16, + process_id:32, + secret_key:32, + >> + FeClose(what, name) -> + encode("C", <>) + FeCopyData(data) -> encode("d", data) + FeCopyDone -> <<"c":utf8, 4:32>> + FeCopyFail(error) -> encode("f", encode_string(error)) + FeDescribe(what, name) -> + encode("D", <>) + FeExecute(portal, count) -> + encode("E", <>) + FeFlush -> <<"H":utf8, 4:32>> + FeFunctionCall(a, b, c, d) -> encode_function_call(a, b, c, d) + FeGssEncRequest -> <<8:32, 1234:16, 5680:16>> + FeAmbigous(FeGssResponse(data)) -> encode("p", data) + FeParse(a, b, c) -> encode_parse(a, b, c) + FeAmbigous(FePasswordMessage(password)) -> + encode("p", encode_string(password)) + FeQuery(query) -> encode("Q", encode_string(query)) + FeAmbigous(FeSaslInitialResponse(a, b)) -> + encode_sasl_initial_response(a, b) + FeAmbigous(FeSaslResponse(data)) -> encode("p", data) + FeStartupMessage(params) -> encode_startup_message(params) + FeSslRequest -> <<8:32, 1234:16, 5679:16>> + FeSync -> <<"S":utf8, 4:32>> + FeTerminate -> <<"X":utf8, 4:32>> + } +} + +fn encode(type_char, data) { + case data { + <<>> -> <> + _ -> { + let len = bit_array.byte_size(data) + 4 + <> + } + } +} + +fn encode_string(str) { + <> +} + +pub const protocol_version_major = <<3:16>> + +pub const protocol_version_minor = <<0:16>> + +pub const protocol_version = << + protocol_version_major:bits, + protocol_version_minor:bits, +>> + +fn encode_startup_message(params) { + let packet = + params + |> list.fold(<>, fn(builder, element) { + let #(key, value) = element + <> + }) + + let size = bit_array.byte_size(packet) + 5 + + <> +} + +fn encode_sasl_initial_response(name, data) { + let len = bit_array.byte_size(data) + encode("p", <>) +} + +fn encode_parse(name, query, parameter_object_ids) { + let oids = + list.fold(parameter_object_ids, <<>>, fn(sum, oid) { <> }) + let len = list.length(parameter_object_ids) + encode("P", << + encode_string(name):bits, + encode_string(query):bits, + len:16, + oids:bits, + >>) +} + +fn encode_bind( + portal, + statement_name, + parameter_format, + parameters, + result_format, +) { + encode("B", << + portal:utf8, + 0, + statement_name:utf8, + 0, + encode_format_value(parameter_format):bits, + parameters_to_bytes(parameters):bits, + encode_format_value(result_format):bits, + >>) +} + +fn encode_function_call( + object_id: Int, + argument_format: FormatValue, + arguments: ParameterValues, + result_format: Format, +) { + encode("F", << + object_id:32, + encode_format_value(argument_format):bits, + parameters_to_bytes(arguments):bits, + encode_format(result_format):16, + >>) +} + +/// Send a message to the database +pub fn send_builder(conn: Connection, message) { + case mug.send_builder(conn.socket, message) { + Ok(Nil) -> Ok(conn) + Error(err) -> Error(err) + } +} + +/// Send a message to the database +pub fn send(conn: Connection, message) { + case mug.send(conn.socket, message) { + Ok(Nil) -> Ok(conn) + Error(err) -> Error(err) + } +} + +/// Receive a single message from the backend +pub fn receive( + conn: Connection, +) -> Result(#(Connection, BackendMessage), ReadError) { + case decode_backend_packet(conn.buffer) { + Ok(#(message, rest)) -> Ok(#(with_buffer(conn, rest), message)) + Error(MessageIncomplete(_)) -> { + case mug.receive(conn.socket, conn.timeout) { + Ok(packet) -> + receive(with_buffer(conn, <>)) + Error(err) -> Error(SocketError(err)) + } + } + Error(err) -> Error(ReadDecodeError(err)) + } +} + +fn with_buffer(conn: Connection, buffer: BitArray) { + Connection(..conn, buffer: buffer) +} + +// decode a single message from the packet +pub fn decode_backend_packet( + packet: BitArray, +) -> Result(#(BackendMessage, BitArray), MessageDecodingError) { + case packet { + <> -> { + let len = length - 4 + case tail { + <> -> + decode_backend_message(<>) + |> result.map(fn(msg) { #(msg, next) }) + _ -> Error(MessageIncomplete(tail)) + } + } + _ -> + case bit_array.byte_size(packet) < 5 { + True -> Error(MessageIncomplete(packet)) + False -> dec_err("packet size too small", packet) + } + } +} + +// decode a backend message +pub fn decode_backend_message(binary) { + case binary { + <<"D":utf8, count:16, data:bytes>> -> decode_message_data_row(count, data) + <<"E":utf8, data:bytes>> -> decode_error_response(data) + <<"R":utf8, 0:32>> -> Ok(BeAuthenticationOk) + <<"R":utf8, 2:32>> -> Ok(BeAuthenticationKerberosV5) + <<"R":utf8, 3:32>> -> Ok(BeAuthenticationCleartextPassword) + <<"R":utf8, 5:32, salt:bytes>> -> + Ok(BeAuthenticationMD5Password(salt: salt)) + <<"R":utf8, 7:32>> -> Ok(BeAuthenticationGSS) + <<"R":utf8, 8:32, auth_data:bytes>> -> + Ok(BeAuthenticationGSSContinue(auth_data: auth_data)) + <<"R":utf8, 9:32>> -> Ok(BeAuthenticationSSPI) + <<"R":utf8, 10:32, data:bytes>> -> decode_authentication_sasl(data) + <<"R":utf8, 11:32, data:bytes>> -> + Ok(BeAuthenticationSASLContinue(data: data)) + <<"R":utf8, 12:32, data:bytes>> -> Ok(BeAuthenticationSASLFinal(data: data)) + <<"K":utf8, pid:32, sk:32>> -> + Ok(BeBackendKeyData(process_id: pid, secret_key: sk)) + <<"2":utf8>> -> Ok(BeBindComplete) + <<"3":utf8>> -> Ok(BeCloseComplete) + <<"C":utf8, data:bytes>> -> decode_command_complete(data) + <<"d":utf8, data:bytes>> -> Ok(BeCopyData(data)) + <<"c":utf8>> -> Ok(BeCopyDone) + <<"G":utf8, format:8, count:16, data:bytes>> -> + decode_copy_response(In, format, count, data) + <<"H":utf8, format:8, count:16, data:bytes>> -> + decode_copy_response(Out, format, count, data) + <<"W":utf8, format:8, count:16, data:bytes>> -> + decode_copy_response(Both, format, count, data) + <<"S":utf8, data:bytes>> -> decode_parameter_status(data) + <<"v":utf8, version:32, count:32, data:bytes>> -> + decode_negotiate_protocol_version(version, count, data) + <<"n":utf8>> -> Ok(BeNoData) + <<"N":utf8, data:bytes>> -> decode_notice_response(data) + <<"A":utf8, process_id:32, data:bytes>> -> + decode_notification_response(process_id, data) + <<"t":utf8, count:16, data:bytes>> -> + decode_parameter_description(count, data, []) + <<"1":utf8>> -> Ok(BeParseComplete) + <<"s":utf8>> -> Ok(BePortalSuspended) + <<"T":utf8, count:16, data:bytes>> -> decode_row_descriptions(count, data) + <<"Z":utf8, "I":utf8>> -> Ok(BeReadyForQuery(TransactionStatusIdle)) + <<"Z":utf8, "T":utf8>> -> + Ok(BeReadyForQuery(TransactionStatusInTransaction)) + <<"Z":utf8, "E":utf8>> -> Ok(BeReadyForQuery(TransactionStatusFailed)) + _ -> Error(UnknownMessage(binary)) + } +} + +/// decode a single message from the packet buffer. +/// note that FeCancelRequest, FeSslRequest, FeGssEncRequest, and +/// FeStartupMessage messages can only be decoded here because they don't follow +/// the standard message format. +pub fn decode_frontend_packet( + packet: BitArray, +) -> Result(#(FrontendMessage, BitArray), MessageDecodingError) { + case packet { + <<16:32, 1234:16, 5678:16, process_id:32, secret_key:32, next:bytes>> -> + Ok(#( + FeCancelRequest(process_id: process_id, secret_key: secret_key), + next, + )) + <<8:32, 1234:16, 5679:16, next:bytes>> -> Ok(#(FeSslRequest, next)) + <<8:32, 1234:16, 5680:16, next:bytes>> -> Ok(#(FeGssEncRequest, next)) + // not sure if there's a way to use the `protocol_version` constant here + <> -> + decode_startup_message(next, length - 8, []) + <> -> { + let len = length - 4 + case tail { + <> -> + decode_frontend_message(<>) + |> result.map(fn(msg) { #(msg, next) }) + _ -> Error(MessageIncomplete(tail)) + } + } + <<_:48>> -> Error(MessageIncomplete(packet)) + _ -> dec_err("invalid message", packet) + } +} + +/// decode a frontend message (also see decode_frontend_packet for messages that +/// can't be decoded here) +pub fn decode_frontend_message( + binary: BitArray, +) -> Result(FrontendMessage, MessageDecodingError) { + case binary { + <<"B":utf8, data:bytes>> -> decode_bind(data) + <<"C":utf8, data:bytes>> -> decode_close(data) + <<"d":utf8, data:bytes>> -> Ok(FeCopyData(data)) + <<"c":utf8>> -> Ok(FeCopyDone) + <<"f":utf8, data:bytes>> -> decode_copy_fail(data) + <<"D":utf8, data:bytes>> -> decode_describe(data) + <<"E":utf8, data:bytes>> -> decode_execute(data) + <<"H":utf8>> -> Ok(FeFlush) + <<"F":utf8, data:bytes>> -> decode_function_call(data) + <<"p":utf8, data:bytes>> -> Ok(FeAmbigous(FeGssResponse(data))) + <<"P":utf8, data:bytes>> -> decode_parse(data) + <<"Q":utf8, data:bytes>> -> decode_query(data) + <<"S":utf8>> -> Ok(FeSync) + <<"X":utf8>> -> Ok(FeTerminate) + _ -> Error(UnknownMessage(data: binary)) + } +} + +fn decode_startup_message(binary, size, result) { + case binary { + <> -> + decode_startup_message_pairs(data, []) + |> result.map(fn(r) { #(r, next) }) + _ -> dec_err("invalid startup message", binary) + } +} + +fn decode_startup_message_pairs(binary, result) { + case binary { + <<0>> -> Ok(FeStartupMessage(params: list.reverse(result))) + _ -> { + use #(key, binary) <- try(decode_string(binary)) + use #(value, binary) <- try(decode_string(binary)) + decode_startup_message_pairs(binary, [#(key, value), ..result]) + } + } +} + +fn decode_query(binary) -> Result(FrontendMessage, MessageDecodingError) { + use #(query, rest) <- try(decode_string(binary)) + case rest { + <<>> -> Ok(FeQuery(query)) + _ -> dec_err("Query message too long", binary) + } +} + +// FeQuery(query) -> encode("Q", encode_string(query)) + +fn decode_parse(binary) -> Result(FrontendMessage, MessageDecodingError) { + use #(name, binary) <- try(decode_string(binary)) + use #(query, binary) <- try(decode_string(binary)) + use parameter_object_ids <- try(decode_parameter_object_ids(binary)) + Ok(FeParse( + name: name, + query: query, + parameter_object_ids: parameter_object_ids, + )) +} + +fn decode_parameter_object_ids(binary) { + case binary { + <> -> decode_parameter_object_ids_rec(rest, count, []) + _ -> dec_err("expected object id count", binary) + } +} + +fn decode_parameter_object_ids_rec(binary, count, result) { + case count, binary { + 0, <<>> -> Ok(list.reverse(result)) + _, <> -> + decode_parameter_object_ids_rec(rest, count - 1, [id, ..result]) + _, _ -> dec_err("expected parameter object id", binary) + } +} + +fn decode_function_call(binary) -> Result(FrontendMessage, MessageDecodingError) { + case binary { + <> -> { + use #(argument_format, rest) <- try(read_parameter_format(rest)) + use #(arguments, rest) <- try(read_parameters(rest, argument_format)) + use #(result_format, rest) <- try(read_format(rest)) + case rest { + <<>> -> + Ok(FeFunctionCall( + argument_format: argument_format, + arguments: arguments, + object_id: object_id, + result_format: result_format, + )) + _ -> dec_err("invalid function call, data remains", rest) + } + } + _ -> dec_err("invalid function call, no object id found", binary) + } +} + +fn decode_execute(binary) -> Result(FrontendMessage, MessageDecodingError) { + use #(portal, binary) <- try(decode_string(binary)) + case binary { + <> -> Ok(FeExecute(portal, count)) + _ -> dec_err("no execute return_row_count found", binary) + } +} + +fn decode_describe(binary) -> Result(FrontendMessage, MessageDecodingError) { + use #(what, binary) <- try(decode_what(binary)) + use #(name, binary) <- try(decode_string(binary)) + case binary { + <<>> -> Ok(FeDescribe(what, name)) + _ -> dec_err("Describe message too long", binary) + } +} + +fn decode_copy_fail(binary) -> Result(FrontendMessage, MessageDecodingError) { + use #(error, binary) <- try(decode_string(binary)) + case binary { + <<>> -> Ok(FeCopyFail(error)) + _ -> dec_err("CopyFail message too long", binary) + } +} + +fn decode_close(binary) -> Result(FrontendMessage, MessageDecodingError) { + use #(what, binary) <- try(decode_what(binary)) + use #(name, binary) <- try(decode_string(binary)) + case binary { + <<>> -> Ok(FeClose(what, name)) + _ -> dec_err("Close message too long", binary) + } +} + +// FeClose(what, name) -> +// encode("C", <>) + +fn decode_bind(binary) -> Result(FrontendMessage, MessageDecodingError) { + use #(portal, binary) <- try(decode_string(binary)) + use #(statement_name, binary) <- try(decode_string(binary)) + use #(parameter_format, binary) <- try(read_parameter_format(binary)) + use #(parameters, binary) <- try(read_parameters(binary, parameter_format)) + use #(result_format, binary) <- try(read_parameter_format(binary)) + case binary { + <<>> -> + Ok(FeBind( + portal: portal, + statement_name: statement_name, + parameter_format: parameter_format, + parameters: parameters, + result_format: result_format, + )) + _ -> dec_err("Bind message too long", binary) + } +} + +fn decode_string( + binary: BitArray, +) -> Result(#(String, BitArray), MessageDecodingError) { + case binary_split(binary, <<0>>, []) { + [head, tail] -> + case bit_array.to_string(head) { + Ok(str) -> Ok(#(str, tail)) + Error(Nil) -> dec_err("invalid string encoding", head) + } + _ -> dec_err("invalid string", binary) + } +} + +fn read_parameter_format( + binary: BitArray, +) -> Result(#(FormatValue, BitArray), MessageDecodingError) { + case binary { + <<0:16, rest:bytes>> -> Ok(#(FormatAllText, rest)) + <<1:16, 0:16, rest:bytes>> -> Ok(#(FormatAll(Text), rest)) + <<1:16, 1:16, rest:bytes>> -> Ok(#(FormatAll(Binary), rest)) + <> -> read_wire_formats(n, rest, []) + _ -> dec_err("invalid parameter format", binary) + } +} + +fn read_parameters(binary: BitArray, parameter_format: FormatValue) { + case binary { + <> -> { + case parameter_format { + FormatAllText -> list.repeat(Text, count) + FormatAll(format) -> list.repeat(format, count) + Formats(formats) -> formats + } + |> read_parameters_rec(count, rest, []) + } + _ -> dec_err("parameters without count", binary) + } +} + +fn read_parameters_rec( + formats: List(Format), + count: Int, + binary: BitArray, + result: List(ParameterValue), +) -> Result(#(List(ParameterValue), BitArray), MessageDecodingError) { + let actual = list.length(formats) + use <- bool.guard( + actual != count, + dec_err( + "expected " + <> int.to_string(count) + <> " parameters, but got " + <> int.to_string(actual), + binary, + ), + ) + + case count, formats, binary { + 0, _, rest -> Ok(#(list.reverse(result), rest)) + _, [_, ..rest_formats], <<-1:32-signed, rest:bytes>> -> { + read_parameters_rec(rest_formats, count - 1, rest, [Null, ..result]) + } + _, [format, ..rest_formats], <> -> + read_parameters_rec(rest_formats, count - 1, rest, [ + read_parameter(format, value), + ..result + ]) + _, _, rest -> dec_err("invalid parameter value", rest) + } +} + +fn read_parameter(format: Format, value: BitArray) { + case format { + Text -> Parameter(value) + Binary -> Parameter(value) + } +} + +fn read_wire_formats( + count: Int, + binary: BitArray, + result: List(Format), +) -> Result(#(FormatValue, BitArray), MessageDecodingError) { + case count, binary { + 0, _ -> Ok(#(Formats(list.reverse(result)), binary)) + _, <<0:16, rest:bytes>> -> + read_wire_formats(count - 1, rest, [Text, ..result]) + _, <<1:16, rest:bytes>> -> + read_wire_formats(count - 1, rest, [Binary, ..result]) + _, _ -> dec_err("unknown format", binary) + } +} + +fn decode_command_complete( + binary: BitArray, +) -> Result(BackendMessage, MessageDecodingError) { + { + use fine <- try(case binary { + <<"INSERT 0 ":utf8, rows:bytes>> -> Ok(#(Insert, rows)) + <<"DELETE ":utf8, rows:bytes>> -> Ok(#(Delete, rows)) + <<"UPDATE ":utf8, rows:bytes>> -> Ok(#(Update, rows)) + <<"MERGE ":utf8, rows:bytes>> -> Ok(#(Merge, rows)) + <<"SELECT ":utf8, rows:bytes>> -> Ok(#(Select, rows)) + <<"MOVE ":utf8, rows:bytes>> -> Ok(#(Move, rows)) + <<"FETCH ":utf8, rows:bytes>> -> Ok(#(Fetch, rows)) + <<"COPY ":utf8, rows:bytes>> -> Ok(#(Copy, rows)) + _ -> dec_err("invalid command", binary) + }) + + let #(command, rows_raw) = fine + let len = bit_array.byte_size(rows_raw) - 1 + + use rows_bits <- try(case rows_raw { + <> -> Ok(rows_bits) + _ -> dec_err("invalid command row count", binary) + }) + + use rows_string <- try( + bit_array.to_string(rows_bits) + |> result.replace_error(msg_dec_err( + "failed to convert row count to string", + rows_bits, + )), + ) + + use rows <- try( + int.parse(rows_string) + |> result.replace_error(msg_dec_err( + "failed to convert row count to int", + rows_bits, + )), + ) + + Ok(BeCommandComplete(command, rows)) + } + |> result.or(Ok(BeCommandComplete(Insert, -1))) +} + +pub type TransactionStatus { + TransactionStatusIdle + TransactionStatusInTransaction + TransactionStatusFailed +} + +fn decode_parameter_description(count, binary, results) { + case count, binary { + 0, <<>> -> Ok(BeParameterDescription(list.reverse(results))) + _, <> -> + decode_parameter_description(count - 1, tail, [value, ..results]) + _, _ -> dec_err("invalid parameter description", binary) + } +} + +fn decode_notification_response( + process_id, + binary, +) -> Result(BackendMessage, MessageDecodingError) { + use strings <- try(read_strings(binary, 2, [])) + case strings { + [channel, payload] -> + Ok(BeNotificationResponse( + process_id: process_id, + channel: channel, + payload: payload, + )) + _ -> dec_err("invalid notification response encoding", binary) + } +} + +fn decode_negotiate_protocol_version(version, count, binary) { + use options <- try(read_strings(binary, count, [])) + Ok(BeNegotiateProtocolVersion(version, options)) +} + +fn decode_row_descriptions(count, binary) { + use fields <- try(read_row_descriptions(count, binary, [])) + Ok(BeRowDescriptions(fields)) +} + +fn decode_parameter_status(binary) { + use strings <- try(decode_strings(binary)) + case strings { + [name, value] -> Ok(BeParameterStatus(name: name, value: value)) + _ -> dec_err("invalid parameter status", binary) + } +} + +fn decode_authentication_sasl(binary) { + use strings <- try(decode_strings(binary)) + Ok(BeAuthenticationSASL(strings)) +} + +fn decode_copy_response(direction, format_raw, count, rest) { + use overall_format <- try(decode_format(format_raw)) + use <- bool.guard( + bit_array.byte_size(rest) != count * 2, + dec_err("size must be count * 2", rest), + ) + + use codes <- try(decode_format_codes(rest, [])) + + case overall_format == Text { + False -> Ok(BeCopyResponse(direction, overall_format, codes)) + True -> + case list.all(codes, fn(code) { code == Text }) { + True -> Ok(BeCopyResponse(direction, overall_format, codes)) + False -> dec_err("invalid copy response format", rest) + } + } +} + +fn decode_format_codes( + binary: BitArray, + result: List(Format), +) -> Result(List(Format), MessageDecodingError) { + case binary { + <> -> + case decode_format(code) { + Ok(format) -> decode_format_codes(tail, [format, ..result]) + Error(err) -> Error(err) + } + <<>> -> Ok(list.reverse(result)) + _ -> dec_err("invalid format codes", binary) + } +} + +fn decode_format(num: Int) { + case num { + 0 -> Ok(Text) + 1 -> Ok(Binary) + _ -> dec_err("invalid format code: " <> int.to_string(num), <<>>) + } +} + +fn read_format(binary) { + case binary { + <<0:16, rest:bytes>> -> Ok(#(Text, rest)) + <<1:16, rest:bytes>> -> Ok(#(Binary, rest)) + _ -> dec_err("invalid format code", binary) + } +} + +fn encode_format(format_raw) -> Int { + case format_raw { + Text -> 0 + Binary -> 1 + } +} + +pub type DataRow { + DataRow(List(BitArray)) +} + +fn decode_message_data_row( + count, + rest, +) -> Result(BackendMessage, MessageDecodingError) { + case decode_message_data_row_rec(rest, count, []) { + Ok(cols) -> + case list.length(cols) == count { + True -> Ok(BeMessageDataRow(cols)) + False -> dec_err("column count doesn't match", rest) + } + Error(err) -> Error(err) + } +} + +fn decode_message_data_row_rec( + binary, + count, + result, +) -> Result(List(BitArray), MessageDecodingError) { + case count, binary { + 0, _ -> Ok(list.reverse(result)) + _, <> -> + decode_message_data_row_rec(rest, count - 1, [value, ..result]) + _, _ -> + dec_err( + "failed to parse data row at count " <> int.to_string(count), + binary, + ) + } +} + +pub type RowDescriptionField { + RowDescriptionField( + // The field name. + name: String, + // If the field can be identified as a column of a specific table, the + // object ID of the table; otherwise zero. + table_oid: Int, + // If the field can be identified as a column of a specific table, the + // attribute number of the column; otherwise zero. + attr_number: Int, + // The object ID of the field's data type. + data_type_oid: Int, + // The data type size (see pg_type.typlen). Note that negative values denote + // variable-width types. + data_type_size: Int, + // The type modifier (see pg_attribute.atttypmod). The meaning of the + // modifier is type-specific. + type_modifier: Int, + // The format code being used for the field. Currently will be zero (text) + // or one (binary). In a RowDescription returned from the statement variant + // of Describe, the format code is not yet known and will always be zero. + format_code: Int, + ) +} + +fn read_row_descriptions(count, binary, result) { + case count, binary { + 0, <<>> -> Ok(list.reverse(result)) + _, <<>> -> dec_err("row description count mismatch", binary) + _, _ -> + case read_row_description_field(binary) { + Ok(#(field, tail)) -> + read_row_descriptions(count - 1, tail, [field, ..result]) + Error(err) -> Error(err) + } + } +} + +fn read_row_description_field(binary) { + case read_string(binary) { + Ok(#( + name, + << + table_oid:32, + attr_number:16, + data_type_oid:32, + data_type_size:16, + type_modifier:32, + format_code:16, + tail:bytes, + >>, + )) -> + Ok(#( + RowDescriptionField( + name: name, + table_oid: table_oid, + attr_number: attr_number, + data_type_oid: data_type_oid, + data_type_size: data_type_size, + type_modifier: type_modifier, + format_code: format_code, + ), + tail, + )) + Ok(#(_, tail)) -> dec_err("failed to parse row description field", tail) + Error(_) -> dec_err("failed to decode row description field name", binary) + } +} + +fn decode_strings(binary) { + let length = bit_array.byte_size(binary) - 1 + case binary { + <<>> -> Ok([]) + <> -> { + binary_split(head, <<0>>, [Global]) + |> list.map(bit_array.to_string) + |> result.all() + |> result.replace_error(msg_dec_err("invalid strings encoding", binary)) + } + _ -> dec_err("string size didn't match", binary) + } +} + +fn read_strings(binary, count, result) { + case count { + 0 -> Ok(list.reverse(result)) + _ -> { + case read_string(binary) { + Ok(#(value, rest)) -> read_strings(rest, count - 1, [value, ..result]) + Error(err) -> Error(err) + } + } + } +} + +fn read_string(binary) { + case binary_split(binary, <<0>>, []) { + [<<>>, <<>>] -> Ok(#("", <<>>)) + [head, tail] -> { + bit_array.to_string(head) + |> result.replace_error(msg_dec_err("invalid string encoding", head)) + |> result.map(fn(s) { #(s, tail) }) + } + _ -> dec_err("invalid string", binary) + } +} + +fn decode_notice_response(binary) { + use fields <- try(decode_fields(binary)) + Ok(BeNoticeResponse(fields)) +} + +fn decode_error_response(binary) { + use fields <- try(decode_fields(binary)) + Ok(BeErrorResponse(fields)) +} + +pub type ErrorOrNoticeField { + Code(String) + Detail(String) + File(String) + Hint(String) + Line(String) + Message(String) + Position(String) + Routine(String) + SeverityLocalized(String) + Severity(String) + Where(String) + Column(String) + DataType(String) + Constraint(String) + InternalPosition(String) + InternalQuery(String) + Schema(String) + Table(String) + Unknown(key: BitArray, value: String) +} + +fn decode_fields(binary) { + case decode_fields_rec(binary, []) { + Ok(fields) -> + fields + |> list.map(fn(key_value_raw) { + let #(key, value) = key_value_raw + case key { + <<"S":utf8>> -> Severity(value) + <<"V":utf8>> -> SeverityLocalized(value) + <<"C":utf8>> -> Code(value) + <<"M":utf8>> -> Message(value) + <<"D":utf8>> -> Detail(value) + <<"H":utf8>> -> Hint(value) + <<"P":utf8>> -> Position(value) + <<"p":utf8>> -> InternalPosition(value) + <<"q":utf8>> -> InternalQuery(value) + <<"W":utf8>> -> Where(value) + <<"s":utf8>> -> Schema(value) + <<"t":utf8>> -> Table(value) + <<"c":utf8>> -> Column(value) + <<"d":utf8>> -> DataType(value) + <<"n":utf8>> -> Constraint(value) + <<"F":utf8>> -> File(value) + <<"L":utf8>> -> Line(value) + <<"R":utf8>> -> Routine(value) + _ -> Unknown(key, value) + } + }) + |> set.from_list() + |> Ok + Error(err) -> Error(err) + } +} + +fn decode_fields_rec(binary, result) { + case binary { + <<0>> | <<>> -> Ok(result) + <> -> { + case binary_split(rest, <<0>>, []) { + [head, tail] -> { + case bit_array.to_string(head) { + Ok(value) -> + decode_fields_rec(tail, [#(field_type, value), ..result]) + Error(Nil) -> dec_err("invalid field encoding", binary) + } + } + _ -> dec_err("invalid field separator", binary) + } + } + _ -> dec_err("invalid field", binary) + } +} + +type BinarySplitOption { + Global +} + +@external(erlang, "binary", "split") +fn binary_split( + subject: BitArray, + pattern: BitArray, + options: List(BinarySplitOption), +) -> List(BitArray) diff --git a/src/squirrel/internal/error.gleam b/src/squirrel/internal/error.gleam new file mode 100644 index 0000000..595013e --- /dev/null +++ b/src/squirrel/internal/error.gleam @@ -0,0 +1,638 @@ +import glam/doc.{type Document} +import gleam/int +import gleam/list +import gleam/option.{type Option, None, Some} +import gleam/regex +import gleam/result +import gleam/string +import gleam_community/ansi +import simplifile + +pub type Error { + // --- POSTGRES RELATED ERRORS ----------------------------------------------- + /// When authentication workflow goes wrong. + /// TODO)) For now I only support no authentication so people might report + /// this issue. + /// + PgCannotAuthenticate(expected: String, got: String) + + /// When there's an error with the underlying socket and I cannot send + /// messages to the server. + /// + PgCannotSendMessage(reason: String) + + /// This comes from the `postgres_protocol` module, it happens if the server + /// sends back a malformed message or there's a type of messages decoding is + /// not implemented for. + /// This should never happen and warrants a bug report! + /// + PgCannotDecodeReceivedMessage(reason: String) + + /// When there's an error with the underlying socket and I cannot receive + /// messages to the server. + /// + PgCannotReceiveMessage(reason: String) + + /// When I cannot get a query description back from the postgres server. + /// + PgCannotDescribeQuery( + file: String, + query_name: String, + expected: String, + got: String, + ) + + // --- OTHER GENERIC ERRORS -------------------------------------------------- + /// When I cannot read a file containing queries. + /// + CannotReadFile(file: String, reason: simplifile.FileError) + + /// When the generated code cannot be written to a file. + /// + CannotWriteToFile(file: String, reason: simplifile.FileError) + + /// If an ".sql" file holding a query has a name that is not a valid Gleam + /// name. + /// Instead of trying to magically come up with a name we fail and report the + /// error. + /// + QueryFileHasInvalidName( + file: String, + suggested_name: Option(String), + reason: ValueIdentifierError, + ) + + /// If a query returns a column that is not a valid Gleam identifier. Instead + /// of trying to magically come up with a name we fail and report the error. + /// + QueryHasInvalidColumn( + file: String, + column_name: String, + suggested_name: Option(String), + content: String, + starting_line: Int, + reason: ValueIdentifierError, + ) + + /// When there's a param/return type that cannot be converted into a Gleam + /// type. + /// + QueryHasUnsupportedType( + file: String, + name: String, + content: String, + starting_line: Int, + type_: String, + ) + + /// If the query contains an error and cannot be parsed by the DBMS. + /// + CannotParseQuery( + file: String, + name: String, + content: String, + starting_line: Int, + error_code: Option(String), + pointer: Option(Pointer), + hint: Option(String), + ) +} + +pub type ValueIdentifierError { + DoesntStartWithLowercaseLetter + ContainsInvalidGrapheme(at: Int, grapheme: String) + IsEmpty +} + +/// Used to literally point to a particular piece of a string and attach a +/// message to that point. +/// +pub type Pointer { + Pointer(point_to: PointerKind, message: String) +} + +/// A pointer could either point to a specific byte of a String or it could +/// point at a specific word (in that case it will point to the first occurrence +/// of such word). +/// +pub type PointerKind { + Name(name: String) + ByteIndex(position: Int) +} + +pub fn to_doc(error: Error) -> Document { + // Errors as they are, are not that easy to print. What we do here is turn + // each error into an easier-to-print data structure: a `PrintableError` + // using a nice declarative API. So we can ignore all the gory details of how + // that is actually printed and we do not have to make any effort to add and + // print new errors. + let printable_error = case error { + PgCannotSendMessage(reason: reason) -> + printable_error("Cannot send message") + |> add_paragraph( + "I ran into an unexpected error while trying to talk to the Postgres +database server.", + ) + |> report_bug(reason) + + PgCannotDecodeReceivedMessage(reason: reason) -> + printable_error("Cannot decode message") + |> add_paragraph( + "I ran into an unexpected error while trying to decode a message +received from the Postgres database server.", + ) + |> report_bug(reason) + + PgCannotReceiveMessage(reason: reason) -> + printable_error("Cannot receive message") + |> add_paragraph( + "I ran into an unexpected error while trying to listen to the Postgres +database server.", + ) + |> report_bug(reason) + + CannotReadFile(file: file, reason: reason) -> + printable_error("Cannot read file") + |> add_paragraph( + "I couldn't read " + <> style_file(file) + <> " because of the following error: " + <> simplifile.describe_error(reason), + ) + + CannotWriteToFile(file: file, reason: reason) -> + printable_error("Cannot write to file") + |> add_paragraph( + "I couldn't write to " + <> style_file(file) + <> " because of the following error: " + <> simplifile.describe_error(reason), + ) + + QueryFileHasInvalidName( + file: file, + suggested_name: suggested_name, + reason: _, + ) -> + printable_error("Query file with invalid name") + |> add_paragraph( + "File " <> style_file(file) <> " doesn't have a valid name. +The name of a file is used to generate a corresponding Gleam function, so it +should be a valid Gleam name.", + ) + |> hint("A file name must start with a lowercase letter and can only +contain lowercase letters, numbers and underscores." <> case suggested_name { + Some(name) -> + "\nMaybe try renaming it to " <> style_inline_code(name) <> "?" + None -> "" + }) + + QueryHasInvalidColumn( + file: file, + column_name: column_name, + suggested_name: suggested_name, + content: content, + reason: reason, + starting_line: starting_line, + ) -> + case reason { + IsEmpty -> + printable_error("Column with empty name") + |> add_code_paragraph( + file: file, + content: content, + point: None, + starting_line: starting_line, + ) + |> add_paragraph( + "A column returned by this query has the empty string as a name, +all columns should have a valid Gleam name as name.", + ) + + _ -> + printable_error("Column with invalid name") + |> add_code_paragraph( + file: file, + content: content, + starting_line: starting_line, + point: Some( + Pointer(point_to: Name(column_name), message: case + suggested_name + { + None -> "This is not a valid Gleam name" + Some(suggestion) -> + "This is not a valid Gleam name, maybe try " + <> style_inline_code(suggestion) + <> "?" + }), + ), + ) + |> hint( + "A column name must start with a lowercase letter and can only +contain lowercase letters, numbers and underscores.", + ) + } + + QueryHasUnsupportedType( + file: file, + name: _, + content: content, + type_: type_, + starting_line: starting_line, + ) -> + printable_error("Unsupported type") + |> add_code_paragraph( + file: file, + content: content, + point: None, + starting_line: starting_line, + ) + |> add_paragraph( + "One of the rows returned by this query has type " + <> style_inline_code(type_) + <> " which I cannot currently generate code for.", + ) + |> call_to_action(for: "this type to be supported") + + CannotParseQuery( + file: file, + name: _name, + content: content, + starting_line: starting_line, + error_code: error_code, + hint: hint, + pointer: pointer, + ) -> + printable_error(case error_code { + Some(code) -> "Invalid query [" <> code <> "]" + None -> "Invalid query" + }) + |> add_code_paragraph( + file: file, + content: content, + point: pointer, + starting_line: starting_line, + ) + |> maybe_hint(hint) + + PgCannotAuthenticate(expected: expected, got: got) -> + printable_error("Cannot authenticate") + |> add_paragraph( + "I ran into an unexpected problem while trying to authenticate with the +Postgres server. This is most definitely a bug!", + ) + |> report_bug("Expected: " <> expected <> ", Got: " <> got) + + PgCannotDescribeQuery( + file: file, + query_name: query_name, + expected: expected, + got: got, + ) -> + printable_error("Cannot inspect query") + |> add_paragraph("I ran into an unexpected problem while trying to figure +out the types of query " <> style_inline_code(query_name) <> " +defined in " <> style_file(file) <> ". This is most definitely a bug!") + |> report_bug("Expected: " <> expected <> ", Got: " <> got) + } + + printable_error_to_doc(printable_error) +} + +fn style_file(file: String) -> String { + ansi.underline(file) +} + +fn style_inline_code(code: String) -> String { + "`" <> code <> "`" +} + +fn style_link(link: String) -> String { + ansi.underline(link) +} + +// --- ERROR PRETTY PRINTING --------------------------------------------------- + +const indent = 2 + +type PrintableError { + PrintableError( + title: String, + body: List(Paragraph), + report_bug: Option(String), + call_to_action: Option(String), + hint: Option(String), + ) +} + +type Paragraph { + Simple(String) + Code( + file: String, + content: String, + pointer: Option(Pointer), + starting_line: Int, + ) +} + +/// A default printable error with just a title. +/// +fn printable_error(title: String) -> PrintableError { + PrintableError( + title: title, + body: [], + report_bug: None, + hint: None, + call_to_action: None, + ) +} + +fn add_paragraph(error: PrintableError, string: String) -> PrintableError { + PrintableError(..error, body: list.append(error.body, [Simple(string)])) +} + +fn add_code_paragraph( + error: PrintableError, + file file: String, + content content: String, + point point: Option(Pointer), + starting_line starting_line: Int, +) -> PrintableError { + PrintableError( + ..error, + body: list.append(error.body, [ + Code( + file: file, + content: content, + pointer: point, + starting_line: starting_line, + ), + ]), + ) +} + +/// Sets a call to action to report a specific bug. +/// +fn report_bug(error: PrintableError, report_bug: String) -> PrintableError { + PrintableError(..error, report_bug: Some(report_bug)) +} + +/// Sets a hint that will be displayed at the bottom of the error message. +/// +fn hint(error: PrintableError, hint: String) -> PrintableError { + PrintableError(..error, hint: Some(hint)) +} + +/// Sets a hint that will be displayed at the bottom of the error message. +/// +fn maybe_hint(error: PrintableError, hint: Option(String)) -> PrintableError { + PrintableError(..error, hint: hint) +} + +/// Given something a user might want to be added to the package it sets a +/// call to action message telling someone to open a ticket on the `squirrel` +/// repo. +/// +fn call_to_action(error: PrintableError, for wanted: String) -> PrintableError { + PrintableError(..error, call_to_action: Some(wanted)) +} + +fn printable_error_to_doc(error: PrintableError) -> Document { + // And now for the tricky bit... + let PrintableError( + title: title, + body: body, + report_bug: report_bug, + call_to_action: call_to_action, + hint: hint, + ) = error + + [ + title_doc(title), + body_doc(body), + option_to_doc(report_bug, report_bug_doc), + option_to_doc(call_to_action, call_to_action_doc), + option_to_doc(hint, hint_doc), + ] + |> list.filter(keeping: fn(doc) { doc != doc.empty }) + |> doc.join(with: doc.lines(2)) + |> doc.group +} + +fn title_doc(title: String) -> Document { + doc.from_string(ansi.red(ansi.bold("Error: ") <> title)) +} + +fn body_doc(body: List(Paragraph)) -> Document { + list.map(body, paragraph_doc) + |> doc.join(with: doc.line) + |> doc.group +} + +fn paragraph_doc(paragraph: Paragraph) -> Document { + case paragraph { + Simple(string) -> flexible_string(string) + Code( + file: file, + content: content, + pointer: pointer, + starting_line: starting_line, + ) -> + code_doc( + file: file, + content: content, + pointer: pointer, + starting_line: starting_line, + ) + } +} + +fn code_doc( + file file: String, + content content: String, + pointer pointer: Option(Pointer), + starting_line starting_line: Int, +) { + let pointer = + option.to_result(pointer, Nil) + |> result.then(pointer_doc(_, content)) + + let content = syntax_highlight(content) + let lines = string.split(content, on: "\n") + let lines_count = list.length(lines) + let assert Ok(digits) = int.digits(lines_count + starting_line, 10) + let max_digits = list.length(digits) + + let code_lines = { + use line, i <- list.index_map(lines) + let prefix = + int.to_string(i + starting_line) + |> string.pad_left(to: max_digits + 2, with: " ") + + case pointer { + Ok(#(pointer_line, from, pointer_doc)) if pointer_line == i -> [ + doc.from_string(ansi.dim(prefix <> " โ”‚ ")), + doc.from_string(line), + [doc.line, pointer_doc] + |> doc.concat + |> doc.nest(by: from + max_digits + 5), + ] + + Ok(_) | Error(_) -> [ + doc.from_string(ansi.dim(prefix <> " โ”‚ ")), + doc.from_string(ansi.dim(line)), + ] + } + |> doc.concat + } + + let padding = string.repeat(" ", max_digits + 3) + [ + doc.from_string(padding <> ansi.dim("โ•ญโ”€ " <> file)), + case starting_line { + 1 -> doc.from_string(padding <> ansi.dim("โ”‚ ")) + _ -> doc.from_string(padding <> ansi.dim("โ”† ")) + }, + ..code_lines + ] + |> doc.join(with: doc.line) + |> doc.append(doc.line) + |> doc.append(doc.from_string(padding <> ansi.dim("โ”†"))) + |> doc.group +} + +fn pointer_doc( + pointer: Pointer, + content: String, +) -> Result(#(Int, Int, Document), Nil) { + let Pointer(kind, message) = pointer + use #(line, from, to) <- result.try(find_span(kind, content)) + let width = to - from + 1 + let doc = + [ + doc.zero_width_string("\u{001B}[31m"), + doc.from_string("โ”ฌ" <> string.repeat("โ”€", width - 1)), + doc.line, + doc.from_string("โ•ฐโ”€ "), + flexible_string(message) + |> doc.nest(by: 3), + doc.zero_width_string("\u{001B}[0m"), + ] + |> doc.concat + |> doc.group + + Ok(#(line, from, doc)) +} + +fn find_span(kind: PointerKind, string: String) -> Result(#(Int, Int, Int), Nil) { + case kind { + Name(name) -> find_name_span(name, string.length(name), string, 0, 0) + ByteIndex(n) -> find_byte_span(n - 1, string, 0, 0) + } +} + +fn find_name_span( + name: String, + name_len: Int, + string: String, + row: Int, + col: Int, +) -> Result(#(Int, Int, Int), Nil) { + case string.starts_with(string, name) { + True -> Ok(#(row, col, col + name_len - 1)) + False -> + case string.pop_grapheme(string) { + Ok(#("\n", rest)) -> find_name_span(name, name_len, rest, row + 1, 0) + Ok(#(_, rest)) -> find_name_span(name, name_len, rest, row, col + 1) + Error(_) -> Error(Nil) + } + } +} + +fn find_byte_span( + position: Int, + string: String, + row: Int, + col: Int, +) -> Result(#(Int, Int, Int), Nil) { + case position { + 0 -> Ok(#(row, col, col)) + n -> + case string.pop_grapheme(string) { + Ok(#("\n", rest)) -> find_byte_span(n - 1, rest, row + 1, 0) + Ok(#(_, rest)) -> find_byte_span(n - 1, rest, row, col + 1) + Error(_) -> Error(Nil) + } + } +} + +const keywords = [ + "and", "any", "as", "asc", "begin", "between", "by", "case", "count", "desc", + "distinct", "else", "end", "exists", "from", "full", "group", "having", "if", + "in", "inner", "insert", "into", "join", "key", "left", "like", "not", "null", + "on", "or", "order", "primary", "revert", "right", "select", "set", "table", + "top", "trigger", "union", "update", "use", "values", "view", "where", "with", +] + +fn syntax_highlight(content: String) -> String { + let keywords = string.join(keywords, with: "|") + let not_inside_string = "(?=(?:[^']*'[^']*')*[^']*$)" + + let assert Ok(keyword) = + { "\\b(" <> keywords <> ")\\b" <> not_inside_string } + |> regex.compile(regex.Options(True, False)) + let assert Ok(number) = + regex.from_string("(? not_inside_string) + let assert Ok(comment) = regex.from_string("(^\\s*--.*)") + let assert Ok(string) = regex.from_string("(\\'.*\\')") + let assert Ok(hole) = regex.from_string("(\\$\\d+)" <> not_inside_string) + + content + |> regex.replace(each: comment, with: "\u{001B}[2m\\1\u{001B}[0m") + |> regex.replace(each: keyword, with: "\u{001B}[36m\\1\u{001B}[39m") + |> regex.replace(each: string, with: "\u{001B}[33m\\1\u{001B}[39m") + |> regex.replace(each: number, with: "\u{001B}[32m\\1\u{001B}[39m") + |> regex.replace(each: hole, with: "\u{001B}[35m\\1\u{001B}[39m") +} + +fn report_bug_doc(additional_info: String) -> Document { + [ + flexible_string( + "Please open an issue at " + <> style_link("https://github.com/giacomocavalieri/squirrel/issues/new") + <> " with some details about what you where doing, including the following message:", + ), + doc.line |> doc.nest(by: indent), + doc.from_string(additional_info), + ] + |> doc.concat + |> doc.group +} + +fn call_to_action_doc(wanted: String) -> Document { + flexible_string( + "If you would like for " + <> wanted + <> ", please open an issue at " + <> style_link("https://github.com/giacomocavalieri/squirrel/issues/new"), + ) +} + +fn hint_doc(hint: String) -> Document { + flexible_string("Hint: " <> hint) +} + +fn flexible_string(string: String) -> Document { + string.split(string, on: "\n") + |> list.flat_map(string.split(_, on: " ")) + |> list.map(doc.from_string) + |> doc.join(with: doc.flex_space) + |> doc.group +} + +fn option_to_doc(option: Option(a), fun: fn(a) -> Document) -> Document { + case option { + Some(a) -> fun(a) + None -> doc.empty + } +} diff --git a/src/squirrel/internal/eval_extra.gleam b/src/squirrel/internal/eval_extra.gleam new file mode 100644 index 0000000..c5caca0 --- /dev/null +++ b/src/squirrel/internal/eval_extra.gleam @@ -0,0 +1,77 @@ +//// This package has some additional helpers to work with the `Eval` package. +//// + +import eval.{type Eval} +import gleam/list + +pub fn try_map( + list: List(a), + fun: fn(a) -> Eval(b, _, _), +) -> Eval(List(b), _, _) { + try_fold(list, [], fn(acc, item) { + use mapped_item <- eval.try(fun(item)) + eval.return([mapped_item, ..acc]) + }) + |> eval.map(list.reverse) +} + +/// Runs a list of `Eval` actions in sequence sharing the same context. +/// +pub fn run_all(list: List(Eval(a, b, c)), context: c) -> List(Result(a, b)) { + let acc = #([], context) + let #(results, _) = { + use #(results, context), script <- list.fold(list, acc) + let #(context, result) = eval.step(script, context) + #([result, ..results], context) + } + + results +} + +pub fn try_index_map( + list: List(a), + fun: fn(a, Int) -> Eval(b, _, _), +) -> Eval(List(b), _, _) { + try_index_fold(list, [], fn(acc, item, i) { + use mapped_item <- eval.try(fun(item, i)) + eval.return([mapped_item, ..acc]) + }) + |> eval.map(list.reverse) +} + +pub fn try_fold( + over list: List(a), + from acc: b, + with fun: fn(b, a) -> Eval(b, _, _), +) -> Eval(b, _, _) { + case list { + [] -> eval.return(acc) + [first, ..rest] -> { + use acc <- eval.try(fun(acc, first)) + try_fold(rest, acc, fun) + } + } +} + +fn try_index_fold( + over list: List(a), + from acc: b, + with fun: fn(b, a, Int) -> Eval(b, _, _), +) -> Eval(b, _, _) { + do_try_index_fold(0, over: list, from: acc, with: fun) +} + +fn do_try_index_fold( + index: Int, + over list: List(a), + from acc: b, + with fun: fn(b, a, Int) -> Eval(b, _, _), +) -> Eval(b, _, _) { + case list { + [] -> eval.return(acc) + [first, ..rest] -> { + use acc <- eval.try(fun(acc, first, index)) + do_try_index_fold(index + 1, rest, acc, fun) + } + } +} diff --git a/src/squirrel/internal/gleam.gleam b/src/squirrel/internal/gleam.gleam new file mode 100644 index 0000000..7019a25 --- /dev/null +++ b/src/squirrel/internal/gleam.gleam @@ -0,0 +1,224 @@ +import gleam/list +import gleam/string +import justin +import squirrel/internal/error.{ + type ValueIdentifierError, ContainsInvalidGrapheme, IsEmpty, +} + +/// A Gleam type. +/// +pub type Type { + List(Type) + Option(Type) + Int + Float + Bool + String +} + +/// The labelled field of a Gleam record. +/// +pub type Field { + Field(label: ValueIdentifier, type_: Type) +} + +/// A Gleam identifier, that is a string that starts with a lowercase letter, +/// is in snake_case and can only contain lowercase letters, numbers and +/// underscores. +/// +/// > ๐Ÿ’ก This can only be built using the `gleam.identifier` function that +/// > ensures that a string is a valid Gleam identifier. +/// +pub opaque type ValueIdentifier { + ValueIdentifier(String) +} + +/// Returns true if the given string is a valid Gleam identifier (that is not +/// a discard identifier, that is starting with an '_'). +/// +/// > ๐Ÿ’ก A valid identifier can be described by the following regex: +/// > `[a-z][a-z0-9_]*`. +pub fn identifier( + from name: String, +) -> Result(ValueIdentifier, ValueIdentifierError) { + // A valid identifier needs to start with a lowercase letter. + // We do not accept _discard identifier as valid. + case name { + "a" <> rest + | "b" <> rest + | "c" <> rest + | "d" <> rest + | "e" <> rest + | "f" <> rest + | "g" <> rest + | "h" <> rest + | "i" <> rest + | "j" <> rest + | "k" <> rest + | "l" <> rest + | "m" <> rest + | "n" <> rest + | "o" <> rest + | "p" <> rest + | "q" <> rest + | "r" <> rest + | "s" <> rest + | "t" <> rest + | "u" <> rest + | "v" <> rest + | "w" <> rest + | "x" <> rest + | "y" <> rest + | "z" <> rest -> to_identifier_rest(name, rest, 1) + _ -> + case string.pop_grapheme(name) { + Ok(#(g, _)) -> Error(ContainsInvalidGrapheme(0, g)) + Error(_) -> Error(IsEmpty) + } + } +} + +fn to_identifier_rest( + name: String, + rest: String, + position: Int, +) -> Result(ValueIdentifier, ValueIdentifierError) { + // The rest of an identifier can only contain lowercase letters, _, numbers, + // or be empty. In all other cases it's not valid. + case rest { + "a" <> rest + | "b" <> rest + | "c" <> rest + | "d" <> rest + | "e" <> rest + | "f" <> rest + | "g" <> rest + | "h" <> rest + | "i" <> rest + | "j" <> rest + | "k" <> rest + | "l" <> rest + | "m" <> rest + | "n" <> rest + | "o" <> rest + | "p" <> rest + | "q" <> rest + | "r" <> rest + | "s" <> rest + | "t" <> rest + | "u" <> rest + | "v" <> rest + | "w" <> rest + | "x" <> rest + | "y" <> rest + | "z" <> rest + | "_" <> rest + | "0" <> rest + | "1" <> rest + | "2" <> rest + | "3" <> rest + | "4" <> rest + | "5" <> rest + | "6" <> rest + | "7" <> rest + | "8" <> rest + | "9" <> rest -> to_identifier_rest(name, rest, position + 1) + "" -> Ok(ValueIdentifier(name)) + _ -> + case string.pop_grapheme(rest) { + Ok(#(g, _)) -> Error(ContainsInvalidGrapheme(position, g)) + Error(_) -> panic as "unreachable: empty identifier rest should be ok" + } + } +} + +/// Turns an identifier back into a String. +/// +pub fn identifier_to_string(identifier: ValueIdentifier) -> String { + let ValueIdentifier(name) = identifier + name +} + +/// Turns a Gleam identifier into a type name. That is it strips it of all its +/// underscores and makes it PascalCase. +/// +pub fn identifier_to_type_name(identifier: ValueIdentifier) -> String { + let ValueIdentifier(name) = identifier + + justin.pascal_case(name) + |> string.to_graphemes + // We want to remove any leftover "_" that might still be present after the + // conversion if the identifier had consecutive "_". + |> list.filter(keeping: fn(c) { c != "_" }) + |> string.join(with: "") +} + +/// Tries to suggest a valid Gleam identifier as similar as possible to a given +/// String. +/// +/// If it cannot come up with a suggestion, it returns `Error(Nil)`. +/// +pub fn similar_identifier_string(string: String) -> Result(String, Nil) { + let proposal = + string.trim(string) + |> justin.snake_case + |> string.to_graphemes + |> list.drop_while(fn(g) { g == "_" || is_digit(g) }) + |> list.filter(keeping: is_identifier_char) + |> string.join(with: "") + + case proposal { + "" -> Error(Nil) + _ -> Ok(proposal) + } +} + +fn is_digit(char: String) -> Bool { + case char { + "0" | "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9" -> True + _ -> False + } +} + +fn is_identifier_char(char: String) -> Bool { + case char { + "a" + | "b" + | "c" + | "d" + | "e" + | "f" + | "g" + | "h" + | "i" + | "j" + | "k" + | "l" + | "m" + | "n" + | "o" + | "p" + | "q" + | "r" + | "s" + | "t" + | "u" + | "v" + | "w" + | "x" + | "y" + | "z" + | "_" + | "0" + | "1" + | "2" + | "3" + | "4" + | "5" + | "6" + | "7" + | "8" + | "9" -> True + _ -> False + } +} diff --git a/src/squirrel/internal/query.gleam b/src/squirrel/internal/query.gleam new file mode 100644 index 0000000..eb1649a --- /dev/null +++ b/src/squirrel/internal/query.gleam @@ -0,0 +1,418 @@ +import filepath +import glam/doc.{type Document} +import gleam/int +import gleam/list +import gleam/option.{type Option} +import gleam/result +import gleam/string +import simplifile +import squirrel/internal/error.{ + type Error, CannotReadFile, QueryFileHasInvalidName, +} +import squirrel/internal/gleam + +/// A query that still needs to go through the type checking process. +/// +pub type UntypedQuery { + UntypedQuery( + /// The file the query comes from. + /// + file: String, + /// The starting line in the source file where the query is defined. + /// + starting_line: Int, + /// The name of the query, it must be a valid Gleam identifier. + /// + name: gleam.ValueIdentifier, + /// Any comment lines that were preceding the query in the file. + /// + comment: List(String), + /// The text of the query itself. + /// + content: String, + ) +} + +/// This is exactly the same as an untyped query with the difference that it +/// has also been annotated with the type of its parameters and returned values. +/// +pub type TypedQuery { + TypedQuery( + file: String, + starting_line: Int, + name: gleam.ValueIdentifier, + comment: List(String), + content: String, + params: List(gleam.Type), + returns: List(gleam.Field), + ) +} + +/// Turns an untyped query into a typed one. +/// +pub fn add_types( + to query: UntypedQuery, + params params: List(gleam.Type), + returns returns: List(gleam.Field), +) -> TypedQuery { + let UntypedQuery( + file: file, + name: name, + comment: comment, + content: content, + starting_line: starting_line, + ) = query + TypedQuery( + file: file, + name: name, + comment: comment, + content: content, + starting_line: starting_line, + params: params, + returns: returns, + ) +} + +// --- PARSING ----------------------------------------------------------------- + +/// Reads a query from a file. +/// This expects the user to follow the convention of having a single query per +/// file. +/// +pub fn from_file(file: String) -> Result(UntypedQuery, Error) { + let read_file = + simplifile.read(file) + |> result.map_error(CannotReadFile(file, _)) + + use content <- result.try(read_file) + + // A query always starts at the top of the file. + // If in the future I want to add support for many queries per file this + // field will be handy to properly show error messages. + let file_name = + filepath.base_name(file) + |> filepath.strip_extension + let name = + gleam.identifier(file_name) + |> result.map_error(QueryFileHasInvalidName( + file: file, + reason: _, + suggested_name: gleam.similar_identifier_string(file_name) + |> option.from_result, + )) + + use name <- result.try(name) + Ok(UntypedQuery( + file: file, + starting_line: 1, + name: name, + content: content, + comment: take_comment(content), + )) +} + +fn take_comment(query: String) -> List(String) { + do_take_comment(query, []) +} + +fn do_take_comment(query: String, lines: List(String)) -> List(String) { + case string.trim_left(query) { + "--" <> rest -> + case string.split_once(rest, on: "\n") { + Ok(#(line, rest)) -> do_take_comment(rest, [string.trim(line), ..lines]) + _ -> do_take_comment("", [string.trim(rest), ..lines]) + } + _ -> list.reverse(lines) + } +} + +// --- CODE GENERATION --------------------------------------------------------- + +pub fn generate_code(version: String, query: TypedQuery) -> String { + let TypedQuery( + file: file, + name: name, + content: content, + comment: comment, + params: params, + returns: returns, + starting_line: _, + ) = query + + let arg_name = fn(i) { "arg_" <> int.to_string(i + 1) } + let inputs = list.index_map(params, fn(_, i) { arg_name(i) }) + let inputs_encoders = + list.index_map(params, fn(p, i) { + gleam_type_to_encoder(p, arg_name(i)) |> doc.from_string + }) + + let function_name = gleam.identifier_to_string(name) + let constructor_name = gleam.identifier_to_type_name(name) <> "Row" + + let record_doc = + "/// A row you get from running the `" <> function_name <> "` query +/// defined in `" <> file <> "`. +/// +/// > ๐Ÿฟ๏ธ This type definition was generated automatically using " <> version <> " of the +/// > [squirrel package](https://github.com/giacomocavalieri/squirrel). +///" + + let fun_doc = case comment { + [] -> "/// Runs the `" <> function_name <> "` query +/// defined in `" <> file <> "`." + [_, ..] -> + list.map(comment, string.append("/// ", _)) + |> string.join(with: "\n") + } + let fun_doc = fun_doc <> " +/// +/// > ๐Ÿฟ๏ธ This function was generated automatically using " <> version <> " of +/// > the [squirrel package](https://github.com/giacomocavalieri/squirrel). +///" + + [ + doc.from_string(record_doc), + doc.line, + record(constructor_name, returns), + doc.lines(2), + doc.from_string(fun_doc), + doc.line, + fun(function_name, ["db", ..inputs], [ + var("decoder", decoder(constructor_name, returns)), + pipe_call("pgo.execute", string(content), [ + doc.from_string("db"), + list(inputs_encoders), + doc.from_string("decode.from(decoder, _)"), + ]), + ]), + ] + |> doc.concat + |> doc.to_string(80) +} + +fn gleam_type_to_decoder(type_: gleam.Type) -> String { + case type_ { + gleam.List(type_) -> "decode.list(" <> gleam_type_to_decoder(type_) <> ")" + gleam.Int -> "decode.int" + gleam.Float -> "decode.float" + gleam.Bool -> "decode.bool" + gleam.String -> "decode.string" + gleam.Option(type_) -> + "decode.optional(" <> gleam_type_to_decoder(type_) <> ")" + } +} + +fn gleam_type_to_encoder(type_: gleam.Type, name: String) { + case type_ { + gleam.List(type_) -> + "pgo.array(list.map(" + <> name + <> ", fn(a) {" + <> gleam_type_to_encoder(type_, "a") + <> "}))" + gleam.Option(type_) -> + "pgo.nullable(fn(a) {" + <> gleam_type_to_encoder(type_, "a") + <> "}, " + <> name + <> ")" + gleam.Int -> "pgo.int(" <> name <> ")" + gleam.Float -> "pgo.float(" <> name <> ")" + gleam.Bool -> "pgo.bool(" <> name <> ")" + gleam.String -> "pgo.text(" <> name <> ")" + } +} + +// --- CODE GENERATION PRETTY PRINTING ----------------------------------------- +// These are just a couple of handy helpers to make it easier to generate code +// for a query. +// +// It makes a best effort to also make the generated code look nice. +// Due to some missing features in `glam`, it doesn't reimplement 100% of +// Gleam's own pretty printer so it might have a different look in some places. +// + +const indent = 2 + +pub fn record(name: String, fields: List(gleam.Field)) -> Document { + let fields = + list.map(fields, fn(field) { + let label = gleam.identifier_to_string(field.label) + + [doc.from_string(label <> ": "), pretty_gleam_type(field.type_)] + |> doc.concat + |> doc.group + }) + + [ + doc.from_string("pub type " <> name <> " {"), + [doc.line, call(name, fields)] + |> doc.concat + |> doc.nest(by: indent), + doc.line, + doc.from_string("}"), + ] + |> doc.concat + |> doc.group +} + +fn pretty_gleam_type(type_: gleam.Type) -> Document { + case type_ { + gleam.List(type_) -> call("List", [pretty_gleam_type(type_)]) + gleam.Option(type_) -> call("Option", [pretty_gleam_type(type_)]) + gleam.Int -> doc.from_string("Int") + gleam.Float -> doc.from_string("Float") + gleam.Bool -> doc.from_string("Bool") + gleam.String -> doc.from_string("String") + } +} + +/// A pretty printed public function definition. +/// +pub fn fun(name: String, args: List(String), body: List(Document)) -> Document { + let args = list.map(args, doc.from_string) + + [ + doc.from_string("pub fn " <> name), + comma_list("(", args, ") "), + block([body |> doc.join(with: doc.lines(2))]), + doc.line, + ] + |> doc.concat + |> doc.group +} + +/// A pretty printed let assignment. +/// +pub fn var(name: String, body: Document) -> Document { + [ + doc.from_string("let " <> name <> " ="), + [doc.space, body] + |> doc.concat + |> doc.group + |> doc.nest(by: indent), + ] + |> doc.concat +} + +/// A pretty printed Gleam string. +/// +/// > โš ๏ธ This function escapes all `\` and `"` inside the original string to +/// > avoid generating invalid Gleam code. +/// +pub fn string(content: String) -> Document { + let escaped_string = + content + |> string.replace(each: "\\", with: "\\\\") + |> string.replace(each: "\"", with: "\\\"") + |> doc.from_string + + [doc.from_string("\""), escaped_string, doc.from_string("\"")] + |> doc.concat +} + +/// A pretty printed Gleam list. +/// +pub fn list(elems: List(Document)) -> Document { + comma_list("[", elems, "]") +} + +/// A pretty printed decoder that decodes an n-item dynamic tuple using the +/// `decode` package. +/// +pub fn decoder(constructor: String, returns: List(gleam.Field)) -> Document { + let parameters = + list.map(returns, fn(field) { + let label = gleam.identifier_to_string(field.label) + doc.from_string("use " <> label <> " <- decode.parameter") + }) + + let pipes = + list.index_map(returns, fn(field, i) { + let position = int.to_string(i) |> doc.from_string + let decoder = gleam_type_to_decoder(field.type_) |> doc.from_string + call("|> decode.field", [position, decoder]) + }) + + let labelled_names = + list.map(returns, fn(field) { + let label = gleam.identifier_to_string(field.label) + doc.from_string(label <> ": " <> label) + }) + + [ + call_block("decode.into", [ + doc.join(parameters, with: doc.line), + doc.line, + call(constructor, labelled_names), + ]), + doc.line, + doc.join(pipes, with: doc.line), + ] + |> doc.concat() + |> doc.group +} + +/// A pretty printed function call where the first argument is piped into +/// the function. +/// +pub fn pipe_call( + function: String, + first: Document, + rest: List(Document), +) -> Document { + [first, doc.line, call("|> " <> function, rest)] + |> doc.concat +} + +/// A pretty printed function call. +/// +fn call(function: String, args: List(Document)) -> Document { + [doc.from_string(function), comma_list("(", args, ")")] + |> doc.concat + |> doc.group +} + +/// A pretty printed function call where the only argument is a single block. +/// +fn call_block(function: String, body: List(Document)) -> Document { + [doc.from_string(function <> "("), block(body), doc.from_string(")")] + |> doc.concat + |> doc.group +} + +/// A pretty printed Gleam block. +/// +fn block(body: List(Document)) -> Document { + [ + doc.from_string("{"), + [doc.line, ..body] + |> doc.concat + |> doc.nest(by: indent), + doc.line, + doc.from_string("}"), + ] + |> doc.concat + |> doc.force_break +} + +/// A comma separated list of items with some given open and closed delimiters. +/// +fn comma_list(open: String, content: List(Document), close: String) -> Document { + [ + doc.from_string(open), + [ + // We want the first break to be nested + // in case the group is broken. + doc.soft_break, + doc.join(content, doc.break(", ", ",")), + ] + |> doc.concat + |> doc.group + |> doc.nest(by: indent), + doc.break("", ","), + doc.from_string(close), + ] + |> doc.concat + |> doc.group +} diff --git a/test/squirrel_test.gleam b/test/squirrel_test.gleam new file mode 100644 index 0000000..c0b5c85 --- /dev/null +++ b/test/squirrel_test.gleam @@ -0,0 +1,216 @@ +import birdie +import filepath +import gleam/dynamic +import gleam/list +import gleam/pgo +import gleam/string +import gleeunit +import simplifile +import squirrel/internal/database/postgres +import squirrel/internal/error.{type Error} +import squirrel/internal/query.{type TypedQuery} +import temporary + +pub fn main() { + setup_database() + gleeunit.main() +} + +// --- TEST SETUP -------------------------------------------------------------- + +const host = "localhost" + +const user = "squirrel_test" + +const database = "squirrel_test" + +const port = 5432 + +fn setup_database() { + let config = + pgo.Config( + ..pgo.default_config(), + port: port, + user: user, + host: host, + database: database, + ) + let db = pgo.connect(config) + + let assert Ok(_) = + " +create table if not exists squirrel( + name text primary key, + acorns int +) +" + |> pgo.execute(db, [], dynamic.dynamic) + + pgo.disconnect(db) +} + +// --- ASSERTION HELPERS ------------------------------------------------------- + +fn should_codegen(query: String) -> String { + // We assert everything went smoothly and we have no errors in the query. + let assert Ok(#(queries, [])) = codegen_queries([#("query", query)]) + list.map(queries, query.generate_code("v-test", _)) + |> string.join(with: "\n\n") +} + +fn codegen_queries( + queries: List(#(String, String)), +) -> Result(#(List(TypedQuery), List(Error)), Error) { + // If there's any error with the temporary package we just fail the test, + // there's no reason to try and keep going. + let assert Ok(result) = { + use temp_dir <- temporary.create(temporary.directory()) + + // We parse all the queries. + let queries = { + use #(file, query) <- list.map(queries) + let out_file = filepath.join(temp_dir, file <> ".sql") + let assert Ok(_) = simplifile.write(to: out_file, contents: query) + let assert Ok(query) = query.from_file(out_file) + // We manually change the file name here: we do not want to use the full + // out_file name in tests because that will change across different runs, + // causing the snapshot tests to fail. + let query = query.UntypedQuery(..query, file: file <> ".sql") + query + } + + // We can then ask squirrel to type check all the queries. + postgres.main( + queries, + postgres.ConnectionOptions( + host: host, + port: port, + user: user, + database: database, + password: "", + timeout: 1000, + ), + ) + } + + result +} + +// --- ENCODING/DECODING CODEGEN TESTS ----------------------------------------- +// This is a group of tests to ensure the generated encoders/decoders are what +// we expect for all the supported data types. +// + +pub fn int_decoding_test() { + "select 11 as res" + |> should_codegen + |> birdie.snap(title: "int decoding") +} + +pub fn int_encoding_test() { + "select true as res where $1 = 11" + |> should_codegen + |> birdie.snap(title: "int encoding") +} + +pub fn float_decoding_test() { + "select 1.1 as res" + |> should_codegen + |> birdie.snap(title: "float decoding") +} + +pub fn float_encoding_test() { + "select true as res where $1 = 1.1" + |> should_codegen + |> birdie.snap(title: "float encoding") +} + +pub fn string_decoding_test() { + "select 'wibble' as res" + |> should_codegen + |> birdie.snap(title: "string decoding") +} + +pub fn string_encoding_test() { + "select true as res where $1 = 'wibble'" + |> should_codegen + |> birdie.snap(title: "string encoding") +} + +pub fn bool_decoding_test() { + "select true as res" + |> should_codegen + |> birdie.snap(title: "bool decoding") +} + +pub fn bool_encoding_test() { + "select true as res where $1 = true" + |> should_codegen + |> birdie.snap(title: "bool encoding") +} + +pub fn array_decoding_test() { + "select array[1, 2, 3] as res" + |> should_codegen + |> birdie.snap(title: "array decoding") +} + +pub fn array_encoding_test() { + "select true as res where $1 = array[1, 2, 3]" + |> should_codegen + |> birdie.snap(title: "array encoding") +} + +pub fn optional_decoding_test() { + "select acorns from squirrel" + |> should_codegen + |> birdie.snap(title: "optional decoding") +} + +// --- CODEGEN STRUCTURE TESTS ------------------------------------------------- +// This is a group of tests to ensure the generated code has some specific +// structure (e.g. the names and comments are what we expect...) +// + +pub fn query_with_comment_test() { + " +-- This is a comment +select true as res +" + |> should_codegen + |> birdie.snap(title: "query with comment") +} + +pub fn query_with_multiline_comment_test() { + " +-- This is a comment +-- that goes over multiple lines! +select true as res +" + |> should_codegen + |> birdie.snap(title: "query with multiline comment") +} + +pub fn generated_type_has_the_same_name_as_the_function_but_in_pascal_case_test() { + " +select true as res +" + |> should_codegen + |> birdie.snap( + title: "generated type has the same name as the function but in pascal case", + ) +} + +pub fn generated_type_fields_are_labelled_with_their_name_in_the_select_list_test() { + " +select + acorns, + name as squirrel_name +from + squirrel +" + |> should_codegen + |> birdie.snap( + title: "generated type fields are labelled with their name in the select list", + ) +}