diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..b8556db --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,23 @@ +name: test + +on: + push: + branches: + - master + - main + pull_request: + +jobs: + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: erlef/setup-beam@v1 + with: + otp-version: "26.0.2" + gleam-version: "1.4.0-rc1" + rebar3-version: "3" + # elixir-version: "1.15.4" + - run: gleam deps download + - run: gleam test + - run: gleam format --check src test diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..599be4e --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +*.beam +*.ez +/build +erl_crash.dump diff --git a/README.md b/README.md new file mode 100644 index 0000000..a64df9d --- /dev/null +++ b/README.md @@ -0,0 +1,21 @@ +# squirrel + +[![Package Version](https://img.shields.io/hexpm/v/squirrel)](https://hex.pm/packages/squirrel) +[![Hex Docs](https://img.shields.io/badge/hex-docs-ffaff3)](https://hexdocs.pm/squirrel/) + +```gleam +import squirrel + +pub fn main() { + // TODO: An example of the project in use +} +``` + +Further documentation can be found at . + +## Development + +```sh +gleam run # Run the project +gleam test # Run the tests +``` diff --git a/gleam.toml b/gleam.toml new file mode 100644 index 0000000..eaac12f --- /dev/null +++ b/gleam.toml @@ -0,0 +1,17 @@ +name = "squirrel" +version = "0.1.0" +description = "🐿️ Type safe SQL in Gleam" +licences = ["Apache-2.0"] +repository = { type = "github", user = "giacomocavalieri", repo = "squirrel" } + +[dependencies] +gleam_stdlib = ">= 0.34.0 and < 2.0.0" +simplifile = ">= 2.0.1 and < 3.0.0" +eval = ">= 1.0.0 and < 2.0.0" +gleam_json = ">= 1.0.0 and < 2.0.0" +mug = ">= 1.1.0 and < 2.0.0" +glam = ">= 2.0.1 and < 3.0.0" +justin = ">= 1.0.1 and < 2.0.0" + +[dev-dependencies] +gleeunit = ">= 1.0.0 and < 2.0.0" diff --git a/manifest.toml b/manifest.toml new file mode 100644 index 0000000..38ff89c --- /dev/null +++ b/manifest.toml @@ -0,0 +1,26 @@ +# This file was generated by Gleam +# You typically do not need to edit this file + +packages = [ + { name = "eval", version = "1.0.0", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "eval", source = "hex", outer_checksum = "264DAF4B49DF807F303CA4A4E4EBC012070429E40BE384C58FE094C4958F9BDA" }, + { name = "filepath", version = "1.0.0", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "filepath", source = "hex", outer_checksum = "EFB6FF65C98B2A16378ABC3EE2B14124168C0CE5201553DE652E2644DCFDB594" }, + { name = "glam", version = "2.0.1", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "glam", source = "hex", outer_checksum = "66EC3BCD632E51EED029678F8DF419659C1E57B1A93D874C5131FE220DFAD2B2" }, + { name = "gleam_erlang", version = "0.25.0", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "gleam_erlang", source = "hex", outer_checksum = "054D571A7092D2A9727B3E5D183B7507DAB0DA41556EC9133606F09C15497373" }, + { name = "gleam_json", version = "1.0.1", build_tools = ["gleam"], requirements = ["gleam_stdlib", "thoas"], otp_app = "gleam_json", source = "hex", outer_checksum = "9063D14D25406326C0255BDA0021541E797D8A7A12573D849462CAFED459F6EB" }, + { name = "gleam_stdlib", version = "0.39.0", build_tools = ["gleam"], requirements = [], otp_app = "gleam_stdlib", source = "hex", outer_checksum = "2D7DE885A6EA7F1D5015D1698920C9BAF7241102836CE0C3837A4F160128A9C4" }, + { name = "gleeunit", version = "1.2.0", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "gleeunit", source = "hex", outer_checksum = "F7A7228925D3EE7D0813C922E062BFD6D7E9310F0BEE585D3A42F3307E3CFD13" }, + { name = "justin", version = "1.0.1", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "justin", source = "hex", outer_checksum = "7FA0C6DB78640C6DC5FBFD59BF3456009F3F8B485BF6825E97E1EB44E9A1E2CD" }, + { name = "mug", version = "1.1.0", build_tools = ["gleam"], requirements = ["gleam_erlang", "gleam_stdlib"], otp_app = "mug", source = "hex", outer_checksum = "85A61E67A7A8C25F4460D9CBEF1C09C68FC06ABBC6FF893B0A1F42AE01CBB546" }, + { name = "simplifile", version = "2.0.1", build_tools = ["gleam"], requirements = ["filepath", "gleam_stdlib"], otp_app = "simplifile", source = "hex", outer_checksum = "5FFEBD0CAB39BDD343C3E1CCA6438B2848847DC170BA2386DF9D7064F34DF000" }, + { name = "thoas", version = "1.2.1", build_tools = ["rebar3"], requirements = [], otp_app = "thoas", source = "hex", outer_checksum = "E38697EDFFD6E91BD12CEA41B155115282630075C2A727E7A6B2947F5408B86A" }, +] + +[requirements] +eval = { version = ">= 1.0.0 and < 2.0.0" } +glam = { version = ">= 2.0.1 and < 3.0.0" } +gleam_json = { version = ">= 1.0.0 and < 2.0.0" } +gleam_stdlib = { version = ">= 0.34.0 and < 2.0.0" } +gleeunit = { version = ">= 1.0.0 and < 2.0.0" } +justin = { version = ">= 1.0.1 and < 2.0.0" } +mug = { version = ">= 1.1.0 and < 2.0.0" } +simplifile = { version = ">= 2.0.1 and < 3.0.0" } diff --git a/src/squirrel.gleam b/src/squirrel.gleam new file mode 100644 index 0000000..41d8777 --- /dev/null +++ b/src/squirrel.gleam @@ -0,0 +1,24 @@ +import gleam/list +import simplifile +import squirrel/internal/database/postgres +import squirrel/internal/query + +// TODO LIST: +// - [ ] How does one decide where the query goes? +// - [ ] How do I deal with duplicate names then? +// - [ ] Make a nice CLI tool +// - [ ] How does one connect to the db? +// + +/// Entry point for the `squirrel` CLI. +/// +pub fn main() { + let assert Ok(queries) = + "" + |> query.from_file + + let assert Ok(queries) = postgres.main(queries) + + use query <- list.each(queries) + simplifile.append(query.to_function(query) <> "\n\n", to: "") +} diff --git a/src/squirrel/internal/database/postgres.gleam b/src/squirrel/internal/database/postgres.gleam new file mode 100644 index 0000000..88d755a --- /dev/null +++ b/src/squirrel/internal/database/postgres.gleam @@ -0,0 +1,682 @@ +import eval +import gleam/bit_array +import gleam/bool +import gleam/dict.{type Dict} +import gleam/dynamic.{type DecodeErrors, type Dynamic} as d +import gleam/int +import gleam/json +import gleam/list +import gleam/option.{type Option, None, Some} +import gleam/result +import gleam/set.{type Set} +import gleam/string +import squirrel/internal/database/postgres_protocol as pg +import squirrel/internal/error.{ + type Error, CannotParseQuery, PgCannotDecodeReceivedMessage, + PgCannotReceiveMessage, PgCannotSendMessage, QueryHasInvalidColumn, +} +import squirrel/internal/eval_extra +import squirrel/internal/gleam +import squirrel/internal/query.{type Query, type TypedQuery, Query, TypedQuery} + +const find_postgres_type_query = " +select + -- The name of the type or, if the type is an array, the name of its + -- elements' type. + case + when elem.typname is null then type.typname + else elem.typname + end as type, + -- Tells us how to interpret the firs column: if this is true then the first + -- column is the type of the elements of the array type. + -- Otherwise it means we've found a base type. + case + when elem.typname is null then false + else true + end as is_array +from + pg_type as type + left join pg_type as elem on type.typelem = elem.oid +where + type.oid = $1 +" + +const find_column_nullability_query = " +select + -- Whether the column has a not-null constraint. + attnotnull +from + pg_attribute +where + -- The oid of the table the column comes from. + attrelid = $1 + -- The index of the column we're looking for. + and attnum = $2 +" + +// --- TYPES ------------------------------------------------------------------- + +type PgType { + PBase(name: String) + PArray(inner: PgType) + POption(inner: PgType) +} + +type Context { + Context( + db: pg.Connection, + gleam_types: Dict(Int, gleam.Type), + column_nullability: Dict(#(Int, Int), Nullability), + ) +} + +type Nullability { + Nullable + NotNullable +} + +type Plan { + Plan( + join_type: Option(JoinType), + parent_relation: Option(ParentRelation), + output: Option(List(String)), + plans: Option(List(Plan)), + ) +} + +type JoinType { + Full + Left + Right + Other +} + +type ParentRelation { + Inner + NotInner +} + +type Db(a) = + eval.Eval(a, Error, Context) + +// --- POSTGRES TO GLEAM TYPES CONVERSIONS ------------------------------------- + +fn pg_to_gleam_type(type_: PgType) -> gleam.Type { + case type_ { + PArray(inner: inner) -> gleam.List(pg_to_gleam_type(inner)) + POption(inner: inner) -> gleam.Option(pg_to_gleam_type(inner)) + PBase(name: name) -> + case name { + "bool" -> gleam.Bool + "text" | "char" -> gleam.String + "float4" | "float8" -> gleam.Float + "int2" | "int4" | "int8" -> gleam.Int + _ -> panic as "TODO)) unsupported type" + } + } +} + +// --- CLI ENTRY POINT --------------------------------------------------------- + +pub fn main(queries: List(Query)) -> Result(List(TypedQuery), Error) { + let script = { + use _ <- eval.try(authenticate()) + // Once the server has confirmed that it is ready to accept query requests we + // can start gathering information about all the different queries. + // After each one we need to make sure the server is ready to go on with the + // next one. + // + // https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-EXT-QUERY + // + use query <- eval_extra.try_map(queries) + infer_types(query) + } + + script + |> eval.run(Context( + db: pg.connect("localhost", 5432, 1000), + gleam_types: dict.new(), + column_nullability: dict.new(), + )) +} + +fn authenticate() -> Db(Nil) { + let params = [#("user", "giacomocavalieri"), #("database", "prova")] + use _ <- eval.try(send(pg.FeStartupMessage(params))) + // TODO)) Deal with possible failures + use msg <- eval.try(receive()) + let assert pg.BeAuthenticationOk = msg + use _ <- eval.try(wait_until_ready()) + eval.return(Nil) +} + +/// Returns type information about a query. +/// +fn infer_types(query: Query) -> Db(TypedQuery) { + // Postgres doesn't give us 100% accurate data regardin a query's type. + // We'll need to perform a couple of database interrogations and do some + // guessing. + // + // The big picture idea is the following: + // - We ask the server to prepare the query. + // - Postgres will reply with type information about the returned rows and the + // query's parameters. + let action = parameters_and_returns(query) + use #(parameters, returns) <- eval.try(action) + // - The parameters' types are just OIDs so we need to interrogate the + // database to learn the actual corresponding Gleam type. + use parameters <- eval.try(resolve_parameters(parameters)) + // - The returns' types are just OIDs so we have to do the same for those. + // - Here comes the tricky part: we can't know if a column is nullable just + // from the server's previous answer. + // - For columns coming from a database table we'll look it up and see if + // the column is nullable or not + // - But this is not enough! If a returned column comes from a left/right + // join it will be nullable even if it is not in the original table. + // To work around this we'll have to inspect the query plan. + use plan <- eval.try(query_plan(query, list.length(parameters))) + let nullables = nullables_from_plan(plan) + use returns <- eval.try(resolve_returns(query, returns, nullables)) + + query + |> query.add_types(parameters, returns) + |> eval.return +} + +fn parameters_and_returns(query: Query) -> Db(_) { + // We need to send three messages: + // - `Parse` with the query to parse + // - `Describe` to ask the server to reply with a description of the query's + // return types and parameter types. This is what we need to understand the + // SQL inferred types and generate the corresponding Gleam types. + // - `Sync` to ask the server to immediately reply with the results of parsing + // + use _ <- eval.try( + send_all([ + pg.FeParse("", query.content, []), + pg.FeDescribe(pg.PreparedStatement, ""), + pg.FeSync, + ]), + ) + + use msg <- eval.try(receive()) + case msg { + pg.BeErrorResponse(errors) -> + eval.throw(error_fields_to_parse_error(query, errors)) + pg.BeParseComplete -> { + use msg <- eval.try(receive()) + let assert pg.BeParameterDescription(parameters) = msg + use msg <- eval.try(receive()) + let assert pg.BeRowDescriptions(rows) = msg + use msg <- eval.try(receive()) + let assert pg.BeReadyForQuery(_) = msg + eval.return(#(parameters, rows)) + } + _ -> panic as "parse complete: unexpected message sequence" + } +} + +fn error_fields_to_parse_error( + query: Query, + errors: Set(pg.ErrorOrNoticeField), +) -> Error { + let #(error_code, message, position, hint) = { + use acc, error_field <- set.fold(errors, from: #(None, None, None, None)) + let #(code, message, position, hint) = acc + case error_field { + pg.Code(code) -> #(Some(code), message, position, hint) + pg.Message(message) -> #(code, Some(message), position, hint) + pg.Hint(hint) -> #(code, message, position, Some(hint)) + pg.Position(position) -> + case int.parse(position) { + Ok(position) -> #(code, message, Some(position), hint) + Error(_) -> acc + } + _ -> acc + } + } + + let Query(content: content, file: file, name: name) = query + CannotParseQuery( + content: content, + file: file, + name: gleam.identifier_to_string(name), + error_code: error_code, + hint: hint, + message: message, + position: position, + ) +} + +fn resolve_parameters(parameters: List(Int)) -> Db(List(gleam.Type)) { + use oid <- eval_extra.try_map(parameters) + find_gleam_type(oid) +} + +/// Looks up a type with the given id in the Postgres registry. +/// +/// > ⚠️ This function assumes that the oid is present in the database and +/// > will crash otherwise. This should only be called with oids coming from +/// > a database interrogation. +/// +fn find_gleam_type(oid: Int) -> Db(gleam.Type) { + // We first look for the Gleam type corresponding to this id in the cache to + // avoid hammering the db with needless queries. + use <- with_cached_gleam_type(oid) + + // The only parameter to this query is the oid of the type to lookup: + // that's a 32bit integer (its oid needed to prepare the query is 23). + let params = [pg.Parameter(<>)] + let run_query = find_postgres_type_query |> run_query(params, [23]) + use res <- eval.try(run_query) + + // We know the output must only contain two values: the name and a boolean to + // check wether it is an array or not. + // It's safe to assert because this query is hard coded in our code and the + // output shape cannot change without us changing that query. + let assert [name, is_array] = res + + // We then decode the bitarrays we got as a result: + // - `name` is just a string + // - `is_array` is a pg boolean + let assert Ok(name) = bit_array.to_string(name) + let type_ = case bit_array_to_bool(is_array) { + True -> PArray(PBase(name)) + False -> PBase(name) + } + eval.return(pg_to_gleam_type(type_)) +} + +/// Returns the query plan for a given query. +/// `parameters` is the number of parameter placeholders in the query. +/// +fn query_plan(query: Query, parameters: Int) -> Db(Plan) { + // We ask postgres to give us the query plan. To do that we need to fill in + // all the possible holes in the user supplied query with null values; + // otherwise, the server would complain that it has arguments that are not + // bound. + let query = "explain (format json, verbose) " <> query.content + let params = list.repeat(pg.Null, parameters) + let run_query = run_query(query, params, []) + use res <- eval.try(run_query) + + // We know the output will only contain a single row that is the json string + // containing the query plan. + let assert [plan] = res + let assert Ok([plan, ..]) = json.decode_bits(plan, json_plans_decoder) + eval.return(plan) +} + +/// Given a query plan, returns a set with the indices of the output columns +/// that can contain null values. +/// +fn nullables_from_plan(plan: Plan) -> Set(Int) { + let outputs = case plan.output { + Some(outputs) -> list.index_fold(outputs, dict.new(), dict.insert) + None -> dict.new() + } + + do_nullables_from_plan(plan, outputs, set.new()) +} + +fn do_nullables_from_plan( + plan: Plan, + // A dict from "column name" to its position in the query output. + query_outputs: Dict(String, Int), + nullables: Set(Int), +) -> Set(Int) { + let nullables = case plan.output, plan.join_type, plan.parent_relation { + // - All the outputs of a full join must be marked as nullable + // - All the outputs of an inner half join must be marked as nullable + Some(outputs), Some(Full), _ | Some(outputs), _, Some(Inner) -> { + use nullables, output <- list.fold(outputs, from: nullables) + case dict.get(query_outputs, output) { + Ok(i) -> set.insert(nullables, i) + Error(_) -> nullables + } + } + _, _, _ -> nullables + } + + case plan.plans, plan.join_type { + // If this is an inner half join we keep inspecting the children to mark + // their outputs as nullable. + Some(plans), Some(Left) | Some(plans), Some(Right) -> { + use nullables, plan <- list.fold(plans, from: nullables) + do_nullables_from_plan(plan, query_outputs, nullables) + } + _, _ -> nullables + } +} + +/// Given a list of `RowDescriptionFields` it turns those into Gleam fields with +/// a name and a type. +/// +/// This also uses nullability info coming from the query plan to figure out if +/// a column can be nullable or not: +/// - If the column name ends with `!` it will be forced to be not nullable +/// - If the column name ends with `?` it will be forced to be nullable +/// - If the column appears in the `nullables` set then it will be nullable +/// - Othwerwise we look for its metadata in the database and if it has a +/// not-null constraint it will be not nullable; otherwise it will be nullable +/// +fn resolve_returns( + query: Query, + returns: List(pg.RowDescriptionField), + nullables: Set(Int), +) -> Db(List(gleam.Field)) { + use column, i <- eval_extra.try_index_map(returns) + let pg.RowDescriptionField( + data_type_oid: type_oid, + attr_number: column, + table_oid: table, + name: name, + .., + ) = column + + use type_ <- eval.try(find_gleam_type(type_oid)) + + let ends_with_exclamation_mark = string.ends_with(name, "!") + let ends_with_question_mark = string.ends_with(name, "?") + use nullability <- eval.try(case ends_with_exclamation_mark { + True -> eval.return(NotNullable) + False -> + case ends_with_question_mark { + True -> eval.return(Nullable) + False -> + case set.contains(nullables, i) { + True -> eval.return(Nullable) + False -> column_nullability(table: table, column: column) + } + } + }) + + let type_ = case nullability { + Nullable -> gleam.Option(type_) + NotNullable -> type_ + } + + let try_convert_name = + case ends_with_exclamation_mark || ends_with_question_mark { + True -> string.drop_right(name, 1) + False -> name + } + |> gleam.identifier + |> result.map_error(QueryHasInvalidColumn( + file: query.file, + query_name: gleam.identifier_to_string(query.name), + column_name: name, + reason: _, + )) + + use name <- eval.try(eval.from_result(try_convert_name)) + + let field = gleam.Field(label: name, type_: type_) + eval.return(field) +} + +fn column_nullability(table table: Int, column column: Int) -> Db(Nullability) { + // We first check if the table+column is cached to avoid making redundant + // queries to the database. + use <- with_cached_column(table: table, column: column) + + // If the table oid is 0 that means the column doesn't come from any table so + // we just assume it's not nullable. + use <- bool.guard(when: table == 0, return: eval.return(NotNullable)) + + // This query has 2 parameters: + // - the oid of the table (a 32bit integer, oid is 23) + // - the index of the column (a 32 bit integer, oid is 23) + let params = [pg.Parameter(<>), pg.Parameter(<>)] + let run_query = find_column_nullability_query |> run_query(params, [23, 23]) + use res <- eval.try(run_query) + + // We know the output will only have only one column, that is the boolean + // telling us if the column has a not-null constraint. + let assert [has_non_null_constraint] = res + case bit_array_to_bool(has_non_null_constraint) { + True -> eval.return(NotNullable) + False -> eval.return(Nullable) + } +} + +// --- DB ACTION HELPERS ------------------------------------------------------- + +/// Runs a query against the database. +/// - `parameters` are the parameter values that need to be supplied in place of +/// the query placeholders +/// - `parameters_object_ids` are the oids describing the type of each +/// parameter. +/// +/// > ⚠️ The `parameters_objects_ids` should have the same length of +/// > `parameters` and correctly describe each parameter's type. This function +/// > makes no attempt whatsoever to verify this assumption is correct so be +/// > careful! +/// +/// > ⚠️ This function makes the assumption that the query will only return one +/// > single row. This is totally fine here because we only use this to run +/// > specific hard coded queries that are guaranteed to return a single row. +/// +fn run_query( + query: String, + parameters: List(pg.ParameterValue), + parameters_object_ids: List(Int), +) -> Db(List(BitArray)) { + // The message exchange to run a query works as follow: + // - `Parse` we ask the server to parse the query, we do not give it a name + // - `Bind` we bind the query to the unnamed portal so that it is ready to + // be executed. + // - `Execute` we ask the server to run the unnamed portal and return all + // the rows (0 means return all rows). + // - `Close` we ask to close the unnamed query and the unnamed portal to free + // their resources. + // - `Sync` this acts as a synchronization point that needs to go at the end + // of the sequence before the next one. + // + // As you can see in the receiving part, each message we send corresponds to a + // specific answer from the server: + // - `ParseComplete` the query was parsed correctly + // - `BindComplete` the query was bound to a portal + // - `MessageDataRow` the result(s) coming from the query execution + // - `CommandComplete` when the result coming from the query is over + // - `CloseComplete` the portal/statement was closed + // - `ReadyForQuery` final reply to the sync message signalling we can go on + // making new requests + use _ <- eval.try( + send_all([ + pg.FeParse("", query, parameters_object_ids), + pg.FeBind( + portal: "", + statement_name: "", + parameter_format: pg.FormatAll(pg.Binary), + parameters:, + result_format: pg.FormatAll(pg.Binary), + ), + pg.FeExecute("", 0), + pg.FeClose(pg.PreparedStatement, ""), + pg.FeClose(pg.Portal, ""), + pg.FeSync, + ]), + ) + + use msg <- eval.try(receive()) + let assert pg.BeParseComplete = msg + use msg <- eval.try(receive()) + let assert pg.BeBindComplete = msg + use msg <- eval.try(receive()) + let assert pg.BeMessageDataRow(res) = msg + use msg <- eval.try(receive()) + let assert pg.BeCommandComplete(_, _) = msg + use msg <- eval.try(receive()) + let assert pg.BeCloseComplete = msg + use msg <- eval.try(receive()) + let assert pg.BeCloseComplete = msg + use msg <- eval.try(receive()) + let assert pg.BeReadyForQuery(_) = msg + eval.return(res) +} + +/// Receive a single message from the database. +/// +fn receive() -> Db(pg.BackendMessage) { + use Context(db: db, ..) as context <- eval.from + case pg.receive(db) { + Ok(#(db, msg)) -> #(Context(..context, db: db), Ok(msg)) + Error(pg.ReadDecodeError(error)) -> #( + context, + Error(PgCannotDecodeReceivedMessage(string.inspect(error))), + ) + Error(pg.SocketError(error)) -> #( + context, + Error(PgCannotReceiveMessage(string.inspect(error))), + ) + } +} + +/// Send a single message to the database. +/// +fn send(message message: pg.FrontendMessage) -> Db(Nil) { + use Context(db: db, ..) as context <- eval.from + + let result = + message + |> pg.encode_frontend_message + |> pg.send(db, _) + + let #(db, result) = case result { + Ok(db) -> #(db, Ok(Nil)) + Error(error) -> #(db, Error(PgCannotSendMessage(string.inspect(error)))) + } + + #(Context(..context, db: db), result) +} + +/// Send many messages, one after the other. +/// +fn send_all(messages messages: List(pg.FrontendMessage)) -> Db(Nil) { + use acc, msg <- eval_extra.try_fold(messages, from: Nil) + use _ <- eval.try(send(msg)) + eval.return(acc) +} + +/// Start receiving and discarding messages until a `ReadyForQuery` message is +/// received. +/// +fn wait_until_ready() -> Db(Nil) { + use _ <- eval.try(send(pg.FeFlush)) + do_wait_until_ready() +} + +fn do_wait_until_ready() -> Db(Nil) { + use msg <- eval.try(receive()) + case msg { + pg.BeReadyForQuery(_) -> eval.return(Nil) + _ -> do_wait_until_ready() + } +} + +/// Looks up for a type with the given id in a global cache. +/// If the type is present it immediately returns it. +/// Otherwise it runs the database action to fetch it and then caches it to be +/// reused later. +/// +fn with_cached_gleam_type( + lookup oid: Int, + otherwise do: fn() -> Db(gleam.Type), +) -> Db(gleam.Type) { + use context: Context <- eval.from + case dict.get(context.gleam_types, oid) { + Ok(type_) -> #(context, Ok(type_)) + Error(_) -> + case eval.step(do(), context) { + #(_, Error(_)) as result -> result + #(Context(gleam_types: gleam_types, ..) as context, Ok(type_)) -> { + let gleam_types = dict.insert(gleam_types, oid, type_) + let new_context = Context(..context, gleam_types: gleam_types) + #(new_context, Ok(type_)) + } + } + } +} + +/// Looks up for the nullability of table's column. +/// If the nullability is cached it is immediately returns it. +/// Otherwise it runs the database action to fetch it and then caches it to be +/// reused later. +/// +fn with_cached_column( + table table_oid: Int, + column column: Int, + otherwise do: fn() -> Db(Nullability), +) -> Db(Nullability) { + use context: Context <- eval.from + let key = #(table_oid, column) + case dict.get(context.column_nullability, key) { + Ok(type_) -> #(context, Ok(type_)) + Error(_) -> + case eval.step(do(), context) { + #(_, Error(_)) as result -> result + #( + Context( + column_nullability: column_nullability, + .., + ) as context, + Ok(type_), + ) -> { + let column_nullability = dict.insert(column_nullability, key, type_) + let new_context = + Context(..context, column_nullability: column_nullability) + #(new_context, Ok(type_)) + } + } + } +} + +// --- DECODERS ---------------------------------------------------------------- + +fn json_plans_decoder(data: Dynamic) -> Result(List(Plan), DecodeErrors) { + d.list(d.field("Plan", plan_decoder))(data) +} + +fn plan_decoder(data: Dynamic) -> Result(Plan, DecodeErrors) { + d.decode4( + Plan, + d.optional_field("Join Type", join_type_decoder), + d.optional_field("Parent Relationship", parent_relation_decoder), + d.optional_field("Output", d.list(d.string)), + d.optional_field("Plans", d.list(plan_decoder)), + )(data) +} + +fn join_type_decoder(data: Dynamic) -> Result(JoinType, DecodeErrors) { + use data <- result.map(d.string(data)) + case data { + "Full" -> Full + "Left" -> Left + "Right" -> Right + _ -> Other + } +} + +fn parent_relation_decoder( + data: Dynamic, +) -> Result(ParentRelation, DecodeErrors) { + use data <- result.map(d.string(data)) + case data { + "Inner" -> Inner + _ -> NotInner + } +} + +// --- UTILS ------------------------------------------------------------------- + +/// Turns a bit array into a boolean value. +/// Returns `False` if the bit array is all `0`s or empty, `True` otherwise. +/// +fn bit_array_to_bool(bit_array: BitArray) -> Bool { + case bit_array { + <<0, rest:bits>> -> bit_array_to_bool(rest) + <<>> -> False + _ -> True + } +} diff --git a/src/squirrel/internal/database/postgres_protocol.gleam b/src/squirrel/internal/database/postgres_protocol.gleam new file mode 100644 index 0000000..10cb525 --- /dev/null +++ b/src/squirrel/internal/database/postgres_protocol.gleam @@ -0,0 +1,1387 @@ +//// Vendored version of https://hex.pm/packages/postgresql_protocol. +//// - do not fail with an unexpected Command if the outer message is ok +//// +//// This library parses and generates packages for the PostgreSQL Binary Protocol +//// It also provides a basic connection abstraction, but this hasn't been used +//// outside of tests. + +import gleam/bit_array +import gleam/bool +import gleam/dict +import gleam/int +import gleam/list +import gleam/result.{try} +import gleam/set +import mug + +pub type Connection { + Connection(socket: mug.Socket, buffer: BitArray, timeout: Int) +} + +pub fn connect(host, port, timeout) { + let assert Ok(socket) = + mug.connect(mug.ConnectionOptions(host: host, port: port, timeout: timeout)) + + Connection(socket: socket, buffer: <<>>, timeout: timeout) +} + +pub type StateInitial { + StateInitial(parameters: dict.Dict(String, String)) +} + +pub type State { + State( + process_id: Int, + secret_key: Int, + parameters: dict.Dict(String, String), + oids: dict.Dict(Int, fn(BitArray) -> Result(Int, Nil)), + ) +} + +fn default_oids() { + dict.new() + |> dict.insert(23, fn(col: BitArray) { + use converted <- try(bit_array.to_string(col)) + int.parse(converted) + }) +} + +pub fn start(conn, params) { + let assert Ok(conn) = + conn + |> send(encode_startup_message(params)) + + let assert Ok(#(conn, state)) = + conn + |> receive_startup_rec(StateInitial(dict.new())) + + #(conn, state) +} + +fn receive_startup_rec(conn: Connection, state: StateInitial) { + case receive(conn) { + Ok(#(conn, BeAuthenticationOk)) -> receive_startup_rec(conn, state) + Ok(#(conn, BeParameterStatus(name: name, value: value))) -> + receive_startup_rec( + conn, + StateInitial(parameters: dict.insert(state.parameters, name, value)), + ) + Ok(#(conn, BeBackendKeyData(secret_key: secret_key, process_id: process_id))) -> + receive_startup_1( + conn, + State( + parameters: state.parameters, + secret_key: secret_key, + process_id: process_id, + oids: default_oids(), + ), + ) + Ok(#(_conn, msg)) -> Error(StartupFailedWithUnexpectedMessage(msg)) + Error(err) -> Error(StartupFailedWithError(err)) + } +} + +fn receive_startup_1(conn: Connection, state: State) { + case receive(conn) { + Ok(#(conn, BeParameterStatus(name: name, value: value))) -> + receive_startup_1( + conn, + State(..state, parameters: dict.insert(state.parameters, name, value)), + ) + Ok(#(conn, BeReadyForQuery(_))) -> Ok(#(conn, state)) + Ok(#(_conn, msg)) -> Error(StartupFailedWithUnexpectedMessage(msg)) + Error(err) -> Error(StartupFailedWithError(err)) + } +} + +pub type StartupFailed { + StartupFailedWithUnexpectedMessage(BackendMessage) + StartupFailedWithError(ReadError) +} + +pub type ReadError { + SocketError(mug.Error) + ReadDecodeError(MessageDecodingError) +} + +pub type MessageDecodingError { + MessageDecodingError(String) + MessageIncomplete(BitArray) + UnknownMessage(data: BitArray) +} + +fn dec_err(desc: String, data: BitArray) -> Result(a, MessageDecodingError) { + Error(MessageDecodingError(desc <> "; data: " <> bit_array.inspect(data))) +} + +fn msg_dec_err(desc: String, data: BitArray) -> MessageDecodingError { + MessageDecodingError(desc <> "; data: " <> bit_array.inspect(data)) +} + +pub type Command { + Insert + Delete + Update + Merge + Select + Move + Fetch + Copy +} + +/// Messages originating from the PostgreSQL backend (server) +pub type BackendMessage { + BeBindComplete + BeCloseComplete + BeCommandComplete(Command, Int) + BeCopyData(data: BitArray) + BeCopyDone + BeAuthenticationOk + BeAuthenticationKerberosV5 + BeAuthenticationCleartextPassword + BeAuthenticationMD5Password(salt: BitArray) + BeAuthenticationGSS + BeAuthenticationGSSContinue(auth_data: BitArray) + BeAuthenticationSSPI + BeAuthenticationSASL(mechanisms: List(String)) + BeAuthenticationSASLContinue(data: BitArray) + BeAuthenticationSASLFinal(data: BitArray) + BeReadyForQuery(TransactionStatus) + BeRowDescriptions(List(RowDescriptionField)) + BeMessageDataRow(List(BitArray)) + BeBackendKeyData(process_id: Int, secret_key: Int) + BeParameterStatus(name: String, value: String) + BeCopyResponse( + direction: CopyDirection, + overall_format: Format, + codes: List(Format), + ) + BeNegotiateProtocolVersion( + newest_minor: Int, + unrecognized_options: List(String), + ) + BeNoData + BeNoticeResponse(set.Set(ErrorOrNoticeField)) + BeNotificationResponse(process_id: Int, channel: String, payload: String) + BeParameterDescription(List(Int)) + BeParseComplete + BePortalSuspended + BeErrorResponse(set.Set(ErrorOrNoticeField)) +} + +/// Direction of a BeCopyResponse +pub type CopyDirection { + In + Out + Both +} + +/// Indicates encoding of column values +pub type Format { + Text + Binary +} + +/// Lists of parameters can have different encodings. +pub type FormatValue { + FormatAllText + FormatAll(Format) + Formats(List(Format)) +} + +fn encode_format_value(format) -> BitArray { + case format { + FormatAllText -> <> + FormatAll(Text) -> <<1:16, encode_format(Text):16>> + FormatAll(Binary) -> <<1:16, encode_format(Binary):16>> + Formats(formats) -> { + let size = list.length(formats) + list.fold(formats, <>, fn(sum, fmt) { + <> + }) + } + } +} + +pub type ParameterValues = + List(ParameterValue) + +pub type ParameterValue { + Parameter(BitArray) + Null +} + +fn parameters_to_bytes(parameters: ParameterValues) { + let mapped = + parameters + |> list.map(fn(parameter) { + case parameter { + Parameter(value) -> <> + Null -> <<-1:32>> + } + }) + + <> +} + +pub type FrontendMessage { + FeBind( + portal: String, + statement_name: String, + parameter_format: FormatValue, + parameters: ParameterValues, + result_format: FormatValue, + ) + FeCancelRequest(process_id: Int, secret_key: Int) + FeClose(what: What, name: String) + FeCopyData(data: BitArray) + FeCopyDone + FeCopyFail(error: String) + FeDescribe(what: What, name: String) + FeExecute(portal: String, return_row_count: Int) + FeFlush + FeFunctionCall( + object_id: Int, + argument_format: FormatValue, + arguments: ParameterValues, + result_format: Format, + ) + FeGssEncRequest + FeParse(name: String, query: String, parameter_object_ids: List(Int)) + FeQuery(query: String) + FeStartupMessage(params: List(#(String, String))) + FeSslRequest + FeTerminate + FeSync + FeAmbigous(FeAmbigous) +} + +// These all share the same message type +pub type FeAmbigous { + FeGssResponse(data: BitArray) + FeSaslInitialResponse(name: String, data: BitArray) + FeSaslResponse(data: BitArray) + FePasswordMessage(password: String) +} + +pub type What { + PreparedStatement + Portal +} + +fn wire_what(what) { + case what { + Portal -> <<"P":utf8>> + PreparedStatement -> <<"S":utf8>> + } +} + +fn decode_what(binary) { + case binary { + <<"P":utf8, rest:bytes>> -> Ok(#(Portal, rest)) + <<"S":utf8, rest:bytes>> -> Ok(#(PreparedStatement, rest)) + _ -> dec_err("only portal and prepared statement are allowed", binary) + } +} + +fn encode_message_data_row(columns) { + <> + |> list.fold( + columns, + _, + fn(sum, column) { + let len = bit_array.byte_size(column) + <> + }, + ) + |> encode("D", _) +} + +fn encode_error_response(fields: set.Set(ErrorOrNoticeField)) -> BitArray { + fields + |> set.fold(<<>>, fn(sum, field) { <> }) + |> encode("E", _) +} + +fn encode_field(field) { + case field { + Severity(value) -> <<"S":utf8, value:utf8, 0>> + SeverityLocalized(value) -> <<"V":utf8, value:utf8, 0>> + Code(value) -> <<"C":utf8, value:utf8, 0>> + Message(value) -> <<"M":utf8, value:utf8, 0>> + Detail(value) -> <<"D":utf8, value:utf8, 0>> + Hint(value) -> <<"H":utf8, value:utf8, 0>> + Position(value) -> <<"P":utf8, value:utf8, 0>> + InternalPosition(value) -> <<"p":utf8, value:utf8, 0>> + InternalQuery(value) -> <<"q":utf8, value:utf8, 0>> + Where(value) -> <<"W":utf8, value:utf8, 0>> + Schema(value) -> <<"s":utf8, value:utf8, 0>> + Table(value) -> <<"t":utf8, value:utf8, 0>> + Column(value) -> <<"c":utf8, value:utf8, 0>> + DataType(value) -> <<"d":utf8, value:utf8, 0>> + Constraint(value) -> <<"n":utf8, value:utf8, 0>> + File(value) -> <<"F":utf8, value:utf8, 0>> + Line(value) -> <<"L":utf8, value:utf8, 0>> + Routine(value) -> <<"R":utf8, value:utf8, 0>> + Unknown(key, value) -> <> + } +} + +fn encode_authentication_sasl(mechanisms) -> BitArray { + list.fold(mechanisms, <<10:32>>, fn(sum, mechanism) { + <> + }) + |> encode("R", _) +} + +fn encode_command_complete(command, rows_num) { + let rows = int.to_string(rows_num) + + let data = + case command { + Insert -> "INSERT 0 " <> rows + Delete -> "DELETE " <> rows + Update -> "UPDATE " <> rows + Merge -> "MERGE " <> rows + Select -> "SELECT " <> rows + Move -> "MOVE " <> rows + Fetch -> "FETCH " <> rows + Copy -> "COPY " <> rows + } + |> encode_string() + + encode("C", data) +} + +fn encode_copy_response( + direction: CopyDirection, + overall_format: Format, + codes: List(Format), +) { + let data = encode_copy_response_rec(overall_format, codes) + case direction { + In -> encode("G", data) + Out -> encode("H", data) + Both -> encode("W", data) + } +} + +// The format codes to be used for each column. Each must presently be zero +// (text) or one (binary). All must be zero if the overall copy format is +// textual. +fn encode_copy_response_rec(overall_format, codes) { + case overall_format { + Text -> + codes + |> list.fold( + <>, + fn(sum, _code) { <> }, + ) + Binary -> + codes + |> list.fold( + <>, + fn(sum, code) { <> }, + ) + } +} + +fn encode_parameter_status(name: String, value: String) -> BitArray { + encode("S", <>) +} + +fn encode_negotiate_protocol_version( + newest_minor: Int, + unrecognized_options: List(String), +) { + list.fold( + unrecognized_options, + <>, + fn(sum, option) { <> }, + ) + |> encode("v", _) +} + +fn encode_notice_response(fields: set.Set(ErrorOrNoticeField)) { + fields + |> set.fold(<<>>, fn(sum, field) { <> }) + |> encode("N", _) +} + +fn encode_notification_response( + process_id: Int, + channel: String, + payload: String, +) { + encode("A", <>) +} + +fn encode_parameter_description(descriptions: List(Int)) { + descriptions + |> list.fold(<>, fn(sum, description) { + <> + }) + |> encode("t", _) +} + +fn encode_row_descriptions(fields: List(RowDescriptionField)) { + fields + |> list.fold(<>, fn(sum, field) { + << + sum:bits, + encode_string(field.name):bits, + field.table_oid:32, + field.attr_number:16, + field.data_type_oid:32, + field.data_type_size:16, + field.type_modifier:32, + field.format_code:16, + >> + }) + |> encode("T", _) +} + +pub fn encode_backend_message(message: BackendMessage) -> BitArray { + case message { + BeMessageDataRow(columns) -> encode_message_data_row(columns) + BeErrorResponse(fields) -> encode_error_response(fields) + BeAuthenticationOk -> encode("R", <<0:32>>) + BeAuthenticationKerberosV5 -> encode("R", <<2:32>>) + BeAuthenticationCleartextPassword -> encode("R", <<3:32>>) + BeAuthenticationMD5Password(salt: salt) -> encode("R", <<5:32, salt:bits>>) + BeAuthenticationGSS -> encode("R", <<7:32>>) + BeAuthenticationGSSContinue(data) -> encode("R", <<8:32, data:bits>>) + BeAuthenticationSSPI -> encode("R", <<9:32>>) + BeAuthenticationSASL(a) -> encode_authentication_sasl(a) + BeAuthenticationSASLContinue(data) -> encode("R", <<11:32, data:bits>>) + BeAuthenticationSASLFinal(data: data) -> encode("R", <<12:32, data:bits>>) + BeBackendKeyData(pid, sk) -> encode("K", <>) + BeBindComplete -> encode("2", <<>>) + BeCloseComplete -> encode("3", <<>>) + BeCommandComplete(a, b) -> encode_command_complete(a, b) + BeCopyData(data) -> encode("d", data) + BeCopyDone -> encode("c", <<>>) + BeCopyResponse(a, b, c) -> encode_copy_response(a, b, c) + BeParameterStatus(name, value) -> encode_parameter_status(name, value) + BeNegotiateProtocolVersion(a, b) -> encode_negotiate_protocol_version(a, b) + BeNoData -> encode("n", <<>>) + BeNoticeResponse(data) -> encode_notice_response(data) + BeNotificationResponse(a, b, c) -> encode_notification_response(a, b, c) + BeParameterDescription(a) -> encode_parameter_description(a) + BeParseComplete -> encode("1", <<>>) + BePortalSuspended -> encode("s", <<>>) + BeRowDescriptions(a) -> encode_row_descriptions(a) + BeReadyForQuery(TransactionStatusIdle) -> encode("Z", <<"I":utf8>>) + BeReadyForQuery(TransactionStatusInTransaction) -> encode("Z", <<"T":utf8>>) + BeReadyForQuery(TransactionStatusFailed) -> encode("Z", <<"E":utf8>>) + } +} + +pub fn encode_frontend_message(message: FrontendMessage) { + case message { + FeBind(a, b, c, d, e) -> encode_bind(a, b, c, d, e) + FeCancelRequest(process_id: process_id, secret_key: secret_key) -> << + 16:32, + 1234:16, + 5678:16, + process_id:32, + secret_key:32, + >> + FeClose(what, name) -> + encode("C", <>) + FeCopyData(data) -> encode("d", data) + FeCopyDone -> <<"c":utf8, 4:32>> + FeCopyFail(error) -> encode("f", encode_string(error)) + FeDescribe(what, name) -> + encode("D", <>) + FeExecute(portal, count) -> + encode("E", <>) + FeFlush -> <<"H":utf8, 4:32>> + FeFunctionCall(a, b, c, d) -> encode_function_call(a, b, c, d) + FeGssEncRequest -> <<8:32, 1234:16, 5680:16>> + FeAmbigous(FeGssResponse(data)) -> encode("p", data) + FeParse(a, b, c) -> encode_parse(a, b, c) + FeAmbigous(FePasswordMessage(password)) -> + encode("p", encode_string(password)) + FeQuery(query) -> encode("Q", encode_string(query)) + FeAmbigous(FeSaslInitialResponse(a, b)) -> + encode_sasl_initial_response(a, b) + FeAmbigous(FeSaslResponse(data)) -> encode("p", data) + FeStartupMessage(params) -> encode_startup_message(params) + FeSslRequest -> <<8:32, 1234:16, 5679:16>> + FeSync -> <<"S":utf8, 4:32>> + FeTerminate -> <<"X":utf8, 4:32>> + } +} + +fn encode(type_char, data) { + case data { + <<>> -> <> + _ -> { + let len = bit_array.byte_size(data) + 4 + <> + } + } +} + +fn encode_string(str) { + <> +} + +pub const protocol_version_major = <<3:16>> + +pub const protocol_version_minor = <<0:16>> + +pub const protocol_version = << + protocol_version_major:bits, + protocol_version_minor:bits, +>> + +fn encode_startup_message(params) { + let packet = + params + |> list.fold(<>, fn(builder, element) { + let #(key, value) = element + <> + }) + + let size = bit_array.byte_size(packet) + 5 + + <> +} + +fn encode_sasl_initial_response(name, data) { + let len = bit_array.byte_size(data) + encode("p", <>) +} + +fn encode_parse(name, query, parameter_object_ids) { + let oids = + list.fold(parameter_object_ids, <<>>, fn(sum, oid) { <> }) + let len = list.length(parameter_object_ids) + encode("P", << + encode_string(name):bits, + encode_string(query):bits, + len:16, + oids:bits, + >>) +} + +fn encode_bind( + portal, + statement_name, + parameter_format, + parameters, + result_format, +) { + encode("B", << + portal:utf8, + 0, + statement_name:utf8, + 0, + encode_format_value(parameter_format):bits, + parameters_to_bytes(parameters):bits, + encode_format_value(result_format):bits, + >>) +} + +fn encode_function_call( + object_id: Int, + argument_format: FormatValue, + arguments: ParameterValues, + result_format: Format, +) { + encode("F", << + object_id:32, + encode_format_value(argument_format):bits, + parameters_to_bytes(arguments):bits, + encode_format(result_format):16, + >>) +} + +/// Send a message to the database +pub fn send_builder(conn: Connection, message) { + case mug.send_builder(conn.socket, message) { + Ok(Nil) -> Ok(conn) + Error(err) -> Error(err) + } +} + +/// Send a message to the database +pub fn send(conn: Connection, message) { + case mug.send(conn.socket, message) { + Ok(Nil) -> Ok(conn) + Error(err) -> Error(err) + } +} + +/// Receive a single message from the backend +pub fn receive( + conn: Connection, +) -> Result(#(Connection, BackendMessage), ReadError) { + case decode_backend_packet(conn.buffer) { + Ok(#(message, rest)) -> Ok(#(with_buffer(conn, rest), message)) + Error(MessageIncomplete(_)) -> { + case mug.receive(conn.socket, conn.timeout) { + Ok(packet) -> + receive(with_buffer(conn, <>)) + Error(err) -> Error(SocketError(err)) + } + } + Error(err) -> Error(ReadDecodeError(err)) + } +} + +fn with_buffer(conn: Connection, buffer: BitArray) { + Connection(..conn, buffer: buffer) +} + +// decode a single message from the packet +pub fn decode_backend_packet( + packet: BitArray, +) -> Result(#(BackendMessage, BitArray), MessageDecodingError) { + case packet { + <> -> { + let len = length - 4 + case tail { + <> -> + decode_backend_message(<>) + |> result.map(fn(msg) { #(msg, next) }) + _ -> Error(MessageIncomplete(tail)) + } + } + _ -> + case bit_array.byte_size(packet) < 5 { + True -> Error(MessageIncomplete(packet)) + False -> dec_err("packet size too small", packet) + } + } +} + +// decode a backend message +pub fn decode_backend_message(binary) { + case binary { + <<"D":utf8, count:16, data:bytes>> -> decode_message_data_row(count, data) + <<"E":utf8, data:bytes>> -> decode_error_response(data) + <<"R":utf8, 0:32>> -> Ok(BeAuthenticationOk) + <<"R":utf8, 2:32>> -> Ok(BeAuthenticationKerberosV5) + <<"R":utf8, 3:32>> -> Ok(BeAuthenticationCleartextPassword) + <<"R":utf8, 5:32, salt:bytes>> -> + Ok(BeAuthenticationMD5Password(salt: salt)) + <<"R":utf8, 7:32>> -> Ok(BeAuthenticationGSS) + <<"R":utf8, 8:32, auth_data:bytes>> -> + Ok(BeAuthenticationGSSContinue(auth_data: auth_data)) + <<"R":utf8, 9:32>> -> Ok(BeAuthenticationSSPI) + <<"R":utf8, 10:32, data:bytes>> -> decode_authentication_sasl(data) + <<"R":utf8, 11:32, data:bytes>> -> + Ok(BeAuthenticationSASLContinue(data: data)) + <<"R":utf8, 12:32, data:bytes>> -> Ok(BeAuthenticationSASLFinal(data: data)) + <<"K":utf8, pid:32, sk:32>> -> + Ok(BeBackendKeyData(process_id: pid, secret_key: sk)) + <<"2":utf8>> -> Ok(BeBindComplete) + <<"3":utf8>> -> Ok(BeCloseComplete) + <<"C":utf8, data:bytes>> -> decode_command_complete(data) + <<"d":utf8, data:bytes>> -> Ok(BeCopyData(data)) + <<"c":utf8>> -> Ok(BeCopyDone) + <<"G":utf8, format:8, count:16, data:bytes>> -> + decode_copy_response(In, format, count, data) + <<"H":utf8, format:8, count:16, data:bytes>> -> + decode_copy_response(Out, format, count, data) + <<"W":utf8, format:8, count:16, data:bytes>> -> + decode_copy_response(Both, format, count, data) + <<"S":utf8, data:bytes>> -> decode_parameter_status(data) + <<"v":utf8, version:32, count:32, data:bytes>> -> + decode_negotiate_protocol_version(version, count, data) + <<"n":utf8>> -> Ok(BeNoData) + <<"N":utf8, data:bytes>> -> decode_notice_response(data) + <<"A":utf8, process_id:32, data:bytes>> -> + decode_notification_response(process_id, data) + <<"t":utf8, count:16, data:bytes>> -> + decode_parameter_description(count, data, []) + <<"1":utf8>> -> Ok(BeParseComplete) + <<"s":utf8>> -> Ok(BePortalSuspended) + <<"T":utf8, count:16, data:bytes>> -> decode_row_descriptions(count, data) + <<"Z":utf8, "I":utf8>> -> Ok(BeReadyForQuery(TransactionStatusIdle)) + <<"Z":utf8, "T":utf8>> -> + Ok(BeReadyForQuery(TransactionStatusInTransaction)) + <<"Z":utf8, "E":utf8>> -> Ok(BeReadyForQuery(TransactionStatusFailed)) + _ -> Error(UnknownMessage(binary)) + } +} + +/// decode a single message from the packet buffer. +/// note that FeCancelRequest, FeSslRequest, FeGssEncRequest, and +/// FeStartupMessage messages can only be decoded here because they don't follow +/// the standard message format. +pub fn decode_frontend_packet( + packet: BitArray, +) -> Result(#(FrontendMessage, BitArray), MessageDecodingError) { + case packet { + <<16:32, 1234:16, 5678:16, process_id:32, secret_key:32, next:bytes>> -> + Ok(#( + FeCancelRequest(process_id: process_id, secret_key: secret_key), + next, + )) + <<8:32, 1234:16, 5679:16, next:bytes>> -> Ok(#(FeSslRequest, next)) + <<8:32, 1234:16, 5680:16, next:bytes>> -> Ok(#(FeGssEncRequest, next)) + // not sure if there's a way to use the `protocol_version` constant here + <> -> + decode_startup_message(next, length - 8, []) + <> -> { + let len = length - 4 + case tail { + <> -> + decode_frontend_message(<>) + |> result.map(fn(msg) { #(msg, next) }) + _ -> Error(MessageIncomplete(tail)) + } + } + <<_:48>> -> Error(MessageIncomplete(packet)) + _ -> dec_err("invalid message", packet) + } +} + +/// decode a frontend message (also see decode_frontend_packet for messages that +/// can't be decoded here) +pub fn decode_frontend_message( + binary: BitArray, +) -> Result(FrontendMessage, MessageDecodingError) { + case binary { + <<"B":utf8, data:bytes>> -> decode_bind(data) + <<"C":utf8, data:bytes>> -> decode_close(data) + <<"d":utf8, data:bytes>> -> Ok(FeCopyData(data)) + <<"c":utf8>> -> Ok(FeCopyDone) + <<"f":utf8, data:bytes>> -> decode_copy_fail(data) + <<"D":utf8, data:bytes>> -> decode_describe(data) + <<"E":utf8, data:bytes>> -> decode_execute(data) + <<"H":utf8>> -> Ok(FeFlush) + <<"F":utf8, data:bytes>> -> decode_function_call(data) + <<"p":utf8, data:bytes>> -> Ok(FeAmbigous(FeGssResponse(data))) + <<"P":utf8, data:bytes>> -> decode_parse(data) + <<"Q":utf8, data:bytes>> -> decode_query(data) + <<"S":utf8>> -> Ok(FeSync) + <<"X":utf8>> -> Ok(FeTerminate) + _ -> Error(UnknownMessage(data: binary)) + } +} + +fn decode_startup_message(binary, size, result) { + case binary { + <> -> + decode_startup_message_pairs(data, []) + |> result.map(fn(r) { #(r, next) }) + _ -> dec_err("invalid startup message", binary) + } +} + +fn decode_startup_message_pairs(binary, result) { + case binary { + <<0>> -> Ok(FeStartupMessage(params: list.reverse(result))) + _ -> { + use #(key, binary) <- try(decode_string(binary)) + use #(value, binary) <- try(decode_string(binary)) + decode_startup_message_pairs(binary, [#(key, value), ..result]) + } + } +} + +fn decode_query(binary) -> Result(FrontendMessage, MessageDecodingError) { + use #(query, rest) <- try(decode_string(binary)) + case rest { + <<>> -> Ok(FeQuery(query)) + _ -> dec_err("Query message too long", binary) + } +} + +// FeQuery(query) -> encode("Q", encode_string(query)) + +fn decode_parse(binary) -> Result(FrontendMessage, MessageDecodingError) { + use #(name, binary) <- try(decode_string(binary)) + use #(query, binary) <- try(decode_string(binary)) + use parameter_object_ids <- try(decode_parameter_object_ids(binary)) + Ok(FeParse( + name: name, + query: query, + parameter_object_ids: parameter_object_ids, + )) +} + +fn decode_parameter_object_ids(binary) { + case binary { + <> -> decode_parameter_object_ids_rec(rest, count, []) + _ -> dec_err("expected object id count", binary) + } +} + +fn decode_parameter_object_ids_rec(binary, count, result) { + case count, binary { + 0, <<>> -> Ok(list.reverse(result)) + _, <> -> + decode_parameter_object_ids_rec(rest, count - 1, [id, ..result]) + _, _ -> dec_err("expected parameter object id", binary) + } +} + +fn decode_function_call(binary) -> Result(FrontendMessage, MessageDecodingError) { + case binary { + <> -> { + use #(argument_format, rest) <- try(read_parameter_format(rest)) + use #(arguments, rest) <- try(read_parameters(rest, argument_format)) + use #(result_format, rest) <- try(read_format(rest)) + case rest { + <<>> -> + Ok(FeFunctionCall( + argument_format: argument_format, + arguments: arguments, + object_id: object_id, + result_format: result_format, + )) + _ -> dec_err("invalid function call, data remains", rest) + } + } + _ -> dec_err("invalid function call, no object id found", binary) + } +} + +fn decode_execute(binary) -> Result(FrontendMessage, MessageDecodingError) { + use #(portal, binary) <- try(decode_string(binary)) + case binary { + <> -> Ok(FeExecute(portal, count)) + _ -> dec_err("no execute return_row_count found", binary) + } +} + +fn decode_describe(binary) -> Result(FrontendMessage, MessageDecodingError) { + use #(what, binary) <- try(decode_what(binary)) + use #(name, binary) <- try(decode_string(binary)) + case binary { + <<>> -> Ok(FeDescribe(what, name)) + _ -> dec_err("Describe message too long", binary) + } +} + +fn decode_copy_fail(binary) -> Result(FrontendMessage, MessageDecodingError) { + use #(error, binary) <- try(decode_string(binary)) + case binary { + <<>> -> Ok(FeCopyFail(error)) + _ -> dec_err("CopyFail message too long", binary) + } +} + +fn decode_close(binary) -> Result(FrontendMessage, MessageDecodingError) { + use #(what, binary) <- try(decode_what(binary)) + use #(name, binary) <- try(decode_string(binary)) + case binary { + <<>> -> Ok(FeClose(what, name)) + _ -> dec_err("Close message too long", binary) + } +} + +// FeClose(what, name) -> +// encode("C", <>) + +fn decode_bind(binary) -> Result(FrontendMessage, MessageDecodingError) { + use #(portal, binary) <- try(decode_string(binary)) + use #(statement_name, binary) <- try(decode_string(binary)) + use #(parameter_format, binary) <- try(read_parameter_format(binary)) + use #(parameters, binary) <- try(read_parameters(binary, parameter_format)) + use #(result_format, binary) <- try(read_parameter_format(binary)) + case binary { + <<>> -> + Ok(FeBind( + portal: portal, + statement_name: statement_name, + parameter_format: parameter_format, + parameters: parameters, + result_format: result_format, + )) + _ -> dec_err("Bind message too long", binary) + } +} + +fn decode_string( + binary: BitArray, +) -> Result(#(String, BitArray), MessageDecodingError) { + case binary_split(binary, <<0>>, []) { + [head, tail] -> + case bit_array.to_string(head) { + Ok(str) -> Ok(#(str, tail)) + Error(Nil) -> dec_err("invalid string encoding", head) + } + _ -> dec_err("invalid string", binary) + } +} + +fn read_parameter_format( + binary: BitArray, +) -> Result(#(FormatValue, BitArray), MessageDecodingError) { + case binary { + <<0:16, rest:bytes>> -> Ok(#(FormatAllText, rest)) + <<1:16, 0:16, rest:bytes>> -> Ok(#(FormatAll(Text), rest)) + <<1:16, 1:16, rest:bytes>> -> Ok(#(FormatAll(Binary), rest)) + <> -> read_wire_formats(n, rest, []) + _ -> dec_err("invalid parameter format", binary) + } +} + +fn read_parameters(binary: BitArray, parameter_format: FormatValue) { + case binary { + <> -> { + case parameter_format { + FormatAllText -> list.repeat(Text, count) + FormatAll(format) -> list.repeat(format, count) + Formats(formats) -> formats + } + |> read_parameters_rec(count, rest, []) + } + _ -> dec_err("parameters without count", binary) + } +} + +fn read_parameters_rec( + formats: List(Format), + count: Int, + binary: BitArray, + result: List(ParameterValue), +) -> Result(#(List(ParameterValue), BitArray), MessageDecodingError) { + let actual = list.length(formats) + use <- bool.guard( + actual != count, + dec_err( + "expected " + <> int.to_string(count) + <> " parameters, but got " + <> int.to_string(actual), + binary, + ), + ) + + case count, formats, binary { + 0, _, rest -> Ok(#(list.reverse(result), rest)) + _, [_, ..rest_formats], <<-1:32-signed, rest:bytes>> -> { + read_parameters_rec(rest_formats, count - 1, rest, [Null, ..result]) + } + _, [format, ..rest_formats], <> -> + read_parameters_rec(rest_formats, count - 1, rest, [ + read_parameter(format, value), + ..result + ]) + _, _, rest -> dec_err("invalid parameter value", rest) + } +} + +fn read_parameter(format: Format, value: BitArray) { + case format { + Text -> Parameter(value) + Binary -> Parameter(value) + } +} + +fn read_wire_formats( + count: Int, + binary: BitArray, + result: List(Format), +) -> Result(#(FormatValue, BitArray), MessageDecodingError) { + case count, binary { + 0, _ -> Ok(#(Formats(list.reverse(result)), binary)) + _, <<0:16, rest:bytes>> -> + read_wire_formats(count - 1, rest, [Text, ..result]) + _, <<1:16, rest:bytes>> -> + read_wire_formats(count - 1, rest, [Binary, ..result]) + _, _ -> dec_err("unknown format", binary) + } +} + +fn decode_command_complete( + binary: BitArray, +) -> Result(BackendMessage, MessageDecodingError) { + { + use fine <- try(case binary { + <<"INSERT 0 ":utf8, rows:bytes>> -> Ok(#(Insert, rows)) + <<"DELETE ":utf8, rows:bytes>> -> Ok(#(Delete, rows)) + <<"UPDATE ":utf8, rows:bytes>> -> Ok(#(Update, rows)) + <<"MERGE ":utf8, rows:bytes>> -> Ok(#(Merge, rows)) + <<"SELECT ":utf8, rows:bytes>> -> Ok(#(Select, rows)) + <<"MOVE ":utf8, rows:bytes>> -> Ok(#(Move, rows)) + <<"FETCH ":utf8, rows:bytes>> -> Ok(#(Fetch, rows)) + <<"COPY ":utf8, rows:bytes>> -> Ok(#(Copy, rows)) + _ -> dec_err("invalid command", binary) + }) + + let #(command, rows_raw) = fine + let len = bit_array.byte_size(rows_raw) - 1 + + use rows_bits <- try(case rows_raw { + <> -> Ok(rows_bits) + _ -> dec_err("invalid command row count", binary) + }) + + use rows_string <- try( + bit_array.to_string(rows_bits) + |> result.replace_error(msg_dec_err( + "failed to convert row count to string", + rows_bits, + )), + ) + + use rows <- try( + int.parse(rows_string) + |> result.replace_error(msg_dec_err( + "failed to convert row count to int", + rows_bits, + )), + ) + + Ok(BeCommandComplete(command, rows)) + } + |> result.or(Ok(BeCommandComplete(Insert, -1))) +} + +pub type TransactionStatus { + TransactionStatusIdle + TransactionStatusInTransaction + TransactionStatusFailed +} + +fn decode_parameter_description(count, binary, results) { + case count, binary { + 0, <<>> -> Ok(BeParameterDescription(list.reverse(results))) + _, <> -> + decode_parameter_description(count - 1, tail, [value, ..results]) + _, _ -> dec_err("invalid parameter description", binary) + } +} + +fn decode_notification_response( + process_id, + binary, +) -> Result(BackendMessage, MessageDecodingError) { + use strings <- try(read_strings(binary, 2, [])) + case strings { + [channel, payload] -> + Ok(BeNotificationResponse( + process_id: process_id, + channel: channel, + payload: payload, + )) + _ -> dec_err("invalid notification response encoding", binary) + } +} + +fn decode_negotiate_protocol_version(version, count, binary) { + use options <- try(read_strings(binary, count, [])) + Ok(BeNegotiateProtocolVersion(version, options)) +} + +fn decode_row_descriptions(count, binary) { + use fields <- try(read_row_descriptions(count, binary, [])) + Ok(BeRowDescriptions(fields)) +} + +fn decode_parameter_status(binary) { + use strings <- try(decode_strings(binary)) + case strings { + [name, value] -> Ok(BeParameterStatus(name: name, value: value)) + _ -> dec_err("invalid parameter status", binary) + } +} + +fn decode_authentication_sasl(binary) { + use strings <- try(decode_strings(binary)) + Ok(BeAuthenticationSASL(strings)) +} + +fn decode_copy_response(direction, format_raw, count, rest) { + use overall_format <- try(decode_format(format_raw)) + use <- bool.guard( + bit_array.byte_size(rest) != count * 2, + dec_err("size must be count * 2", rest), + ) + + use codes <- try(decode_format_codes(rest, [])) + + case overall_format == Text { + False -> Ok(BeCopyResponse(direction, overall_format, codes)) + True -> + case list.all(codes, fn(code) { code == Text }) { + True -> Ok(BeCopyResponse(direction, overall_format, codes)) + False -> dec_err("invalid copy response format", rest) + } + } +} + +fn decode_format_codes( + binary: BitArray, + result: List(Format), +) -> Result(List(Format), MessageDecodingError) { + case binary { + <> -> + case decode_format(code) { + Ok(format) -> decode_format_codes(tail, [format, ..result]) + Error(err) -> Error(err) + } + <<>> -> Ok(list.reverse(result)) + _ -> dec_err("invalid format codes", binary) + } +} + +fn decode_format(num: Int) { + case num { + 0 -> Ok(Text) + 1 -> Ok(Binary) + _ -> dec_err("invalid format code: " <> int.to_string(num), <<>>) + } +} + +fn read_format(binary) { + case binary { + <<0:16, rest:bytes>> -> Ok(#(Text, rest)) + <<1:16, rest:bytes>> -> Ok(#(Binary, rest)) + _ -> dec_err("invalid format code", binary) + } +} + +fn encode_format(format_raw) -> Int { + case format_raw { + Text -> 0 + Binary -> 1 + } +} + +pub type DataRow { + DataRow(List(BitArray)) +} + +fn decode_message_data_row( + count, + rest, +) -> Result(BackendMessage, MessageDecodingError) { + case decode_message_data_row_rec(rest, count, []) { + Ok(cols) -> + case list.length(cols) == count { + True -> Ok(BeMessageDataRow(cols)) + False -> dec_err("column count doesn't match", rest) + } + Error(err) -> Error(err) + } +} + +fn decode_message_data_row_rec( + binary, + count, + result, +) -> Result(List(BitArray), MessageDecodingError) { + case count, binary { + 0, _ -> Ok(list.reverse(result)) + _, <> -> + decode_message_data_row_rec(rest, count - 1, [value, ..result]) + _, _ -> + dec_err( + "failed to parse data row at count " <> int.to_string(count), + binary, + ) + } +} + +pub type RowDescriptionField { + RowDescriptionField( + // The field name. + name: String, + // If the field can be identified as a column of a specific table, the + // object ID of the table; otherwise zero. + table_oid: Int, + // If the field can be identified as a column of a specific table, the + // attribute number of the column; otherwise zero. + attr_number: Int, + // The object ID of the field's data type. + data_type_oid: Int, + // The data type size (see pg_type.typlen). Note that negative values denote + // variable-width types. + data_type_size: Int, + // The type modifier (see pg_attribute.atttypmod). The meaning of the + // modifier is type-specific. + type_modifier: Int, + // The format code being used for the field. Currently will be zero (text) + // or one (binary). In a RowDescription returned from the statement variant + // of Describe, the format code is not yet known and will always be zero. + format_code: Int, + ) +} + +fn read_row_descriptions(count, binary, result) { + case count, binary { + 0, <<>> -> Ok(list.reverse(result)) + _, <<>> -> dec_err("row description count mismatch", binary) + _, _ -> + case read_row_description_field(binary) { + Ok(#(field, tail)) -> + read_row_descriptions(count - 1, tail, [field, ..result]) + Error(err) -> Error(err) + } + } +} + +fn read_row_description_field(binary) { + case read_string(binary) { + Ok(#( + name, + << + table_oid:32, + attr_number:16, + data_type_oid:32, + data_type_size:16, + type_modifier:32, + format_code:16, + tail:bytes, + >>, + )) -> + Ok(#( + RowDescriptionField( + name: name, + table_oid: table_oid, + attr_number: attr_number, + data_type_oid: data_type_oid, + data_type_size: data_type_size, + type_modifier: type_modifier, + format_code: format_code, + ), + tail, + )) + Ok(#(_, tail)) -> dec_err("failed to parse row description field", tail) + Error(_) -> dec_err("failed to decode row description field name", binary) + } +} + +fn decode_strings(binary) { + let length = bit_array.byte_size(binary) - 1 + case binary { + <<>> -> Ok([]) + <> -> { + binary_split(head, <<0>>, [Global]) + |> list.map(bit_array.to_string) + |> result.all() + |> result.replace_error(msg_dec_err("invalid strings encoding", binary)) + } + _ -> dec_err("string size didn't match", binary) + } +} + +fn read_strings(binary, count, result) { + case count { + 0 -> Ok(list.reverse(result)) + _ -> { + case read_string(binary) { + Ok(#(value, rest)) -> read_strings(rest, count - 1, [value, ..result]) + Error(err) -> Error(err) + } + } + } +} + +fn read_string(binary) { + case binary_split(binary, <<0>>, []) { + [<<>>, <<>>] -> Ok(#("", <<>>)) + [head, tail] -> { + bit_array.to_string(head) + |> result.replace_error(msg_dec_err("invalid string encoding", head)) + |> result.map(fn(s) { #(s, tail) }) + } + _ -> dec_err("invalid string", binary) + } +} + +fn decode_notice_response(binary) { + use fields <- try(decode_fields(binary)) + Ok(BeNoticeResponse(fields)) +} + +fn decode_error_response(binary) { + use fields <- try(decode_fields(binary)) + Ok(BeErrorResponse(fields)) +} + +pub type ErrorOrNoticeField { + Code(String) + Detail(String) + File(String) + Hint(String) + Line(String) + Message(String) + Position(String) + Routine(String) + SeverityLocalized(String) + Severity(String) + Where(String) + Column(String) + DataType(String) + Constraint(String) + InternalPosition(String) + InternalQuery(String) + Schema(String) + Table(String) + Unknown(key: BitArray, value: String) +} + +fn decode_fields(binary) { + case decode_fields_rec(binary, []) { + Ok(fields) -> + fields + |> list.map(fn(key_value_raw) { + let #(key, value) = key_value_raw + case key { + <<"S":utf8>> -> Severity(value) + <<"V":utf8>> -> SeverityLocalized(value) + <<"C":utf8>> -> Code(value) + <<"M":utf8>> -> Message(value) + <<"D":utf8>> -> Detail(value) + <<"H":utf8>> -> Hint(value) + <<"P":utf8>> -> Position(value) + <<"p":utf8>> -> InternalPosition(value) + <<"q":utf8>> -> InternalQuery(value) + <<"W":utf8>> -> Where(value) + <<"s":utf8>> -> Schema(value) + <<"t":utf8>> -> Table(value) + <<"c":utf8>> -> Column(value) + <<"d":utf8>> -> DataType(value) + <<"n":utf8>> -> Constraint(value) + <<"F":utf8>> -> File(value) + <<"L":utf8>> -> Line(value) + <<"R":utf8>> -> Routine(value) + _ -> Unknown(key, value) + } + }) + |> set.from_list() + |> Ok + Error(err) -> Error(err) + } +} + +fn decode_fields_rec(binary, result) { + case binary { + <<0>> | <<>> -> Ok(result) + <> -> { + case binary_split(rest, <<0>>, []) { + [head, tail] -> { + case bit_array.to_string(head) { + Ok(value) -> + decode_fields_rec(tail, [#(field_type, value), ..result]) + Error(Nil) -> dec_err("invalid field encoding", binary) + } + } + _ -> dec_err("invalid field separator", binary) + } + } + _ -> dec_err("invalid field", binary) + } +} + +type BinarySplitOption { + Global +} + +@external(erlang, "binary", "split") +fn binary_split( + subject: BitArray, + pattern: BitArray, + options: List(BinarySplitOption), +) -> List(BitArray) diff --git a/src/squirrel/internal/error.gleam b/src/squirrel/internal/error.gleam new file mode 100644 index 0000000..04925d6 --- /dev/null +++ b/src/squirrel/internal/error.gleam @@ -0,0 +1,66 @@ +import gleam/option.{type Option} +import gleam/string +import simplifile + +pub type Error { + + // --- POSTGRES RELATED ERRORS ----------------------------------------------- + /// When there's an error with the underlying socket and I cannot send + /// messages to the server. + /// + PgCannotSendMessage(reason: String) + + /// This comes from the `postgres_protocol` module, it happens if the server + /// sends back a malformed message or there's a type of messages decoding is + /// not implemented for. + /// This should never happen and warrants a bug report! + /// + PgCannotDecodeReceivedMessage(reason: String) + + /// When there's an error with the underlying socket and I cannot receive + /// messages to the server. + /// + PgCannotReceiveMessage(reason: String) + + // --- OTHER GENERIC ERRORS -------------------------------------------------- + /// When I cannot read a file containing queries. + /// + CannotReadFile(file: String, reason: simplifile.FileError) + + /// If a query has a name that is not a valid Gleam identifier. Instead of + /// trying to magically come up with a name we fail and report the error. + /// + QueryHasInvalidName(file: String, name: String, reason: ValueIdentifierError) + + /// If a query returns a column that is not a valid Gleam identifier. Instead + /// of trying to magically come up with a name we fail and report the error. + /// + QueryHasInvalidColumn( + file: String, + query_name: String, + column_name: String, + reason: ValueIdentifierError, + ) + + /// If the query contains an error and cannot be parsed by the DBMS. + /// + CannotParseQuery( + name: String, + file: String, + content: String, + error_code: Option(String), + message: Option(String), + hint: Option(String), + position: Option(Int), + ) +} + +pub type ValueIdentifierError { + DoesntStartWithLowercaseLetter + ContainsInvalidGrapheme(at: Int, grapheme: String) + IsEmpty +} + +pub fn explain(error: Error) -> String { + string.inspect(error) +} diff --git a/src/squirrel/internal/eval_extra.gleam b/src/squirrel/internal/eval_extra.gleam new file mode 100644 index 0000000..70a7c72 --- /dev/null +++ b/src/squirrel/internal/eval_extra.gleam @@ -0,0 +1,61 @@ +import eval.{type Eval} +import gleam/list + +pub fn try_map( + list: List(a), + fun: fn(a) -> Eval(b, _, _), +) -> Eval(List(b), _, _) { + try_fold(list, [], fn(acc, item) { + use mapped_item <- eval.try(fun(item)) + eval.return([mapped_item, ..acc]) + }) + |> eval.map(list.reverse) +} + +pub fn try_index_map( + list: List(a), + fun: fn(a, Int) -> Eval(b, _, _), +) -> Eval(List(b), _, _) { + try_index_fold(list, [], fn(acc, item, i) { + use mapped_item <- eval.try(fun(item, i)) + eval.return([mapped_item, ..acc]) + }) + |> eval.map(list.reverse) +} + +pub fn try_fold( + over list: List(a), + from acc: b, + with fun: fn(b, a) -> Eval(b, _, _), +) -> Eval(b, _, _) { + case list { + [] -> eval.return(acc) + [first, ..rest] -> { + use acc <- eval.try(fun(acc, first)) + try_fold(rest, acc, fun) + } + } +} + +pub fn try_index_fold( + over list: List(a), + from acc: b, + with fun: fn(b, a, Int) -> Eval(b, _, _), +) -> Eval(b, _, _) { + do_try_index_fold(0, over: list, from: acc, with: fun) +} + +fn do_try_index_fold( + index: Int, + over list: List(a), + from acc: b, + with fun: fn(b, a, Int) -> Eval(b, _, _), +) -> Eval(b, _, _) { + case list { + [] -> eval.return(acc) + [first, ..rest] -> { + use acc <- eval.try(fun(acc, first, index)) + do_try_index_fold(index + 1, rest, acc, fun) + } + } +} diff --git a/src/squirrel/internal/gleam.gleam b/src/squirrel/internal/gleam.gleam new file mode 100644 index 0000000..64357be --- /dev/null +++ b/src/squirrel/internal/gleam.gleam @@ -0,0 +1,138 @@ +import gleam/list +import gleam/string +import justin +import squirrel/internal/error.{ + type ValueIdentifierError, ContainsInvalidGrapheme, IsEmpty, +} + +pub type Type { + List(Type) + Option(Type) + Int + Float + Bool + String +} + +pub type Field { + Field(label: ValueIdentifier, type_: Type) +} + +pub opaque type ValueIdentifier { + ValueIdentifier(String) +} + +/// Returns true if the given string is a valid Gleam identifier (that is not +/// a discard identifier, that is starting with an '_'). +/// +/// > 💡 A valid identifier can be described by the following regex: +/// > `[a-z][a-z0-9_]*`. +pub fn identifier( + from name: String, +) -> Result(ValueIdentifier, ValueIdentifierError) { + // A valid identifier needs to start with a lowercase letter. + // We do not accept _discard identifier as valid. + case name { + "a" <> rest + | "b" <> rest + | "c" <> rest + | "d" <> rest + | "e" <> rest + | "f" <> rest + | "g" <> rest + | "h" <> rest + | "i" <> rest + | "j" <> rest + | "k" <> rest + | "l" <> rest + | "m" <> rest + | "n" <> rest + | "o" <> rest + | "p" <> rest + | "q" <> rest + | "r" <> rest + | "s" <> rest + | "t" <> rest + | "u" <> rest + | "v" <> rest + | "w" <> rest + | "x" <> rest + | "y" <> rest + | "z" <> rest -> to_identifier_rest(name, rest, 1) + _ -> + case string.pop_grapheme(name) { + Ok(#(g, _)) -> Error(ContainsInvalidGrapheme(0, g)) + Error(_) -> Error(IsEmpty) + } + } +} + +fn to_identifier_rest( + name: String, + rest: String, + position: Int, +) -> Result(ValueIdentifier, ValueIdentifierError) { + // The rest of an identifier can only contain lowercase letters, _, numbers, + // or be empty. In all other cases it's not valid. + case rest { + "a" <> rest + | "b" <> rest + | "c" <> rest + | "d" <> rest + | "e" <> rest + | "f" <> rest + | "g" <> rest + | "h" <> rest + | "i" <> rest + | "j" <> rest + | "k" <> rest + | "l" <> rest + | "m" <> rest + | "n" <> rest + | "o" <> rest + | "p" <> rest + | "q" <> rest + | "r" <> rest + | "s" <> rest + | "t" <> rest + | "u" <> rest + | "v" <> rest + | "w" <> rest + | "x" <> rest + | "y" <> rest + | "z" <> rest + | "_" <> rest + | "0" <> rest + | "1" <> rest + | "2" <> rest + | "3" <> rest + | "4" <> rest + | "5" <> rest + | "6" <> rest + | "7" <> rest + | "8" <> rest + | "9" <> rest -> to_identifier_rest(name, rest, position + 1) + "" -> Ok(ValueIdentifier(name)) + _ -> + case string.pop_grapheme(rest) { + Ok(#(g, _)) -> Error(ContainsInvalidGrapheme(position, g)) + Error(_) -> panic as "unreachable: empty identifier rest should be ok" + } + } +} + +pub fn identifier_to_string(identifier: ValueIdentifier) -> String { + let ValueIdentifier(name) = identifier + name +} + +pub fn identifier_to_type_name(identifier: ValueIdentifier) -> String { + let ValueIdentifier(name) = identifier + + justin.pascal_case(name) + |> string.to_graphemes + // We want to remove any leftover "_" that might still be present after the + // conversion if the identifier had consecutive "_". + |> list.filter(keeping: fn(c) { c != "_" }) + |> string.join(with: "") +} diff --git a/src/squirrel/internal/query.gleam b/src/squirrel/internal/query.gleam new file mode 100644 index 0000000..5e49d30 --- /dev/null +++ b/src/squirrel/internal/query.gleam @@ -0,0 +1,438 @@ +import glam/doc.{type Document} +import gleam/bit_array +import gleam/int +import gleam/list +import gleam/result +import gleam/string +import simplifile +import squirrel/internal/error.{ + type Error, type ValueIdentifierError, CannotReadFile, QueryHasInvalidName, +} +import squirrel/internal/gleam + +pub type Query { + Query(file: String, name: gleam.ValueIdentifier, content: String) +} + +pub type TypedQuery { + TypedQuery( + file: String, + name: gleam.ValueIdentifier, + content: String, + params: List(gleam.Type), + returns: List(gleam.Field), + ) +} + +pub fn add_types( + to query: Query, + params params: List(gleam.Type), + returns returns: List(gleam.Field), +) -> TypedQuery { + let Query(file: file, name: name, content: content) = query + TypedQuery( + file: file, + name: name, + content: content, + params: params, + returns: returns, + ) +} + +// --- PARSING ----------------------------------------------------------------- + +pub fn from_file(file: String) -> Result(List(Query), Error) { + let read_file = + simplifile.read(file) + |> result.map_error(CannotReadFile(file, _)) + + use content <- result.try(read_file) + let content = <> + parse_queries(file, content, content, 0, []) +} + +fn parse_queries( + file: String, + original: BitArray, + string: BitArray, + position: Int, + queries: List(Query), +) -> Result(List(Query), Error) { + case string { + <<>> -> Ok(list.reverse(queries)) + <<"--":utf8, rest:bits>> -> + case name_comment(rest) { + NotANameComment -> + parse_queries(file, original, rest, position + 2, queries) + + InvalidNameComment(name: name, reason: reason) -> + Error(QueryHasInvalidName(file: file, name: name, reason: reason)) + + ValidNameComment(identifier) -> + parse_query( + file, + original, + rest, + identifier, + position, + position + 2, + queries, + ) + } + <<_, rest:bits>> -> + parse_queries(file, original, rest, position + 1, queries) + _ -> panic as "non byte aligned Gleam String" + } +} + +type NameCommentResult { + NotANameComment + InvalidNameComment(name: String, reason: ValueIdentifierError) + ValidNameComment(identifier: gleam.ValueIdentifier) +} + +fn name_comment(string: BitArray) -> NameCommentResult { + case string { + // We ignore all whitespace between the start of the comment and its + // content. + <<" ":utf8, rest:bits>> + | <<"\t":utf8, rest:bits>> + | <<"\r":utf8, rest:bits>> -> name_comment(rest) + + <<"name:":utf8, rest:bits>> -> take_name(rest, rest, 0) + <<_, _:bits>> | <<>> -> NotANameComment + _ -> panic as "non byte aligned Gleam String" + } +} + +fn take_name( + original: BitArray, + string: BitArray, + size: Int, +) -> NameCommentResult { + case string { + <<>> | <<"\n":utf8, _:bits>> -> { + let assert Ok(name) = bit_array.slice(original, 0, size) + let assert Ok(name) = bit_array.to_string(name) + case string.trim(name) |> gleam.identifier { + Ok(identifier) -> ValidNameComment(identifier) + Error(reason) -> InvalidNameComment(name: name, reason: reason) + } + } + <<_, rest:bits>> -> take_name(original, rest, size + 1) + _ -> panic as "non byte aligned Gleam String" + } +} + +fn parse_query( + file: String, + original: BitArray, + string: BitArray, + name: gleam.ValueIdentifier, + start: Int, + position: Int, + queries: List(Query), +) -> Result(List(Query), Error) { + case string { + <<>> -> { + let assert Ok(query) = bit_array.slice(original, start, position - start) + let assert Ok(content) = bit_array.to_string(query) + let query = Query(file: file, name: name, content: content) + Ok(list.reverse([query, ..queries])) + } + + <<"--":utf8, rest:bits>> -> + case name_comment(rest) { + NotANameComment -> + parse_query(file, original, rest, name, start, position + 2, queries) + + ValidNameComment(_) | InvalidNameComment(_, _) -> { + let assert Ok(query) = + bit_array.slice(original, start, position - start) + let assert Ok(content) = bit_array.to_string(query) + let query = Query(file: file, name: name, content: content) + parse_queries(file, original, string, position, [query, ..queries]) + } + } + + <<_, rest:bits>> -> + parse_query(file, original, rest, name, start, position + 1, queries) + + _ -> panic as "non byte aligned Gleam String" + } +} + +// --- CODE GENERATION --------------------------------------------------------- + +pub fn to_function(query: TypedQuery) { + let TypedQuery( + file: file, + name: name, + content: content, + params: params, + returns: returns, + ) = query + + let arg_name = fn(i) { "arg_" <> int.to_string(i + 1) } + let inputs = list.index_map(params, fn(_, i) { arg_name(i) }) + let inputs_encoders = + list.index_map(params, fn(p, i) { + gleam_type_to_encoder(p, arg_name(i)) |> doc.from_string + }) + + let function_name = gleam.identifier_to_string(name) + let constructor_name = gleam.identifier_to_type_name(name) <> "Row" + + let record_doc = + "/// A row you get from running the `" <> function_name <> "` query +/// defined in `" <> file <> "`. +/// +/// > 🐿️ This type definition was generated automatically using v1.0.0 of the +/// > [squirrel package](link!!). +///" + + let fun_doc = "/// Runs the `" <> function_name <> "` query +/// defined in `" <> file <> "`. +/// +/// > 🐿️ This function was generated automatically using v1.0.0 of the +/// > [squirrel package](link!!). +///" + + [ + doc.from_string(record_doc), + doc.line, + record(constructor_name, returns), + doc.lines(2), + doc.from_string(fun_doc), + doc.line, + fun(function_name, ["db", ..inputs], [ + var("decoder", decoder(constructor_name, returns)), + pipe_call("pgo.execute", string(content), [ + doc.from_string("db"), + list(inputs_encoders), + doc.from_string("decode.from(decoder, _)"), + ]), + ]), + ] + |> doc.concat + |> doc.to_string(80) +} + +fn gleam_type_to_decoder(type_: gleam.Type) -> String { + case type_ { + gleam.List(type_) -> "decode.list(" <> gleam_type_to_decoder(type_) <> ")" + gleam.Int -> "decode.int" + gleam.Float -> "decode.float" + gleam.Bool -> "decode.bool" + gleam.String -> "decode.string" + gleam.Option(type_) -> + "decode.option(" <> gleam_type_to_decoder(type_) <> ")" + } +} + +fn gleam_type_to_encoder(type_: gleam.Type, name: String) { + case type_ { + gleam.List(type_) -> + "pgo.array(list.map(name, fn(a) {" + <> gleam_type_to_encoder(type_, "a") + <> "}))" + gleam.Option(type_) -> + "pgo.nullable(fn(a) {" + <> gleam_type_to_encoder(type_, "a") + <> "}, " + <> name + <> ")" + gleam.Int -> "pgo.int(" <> name <> ")" + gleam.Float -> "pgo.float(" <> name <> ")" + gleam.Bool -> "pgo.bool(" <> name <> ")" + gleam.String -> "pgo.text(" <> name <> ")" + } +} + +// --- CODE GENERATION PRETTY PRINTING ----------------------------------------- + +const indent = 2 + +pub fn record(name: String, fields: List(gleam.Field)) -> Document { + let fields = + list.map(fields, fn(field) { + let label = gleam.identifier_to_string(field.label) + + [doc.from_string(label <> ": "), pretty_gleam_type(field.type_)] + |> doc.concat + |> doc.group + }) + + [ + doc.from_string("pub type " <> name <> " {"), + [doc.line, call(name, fields)] + |> doc.concat + |> doc.nest(by: indent), + doc.line, + doc.from_string("}"), + ] + |> doc.concat + |> doc.group +} + +fn pretty_gleam_type(type_: gleam.Type) -> Document { + case type_ { + gleam.List(type_) -> call("List", [pretty_gleam_type(type_)]) + gleam.Option(type_) -> call("Option", [pretty_gleam_type(type_)]) + gleam.Int -> doc.from_string("Int") + gleam.Float -> doc.from_string("Float") + gleam.Bool -> doc.from_string("Bool") + gleam.String -> doc.from_string("String") + } +} + +/// A pretty printed public function definition. +/// +pub fn fun(name: String, args: List(String), body: List(Document)) -> Document { + let args = list.map(args, doc.from_string) + + [ + doc.from_string("pub fn " <> name), + comma_list("(", args, ") "), + block([body |> doc.join(with: doc.lines(2))]), + doc.line, + ] + |> doc.concat + |> doc.group +} + +/// A pretty printed let assignment. +/// +pub fn var(name: String, body: Document) -> Document { + [ + doc.from_string("let " <> name <> " ="), + [doc.space, body] + |> doc.concat + |> doc.group + |> doc.nest(by: indent), + ] + |> doc.concat +} + +/// A pretty printed Gleam string. +/// +/// > ⚠️ This function escapes all `\` and `"` inside the original string to +/// > avoid generating invalid Gleam code. +/// +pub fn string(content: String) -> Document { + let escaped_string = + content + |> string.replace(each: "\\", with: "\\\\") + |> string.replace(each: "\"", with: "\\\"") + |> doc.from_string + + [doc.from_string("\""), escaped_string, doc.from_string("\"")] + |> doc.concat +} + +/// A pretty printed Gleam list. +/// +pub fn list(elems: List(Document)) -> Document { + comma_list("[", elems, "]") +} + +/// A pretty printed decoder that decodes an n-item dynamic tuple using the +/// `decode` package. +/// +pub fn decoder(constructor: String, returns: List(gleam.Field)) -> Document { + let parameters = + list.map(returns, fn(field) { + let label = gleam.identifier_to_string(field.label) + doc.from_string("use " <> label <> " <- decode.parameter") + }) + + let pipes = + list.index_map(returns, fn(field, i) { + let position = int.to_string(i) |> doc.from_string + let decoder = gleam_type_to_decoder(field.type_) |> doc.from_string + call("|> decode.field", [position, decoder]) + }) + + let labelled_names = + // TODO)) remove the redundant label once zed updates syntax highlight for gleam + list.map(returns, fn(field) { + let label = gleam.identifier_to_string(field.label) + doc.from_string(label <> ": " <> label) + }) + + [ + call_block("decode.into", [ + doc.join(parameters, with: doc.line), + doc.line, + call(constructor, labelled_names), + ]), + doc.line, + doc.join(pipes, with: doc.line), + ] + |> doc.concat() + |> doc.group +} + +/// A pretty printed function call where the first argument is piped into +/// the function. +/// +pub fn pipe_call( + function: String, + first: Document, + rest: List(Document), +) -> Document { + [first, doc.line, call("|> " <> function, rest)] + |> doc.concat +} + +/// A pretty printed function call. +/// +fn call(function: String, args: List(Document)) -> Document { + [doc.from_string(function), comma_list("(", args, ")")] + |> doc.concat + |> doc.group +} + +/// A pretty printed function call where the only argument is a single block. +/// +fn call_block(function: String, body: List(Document)) -> Document { + [doc.from_string(function <> "("), block(body), doc.from_string(")")] + |> doc.concat + |> doc.group +} + +/// A pretty printed Gleam block. +/// +fn block(body: List(Document)) -> Document { + [ + doc.from_string("{"), + [doc.line, ..body] + |> doc.concat + |> doc.nest(by: indent), + doc.line, + doc.from_string("}"), + ] + |> doc.concat + |> doc.force_break +} + +/// A comma separated list of items with some given open and closed delimiters. +/// +fn comma_list(open: String, content: List(Document), close: String) -> Document { + [ + doc.from_string(open), + [ + // We want the first break to be nested + // in case the group is broken. + doc.soft_break, + doc.join(content, doc.break(", ", ",")), + ] + |> doc.concat + |> doc.group + |> doc.nest(by: indent), + doc.break("", ","), + doc.from_string(close), + ] + |> doc.concat + |> doc.group +} diff --git a/test/squirrel_test.gleam b/test/squirrel_test.gleam new file mode 100644 index 0000000..3831e7a --- /dev/null +++ b/test/squirrel_test.gleam @@ -0,0 +1,12 @@ +import gleeunit +import gleeunit/should + +pub fn main() { + gleeunit.main() +} + +// gleeunit test functions end in `_test` +pub fn hello_world_test() { + 1 + |> should.equal(1) +}