diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..b8556db --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,23 @@ +name: test + +on: + push: + branches: + - master + - main + pull_request: + +jobs: + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: erlef/setup-beam@v1 + with: + otp-version: "26.0.2" + gleam-version: "1.4.0-rc1" + rebar3-version: "3" + # elixir-version: "1.15.4" + - run: gleam deps download + - run: gleam test + - run: gleam format --check src test diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..599be4e --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +*.beam +*.ez +/build +erl_crash.dump diff --git a/README.md b/README.md new file mode 100644 index 0000000..a64df9d --- /dev/null +++ b/README.md @@ -0,0 +1,21 @@ +# squirrel + +[![Package Version](https://img.shields.io/hexpm/v/squirrel)](https://hex.pm/packages/squirrel) +[![Hex Docs](https://img.shields.io/badge/hex-docs-ffaff3)](https://hexdocs.pm/squirrel/) + +```gleam +import squirrel + +pub fn main() { + // TODO: An example of the project in use +} +``` + +Further documentation can be found at . + +## Development + +```sh +gleam run # Run the project +gleam test # Run the tests +``` diff --git a/gleam.toml b/gleam.toml new file mode 100644 index 0000000..cf0a138 --- /dev/null +++ b/gleam.toml @@ -0,0 +1,23 @@ +name = "squirrel" +version = "0.1.0" +description = "🐿️ Type safe SQL in Gleam" +licences = ["Apache-2.0"] +repository = { type = "github", user = "giacomocavalieri", repo = "squirrel" } + +[dependencies] +gleam_stdlib = ">= 0.34.0 and < 2.0.0" +simplifile = ">= 2.0.1 and < 3.0.0" +eval = ">= 1.0.0 and < 2.0.0" +gleam_json = ">= 1.0.0 and < 2.0.0" +mug = ">= 1.1.0 and < 2.0.0" +glam = ">= 2.0.1 and < 3.0.0" +justin = ">= 1.0.1 and < 2.0.0" +filepath = ">= 1.0.0 and < 2.0.0" +gleam_community_ansi = ">= 1.4.0 and < 2.0.0" +term_size = ">= 1.0.1 and < 2.0.0" +gleam_community_colour = ">= 1.4.0 and < 2.0.0" +argv = ">= 1.0.2 and < 2.0.0" + +[dev-dependencies] +gleeunit = ">= 1.0.0 and < 2.0.0" +birdie = ">= 1.1.8 and < 2.0.0" diff --git a/manifest.toml b/manifest.toml new file mode 100644 index 0000000..cfbd47c --- /dev/null +++ b/manifest.toml @@ -0,0 +1,41 @@ +# This file was generated by Gleam +# You typically do not need to edit this file + +packages = [ + { name = "argv", version = "1.0.2", build_tools = ["gleam"], requirements = [], otp_app = "argv", source = "hex", outer_checksum = "BA1FF0929525DEBA1CE67256E5ADF77A7CDDFE729E3E3F57A5BDCAA031DED09D" }, + { name = "birdie", version = "1.1.8", build_tools = ["gleam"], requirements = ["argv", "filepath", "glance", "gleam_community_ansi", "gleam_erlang", "gleam_stdlib", "justin", "rank", "simplifile", "trie_again"], otp_app = "birdie", source = "hex", outer_checksum = "D225C0A3035FCD73A88402925A903AAD3567A1515C9EAE8364F11C17AD1805BB" }, + { name = "eval", version = "1.0.0", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "eval", source = "hex", outer_checksum = "264DAF4B49DF807F303CA4A4E4EBC012070429E40BE384C58FE094C4958F9BDA" }, + { name = "filepath", version = "1.0.0", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "filepath", source = "hex", outer_checksum = "EFB6FF65C98B2A16378ABC3EE2B14124168C0CE5201553DE652E2644DCFDB594" }, + { name = "glam", version = "2.0.1", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "glam", source = "hex", outer_checksum = "66EC3BCD632E51EED029678F8DF419659C1E57B1A93D874C5131FE220DFAD2B2" }, + { name = "glance", version = "0.11.0", build_tools = ["gleam"], requirements = ["gleam_stdlib", "glexer"], otp_app = "glance", source = "hex", outer_checksum = "8F3314D27773B7C3B9FB58D8C02C634290422CE531988C0394FA0DF8676B964D" }, + { name = "gleam_community_ansi", version = "1.4.0", build_tools = ["gleam"], requirements = ["gleam_community_colour", "gleam_stdlib"], otp_app = "gleam_community_ansi", source = "hex", outer_checksum = "FE79E08BF97009729259B6357EC058315B6FBB916FAD1C2FF9355115FEB0D3A4" }, + { name = "gleam_community_colour", version = "1.4.0", build_tools = ["gleam"], requirements = ["gleam_json", "gleam_stdlib"], otp_app = "gleam_community_colour", source = "hex", outer_checksum = "795964217EBEDB3DA656F5EB8F67D7AD22872EB95182042D3E7AFEF32D3FD2FE" }, + { name = "gleam_erlang", version = "0.25.0", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "gleam_erlang", source = "hex", outer_checksum = "054D571A7092D2A9727B3E5D183B7507DAB0DA41556EC9133606F09C15497373" }, + { name = "gleam_json", version = "1.0.1", build_tools = ["gleam"], requirements = ["gleam_stdlib", "thoas"], otp_app = "gleam_json", source = "hex", outer_checksum = "9063D14D25406326C0255BDA0021541E797D8A7A12573D849462CAFED459F6EB" }, + { name = "gleam_stdlib", version = "0.39.0", build_tools = ["gleam"], requirements = [], otp_app = "gleam_stdlib", source = "hex", outer_checksum = "2D7DE885A6EA7F1D5015D1698920C9BAF7241102836CE0C3837A4F160128A9C4" }, + { name = "gleeunit", version = "1.2.0", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "gleeunit", source = "hex", outer_checksum = "F7A7228925D3EE7D0813C922E062BFD6D7E9310F0BEE585D3A42F3307E3CFD13" }, + { name = "glexer", version = "1.0.1", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "glexer", source = "hex", outer_checksum = "BD477AD657C2B637FEF75F2405FAEFFA533F277A74EF1A5E17B55B1178C228FB" }, + { name = "justin", version = "1.0.1", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "justin", source = "hex", outer_checksum = "7FA0C6DB78640C6DC5FBFD59BF3456009F3F8B485BF6825E97E1EB44E9A1E2CD" }, + { name = "mug", version = "1.1.0", build_tools = ["gleam"], requirements = ["gleam_erlang", "gleam_stdlib"], otp_app = "mug", source = "hex", outer_checksum = "85A61E67A7A8C25F4460D9CBEF1C09C68FC06ABBC6FF893B0A1F42AE01CBB546" }, + { name = "rank", version = "1.0.0", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "rank", source = "hex", outer_checksum = "5660E361F0E49CBB714CC57CC4C89C63415D8986F05B2DA0C719D5642FAD91C9" }, + { name = "simplifile", version = "2.0.1", build_tools = ["gleam"], requirements = ["filepath", "gleam_stdlib"], otp_app = "simplifile", source = "hex", outer_checksum = "5FFEBD0CAB39BDD343C3E1CCA6438B2848847DC170BA2386DF9D7064F34DF000" }, + { name = "term_size", version = "1.0.1", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "term_size", source = "hex", outer_checksum = "D00BD2BC8FB3EBB7E6AE076F3F1FF2AC9D5ED1805F004D0896C784D06C6645F1" }, + { name = "thoas", version = "1.2.1", build_tools = ["rebar3"], requirements = [], otp_app = "thoas", source = "hex", outer_checksum = "E38697EDFFD6E91BD12CEA41B155115282630075C2A727E7A6B2947F5408B86A" }, + { name = "trie_again", version = "1.1.2", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "trie_again", source = "hex", outer_checksum = "5B19176F52B1BD98831B57FDC97BD1F88C8A403D6D8C63471407E78598E27184" }, +] + +[requirements] +argv = { version = ">= 1.0.2 and < 2.0.0" } +birdie = { version = ">= 1.1.8 and < 2.0.0" } +eval = { version = ">= 1.0.0 and < 2.0.0" } +filepath = { version = ">= 1.0.0 and < 2.0.0" } +glam = { version = ">= 2.0.1 and < 3.0.0" } +gleam_community_ansi = { version = ">= 1.4.0 and < 2.0.0" } +gleam_community_colour = { version = ">= 1.4.0 and < 2.0.0" } +gleam_json = { version = ">= 1.0.0 and < 2.0.0" } +gleam_stdlib = { version = ">= 0.34.0 and < 2.0.0" } +gleeunit = { version = ">= 1.0.0 and < 2.0.0" } +justin = { version = ">= 1.0.1 and < 2.0.0" } +mug = { version = ">= 1.1.0 and < 2.0.0" } +simplifile = { version = ">= 2.0.1 and < 3.0.0" } +term_size = { version = ">= 1.0.1 and < 2.0.0" } diff --git a/prova.gleam b/prova.gleam new file mode 100644 index 0000000..6229a26 --- /dev/null +++ b/prova.gleam @@ -0,0 +1,36 @@ +import gleam/pgo +import decode + +/// A row you get from running the `prova` query +/// defined in `prova.sql`. +/// +/// > 🐿️ This type definition was generated automatically using v1.0.0 of the +/// > [squirrel package](https://github.com/giacomocavalieri/squirrel). +/// +pub type ProvaRow { + ProvaRow(user_id: String, username: String, password: String) +} + +/// Runs the `prova` query +/// defined in `prova.sql`. +/// +/// > 🐿️ This function was generated automatically using v1.0.0 of the +/// > [squirrel package](https://github.com/giacomocavalieri/squirrel). +/// +pub fn prova(db) { + let decoder = + decode.into({ + use user_id <- decode.parameter + use username <- decode.parameter + use password <- decode.parameter + ProvaRow(user_id: user_id, username: username, password: password) + }) + |> decode.field(0, decode.string) + |> decode.field(1, decode.string) + |> decode.field(2, decode.string) + + "-- name: prova +select * from kira_user; +" + |> pgo.execute(db, [], decode.from(decoder, _)) +} diff --git a/prova.sql b/prova.sql new file mode 100644 index 0000000..c2ca628 --- /dev/null +++ b/prova.sql @@ -0,0 +1,2 @@ +-- name: prova +select * from kira_user; diff --git a/src/squirrel.gleam b/src/squirrel.gleam new file mode 100644 index 0000000..23fa9fa --- /dev/null +++ b/src/squirrel.gleam @@ -0,0 +1,198 @@ +import argv +import filepath +import glam/doc +import gleam/int +import gleam/io +import gleam/list +import gleam/result +import gleam_community/ansi +import simplifile +import squirrel/internal/cli +import squirrel/internal/database/postgres +import squirrel/internal/error.{type Error, CannotWriteToFile, FileWithNoQueries} +import squirrel/internal/query.{type TypedQuery} +import term_size + +// TODO LIST: +// - [ ] How does one decide where the query goes? +// - [ ] How do I deal with duplicate names then? +// - [ ] Somehow it treats int[][] as int[] + +const squirrel_version = "v1.0.0" + +// const clear_screen = "\u{001B}[H\u{001B}[J" + +/// Entry point for the `squirrel` CLI. +/// +pub fn main() { + let width = term_size.columns() |> result.unwrap(80) + use option <- cli.run(squirrel_cli(), argv.load().arguments) + case option { + Base(version: False) -> cli.PrintFullHelp + Base(version: True) -> { + io.println(squirrel_version) + cli.Done + } + + Postgres(input: input, output: output, connection: connection) -> { + run(input, output, postgres.main(_, connection)) + |> pretty_result + |> doc.to_string(width) + |> io.println + + cli.Done + } + } +} + +fn run(input_file: String, output_file: String, fun: _) { + use queries <- result.try(query.from_file(input_file)) + use queries <- result.try(fun(queries)) + use _ <- result.try(case queries { + [] -> Error(FileWithNoQueries(input_file)) + _ -> Ok(Nil) + }) + write_queries(queries, to: output_file) +} + +fn write_queries( + queries: List(TypedQuery), + to file: String, +) -> Result(Int, Error) { + let directory = filepath.directory_name(file) + let _ = simplifile.create_directory_all(directory) + + let imports = "import gleam/pgo\nimport decode\n" + let #(count, code) = { + use #(count, code), query <- list.fold(queries, #(0, imports)) + #(count + 1, code <> "\n" <> query.generate_code(squirrel_version, query)) + } + + let try_write = + simplifile.write(code, to: file) + |> result.map_error(CannotWriteToFile(file, _)) + + use _ <- result.try(try_write) + Ok(count) +} + +fn pretty_result(result: Result(Int, Error)) -> doc.Document { + case result { + Error(error) -> error.to_doc(error) + Ok(n) -> + { + "🐿️ Generated code for " + <> int.to_string(n) + <> " " + <> pluralise(n, "query", "queries") + } + |> ansi.green + |> doc.from_string + } +} + +fn pluralise(count: Int, singular: String, plural: String) -> String { + case count { + 1 -> singular + _ -> plural + } +} + +// --- CLI --------------------------------------------------------------------- + +type SquirrelArgs { + Base(version: Bool) + Postgres( + input: String, + output: String, + connection: postgres.ConnectionOptions, + ) +} + +fn squirrel_cli() -> cli.Cli(SquirrelArgs) { + cli.new("gleam run -m squirrel", base()) + |> cli.explain("🐿️ squirrel - " <> squirrel_version) + |> cli.add_subcommand( + cli.new("postgres", postgres()) + |> cli.summarise("generate code for Postgres") + |> cli.explain( + "Perform code generation for a Postgres database.\n" + <> "The generated code will use the `gleam_pgo` package to connect to a " + <> "Postgres database and the `decode` package to decode the read rows.", + ), + ) +} + +fn base() -> cli.Command(SquirrelArgs) { + cli.command({ + use version <- cli.next + Base(version: version) + }) + |> cli.add(cli.flag("version", "print version")) +} + +fn postgres() -> cli.Command(SquirrelArgs) { + cli.command({ + use input <- cli.next + use output <- cli.next + use database <- cli.next + use port <- cli.next + use user <- cli.next + Postgres( + input: input, + output: output, + connection: postgres.ConnectionOptions( + ..postgres.default_connection(), + port: port, + database: database, + user: user, + ), + ) + }) + |> cli.add(input_arg()) + |> cli.add(output_arg()) + |> cli.add(database_arg()) + |> cli.add(port_option(5432)) + |> cli.add(user_option()) +} + +fn input_arg() { + cli.positional( + name: "input", + description: "the input file to read SQL queries from", + ) +} + +fn output_arg() { + cli.positional( + name: "output", + description: "the output file to write generated Gleam to", + ) +} + +fn database_arg() { + cli.positional( + name: "database", + description: "the name of the database to connect to", + ) +} + +fn port_option(default: Int) { + cli.labelled( + long: "port", + description: "the port used to connect to the database", + ) + |> cli.validate_if_present(fn(n) { + int.parse(n) + |> result.replace_error("port should be a number") + }) + |> cli.default(default) +} + +fn user_option() { + cli.labelled( + long: "user", + description: "the username used to connect to the database", + ) + |> cli.default("root") +} diff --git a/src/squirrel/internal/cli.gleam b/src/squirrel/internal/cli.gleam new file mode 100644 index 0000000..1f5f861 --- /dev/null +++ b/src/squirrel/internal/cli.gleam @@ -0,0 +1,701 @@ +import glam/doc.{type Document} +import gleam/dict.{type Dict} +import gleam/int +import gleam/io +import gleam/list +import gleam/option.{type Option, None, Some} +import gleam/result +import gleam/set.{type Set} +import gleam/string +import gleam_community/ansi +import term_size + +// --- TYPES ------------------------------------------------------------------- + +pub opaque type Cli(a) { + Cli( + command_name: String, + summary: String, + explanation: String, + command: Command(a), + commands: Dict(String, Cli(a)), + ) +} + +pub type Error { + MissingPositionalArgument(name: String) + LeftoverPositionals + + FlagWithValue(name: String) + OptionWithNoValue(name: String) + + UnknownOption(value: String) + CannotParseFlag(name: String, reason: String) + CannotParseLabelled(name: String, reason: String) + CannotParsePositional(name: String, reason: String) +} + +pub type CannotHaveDefault + +pub type CanHaveDefault + +pub type AlreadyHasDefault + +pub opaque type CliOption(a, default) { + Flag(description: String, parse: fn(Bool) -> Result(a, String), long: String) + Positional( + description: String, + parse: fn(String) -> Result(a, String), + name: String, + ) + Labelled( + description: String, + parse: fn(Option(String)) -> Result(a, String), + long: String, + default: Option(a), + ) +} + +type CommandState { + CommandState( + positionals: List(String), + flags: Set(String), + labelled: Dict(String, String), + ) +} + +type CliState { + CliState( + path: List(String), + explanation: String, + usage: Usage, + commands: Dict(String, String), + ) +} + +type Usage { + Usage( + positionals: List(#(String, String)), + flags: Dict(String, String), + labelled: Dict(String, String), + ) +} + +pub opaque type Command(a) { + Command( + usage: Usage, + parse: fn(CommandState) -> #(CommandState, Result(a, List(Error))), + ) +} + +type OptionKind { + LabelledKind + FlagKind +} + +// --- COMMANDS ---------------------------------------------------------------- + +pub fn command(a: a) -> Command(a) { + let empty_usage = + Usage(positionals: [], flags: dict.new(), labelled: dict.new()) + Command(usage: empty_usage, parse: fn(state) { #(state, Ok(a)) }) +} + +pub fn next(fun: fn(a) -> b) -> fn(a) -> b { + fun +} + +type InitialStateResult { + FlagsHaveHelp + Parsed(CommandState) +} + +fn initial_state(args: List(String), usage: Usage, state: CommandState) { + case args { + [] -> + Ok(Parsed( + CommandState(..state, positionals: list.reverse(state.positionals)), + )) + + ["--help", ..] -> Ok(FlagsHaveHelp) + + ["--" <> flag, ..rest] -> { + case string.split_once(flag, on: "=") { + Ok(#(flag, value)) -> + case option_kind(usage, flag) { + Ok(LabelledKind) -> { + let labelled = dict.insert(state.labelled, flag, value) + let state = CommandState(..state, labelled: labelled) + initial_state(rest, usage, state) + } + Ok(FlagKind) -> Error(FlagWithValue(flag)) + Error(_) -> Error(UnknownOption(flag)) + } + + Error(_) -> + case option_kind(usage, flag) { + Ok(LabelledKind) -> + case rest { + [] | ["--" <> _, ..] -> Error(OptionWithNoValue(flag)) + [value, ..rest] -> { + let labelled = dict.insert(state.labelled, flag, value) + let state = CommandState(..state, labelled: labelled) + initial_state(rest, usage, state) + } + } + Ok(FlagKind) -> { + let flags = set.insert(state.flags, flag) + let state = CommandState(..state, flags: flags) + initial_state(rest, usage, state) + } + Error(_) -> Error(UnknownOption(flag)) + } + } + } + + [positional, ..rest] -> + initial_state( + rest, + usage, + CommandState(..state, positionals: [positional, ..state.positionals]), + ) + } +} + +fn option_kind(usage: Usage, option_name: String) -> Result(OptionKind, Nil) { + case dict.has_key(usage.flags, option_name) { + True -> Ok(FlagKind) + False -> + case dict.has_key(usage.labelled, option_name) { + True -> Ok(LabelledKind) + False -> Error(Nil) + } + } +} + +pub fn add( + to command: Command(fn(a) -> b), + option option: CliOption(a, _), +) -> Command(b) { + use state <- Command(usage: command.usage |> register_option(option)) + let #(state, rest_result) = command.parse(state) + let #(state, option_result) = parse_arg(state, option) + let result = case option_result, rest_result { + Ok(option), Ok(fun) -> Ok(fun(option)) + Ok(_), Error(errors) -> Error(errors) + Error(option_error), Ok(_) -> Error([option_error]) + Error(option_error), Error(result_errors) -> + Error([option_error, ..result_errors]) + } + #(state, result) +} + +fn register_option(usage: Usage, option: CliOption(a, _)) -> Usage { + case option { + Flag(description: description, long: long, ..) -> + Usage(..usage, flags: dict.insert(usage.flags, long, description)) + + Labelled(description: description, long: long, ..) -> + Usage(..usage, labelled: dict.insert(usage.labelled, long, description)) + + Positional(name: name, description: description, ..) -> + Usage(..usage, positionals: [#(name, description), ..usage.positionals]) + } +} + +fn parse_arg( + state: CommandState, + option: CliOption(a, _), +) -> #(CommandState, Result(a, Error)) { + case option { + Flag(long: long, parse: parse, ..) -> + case take_flag(state, long) { + Ok(state) -> #( + state, + parse(True) + |> result.map_error(CannotParseFlag(name: long, reason: _)), + ) + Error(_) -> #( + state, + parse(False) + |> result.map_error(CannotParseFlag(name: long, reason: _)), + ) + } + + Labelled(long: long, parse: parse, ..) -> + case take_labelled(state, long) { + Ok(#(state, value)) -> #( + state, + parse(Some(value)) + |> result.map_error(CannotParseLabelled(name: long, reason: _)), + ) + Error(_) -> #( + state, + parse(None) + |> result.map_error(CannotParseLabelled(name: long, reason: _)), + ) + } + + Positional(name: name, parse: parse, ..) -> + case take_positional(state) { + Ok(#(state, value)) -> #( + state, + parse(value) + |> result.map_error(CannotParsePositional(name: name, reason: _)), + ) + Error(_) -> #(state, Error(MissingPositionalArgument(name))) + } + } +} + +fn take_positional(state: CommandState) -> Result(#(CommandState, String), Nil) { + let CommandState(positionals: positionals, ..) = state + case positionals { + [] -> Error(Nil) + [positional, ..rest] -> + Ok(#(CommandState(..state, positionals: rest), positional)) + } +} + +fn take_flag(state: CommandState, long: String) -> Result(CommandState, Nil) { + let CommandState(flags: flags, ..) = state + case set.contains(flags, long) { + True -> Ok(CommandState(..state, flags: set.delete(flags, long))) + False -> Error(Nil) + } +} + +fn take_labelled( + state: CommandState, + long: String, +) -> Result(#(CommandState, String), Nil) { + let CommandState(labelled: labelled, ..) = state + case dict.get(labelled, long) { + Ok(value) -> + Ok(#(CommandState(..state, labelled: dict.delete(labelled, long)), value)) + Error(_) -> Error(Nil) + } +} + +// --- COMMAND ARGUMENTS ------------------------------------------------------- + +pub fn flag( + long long: String, + description description: String, +) -> CliOption(Bool, CannotHaveDefault) { + Flag(description: description, parse: fn(bool) { Ok(bool) }, long: long) +} + +pub fn positional( + name name: String, + description description: String, +) -> CliOption(String, CannotHaveDefault) { + Positional( + description: description, + parse: fn(string) { Ok(string) }, + name: name, + ) +} + +pub fn labelled( + long long: String, + description description: String, +) -> CliOption(Option(String), CanHaveDefault) { + Labelled( + description: description, + parse: fn(string) { Ok(string) }, + long: long, + default: None, + ) +} + +pub fn validate( + option: CliOption(a, d), + with fun: fn(a) -> Result(b, String), +) -> CliOption(b, d) { + case option { + Flag(description: description, parse: parse, long: long) -> + Flag(description: description, long: long, parse: fn(bool) { + parse(bool) |> result.then(fun) + }) + + Positional(description: description, parse: parse, name: name) -> + Positional(description: description, name: name, parse: fn(string) { + parse(string) |> result.then(fun) + }) + + Labelled( + description: description, + parse: parse, + long: long, + default: default, + ) -> + Labelled( + description: description, + long: long, + default: default + |> option.map(fn(value) { + case fun(value) { + Error(_) -> None + Ok(value) -> Some(value) + } + }) + |> option.flatten, + parse: fn(string) { parse(string) |> result.then(fun) }, + ) + } +} + +pub fn validate_if_present( + option: CliOption(Option(a), d), + with fun: fn(a) -> Result(b, String), +) -> CliOption(Option(b), d) { + validate(option, fn(option) { + case option { + Some(value) -> fun(value) |> result.map(Some) + None -> Ok(None) + } + }) +} + +pub fn default( + option: CliOption(Option(a), CanHaveDefault), + value: a, +) -> CliOption(a, AlreadyHasDefault) { + case option { + Flag(..) | Positional(..) -> panic as "shouldn't allow a default" + Labelled(description: description, parse: parse, long: long, default: _) -> + Labelled( + description: description, + long: long, + default: Some(value), + parse: fn(string) { + case parse(string) { + Ok(Some(a)) -> Ok(a) + Ok(None) -> Ok(value) + Error(error) -> Error(error) + } + }, + ) + } +} + +// --- woo + +pub fn new(command_name: String, command: Command(a)) -> Cli(a) { + Cli( + command_name: command_name, + explanation: "", + summary: "", + command: command, + commands: dict.new(), + ) +} + +pub fn explain(cli: Cli(a), explanation: String) -> Cli(a) { + Cli(..cli, explanation: explanation) +} + +pub fn summarise(cli: Cli(a), summary: String) -> Cli(a) { + Cli(..cli, summary: summary) +} + +pub fn add_subcommand(cli: Cli(a), subcommand: Cli(a)) -> Cli(a) { + let commands = dict.insert(cli.commands, subcommand.command_name, subcommand) + Cli(..cli, commands: commands) +} + +pub type CliCommand { + Done + PrintFullHelp +} + +pub fn run(cli: Cli(a), args: List(String), then do: fn(a) -> CliCommand) -> Nil { + let output_doc = do_run([], cli, args, do) + case output_doc == doc.empty { + True -> Nil + False -> + output_doc + |> doc.to_string(term_size.columns() |> result.unwrap(80)) + |> io.println + } +} + +fn do_run( + path: List(String), + cli: Cli(a), + args: List(String), + then do: fn(a) -> CliCommand, +) -> Document { + let path = [cli.command_name, ..path] + + let run_current = fn() { + let state = CommandState([], set.new(), dict.new()) + let cli_state = + CliState( + path: path, + explanation: cli.explanation, + commands: dict.map_values(cli.commands, fn(_, cli) { cli.summary }), + usage: cli.command.usage, + ) + + case initial_state(args, cli.command.usage, state) { + Error(reason) -> pretty_error(cli_state, [reason]) + Ok(FlagsHaveHelp) -> pretty_help(cli_state) + Ok(Parsed(state)) -> { + let #(_command_state, values) = cli.command.parse(state) + // TODO)) I need additional controls of the final state to know if + // there's arguments that were not consumed. + case values { + Error(reason) -> pretty_error(cli_state, reason) + Ok(values) -> + case do(values) { + Done -> doc.empty + PrintFullHelp -> pretty_help(cli_state) + } + } + } + } + } + + case args { + [] -> run_current() + [first, ..rest] -> + case dict.get(cli.commands, first) { + Ok(subcommand) -> do_run(path, subcommand, rest, do) + Error(_) -> run_current() + } + } +} + +// --- PRETTY PRINTING --------------------------------------------------------- + +const indent = 2 + +fn pretty_error(cli_state: CliState, errors: List(Error)) -> _ { + io.debug(cli_state) + io.debug(errors) + panic +} + +fn pretty_help(cli: CliState) -> Document { + let CliState( + path: path, + explanation: explanation, + usage: usage, + commands: commands, + ) = cli + + let command = list.reverse(path) |> string.join(with: " ") + + [ + flexible_string(explanation), + usage_section(command, usage, commands), + positional_section(usage), + commands_section(commands), + flags_section(usage), + ] + |> list.filter(keeping: fn(arg) { arg != doc.empty }) + |> doc.join(with: doc.lines(2)) +} + +fn flexible_string(string: String) -> Document { + string.split(string, on: "\n") + |> list.map(fn(line) { + string.split(line, on: " ") + |> list.map(doc.from_string) + |> doc.join(with: doc.flex_space) + |> doc.group + }) + |> doc.join(with: doc.line) + |> doc.group +} + +fn usage_section( + command: String, + usage: Usage, + commands: Dict(String, String), +) -> Document { + let flags = colour("[flags]", Green) |> doc.prepend(doc.flex_space) + let commands = case dict.is_empty(commands) { + True -> doc.empty + False -> colour("", Yellow) |> doc.prepend(doc.flex_space) + } + + let named_args = + usage.positionals + |> list.reverse + |> list.map(fn(arg) { colour("<" <> arg.0 <> ">", Cyan) }) + |> doc.join(with: doc.flex_space) + |> doc.prepend(doc.flex_space) + + let usage_line = + [doc.from_string(command), flags, commands, named_args] + |> doc.concat + |> doc.group + |> doc.nest(by: indent) + + [ + doc.from_string(ansi.bold("Usage:")), + [doc.line, usage_line] + |> doc.concat + |> doc.nest(by: indent), + ] + |> doc.concat +} + +fn positional_section(usage: Usage) -> Document { + case list.reverse(usage.positionals) { + [] -> doc.empty + [first, ..rest] as args -> { + let widest_name = { + use n, arg <- list.fold(rest, from: string.length(first.0)) + int.max(n, string.length(arg.0)) + } + + let args = + list.map(args, fn(arg) { positional_line(arg, widest_name) }) + |> doc.join(with: doc.line) + + [ + colour(ansi.bold("Arguments:"), Cyan), + [doc.line, args] + |> doc.concat + |> doc.nest(by: indent), + ] + |> doc.concat + } + } +} + +fn positional_line(arg: #(String, String), widest_name: Int) { + let column_width = widest_name + 2 + let padding = string.repeat(" ", column_width - string.length(arg.0)) + + [ + colour(arg.0, Cyan), + doc.from_string(padding), + flexible_string(arg.1) + |> doc.group + |> doc.nest(by: column_width), + ] + |> doc.concat +} + +fn commands_section(commands: Dict(String, String)) -> Document { + let sorted_commands = + dict.to_list(commands) + |> list.sort(fn(one, other) { string.compare(one.0, other.0) }) + + case sorted_commands { + [] -> doc.empty + [first, ..rest] as commands -> { + let widest_name = { + use n, command <- list.fold(rest, from: string.length(first.0)) + int.max(n, string.length(command.0)) + } + + let commands = + list.map(commands, fn(command) { command_line(command, widest_name) }) + |> doc.join(with: doc.line) + + [ + colour(ansi.bold("Commands:"), Yellow), + [doc.line, commands] + |> doc.concat + |> doc.nest(by: indent), + ] + |> doc.concat + } + } +} + +fn command_line(command: #(String, String), widest_name: Int) { + let column_width = widest_name + 2 + let padding = string.repeat(" ", column_width - string.length(command.0)) + + [ + colour(command.0, Yellow), + doc.from_string(padding), + flexible_string(command.1) + |> doc.group + |> doc.nest(by: column_width), + ] + |> doc.concat +} + +fn flags_section(usage: Usage) -> Document { + let sorted_labelled = + dict.merge(usage.labelled, usage.flags) + |> dict.insert("help", "print this help text") + |> dict.to_list + |> list.sort(fn(one, other) { string.compare(one.0, other.0) }) + + case sorted_labelled { + [] -> doc.empty + [first, ..rest] as labelled -> { + let widest_name = { + use n, labelled <- list.fold(rest, from: labelled_width(first.0)) + int.max(n, labelled_width(labelled.0)) + } + + let labelled = + list.map(labelled, fn(labelled) { labelled_line(labelled, widest_name) }) + |> doc.join(with: doc.line) + + [ + colour(ansi.bold("Flags:"), Green), + [doc.line, labelled] + |> doc.concat + |> doc.nest(by: indent), + ] + |> doc.concat + } + } +} + +fn labelled_line(labelled: #(String, String), widest_name: Int) -> Document { + let column_width = widest_name + 2 + let padding = string.repeat(" ", column_width - labelled_width(labelled.0)) + + let long = colour("--" <> labelled.0, Green) + let flag_column = [long] |> doc.concat + + let description_column = + [flexible_string(labelled.1)] + |> doc.concat + |> doc.group + |> doc.nest(by: column_width) + + [flag_column, doc.from_string(padding), description_column] + |> doc.concat +} + +fn labelled_width(label: String) -> Int { + string.length(label) + 2 +} + +type Colour { + Green + Purple + Cyan + Yellow +} + +fn colour_code(colour: Colour) -> String { + case colour { + Green -> "32" + Purple -> "35" + Cyan -> "36" + Yellow -> "33" + } +} + +fn colour(string: String, colour: Colour) -> Document { + [ + doc.zero_width_string("\u{001B}[" <> colour_code(colour) <> "m"), + doc.from_string(string), + doc.zero_width_string("\u{001B}[39m"), + ] + |> doc.concat +} diff --git a/src/squirrel/internal/database/postgres.gleam b/src/squirrel/internal/database/postgres.gleam new file mode 100644 index 0000000..5198643 --- /dev/null +++ b/src/squirrel/internal/database/postgres.gleam @@ -0,0 +1,819 @@ +import eval +import gleam/bit_array +import gleam/bool +import gleam/dict.{type Dict} +import gleam/dynamic.{type DecodeErrors, type Dynamic} as d +import gleam/int +import gleam/json +import gleam/list +import gleam/option.{type Option, None, Some} +import gleam/result +import gleam/set.{type Set} +import gleam/string +import squirrel/internal/database/postgres_protocol as pg +import squirrel/internal/error.{ + type Error, type Pointer, type ValueIdentifierError, ByteIndex, + CannotParseQuery, PgCannotAuthenticate, PgCannotDecodeReceivedMessage, + PgCannotDescribeQuery, PgCannotReceiveMessage, PgCannotSendMessage, Pointer, + QueryHasInvalidColumn, QueryHasUnsupportedType, +} +import squirrel/internal/eval_extra +import squirrel/internal/gleam +import squirrel/internal/query.{ + type TypedQuery, type UntypedQuery, TypedQuery, UntypedQuery, +} + +const find_postgres_type_query = " +select + -- The name of the type or, if the type is an array, the name of its + -- elements' type. + case + when elem.typname is null then type.typname + else elem.typname + end as type, + -- Tells us how to interpret the firs column: if this is true then the first + -- column is the type of the elements of the array type. + -- Otherwise it means we've found a base type. + case + when elem.typname is null then false + else true + end as is_array +from + pg_type as type + left join pg_type as elem on type.typelem = elem.oid +where + type.oid = $1 +" + +const find_column_nullability_query = " +select + -- Whether the column has a not-null constraint. + attnotnull +from + pg_attribute +where + -- The oid of the table the column comes from. + attrelid = $1 + -- The index of the column we're looking for. + and attnum = $2 +" + +// --- TYPES ------------------------------------------------------------------- + +type PgType { + PBase(name: String) + PArray(inner: PgType) + POption(inner: PgType) +} + +type Context { + Context( + db: pg.Connection, + gleam_types: Dict(Int, gleam.Type), + column_nullability: Dict(#(Int, Int), Nullability), + ) +} + +type Nullability { + Nullable + NotNullable +} + +type Plan { + Plan( + join_type: Option(JoinType), + parent_relation: Option(ParentRelation), + output: Option(List(String)), + plans: Option(List(Plan)), + ) +} + +type JoinType { + Full + Left + Right + Other +} + +type ParentRelation { + Inner + NotInner +} + +type Db(a) = + eval.Eval(a, Error, Context) + +pub type ConnectionOptions { + ConnectionOptions( + host: String, + port: Int, + user: String, + password: String, + database: String, + timeout: Int, + ) +} + +// --- POSTGRES TO GLEAM TYPES CONVERSIONS ------------------------------------- + +fn pg_to_gleam_type(type_: PgType) -> Result(gleam.Type, String) { + case type_ { + PArray(inner: inner) -> + pg_to_gleam_type(inner) + |> result.map(gleam.List) + |> result.map_error(fn(inner) { inner <> "[]" }) + + POption(inner: inner) -> + pg_to_gleam_type(inner) + |> result.map(gleam.Option) + |> result.map_error(fn(inner) { inner <> "?" }) + + PBase(name: name) -> + case name { + "bool" -> Ok(gleam.Bool) + "text" | "char" -> Ok(gleam.String) + "float4" | "float8" -> Ok(gleam.Float) + "int2" | "int4" | "int8" -> Ok(gleam.Int) + _ -> Error(name) + } + } +} + +// --- CLI ENTRY POINT --------------------------------------------------------- + +pub fn default_connection() -> ConnectionOptions { + ConnectionOptions( + host: "localhost", + port: 5432, + user: "root", + password: "", + database: "", + timeout: 1000, + ) +} + +pub fn main( + queries: List(UntypedQuery), + connection: ConnectionOptions, +) -> Result(List(TypedQuery), Error) { + let script = { + use _ <- eval.try(authenticate(connection)) + // Once the server has confirmed that it is ready to accept query requests we + // can start gathering information about all the different queries. + // After each one we need to make sure the server is ready to go on with the + // next one. + // + // https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-EXT-QUERY + // + use query <- eval_extra.try_map(queries) + infer_types(query) + } + + script + |> eval.run(Context( + db: pg.connect(connection.host, connection.port, connection.timeout), + gleam_types: dict.new(), + column_nullability: dict.new(), + )) +} + +fn authenticate(connection: ConnectionOptions) -> Db(Nil) { + let params = [#("user", connection.user), #("database", connection.database)] + use _ <- eval.try(send(pg.FeStartupMessage(params))) + + use msg <- eval.try(receive()) + use _ <- eval.try(case msg { + pg.BeAuthenticationOk -> eval.return(Nil) + _ -> unexpected_message(PgCannotAuthenticate, "AuthenticationOk", msg) + }) + + use _ <- eval.try(wait_until_ready()) + eval.return(Nil) +} + +/// Returns type information about a query. +/// +fn infer_types(query: UntypedQuery) -> Db(TypedQuery) { + // Postgres doesn't give us 100% accurate data regardin a query's type. + // We'll need to perform a couple of database interrogations and do some + // guessing. + // + // The big picture idea is the following: + // - We ask the server to prepare the query. + // - Postgres will reply with type information about the returned rows and the + // query's parameters. + let action = parameters_and_returns(query) + use #(parameters, returns) <- eval.try(action) + // - The parameters' types are just OIDs so we need to interrogate the + // database to learn the actual corresponding Gleam type. + use parameters <- eval.try(resolve_parameters(query, parameters)) + // - The returns' types are just OIDs so we have to do the same for those. + // - Here comes the tricky part: we can't know if a column is nullable just + // from the server's previous answer. + // - For columns coming from a database table we'll look it up and see if + // the column is nullable or not + // - But this is not enough! If a returned column comes from a left/right + // join it will be nullable even if it is not in the original table. + // To work around this we'll have to inspect the query plan. + use plan <- eval.try(query_plan(query, list.length(parameters))) + let nullables = nullables_from_plan(plan) + use returns <- eval.try(resolve_returns(query, returns, nullables)) + + query + |> query.add_types(parameters, returns) + |> eval.return +} + +fn parameters_and_returns(query: UntypedQuery) -> Db(_) { + // We need to send three messages: + // - `Parse` with the query to parse + // - `Describe` to ask the server to reply with a description of the query's + // return types and parameter types. This is what we need to understand the + // SQL inferred types and generate the corresponding Gleam types. + // - `Sync` to ask the server to immediately reply with the results of parsing + // + use _ <- eval.try( + send_all([ + pg.FeParse("", query.content, []), + pg.FeDescribe(pg.PreparedStatement, ""), + pg.FeSync, + ]), + ) + + // Error builder used in the following steps in case the message sequence + // doesn't go as planned. + let cannot_describe = fn(expected, got) { + PgCannotDescribeQuery( + file: query.file, + query_name: gleam.identifier_to_string(query.name), + expected: expected, + got: got, + ) + } + + use msg <- eval.try(receive()) + case msg { + pg.BeErrorResponse(errors) -> + eval.throw(error_fields_to_parse_error(query, errors)) + pg.BeParseComplete -> { + use msg <- eval.try(receive()) + use parameters <- eval.try(case msg { + pg.BeParameterDescription(parameters) -> eval.return(parameters) + _ -> unexpected_message(cannot_describe, "ParameterDescription", msg) + }) + + use msg <- eval.try(receive()) + use rows <- eval.try(case msg { + pg.BeRowDescriptions(rows) -> eval.return(rows) + _ -> unexpected_message(cannot_describe, "RowDescriptions", msg) + }) + + use msg <- eval.try(receive()) + use _ <- eval.try(case msg { + pg.BeReadyForQuery(_) -> eval.return(Nil) + _ -> unexpected_message(cannot_describe, "ReadyForQuery", msg) + }) + + eval.return(#(parameters, rows)) + } + _ -> + unexpected_message(cannot_describe, "ParseComplete or ErrorResponse", msg) + } +} + +fn error_fields_to_parse_error( + query: UntypedQuery, + errors: Set(pg.ErrorOrNoticeField), +) -> Error { + let #(error_code, message, position, hint) = { + use acc, error_field <- set.fold(errors, from: #(None, None, None, None)) + let #(code, message, position, hint) = acc + case error_field { + pg.Code(code) -> #(Some(code), message, position, hint) + pg.Message(message) -> #(code, Some(message), position, hint) + pg.Hint(hint) -> #(code, message, position, Some(hint)) + pg.Position(position) -> + case int.parse(position) { + Ok(position) -> #(code, message, Some(position), hint) + Error(_) -> acc + } + _ -> acc + } + } + + let pointer = case message, position { + Some(message), Some(position) -> + Some(Pointer(point_to: ByteIndex(position), message: message)) + _, _ -> None + } + + cannot_parse_error(query, error_code, hint, pointer) +} + +fn resolve_parameters( + query: UntypedQuery, + parameters: List(Int), +) -> Db(List(gleam.Type)) { + use oid <- eval_extra.try_map(parameters) + find_gleam_type(query, oid) +} + +/// Looks up a type with the given id in the Postgres registry. +/// +/// > ⚠️ This function assumes that the oid is present in the database and +/// > will crash otherwise. This should only be called with oids coming from +/// > a database interrogation. +/// +fn find_gleam_type(query: UntypedQuery, oid: Int) -> Db(gleam.Type) { + // We first look for the Gleam type corresponding to this id in the cache to + // avoid hammering the db with needless queries. + use <- with_cached_gleam_type(oid) + + // The only parameter to this query is the oid of the type to lookup: + // that's a 32bit integer (its oid needed to prepare the query is 23). + let params = [pg.Parameter(<>)] + let run_query = find_postgres_type_query |> run_query(params, [23]) + use res <- eval.try(run_query) + + // We know the output must only contain two values: the name and a boolean to + // check wether it is an array or not. + // It's safe to assert because this query is hard coded in our code and the + // output shape cannot change without us changing that query. + let assert [name, is_array] = res + + // We then decode the bitarrays we got as a result: + // - `name` is just a string + // - `is_array` is a pg boolean + let assert Ok(name) = bit_array.to_string(name) + let type_ = case bit_array_to_bool(is_array) { + True -> PArray(PBase(name)) + False -> PBase(name) + } + + pg_to_gleam_type(type_) + |> result.map_error(unsupported_type_error(query, _)) + |> eval.from_result +} + +/// Returns the query plan for a given query. +/// `parameters` is the number of parameter placeholders in the query. +/// +fn query_plan(query: UntypedQuery, parameters: Int) -> Db(Plan) { + // We ask postgres to give us the query plan. To do that we need to fill in + // all the possible holes in the user supplied query with null values; + // otherwise, the server would complain that it has arguments that are not + // bound. + let query = "explain (format json, verbose) " <> query.content + let params = list.repeat(pg.Null, parameters) + let run_query = run_query(query, params, []) + use res <- eval.try(run_query) + + // We know the output will only contain a single row that is the json string + // containing the query plan. + let assert [plan] = res + let assert Ok([plan, ..]) = json.decode_bits(plan, json_plans_decoder) + eval.return(plan) +} + +/// Given a query plan, returns a set with the indices of the output columns +/// that can contain null values. +/// +fn nullables_from_plan(plan: Plan) -> Set(Int) { + let outputs = case plan.output { + Some(outputs) -> list.index_fold(outputs, dict.new(), dict.insert) + None -> dict.new() + } + + do_nullables_from_plan(plan, outputs, set.new()) +} + +fn do_nullables_from_plan( + plan: Plan, + // A dict from "column name" to its position in the query output. + query_outputs: Dict(String, Int), + nullables: Set(Int), +) -> Set(Int) { + let nullables = case plan.output, plan.join_type, plan.parent_relation { + // - All the outputs of a full join must be marked as nullable + // - All the outputs of an inner half join must be marked as nullable + Some(outputs), Some(Full), _ | Some(outputs), _, Some(Inner) -> { + use nullables, output <- list.fold(outputs, from: nullables) + case dict.get(query_outputs, output) { + Ok(i) -> set.insert(nullables, i) + Error(_) -> nullables + } + } + _, _, _ -> nullables + } + + case plan.plans, plan.join_type { + // If this is an inner half join we keep inspecting the children to mark + // their outputs as nullable. + Some(plans), Some(Left) | Some(plans), Some(Right) -> { + use nullables, plan <- list.fold(plans, from: nullables) + do_nullables_from_plan(plan, query_outputs, nullables) + } + _, _ -> nullables + } +} + +/// Given a list of `RowDescriptionFields` it turns those into Gleam fields with +/// a name and a type. +/// +/// This also uses nullability info coming from the query plan to figure out if +/// a column can be nullable or not: +/// - If the column name ends with `!` it will be forced to be not nullable +/// - If the column name ends with `?` it will be forced to be nullable +/// - If the column appears in the `nullables` set then it will be nullable +/// - Othwerwise we look for its metadata in the database and if it has a +/// not-null constraint it will be not nullable; otherwise it will be nullable +/// +fn resolve_returns( + query: UntypedQuery, + returns: List(pg.RowDescriptionField), + nullables: Set(Int), +) -> Db(List(gleam.Field)) { + use column, i <- eval_extra.try_index_map(returns) + let pg.RowDescriptionField( + data_type_oid: type_oid, + attr_number: column, + table_oid: table, + name: name, + .., + ) = column + + use type_ <- eval.try(find_gleam_type(query, type_oid)) + + let ends_with_exclamation_mark = string.ends_with(name, "!") + let ends_with_question_mark = string.ends_with(name, "?") + use nullability <- eval.try(case ends_with_exclamation_mark { + True -> eval.return(NotNullable) + False -> + case ends_with_question_mark { + True -> eval.return(Nullable) + False -> + case set.contains(nullables, i) { + True -> eval.return(Nullable) + False -> column_nullability(table: table, column: column) + } + } + }) + + let type_ = case nullability { + Nullable -> gleam.Option(type_) + NotNullable -> type_ + } + + let try_convert_name = + case ends_with_exclamation_mark || ends_with_question_mark { + True -> string.drop_right(name, 1) + False -> name + } + |> gleam.identifier + |> result.map_error(invalid_column_error(query, name, _)) + + use name <- eval.try(eval.from_result(try_convert_name)) + + let field = gleam.Field(label: name, type_: type_) + eval.return(field) +} + +fn column_nullability(table table: Int, column column: Int) -> Db(Nullability) { + // We first check if the table+column is cached to avoid making redundant + // queries to the database. + use <- with_cached_column(table: table, column: column) + + // If the table oid is 0 that means the column doesn't come from any table so + // we just assume it's not nullable. + use <- bool.guard(when: table == 0, return: eval.return(NotNullable)) + + // This query has 2 parameters: + // - the oid of the table (a 32bit integer, oid is 23) + // - the index of the column (a 32 bit integer, oid is 23) + let params = [pg.Parameter(<>), pg.Parameter(<>)] + let run_query = find_column_nullability_query |> run_query(params, [23, 23]) + use res <- eval.try(run_query) + + // We know the output will only have only one column, that is the boolean + // telling us if the column has a not-null constraint. + let assert [has_non_null_constraint] = res + case bit_array_to_bool(has_non_null_constraint) { + True -> eval.return(NotNullable) + False -> eval.return(Nullable) + } +} + +// --- DB ACTION HELPERS ------------------------------------------------------- + +/// Runs a query against the database. +/// - `parameters` are the parameter values that need to be supplied in place of +/// the query placeholders +/// - `parameters_object_ids` are the oids describing the type of each +/// parameter. +/// +/// > ⚠️ The `parameters_objects_ids` should have the same length of +/// > `parameters` and correctly describe each parameter's type. This function +/// > makes no attempt whatsoever to verify this assumption is correct so be +/// > careful! +/// +/// > ⚠️ This function makes the assumption that the query will only return one +/// > single row. This is totally fine here because we only use this to run +/// > specific hard coded queries that are guaranteed to return a single row. +/// +fn run_query( + query: String, + parameters: List(pg.ParameterValue), + parameters_object_ids: List(Int), +) -> Db(List(BitArray)) { + // The message exchange to run a query works as follow: + // - `Parse` we ask the server to parse the query, we do not give it a name + // - `Bind` we bind the query to the unnamed portal so that it is ready to + // be executed. + // - `Execute` we ask the server to run the unnamed portal and return all + // the rows (0 means return all rows). + // - `Close` we ask to close the unnamed query and the unnamed portal to free + // their resources. + // - `Sync` this acts as a synchronization point that needs to go at the end + // of the sequence before the next one. + // + // As you can see in the receiving part, each message we send corresponds to a + // specific answer from the server: + // - `ParseComplete` the query was parsed correctly + // - `BindComplete` the query was bound to a portal + // - `MessageDataRow` the result(s) coming from the query execution + // - `CommandComplete` when the result coming from the query is over + // - `CloseComplete` the portal/statement was closed + // - `ReadyForQuery` final reply to the sync message signalling we can go on + // making new requests + use _ <- eval.try( + send_all([ + pg.FeParse("", query, parameters_object_ids), + pg.FeBind( + portal: "", + statement_name: "", + parameter_format: pg.FormatAll(pg.Binary), + parameters:, + result_format: pg.FormatAll(pg.Binary), + ), + pg.FeExecute("", 0), + pg.FeClose(pg.PreparedStatement, ""), + pg.FeClose(pg.Portal, ""), + pg.FeSync, + ]), + ) + + use msg <- eval.try(receive()) + let assert pg.BeParseComplete = msg + use msg <- eval.try(receive()) + let assert pg.BeBindComplete = msg + use msg <- eval.try(receive()) + let assert pg.BeMessageDataRow(res) = msg + use msg <- eval.try(receive()) + let assert pg.BeCommandComplete(_, _) = msg + use msg <- eval.try(receive()) + let assert pg.BeCloseComplete = msg + use msg <- eval.try(receive()) + let assert pg.BeCloseComplete = msg + use msg <- eval.try(receive()) + let assert pg.BeReadyForQuery(_) = msg + eval.return(res) +} + +/// Receive a single message from the database. +/// +fn receive() -> Db(pg.BackendMessage) { + use Context(db: db, ..) as context <- eval.from + case pg.receive(db) { + Ok(#(db, msg)) -> #(Context(..context, db: db), Ok(msg)) + Error(pg.ReadDecodeError(error)) -> #( + context, + Error(PgCannotDecodeReceivedMessage(string.inspect(error))), + ) + Error(pg.SocketError(error)) -> #( + context, + Error(PgCannotReceiveMessage(string.inspect(error))), + ) + } +} + +/// Send a single message to the database. +/// +fn send(message message: pg.FrontendMessage) -> Db(Nil) { + use Context(db: db, ..) as context <- eval.from + + let result = + message + |> pg.encode_frontend_message + |> pg.send(db, _) + + let #(db, result) = case result { + Ok(db) -> #(db, Ok(Nil)) + Error(error) -> #(db, Error(PgCannotSendMessage(string.inspect(error)))) + } + + #(Context(..context, db: db), result) +} + +/// Send many messages, one after the other. +/// +fn send_all(messages messages: List(pg.FrontendMessage)) -> Db(Nil) { + use acc, msg <- eval_extra.try_fold(messages, from: Nil) + use _ <- eval.try(send(msg)) + eval.return(acc) +} + +/// Start receiving and discarding messages until a `ReadyForQuery` message is +/// received. +/// +fn wait_until_ready() -> Db(Nil) { + use _ <- eval.try(send(pg.FeFlush)) + do_wait_until_ready() +} + +fn do_wait_until_ready() -> Db(Nil) { + use msg <- eval.try(receive()) + case msg { + pg.BeReadyForQuery(_) -> eval.return(Nil) + _ -> do_wait_until_ready() + } +} + +/// Throws an error built from an expected message and an unexpected message +/// that is turned into a string. +/// +fn unexpected_message( + builder: fn(String, String) -> Error, + expected expected: String, + got got: pg.BackendMessage, +) { + builder(expected, string.inspect(got)) |> eval.throw +} + +/// Looks up for a type with the given id in a global cache. +/// If the type is present it immediately returns it. +/// Otherwise it runs the database action to fetch it and then caches it to be +/// reused later. +/// +fn with_cached_gleam_type( + lookup oid: Int, + otherwise do: fn() -> Db(gleam.Type), +) -> Db(gleam.Type) { + use context: Context <- eval.from + case dict.get(context.gleam_types, oid) { + Ok(type_) -> #(context, Ok(type_)) + Error(_) -> + case eval.step(do(), context) { + #(_, Error(_)) as result -> result + #(Context(gleam_types: gleam_types, ..) as context, Ok(type_)) -> { + let gleam_types = dict.insert(gleam_types, oid, type_) + let new_context = Context(..context, gleam_types: gleam_types) + #(new_context, Ok(type_)) + } + } + } +} + +/// Looks up for the nullability of table's column. +/// If the nullability is cached it is immediately returns it. +/// Otherwise it runs the database action to fetch it and then caches it to be +/// reused later. +/// +fn with_cached_column( + table table_oid: Int, + column column: Int, + otherwise do: fn() -> Db(Nullability), +) -> Db(Nullability) { + use context: Context <- eval.from + let key = #(table_oid, column) + case dict.get(context.column_nullability, key) { + Ok(type_) -> #(context, Ok(type_)) + Error(_) -> + case eval.step(do(), context) { + #(_, Error(_)) as result -> result + #( + Context( + column_nullability: column_nullability, + .., + ) as context, + Ok(type_), + ) -> { + let column_nullability = dict.insert(column_nullability, key, type_) + let new_context = + Context(..context, column_nullability: column_nullability) + #(new_context, Ok(type_)) + } + } + } +} + +// --- HELPERS TO BUILD ERRORS ------------------------------------------------- + +fn unsupported_type_error(query: UntypedQuery, type_: String) -> Error { + let UntypedQuery( + content: content, + file: file, + name: name, + starting_line: starting_line, + ) = query + QueryHasUnsupportedType( + file: file, + name: gleam.identifier_to_string(name), + content: content, + type_: type_, + starting_line: starting_line, + ) +} + +fn cannot_parse_error( + query: UntypedQuery, + error_code: Option(String), + hint: Option(String), + pointer: Option(Pointer), +) -> Error { + let UntypedQuery( + content: content, + file: file, + name: name, + starting_line: starting_line, + ) = query + CannotParseQuery( + content: content, + file: file, + name: gleam.identifier_to_string(name), + error_code: error_code, + hint: hint, + pointer: pointer, + starting_line: starting_line, + ) +} + +fn invalid_column_error( + query: UntypedQuery, + column_name: String, + reason: ValueIdentifierError, +) -> Error { + let UntypedQuery( + name: _, + file: file, + content: content, + starting_line: starting_line, + ) = query + QueryHasInvalidColumn( + file: file, + column_name: column_name, + suggested_name: gleam.similar_identifier_string(column_name) + |> option.from_result, + content: content, + reason: reason, + starting_line: starting_line, + ) +} + +// --- DECODERS ---------------------------------------------------------------- + +fn json_plans_decoder(data: Dynamic) -> Result(List(Plan), DecodeErrors) { + d.list(d.field("Plan", plan_decoder))(data) +} + +fn plan_decoder(data: Dynamic) -> Result(Plan, DecodeErrors) { + d.decode4( + Plan, + d.optional_field("Join Type", join_type_decoder), + d.optional_field("Parent Relationship", parent_relation_decoder), + d.optional_field("Output", d.list(d.string)), + d.optional_field("Plans", d.list(plan_decoder)), + )(data) +} + +fn join_type_decoder(data: Dynamic) -> Result(JoinType, DecodeErrors) { + use data <- result.map(d.string(data)) + case data { + "Full" -> Full + "Left" -> Left + "Right" -> Right + _ -> Other + } +} + +fn parent_relation_decoder( + data: Dynamic, +) -> Result(ParentRelation, DecodeErrors) { + use data <- result.map(d.string(data)) + case data { + "Inner" -> Inner + _ -> NotInner + } +} + +// --- UTILS ------------------------------------------------------------------- + +/// Turns a bit array into a boolean value. +/// Returns `False` if the bit array is all `0`s or empty, `True` otherwise. +/// +fn bit_array_to_bool(bit_array: BitArray) -> Bool { + case bit_array { + <<0, rest:bits>> -> bit_array_to_bool(rest) + <<>> -> False + _ -> True + } +} diff --git a/src/squirrel/internal/database/postgres_protocol.gleam b/src/squirrel/internal/database/postgres_protocol.gleam new file mode 100644 index 0000000..10cb525 --- /dev/null +++ b/src/squirrel/internal/database/postgres_protocol.gleam @@ -0,0 +1,1387 @@ +//// Vendored version of https://hex.pm/packages/postgresql_protocol. +//// - do not fail with an unexpected Command if the outer message is ok +//// +//// This library parses and generates packages for the PostgreSQL Binary Protocol +//// It also provides a basic connection abstraction, but this hasn't been used +//// outside of tests. + +import gleam/bit_array +import gleam/bool +import gleam/dict +import gleam/int +import gleam/list +import gleam/result.{try} +import gleam/set +import mug + +pub type Connection { + Connection(socket: mug.Socket, buffer: BitArray, timeout: Int) +} + +pub fn connect(host, port, timeout) { + let assert Ok(socket) = + mug.connect(mug.ConnectionOptions(host: host, port: port, timeout: timeout)) + + Connection(socket: socket, buffer: <<>>, timeout: timeout) +} + +pub type StateInitial { + StateInitial(parameters: dict.Dict(String, String)) +} + +pub type State { + State( + process_id: Int, + secret_key: Int, + parameters: dict.Dict(String, String), + oids: dict.Dict(Int, fn(BitArray) -> Result(Int, Nil)), + ) +} + +fn default_oids() { + dict.new() + |> dict.insert(23, fn(col: BitArray) { + use converted <- try(bit_array.to_string(col)) + int.parse(converted) + }) +} + +pub fn start(conn, params) { + let assert Ok(conn) = + conn + |> send(encode_startup_message(params)) + + let assert Ok(#(conn, state)) = + conn + |> receive_startup_rec(StateInitial(dict.new())) + + #(conn, state) +} + +fn receive_startup_rec(conn: Connection, state: StateInitial) { + case receive(conn) { + Ok(#(conn, BeAuthenticationOk)) -> receive_startup_rec(conn, state) + Ok(#(conn, BeParameterStatus(name: name, value: value))) -> + receive_startup_rec( + conn, + StateInitial(parameters: dict.insert(state.parameters, name, value)), + ) + Ok(#(conn, BeBackendKeyData(secret_key: secret_key, process_id: process_id))) -> + receive_startup_1( + conn, + State( + parameters: state.parameters, + secret_key: secret_key, + process_id: process_id, + oids: default_oids(), + ), + ) + Ok(#(_conn, msg)) -> Error(StartupFailedWithUnexpectedMessage(msg)) + Error(err) -> Error(StartupFailedWithError(err)) + } +} + +fn receive_startup_1(conn: Connection, state: State) { + case receive(conn) { + Ok(#(conn, BeParameterStatus(name: name, value: value))) -> + receive_startup_1( + conn, + State(..state, parameters: dict.insert(state.parameters, name, value)), + ) + Ok(#(conn, BeReadyForQuery(_))) -> Ok(#(conn, state)) + Ok(#(_conn, msg)) -> Error(StartupFailedWithUnexpectedMessage(msg)) + Error(err) -> Error(StartupFailedWithError(err)) + } +} + +pub type StartupFailed { + StartupFailedWithUnexpectedMessage(BackendMessage) + StartupFailedWithError(ReadError) +} + +pub type ReadError { + SocketError(mug.Error) + ReadDecodeError(MessageDecodingError) +} + +pub type MessageDecodingError { + MessageDecodingError(String) + MessageIncomplete(BitArray) + UnknownMessage(data: BitArray) +} + +fn dec_err(desc: String, data: BitArray) -> Result(a, MessageDecodingError) { + Error(MessageDecodingError(desc <> "; data: " <> bit_array.inspect(data))) +} + +fn msg_dec_err(desc: String, data: BitArray) -> MessageDecodingError { + MessageDecodingError(desc <> "; data: " <> bit_array.inspect(data)) +} + +pub type Command { + Insert + Delete + Update + Merge + Select + Move + Fetch + Copy +} + +/// Messages originating from the PostgreSQL backend (server) +pub type BackendMessage { + BeBindComplete + BeCloseComplete + BeCommandComplete(Command, Int) + BeCopyData(data: BitArray) + BeCopyDone + BeAuthenticationOk + BeAuthenticationKerberosV5 + BeAuthenticationCleartextPassword + BeAuthenticationMD5Password(salt: BitArray) + BeAuthenticationGSS + BeAuthenticationGSSContinue(auth_data: BitArray) + BeAuthenticationSSPI + BeAuthenticationSASL(mechanisms: List(String)) + BeAuthenticationSASLContinue(data: BitArray) + BeAuthenticationSASLFinal(data: BitArray) + BeReadyForQuery(TransactionStatus) + BeRowDescriptions(List(RowDescriptionField)) + BeMessageDataRow(List(BitArray)) + BeBackendKeyData(process_id: Int, secret_key: Int) + BeParameterStatus(name: String, value: String) + BeCopyResponse( + direction: CopyDirection, + overall_format: Format, + codes: List(Format), + ) + BeNegotiateProtocolVersion( + newest_minor: Int, + unrecognized_options: List(String), + ) + BeNoData + BeNoticeResponse(set.Set(ErrorOrNoticeField)) + BeNotificationResponse(process_id: Int, channel: String, payload: String) + BeParameterDescription(List(Int)) + BeParseComplete + BePortalSuspended + BeErrorResponse(set.Set(ErrorOrNoticeField)) +} + +/// Direction of a BeCopyResponse +pub type CopyDirection { + In + Out + Both +} + +/// Indicates encoding of column values +pub type Format { + Text + Binary +} + +/// Lists of parameters can have different encodings. +pub type FormatValue { + FormatAllText + FormatAll(Format) + Formats(List(Format)) +} + +fn encode_format_value(format) -> BitArray { + case format { + FormatAllText -> <> + FormatAll(Text) -> <<1:16, encode_format(Text):16>> + FormatAll(Binary) -> <<1:16, encode_format(Binary):16>> + Formats(formats) -> { + let size = list.length(formats) + list.fold(formats, <>, fn(sum, fmt) { + <> + }) + } + } +} + +pub type ParameterValues = + List(ParameterValue) + +pub type ParameterValue { + Parameter(BitArray) + Null +} + +fn parameters_to_bytes(parameters: ParameterValues) { + let mapped = + parameters + |> list.map(fn(parameter) { + case parameter { + Parameter(value) -> <> + Null -> <<-1:32>> + } + }) + + <> +} + +pub type FrontendMessage { + FeBind( + portal: String, + statement_name: String, + parameter_format: FormatValue, + parameters: ParameterValues, + result_format: FormatValue, + ) + FeCancelRequest(process_id: Int, secret_key: Int) + FeClose(what: What, name: String) + FeCopyData(data: BitArray) + FeCopyDone + FeCopyFail(error: String) + FeDescribe(what: What, name: String) + FeExecute(portal: String, return_row_count: Int) + FeFlush + FeFunctionCall( + object_id: Int, + argument_format: FormatValue, + arguments: ParameterValues, + result_format: Format, + ) + FeGssEncRequest + FeParse(name: String, query: String, parameter_object_ids: List(Int)) + FeQuery(query: String) + FeStartupMessage(params: List(#(String, String))) + FeSslRequest + FeTerminate + FeSync + FeAmbigous(FeAmbigous) +} + +// These all share the same message type +pub type FeAmbigous { + FeGssResponse(data: BitArray) + FeSaslInitialResponse(name: String, data: BitArray) + FeSaslResponse(data: BitArray) + FePasswordMessage(password: String) +} + +pub type What { + PreparedStatement + Portal +} + +fn wire_what(what) { + case what { + Portal -> <<"P":utf8>> + PreparedStatement -> <<"S":utf8>> + } +} + +fn decode_what(binary) { + case binary { + <<"P":utf8, rest:bytes>> -> Ok(#(Portal, rest)) + <<"S":utf8, rest:bytes>> -> Ok(#(PreparedStatement, rest)) + _ -> dec_err("only portal and prepared statement are allowed", binary) + } +} + +fn encode_message_data_row(columns) { + <> + |> list.fold( + columns, + _, + fn(sum, column) { + let len = bit_array.byte_size(column) + <> + }, + ) + |> encode("D", _) +} + +fn encode_error_response(fields: set.Set(ErrorOrNoticeField)) -> BitArray { + fields + |> set.fold(<<>>, fn(sum, field) { <> }) + |> encode("E", _) +} + +fn encode_field(field) { + case field { + Severity(value) -> <<"S":utf8, value:utf8, 0>> + SeverityLocalized(value) -> <<"V":utf8, value:utf8, 0>> + Code(value) -> <<"C":utf8, value:utf8, 0>> + Message(value) -> <<"M":utf8, value:utf8, 0>> + Detail(value) -> <<"D":utf8, value:utf8, 0>> + Hint(value) -> <<"H":utf8, value:utf8, 0>> + Position(value) -> <<"P":utf8, value:utf8, 0>> + InternalPosition(value) -> <<"p":utf8, value:utf8, 0>> + InternalQuery(value) -> <<"q":utf8, value:utf8, 0>> + Where(value) -> <<"W":utf8, value:utf8, 0>> + Schema(value) -> <<"s":utf8, value:utf8, 0>> + Table(value) -> <<"t":utf8, value:utf8, 0>> + Column(value) -> <<"c":utf8, value:utf8, 0>> + DataType(value) -> <<"d":utf8, value:utf8, 0>> + Constraint(value) -> <<"n":utf8, value:utf8, 0>> + File(value) -> <<"F":utf8, value:utf8, 0>> + Line(value) -> <<"L":utf8, value:utf8, 0>> + Routine(value) -> <<"R":utf8, value:utf8, 0>> + Unknown(key, value) -> <> + } +} + +fn encode_authentication_sasl(mechanisms) -> BitArray { + list.fold(mechanisms, <<10:32>>, fn(sum, mechanism) { + <> + }) + |> encode("R", _) +} + +fn encode_command_complete(command, rows_num) { + let rows = int.to_string(rows_num) + + let data = + case command { + Insert -> "INSERT 0 " <> rows + Delete -> "DELETE " <> rows + Update -> "UPDATE " <> rows + Merge -> "MERGE " <> rows + Select -> "SELECT " <> rows + Move -> "MOVE " <> rows + Fetch -> "FETCH " <> rows + Copy -> "COPY " <> rows + } + |> encode_string() + + encode("C", data) +} + +fn encode_copy_response( + direction: CopyDirection, + overall_format: Format, + codes: List(Format), +) { + let data = encode_copy_response_rec(overall_format, codes) + case direction { + In -> encode("G", data) + Out -> encode("H", data) + Both -> encode("W", data) + } +} + +// The format codes to be used for each column. Each must presently be zero +// (text) or one (binary). All must be zero if the overall copy format is +// textual. +fn encode_copy_response_rec(overall_format, codes) { + case overall_format { + Text -> + codes + |> list.fold( + <>, + fn(sum, _code) { <> }, + ) + Binary -> + codes + |> list.fold( + <>, + fn(sum, code) { <> }, + ) + } +} + +fn encode_parameter_status(name: String, value: String) -> BitArray { + encode("S", <>) +} + +fn encode_negotiate_protocol_version( + newest_minor: Int, + unrecognized_options: List(String), +) { + list.fold( + unrecognized_options, + <>, + fn(sum, option) { <> }, + ) + |> encode("v", _) +} + +fn encode_notice_response(fields: set.Set(ErrorOrNoticeField)) { + fields + |> set.fold(<<>>, fn(sum, field) { <> }) + |> encode("N", _) +} + +fn encode_notification_response( + process_id: Int, + channel: String, + payload: String, +) { + encode("A", <>) +} + +fn encode_parameter_description(descriptions: List(Int)) { + descriptions + |> list.fold(<>, fn(sum, description) { + <> + }) + |> encode("t", _) +} + +fn encode_row_descriptions(fields: List(RowDescriptionField)) { + fields + |> list.fold(<>, fn(sum, field) { + << + sum:bits, + encode_string(field.name):bits, + field.table_oid:32, + field.attr_number:16, + field.data_type_oid:32, + field.data_type_size:16, + field.type_modifier:32, + field.format_code:16, + >> + }) + |> encode("T", _) +} + +pub fn encode_backend_message(message: BackendMessage) -> BitArray { + case message { + BeMessageDataRow(columns) -> encode_message_data_row(columns) + BeErrorResponse(fields) -> encode_error_response(fields) + BeAuthenticationOk -> encode("R", <<0:32>>) + BeAuthenticationKerberosV5 -> encode("R", <<2:32>>) + BeAuthenticationCleartextPassword -> encode("R", <<3:32>>) + BeAuthenticationMD5Password(salt: salt) -> encode("R", <<5:32, salt:bits>>) + BeAuthenticationGSS -> encode("R", <<7:32>>) + BeAuthenticationGSSContinue(data) -> encode("R", <<8:32, data:bits>>) + BeAuthenticationSSPI -> encode("R", <<9:32>>) + BeAuthenticationSASL(a) -> encode_authentication_sasl(a) + BeAuthenticationSASLContinue(data) -> encode("R", <<11:32, data:bits>>) + BeAuthenticationSASLFinal(data: data) -> encode("R", <<12:32, data:bits>>) + BeBackendKeyData(pid, sk) -> encode("K", <>) + BeBindComplete -> encode("2", <<>>) + BeCloseComplete -> encode("3", <<>>) + BeCommandComplete(a, b) -> encode_command_complete(a, b) + BeCopyData(data) -> encode("d", data) + BeCopyDone -> encode("c", <<>>) + BeCopyResponse(a, b, c) -> encode_copy_response(a, b, c) + BeParameterStatus(name, value) -> encode_parameter_status(name, value) + BeNegotiateProtocolVersion(a, b) -> encode_negotiate_protocol_version(a, b) + BeNoData -> encode("n", <<>>) + BeNoticeResponse(data) -> encode_notice_response(data) + BeNotificationResponse(a, b, c) -> encode_notification_response(a, b, c) + BeParameterDescription(a) -> encode_parameter_description(a) + BeParseComplete -> encode("1", <<>>) + BePortalSuspended -> encode("s", <<>>) + BeRowDescriptions(a) -> encode_row_descriptions(a) + BeReadyForQuery(TransactionStatusIdle) -> encode("Z", <<"I":utf8>>) + BeReadyForQuery(TransactionStatusInTransaction) -> encode("Z", <<"T":utf8>>) + BeReadyForQuery(TransactionStatusFailed) -> encode("Z", <<"E":utf8>>) + } +} + +pub fn encode_frontend_message(message: FrontendMessage) { + case message { + FeBind(a, b, c, d, e) -> encode_bind(a, b, c, d, e) + FeCancelRequest(process_id: process_id, secret_key: secret_key) -> << + 16:32, + 1234:16, + 5678:16, + process_id:32, + secret_key:32, + >> + FeClose(what, name) -> + encode("C", <>) + FeCopyData(data) -> encode("d", data) + FeCopyDone -> <<"c":utf8, 4:32>> + FeCopyFail(error) -> encode("f", encode_string(error)) + FeDescribe(what, name) -> + encode("D", <>) + FeExecute(portal, count) -> + encode("E", <>) + FeFlush -> <<"H":utf8, 4:32>> + FeFunctionCall(a, b, c, d) -> encode_function_call(a, b, c, d) + FeGssEncRequest -> <<8:32, 1234:16, 5680:16>> + FeAmbigous(FeGssResponse(data)) -> encode("p", data) + FeParse(a, b, c) -> encode_parse(a, b, c) + FeAmbigous(FePasswordMessage(password)) -> + encode("p", encode_string(password)) + FeQuery(query) -> encode("Q", encode_string(query)) + FeAmbigous(FeSaslInitialResponse(a, b)) -> + encode_sasl_initial_response(a, b) + FeAmbigous(FeSaslResponse(data)) -> encode("p", data) + FeStartupMessage(params) -> encode_startup_message(params) + FeSslRequest -> <<8:32, 1234:16, 5679:16>> + FeSync -> <<"S":utf8, 4:32>> + FeTerminate -> <<"X":utf8, 4:32>> + } +} + +fn encode(type_char, data) { + case data { + <<>> -> <> + _ -> { + let len = bit_array.byte_size(data) + 4 + <> + } + } +} + +fn encode_string(str) { + <> +} + +pub const protocol_version_major = <<3:16>> + +pub const protocol_version_minor = <<0:16>> + +pub const protocol_version = << + protocol_version_major:bits, + protocol_version_minor:bits, +>> + +fn encode_startup_message(params) { + let packet = + params + |> list.fold(<>, fn(builder, element) { + let #(key, value) = element + <> + }) + + let size = bit_array.byte_size(packet) + 5 + + <> +} + +fn encode_sasl_initial_response(name, data) { + let len = bit_array.byte_size(data) + encode("p", <>) +} + +fn encode_parse(name, query, parameter_object_ids) { + let oids = + list.fold(parameter_object_ids, <<>>, fn(sum, oid) { <> }) + let len = list.length(parameter_object_ids) + encode("P", << + encode_string(name):bits, + encode_string(query):bits, + len:16, + oids:bits, + >>) +} + +fn encode_bind( + portal, + statement_name, + parameter_format, + parameters, + result_format, +) { + encode("B", << + portal:utf8, + 0, + statement_name:utf8, + 0, + encode_format_value(parameter_format):bits, + parameters_to_bytes(parameters):bits, + encode_format_value(result_format):bits, + >>) +} + +fn encode_function_call( + object_id: Int, + argument_format: FormatValue, + arguments: ParameterValues, + result_format: Format, +) { + encode("F", << + object_id:32, + encode_format_value(argument_format):bits, + parameters_to_bytes(arguments):bits, + encode_format(result_format):16, + >>) +} + +/// Send a message to the database +pub fn send_builder(conn: Connection, message) { + case mug.send_builder(conn.socket, message) { + Ok(Nil) -> Ok(conn) + Error(err) -> Error(err) + } +} + +/// Send a message to the database +pub fn send(conn: Connection, message) { + case mug.send(conn.socket, message) { + Ok(Nil) -> Ok(conn) + Error(err) -> Error(err) + } +} + +/// Receive a single message from the backend +pub fn receive( + conn: Connection, +) -> Result(#(Connection, BackendMessage), ReadError) { + case decode_backend_packet(conn.buffer) { + Ok(#(message, rest)) -> Ok(#(with_buffer(conn, rest), message)) + Error(MessageIncomplete(_)) -> { + case mug.receive(conn.socket, conn.timeout) { + Ok(packet) -> + receive(with_buffer(conn, <>)) + Error(err) -> Error(SocketError(err)) + } + } + Error(err) -> Error(ReadDecodeError(err)) + } +} + +fn with_buffer(conn: Connection, buffer: BitArray) { + Connection(..conn, buffer: buffer) +} + +// decode a single message from the packet +pub fn decode_backend_packet( + packet: BitArray, +) -> Result(#(BackendMessage, BitArray), MessageDecodingError) { + case packet { + <> -> { + let len = length - 4 + case tail { + <> -> + decode_backend_message(<>) + |> result.map(fn(msg) { #(msg, next) }) + _ -> Error(MessageIncomplete(tail)) + } + } + _ -> + case bit_array.byte_size(packet) < 5 { + True -> Error(MessageIncomplete(packet)) + False -> dec_err("packet size too small", packet) + } + } +} + +// decode a backend message +pub fn decode_backend_message(binary) { + case binary { + <<"D":utf8, count:16, data:bytes>> -> decode_message_data_row(count, data) + <<"E":utf8, data:bytes>> -> decode_error_response(data) + <<"R":utf8, 0:32>> -> Ok(BeAuthenticationOk) + <<"R":utf8, 2:32>> -> Ok(BeAuthenticationKerberosV5) + <<"R":utf8, 3:32>> -> Ok(BeAuthenticationCleartextPassword) + <<"R":utf8, 5:32, salt:bytes>> -> + Ok(BeAuthenticationMD5Password(salt: salt)) + <<"R":utf8, 7:32>> -> Ok(BeAuthenticationGSS) + <<"R":utf8, 8:32, auth_data:bytes>> -> + Ok(BeAuthenticationGSSContinue(auth_data: auth_data)) + <<"R":utf8, 9:32>> -> Ok(BeAuthenticationSSPI) + <<"R":utf8, 10:32, data:bytes>> -> decode_authentication_sasl(data) + <<"R":utf8, 11:32, data:bytes>> -> + Ok(BeAuthenticationSASLContinue(data: data)) + <<"R":utf8, 12:32, data:bytes>> -> Ok(BeAuthenticationSASLFinal(data: data)) + <<"K":utf8, pid:32, sk:32>> -> + Ok(BeBackendKeyData(process_id: pid, secret_key: sk)) + <<"2":utf8>> -> Ok(BeBindComplete) + <<"3":utf8>> -> Ok(BeCloseComplete) + <<"C":utf8, data:bytes>> -> decode_command_complete(data) + <<"d":utf8, data:bytes>> -> Ok(BeCopyData(data)) + <<"c":utf8>> -> Ok(BeCopyDone) + <<"G":utf8, format:8, count:16, data:bytes>> -> + decode_copy_response(In, format, count, data) + <<"H":utf8, format:8, count:16, data:bytes>> -> + decode_copy_response(Out, format, count, data) + <<"W":utf8, format:8, count:16, data:bytes>> -> + decode_copy_response(Both, format, count, data) + <<"S":utf8, data:bytes>> -> decode_parameter_status(data) + <<"v":utf8, version:32, count:32, data:bytes>> -> + decode_negotiate_protocol_version(version, count, data) + <<"n":utf8>> -> Ok(BeNoData) + <<"N":utf8, data:bytes>> -> decode_notice_response(data) + <<"A":utf8, process_id:32, data:bytes>> -> + decode_notification_response(process_id, data) + <<"t":utf8, count:16, data:bytes>> -> + decode_parameter_description(count, data, []) + <<"1":utf8>> -> Ok(BeParseComplete) + <<"s":utf8>> -> Ok(BePortalSuspended) + <<"T":utf8, count:16, data:bytes>> -> decode_row_descriptions(count, data) + <<"Z":utf8, "I":utf8>> -> Ok(BeReadyForQuery(TransactionStatusIdle)) + <<"Z":utf8, "T":utf8>> -> + Ok(BeReadyForQuery(TransactionStatusInTransaction)) + <<"Z":utf8, "E":utf8>> -> Ok(BeReadyForQuery(TransactionStatusFailed)) + _ -> Error(UnknownMessage(binary)) + } +} + +/// decode a single message from the packet buffer. +/// note that FeCancelRequest, FeSslRequest, FeGssEncRequest, and +/// FeStartupMessage messages can only be decoded here because they don't follow +/// the standard message format. +pub fn decode_frontend_packet( + packet: BitArray, +) -> Result(#(FrontendMessage, BitArray), MessageDecodingError) { + case packet { + <<16:32, 1234:16, 5678:16, process_id:32, secret_key:32, next:bytes>> -> + Ok(#( + FeCancelRequest(process_id: process_id, secret_key: secret_key), + next, + )) + <<8:32, 1234:16, 5679:16, next:bytes>> -> Ok(#(FeSslRequest, next)) + <<8:32, 1234:16, 5680:16, next:bytes>> -> Ok(#(FeGssEncRequest, next)) + // not sure if there's a way to use the `protocol_version` constant here + <> -> + decode_startup_message(next, length - 8, []) + <> -> { + let len = length - 4 + case tail { + <> -> + decode_frontend_message(<>) + |> result.map(fn(msg) { #(msg, next) }) + _ -> Error(MessageIncomplete(tail)) + } + } + <<_:48>> -> Error(MessageIncomplete(packet)) + _ -> dec_err("invalid message", packet) + } +} + +/// decode a frontend message (also see decode_frontend_packet for messages that +/// can't be decoded here) +pub fn decode_frontend_message( + binary: BitArray, +) -> Result(FrontendMessage, MessageDecodingError) { + case binary { + <<"B":utf8, data:bytes>> -> decode_bind(data) + <<"C":utf8, data:bytes>> -> decode_close(data) + <<"d":utf8, data:bytes>> -> Ok(FeCopyData(data)) + <<"c":utf8>> -> Ok(FeCopyDone) + <<"f":utf8, data:bytes>> -> decode_copy_fail(data) + <<"D":utf8, data:bytes>> -> decode_describe(data) + <<"E":utf8, data:bytes>> -> decode_execute(data) + <<"H":utf8>> -> Ok(FeFlush) + <<"F":utf8, data:bytes>> -> decode_function_call(data) + <<"p":utf8, data:bytes>> -> Ok(FeAmbigous(FeGssResponse(data))) + <<"P":utf8, data:bytes>> -> decode_parse(data) + <<"Q":utf8, data:bytes>> -> decode_query(data) + <<"S":utf8>> -> Ok(FeSync) + <<"X":utf8>> -> Ok(FeTerminate) + _ -> Error(UnknownMessage(data: binary)) + } +} + +fn decode_startup_message(binary, size, result) { + case binary { + <> -> + decode_startup_message_pairs(data, []) + |> result.map(fn(r) { #(r, next) }) + _ -> dec_err("invalid startup message", binary) + } +} + +fn decode_startup_message_pairs(binary, result) { + case binary { + <<0>> -> Ok(FeStartupMessage(params: list.reverse(result))) + _ -> { + use #(key, binary) <- try(decode_string(binary)) + use #(value, binary) <- try(decode_string(binary)) + decode_startup_message_pairs(binary, [#(key, value), ..result]) + } + } +} + +fn decode_query(binary) -> Result(FrontendMessage, MessageDecodingError) { + use #(query, rest) <- try(decode_string(binary)) + case rest { + <<>> -> Ok(FeQuery(query)) + _ -> dec_err("Query message too long", binary) + } +} + +// FeQuery(query) -> encode("Q", encode_string(query)) + +fn decode_parse(binary) -> Result(FrontendMessage, MessageDecodingError) { + use #(name, binary) <- try(decode_string(binary)) + use #(query, binary) <- try(decode_string(binary)) + use parameter_object_ids <- try(decode_parameter_object_ids(binary)) + Ok(FeParse( + name: name, + query: query, + parameter_object_ids: parameter_object_ids, + )) +} + +fn decode_parameter_object_ids(binary) { + case binary { + <> -> decode_parameter_object_ids_rec(rest, count, []) + _ -> dec_err("expected object id count", binary) + } +} + +fn decode_parameter_object_ids_rec(binary, count, result) { + case count, binary { + 0, <<>> -> Ok(list.reverse(result)) + _, <> -> + decode_parameter_object_ids_rec(rest, count - 1, [id, ..result]) + _, _ -> dec_err("expected parameter object id", binary) + } +} + +fn decode_function_call(binary) -> Result(FrontendMessage, MessageDecodingError) { + case binary { + <> -> { + use #(argument_format, rest) <- try(read_parameter_format(rest)) + use #(arguments, rest) <- try(read_parameters(rest, argument_format)) + use #(result_format, rest) <- try(read_format(rest)) + case rest { + <<>> -> + Ok(FeFunctionCall( + argument_format: argument_format, + arguments: arguments, + object_id: object_id, + result_format: result_format, + )) + _ -> dec_err("invalid function call, data remains", rest) + } + } + _ -> dec_err("invalid function call, no object id found", binary) + } +} + +fn decode_execute(binary) -> Result(FrontendMessage, MessageDecodingError) { + use #(portal, binary) <- try(decode_string(binary)) + case binary { + <> -> Ok(FeExecute(portal, count)) + _ -> dec_err("no execute return_row_count found", binary) + } +} + +fn decode_describe(binary) -> Result(FrontendMessage, MessageDecodingError) { + use #(what, binary) <- try(decode_what(binary)) + use #(name, binary) <- try(decode_string(binary)) + case binary { + <<>> -> Ok(FeDescribe(what, name)) + _ -> dec_err("Describe message too long", binary) + } +} + +fn decode_copy_fail(binary) -> Result(FrontendMessage, MessageDecodingError) { + use #(error, binary) <- try(decode_string(binary)) + case binary { + <<>> -> Ok(FeCopyFail(error)) + _ -> dec_err("CopyFail message too long", binary) + } +} + +fn decode_close(binary) -> Result(FrontendMessage, MessageDecodingError) { + use #(what, binary) <- try(decode_what(binary)) + use #(name, binary) <- try(decode_string(binary)) + case binary { + <<>> -> Ok(FeClose(what, name)) + _ -> dec_err("Close message too long", binary) + } +} + +// FeClose(what, name) -> +// encode("C", <>) + +fn decode_bind(binary) -> Result(FrontendMessage, MessageDecodingError) { + use #(portal, binary) <- try(decode_string(binary)) + use #(statement_name, binary) <- try(decode_string(binary)) + use #(parameter_format, binary) <- try(read_parameter_format(binary)) + use #(parameters, binary) <- try(read_parameters(binary, parameter_format)) + use #(result_format, binary) <- try(read_parameter_format(binary)) + case binary { + <<>> -> + Ok(FeBind( + portal: portal, + statement_name: statement_name, + parameter_format: parameter_format, + parameters: parameters, + result_format: result_format, + )) + _ -> dec_err("Bind message too long", binary) + } +} + +fn decode_string( + binary: BitArray, +) -> Result(#(String, BitArray), MessageDecodingError) { + case binary_split(binary, <<0>>, []) { + [head, tail] -> + case bit_array.to_string(head) { + Ok(str) -> Ok(#(str, tail)) + Error(Nil) -> dec_err("invalid string encoding", head) + } + _ -> dec_err("invalid string", binary) + } +} + +fn read_parameter_format( + binary: BitArray, +) -> Result(#(FormatValue, BitArray), MessageDecodingError) { + case binary { + <<0:16, rest:bytes>> -> Ok(#(FormatAllText, rest)) + <<1:16, 0:16, rest:bytes>> -> Ok(#(FormatAll(Text), rest)) + <<1:16, 1:16, rest:bytes>> -> Ok(#(FormatAll(Binary), rest)) + <> -> read_wire_formats(n, rest, []) + _ -> dec_err("invalid parameter format", binary) + } +} + +fn read_parameters(binary: BitArray, parameter_format: FormatValue) { + case binary { + <> -> { + case parameter_format { + FormatAllText -> list.repeat(Text, count) + FormatAll(format) -> list.repeat(format, count) + Formats(formats) -> formats + } + |> read_parameters_rec(count, rest, []) + } + _ -> dec_err("parameters without count", binary) + } +} + +fn read_parameters_rec( + formats: List(Format), + count: Int, + binary: BitArray, + result: List(ParameterValue), +) -> Result(#(List(ParameterValue), BitArray), MessageDecodingError) { + let actual = list.length(formats) + use <- bool.guard( + actual != count, + dec_err( + "expected " + <> int.to_string(count) + <> " parameters, but got " + <> int.to_string(actual), + binary, + ), + ) + + case count, formats, binary { + 0, _, rest -> Ok(#(list.reverse(result), rest)) + _, [_, ..rest_formats], <<-1:32-signed, rest:bytes>> -> { + read_parameters_rec(rest_formats, count - 1, rest, [Null, ..result]) + } + _, [format, ..rest_formats], <> -> + read_parameters_rec(rest_formats, count - 1, rest, [ + read_parameter(format, value), + ..result + ]) + _, _, rest -> dec_err("invalid parameter value", rest) + } +} + +fn read_parameter(format: Format, value: BitArray) { + case format { + Text -> Parameter(value) + Binary -> Parameter(value) + } +} + +fn read_wire_formats( + count: Int, + binary: BitArray, + result: List(Format), +) -> Result(#(FormatValue, BitArray), MessageDecodingError) { + case count, binary { + 0, _ -> Ok(#(Formats(list.reverse(result)), binary)) + _, <<0:16, rest:bytes>> -> + read_wire_formats(count - 1, rest, [Text, ..result]) + _, <<1:16, rest:bytes>> -> + read_wire_formats(count - 1, rest, [Binary, ..result]) + _, _ -> dec_err("unknown format", binary) + } +} + +fn decode_command_complete( + binary: BitArray, +) -> Result(BackendMessage, MessageDecodingError) { + { + use fine <- try(case binary { + <<"INSERT 0 ":utf8, rows:bytes>> -> Ok(#(Insert, rows)) + <<"DELETE ":utf8, rows:bytes>> -> Ok(#(Delete, rows)) + <<"UPDATE ":utf8, rows:bytes>> -> Ok(#(Update, rows)) + <<"MERGE ":utf8, rows:bytes>> -> Ok(#(Merge, rows)) + <<"SELECT ":utf8, rows:bytes>> -> Ok(#(Select, rows)) + <<"MOVE ":utf8, rows:bytes>> -> Ok(#(Move, rows)) + <<"FETCH ":utf8, rows:bytes>> -> Ok(#(Fetch, rows)) + <<"COPY ":utf8, rows:bytes>> -> Ok(#(Copy, rows)) + _ -> dec_err("invalid command", binary) + }) + + let #(command, rows_raw) = fine + let len = bit_array.byte_size(rows_raw) - 1 + + use rows_bits <- try(case rows_raw { + <> -> Ok(rows_bits) + _ -> dec_err("invalid command row count", binary) + }) + + use rows_string <- try( + bit_array.to_string(rows_bits) + |> result.replace_error(msg_dec_err( + "failed to convert row count to string", + rows_bits, + )), + ) + + use rows <- try( + int.parse(rows_string) + |> result.replace_error(msg_dec_err( + "failed to convert row count to int", + rows_bits, + )), + ) + + Ok(BeCommandComplete(command, rows)) + } + |> result.or(Ok(BeCommandComplete(Insert, -1))) +} + +pub type TransactionStatus { + TransactionStatusIdle + TransactionStatusInTransaction + TransactionStatusFailed +} + +fn decode_parameter_description(count, binary, results) { + case count, binary { + 0, <<>> -> Ok(BeParameterDescription(list.reverse(results))) + _, <> -> + decode_parameter_description(count - 1, tail, [value, ..results]) + _, _ -> dec_err("invalid parameter description", binary) + } +} + +fn decode_notification_response( + process_id, + binary, +) -> Result(BackendMessage, MessageDecodingError) { + use strings <- try(read_strings(binary, 2, [])) + case strings { + [channel, payload] -> + Ok(BeNotificationResponse( + process_id: process_id, + channel: channel, + payload: payload, + )) + _ -> dec_err("invalid notification response encoding", binary) + } +} + +fn decode_negotiate_protocol_version(version, count, binary) { + use options <- try(read_strings(binary, count, [])) + Ok(BeNegotiateProtocolVersion(version, options)) +} + +fn decode_row_descriptions(count, binary) { + use fields <- try(read_row_descriptions(count, binary, [])) + Ok(BeRowDescriptions(fields)) +} + +fn decode_parameter_status(binary) { + use strings <- try(decode_strings(binary)) + case strings { + [name, value] -> Ok(BeParameterStatus(name: name, value: value)) + _ -> dec_err("invalid parameter status", binary) + } +} + +fn decode_authentication_sasl(binary) { + use strings <- try(decode_strings(binary)) + Ok(BeAuthenticationSASL(strings)) +} + +fn decode_copy_response(direction, format_raw, count, rest) { + use overall_format <- try(decode_format(format_raw)) + use <- bool.guard( + bit_array.byte_size(rest) != count * 2, + dec_err("size must be count * 2", rest), + ) + + use codes <- try(decode_format_codes(rest, [])) + + case overall_format == Text { + False -> Ok(BeCopyResponse(direction, overall_format, codes)) + True -> + case list.all(codes, fn(code) { code == Text }) { + True -> Ok(BeCopyResponse(direction, overall_format, codes)) + False -> dec_err("invalid copy response format", rest) + } + } +} + +fn decode_format_codes( + binary: BitArray, + result: List(Format), +) -> Result(List(Format), MessageDecodingError) { + case binary { + <> -> + case decode_format(code) { + Ok(format) -> decode_format_codes(tail, [format, ..result]) + Error(err) -> Error(err) + } + <<>> -> Ok(list.reverse(result)) + _ -> dec_err("invalid format codes", binary) + } +} + +fn decode_format(num: Int) { + case num { + 0 -> Ok(Text) + 1 -> Ok(Binary) + _ -> dec_err("invalid format code: " <> int.to_string(num), <<>>) + } +} + +fn read_format(binary) { + case binary { + <<0:16, rest:bytes>> -> Ok(#(Text, rest)) + <<1:16, rest:bytes>> -> Ok(#(Binary, rest)) + _ -> dec_err("invalid format code", binary) + } +} + +fn encode_format(format_raw) -> Int { + case format_raw { + Text -> 0 + Binary -> 1 + } +} + +pub type DataRow { + DataRow(List(BitArray)) +} + +fn decode_message_data_row( + count, + rest, +) -> Result(BackendMessage, MessageDecodingError) { + case decode_message_data_row_rec(rest, count, []) { + Ok(cols) -> + case list.length(cols) == count { + True -> Ok(BeMessageDataRow(cols)) + False -> dec_err("column count doesn't match", rest) + } + Error(err) -> Error(err) + } +} + +fn decode_message_data_row_rec( + binary, + count, + result, +) -> Result(List(BitArray), MessageDecodingError) { + case count, binary { + 0, _ -> Ok(list.reverse(result)) + _, <> -> + decode_message_data_row_rec(rest, count - 1, [value, ..result]) + _, _ -> + dec_err( + "failed to parse data row at count " <> int.to_string(count), + binary, + ) + } +} + +pub type RowDescriptionField { + RowDescriptionField( + // The field name. + name: String, + // If the field can be identified as a column of a specific table, the + // object ID of the table; otherwise zero. + table_oid: Int, + // If the field can be identified as a column of a specific table, the + // attribute number of the column; otherwise zero. + attr_number: Int, + // The object ID of the field's data type. + data_type_oid: Int, + // The data type size (see pg_type.typlen). Note that negative values denote + // variable-width types. + data_type_size: Int, + // The type modifier (see pg_attribute.atttypmod). The meaning of the + // modifier is type-specific. + type_modifier: Int, + // The format code being used for the field. Currently will be zero (text) + // or one (binary). In a RowDescription returned from the statement variant + // of Describe, the format code is not yet known and will always be zero. + format_code: Int, + ) +} + +fn read_row_descriptions(count, binary, result) { + case count, binary { + 0, <<>> -> Ok(list.reverse(result)) + _, <<>> -> dec_err("row description count mismatch", binary) + _, _ -> + case read_row_description_field(binary) { + Ok(#(field, tail)) -> + read_row_descriptions(count - 1, tail, [field, ..result]) + Error(err) -> Error(err) + } + } +} + +fn read_row_description_field(binary) { + case read_string(binary) { + Ok(#( + name, + << + table_oid:32, + attr_number:16, + data_type_oid:32, + data_type_size:16, + type_modifier:32, + format_code:16, + tail:bytes, + >>, + )) -> + Ok(#( + RowDescriptionField( + name: name, + table_oid: table_oid, + attr_number: attr_number, + data_type_oid: data_type_oid, + data_type_size: data_type_size, + type_modifier: type_modifier, + format_code: format_code, + ), + tail, + )) + Ok(#(_, tail)) -> dec_err("failed to parse row description field", tail) + Error(_) -> dec_err("failed to decode row description field name", binary) + } +} + +fn decode_strings(binary) { + let length = bit_array.byte_size(binary) - 1 + case binary { + <<>> -> Ok([]) + <> -> { + binary_split(head, <<0>>, [Global]) + |> list.map(bit_array.to_string) + |> result.all() + |> result.replace_error(msg_dec_err("invalid strings encoding", binary)) + } + _ -> dec_err("string size didn't match", binary) + } +} + +fn read_strings(binary, count, result) { + case count { + 0 -> Ok(list.reverse(result)) + _ -> { + case read_string(binary) { + Ok(#(value, rest)) -> read_strings(rest, count - 1, [value, ..result]) + Error(err) -> Error(err) + } + } + } +} + +fn read_string(binary) { + case binary_split(binary, <<0>>, []) { + [<<>>, <<>>] -> Ok(#("", <<>>)) + [head, tail] -> { + bit_array.to_string(head) + |> result.replace_error(msg_dec_err("invalid string encoding", head)) + |> result.map(fn(s) { #(s, tail) }) + } + _ -> dec_err("invalid string", binary) + } +} + +fn decode_notice_response(binary) { + use fields <- try(decode_fields(binary)) + Ok(BeNoticeResponse(fields)) +} + +fn decode_error_response(binary) { + use fields <- try(decode_fields(binary)) + Ok(BeErrorResponse(fields)) +} + +pub type ErrorOrNoticeField { + Code(String) + Detail(String) + File(String) + Hint(String) + Line(String) + Message(String) + Position(String) + Routine(String) + SeverityLocalized(String) + Severity(String) + Where(String) + Column(String) + DataType(String) + Constraint(String) + InternalPosition(String) + InternalQuery(String) + Schema(String) + Table(String) + Unknown(key: BitArray, value: String) +} + +fn decode_fields(binary) { + case decode_fields_rec(binary, []) { + Ok(fields) -> + fields + |> list.map(fn(key_value_raw) { + let #(key, value) = key_value_raw + case key { + <<"S":utf8>> -> Severity(value) + <<"V":utf8>> -> SeverityLocalized(value) + <<"C":utf8>> -> Code(value) + <<"M":utf8>> -> Message(value) + <<"D":utf8>> -> Detail(value) + <<"H":utf8>> -> Hint(value) + <<"P":utf8>> -> Position(value) + <<"p":utf8>> -> InternalPosition(value) + <<"q":utf8>> -> InternalQuery(value) + <<"W":utf8>> -> Where(value) + <<"s":utf8>> -> Schema(value) + <<"t":utf8>> -> Table(value) + <<"c":utf8>> -> Column(value) + <<"d":utf8>> -> DataType(value) + <<"n":utf8>> -> Constraint(value) + <<"F":utf8>> -> File(value) + <<"L":utf8>> -> Line(value) + <<"R":utf8>> -> Routine(value) + _ -> Unknown(key, value) + } + }) + |> set.from_list() + |> Ok + Error(err) -> Error(err) + } +} + +fn decode_fields_rec(binary, result) { + case binary { + <<0>> | <<>> -> Ok(result) + <> -> { + case binary_split(rest, <<0>>, []) { + [head, tail] -> { + case bit_array.to_string(head) { + Ok(value) -> + decode_fields_rec(tail, [#(field_type, value), ..result]) + Error(Nil) -> dec_err("invalid field encoding", binary) + } + } + _ -> dec_err("invalid field separator", binary) + } + } + _ -> dec_err("invalid field", binary) + } +} + +type BinarySplitOption { + Global +} + +@external(erlang, "binary", "split") +fn binary_split( + subject: BitArray, + pattern: BitArray, + options: List(BinarySplitOption), +) -> List(BitArray) diff --git a/src/squirrel/internal/error.gleam b/src/squirrel/internal/error.gleam new file mode 100644 index 0000000..7b810d9 --- /dev/null +++ b/src/squirrel/internal/error.gleam @@ -0,0 +1,642 @@ +import glam/doc.{type Document} +import gleam/int +import gleam/list +import gleam/option.{type Option, None, Some} +import gleam/regex +import gleam/result +import gleam/string +import gleam_community/ansi +import simplifile + +pub type Error { + + // --- POSTGRES RELATED ERRORS ----------------------------------------------- + /// When authentication workflow goes wrong. + /// TODO)) For now I only support no authentication so people might report + /// this issue. + /// + PgCannotAuthenticate(expected: String, got: String) + + /// When there's an error with the underlying socket and I cannot send + /// messages to the server. + /// + PgCannotSendMessage(reason: String) + + /// This comes from the `postgres_protocol` module, it happens if the server + /// sends back a malformed message or there's a type of messages decoding is + /// not implemented for. + /// This should never happen and warrants a bug report! + /// + PgCannotDecodeReceivedMessage(reason: String) + + /// When there's an error with the underlying socket and I cannot receive + /// messages to the server. + /// + PgCannotReceiveMessage(reason: String) + + /// When I cannot get a query description back from the postgres server. + /// + PgCannotDescribeQuery( + file: String, + query_name: String, + expected: String, + got: String, + ) + + // --- OTHER GENERIC ERRORS -------------------------------------------------- + /// When I cannot read a file containing queries. + /// + CannotReadFile(file: String, reason: simplifile.FileError) + + /// When the generated code cannot be written to a file. + /// + CannotWriteToFile(file: String, reason: simplifile.FileError) + + /// When the input file has no queries that the cool can generate code from. + /// + FileWithNoQueries(file: String) + + /// If a query has a name that is not a valid Gleam identifier. Instead of + /// trying to magically come up with a name we fail and report the error. + /// + QueryHasInvalidName( + file: String, + name: String, + suggested_name: Option(String), + name_line: String, + starting_line: Int, + reason: ValueIdentifierError, + ) + + /// If a query returns a column that is not a valid Gleam identifier. Instead + /// of trying to magically come up with a name we fail and report the error. + /// + QueryHasInvalidColumn( + file: String, + column_name: String, + suggested_name: Option(String), + content: String, + starting_line: Int, + reason: ValueIdentifierError, + ) + + /// When there's a param/return type that cannot be converted into a Gleam + /// type. + /// + QueryHasUnsupportedType( + file: String, + name: String, + content: String, + starting_line: Int, + type_: String, + ) + + /// If the query contains an error and cannot be parsed by the DBMS. + /// + CannotParseQuery( + file: String, + name: String, + content: String, + starting_line: Int, + error_code: Option(String), + pointer: Option(Pointer), + hint: Option(String), + ) +} + +pub type ValueIdentifierError { + DoesntStartWithLowercaseLetter + ContainsInvalidGrapheme(at: Int, grapheme: String) + IsEmpty +} + +pub type Pointer { + Pointer(point_to: PointerKind, message: String) +} + +pub type PointerKind { + Name(name: String) + ByteIndex(position: Int) +} + +pub fn to_doc(error: Error) -> Document { + let printable_error = case error { + PgCannotSendMessage(reason: reason) -> + printable_error("Cannot send message") + |> add_paragraph( + "I ran into an unexpected error while trying to talk to the Postgres +database server.", + ) + |> report_bug(reason) + + PgCannotDecodeReceivedMessage(reason: reason) -> + printable_error("Cannot decode message") + |> add_paragraph( + "I ran into an unexpected error while trying to decode a message +received from the Postgres database server.", + ) + |> report_bug(reason) + + PgCannotReceiveMessage(reason: reason) -> + printable_error("Cannot receive message") + |> add_paragraph( + "I ran into an unexpected error while trying to listen to the Postgres +database server.", + ) + |> report_bug(reason) + + CannotReadFile(file: file, reason: reason) -> + printable_error("Cannot read file") + |> add_paragraph( + "I couldn't read " + <> style_file(file) + <> " because of the following error: " + <> simplifile.describe_error(reason), + ) + + CannotWriteToFile(file: file, reason: reason) -> + printable_error("Cannot write to file") + |> add_paragraph( + "I couldn't write to " + <> style_file(file) + <> " because of the following error: " + <> simplifile.describe_error(reason), + ) + + FileWithNoQueries(file: file) -> + printable_error("File with no queries") + |> add_paragraph("I couldn't find any query in " <> style_file(file)) + |> hint("Each query should be preceded by a comment with its +name like this one: " <> style_inline_code("-- name: name_of_the_query") <> ", +or I won't be able to pick it up!") + + QueryHasInvalidName( + file: file, + name: name, + suggested_name: suggested_name, + name_line: name_line, + reason: reason, + starting_line: starting_line, + ) -> + printable_error("Query with invalid name") + |> add_code_paragraph( + file: file, + content: name_line, + starting_line: starting_line, + point: Some( + Pointer( + point_to: case reason { + IsEmpty -> Name("name:") + _ -> Name(name) + }, + message: case suggested_name { + None -> "This is not a valid Gleam name" + Some(suggestion) -> + "This is not a valid Gleam name, maybe try " + <> style_inline_code(suggestion) + <> "?" + }, + ), + ), + ) + |> hint( + "A query name must start with a lowercase letter and can only +contain lowercase letters, numbers and underscores.", + ) + + QueryHasInvalidColumn( + file: file, + column_name: column_name, + suggested_name: suggested_name, + content: content, + reason: reason, + starting_line: starting_line, + ) -> + case reason { + IsEmpty -> + printable_error("Column with empty name") + |> add_code_paragraph( + file: file, + content: content, + point: None, + starting_line: starting_line, + ) + |> add_paragraph( + "A column returned by this query has the empty string as a name, +all columns should have a valid Gleam name as name.", + ) + + _ -> + printable_error("Column with invalid name") + |> add_code_paragraph( + file: file, + content: content, + starting_line: starting_line, + point: Some( + Pointer(point_to: Name(column_name), message: case + suggested_name + { + None -> "This is not a valid Gleam name" + Some(suggestion) -> + "This is not a valid Gleam name, maybe try " + <> style_inline_code(suggestion) + <> "?" + }), + ), + ) + |> hint( + "A column name must start with a lowercase letter and can only +contain lowercase letters, numbers and underscores.", + ) + } + + QueryHasUnsupportedType( + file: file, + name: _, + content: content, + type_: type_, + starting_line: starting_line, + ) -> + printable_error("Unsupported type") + |> add_code_paragraph( + file: file, + content: content, + point: None, + starting_line: starting_line, + ) + |> add_paragraph( + "One of the rows returned by this query has type " + <> style_inline_code(type_) + <> " which I cannot currently generate code for.", + ) + |> call_to_action(for: "this type to be supported") + + CannotParseQuery( + file: file, + name: _name, + content: content, + starting_line: starting_line, + error_code: error_code, + hint: hint, + pointer: pointer, + ) -> + printable_error(case error_code { + Some(code) -> "Invalid query [" <> code <> "]" + None -> "Invalid query" + }) + |> add_code_paragraph( + file: file, + content: content, + point: pointer, + starting_line: starting_line, + ) + |> maybe_hint(hint) + + PgCannotAuthenticate(expected: expected, got: got) -> + printable_error("Cannot authenticate") + |> add_paragraph( + "I ran into an unexpected problem while trying to authenticate with the +Postgres server. This is most definitely a bug!", + ) + |> report_bug("Expected: " <> expected <> ", Got: " <> got) + + PgCannotDescribeQuery( + file: file, + query_name: query_name, + expected: expected, + got: got, + ) -> + printable_error("Cannot inspect query") + |> add_paragraph("I ran into an unexpected problem while trying to figure +out the types of query " <> style_inline_code(query_name) <> " +defined in " <> style_file(file) <> ". This is most definitely a bug!") + |> report_bug("Expected: " <> expected <> ", Got: " <> got) + } + + printable_error_to_doc(printable_error) +} + +fn style_file(file: String) -> String { + ansi.underline(file) +} + +fn style_inline_code(code: String) -> String { + "`" <> code <> "`" +} + +fn style_link(link: String) -> String { + ansi.underline(link) +} + +// --- ERROR PRETTY PRINTING --------------------------------------------------- + +const indent = 2 + +type PrintableError { + PrintableError( + title: String, + body: List(Paragraph), + report_bug: Option(String), + call_to_action: Option(String), + hint: Option(String), + ) +} + +type Paragraph { + Simple(String) + Code( + file: String, + content: String, + pointer: Option(Pointer), + starting_line: Int, + ) +} + +fn printable_error(title: String) -> PrintableError { + PrintableError( + title: title, + body: [], + report_bug: None, + hint: None, + call_to_action: None, + ) +} + +fn add_paragraph(error: PrintableError, string: String) -> PrintableError { + PrintableError(..error, body: list.append(error.body, [Simple(string)])) +} + +fn add_code_paragraph( + error: PrintableError, + file file: String, + content content: String, + point point: Option(Pointer), + starting_line starting_line: Int, +) -> PrintableError { + PrintableError( + ..error, + body: list.append(error.body, [ + Code( + file: file, + content: content, + pointer: point, + starting_line: starting_line, + ), + ]), + ) +} + +fn report_bug(error: PrintableError, report_bug: String) -> PrintableError { + PrintableError(..error, report_bug: Some(report_bug)) +} + +fn hint(error: PrintableError, hint: String) -> PrintableError { + PrintableError(..error, hint: Some(hint)) +} + +fn maybe_hint(error: PrintableError, hint: Option(String)) -> PrintableError { + PrintableError(..error, hint: hint) +} + +fn call_to_action(error: PrintableError, for wanted: String) -> PrintableError { + PrintableError(..error, call_to_action: Some(wanted)) +} + +fn printable_error_to_doc(error: PrintableError) -> Document { + let PrintableError( + title: title, + body: body, + report_bug: report_bug, + call_to_action: call_to_action, + hint: hint, + ) = error + + [ + title_doc(title), + body_doc(body), + option_to_doc(report_bug, report_bug_doc), + option_to_doc(call_to_action, call_to_action_doc), + option_to_doc(hint, hint_doc), + ] + |> list.filter(keeping: fn(doc) { doc != doc.empty }) + |> doc.join(with: doc.lines(2)) + |> doc.group +} + +fn title_doc(title: String) -> Document { + doc.from_string(ansi.red(ansi.bold("Error: ") <> title)) +} + +fn body_doc(body: List(Paragraph)) -> Document { + list.map(body, paragraph_doc) + |> doc.join(with: doc.line) + |> doc.group +} + +fn paragraph_doc(paragraph: Paragraph) -> Document { + case paragraph { + Simple(string) -> flexible_string(string) + Code( + file: file, + content: content, + pointer: pointer, + starting_line: starting_line, + ) -> + code_doc( + file: file, + content: content, + pointer: pointer, + starting_line: starting_line, + ) + } +} + +fn code_doc( + file file: String, + content content: String, + pointer pointer: Option(Pointer), + starting_line starting_line: Int, +) { + let pointer = + option.to_result(pointer, Nil) + |> result.then(pointer_doc(_, content)) + + let content = syntax_highlight(content) + let lines = string.split(content, on: "\n") + let lines_count = list.length(lines) + let assert Ok(digits) = int.digits(lines_count + starting_line, 10) + let max_digits = list.length(digits) + + let code_lines = { + use line, i <- list.index_map(lines) + let prefix = + int.to_string(i + starting_line) + |> string.pad_left(to: max_digits + 2, with: " ") + + case pointer { + Ok(#(pointer_line, from, pointer_doc)) if pointer_line == i -> [ + doc.from_string(ansi.dim(prefix <> " │ ")), + doc.from_string(line), + [doc.line, pointer_doc] + |> doc.concat + |> doc.nest(by: from + max_digits + 5), + ] + + Ok(_) | Error(_) -> [ + doc.from_string(ansi.dim(prefix <> " │ ")), + doc.from_string(ansi.dim(line)), + ] + } + |> doc.concat + } + + let padding = string.repeat(" ", max_digits + 3) + [ + doc.from_string(padding <> ansi.dim("╭─ " <> file)), + case starting_line { + 1 -> doc.from_string(padding <> ansi.dim("│ ")) + _ -> doc.from_string(padding <> ansi.dim("┆ ")) + }, + ..code_lines + ] + |> doc.join(with: doc.line) + |> doc.append(doc.line) + |> doc.append(doc.from_string(padding <> ansi.dim("┆"))) + |> doc.group +} + +fn pointer_doc( + pointer: Pointer, + content: String, +) -> Result(#(Int, Int, Document), Nil) { + let Pointer(kind, message) = pointer + use #(line, from, to) <- result.try(find_span(kind, content)) + let width = to - from + 1 + let doc = + [ + doc.zero_width_string("\u{001B}[31m"), + doc.from_string("┬" <> string.repeat("─", width - 1)), + doc.line, + doc.from_string("╰─ "), + flexible_string(message) + |> doc.nest(by: 3), + doc.zero_width_string("\u{001B}[0m"), + ] + |> doc.concat + |> doc.group + + Ok(#(line, from, doc)) +} + +fn find_span(kind: PointerKind, string: String) -> Result(#(Int, Int, Int), Nil) { + case kind { + Name(name) -> find_name_span(name, string.length(name), string, 0, 0) + ByteIndex(n) -> find_byte_span(n - 1, string, 0, 0) + } +} + +fn find_name_span( + name: String, + name_len: Int, + string: String, + row: Int, + col: Int, +) -> Result(#(Int, Int, Int), Nil) { + case string.starts_with(string, name) { + True -> Ok(#(row, col, col + name_len - 1)) + False -> + case string.pop_grapheme(string) { + Ok(#("\n", rest)) -> find_name_span(name, name_len, rest, row + 1, 0) + Ok(#(_, rest)) -> find_name_span(name, name_len, rest, row, col + 1) + Error(_) -> Error(Nil) + } + } +} + +fn find_byte_span( + position: Int, + string: String, + row: Int, + col: Int, +) -> Result(#(Int, Int, Int), Nil) { + case position { + 0 -> Ok(#(row, col, col)) + n -> + case string.pop_grapheme(string) { + Ok(#("\n", rest)) -> find_byte_span(n - 1, rest, row + 1, 0) + Ok(#(_, rest)) -> find_byte_span(n - 1, rest, row, col + 1) + Error(_) -> Error(Nil) + } + } +} + +const keywords = [ + "and", "any", "as", "asc", "begin", "between", "by", "case", "count", "desc", + "distinct", "else", "end", "exists", "from", "full", "group", "having", "if", + "in", "inner", "insert", "into", "join", "key", "left", "like", "not", "null", + "on", "or", "order", "primary", "revert", "right", "select", "set", "table", + "top", "trigger", "union", "update", "use", "values", "view", "where", "with", +] + +fn syntax_highlight(content: String) -> String { + let keywords = string.join(keywords, with: "|") + let not_inside_string = "(?=(?:[^']*'[^']*')*[^']*$)" + + let assert Ok(keyword) = + { "\\b(" <> keywords <> ")\\b" <> not_inside_string } + |> regex.compile(regex.Options(True, False)) + let assert Ok(number) = + regex.from_string("(? not_inside_string) + let assert Ok(comment) = regex.from_string("(^\\s*--.*)") + let assert Ok(string) = regex.from_string("(\\'.*\\')") + let assert Ok(hole) = regex.from_string("(\\$\\d+)" <> not_inside_string) + + content + |> regex.replace(each: comment, with: "\u{001B}[2m\\1\u{001B}[0m") + |> regex.replace(each: keyword, with: "\u{001B}[36m\\1\u{001B}[39m") + |> regex.replace(each: string, with: "\u{001B}[33m\\1\u{001B}[39m") + |> regex.replace(each: number, with: "\u{001B}[32m\\1\u{001B}[39m") + |> regex.replace(each: hole, with: "\u{001B}[35m\\1\u{001B}[39m") +} + +fn report_bug_doc(additional_info: String) -> Document { + [ + flexible_string( + "Please open an issue at " + <> style_link("https://github.com/giacomocavalieri/squirrel/issues/new") + <> " with some details about what you where doing, including the following message:", + ), + doc.line |> doc.nest(by: indent), + doc.from_string(additional_info), + ] + |> doc.concat + |> doc.group +} + +fn call_to_action_doc(wanted: String) -> Document { + flexible_string( + "If you would like for " + <> wanted + <> ", please open an issue at " + <> style_link("https://github.com/giacomocavalieri/squirrel/issues/new"), + ) +} + +fn hint_doc(hint: String) -> Document { + flexible_string("Hint: " <> hint) +} + +fn flexible_string(string: String) -> Document { + string.split(string, on: "\n") + |> list.flat_map(string.split(_, on: " ")) + |> list.map(doc.from_string) + |> doc.join(with: doc.flex_space) + |> doc.group +} + +fn option_to_doc(option: Option(a), fun: fn(a) -> Document) -> Document { + case option { + Some(a) -> fun(a) + None -> doc.empty + } +} diff --git a/src/squirrel/internal/eval_extra.gleam b/src/squirrel/internal/eval_extra.gleam new file mode 100644 index 0000000..70a7c72 --- /dev/null +++ b/src/squirrel/internal/eval_extra.gleam @@ -0,0 +1,61 @@ +import eval.{type Eval} +import gleam/list + +pub fn try_map( + list: List(a), + fun: fn(a) -> Eval(b, _, _), +) -> Eval(List(b), _, _) { + try_fold(list, [], fn(acc, item) { + use mapped_item <- eval.try(fun(item)) + eval.return([mapped_item, ..acc]) + }) + |> eval.map(list.reverse) +} + +pub fn try_index_map( + list: List(a), + fun: fn(a, Int) -> Eval(b, _, _), +) -> Eval(List(b), _, _) { + try_index_fold(list, [], fn(acc, item, i) { + use mapped_item <- eval.try(fun(item, i)) + eval.return([mapped_item, ..acc]) + }) + |> eval.map(list.reverse) +} + +pub fn try_fold( + over list: List(a), + from acc: b, + with fun: fn(b, a) -> Eval(b, _, _), +) -> Eval(b, _, _) { + case list { + [] -> eval.return(acc) + [first, ..rest] -> { + use acc <- eval.try(fun(acc, first)) + try_fold(rest, acc, fun) + } + } +} + +pub fn try_index_fold( + over list: List(a), + from acc: b, + with fun: fn(b, a, Int) -> Eval(b, _, _), +) -> Eval(b, _, _) { + do_try_index_fold(0, over: list, from: acc, with: fun) +} + +fn do_try_index_fold( + index: Int, + over list: List(a), + from acc: b, + with fun: fn(b, a, Int) -> Eval(b, _, _), +) -> Eval(b, _, _) { + case list { + [] -> eval.return(acc) + [first, ..rest] -> { + use acc <- eval.try(fun(acc, first, index)) + do_try_index_fold(index + 1, rest, acc, fun) + } + } +} diff --git a/src/squirrel/internal/gleam.gleam b/src/squirrel/internal/gleam.gleam new file mode 100644 index 0000000..bf1bbc6 --- /dev/null +++ b/src/squirrel/internal/gleam.gleam @@ -0,0 +1,203 @@ +import gleam/list +import gleam/string +import justin +import squirrel/internal/error.{ + type ValueIdentifierError, ContainsInvalidGrapheme, IsEmpty, +} + +pub type Type { + List(Type) + Option(Type) + Int + Float + Bool + String +} + +pub type Field { + Field(label: ValueIdentifier, type_: Type) +} + +pub opaque type ValueIdentifier { + ValueIdentifier(String) +} + +/// Returns true if the given string is a valid Gleam identifier (that is not +/// a discard identifier, that is starting with an '_'). +/// +/// > 💡 A valid identifier can be described by the following regex: +/// > `[a-z][a-z0-9_]*`. +pub fn identifier( + from name: String, +) -> Result(ValueIdentifier, ValueIdentifierError) { + // A valid identifier needs to start with a lowercase letter. + // We do not accept _discard identifier as valid. + case name { + "a" <> rest + | "b" <> rest + | "c" <> rest + | "d" <> rest + | "e" <> rest + | "f" <> rest + | "g" <> rest + | "h" <> rest + | "i" <> rest + | "j" <> rest + | "k" <> rest + | "l" <> rest + | "m" <> rest + | "n" <> rest + | "o" <> rest + | "p" <> rest + | "q" <> rest + | "r" <> rest + | "s" <> rest + | "t" <> rest + | "u" <> rest + | "v" <> rest + | "w" <> rest + | "x" <> rest + | "y" <> rest + | "z" <> rest -> to_identifier_rest(name, rest, 1) + _ -> + case string.pop_grapheme(name) { + Ok(#(g, _)) -> Error(ContainsInvalidGrapheme(0, g)) + Error(_) -> Error(IsEmpty) + } + } +} + +fn to_identifier_rest( + name: String, + rest: String, + position: Int, +) -> Result(ValueIdentifier, ValueIdentifierError) { + // The rest of an identifier can only contain lowercase letters, _, numbers, + // or be empty. In all other cases it's not valid. + case rest { + "a" <> rest + | "b" <> rest + | "c" <> rest + | "d" <> rest + | "e" <> rest + | "f" <> rest + | "g" <> rest + | "h" <> rest + | "i" <> rest + | "j" <> rest + | "k" <> rest + | "l" <> rest + | "m" <> rest + | "n" <> rest + | "o" <> rest + | "p" <> rest + | "q" <> rest + | "r" <> rest + | "s" <> rest + | "t" <> rest + | "u" <> rest + | "v" <> rest + | "w" <> rest + | "x" <> rest + | "y" <> rest + | "z" <> rest + | "_" <> rest + | "0" <> rest + | "1" <> rest + | "2" <> rest + | "3" <> rest + | "4" <> rest + | "5" <> rest + | "6" <> rest + | "7" <> rest + | "8" <> rest + | "9" <> rest -> to_identifier_rest(name, rest, position + 1) + "" -> Ok(ValueIdentifier(name)) + _ -> + case string.pop_grapheme(rest) { + Ok(#(g, _)) -> Error(ContainsInvalidGrapheme(position, g)) + Error(_) -> panic as "unreachable: empty identifier rest should be ok" + } + } +} + +pub fn identifier_to_string(identifier: ValueIdentifier) -> String { + let ValueIdentifier(name) = identifier + name +} + +pub fn identifier_to_type_name(identifier: ValueIdentifier) -> String { + let ValueIdentifier(name) = identifier + + justin.pascal_case(name) + |> string.to_graphemes + // We want to remove any leftover "_" that might still be present after the + // conversion if the identifier had consecutive "_". + |> list.filter(keeping: fn(c) { c != "_" }) + |> string.join(with: "") +} + +pub fn similar_identifier_string(string: String) -> Result(String, Nil) { + let proposal = + string.trim(string) + |> justin.snake_case + |> string.to_graphemes + |> list.drop_while(fn(g) { g == "_" || is_digit(g) }) + |> list.filter(keeping: is_identifier_char) + |> string.join(with: "") + + case proposal { + "" -> Error(Nil) + _ -> Ok(proposal) + } +} + +fn is_digit(char: String) -> Bool { + case char { + "0" | "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9" -> True + _ -> False + } +} + +fn is_identifier_char(char: String) -> Bool { + case char { + "a" + | "b" + | "c" + | "d" + | "e" + | "f" + | "g" + | "h" + | "i" + | "j" + | "k" + | "l" + | "m" + | "n" + | "o" + | "p" + | "q" + | "r" + | "s" + | "t" + | "u" + | "v" + | "w" + | "x" + | "y" + | "z" + | "_" + | "0" + | "1" + | "2" + | "3" + | "4" + | "5" + | "6" + | "7" + | "8" + | "9" -> True + _ -> False + } +} diff --git a/src/squirrel/internal/query.gleam b/src/squirrel/internal/query.gleam new file mode 100644 index 0000000..032f82d --- /dev/null +++ b/src/squirrel/internal/query.gleam @@ -0,0 +1,537 @@ +import glam/doc.{type Document} +import gleam/bit_array +import gleam/int +import gleam/list +import gleam/option +import gleam/result +import gleam/string +import simplifile +import squirrel/internal/error.{ + type Error, type ValueIdentifierError, CannotReadFile, QueryHasInvalidName, +} +import squirrel/internal/gleam + +pub type UntypedQuery { + UntypedQuery( + file: String, + starting_line: Int, + name: gleam.ValueIdentifier, + content: String, + ) +} + +pub type TypedQuery { + TypedQuery( + file: String, + starting_line: Int, + name: gleam.ValueIdentifier, + content: String, + params: List(gleam.Type), + returns: List(gleam.Field), + ) +} + +pub fn add_types( + to query: UntypedQuery, + params params: List(gleam.Type), + returns returns: List(gleam.Field), +) -> TypedQuery { + let UntypedQuery( + file: file, + name: name, + content: content, + starting_line: starting_line, + ) = query + TypedQuery( + file: file, + name: name, + content: content, + starting_line: starting_line, + params: params, + returns: returns, + ) +} + +// --- PARSING ----------------------------------------------------------------- + +pub fn from_file(file: String) -> Result(List(UntypedQuery), Error) { + let read_file = + simplifile.read(file) + |> result.map_error(CannotReadFile(file, _)) + + use content <- result.try(read_file) + let content = <> + parse_queries(file, content, content, 0, 1, []) +} + +fn parse_queries( + file: String, + original: BitArray, + string: BitArray, + position: Int, + line: Int, + queries: List(UntypedQuery), +) -> Result(List(UntypedQuery), Error) { + case string { + <<>> -> Ok(list.reverse(queries)) + <<"--":utf8, rest:bits>> -> + case name_comment(rest) { + NotANameComment -> + parse_queries(file, original, rest, position + 2, line, queries) + + InvalidNameComment(name: name, reason: reason) -> + Error(QueryHasInvalidName( + file: file, + name: name, + suggested_name: gleam.similar_identifier_string(name) + |> option.from_result, + name_line: take_line(string), + reason: reason, + starting_line: line, + )) + + ValidNameComment(identifier) -> + parse_query( + file, + original, + rest, + identifier, + position, + line, + position + 2, + line, + queries, + ) + } + <<"\n":utf8, rest:bits>> -> + parse_queries(file, original, rest, position + 1, line + 1, queries) + <<_, rest:bits>> -> + parse_queries(file, original, rest, position + 1, line + 1, queries) + _ -> panic as "non byte aligned Gleam String" + } +} + +type NameCommentResult { + NotANameComment + InvalidNameComment(name: String, reason: ValueIdentifierError) + ValidNameComment(identifier: gleam.ValueIdentifier) +} + +fn name_comment(string: BitArray) -> NameCommentResult { + case string { + // We ignore all whitespace between the start of the comment and its + // content. + <<" ":utf8, rest:bits>> + | <<"\t":utf8, rest:bits>> + | <<"\r":utf8, rest:bits>> -> name_comment(rest) + + <<"name:":utf8, rest:bits>> -> take_name(rest, rest, 0) + <<_, _:bits>> | <<>> -> NotANameComment + _ -> panic as "non byte aligned Gleam String" + } +} + +fn take_name( + original: BitArray, + string: BitArray, + size: Int, +) -> NameCommentResult { + case string { + <<>> | <<"\n":utf8, _:bits>> -> { + let assert Ok(name) = bit_array.slice(original, 0, size) + let assert Ok(name) = bit_array.to_string(name) + let name = string.trim(name) + case gleam.identifier(name) { + Ok(identifier) -> ValidNameComment(identifier) + Error(reason) -> InvalidNameComment(name: name, reason: reason) + } + } + <<_, rest:bits>> -> take_name(original, rest, size + 1) + _ -> panic as "non byte aligned Gleam String" + } +} + +fn parse_query( + file: String, + original: BitArray, + string: BitArray, + name: gleam.ValueIdentifier, + start: Int, + starting_line: Int, + position: Int, + line: Int, + queries: List(UntypedQuery), +) -> Result(List(UntypedQuery), Error) { + case string { + <<>> -> { + let assert Ok(query) = bit_array.slice(original, start, position - start) + let assert Ok(content) = bit_array.to_string(query) + let query = + UntypedQuery( + file: file, + name: name, + content: content, + starting_line: starting_line, + ) + Ok(list.reverse([query, ..queries])) + } + + <<"--":utf8, rest:bits>> -> + case name_comment(rest) { + NotANameComment -> + parse_query( + file, + original, + rest, + name, + start, + starting_line, + position + 2, + line, + queries, + ) + + ValidNameComment(_) | InvalidNameComment(_, _) -> { + let assert Ok(query) = + bit_array.slice(original, start, position - start) + let assert Ok(content) = bit_array.to_string(query) + let query = + UntypedQuery( + file: file, + name: name, + content: content, + starting_line: line, + ) + parse_queries(file, original, string, position, line, [ + query, + ..queries + ]) + } + } + + <<"\n":utf8, rest:bits>> -> + parse_query( + file, + original, + rest, + name, + start, + starting_line, + position + 1, + line + 1, + queries, + ) + + <<_, rest:bits>> -> + parse_query( + file, + original, + rest, + name, + start, + starting_line, + position + 1, + line, + queries, + ) + + _ -> panic as "non byte aligned Gleam String" + } +} + +fn take_line(string: BitArray) -> String { + let assert Ok(line) = do_take_line(string, string, 0) + let assert Ok(line) = bit_array.to_string(line) + line +} + +fn do_take_line( + original: BitArray, + line: BitArray, + size: Int, +) -> Result(BitArray, Nil) { + case line { + <<>> -> bit_array.slice(original, 0, size) + <<"\n":utf8, _:bits>> -> bit_array.slice(original, 0, size) + <<_, rest:bits>> -> { + do_take_line(original, rest, size + 1) + } + _ -> panic as "non byte aligned Gleam String" + } +} + +// --- CODE GENERATION --------------------------------------------------------- + +pub fn generate_code(version: String, query: TypedQuery) { + let TypedQuery( + file: file, + name: name, + content: content, + params: params, + returns: returns, + starting_line: _, + ) = query + + let arg_name = fn(i) { "arg_" <> int.to_string(i + 1) } + let inputs = list.index_map(params, fn(_, i) { arg_name(i) }) + let inputs_encoders = + list.index_map(params, fn(p, i) { + gleam_type_to_encoder(p, arg_name(i)) |> doc.from_string + }) + + let function_name = gleam.identifier_to_string(name) + let constructor_name = gleam.identifier_to_type_name(name) <> "Row" + + let record_doc = + "/// A row you get from running the `" <> function_name <> "` query +/// defined in `" <> file <> "`. +/// +/// > 🐿️ This type definition was generated automatically using " <> version <> " of the +/// > [squirrel package](https://github.com/giacomocavalieri/squirrel). +///" + + let fun_doc = "/// Runs the `" <> function_name <> "` query +/// defined in `" <> file <> "`. +/// +/// > 🐿️ This function was generated automatically using " <> version <> " of the +/// > [squirrel package](https://github.com/giacomocavalieri/squirrel). +///" + + [ + doc.from_string(record_doc), + doc.line, + record(constructor_name, returns), + doc.lines(2), + doc.from_string(fun_doc), + doc.line, + fun(function_name, ["db", ..inputs], [ + var("decoder", decoder(constructor_name, returns)), + pipe_call("pgo.execute", string(content), [ + doc.from_string("db"), + list(inputs_encoders), + doc.from_string("decode.from(decoder, _)"), + ]), + ]), + ] + |> doc.concat + |> doc.to_string(80) +} + +fn gleam_type_to_decoder(type_: gleam.Type) -> String { + case type_ { + gleam.List(type_) -> "decode.list(" <> gleam_type_to_decoder(type_) <> ")" + gleam.Int -> "decode.int" + gleam.Float -> "decode.float" + gleam.Bool -> "decode.bool" + gleam.String -> "decode.string" + gleam.Option(type_) -> + "decode.optional(" <> gleam_type_to_decoder(type_) <> ")" + } +} + +fn gleam_type_to_encoder(type_: gleam.Type, name: String) { + case type_ { + gleam.List(type_) -> + "pgo.array(list.map(name, fn(a) {" + <> gleam_type_to_encoder(type_, "a") + <> "}))" + gleam.Option(type_) -> + "pgo.nullable(fn(a) {" + <> gleam_type_to_encoder(type_, "a") + <> "}, " + <> name + <> ")" + gleam.Int -> "pgo.int(" <> name <> ")" + gleam.Float -> "pgo.float(" <> name <> ")" + gleam.Bool -> "pgo.bool(" <> name <> ")" + gleam.String -> "pgo.text(" <> name <> ")" + } +} + +// --- CODE GENERATION PRETTY PRINTING ----------------------------------------- + +const indent = 2 + +pub fn record(name: String, fields: List(gleam.Field)) -> Document { + let fields = + list.map(fields, fn(field) { + let label = gleam.identifier_to_string(field.label) + + [doc.from_string(label <> ": "), pretty_gleam_type(field.type_)] + |> doc.concat + |> doc.group + }) + + [ + doc.from_string("pub type " <> name <> " {"), + [doc.line, call(name, fields)] + |> doc.concat + |> doc.nest(by: indent), + doc.line, + doc.from_string("}"), + ] + |> doc.concat + |> doc.group +} + +fn pretty_gleam_type(type_: gleam.Type) -> Document { + case type_ { + gleam.List(type_) -> call("List", [pretty_gleam_type(type_)]) + gleam.Option(type_) -> call("Option", [pretty_gleam_type(type_)]) + gleam.Int -> doc.from_string("Int") + gleam.Float -> doc.from_string("Float") + gleam.Bool -> doc.from_string("Bool") + gleam.String -> doc.from_string("String") + } +} + +/// A pretty printed public function definition. +/// +pub fn fun(name: String, args: List(String), body: List(Document)) -> Document { + let args = list.map(args, doc.from_string) + + [ + doc.from_string("pub fn " <> name), + comma_list("(", args, ") "), + block([body |> doc.join(with: doc.lines(2))]), + doc.line, + ] + |> doc.concat + |> doc.group +} + +/// A pretty printed let assignment. +/// +pub fn var(name: String, body: Document) -> Document { + [ + doc.from_string("let " <> name <> " ="), + [doc.space, body] + |> doc.concat + |> doc.group + |> doc.nest(by: indent), + ] + |> doc.concat +} + +/// A pretty printed Gleam string. +/// +/// > ⚠️ This function escapes all `\` and `"` inside the original string to +/// > avoid generating invalid Gleam code. +/// +pub fn string(content: String) -> Document { + let escaped_string = + content + |> string.replace(each: "\\", with: "\\\\") + |> string.replace(each: "\"", with: "\\\"") + |> doc.from_string + + [doc.from_string("\""), escaped_string, doc.from_string("\"")] + |> doc.concat +} + +/// A pretty printed Gleam list. +/// +pub fn list(elems: List(Document)) -> Document { + comma_list("[", elems, "]") +} + +/// A pretty printed decoder that decodes an n-item dynamic tuple using the +/// `decode` package. +/// +pub fn decoder(constructor: String, returns: List(gleam.Field)) -> Document { + let parameters = + list.map(returns, fn(field) { + let label = gleam.identifier_to_string(field.label) + doc.from_string("use " <> label <> " <- decode.parameter") + }) + + let pipes = + list.index_map(returns, fn(field, i) { + let position = int.to_string(i) |> doc.from_string + let decoder = gleam_type_to_decoder(field.type_) |> doc.from_string + call("|> decode.field", [position, decoder]) + }) + + let labelled_names = + // TODO)) remove the redundant label once zed updates syntax highlight for gleam + list.map(returns, fn(field) { + let label = gleam.identifier_to_string(field.label) + doc.from_string(label <> ": " <> label) + }) + + [ + call_block("decode.into", [ + doc.join(parameters, with: doc.line), + doc.line, + call(constructor, labelled_names), + ]), + doc.line, + doc.join(pipes, with: doc.line), + ] + |> doc.concat() + |> doc.group +} + +/// A pretty printed function call where the first argument is piped into +/// the function. +/// +pub fn pipe_call( + function: String, + first: Document, + rest: List(Document), +) -> Document { + [first, doc.line, call("|> " <> function, rest)] + |> doc.concat +} + +/// A pretty printed function call. +/// +fn call(function: String, args: List(Document)) -> Document { + [doc.from_string(function), comma_list("(", args, ")")] + |> doc.concat + |> doc.group +} + +/// A pretty printed function call where the only argument is a single block. +/// +fn call_block(function: String, body: List(Document)) -> Document { + [doc.from_string(function <> "("), block(body), doc.from_string(")")] + |> doc.concat + |> doc.group +} + +/// A pretty printed Gleam block. +/// +fn block(body: List(Document)) -> Document { + [ + doc.from_string("{"), + [doc.line, ..body] + |> doc.concat + |> doc.nest(by: indent), + doc.line, + doc.from_string("}"), + ] + |> doc.concat + |> doc.force_break +} + +/// A comma separated list of items with some given open and closed delimiters. +/// +fn comma_list(open: String, content: List(Document), close: String) -> Document { + [ + doc.from_string(open), + [ + // We want the first break to be nested + // in case the group is broken. + doc.soft_break, + doc.join(content, doc.break(", ", ",")), + ] + |> doc.concat + |> doc.group + |> doc.nest(by: indent), + doc.break("", ","), + doc.from_string(close), + ] + |> doc.concat + |> doc.group +} diff --git a/test/squirrel_test.gleam b/test/squirrel_test.gleam new file mode 100644 index 0000000..3831e7a --- /dev/null +++ b/test/squirrel_test.gleam @@ -0,0 +1,12 @@ +import gleeunit +import gleeunit/should + +pub fn main() { + gleeunit.main() +} + +// gleeunit test functions end in `_test` +pub fn hello_world_test() { + 1 + |> should.equal(1) +}