Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move sql-only functions to plpgsql instead of using SPI #361

Merged
merged 2 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lantern_extras/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "lantern_extras"
version = "0.5.0"
version = "0.6.0"
edition = "2021"

[lib]
Expand Down
4 changes: 2 additions & 2 deletions lantern_extras/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ To add a new embedding job, use the `add_embedding_job` function:

```sql
SELECT add_embedding_job(
table => 'articles', -- Name of the table
table_name => 'articles', -- Name of the table
src_column => 'content', -- Source column for embeddings
dst_column => 'content_embedding', -- Destination column for embeddings (will be created automatically)
model => 'text-embedding-3-small', -- Model for runtime to use (default: 'text-embedding-3-small')
Expand Down Expand Up @@ -224,7 +224,7 @@ To add a new completion job, use the `add_completion_job` function:

```sql
SELECT add_completion_job(
table => 'articles', -- Name of the table
table_name => 'articles', -- Name of the table
src_column => 'content', -- Source column for embeddings
dst_column => 'content_summary', -- Destination column for llm response (will be created automatically)
system_prompt => 'Provide short summary for the given text', -- System prompt for LLM (default: '')
Expand Down
298 changes: 178 additions & 120 deletions lantern_extras/src/daemon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,126 +226,177 @@ fn add_completion_job<'a>(
Ok(id.unwrap())
}

#[pg_extern(immutable, parallel_safe, security_definer)]
fn get_embedding_job_status<'a>(
job_id: i32,
) -> Result<
TableIterator<
'static,
(
name!(status, Option<String>),
name!(progress, Option<i16>),
name!(error, Option<String>),
),
>,
anyhow::Error,
> {
let tuple = Spi::get_three_with_args(
r#"
SELECT
CASE
WHEN init_failed_at IS NOT NULL THEN 'failed'
WHEN canceled_at IS NOT NULL THEN 'canceled'
WHEN init_finished_at IS NOT NULL THEN 'enabled'
WHEN init_started_at IS NOT NULL THEN 'in_progress'
ELSE 'queued'
END AS status,
init_progress as progress,
init_failure_reason as error
FROM _lantern_extras_internal.embedding_generation_jobs
WHERE id=$1;
"#,
vec![(PgBuiltInOids::INT4OID.oid(), job_id.into_datum())],
);

if tuple.is_err() {
return Ok(TableIterator::once((None, None, None)));
}

Ok(TableIterator::once(tuple.unwrap()))
}

#[pg_extern(immutable, parallel_safe, security_definer)]
fn get_completion_job_failures<'a>(
job_id: i32,
) -> Result<
TableIterator<'static, (name!(row_id, Option<i32>), name!(value, Option<String>))>,
anyhow::Error,
> {
Spi::connect(|client| {
client.select("SELECT row_id, value FROM _lantern_extras_internal.embedding_failure_info WHERE job_id=$1", None, Some(vec![(PgBuiltInOids::INT4OID.oid(), job_id.into_datum())]))?
.map(|row| Ok((row["row_id"].value()?, row["value"].value()?)))
.collect::<Result<Vec<_>, _>>()
}).map(TableIterator::new)
}

#[pg_extern(immutable, parallel_safe, security_definer)]
fn get_embedding_jobs<'a>() -> Result<
TableIterator<
'static,
(
name!(id, Option<i32>),
name!(status, Option<String>),
name!(progress, Option<i16>),
name!(error, Option<String>),
),
>,
anyhow::Error,
> {
Spi::connect(|client| {
client.select("SELECT id, (get_embedding_job_status(id)).* FROM _lantern_extras_internal.embedding_generation_jobs WHERE job_type = 'embedding_generation'", None, None)?
.map(|row| Ok((row["id"].value()?, row["status"].value()?, row["progress"].value()?, row["error"].value()?)))
.collect::<Result<Vec<_>, _>>()
}).map(TableIterator::new)
}

#[pg_extern(immutable, parallel_safe, security_definer)]
fn get_completion_jobs<'a>() -> Result<
TableIterator<
'static,
(
name!(id, Option<i32>),
name!(status, Option<String>),
name!(progress, Option<i16>),
name!(error, Option<String>),
),
>,
anyhow::Error,
> {
Spi::connect(|client| {
client.select("SELECT id, (get_embedding_job_status(id)).* FROM _lantern_extras_internal.embedding_generation_jobs WHERE job_type = 'completion'", None, None)?
.map(|row| Ok((row["id"].value()?, row["status"].value()?, row["progress"].value()?, row["error"].value()?)))
.collect::<Result<Vec<_>, _>>()
}).map(TableIterator::new)
}

#[pg_extern(immutable, parallel_safe, security_definer)]
fn cancel_embedding_job<'a>(job_id: i32) -> AnyhowVoidResult {
Spi::run_with_args(
r#"
UPDATE _lantern_extras_internal.embedding_generation_jobs
SET canceled_at=NOW()
WHERE id=$1;
"#,
Some(vec![(PgBuiltInOids::INT4OID.oid(), job_id.into_datum())]),
)?;

Ok(())
}

#[pg_extern(immutable, parallel_safe, security_definer)]
fn resume_embedding_job<'a>(job_id: i32) -> AnyhowVoidResult {
Spi::run_with_args(
r#"
UPDATE _lantern_extras_internal.embedding_generation_jobs
SET canceled_at=NULL
WHERE id=$1;
"#,
Some(vec![(PgBuiltInOids::INT4OID.oid(), job_id.into_datum())]),
)?;

Ok(())
}
extension_sql!(
r#"
CREATE OR REPLACE FUNCTION get_embedding_job_status(job_id INT)
RETURNS TABLE (status TEXT, progress SMALLINT, error TEXT)
STRICT IMMUTABLE PARALLEL SAFE
LANGUAGE plpgsql
AS $$
BEGIN
RETURN QUERY
SELECT
CASE
WHEN init_failed_at IS NOT NULL THEN 'failed'
WHEN canceled_at IS NOT NULL THEN 'canceled'
WHEN init_finished_at IS NOT NULL THEN 'enabled'
WHEN init_started_at IS NOT NULL THEN 'in_progress'
ELSE 'queued'
END AS status,
init_progress as progress,
init_failure_reason as error
FROM _lantern_extras_internal.embedding_generation_jobs
WHERE id=job_id;
END
$$;
"#,
name = "get_embedding_job_status"
);

extension_sql!(
r#"
CREATE OR REPLACE FUNCTION get_completion_job_status(job_id INT)
RETURNS TABLE (status TEXT, progress SMALLINT, error TEXT)
STRICT IMMUTABLE PARALLEL SAFE
LANGUAGE plpgsql
AS $$
BEGIN
RETURN QUERY
SELECT * FROM get_embedding_job_status(job_id);
END
$$;
"#,
name = "get_completion_job_status",
requires = ["get_embedding_job_status"]
);

extension_sql!(
r#"
CREATE OR REPLACE FUNCTION get_completion_job_failures(job_id INT)
RETURNS TABLE (row_id INT, value TEXT)
STRICT IMMUTABLE PARALLEL SAFE
LANGUAGE plpgsql
AS $$
BEGIN
RETURN QUERY
SELECT info.row_id, info.value
FROM _lantern_extras_internal.embedding_failure_info info
WHERE info.job_id=get_completion_job_failures.job_id;
END
$$;
"#,
name = "get_completion_job_failures",
);

extension_sql!(
r#"
CREATE OR REPLACE FUNCTION get_embedding_jobs()
RETURNS TABLE (id INT, status TEXT, progress SMALLINT, error TEXT)
STRICT IMMUTABLE PARALLEL SAFE
LANGUAGE plpgsql
AS $$
BEGIN
RETURN QUERY
SELECT jobs.id, (get_embedding_job_status(jobs.id)).*
FROM _lantern_extras_internal.embedding_generation_jobs jobs
WHERE jobs.job_type = 'embedding_generation';
END
$$;
"#,
name = "get_embedding_jobs",
requires = ["get_embedding_job_status"]
);

extension_sql!(
r#"
CREATE OR REPLACE FUNCTION get_completion_jobs()
RETURNS TABLE (id INT, status TEXT, progress SMALLINT, error TEXT)
STRICT IMMUTABLE PARALLEL SAFE
LANGUAGE plpgsql
AS $$
BEGIN
RETURN QUERY
SELECT jobs.id, (get_completion_job_status(jobs.id)).*
FROM _lantern_extras_internal.embedding_generation_jobs jobs
WHERE jobs.job_type = 'completion';
END
$$;
"#,
name = "get_completion_jobs",
requires = ["get_completion_job_status"]
);

extension_sql!(
r#"
CREATE OR REPLACE FUNCTION cancel_embedding_job(job_id INT)
RETURNS VOID
STRICT VOLATILE
SECURITY DEFINER
LANGUAGE plpgsql
AS $$
BEGIN
UPDATE _lantern_extras_internal.embedding_generation_jobs
SET canceled_at=NOW()
WHERE id=job_id;
END
$$;
"#,
name = "cancel_embedding_job",
);

extension_sql!(
r#"
CREATE OR REPLACE FUNCTION cancel_completion_job(job_id INT)
RETURNS VOID
STRICT VOLATILE
SECURITY DEFINER
LANGUAGE plpgsql
AS $$
BEGIN
UPDATE _lantern_extras_internal.embedding_generation_jobs
SET canceled_at=NOW()
WHERE id=job_id;
END
$$;
"#,
name = "cancel_completion_job",
);

extension_sql!(
r#"
CREATE OR REPLACE FUNCTION resume_embedding_job(job_id INT)
RETURNS VOID
STRICT VOLATILE
SECURITY DEFINER
LANGUAGE plpgsql
AS $$
BEGIN
UPDATE _lantern_extras_internal.embedding_generation_jobs
SET canceled_at=NULL
WHERE id=job_id;
END
$$;
"#,
name = "resume_embedding_job",
);

extension_sql!(
r#"
CREATE OR REPLACE FUNCTION resume_completion_job(job_id INT)
RETURNS VOID
STRICT VOLATILE
SECURITY DEFINER
LANGUAGE plpgsql
AS $$
BEGIN
UPDATE _lantern_extras_internal.embedding_generation_jobs
SET canceled_at=NULL
WHERE id=job_id;
END
$$;
"#,
name = "resume_completion_job",
);

#[cfg(any(test, feature = "pg_test"))]
#[pg_schema]
Expand Down Expand Up @@ -472,6 +523,8 @@ pub mod tests {
"
CREATE TABLE t1 (id serial primary key, title text);
SET lantern_extras.openai_token='test';
CREATE ROLE test_role1;
SET ROLE test_role1;
",
None,
None,
Expand Down Expand Up @@ -513,6 +566,8 @@ pub mod tests {
(1, 1, '1test1'),
(1, 2, '1test2'),
(2, 1, '2test1');
CREATE ROLE test_role1;
SET ROLE test_role1;
",
None,
None,
Expand Down Expand Up @@ -601,6 +656,7 @@ pub mod tests {
client.update(
"
CREATE TABLE t1 (id serial primary key, title text);
CREATE ROLE test_role1;
",
None,
None,
Expand All @@ -610,7 +666,9 @@ pub mod tests {
let id: i32 = id.first().get(1)?.unwrap();

// queued
client.update("SET ROLE test_role1;", None, None,)?;
let rows = client.select("SELECT status, progress, error FROM get_embedding_job_status($1)", None, Some(vec![(PgBuiltInOids::INT4OID.oid(), id.into_datum())]))?;
client.update("RESET ROLE;", None, None,)?;
let job = rows.first();

let status: &str = job.get(1)?.unwrap();
Expand Down
Loading