diff --git a/.github/actions/poetry-setup/action.yml b/.github/actions/poetry-setup/action.yml new file mode 100644 index 0000000..df04e1e --- /dev/null +++ b/.github/actions/poetry-setup/action.yml @@ -0,0 +1,88 @@ +# An action for setting up poetry install with caching. +# Using a custom action since the default action does not +# take poetry install groups into account. +# Action code from: +# https://github.com/actions/setup-python/issues/505#issuecomment-1273013236 +name: poetry-install-with-caching +description: Poetry install with support for caching of dependency groups. + +inputs: + python-version: + description: Python version, supporting MAJOR.MINOR only + required: true + + poetry-version: + description: Poetry version + required: true + + cache-key: + description: Cache key to use for manual handling of caching + required: true + +runs: + using: composite + steps: + - uses: actions/setup-python@v5 + name: Setup python ${{ inputs.python-version }} + id: setup-python + with: + python-version: ${{ inputs.python-version }} + + - uses: actions/cache@v3 + id: cache-bin-poetry + name: Cache Poetry binary - Python ${{ inputs.python-version }} + env: + SEGMENT_DOWNLOAD_TIMEOUT_MIN: "1" + with: + path: | + /opt/pipx/venvs/poetry + # This step caches the poetry installation, so make sure it's keyed on the poetry version as well. + key: bin-poetry-${{ runner.os }}-${{ runner.arch }}-py-${{ inputs.python-version }}-${{ inputs.poetry-version }} + + - name: Refresh shell hashtable and fixup softlinks + if: steps.cache-bin-poetry.outputs.cache-hit == 'true' + shell: bash + env: + POETRY_VERSION: ${{ inputs.poetry-version }} + PYTHON_VERSION: ${{ inputs.python-version }} + run: | + set -eux + + # Refresh the shell hashtable, to ensure correct `which` output. + hash -r + + # `actions/cache@v3` doesn't always seem able to correctly unpack softlinks. + # Delete and recreate the softlinks pipx expects to have. + rm /opt/pipx/venvs/poetry/bin/python + cd /opt/pipx/venvs/poetry/bin + ln -s "$(which "python$PYTHON_VERSION")" python + chmod +x python + cd /opt/pipx_bin/ + ln -s /opt/pipx/venvs/poetry/bin/poetry poetry + chmod +x poetry + + # Ensure everything got set up correctly. + /opt/pipx/venvs/poetry/bin/python --version + /opt/pipx_bin/poetry --version + + - name: Install poetry + if: steps.cache-bin-poetry.outputs.cache-hit != 'true' + shell: bash + env: + POETRY_VERSION: ${{ inputs.poetry-version }} + PYTHON_VERSION: ${{ inputs.python-version }} + # Install poetry using the python version installed by setup-python step. + run: pipx install "poetry==$POETRY_VERSION" --python '${{ steps.setup-python.outputs.python-path }}' --verbose + + - name: Restore pip and poetry cached dependencies + uses: actions/cache@v3 + env: + SEGMENT_DOWNLOAD_TIMEOUT_MIN: "4" + with: + path: | + ~/.cache/pip + ~/.cache/pypoetry/virtualenvs + ~/.cache/pypoetry/cache + ~/.cache/pypoetry/artifacts + ./.venv + key: py-deps-${{ runner.os }}-${{ runner.arch }}-py-${{ inputs.python-version }}-poetry-${{ inputs.poetry-version }}-${{ inputs.cache-key }}-${{ hashFiles('./poetry.lock') }} diff --git a/.github/workflows/_lint.yml b/.github/workflows/_lint.yml new file mode 100644 index 0000000..eebbde7 --- /dev/null +++ b/.github/workflows/_lint.yml @@ -0,0 +1,112 @@ +name: lint + +on: + workflow_call + +env: + POETRY_VERSION: "1.7.1" + + # This env var allows us to get inline annotations when ruff has complaints. + RUFF_OUTPUT_FORMAT: github + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + # Only lint on the min and max supported Python versions. + # It's extremely unlikely that there's a lint issue on any version in between + # that doesn't show up on the min or max versions. + # + # GitHub rate-limits how many jobs can be running at any one time. + # Starting new jobs is also relatively slow, + # so linting on fewer versions makes CI faster. + python-version: + - "3.12" + name: "lint #${{ matrix.python-version }}" + steps: + - uses: actions/checkout@v4 + - name: Get changed files + id: changed-files + uses: Ana06/get-changed-files@v2.3.0 + - name: Set up Python ${{ matrix.python-version }} + Poetry ${{ env.POETRY_VERSION }} + if: steps.changed-files.outputs.all + uses: "./.github/actions/poetry_setup" + with: + python-version: ${{ matrix.python-version }} + poetry-version: ${{ env.POETRY_VERSION }} + cache-key: lint + + - name: Check Poetry File + if: steps.changed-files.outputs.all + shell: bash + run: poetry check + + - name: Check lock file + if: steps.changed-files.outputs.all + shell: bash + run: poetry lock --check + + - name: Install dependencies + if: steps.changed-files.outputs.all + # Also installs dev/lint/test/typing dependencies, to ensure we have + # type hints for as many of our libraries as possible. + # This helps catch errors that require dependencies to be spotted, for example: + # https://github.com/langchain-ai/langchain/pull/10249/files#diff-935185cd488d015f026dcd9e19616ff62863e8cde8c0bee70318d3ccbca98341 + # + # If you change this configuration, make sure to change the `cache-key` + # in the `poetry_setup` action above to stop using the old cache. + # It doesn't matter how you change it, any change will cause a cache-bust. + run: poetry install --with dev + + - name: Get .mypy_cache to speed up mypy + if: steps.changed-files.outputs.all + uses: actions/cache@v3 + env: + SEGMENT_DOWNLOAD_TIMEOUT_MIN: "2" + with: + path: | + .mypy_cache + key: mypy-lint-${{ runner.os }}-${{ runner.arch }}-py${{ matrix.python-version }}-${{ hashFiles('poetry.lock') }} + + - name: Analysing package code with our lint + if: steps.changed-files.outputs.all + run: | + if make lint_package > /dev/null 2>&1; then + make lint_package + else + echo "lint_package command not found, using lint instead" + make lint + fi + + - name: Install test dependencies + if: steps.changed-files.outputs.all + # Also installs dev/lint/test/typing dependencies, to ensure we have + # type hints for as many of our libraries as possible. + # This helps catch errors that require dependencies to be spotted, for example: + # https://github.com/langchain-ai/langchain/pull/10249/files#diff-935185cd488d015f026dcd9e19616ff62863e8cde8c0bee70318d3ccbca98341 + # + # If you change this configuration, make sure to change the `cache-key` + # in the `poetry_setup` action above to stop using the old cache. + # It doesn't matter how you change it, any change will cause a cache-bust. + run: | + poetry install --with dev + + - name: Get .mypy_cache_test to speed up mypy + if: steps.changed-files.outputs.all + uses: actions/cache@v3 + env: + SEGMENT_DOWNLOAD_TIMEOUT_MIN: "2" + with: + path: | + .mypy_cache_test + key: mypy-test-${{ runner.os }}-${{ runner.arch }}-py${{ matrix.python-version }}-${{ hashFiles('poetry.lock') }} + + - name: Analysing tests with our lint + if: steps.changed-files.outputs.all + run: | + if make lint_tests > /dev/null 2>&1; then + make lint_tests + else + echo "lint_tests command not found, skipping step" + fi diff --git a/.github/workflows/_test.yml b/.github/workflows/_test.yml new file mode 100644 index 0000000..c7ad8ff --- /dev/null +++ b/.github/workflows/_test.yml @@ -0,0 +1,50 @@ +name: test + +on: + workflow_call + +env: + POETRY_VERSION: "1.7.1" + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: + - "3.9" + - "3.10" + - "3.11" + - "3.12" + + name: "test #${{ matrix.python-version }}" + steps: + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + Poetry ${{ env.POETRY_VERSION }} + uses: "./.github/actions/poetry_setup" + with: + python-version: ${{ matrix.python-version }} + poetry-version: ${{ env.POETRY_VERSION }} + cache-key: test + + - name: Install dependencies + shell: bash + run: | + poetry install --with dev + + - name: Run tests + shell: bash + run: | + make test + + - name: Ensure the tests did not create any additional files + shell: bash + run: | + set -eu + + STATUS="$(git status)" + echo "$STATUS" + + # grep will exit non-zero if the target message isn't found, + # and `set -e` above will cause the step to fail. + echo "$STATUS" | grep 'nothing to commit, working tree clean' diff --git a/.github/workflows/_test_release.yml b/.github/workflows/_test_release.yml new file mode 100644 index 0000000..8a2ec03 --- /dev/null +++ b/.github/workflows/_test_release.yml @@ -0,0 +1,87 @@ +name: test-release + +on: + workflow_call + +env: + POETRY_VERSION: "1.7.1" + PYTHON_VERSION: "3.10" + +jobs: + build: + if: github.ref == 'refs/heads/main' + runs-on: ubuntu-latest + + outputs: + pkg-name: ${{ steps.check-version.outputs.pkg-name }} + version: ${{ steps.check-version.outputs.version }} + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + Poetry ${{ env.POETRY_VERSION }} + uses: "./.github/actions/poetry_setup" + with: + python-version: ${{ env.PYTHON_VERSION }} + poetry-version: ${{ env.POETRY_VERSION }} + cache-key: release + + # We want to keep this build stage *separate* from the release stage, + # so that there's no sharing of permissions between them. + # The release stage has trusted publishing and GitHub repo contents write access, + # and we want to keep the scope of that access limited just to the release job. + # Otherwise, a malicious `build` step (e.g. via a compromised dependency) + # could get access to our GitHub or PyPI credentials. + # + # Per the trusted publishing GitHub Action: + # > It is strongly advised to separate jobs for building [...] + # > from the publish job. + # https://github.com/pypa/gh-action-pypi-publish#non-goals + - name: Build project for distribution + run: poetry build + + - name: Upload build + uses: actions/upload-artifact@v4 + with: + name: test-dist + path: dist/ + + - name: Check Version + id: check-version + shell: bash + run: | + echo pkg-name="$(poetry version | cut -d ' ' -f 1)" >> $GITHUB_OUTPUT + echo version="$(poetry version --short)" >> $GITHUB_OUTPUT + + publish: + needs: + - build + runs-on: ubuntu-latest + permissions: + # This permission is used for trusted publishing: + # https://blog.pypi.org/posts/2023-04-20-introducing-trusted-publishers/ + # + # Trusted publishing has to also be configured on PyPI for each package: + # https://docs.pypi.org/trusted-publishers/adding-a-publisher/ + id-token: write + + steps: + - uses: actions/checkout@v4 + + - uses: actions/download-artifact@v4 + with: + name: test-dist + path: dist/ + + - name: Publish to test PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + with: + packages-dir: dist/ + verbose: true + print-hash: true + repository-url: https://test.pypi.org/legacy/ + + # We overwrite any existing distributions with the same name and version. + # This is *only for CI use* and is *extremely dangerous* otherwise! + # https://github.com/pypa/gh-action-pypi-publish#tolerating-release-package-file-duplicates + skip-existing: true diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..979b243 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,47 @@ +--- +name: CI + +on: + push: + branches: [main] + pull_request: + +# If another push to the same PR or branch happens while this workflow is still running, +# cancel the earlier run in favor of the next run. +# +# There's no point in testing an outdated version of the code. GitHub only allows +# a limited number of job runners to be active at the same time, so it's better to cancel +# pointless jobs early so that more useful jobs can run sooner. +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +env: + POETRY_VERSION: "1.7.1" + +jobs: + lint: + uses: ./.github/workflows/_lint.yml + secrets: inherit + + test: + uses: ./.github/workflows/_test.yml + secrets: inherit + + ci_success: + name: "CI Success" + needs: [lint, test] + if: | + always() + runs-on: ubuntu-latest + env: + JOBS_JSON: ${{ toJSON(needs) }} + RESULTS_JSON: ${{ toJSON(needs.*.result) }} + EXIT_CODE: ${{!contains(needs.*.result, 'failure') && !contains(needs.*.result, 'cancelled') && '0' || '1'}} + steps: + - name: "CI Success" + run: | + echo $JOBS_JSON + echo $RESULTS_JSON + echo "Exiting with $EXIT_CODE" + exit $EXIT_CODE diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 0000000..e0cd91e --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,281 @@ +name: release +run-name: Release langgraph-checkpoint-mysql by @${{ github.actor }} +on: + workflow_dispatch + +env: + PYTHON_VERSION: "3.11" + POETRY_VERSION: "1.7.1" + +jobs: + build: + if: github.ref == 'refs/heads/main' + runs-on: ubuntu-latest + + outputs: + pkg-name: ${{ steps.check-version.outputs.pkg-name }} + short-pkg-name: ${{ steps.check-version.outputs.short-pkg-name }} + version: ${{ steps.check-version.outputs.version }} + tag: ${{ steps.check-version.outputs.tag }} + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + Poetry ${{ env.POETRY_VERSION }} + uses: "./.github/actions/poetry_setup" + with: + python-version: ${{ env.PYTHON_VERSION }} + poetry-version: ${{ env.POETRY_VERSION }} + cache-key: release + + # We want to keep this build stage *separate* from the release stage, + # so that there's no sharing of permissions between them. + # The release stage has trusted publishing and GitHub repo contents write access, + # and we want to keep the scope of that access limited just to the release job. + # Otherwise, a malicious `build` step (e.g. via a compromised dependency) + # could get access to our GitHub or PyPI credentials. + # + # Per the trusted publishing GitHub Action: + # > It is strongly advised to separate jobs for building [...] + # > from the publish job. + # https://github.com/pypa/gh-action-pypi-publish#non-goals + - name: Build project for distribution + run: poetry build + + - name: Upload build + uses: actions/upload-artifact@v4 + with: + name: dist + path: dist/ + + - name: Check Version + id: check-version + shell: bash + run: | + PKG_NAME="$(poetry version | cut -d ' ' -f 1)" + VERSION="$(poetry version --short)" + SHORT_PKG_NAME="$(echo "$PKG_NAME" | sed -e 's/langgraph//g' -e 's/-//g')" + if [ -z $SHORT_PKG_NAME ]; then + TAG="$VERSION" + else + TAG="${SHORT_PKG_NAME}==${VERSION}" + fi + echo pkg-name="$PKG_NAME" >> $GITHUB_OUTPUT + echo short-pkg-name="$SHORT_PKG_NAME" >> $GITHUB_OUTPUT + echo version="$VERSION" >> $GITHUB_OUTPUT + echo tag="$TAG" >> $GITHUB_OUTPUT + + release-notes: + needs: + - build + runs-on: ubuntu-latest + outputs: + release-body: ${{ steps.generate-release-body.outputs.release-body }} + steps: + - uses: actions/checkout@v4 + with: + repository: tjni/langgraph-checkpoint-mysql + path: langgraph-checkpoint-mysql + ref: main # this scopes to just master branch + fetch-depth: 0 # this fetches entire commit history + - name: Check Tags + id: check-tags + shell: bash + working-directory: langgraph-checkpoint-mysql + env: + PKG_NAME: ${{ needs.build.outputs.pkg-name }} + SHORT_PKG_NAME: ${{ needs.build.outputs.short-pkg-name }} + VERSION: ${{ needs.build.outputs.version }} + TAG: ${{ needs.build.outputs.tag }} + run: | + if [ -z $SHORT_PKG_NAME ]; then + REGEX="^\\d+\\.\\d+\\.\\d+((a|b|rc)\\d+)?\$" + else + REGEX="^$SHORT_PKG_NAME==\\d+\\.\\d+\\.\\d+((a|b|rc)\\d+)?\$" + fi + echo $REGEX + PREV_TAG=$(git tag --sort=-creatordate | grep -P $REGEX | head -1 || echo "") + echo $PREV_TAG + if [ "$TAG" == "$PREV_TAG" ]; then + echo "No new version to release" + exit 1 + fi + echo prev-tag="$PREV_TAG" >> $GITHUB_OUTPUT + - name: Generate release body + id: generate-release-body + working-directory: langgraph + env: + PKG_NAME: ${{ needs.build.outputs.pkg-name }} + TAG: ${{ needs.build.outputs.tag }} + PREV_TAG: ${{ steps.check-tags.outputs.prev-tag }} + run: | + { + echo 'release-body<> "$GITHUB_OUTPUT" + + test-pypi-publish: + needs: + - build + - release-notes + permissions: write-all + uses: ./.github/workflows/_test_release.yml + secrets: inherit + + pre-release-checks: + needs: + - build + - release-notes + - test-pypi-publish + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + # We explicitly *don't* set up caching here. This ensures our tests are + # maximally sensitive to catching breakage. + # + # For example, here's a way that caching can cause a falsely-passing test: + # - Make the langchain package manifest no longer list a dependency package + # as a requirement. This means it won't be installed by `pip install`, + # and attempting to use it would cause a crash. + # - That dependency used to be required, so it may have been cached. + # When restoring the venv packages from cache, that dependency gets included. + # - Tests pass, because the dependency is present even though it wasn't specified. + # - The package is published, and it breaks on the missing dependency when + # used in the real world. + + - name: Set up Python + Poetry ${{ env.POETRY_VERSION }} + uses: "./.github/actions/poetry_setup" + with: + python-version: ${{ env.PYTHON_VERSION }} + poetry-version: ${{ env.POETRY_VERSION }} + + - name: Import published package + shell: bash + env: + PKG_NAME: ${{ needs.build.outputs.pkg-name }} + VERSION: ${{ needs.build.outputs.version }} + # Here we use: + # - The default regular PyPI index as the *primary* index, meaning + # that it takes priority (https://pypi.org/simple) + # - The test PyPI index as an extra index, so that any dependencies that + # are not found on test PyPI can be resolved and installed anyway. + # (https://test.pypi.org/simple). This will include the PKG_NAME==VERSION + # package because VERSION will not have been uploaded to regular PyPI yet. + # - attempt install again after 5 seconds if it fails because there is + # sometimes a delay in availability on test pypi + run: | + poetry run pip install \ + --extra-index-url https://test.pypi.org/simple/ \ + "$PKG_NAME==$VERSION" || \ + ( \ + sleep 5 && \ + poetry run pip install \ + --extra-index-url https://test.pypi.org/simple/ \ + "$PKG_NAME==$VERSION" \ + ) + + # Replace all dashes in the package name with underscores, + # since that's how Python imports packages with dashes in the name. + IMPORT_NAME="$(echo "$PKG_NAME" | sed s/-/_/g)" + + poetry run python -c "import $IMPORT_NAME; print(dir($IMPORT_NAME))" + + - name: Import test dependencies + run: poetry install --with dev + + # Overwrite the local version of the package with the test PyPI version. + - name: Import published package (again) + shell: bash + env: + PKG_NAME: ${{ needs.build.outputs.pkg-name }} + VERSION: ${{ needs.build.outputs.version }} + run: | + poetry run pip install \ + --extra-index-url https://test.pypi.org/simple/ \ + "$PKG_NAME==$VERSION" + + - name: Run unit tests + run: make test + + publish: + needs: + - build + - release-notes + - test-pypi-publish + - pre-release-checks + runs-on: ubuntu-latest + permissions: + # This permission is used for trusted publishing: + # https://blog.pypi.org/posts/2023-04-20-introducing-trusted-publishers/ + # + # Trusted publishing has to also be configured on PyPI for each package: + # https://docs.pypi.org/trusted-publishers/adding-a-publisher/ + id-token: write + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + Poetry ${{ env.POETRY_VERSION }} + uses: "./.github/actions/poetry_setup" + with: + python-version: ${{ env.PYTHON_VERSION }} + poetry-version: ${{ env.POETRY_VERSION }} + cache-key: release + + - uses: actions/download-artifact@v4 + with: + name: dist + path: dist/ + + - name: Publish package distributions to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + with: + packages-dir: dist/ + verbose: true + print-hash: true + + mark-release: + needs: + - build + - release-notes + - test-pypi-publish + - pre-release-checks + - publish + runs-on: ubuntu-latest + permissions: + # This permission is needed by `ncipollo/release-action` to + # create the GitHub release. + contents: write + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + Poetry ${{ env.POETRY_VERSION }} + uses: "./.github/actions/poetry_setup" + with: + python-version: ${{ env.PYTHON_VERSION }} + poetry-version: ${{ env.POETRY_VERSION }} + cache-key: release + + - uses: actions/download-artifact@v4 + with: + name: dist + path: dist/ + + - name: Create Tag + uses: ncipollo/release-action@v1 + with: + artifacts: "dist/*" + token: ${{ secrets.GITHUB_TOKEN }} + generateReleaseNotes: false + tag: ${{needs.build.outputs.tag}} + body: ${{ needs.release-notes.outputs.release-body }} + commit: ${{ github.sha }} diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..eb176dc --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +.mypy_cache/ +__pycache__/ diff --git a/Makefile b/Makefile index aadb22c..b912768 100644 --- a/Makefile +++ b/Makefile @@ -4,24 +4,24 @@ # TESTING AND COVERAGE ###################### -start-postgres: - docker compose -f tests/compose-postgres.yml up -V --force-recreate --wait +start-mysql: + docker compose -f tests/compose-mysql.yml up -V --force-recreate --wait -stop-postgres: - docker compose -f tests/compose-postgres.yml down +stop-mysql: + docker compose -f tests/compose-mysql.yml down test: - make start-postgres; \ + make start-mysql; \ poetry run pytest; \ EXIT_CODE=$$?; \ - make stop-postgres; \ + make stop-mysql; \ exit $$EXIT_CODE test_watch: - make start-postgres; \ + make start-mysql; \ poetry run ptw .; \ EXIT_CODE=$$?; \ - make stop-postgres; \ + make stop-mysql; \ exit $$EXIT_CODE ###################### diff --git a/README.md b/README.md index cf6beab..c52232d 100644 --- a/README.md +++ b/README.md @@ -1,27 +1,32 @@ -# LangGraph Checkpoint Postgres +# LangGraph Checkpoint MySQL -Implementation of LangGraph CheckpointSaver that uses Postgres. +Implementation of LangGraph CheckpointSaver that uses MySQL. + +> [!TIP] +> The code in this repository tries to mimic the code in [langgraph-checkpoint-postgres](https://github.com/langchain-ai/langgraph/tree/main/libs/checkpoint-postgres) as much as possible to enable keeping in sync with the official checkpointer implementation. ## Dependencies -By default `langgraph-checkpoint-postgres` installs `psycopg` (Psycopg 3) without any extras. However, you can choose a specific installation that best suits your needs [here](https://www.psycopg.org/psycopg3/docs/basic/install.html) (for example, `psycopg[binary]`). +To use synchronous `PyMySQLSaver`, install `langgraph-checkpoint-mysql[pymysql]`. To use asynchronous `AIOMySQLSaver`, install `langgraph-checkpoint-mysql[aiomysql]`. + +There is currently no support for other drivers. ## Usage > [!IMPORTANT] -> When using Postgres checkpointers for the first time, make sure to call `.setup()` method on them to create required tables. See example below. +> When using MySQL checkpointers for the first time, make sure to call `.setup()` method on them to create required tables. See example below. > [!IMPORTANT] -> When manually creating Postgres connections and passing them to `PostgresSaver` or `AsyncPostgresSaver`, make sure to include `autocommit=True` and `row_factory=dict_row` (`from psycopg.rows import dict_row`). See a full example in this [how-to guide](https://langchain-ai.github.io/langgraph/how-tos/persistence_postgres/). +> When manually creating MySQL connections and passing them to `PyMySQLSaver` or `AIOMySQLSaver`, make sure to include `autocommit=True`. ```python -from langgraph.checkpoint.postgres import PostgresSaver +from langgraph.checkpoint.mysql import PyMySQLSaver write_config = {"configurable": {"thread_id": "1", "checkpoint_ns": ""}} read_config = {"configurable": {"thread_id": "1"}} -DB_URI = "postgres://postgres:postgres@localhost:5432/postgres?sslmode=disable" -with PostgresSaver.from_conn_string(DB_URI) as checkpointer: +DB_URI = "mysql://mysql:mysql@localhost:3306/mysql" +with PyMySQLSaver.from_conn_string(DB_URI) as checkpointer: # call .setup() the first time you're using the checkpointer checkpointer.setup() checkpoint = { @@ -63,9 +68,9 @@ with PostgresSaver.from_conn_string(DB_URI) as checkpointer: ### Async ```python -from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver +from langgraph.checkpoint.mysql.aio import AIOMySQLSaver -async with AsyncPostgresSaver.from_conn_string(DB_URI) as checkpointer: +async with AIOMySQLSaver.from_conn_string(DB_URI) as checkpointer: checkpoint = { "v": 1, "ts": "2024-07-31T20:14:19.804150+00:00", diff --git a/langgraph/checkpoint/postgres/__init__.py b/langgraph/checkpoint/mysql/__init__.py similarity index 74% rename from langgraph/checkpoint/postgres/__init__.py rename to langgraph/checkpoint/mysql/__init__.py index a2274cd..dd9dab7 100644 --- a/langgraph/checkpoint/postgres/__init__.py +++ b/langgraph/checkpoint/mysql/__init__.py @@ -1,13 +1,13 @@ +import json import threading +import urllib.parse from contextlib import contextmanager -from typing import Any, Iterator, Optional, Sequence, Union +from typing import Any, Iterator, Optional, Protocol, Sequence, Union +import pymysql +import pymysql.constants.ER +import pymysql.cursors from langchain_core.runnables import RunnableConfig -from psycopg import Connection, Cursor, Pipeline -from psycopg.errors import UndefinedTable -from psycopg.rows import DictRow, dict_row -from psycopg.types.json import Jsonb -from psycopg_pool import ConnectionPool from langgraph.checkpoint.base import ( WRITES_IDX_MAP, @@ -17,84 +17,93 @@ CheckpointTuple, get_checkpoint_id, ) -from langgraph.checkpoint.postgres.base import ( - BasePostgresSaver, +from langgraph.checkpoint.mysql.base import ( + BaseMySQLSaver, ) from langgraph.checkpoint.serde.base import SerializerProtocol -Conn = Union[Connection[DictRow], ConnectionPool[Connection[DictRow]]] + +class ConnectionPool(Protocol): + """Protocol that a MySQL connection pool should implement.""" + + def get_connection(self) -> pymysql.Connection: + """Gets a connection from the connection pool.""" + ... + + +Conn = Union[pymysql.Connection, ConnectionPool] @contextmanager -def _get_connection(conn: Conn) -> Iterator[Connection[DictRow]]: - if isinstance(conn, Connection): +def _get_connection(conn: Conn) -> Iterator[pymysql.Connection]: + if isinstance(conn, pymysql.Connection): yield conn - elif isinstance(conn, ConnectionPool): - with conn.connection() as conn: + elif hasattr(conn, "get_connection"): + with conn.get_connection() as conn: yield conn else: raise TypeError(f"Invalid connection type: {type(conn)}") -class PostgresSaver(BasePostgresSaver): +class PyMySQLSaver(BaseMySQLSaver): lock: threading.Lock def __init__( self, conn: Conn, - pipe: Optional[Pipeline] = None, serde: Optional[SerializerProtocol] = None, ) -> None: super().__init__(serde=serde) - if isinstance(conn, ConnectionPool) and pipe is not None: - raise ValueError( - "Pipeline should be used only with a single Connection, not ConnectionPool." - ) self.conn = conn - self.pipe = pipe self.lock = threading.Lock() @classmethod @contextmanager def from_conn_string( - cls, conn_string: str, *, pipeline: bool = False - ) -> Iterator["PostgresSaver"]: - """Create a new PostgresSaver instance from a connection string. + cls, + conn_string: str, + ) -> Iterator["PyMySQLSaver"]: + """Create a new PyMySQLSaver instance from a connection string. Args: - conn_string (str): The Postgres connection info string. - pipeline (bool): whether to use Pipeline + conn_string (str): The MySQL connection info string. Returns: - PostgresSaver: A new PostgresSaver instance. + PyMySQLSaver: A new PyMySQLSaver instance. """ - with Connection.connect( - conn_string, autocommit=True, prepare_threshold=0, row_factory=dict_row + parsed = urllib.parse.urlparse(conn_string) + + with pymysql.connect( + host=parsed.hostname, + user=parsed.username, + password=parsed.password or "", + database=parsed.path[1:], + port=parsed.port or 3306, + autocommit=True, ) as conn: - if pipeline: - with conn.pipeline() as pipe: - yield PostgresSaver(conn, pipe) - else: - yield PostgresSaver(conn) + yield PyMySQLSaver(conn) def setup(self) -> None: """Set up the checkpoint database asynchronously. - This method creates the necessary tables in the Postgres database if they don't + This method creates the necessary tables in the MySQL database if they don't already exist and runs database migrations. It MUST be called directly by the user the first time checkpointer is used. """ with self._cursor() as cur: try: - row = cur.execute( + cur.execute( "SELECT v FROM checkpoint_migrations ORDER BY v DESC LIMIT 1" - ).fetchone() + ) + row = cur.fetchone() if row is None: version = -1 else: version = row["v"] - except UndefinedTable: + except pymysql.ProgrammingError as e: + if e.args[0] != pymysql.constants.ER.NO_SUCH_TABLE: + raise version = -1 for v, migration in zip( range(version + 1, len(self.MIGRATIONS)), @@ -102,8 +111,6 @@ def setup(self) -> None: ): cur.execute(migration) cur.execute(f"INSERT INTO checkpoint_migrations (v) VALUES ({v})") - if self.pipe: - self.pipe.sync() def list( self, @@ -115,7 +122,7 @@ def list( ) -> Iterator[CheckpointTuple]: """List checkpoints from the database. - This method retrieves a list of checkpoint tuples from the Postgres database based + This method retrieves a list of checkpoint tuples from the MySQL database based on the provided config. The checkpoints are ordered by checkpoint ID in descending order (newest first). Args: @@ -128,9 +135,9 @@ def list( Iterator[CheckpointTuple]: An iterator of checkpoint tuples. Examples: - >>> from langgraph.checkpoint.postgres import PostgresSaver - >>> DB_URI = "postgres://postgres:postgres@localhost:5432/postgres?sslmode=disable" - >>> with PostgresSaver.from_conn_string(DB_URI) as memory: + >>> from langgraph.checkpoint.mysql import PyMySQLSaver + >>> DB_URI = "mysql://mysql:mysql@localhost:5432/mysql" + >>> with PyMySQLSaver.from_conn_string(DB_URI) as memory: ... # Run a graph, then list the checkpoints >>> config = {"configurable": {"thread_id": "1"}} >>> checkpoints = list(memory.list(config, limit=2)) @@ -139,7 +146,7 @@ def list( >>> config = {"configurable": {"thread_id": "1"}} >>> before = {"configurable": {"checkpoint_id": "1ef4f797-8335-6428-8001-8a1503f9b875"}} - >>> with PostgresSaver.from_conn_string(DB_URI) as memory: + >>> with PyMySQLSaver.from_conn_string(DB_URI) as memory: ... # Run a graph, then list the checkpoints >>> checkpoints = list(memory.list(config, before=before)) >>> print(checkpoints) @@ -151,7 +158,7 @@ def list( query += f" LIMIT {limit}" # if we change this to use .stream() we need to make sure to close the cursor with self._cursor() as cur: - cur.execute(query, args, binary=True) + cur.execute(query, args) for value in cur: yield CheckpointTuple( { @@ -162,7 +169,7 @@ def list( } }, self._load_checkpoint( - value["checkpoint"], + json.loads(value["checkpoint"]), value["channel_values"], value["pending_sends"], ), @@ -182,7 +189,7 @@ def list( def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: """Get a checkpoint tuple from the database. - This method retrieves a checkpoint tuple from the Postgres database based on the + This method retrieves a checkpoint tuple from the MySQL database based on the provided config. If the config contains a "checkpoint_id" key, the checkpoint with the matching thread ID and timestamp is retrieved. Otherwise, the latest checkpoint for the given thread ID is retrieved. @@ -228,7 +235,6 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: cur.execute( self.SELECT_SQL + where, args, - binary=True, ) for value in cur: @@ -241,7 +247,7 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: } }, self._load_checkpoint( - value["checkpoint"], + json.loads(value["checkpoint"]), value["channel_values"], value["pending_sends"], ), @@ -267,7 +273,7 @@ def put( ) -> RunnableConfig: """Save a checkpoint to the database. - This method saves a checkpoint to the Postgres database. The checkpoint is associated + This method saves a checkpoint to the MySQL database. The checkpoint is associated with the provided config and its parent config (if any). Args: @@ -281,9 +287,9 @@ def put( Examples: - >>> from langgraph.checkpoint.postgres import PostgresSaver - >>> DB_URI = "postgres://postgres:postgres@localhost:5432/postgres?sslmode=disable" - >>> with PostgresSaver.from_conn_string(DB_URI) as memory: + >>> from langgraph.checkpoint.mysql import PyMySQLSaver + >>> DB_URI = "mysql://mysql:mysql@localhost:5432/mysql" + >>> with PyMySQLSaver.from_conn_string(DB_URI) as memory: >>> config = {"configurable": {"thread_id": "1", "checkpoint_ns": ""}} >>> checkpoint = {"ts": "2024-05-04T06:32:42.235444+00:00", "id": "1ef4f797-8335-6428-8001-8a1503f9b875", "data": {"key": "value"}} >>> saved_config = memory.put(config, checkpoint, {"source": "input", "step": 1, "writes": {"key": "value"}}, {}) @@ -306,7 +312,7 @@ def put( } } - with self._cursor(pipeline=True) as cur: + with self._cursor() as cur: cur.executemany( self.UPSERT_CHECKPOINT_BLOBS_SQL, self._dump_blobs( @@ -323,7 +329,7 @@ def put( checkpoint_ns, checkpoint["id"], checkpoint_id, - Jsonb(self._dump_checkpoint(copy)), + json.dumps(self._dump_checkpoint(copy)), self._dump_metadata(metadata), ), ) @@ -337,7 +343,7 @@ def put_writes( ) -> None: """Store intermediate writes linked to a checkpoint. - This method saves intermediate writes associated with a checkpoint to the Postgres database. + This method saves intermediate writes associated with a checkpoint to the MySQL database. Args: config (RunnableConfig): Configuration of the related checkpoint. @@ -349,7 +355,7 @@ def put_writes( if all(w[0] in WRITES_IDX_MAP for w in writes) else self.INSERT_CHECKPOINT_WRITES_SQL ) - with self._cursor(pipeline=True) as cur: + with self._cursor() as cur: cur.executemany( query, self._dump_writes( @@ -362,25 +368,7 @@ def put_writes( ) @contextmanager - def _cursor(self, *, pipeline: bool = False) -> Iterator[Cursor[DictRow]]: + def _cursor(self) -> Iterator[pymysql.cursors.DictCursor]: with _get_connection(self.conn) as conn: - if self.pipe: - # a connection in pipeline mode can be used concurrently - # in multiple threads/coroutines, but only one cursor can be - # used at a time - try: - with conn.cursor(binary=True, row_factory=dict_row) as cur: - yield cur - finally: - if pipeline: - self.pipe.sync() - elif pipeline: - # a connection not in pipeline mode can only be used by one - # thread/coroutine at a time, so we acquire a lock - with self.lock, conn.pipeline(), conn.cursor( - binary=True, row_factory=dict_row - ) as cur: - yield cur - else: - with self.lock, conn.cursor(binary=True, row_factory=dict_row) as cur: - yield cur + with self.lock, conn.cursor(pymysql.cursors.DictCursor) as cur: + yield cur diff --git a/langgraph/checkpoint/postgres/aio.py b/langgraph/checkpoint/mysql/aio.py similarity index 79% rename from langgraph/checkpoint/postgres/aio.py rename to langgraph/checkpoint/mysql/aio.py index 59ee7cb..00c6fb8 100644 --- a/langgraph/checkpoint/postgres/aio.py +++ b/langgraph/checkpoint/mysql/aio.py @@ -1,13 +1,14 @@ import asyncio +import json +import urllib.parse from contextlib import asynccontextmanager from typing import Any, AsyncIterator, Iterator, Optional, Sequence, Union +import aiomysql # type: ignore +import pymysql +import pymysql.connections +import pymysql.constants.ER from langchain_core.runnables import RunnableConfig -from psycopg import AsyncConnection, AsyncCursor, AsyncPipeline -from psycopg.errors import UndefinedTable -from psycopg.rows import DictRow, dict_row -from psycopg.types.json import Jsonb -from psycopg_pool import AsyncConnectionPool from langgraph.checkpoint.base import ( WRITES_IDX_MAP, @@ -17,42 +18,36 @@ CheckpointTuple, get_checkpoint_id, ) -from langgraph.checkpoint.postgres.base import BasePostgresSaver +from langgraph.checkpoint.mysql.base import BaseMySQLSaver from langgraph.checkpoint.serde.base import SerializerProtocol -Conn = Union[AsyncConnection[DictRow], AsyncConnectionPool[AsyncConnection[DictRow]]] +Conn = Union[aiomysql.Connection, aiomysql.Pool] @asynccontextmanager async def _get_connection( conn: Conn, -) -> AsyncIterator[AsyncConnection[DictRow]]: - if isinstance(conn, AsyncConnection): +) -> AsyncIterator[aiomysql.Connection]: + if isinstance(conn, aiomysql.Connection): yield conn - elif isinstance(conn, AsyncConnectionPool): - async with conn.connection() as conn: - yield conn + elif isinstance(conn, aiomysql.Pool): + async with conn.acquire() as _conn: + yield _conn else: raise TypeError(f"Invalid connection type: {type(conn)}") -class AsyncPostgresSaver(BasePostgresSaver): +class AIOMySQLSaver(BaseMySQLSaver): lock: asyncio.Lock def __init__( self, conn: Conn, - pipe: Optional[AsyncPipeline] = None, serde: Optional[SerializerProtocol] = None, ) -> None: super().__init__(serde=serde) - if isinstance(conn, AsyncConnectionPool) and pipe is not None: - raise ValueError( - "Pipeline should be used only with a single AsyncConnection, not AsyncConnectionPool." - ) self.conn = conn - self.pipe = pipe self.lock = asyncio.Lock() self.loop = asyncio.get_running_loop() @@ -62,45 +57,52 @@ async def from_conn_string( cls, conn_string: str, *, - pipeline: bool = False, serde: Optional[SerializerProtocol] = None, - ) -> AsyncIterator["AsyncPostgresSaver"]: - """Create a new PostgresSaver instance from a connection string. + ) -> AsyncIterator["AIOMySQLSaver"]: + """Create a new AIOMySQLSaver instance from a connection string. Args: - conn_string (str): The Postgres connection info string. - pipeline (bool): whether to use AsyncPipeline + conn_string (str): The MySQL connection info string. Returns: - AsyncPostgresSaver: A new AsyncPostgresSaver instance. + AIOMySQLSaver: A new AIOMySQLSaver instance. """ - async with await AsyncConnection.connect( - conn_string, autocommit=True, prepare_threshold=0, row_factory=dict_row + parsed = urllib.parse.urlparse(conn_string) + + async with aiomysql.connect( + host=parsed.hostname or "localhost", + user=parsed.username, + password=parsed.password or "", + db=parsed.path[1:], + port=parsed.port or 3306, + autocommit=True, ) as conn: - if pipeline: - async with conn.pipeline() as pipe: - yield AsyncPostgresSaver(conn=conn, pipe=pipe, serde=serde) - else: - yield AsyncPostgresSaver(conn=conn, serde=serde) + # This seems necessary until https://github.com/PyMySQL/PyMySQL/pull/1119 + # is merged into aiomysql. + await conn.set_charset(pymysql.connections.DEFAULT_CHARSET) + + yield AIOMySQLSaver(conn=conn, serde=serde) async def setup(self) -> None: """Set up the checkpoint database asynchronously. - This method creates the necessary tables in the Postgres database if they don't + This method creates the necessary tables in the MySQL database if they don't already exist and runs database migrations. It MUST be called directly by the user the first time checkpointer is used. """ async with self._cursor() as cur: try: - results = await cur.execute( + await cur.execute( "SELECT v FROM checkpoint_migrations ORDER BY v DESC LIMIT 1" ) - row = await results.fetchone() + row = await cur.fetchone() if row is None: version = -1 else: version = row["v"] - except UndefinedTable: + except pymysql.ProgrammingError as e: + if e.args[0] != pymysql.constants.ER.NO_SUCH_TABLE: + raise version = -1 for v, migration in zip( range(version + 1, len(self.MIGRATIONS)), @@ -108,8 +110,6 @@ async def setup(self) -> None: ): await cur.execute(migration) await cur.execute(f"INSERT INTO checkpoint_migrations (v) VALUES ({v})") - if self.pipe: - await self.pipe.sync() async def alist( self, @@ -121,7 +121,7 @@ async def alist( ) -> AsyncIterator[CheckpointTuple]: """List checkpoints from the database asynchronously. - This method retrieves a list of checkpoint tuples from the Postgres database based + This method retrieves a list of checkpoint tuples from the MySQL database based on the provided config. The checkpoints are ordered by checkpoint ID in descending order (newest first). Args: @@ -139,7 +139,7 @@ async def alist( query += f" LIMIT {limit}" # if we change this to use .stream() we need to make sure to close the cursor async with self._cursor() as cur: - await cur.execute(query, args, binary=True) + await cur.execute(query, args) async for value in cur: yield CheckpointTuple( { @@ -151,7 +151,7 @@ async def alist( }, await asyncio.to_thread( self._load_checkpoint, - value["checkpoint"], + json.loads(value["checkpoint"]), value["channel_values"], value["pending_sends"], ), @@ -171,7 +171,7 @@ async def alist( async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: """Get a checkpoint tuple from the database asynchronously. - This method retrieves a checkpoint tuple from the Postgres database based on the + This method retrieves a checkpoint tuple from the MySQL database based on the provided config. If the config contains a "checkpoint_id" key, the checkpoint with the matching thread ID and "checkpoint_id" is retrieved. Otherwise, the latest checkpoint for the given thread ID is retrieved. @@ -196,7 +196,6 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: await cur.execute( self.SELECT_SQL + where, args, - binary=True, ) async for value in cur: @@ -210,7 +209,7 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: }, await asyncio.to_thread( self._load_checkpoint, - value["checkpoint"], + json.loads(value["checkpoint"]), value["channel_values"], value["pending_sends"], ), @@ -236,7 +235,7 @@ async def aput( ) -> RunnableConfig: """Save a checkpoint to the database asynchronously. - This method saves a checkpoint to the Postgres database. The checkpoint is associated + This method saves a checkpoint to the MySQL database. The checkpoint is associated with the provided config and its parent config (if any). Args: @@ -264,7 +263,7 @@ async def aput( } } - async with self._cursor(pipeline=True) as cur: + async with self._cursor() as cur: await cur.executemany( self.UPSERT_CHECKPOINT_BLOBS_SQL, await asyncio.to_thread( @@ -282,7 +281,7 @@ async def aput( checkpoint_ns, checkpoint["id"], checkpoint_id, - Jsonb(self._dump_checkpoint(copy)), + json.dumps(self._dump_checkpoint(copy)), self._dump_metadata(metadata), ), ) @@ -316,36 +315,14 @@ async def aput_writes( task_id, writes, ) - async with self._cursor(pipeline=True) as cur: + async with self._cursor() as cur: await cur.executemany(query, params) @asynccontextmanager - async def _cursor( - self, *, pipeline: bool = False - ) -> AsyncIterator[AsyncCursor[DictRow]]: + async def _cursor(self) -> AsyncIterator[aiomysql.DictCursor]: async with _get_connection(self.conn) as conn: - if self.pipe: - # a connection in pipeline mode can be used concurrently - # in multiple threads/coroutines, but only one cursor can be - # used at a time - try: - async with conn.cursor(binary=True, row_factory=dict_row) as cur: - yield cur - finally: - if pipeline: - await self.pipe.sync() - elif pipeline: - # a connection not in pipeline mode can only be used by one - # thread/coroutine at a time, so we acquire a lock - async with self.lock, conn.pipeline(), conn.cursor( - binary=True, row_factory=dict_row - ) as cur: - yield cur - else: - async with self.lock, conn.cursor( - binary=True, row_factory=dict_row - ) as cur: - yield cur + async with self.lock, conn.cursor(aiomysql.DictCursor) as cur: + yield cur def list( self, @@ -357,7 +334,7 @@ def list( ) -> Iterator[CheckpointTuple]: """List checkpoints from the database. - This method retrieves a list of checkpoint tuples from the Postgres database based + This method retrieves a list of checkpoint tuples from the MySQL database based on the provided config. The checkpoints are ordered by checkpoint ID in descending order (newest first). Args: @@ -382,7 +359,7 @@ def list( def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: """Get a checkpoint tuple from the database. - This method retrieves a checkpoint tuple from the Postgres database based on the + This method retrieves a checkpoint tuple from the MySQL database based on the provided config. If the config contains a "checkpoint_id" key, the checkpoint with the matching thread ID and "checkpoint_id" is retrieved. Otherwise, the latest checkpoint for the given thread ID is retrieved. @@ -398,7 +375,7 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: # we don't check in other methods to avoid the overhead if asyncio.get_running_loop() is self.loop: raise asyncio.InvalidStateError( - "Synchronous calls to AsyncPostgresSaver are only allowed from a " + "Synchronous calls to AIOMySQLSaver are only allowed from a " "different thread. From the main thread, use the async interface." "For example, use `await checkpointer.aget_tuple(...)` or `await " "graph.ainvoke(...)`." @@ -418,7 +395,7 @@ def put( ) -> RunnableConfig: """Save a checkpoint to the database. - This method saves a checkpoint to the Postgres database. The checkpoint is associated + This method saves a checkpoint to the MySQL database. The checkpoint is associated with the provided config and its parent config (if any). Args: diff --git a/langgraph/checkpoint/postgres/base.py b/langgraph/checkpoint/mysql/base.py similarity index 76% rename from langgraph/checkpoint/postgres/base.py rename to langgraph/checkpoint/mysql/base.py index 76232e3..32ec95a 100644 --- a/langgraph/checkpoint/postgres/base.py +++ b/langgraph/checkpoint/mysql/base.py @@ -1,8 +1,8 @@ +import json import random from typing import Any, List, Optional, Sequence, Tuple, cast from langchain_core.runnables import RunnableConfig -from psycopg.types.json import Jsonb from langgraph.checkpoint.base import ( WRITES_IDX_MAP, @@ -26,36 +26,36 @@ v INTEGER PRIMARY KEY );""", """CREATE TABLE IF NOT EXISTS checkpoints ( - thread_id TEXT NOT NULL, - checkpoint_ns TEXT NOT NULL DEFAULT '', - checkpoint_id TEXT NOT NULL, - parent_checkpoint_id TEXT, - type TEXT, - checkpoint JSONB NOT NULL, - metadata JSONB NOT NULL DEFAULT '{}', + thread_id VARCHAR(150) NOT NULL, + checkpoint_ns VARCHAR(150) NOT NULL DEFAULT '', + checkpoint_id VARCHAR(150) NOT NULL, + parent_checkpoint_id VARCHAR(150), + type VARCHAR(150), + checkpoint JSON NOT NULL, + metadata JSON NOT NULL DEFAULT ('{}'), PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id) );""", """CREATE TABLE IF NOT EXISTS checkpoint_blobs ( - thread_id TEXT NOT NULL, - checkpoint_ns TEXT NOT NULL DEFAULT '', - channel TEXT NOT NULL, - version TEXT NOT NULL, - type TEXT NOT NULL, - blob BYTEA, + thread_id VARCHAR(150) NOT NULL, + checkpoint_ns VARCHAR(150) NOT NULL DEFAULT '', + channel VARCHAR(150) NOT NULL, + version VARCHAR(150) NOT NULL, + type VARCHAR(150) NOT NULL, + `blob` LONGBLOB, PRIMARY KEY (thread_id, checkpoint_ns, channel, version) );""", """CREATE TABLE IF NOT EXISTS checkpoint_writes ( - thread_id TEXT NOT NULL, - checkpoint_ns TEXT NOT NULL DEFAULT '', - checkpoint_id TEXT NOT NULL, - task_id TEXT NOT NULL, + thread_id VARCHAR(150) NOT NULL, + checkpoint_ns VARCHAR(150) NOT NULL DEFAULT '', + checkpoint_id VARCHAR(150) NOT NULL, + task_id VARCHAR(150) NOT NULL, idx INTEGER NOT NULL, - channel TEXT NOT NULL, - type TEXT, - blob BYTEA NOT NULL, + channel VARCHAR(150) NOT NULL, + type VARCHAR(150), + `blob` LONGBLOB NOT NULL, PRIMARY KEY (thread_id, checkpoint_ns, checkpoint_id, task_id, idx) );""", - "ALTER TABLE checkpoint_blobs ALTER COLUMN blob DROP not null;", + "ALTER TABLE checkpoint_blobs MODIFY COLUMN `blob` LONGBLOB;", ] SELECT_SQL = f""" @@ -67,24 +67,30 @@ parent_checkpoint_id, metadata, ( - select array_agg(array[bl.channel::bytea, bl.type::bytea, bl.blob]) - from jsonb_each_text(checkpoint -> 'channel_versions') + select json_arrayagg(json_array(bl.channel, bl.type, bl.blob)) + from json_table( + checkpoint, + '$.channel_versions[*]' columns ( + `key` VARCHAR(255) PATH '$.key', + value VARCHAR(255) PATH '$.value' + ) + ) as channel_versions inner join checkpoint_blobs bl on bl.thread_id = checkpoints.thread_id and bl.checkpoint_ns = checkpoints.checkpoint_ns - and bl.channel = jsonb_each_text.key - and bl.version = jsonb_each_text.value + and bl.channel = channel_versions.key + and bl.version = channel_versions.value ) as channel_values, ( select - array_agg(array[cw.task_id::text::bytea, cw.channel::bytea, cw.type::bytea, cw.blob] order by cw.task_id, cw.idx) + json_arrayagg(json_array(cw.task_id, cw.channel, cw.type, cw.blob)) from checkpoint_writes cw where cw.thread_id = checkpoints.thread_id and cw.checkpoint_ns = checkpoints.checkpoint_ns and cw.checkpoint_id = checkpoints.checkpoint_id ) as pending_writes, ( - select array_agg(array[cw.type::bytea, cw.blob] order by cw.idx) + select json_arrayagg(json_array(cw.type, cw.blob)) from checkpoint_writes cw where cw.thread_id = checkpoints.thread_id and cw.checkpoint_ns = checkpoints.checkpoint_ns @@ -94,37 +100,34 @@ from checkpoints """ UPSERT_CHECKPOINT_BLOBS_SQL = """ - INSERT INTO checkpoint_blobs (thread_id, checkpoint_ns, channel, version, type, blob) + INSERT IGNORE INTO checkpoint_blobs (thread_id, checkpoint_ns, channel, version, type, `blob`) VALUES (%s, %s, %s, %s, %s, %s) - ON CONFLICT (thread_id, checkpoint_ns, channel, version) DO NOTHING """ UPSERT_CHECKPOINTS_SQL = """ INSERT INTO checkpoints (thread_id, checkpoint_ns, checkpoint_id, parent_checkpoint_id, checkpoint, metadata) - VALUES (%s, %s, %s, %s, %s, %s) - ON CONFLICT (thread_id, checkpoint_ns, checkpoint_id) - DO UPDATE SET - checkpoint = EXCLUDED.checkpoint, - metadata = EXCLUDED.metadata; + VALUES (%s, %s, %s, %s, %s, %s) AS new + ON DUPLICATE KEY UPDATE + checkpoint = new.checkpoint, + metadata = new.metadata; """ UPSERT_CHECKPOINT_WRITES_SQL = """ - INSERT INTO checkpoint_writes (thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, blob) - VALUES (%s, %s, %s, %s, %s, %s, %s, %s) - ON CONFLICT (thread_id, checkpoint_ns, checkpoint_id, task_id, idx) DO UPDATE SET - channel = EXCLUDED.channel, - type = EXCLUDED.type, - blob = EXCLUDED.blob; + INSERT INTO checkpoint_writes (thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, `blob`) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s) AS new + ON DUPLICATE KEY UPDATE + channel = new.channel, + type = new.type, + `blob` = new.blob; """ INSERT_CHECKPOINT_WRITES_SQL = """ - INSERT INTO checkpoint_writes (thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, blob) + INSERT IGNORE INTO checkpoint_writes (thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, `blob`) VALUES (%s, %s, %s, %s, %s, %s, %s, %s) - ON CONFLICT (thread_id, checkpoint_ns, checkpoint_id, task_id, idx) DO NOTHING """ -class BasePostgresSaver(BaseCheckpointSaver[str]): +class BaseMySQLSaver(BaseCheckpointSaver[str]): SELECT_SQL = SELECT_SQL MIGRATIONS = MIGRATIONS UPSERT_CHECKPOINT_BLOBS_SQL = UPSERT_CHECKPOINT_BLOBS_SQL @@ -224,8 +227,8 @@ def _dump_writes( for idx, (channel, value) in enumerate(writes) ] - def _load_metadata(self, metadata: dict[str, Any]) -> CheckpointMetadata: - return self.jsonplus_serde.loads(self.jsonplus_serde.dumps(metadata)) + def _load_metadata(self, metadata: str) -> CheckpointMetadata: + return self.jsonplus_serde.loads(metadata.encode()) def _dump_metadata(self, metadata: CheckpointMetadata) -> str: serialized_metadata = self.jsonplus_serde.dumps(metadata) @@ -273,8 +276,8 @@ def _search_where( # construct predicate for metadata filter if filter: - wheres.append("metadata @> %s ") - param_values.append(Jsonb(filter)) + wheres.append("json_contains(metadata, %s) ") + param_values.append(json.dumps(filter)) # construct predicate for `before` if before is not None: diff --git a/poetry.lock b/poetry.lock index 473bd42..76f3925 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,5 +1,23 @@ # This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +[[package]] +name = "aiomysql" +version = "0.2.0" +description = "MySQL driver for asyncio." +optional = false +python-versions = ">=3.7" +files = [ + {file = "aiomysql-0.2.0-py3-none-any.whl", hash = "sha256:b7c26da0daf23a5ec5e0b133c03d20657276e4eae9b73e040b72787f6f6ade0a"}, + {file = "aiomysql-0.2.0.tar.gz", hash = "sha256:558b9c26d580d08b8c5fd1be23c5231ce3aeff2dadad989540fee740253deb67"}, +] + +[package.dependencies] +PyMySQL = ">=1.0" + +[package.extras] +rsa = ["PyMySQL[rsa] (>=1.0)"] +sa = ["sqlalchemy (>=1.3,<1.4)"] + [[package]] name = "annotated-types" version = "0.7.0" @@ -327,17 +345,15 @@ name = "langgraph-checkpoint" version = "1.0.11" description = "Library with base interfaces for LangGraph checkpoint savers." optional = false -python-versions = "^3.9.0,<4.0" -files = [] -develop = true +python-versions = "<4.0.0,>=3.9.0" +files = [ + {file = "langgraph_checkpoint-1.0.11-py3-none-any.whl", hash = "sha256:9644bd61e3ab5b03fc0422aa5e625061ad14aa2012d046bf4bb306451da95371"}, + {file = "langgraph_checkpoint-1.0.11.tar.gz", hash = "sha256:156af1666272a0be3cda4a2c4ffe6b2e2f5af8ead7d450d345cbb39828ce4b05"}, +] [package.dependencies] langchain-core = ">=0.2.38,<0.4" -msgpack = "^1.1.0" - -[package.source] -type = "directory" -url = "../checkpoint" +msgpack = ">=1.1.0,<2.0.0" [[package]] name = "langsmith" @@ -578,106 +594,6 @@ files = [ dev = ["pre-commit", "tox"] testing = ["pytest", "pytest-benchmark"] -[[package]] -name = "psycopg" -version = "3.2.1" -description = "PostgreSQL database adapter for Python" -optional = false -python-versions = ">=3.8" -files = [ - {file = "psycopg-3.2.1-py3-none-any.whl", hash = "sha256:ece385fb413a37db332f97c49208b36cf030ff02b199d7635ed2fbd378724175"}, - {file = "psycopg-3.2.1.tar.gz", hash = "sha256:dc8da6dc8729dacacda3cc2f17d2c9397a70a66cf0d2b69c91065d60d5f00cb7"}, -] - -[package.dependencies] -psycopg-binary = {version = "3.2.1", optional = true, markers = "implementation_name != \"pypy\" and extra == \"binary\""} -typing-extensions = ">=4.4" -tzdata = {version = "*", markers = "sys_platform == \"win32\""} - -[package.extras] -binary = ["psycopg-binary (==3.2.1)"] -c = ["psycopg-c (==3.2.1)"] -dev = ["ast-comments (>=1.1.2)", "black (>=24.1.0)", "codespell (>=2.2)", "dnspython (>=2.1)", "flake8 (>=4.0)", "mypy (>=1.6)", "types-setuptools (>=57.4)", "wheel (>=0.37)"] -docs = ["Sphinx (>=5.0)", "furo (==2022.6.21)", "sphinx-autobuild (>=2021.3.14)", "sphinx-autodoc-typehints (>=1.12)"] -pool = ["psycopg-pool"] -test = ["anyio (>=4.0)", "mypy (>=1.6)", "pproxy (>=2.7)", "pytest (>=6.2.5)", "pytest-cov (>=3.0)", "pytest-randomly (>=3.5)"] - -[[package]] -name = "psycopg-binary" -version = "3.2.1" -description = "PostgreSQL database adapter for Python -- C optimisation distribution" -optional = false -python-versions = ">=3.8" -files = [ - {file = "psycopg_binary-3.2.1-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:cad2de17804c4cfee8640ae2b279d616bb9e4734ac3c17c13db5e40982bd710d"}, - {file = "psycopg_binary-3.2.1-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:592b27d6c46a40f9eeaaeea7c1fef6f3c60b02c634365eb649b2d880669f149f"}, - {file = "psycopg_binary-3.2.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9a997efbaadb5e1a294fb5760e2f5643d7b8e4e3fe6cb6f09e6d605fd28e0291"}, - {file = "psycopg_binary-3.2.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c1d2b6438fb83376f43ebb798bf0ad5e57bc56c03c9c29c85bc15405c8c0ac5a"}, - {file = "psycopg_binary-3.2.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b1f087bd84bdcac78bf9f024ebdbfacd07fc0a23ec8191448a50679e2ac4a19e"}, - {file = "psycopg_binary-3.2.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:415c3b72ea32119163255c6504085f374e47ae7345f14bc3f0ef1f6e0976a879"}, - {file = "psycopg_binary-3.2.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:f092114f10f81fb6bae544a0ec027eb720e2d9c74a4fcdaa9dd3899873136935"}, - {file = "psycopg_binary-3.2.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:06a7aae34edfe179ddc04da005e083ff6c6b0020000399a2cbf0a7121a8a22ea"}, - {file = "psycopg_binary-3.2.1-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:0b018631e5c80ce9bc210b71ea885932f9cca6db131e4df505653d7e3873a938"}, - {file = "psycopg_binary-3.2.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:f8a509aeaac364fa965454e80cd110fe6d48ba2c80f56c9b8563423f0b5c3cfd"}, - {file = "psycopg_binary-3.2.1-cp310-cp310-win_amd64.whl", hash = "sha256:413977d18412ff83486eeb5875eb00b185a9391c57febac45b8993bf9c0ff489"}, - {file = "psycopg_binary-3.2.1-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:62b1b7b07e00ee490afb39c0a47d8282a9c2822c7cfed9553a04b0058adf7e7f"}, - {file = "psycopg_binary-3.2.1-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:f8afb07114ea9b924a4a0305ceb15354ccf0ef3c0e14d54b8dbeb03e50182dd7"}, - {file = "psycopg_binary-3.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:40bb515d042f6a345714ec0403df68ccf13f73b05e567837d80c886c7c9d3805"}, - {file = "psycopg_binary-3.2.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6418712ba63cebb0c88c050b3997185b0ef54173b36568522d5634ac06153040"}, - {file = "psycopg_binary-3.2.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:101472468d59c74bb8565fab603e032803fd533d16be4b2d13da1bab8deb32a3"}, - {file = "psycopg_binary-3.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aa3931f308ab4a479d0ee22dc04bea867a6365cac0172e5ddcba359da043854b"}, - {file = "psycopg_binary-3.2.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:dc314a47d44fe1a8069b075a64abffad347a3a1d8652fed1bab5d3baea37acb2"}, - {file = "psycopg_binary-3.2.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:cc304a46be1e291031148d9d95c12451ffe783ff0cc72f18e2cc7ec43cdb8c68"}, - {file = "psycopg_binary-3.2.1-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:6f9e13600647087df5928875559f0eb8f496f53e6278b7da9511b4b3d0aff960"}, - {file = "psycopg_binary-3.2.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:b140182830c76c74d17eba27df3755a46442ce8d4fb299e7f1cf2f74a87c877b"}, - {file = "psycopg_binary-3.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:3c838806eeb99af39f934b7999e35f947a8e577997cc892c12b5053a97a9057f"}, - {file = "psycopg_binary-3.2.1-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:7066d3dca196ed0dc6172f9777b2d62e4f138705886be656cccff2d555234d60"}, - {file = "psycopg_binary-3.2.1-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:28ada5f610468c57d8a4a055a8ea915d0085a43d794266c4f3b9d02f4288f4db"}, - {file = "psycopg_binary-3.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2e8213bf50af073b1aa8dc3cff123bfeedac86332a16c1b7274910bc88a847c7"}, - {file = "psycopg_binary-3.2.1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:74d623261655a169bc84a9669890975c229f2fa6e19a7f2d10a77675dcf1a707"}, - {file = "psycopg_binary-3.2.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:42781ba94e8842ee98bca5a7d0c44cc9d067500fedca2d6a90fa3609b6d16b42"}, - {file = "psycopg_binary-3.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:33e6669091d09f8ba36e10ce678a6d9916e110446236a9b92346464a3565635e"}, - {file = "psycopg_binary-3.2.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:b09e8a576a2ac69d695032ee76f31e03b30781828b5dd6d18c6a009e5a3d1c35"}, - {file = "psycopg_binary-3.2.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:8f28ff0cb9f1defdc4a6f8c958bf6787274247e7dfeca811f6e2f56602695fb1"}, - {file = "psycopg_binary-3.2.1-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:4c84fcac8a3a3479ac14673095cc4e1fdba2935499f72c436785ac679bec0d1a"}, - {file = "psycopg_binary-3.2.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:950fd666ec9e9fe6a8eeb2b5a8f17301790e518953730ad44d715b59ffdbc67f"}, - {file = "psycopg_binary-3.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:334046a937bb086c36e2c6889fe327f9f29bfc085d678f70fac0b0618949f674"}, - {file = "psycopg_binary-3.2.1-cp38-cp38-macosx_12_0_x86_64.whl", hash = "sha256:1d6833f607f3fc7b22226a9e121235d3b84c0eda1d3caab174673ef698f63788"}, - {file = "psycopg_binary-3.2.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d353e028b8f848b9784450fc2abf149d53a738d451eab3ee4c85703438128b9"}, - {file = "psycopg_binary-3.2.1-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f34e369891f77d0738e5d25727c307d06d5344948771e5379ea29c76c6d84555"}, - {file = "psycopg_binary-3.2.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0ab58213cc976a1666f66bc1cb2e602315cd753b7981a8e17237ac2a185bd4a1"}, - {file = "psycopg_binary-3.2.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b0104a72a17aa84b3b7dcab6c84826c595355bf54bb6ea6d284dcb06d99c6801"}, - {file = "psycopg_binary-3.2.1-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:059cbd4e6da2337e17707178fe49464ed01de867dc86c677b30751755ec1dc51"}, - {file = "psycopg_binary-3.2.1-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:73f9c9b984be9c322b5ec1515b12df1ee5896029f5e72d46160eb6517438659c"}, - {file = "psycopg_binary-3.2.1-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:af0469c00f24c4bec18c3d2ede124bf62688d88d1b8a5f3c3edc2f61046fe0d7"}, - {file = "psycopg_binary-3.2.1-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:463d55345f73ff391df8177a185ad57b552915ad33f5cc2b31b930500c068b22"}, - {file = "psycopg_binary-3.2.1-cp38-cp38-win_amd64.whl", hash = "sha256:302b86f92c0d76e99fe1b5c22c492ae519ce8b98b88d37ef74fda4c9e24c6b46"}, - {file = "psycopg_binary-3.2.1-cp39-cp39-macosx_12_0_x86_64.whl", hash = "sha256:0879b5d76b7d48678d31278242aaf951bc2d69ca4e4d7cef117e4bbf7bfefda9"}, - {file = "psycopg_binary-3.2.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f99e59f8a5f4dcd9cbdec445f3d8ac950a492fc0e211032384d6992ed3c17eb7"}, - {file = "psycopg_binary-3.2.1-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:84837e99353d16c6980603b362d0f03302d4b06c71672a6651f38df8a482923d"}, - {file = "psycopg_binary-3.2.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7ce965caf618061817f66c0906f0452aef966c293ae0933d4fa5a16ea6eaf5bb"}, - {file = "psycopg_binary-3.2.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:78c2007caf3c90f08685c5378e3ceb142bafd5636be7495f7d86ec8a977eaeef"}, - {file = "psycopg_binary-3.2.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:7a84b5eb194a258116154b2a4ff2962ea60ea52de089508db23a51d3d6b1c7d1"}, - {file = "psycopg_binary-3.2.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:4a42b8f9ab39affcd5249b45cac763ac3cf12df962b67e23fd15a2ee2932afe5"}, - {file = "psycopg_binary-3.2.1-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:788ffc43d7517c13e624c83e0e553b7b8823c9655e18296566d36a829bfb373f"}, - {file = "psycopg_binary-3.2.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:21927f41c4d722ae8eb30d62a6ce732c398eac230509af5ba1749a337f8a63e2"}, - {file = "psycopg_binary-3.2.1-cp39-cp39-win_amd64.whl", hash = "sha256:921f0c7f39590763d64a619de84d1b142587acc70fd11cbb5ba8fa39786f3073"}, -] - -[[package]] -name = "psycopg-pool" -version = "3.2.2" -description = "Connection Pool for Psycopg" -optional = false -python-versions = ">=3.8" -files = [ - {file = "psycopg_pool-3.2.2-py3-none-any.whl", hash = "sha256:273081d0fbfaced4f35e69200c89cb8fbddfe277c38cc86c235b90a2ec2c8153"}, - {file = "psycopg_pool-3.2.2.tar.gz", hash = "sha256:9e22c370045f6d7f2666a5ad1b0caf345f9f1912195b0b25d0d3bcc4f3a7389c"}, -] - -[package.dependencies] -typing-extensions = ">=4.4" - [[package]] name = "pydantic" version = "2.8.2" @@ -801,6 +717,21 @@ files = [ [package.dependencies] typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" +[[package]] +name = "pymysql" +version = "1.1.1" +description = "Pure Python MySQL Driver" +optional = false +python-versions = ">=3.7" +files = [ + {file = "PyMySQL-1.1.1-py3-none-any.whl", hash = "sha256:4de15da4c61dc132f4fb9ab763063e693d521a80fd0e87943b9a453dd4c19d6c"}, + {file = "pymysql-1.1.1.tar.gz", hash = "sha256:e127611aaf2b417403c60bf4dc570124aeb4a57f5f37b8e95ae399a42f904cd0"}, +] + +[package.extras] +ed25519 = ["PyNaCl (>=1.4.0)"] +rsa = ["cryptography"] + [[package]] name = "pytest" version = "7.4.4" @@ -1020,25 +951,25 @@ files = [ ] [[package]] -name = "typing-extensions" -version = "4.12.2" -description = "Backported and Experimental Type Hints for Python 3.8+" +name = "types-pymysql" +version = "1.1.0.20240524" +description = "Typing stubs for PyMySQL" optional = false python-versions = ">=3.8" files = [ - {file = "typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d"}, - {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"}, + {file = "types-PyMySQL-1.1.0.20240524.tar.gz", hash = "sha256:93058fef2077c407e29bdcd1a7dfbbf06a59324a5440df30dd002f572199ac17"}, + {file = "types_PyMySQL-1.1.0.20240524-py3-none-any.whl", hash = "sha256:8be5be228bf6376f9055ec03bec0dfa6f1a84163f9a89305db446f0b31f87be3"}, ] [[package]] -name = "tzdata" -version = "2024.1" -description = "Provider of IANA time zone data" +name = "typing-extensions" +version = "4.12.2" +description = "Backported and Experimental Type Hints for Python 3.8+" optional = false -python-versions = ">=2" +python-versions = ">=3.8" files = [ - {file = "tzdata-2024.1-py2.py3-none-any.whl", hash = "sha256:9068bc196136463f5245e51efda838afa15aaeca9903f49050dfa2679db4d252"}, - {file = "tzdata-2024.1.tar.gz", hash = "sha256:2674120f8d891909751c38abcdfd386ac0a5a1127954fbc332af6b5ceae07efd"}, + {file = "typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d"}, + {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"}, ] [[package]] @@ -1102,7 +1033,11 @@ files = [ [package.extras] watchmedo = ["PyYAML (>=3.10)"] +[extras] +aiomysql = ["aiomysql"] +pymysql = ["pymysql"] + [metadata] lock-version = "2.0" python-versions = "^3.9.0,<4.0" -content-hash = "4b174f142964238b985de1a2209680bc83e9a310e6797be14dc26f61fe1352e2" +content-hash = "33ab0df9936643f7d819eacef378f4fa9e1cd91f0640c877e45f97b8cda0098e" diff --git a/pyproject.toml b/pyproject.toml index 0385fdd..0224010 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,19 +1,22 @@ [tool.poetry] -name = "langgraph-checkpoint-postgres" -version = "1.0.8" -description = "Library with a Postgres implementation of LangGraph checkpoint saver." -authors = [] +name = "langgraph-checkpoint-mysql" +version = "1.0.0" +description = "Library with a MySQL implementation of LangGraph checkpoint saver." +authors = ["Theodore Ni "] license = "MIT" readme = "README.md" -repository = "https://www.github.com/langchain-ai/langgraph" +repository = "https://www.github.com/tjni/langgraph-checkpoint-mysql" packages = [{ include = "langgraph" }] [tool.poetry.dependencies] python = "^3.9.0,<4.0" langgraph-checkpoint = "^1.0.11" -orjson = ">=3.10.1" -psycopg = "^3.0.0" -psycopg-pool = "^3.0.0" +pymysql = { version = "^1.1.1", optional = true } +aiomysql = { version = "^0.2.0", optional = true } + +[tool.poetry.extras] +pymysql = ["pymysql"] +aiomysql = ["aiomysql"] [tool.poetry.group.dev.dependencies] ruff = "^0.6.2" @@ -24,8 +27,9 @@ pytest-asyncio = "^0.21.1" pytest-mock = "^3.11.1" pytest-watch = "^4.2.0" mypy = "^1.10.0" -psycopg = {extras = ["binary"], version = ">=3.0.0"} -langgraph-checkpoint = {path = "../checkpoint", develop = true} +pymysql = "^1.1.1" +aiomysql = "^0.2.0" +types-PyMySQL = "^1.1.0" [tool.pytest.ini_options] # --strict-markers will raise errors on unknown marks. diff --git a/tests/compose-mysql.yml b/tests/compose-mysql.yml new file mode 100644 index 0000000..6fd6340 --- /dev/null +++ b/tests/compose-mysql.yml @@ -0,0 +1,17 @@ +services: + mysql-test: + image: mysql:8 + ports: + - "5441:3306" + environment: + MYSQL_ROOT_PASSWORD: mysql + MYSQL_DATABASE: mysql + MYSQL_USER: mysql + MYSQL_PASSWORD: mysql + healthcheck: + test: mysqladmin -h 127.0.0.1 ping -P 3306 -u mysql -pmysql | grep "mysqld is alive" + start_period: 10s + timeout: 1s + retries: 5 + interval: 60s + start_interval: 1s diff --git a/tests/compose-postgres.yml b/tests/compose-postgres.yml deleted file mode 100644 index 42d8c37..0000000 --- a/tests/compose-postgres.yml +++ /dev/null @@ -1,16 +0,0 @@ -services: - postgres-test: - image: postgres:16 - ports: - - "5441:5432" - environment: - POSTGRES_DB: postgres - POSTGRES_USER: postgres - POSTGRES_PASSWORD: postgres - healthcheck: - test: pg_isready -U postgres - start_period: 10s - timeout: 1s - retries: 5 - interval: 60s - start_interval: 1s diff --git a/tests/conftest.py b/tests/conftest.py index 3633664..b51416c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,27 +1,35 @@ +import urllib.parse from typing import AsyncIterator +import aiomysql # type: ignore +import pymysql +import pymysql.constants.ER import pytest -from psycopg import AsyncConnection -from psycopg.errors import UndefinedTable -from psycopg.rows import DictRow, dict_row -DEFAULT_URI = "postgres://postgres:postgres@localhost:5441/postgres?sslmode=disable" +DEFAULT_URI = "mysql://mysql:mysql@localhost:5441/mysql" @pytest.fixture(scope="function") -async def conn() -> AsyncIterator[AsyncConnection[DictRow]]: - async with await AsyncConnection.connect( - DEFAULT_URI, autocommit=True, prepare_threshold=0, row_factory=dict_row +async def conn() -> AsyncIterator[aiomysql.Connection]: + parsed = urllib.parse.urlparse(DEFAULT_URI) + async with await aiomysql.connect( + user=parsed.username, + password=parsed.password or "", + db=parsed.path[1:], + port=parsed.port or 3306, + autocommit=True, ) as conn: yield conn @pytest.fixture(scope="function", autouse=True) -async def clear_test_db(conn: AsyncConnection[DictRow]) -> None: +async def clear_test_db(conn: aiomysql.Connection) -> None: """Delete all tables before each test.""" try: - await conn.execute("DELETE FROM checkpoints") - await conn.execute("DELETE FROM checkpoint_blobs") - await conn.execute("DELETE FROM checkpoint_writes") - except UndefinedTable: - pass + async with conn.cursor() as cursor: + await cursor.execute("DELETE FROM checkpoints") + await cursor.execute("DELETE FROM checkpoint_blobs") + await cursor.execute("DELETE FROM checkpoint_writes") + except pymysql.ProgrammingError as e: + if e.args[0] != pymysql.constants.ER.NO_SUCH_TABLE: + raise diff --git a/tests/test_async.py b/tests/test_async.py index 6b3e359..beecd11 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -1,19 +1,19 @@ from typing import Any import pytest +from conftest import DEFAULT_URI # type: ignore from langchain_core.runnables import RunnableConfig + from langgraph.checkpoint.base import ( Checkpoint, CheckpointMetadata, create_checkpoint, empty_checkpoint, ) -from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver - -from conftest import DEFAULT_URI # type: ignore +from langgraph.checkpoint.mysql.aio import AIOMySQLSaver -class TestAsyncPostgresSaver: +class TestAIOMySQLSaver: @pytest.fixture(autouse=True) async def setup(self) -> None: # objects for test setup @@ -57,11 +57,11 @@ async def setup(self) -> None: "score": None, } self.metadata_3: CheckpointMetadata = {} - async with AsyncPostgresSaver.from_conn_string(DEFAULT_URI) as saver: + async with AIOMySQLSaver.from_conn_string(DEFAULT_URI) as saver: await saver.setup() async def test_asearch(self) -> None: - async with AsyncPostgresSaver.from_conn_string(DEFAULT_URI) as saver: + async with AIOMySQLSaver.from_conn_string(DEFAULT_URI) as saver: await saver.aput(self.config_1, self.chkpnt_1, self.metadata_1, {}) await saver.aput(self.config_2, self.chkpnt_2, self.metadata_2, {}) await saver.aput(self.config_3, self.chkpnt_3, self.metadata_3, {}) diff --git a/tests/test_sync.py b/tests/test_sync.py index 56b0117..5429793 100644 --- a/tests/test_sync.py +++ b/tests/test_sync.py @@ -1,19 +1,19 @@ from typing import Any import pytest +from conftest import DEFAULT_URI # type: ignore from langchain_core.runnables import RunnableConfig + from langgraph.checkpoint.base import ( Checkpoint, CheckpointMetadata, create_checkpoint, empty_checkpoint, ) -from langgraph.checkpoint.postgres import PostgresSaver - -from conftest import DEFAULT_URI # type: ignore +from langgraph.checkpoint.mysql import PyMySQLSaver -class TestPostgresSaver: +class TestPyMySQLSaver: @pytest.fixture(autouse=True) def setup(self) -> None: # objects for test setup @@ -57,11 +57,11 @@ def setup(self) -> None: "score": None, } self.metadata_3: CheckpointMetadata = {} - with PostgresSaver.from_conn_string(DEFAULT_URI) as saver: + with PyMySQLSaver.from_conn_string(DEFAULT_URI) as saver: saver.setup() def test_search(self) -> None: - with PostgresSaver.from_conn_string(DEFAULT_URI) as saver: + with PyMySQLSaver.from_conn_string(DEFAULT_URI) as saver: # save checkpoints saver.put(self.config_1, self.chkpnt_1, self.metadata_1, {}) saver.put(self.config_2, self.chkpnt_2, self.metadata_2, {})