diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 40b5c02..ac0a3a5 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,3 +1,2 @@ libraries/dagster-delta/ @ion-elgreco -libraries/dagster-delta-polars/ @ion-elgreco @sverbruggen .github/ @ion-elgreco \ No newline at end of file diff --git a/.github/validate-release-version.py b/.github/validate-release-version.py new file mode 100644 index 0000000..6a7469c --- /dev/null +++ b/.github/validate-release-version.py @@ -0,0 +1,38 @@ +"""Ensures that the files in `dist/` are prefixed with ${{ github.ref_name }} + +Tag must adhere to naming convention of distributed files. For example, the tag +`dagster_delta-0.1.2` must match the prefix of the files in the `dist/` folder: + + -rw-r--r--@ 2.0K Oct 23 14:06 dagster_delta-0.1.2-py3-none-any.whl + -rw-r--r--@ 1.6K Oct 23 14:06 dagster_delta-0.1.2.tar.gz + +USAGE + + $ python .github/validate-release-version.py libraries/dagster-delta/dist dagster_delta-0.1.3 + +""" + +import sys +import os + + +if len(sys.argv) != 3: + print("Requires positional arguments: ") + sys.exit(1) + +dist_path = sys.argv[1] +github_ref_name = sys.argv[2] + +if not os.path.exists(dist_path): + print("Release directory `dist/` must exist") + sys.exit(1) + +for filename in os.listdir(dist_path): + if filename.startswith("."): + continue + if not filename.startswith(github_ref_name): + print(f"{filename} does not start with prefix {github_ref_name}") + sys.exit(1) + + +print(f"Success: all files in `dist/` are prefixed with {github_ref_name}") diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml deleted file mode 100644 index b7c1b6b..0000000 --- a/.github/workflows/CI.yml +++ /dev/null @@ -1,147 +0,0 @@ -name: dagster-delta-[polars] CI/CD - -on: - push: - branches: - - main - - master - tags: - - "*" - pull_request: - workflow_dispatch: - -permissions: - contents: read - -jobs: - lint: - name: Linting checks - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v4 - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: "3.9" - - name: Install pypa/build - run: pip install uv && make .venv && VIRTUAL_ENV=./.venv - - name: CI-check - run: source .venv/bin/activate && make ci-check - - tests: - name: Tests - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v4 - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: "3.9" - - name: Install pypa/build - run: pip install uv && make .venv && VIRTUAL_ENV=./.venv - - name: Execute tests - run: source .venv/bin/activate && pytest . - - build: - name: Build wheels - runs-on: ubuntu-latest - needs: - - lint - - tests - steps: - - uses: actions/checkout@v4 - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: "3.9" - - name: Install pypa/build - run: pip install uv && make .venv && VIRTUAL_ENV=./.venv - - name: Build a binary wheel - run: source .venv/bin/activate && python3 -m build -w "libraries/dagster-delta" && python3 -m build -w "libraries/dagster-delta-polars" && python3 -m build -w "libraries/dagster-unity-catalog-polars" - - name: Store the distribution packages - uses: actions/upload-artifact@v3 - with: - name: python-package-distributions - path: libraries/**/dist/ - - publish-dagster-delta: - name: >- - Publish Python 🐍 distribution πŸ“¦ to PyPI - if: startsWith(github.ref, 'refs/tags/') # only publish to PyPI on tag pushes - needs: - - build - runs-on: ubuntu-latest - environment: - name: pypi - url: https://pypi.org/p/dagster-delta - - steps: - - name: Download all the dists - uses: actions/download-artifact@v3 - with: - name: python-package-distributions - path: my_dists/ - - name: list files - run: ls my_dists/dagster-delta/dist/ - - name: Publish distribution πŸ“¦ to PyPI - uses: pypa/gh-action-pypi-publish@release/v1 - with: - packages-dir: my_dists/dagster-delta/dist/ - password: ${{ secrets.PYPI_API_TOKEN_DD }} - verbose: true - - - publish-dagster-delta-polars: - name: >- - Publish Python 🐍 distribution πŸ“¦ to PyPI - if: startsWith(github.ref, 'refs/tags/') # only publish to PyPI on tag pushes - needs: - - build - runs-on: ubuntu-latest - environment: - name: pypi - url: https://pypi.org/p/dagster-delta-polars - - steps: - - name: Download all the dists - uses: actions/download-artifact@v3 - with: - name: python-package-distributions - path: my_dists/ - - name: list files - run: ls my_dists/dagster-delta-polars/dist/ - - name: Publish distribution πŸ“¦ to PyPI - uses: pypa/gh-action-pypi-publish@release/v1 - with: - packages-dir: my_dists/dagster-delta-polars/dist/ - password: ${{ secrets.PYPI_API_TOKEN_DDP }} - verbose: true - - publish-dagster-unity-catalog-polars: - name: >- - Publish Python 🐍 distribution πŸ“¦ to PyPI - if: startsWith(github.ref, 'refs/tags/') # only publish to PyPI on tag pushes - needs: - - build - runs-on: ubuntu-latest - environment: - name: pypi - url: https://pypi.org/p/dagster-unity-catalog-polars - - steps: - - name: Download all the dists - uses: actions/download-artifact@v3 - with: - name: python-package-distributions - path: my_dists/ - - name: list files - run: ls my_dists/dagster-unity-catalog-polars/dist/ - - name: Publish distribution πŸ“¦ to PyPI - uses: pypa/gh-action-pypi-publish@release/v1 - with: - packages-dir: my_dists/dagster-unity-catalog-polars/dist/ - password: ${{ secrets.PYPI_API_TOKEN_DDUC }} - verbose: true - diff --git a/.github/workflows/quality-dagster-delta.yml b/.github/workflows/quality-dagster-delta.yml new file mode 100644 index 0000000..05827d9 --- /dev/null +++ b/.github/workflows/quality-dagster-delta.yml @@ -0,0 +1,12 @@ +name: quality-check-dagster-delta +on: + pull_request: + types: [opened, synchronize, reopened] + paths: + - 'libraries/dagster-delta/**' + +jobs: + check: + uses: ./.github/workflows/template-quality-check.yml + with: + working_directory: ./libraries/dagster-delta \ No newline at end of file diff --git a/.github/workflows/quality-dagster-unity-catalog-polars.yml b/.github/workflows/quality-dagster-unity-catalog-polars.yml new file mode 100644 index 0000000..e7b8563 --- /dev/null +++ b/.github/workflows/quality-dagster-unity-catalog-polars.yml @@ -0,0 +1,12 @@ +name: quality-check-dagster-unity-catalog-polars +on: + pull_request: + types: [opened, synchronize, reopened] + paths: + - 'libraries/dagster-unity-catalog-polars/**' + +jobs: + check: + uses: ./.github/workflows/template-quality-check.yml + with: + working_directory: ./libraries/dagster-unity-catalog-polars \ No newline at end of file diff --git a/.github/workflows/release-dagster-delta-polars.yml b/.github/workflows/release-dagster-delta-polars.yml new file mode 100644 index 0000000..deb4829 --- /dev/null +++ b/.github/workflows/release-dagster-delta-polars.yml @@ -0,0 +1,15 @@ +name: build-and-release-dagster-delta-polars + +on: + push: + tags: + - 'dagster_delta_polars-*.*.*' + +jobs: + build-and-release-dagster-delta-polars: + uses: ./.github/workflows/template-release.yml + with: + library_name: dagster-delta-polars + working_directory: ./libraries/dagster-delta-polars + secrets: + pypi_token: ${{ secrets.PYPI_API_TOKEN_DDP }} \ No newline at end of file diff --git a/.github/workflows/release-dagster-delta.yml b/.github/workflows/release-dagster-delta.yml new file mode 100644 index 0000000..1f6bc09 --- /dev/null +++ b/.github/workflows/release-dagster-delta.yml @@ -0,0 +1,15 @@ +name: build-and-release-dagster-delta + +on: + push: + tags: + - 'dagster_delta-*.*.*' + +jobs: + build-and-release-dagster-delta: + uses: ./.github/workflows/template-release.yml + with: + library_name: dagster-delta + working_directory: ./libraries/dagster-delta + secrets: + pypi_token: ${{ secrets.PYPI_API_TOKEN_DD }} \ No newline at end of file diff --git a/.github/workflows/release-dagster-unity-catalog-polars.yml b/.github/workflows/release-dagster-unity-catalog-polars.yml new file mode 100644 index 0000000..62737c0 --- /dev/null +++ b/.github/workflows/release-dagster-unity-catalog-polars.yml @@ -0,0 +1,15 @@ +name: build-and-release-dagster-unity-catalog-polars + +on: + push: + tags: + - 'dagster_unity_catalog_polars-*.*.*' + +jobs: + build-and-release-dagster-unity-catalog-polars: + uses: ./.github/workflows/template-release.yml + with: + library_name: dagster-unity-catalog-polars + working_directory: ./libraries/dagster-unity-catalog-polars + secrets: + pypi_token: ${{ secrets.PYPI_API_TOKEN_DDUC }} \ No newline at end of file diff --git a/.github/workflows/template-quality-check.yml b/.github/workflows/template-quality-check.yml new file mode 100644 index 0000000..354a1f5 --- /dev/null +++ b/.github/workflows/template-quality-check.yml @@ -0,0 +1,43 @@ +name: quality-check + +on: + workflow_call: + inputs: + working_directory: + required: true + type: string + +jobs: + check: + runs-on: ubuntu-latest + steps: + + - name: Checkout + uses: actions/checkout@v4 + + - name: Install uv + uses: astral-sh/setup-uv@v3 + + - name: Install python + working-directory: ${{ inputs.working_directory }} + run: uv python install 3.12 + + - name: Sync dependencies + working-directory: ${{ inputs.working_directory }} + run: uv sync + + - name: Ruff (lint) + working-directory: ${{ inputs.working_directory }} + run: uv run ruff check + + - name: Ruff (formatting) + working-directory: ${{ inputs.working_directory }} + run: uv run ruff format --check . + + - name: Pyright + working-directory: ${{ inputs.working_directory }} + run: uv run pyright + + - name: Pytest + working-directory: ${{ inputs.working_directory }} + run: uv run pytest \ No newline at end of file diff --git a/.github/workflows/template-release.yml b/.github/workflows/template-release.yml new file mode 100644 index 0000000..892d4b2 --- /dev/null +++ b/.github/workflows/template-release.yml @@ -0,0 +1,50 @@ +# References +# +# https://docs.astral.sh/uv/guides/integration/github/ +# https://docs.astral.sh/uv/guides/publish/#preparing-your-project-for-packaging +# https://docs.pypi.org/trusted-publishers/adding-a-publisher/ +# + +name: build-and-release + +on: + workflow_call: + inputs: + library_name: + required: true + type: string + working_directory: + required: true + type: string + secrets: + pypi_token: + required: true +jobs: + build: + name: python + runs-on: ubuntu-latest + environment: production + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Install uv + uses: astral-sh/setup-uv@v3 + + - name: Install Python + working-directory: ${{ inputs.working_directory }} + run: uv python install + + - name: Build + working-directory: ${{ inputs.working_directory }} + run: uv build + + - name: Validate release version + run: python .github/validate-release-version.py ${{ inputs.working_directory }}/dist ${{ github.ref_name }} + + - name: Publish + working-directory: ${{ inputs.working_directory }} + run: uv publish + env: + UV_PUBLISH_TOKEN: ${{ secrets.pypi_token }} \ No newline at end of file diff --git a/Makefile b/Makefile index 481dd49..c5319c3 100644 --- a/Makefile +++ b/Makefile @@ -18,35 +18,15 @@ else RESET := "" endif -.venv: ## Set up virtual environment and install requirements - uv venv - $(MAKE) requirements -.PHONY: requirements -requirements: .venv ## Install/refresh all project requirements - uv pip install -r requirements.txt \ - -e libraries/dagster-delta \ - -e libraries/dagster-delta-polars \ - -e libraries/dagster-unity-catalog-polars \ - --config-settings editable_mode=compat - .PHONY: pre-commit -pre-commit: .venv ## Run autoformatting and linting +pre-commit: @echo "${GREEN}Formatting with ruff...${RESET}" - $(VENV_BIN)/ruff format . + uv run ruff format . @echo "${GREEN}Linting with ruff...${RESET}" - $(VENV_BIN)/ruff check . + uv run ruff check . @echo "${GREEN}Running static type checks...${RESET}" - $(VENV_BIN)/pyright . - -.PHONY: ci-check -ci-check: .venv ## Checks autoformatting and linting - @echo "${GREEN}Checking formatting with ruff...${RESET}" - $(VENV_BIN)/ruff format --check . - @echo "${GREEN}Linting with ruff...${RESET}" - $(VENV_BIN)/ruff check . - @echo "${GREEN}Running static type checks...${RESET}" - $(VENV_BIN)/pyright . + uv run pyright . .PHONY: clean clean: ## Remove environment and the caches diff --git a/README.md b/README.md index a904e78..2c78dc7 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,217 @@ # dagster-delta -Dagster deltalake implementation for Polars with a LakeFS IO manager, forked from dagster-deltalake with customizations +Dagster deltalake implementation for Pyarrow & Polars. Originally forked from dagster-deltalake with customizations. + +The IO Managers support partition mapping, custom write modes, special metadata configuration for advanced use cases. + +The supported write modes: + +- **error** +- **append** +- **overwrite** +- **ignore** +- **merge** +- **create_or_replace** + +## Merge + +dagster-delta supports MERGE execution with a couple pre-defined MERGE types (dagster_delta.config.MergeType): + +- **deduplicate_insert** <- Deduplicates on write +- **update_only** <- updates only the matches records +- **upsert** <- updates existing matches and inserts non matched records +- **replace_and_delete_unmatched** <- updates existing matches and deletes unmatched + +Example: +```python +from dagster_delta import DeltaLakePolarsIOManager, WriteMode, MergeConfig, MergeType +from dagster_delta_polars import DeltaLakePolarsIOManager + +@asset( + key_prefix=["my_schema"] # will be used as the schema (parent folder) in Delta Lake +) +def my_table() -> pl.DataFrame: # the name of the asset will be the table name + ... + +defs = Definitions( + assets=[my_table], + resources={"io_manager": DeltaLakePolarsIOManager( + root_uri="s3://bucket", + mode=WriteMode.merge, # or just "merge" + merge_config=MergeConfig( + merge_type=MergeType.upsert, + predicate="s.a = t.a", + source_alias="s", + target_alias="t", + ) + )} +) +``` + +## Special metadata configurations + +### **Add** additional `table_configuration` +Specify additional table configurations for `configuration` in `write_deltalake`. + +```python +@dg.asset( + io_manager_key = "deltalake_io_manager", + metadata={"table_configuration": { + "delta.enableChangeDataFeed": "true" + }}, +) +def my_asset() -> pl.DataFrame: + ... + +``` + +### **Overwrite** the write `mode` +Override the write `mode` to be used in `write_deltalake`. + +```python +@dg.asset( + io_manager_key = "deltalake_io_manager", + metadata={"mode": "append"}, +) +def my_asset() -> pl.DataFrame: + ... + +``` + +### **Overwrite** the `custom_metadata` +Override the `custom_metadata` to be used in `write_deltalake`. + +```python +@dg.asset( + io_manager_key = "deltalake_io_manager", + metadata={"custom_metadata": {"owner":"John Doe"}}, +) +def my_asset() -> pl.DataFrame: + ... + +``` + +### **Overwrite** the write `schema_mode` +Override the `schema_mode` to be used in `write_deltalake`. + +```python +@dg.asset( + io_manager_key = "deltalake_io_manager", + metadata={"schema_mode": "merge"}, +) +def my_asset() -> pl.DataFrame: + ... + +``` + +### **Overwrite** the `writer_properties` +Override the `writer_properties` to be used in `write_deltalake`. + +```python +@dg.asset( + io_manager_key = "deltalake_io_manager", + metadata={"writer_properties": { + "compression": "SNAPPY", + }}, +) +def my_asset() -> pl.DataFrame: + ... + +``` + +### **Overwrite** the `merge_predicate` +Override the `merge_predicate` to be used with `merge` execution. + +```python +@dg.asset( + io_manager_key = "deltalake_io_manager", + metadata={"merge_predicate": "s.foo = t.foo AND s.bar = t.bar"}, +) +def my_asset() -> pl.DataFrame: + ... + +``` + +### **Overwrite** the `schema` +Override the `schema` of where the table will be saved + +```python +@dg.asset( + io_manager_key = "deltalake_io_manager", + metadata={"schema": "custom_db_schema"}, +) +def my_asset() -> pl.DataFrame: + ... + +``` + +### **Set** the `columns` that need to be read +Override the `columns` to only load these columns in + +```python +@dg.asset( + io_manager_key = "deltalake_io_manager", + ins = { + "upstream_asset": dg.AssetIn(metadata={"columns":["foo","bar"]}) + } +) +def my_asset(upstream_asset) -> pl.DataFrame: + ... + +``` + +### **Override** table name using `root_name` + +Instead of using the asset_name for the table name it's possible to set a custom table name using the `root_name` in the asset defintion metadata. + +This is useful where you have two or multiple assets who have the same table structure, but each asset is a subset of the full table partition_definition, and it wasn't possible to combine this into a single asset due to requiring different underlying Op logic and/or upstream assets: + +```python +import polars as pl +import dagster as dg + +@dg.asset( + io_manager_key = "deltalake_io_manager", + partitions_def=dg.StaticPartitionsDefinition(["a", "b"]), + metadata={ + "partition_expr": "foo", + "root_name": "asset_partitioned", + }, +) +def asset_partitioned_1(upstream_1: pl.DataFrame, upstream_2: pl.DataFrame) -> pl.DataFrame: + ... + +@dg.asset( + partitions_def=dg.StaticPartitionsDefinition(["c", "d"]), + metadata={ + "partition_expr": "foo", + "root_name": "asset_partitioned", + }, +) +def asset_partitioned_2(upstream_3: pl.DataFrame, upstream_4: pl.DataFrame) -> pl.DataFrame: + ... + +``` + +Effectively this would be the flow: + +``` + + {static_partition_def: [a,b]} +β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” +β”‚upstream 1 β”œβ”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” +β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ β”‚ β”‚ write to storage on partition (a,b) +β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” └─► asset_partitioned_1 β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” +β”‚upstream 2 β”œβ”€β”€β”€β–Ί β”‚ β”‚ +β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β–Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” + β”‚ partitions β”‚ + β”‚ asset_partitioned: β”‚ + β”‚ [a,b,c,d] β”‚ +β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β–²β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ +β”‚upstream 3 β”œβ”€β”€β”β”‚ β”‚ β”‚ +β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β–Ί asset_partitioned_2 β”‚ β”‚ +β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β–Ί β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ +β”‚upstream 4 β”œβ”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ write to storage on partition (c,d) +β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ + {static_partition_def: [c,d]} + +``` \ No newline at end of file diff --git a/libraries/dagster-delta-polars/README.md b/libraries/dagster-delta-polars/README.md index abd350c..8a814ab 100644 --- a/libraries/dagster-delta-polars/README.md +++ b/libraries/dagster-delta-polars/README.md @@ -1,4 +1,5 @@ # dagster-delta-polars -A fork of the `dagster-deltalake-polars` library with additional improvements. The docs for `dagster-delta-polars` can be found -[here](...). +!! DEPRECATED !! + +Polars integration has moved into `dagster-delta`. Moving forward use `dagster-delta[polars]`. \ No newline at end of file diff --git a/libraries/dagster-delta-polars/dagster_delta_polars/__init__.py b/libraries/dagster-delta-polars/dagster_delta_polars/__init__.py index bcff267..213b648 100644 --- a/libraries/dagster-delta-polars/dagster_delta_polars/__init__.py +++ b/libraries/dagster-delta-polars/dagster_delta_polars/__init__.py @@ -1,3 +1,5 @@ +from warnings import warn + from .deltalake_polars_type_handler import ( DeltaLakePolarsIOManager, DeltaLakePolarsTypeHandler, @@ -7,3 +9,15 @@ "DeltaLakePolarsIOManager", "DeltaLakePolarsTypeHandler", ] + + +warn( + """This library has been deprecated. Polars integration has moved into `dagster-delta`. + This can be installed through `pip install dagster-delta[polars]` + + from dagster_delta.io_manager import DeltaLakePolarsIOManager + + """, + DeprecationWarning, + stacklevel=2, +) diff --git a/libraries/dagster-delta-polars/dagster_delta_polars/lakefs_io/deltalake_polars_lakefs_type_handler.py b/libraries/dagster-delta-polars/dagster_delta_polars/lakefs_io/deltalake_polars_lakefs_type_handler.py index d29482b..b7fa9e9 100644 --- a/libraries/dagster-delta-polars/dagster_delta_polars/lakefs_io/deltalake_polars_lakefs_type_handler.py +++ b/libraries/dagster-delta-polars/dagster_delta_polars/lakefs_io/deltalake_polars_lakefs_type_handler.py @@ -99,7 +99,7 @@ def handle_output( logger = logging.getLogger() logger.setLevel("DEBUG") - step_branch_name = f"{self.source_branch_name}-step-jobid-{context.run_id}-asset-{context.asset_key.to_user_string().replace('/','-').replace('_','-')}"[ + step_branch_name = f"{self.source_branch_name}-step-jobid-{context.run_id}-asset-{context.asset_key.to_user_string().replace('/', '-').replace('_', '-')}"[ 0:256 ] self.repository.branch(step_branch_name).create(source_reference=self.source_branch_name) @@ -139,6 +139,7 @@ def handle_output( # Since we don't care in the logging it got branched out there. metadata = {**context.consume_logged_metadata()} metadata["table_uri"] = MetadataValue.path(connection.table_uri) + # metadata["dagster/uri"] = MetadataValue.path(connection.table_uri) # noqa: ERA001 if self.lakefs_base_url is not None: metadata["lakefs_link"] = MetadataValue.url( _convert_s3_uri_to_lakefs_link(connection.table_uri, self.lakefs_base_url), diff --git a/libraries/dagster-delta-polars/dagster_delta_polars_tests/test_type_handler.py b/libraries/dagster-delta-polars/dagster_delta_polars_tests/test_type_handler.py index b59aa05..306a8d6 100644 --- a/libraries/dagster-delta-polars/dagster_delta_polars_tests/test_type_handler.py +++ b/libraries/dagster-delta-polars/dagster_delta_polars_tests/test_type_handler.py @@ -21,8 +21,7 @@ op, ) from dagster._check import CheckError -from dagster_delta import DELTA_DATE_FORMAT, LocalConfig -from dagster_delta.io_manager import WriteMode +from dagster_delta import DELTA_DATE_FORMAT, LocalConfig, WriteMode from deltalake import DeltaTable from dagster_delta_polars import DeltaLakePolarsIOManager @@ -30,7 +29,7 @@ warnings.filterwarnings("ignore", category=ExperimentalWarning) -@pytest.fixture() +@pytest.fixture def io_manager(tmp_path) -> DeltaLakePolarsIOManager: return DeltaLakePolarsIOManager( root_uri=str(tmp_path), diff --git a/libraries/dagster-delta-polars/dagster_delta_polars_tests/test_type_handler_save_modes.py b/libraries/dagster-delta-polars/dagster_delta_polars_tests/test_type_handler_save_modes.py index 980649c..fca20df 100644 --- a/libraries/dagster-delta-polars/dagster_delta_polars_tests/test_type_handler_save_modes.py +++ b/libraries/dagster-delta-polars/dagster_delta_polars_tests/test_type_handler_save_modes.py @@ -7,14 +7,13 @@ graph, op, ) -from dagster_delta import LocalConfig -from dagster_delta.io_manager import WriteMode +from dagster_delta import LocalConfig, WriteMode from deltalake import DeltaTable from dagster_delta_polars import DeltaLakePolarsIOManager -@pytest.fixture() +@pytest.fixture def io_manager(tmp_path) -> DeltaLakePolarsIOManager: return DeltaLakePolarsIOManager( root_uri=str(tmp_path), @@ -23,7 +22,7 @@ def io_manager(tmp_path) -> DeltaLakePolarsIOManager: ) -@pytest.fixture() +@pytest.fixture def io_manager_append(tmp_path) -> DeltaLakePolarsIOManager: return DeltaLakePolarsIOManager( root_uri=str(tmp_path), @@ -32,7 +31,7 @@ def io_manager_append(tmp_path) -> DeltaLakePolarsIOManager: ) -@pytest.fixture() +@pytest.fixture def io_manager_ignore(tmp_path) -> DeltaLakePolarsIOManager: return DeltaLakePolarsIOManager( root_uri=str(tmp_path), diff --git a/libraries/dagster-delta-polars/pyproject.toml b/libraries/dagster-delta-polars/pyproject.toml index 038efc8..a51effb 100644 --- a/libraries/dagster-delta-polars/pyproject.toml +++ b/libraries/dagster-delta-polars/pyproject.toml @@ -1,11 +1,12 @@ [project] name = "dagster-delta-polars" -version = "0.2.1" +version = "0.2.2" description = "Polars deltalake IO Managers for Dagster with optional LakeFS support" readme = "README.md" requires-python = ">=3.9" dependencies = [ - "dagster-delta", + "dagster-delta<=0.2.1", + "polars", "Deprecated", ] authors = [ @@ -113,6 +114,9 @@ lint.ignore = [ # Allow autofix for all enabled rules (when `--fix`) is provided. lint.fixable = ["ALL"] +# Allow unused variables when underscore-prefixed. +lint.dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" + # Exclude a variety of commonly ignored directories. exclude = [ ".bzr", @@ -141,9 +145,6 @@ exclude = [ # Same as Black. line-length = 100 -# Allow unused variables when underscore-prefixed. -lint.dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" - # Assume Python 3.9 target-version = "py39" @@ -157,3 +158,12 @@ suppress-none-returning = true [tool.ruff.lint.pydocstyle] # Use Google-style docstrings. convention = "google" + +[dependency-groups] +dev = [ + "lakefs>=0.8.0", + "pyarrow<=18.0.0", + "pyright>=1.1.393", + "pytest>=8.3.4", + "ruff>=0.9.5", +] diff --git a/libraries/dagster-delta/README.md b/libraries/dagster-delta/README.md index 17315cd..2c78dc7 100644 --- a/libraries/dagster-delta/README.md +++ b/libraries/dagster-delta/README.md @@ -1,4 +1,217 @@ # dagster-delta +Dagster deltalake implementation for Pyarrow & Polars. Originally forked from dagster-deltalake with customizations. -A fork of the `dagster-deltalake` library with additional improvements. The docs for `dagster-delta` can be found -[here](...). +The IO Managers support partition mapping, custom write modes, special metadata configuration for advanced use cases. + +The supported write modes: + +- **error** +- **append** +- **overwrite** +- **ignore** +- **merge** +- **create_or_replace** + +## Merge + +dagster-delta supports MERGE execution with a couple pre-defined MERGE types (dagster_delta.config.MergeType): + +- **deduplicate_insert** <- Deduplicates on write +- **update_only** <- updates only the matches records +- **upsert** <- updates existing matches and inserts non matched records +- **replace_and_delete_unmatched** <- updates existing matches and deletes unmatched + +Example: +```python +from dagster_delta import DeltaLakePolarsIOManager, WriteMode, MergeConfig, MergeType +from dagster_delta_polars import DeltaLakePolarsIOManager + +@asset( + key_prefix=["my_schema"] # will be used as the schema (parent folder) in Delta Lake +) +def my_table() -> pl.DataFrame: # the name of the asset will be the table name + ... + +defs = Definitions( + assets=[my_table], + resources={"io_manager": DeltaLakePolarsIOManager( + root_uri="s3://bucket", + mode=WriteMode.merge, # or just "merge" + merge_config=MergeConfig( + merge_type=MergeType.upsert, + predicate="s.a = t.a", + source_alias="s", + target_alias="t", + ) + )} +) +``` + +## Special metadata configurations + +### **Add** additional `table_configuration` +Specify additional table configurations for `configuration` in `write_deltalake`. + +```python +@dg.asset( + io_manager_key = "deltalake_io_manager", + metadata={"table_configuration": { + "delta.enableChangeDataFeed": "true" + }}, +) +def my_asset() -> pl.DataFrame: + ... + +``` + +### **Overwrite** the write `mode` +Override the write `mode` to be used in `write_deltalake`. + +```python +@dg.asset( + io_manager_key = "deltalake_io_manager", + metadata={"mode": "append"}, +) +def my_asset() -> pl.DataFrame: + ... + +``` + +### **Overwrite** the `custom_metadata` +Override the `custom_metadata` to be used in `write_deltalake`. + +```python +@dg.asset( + io_manager_key = "deltalake_io_manager", + metadata={"custom_metadata": {"owner":"John Doe"}}, +) +def my_asset() -> pl.DataFrame: + ... + +``` + +### **Overwrite** the write `schema_mode` +Override the `schema_mode` to be used in `write_deltalake`. + +```python +@dg.asset( + io_manager_key = "deltalake_io_manager", + metadata={"schema_mode": "merge"}, +) +def my_asset() -> pl.DataFrame: + ... + +``` + +### **Overwrite** the `writer_properties` +Override the `writer_properties` to be used in `write_deltalake`. + +```python +@dg.asset( + io_manager_key = "deltalake_io_manager", + metadata={"writer_properties": { + "compression": "SNAPPY", + }}, +) +def my_asset() -> pl.DataFrame: + ... + +``` + +### **Overwrite** the `merge_predicate` +Override the `merge_predicate` to be used with `merge` execution. + +```python +@dg.asset( + io_manager_key = "deltalake_io_manager", + metadata={"merge_predicate": "s.foo = t.foo AND s.bar = t.bar"}, +) +def my_asset() -> pl.DataFrame: + ... + +``` + +### **Overwrite** the `schema` +Override the `schema` of where the table will be saved + +```python +@dg.asset( + io_manager_key = "deltalake_io_manager", + metadata={"schema": "custom_db_schema"}, +) +def my_asset() -> pl.DataFrame: + ... + +``` + +### **Set** the `columns` that need to be read +Override the `columns` to only load these columns in + +```python +@dg.asset( + io_manager_key = "deltalake_io_manager", + ins = { + "upstream_asset": dg.AssetIn(metadata={"columns":["foo","bar"]}) + } +) +def my_asset(upstream_asset) -> pl.DataFrame: + ... + +``` + +### **Override** table name using `root_name` + +Instead of using the asset_name for the table name it's possible to set a custom table name using the `root_name` in the asset defintion metadata. + +This is useful where you have two or multiple assets who have the same table structure, but each asset is a subset of the full table partition_definition, and it wasn't possible to combine this into a single asset due to requiring different underlying Op logic and/or upstream assets: + +```python +import polars as pl +import dagster as dg + +@dg.asset( + io_manager_key = "deltalake_io_manager", + partitions_def=dg.StaticPartitionsDefinition(["a", "b"]), + metadata={ + "partition_expr": "foo", + "root_name": "asset_partitioned", + }, +) +def asset_partitioned_1(upstream_1: pl.DataFrame, upstream_2: pl.DataFrame) -> pl.DataFrame: + ... + +@dg.asset( + partitions_def=dg.StaticPartitionsDefinition(["c", "d"]), + metadata={ + "partition_expr": "foo", + "root_name": "asset_partitioned", + }, +) +def asset_partitioned_2(upstream_3: pl.DataFrame, upstream_4: pl.DataFrame) -> pl.DataFrame: + ... + +``` + +Effectively this would be the flow: + +``` + + {static_partition_def: [a,b]} +β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” +β”‚upstream 1 β”œβ”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” +β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ β”‚ β”‚ write to storage on partition (a,b) +β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” └─► asset_partitioned_1 β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” +β”‚upstream 2 β”œβ”€β”€β”€β–Ί β”‚ β”‚ +β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β–Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” + β”‚ partitions β”‚ + β”‚ asset_partitioned: β”‚ + β”‚ [a,b,c,d] β”‚ +β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β–²β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ +β”‚upstream 3 β”œβ”€β”€β”β”‚ β”‚ β”‚ +β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β–Ί asset_partitioned_2 β”‚ β”‚ +β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β–Ί β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ +β”‚upstream 4 β”œβ”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ write to storage on partition (c,d) +β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ + {static_partition_def: [c,d]} + +``` \ No newline at end of file diff --git a/libraries/dagster-delta/dagster_delta/__init__.py b/libraries/dagster-delta/dagster_delta/__init__.py index 2ad55ab..d2dd47b 100644 --- a/libraries/dagster-delta/dagster_delta/__init__.py +++ b/libraries/dagster-delta/dagster_delta/__init__.py @@ -1,53 +1,47 @@ -from collections.abc import Sequence +from dagster_delta.config import ( + AzureConfig, + BackoffConfig, + ClientConfig, + GcsConfig, + LocalConfig, + MergeConfig, + MergeType, + S3Config, +) +from dagster_delta.io_manager.arrow import DeltaLakePyarrowIOManager +from dagster_delta.io_manager.base import ( + BaseDeltaLakeIOManager, + SchemaMode, + WriteMode, + WriterEngine, +) +from dagster_delta.resources import DeltaTableResource -from dagster._core.storage.db_io_manager import DbTypeHandler +__all__ = [ + "AzureConfig", + "ClientConfig", + "GcsConfig", + "S3Config", + "LocalConfig", + "BackoffConfig", + "MergeConfig", + "MergeType", + "WriteMode", + "WriterEngine", + "SchemaMode", + "DeltaTableResource", + "BaseDeltaLakeIOManager", + "DeltaLakePyarrowIOManager", +] -from .config import ( - AzureConfig as AzureConfig, -) -from .config import ( - ClientConfig as ClientConfig, -) -from .config import ( - GcsConfig as GcsConfig, -) -from .config import ( - LocalConfig as LocalConfig, -) -from .config import ( - MergeConfig as MergeConfig, -) -from .config import ( - MergeType as MergeType, -) -from .config import ( - S3Config as S3Config, -) -from .handler import ( - DeltalakeBaseArrowTypeHandler as DeltalakeBaseArrowTypeHandler, -) -from .handler import ( - DeltaLakePyArrowTypeHandler as DeltaLakePyArrowTypeHandler, -) -from .io_manager import ( - DELTA_DATE_FORMAT as DELTA_DATE_FORMAT, -) -from .io_manager import ( - DELTA_DATETIME_FORMAT as DELTA_DATETIME_FORMAT, -) -from .io_manager import ( - DeltaLakeIOManager as DeltaLakeIOManager, -) -from .io_manager import ( - WriteMode as WriteMode, -) -from .io_manager import ( - WriterEngine as WriterEngine, -) -from .resource import DeltaTableResource as DeltaTableResource +try: + from dagster_delta.io_manager.polars import DeltaLakePolarsIOManager # noqa + + __all__.extend(["DeltaLakePolarsIOManager"]) -class DeltaLakePyarrowIOManager(DeltaLakeIOManager): # noqa: D101 - @staticmethod - def type_handlers() -> Sequence[DbTypeHandler]: # noqa: D102 - return [DeltaLakePyArrowTypeHandler()] +except ImportError as e: + if "polars" in str(e): + pass + else: + raise e diff --git a/libraries/dagster-delta/dagster_delta/_db_io_manager/__init__.py b/libraries/dagster-delta/dagster_delta/_db_io_manager/__init__.py new file mode 100644 index 0000000..a510853 --- /dev/null +++ b/libraries/dagster-delta/dagster_delta/_db_io_manager/__init__.py @@ -0,0 +1,5 @@ +from dagster_delta._db_io_manager.custom_db_io_manager import CustomDbIOManager + +__all__ = [ + "CustomDbIOManager", +] diff --git a/libraries/dagster-delta/dagster_delta/dbiomanager_fixed.py b/libraries/dagster-delta/dagster_delta/_db_io_manager/custom_db_io_manager.py similarity index 55% rename from libraries/dagster-delta/dagster_delta/dbiomanager_fixed.py rename to libraries/dagster-delta/dagster_delta/_db_io_manager/custom_db_io_manager.py index 5f1d5e1..e5bb9bb 100644 --- a/libraries/dagster-delta/dagster_delta/dbiomanager_fixed.py +++ b/libraries/dagster-delta/dagster_delta/_db_io_manager/custom_db_io_manager.py @@ -6,7 +6,6 @@ ) from dagster._core.definitions.multi_dimensional_partitions import ( - MultiPartitionKey, MultiPartitionsDefinition, ) from dagster._core.definitions.time_window_partitions import ( @@ -16,13 +15,38 @@ from dagster._core.execution.context.output import OutputContext from dagster._core.storage.db_io_manager import DbIOManager, TablePartitionDimension, TableSlice +from dagster_delta._db_io_manager.utils import ( + generate_multi_partitions_dimension, + generate_single_partition_dimension, +) + T = TypeVar("T") -class DbIOManagerFixed(DbIOManager): # noqa +class CustomDbIOManager(DbIOManager): + """Works exactly like the DbIOManager, but overrides the _get_table_slice method + to provide support for partition mapping. e.g. a mapping from partition A to partition B, + where A is partitioned on two dimensions and B is partitioned on only one dimension. + + Additionally, gives ability to override + the table name using `root_name` in the metadata. + + Example: + ``` + @dg.asset( + partitions_def=dg.StaticPartitionsDefinition(["a", "b"]), + metadata={ + "partition_expr": "foo", + "root_name": "asset_partitioned", + }, + ) + def asset_partitioned_1(upstream_1, upstream_2): + ``` + """ + def _get_table_slice( self, - context: Union[OutputContext, InputContext], # noqa + context: Union[OutputContext, InputContext], output_context: OutputContext, ) -> TableSlice: output_context_definition_metadata = output_context.definition_metadata or {} @@ -33,6 +57,7 @@ def _get_table_slice( if context.has_asset_key: asset_key_path = context.asset_key.path + ## Override the if output_context_definition_metadata.get("root_name"): table = output_context_definition_metadata["root_name"] else: @@ -58,54 +83,27 @@ def _get_table_slice( ) if isinstance(context.asset_partitions_def, MultiPartitionsDefinition): - multi_partition_key_mappings = [ - cast(MultiPartitionKey, partition_key).keys_by_dimension - for partition_key in context.asset_partition_keys - ] - for part in context.asset_partitions_def.partitions_defs: - partitions = [] - for multi_partition_key_mapping in multi_partition_key_mappings: - partition_key = multi_partition_key_mapping[part.name] - if isinstance(part.partitions_def, TimeWindowPartitionsDefinition): - partitions.append( - part.partitions_def.time_window_for_partition_key( - partition_key, # type: ignore - ), - ) - else: - partitions.append(partition_key) - - partition_expr_str = cast(Mapping[str, str], partition_expr).get(part.name) - if partition_expr is None: - raise ValueError( - f"Asset '{context.asset_key}' has partition {part.name}, but the" - f" 'partition_expr' metadata does not contain a {part.name} entry," - " so we don't know what column to filter it on. Specify which" - " column of the database contains data for the" - f" {part.name} partition.", - ) - partition_dimensions.append( - TablePartitionDimension( - partition_expr=cast(str, partition_expr_str), - partitions=partitions, - ), - ) - elif isinstance(context.asset_partitions_def, TimeWindowPartitionsDefinition): - partition_dimensions.append( - TablePartitionDimension( - partition_expr=cast(str, partition_expr), - partitions=( - context.asset_partitions_time_window - if context.asset_partition_keys - else [] - ), + partition_dimensions.extend( + generate_multi_partitions_dimension( + asset_partition_keys=context.asset_partition_keys, + asset_partitions_def=context.asset_partitions_def, + partition_expr=cast(Mapping[str, str], partition_expr), + asset_key=context.asset_key, ), ) else: partition_dimensions.append( - TablePartitionDimension( + generate_single_partition_dimension( partition_expr=cast(str, partition_expr), - partitions=context.asset_partition_keys, + asset_partition_keys=context.asset_partition_keys, + asset_partitions_time_window=( + context.asset_partitions_time_window + if isinstance( + context.asset_partitions_def, + TimeWindowPartitionsDefinition, + ) + else None + ), ), ) else: diff --git a/libraries/dagster-delta/dagster_delta/_db_io_manager/utils.py b/libraries/dagster-delta/dagster_delta/_db_io_manager/utils.py new file mode 100644 index 0000000..7702d96 --- /dev/null +++ b/libraries/dagster-delta/dagster_delta/_db_io_manager/utils.py @@ -0,0 +1,149 @@ +import datetime as dt +from typing import List, Mapping, Sequence, Union, cast # noqa + +import pendulum +from dagster import ( + AssetKey, + MultiPartitionKey, + MultiPartitionsDefinition, + TimeWindowPartitionsDefinition, +) +from dagster._core.definitions.time_window_partitions import TimeWindow +from dagster._core.storage.db_io_manager import TablePartitionDimension +from pendulum import instance as pdi + + +def generate_multi_partitions_dimension( + asset_partition_keys: Sequence[str], + asset_partitions_def: MultiPartitionsDefinition, + partition_expr: Mapping[str, str], + asset_key: AssetKey, +) -> list[TablePartitionDimension]: + """Generates multi partition dimensions.""" + partition_dimensions: list[TablePartitionDimension] = [] + multi_partition_key_mappings = [ + cast(MultiPartitionKey, partition_key).keys_by_dimension + for partition_key in asset_partition_keys + ] + for part in asset_partitions_def.partitions_defs: + partitions: list[Union[TimeWindow, str]] = [] + for multi_partition_key_mapping in multi_partition_key_mappings: + partition_key = multi_partition_key_mapping[part.name] + if isinstance(part.partitions_def, TimeWindowPartitionsDefinition): + partitions.append( + part.partitions_def.time_window_for_partition_key(partition_key), + ) + else: + partitions.append(partition_key) + + partition_expr_str = partition_expr.get(part.name) + if partition_expr_str is None: + raise ValueError( + f"Asset '{asset_key}' has partition {part.name}, but the" + f" 'partition_expr' metadata does not contain a {part.name} entry," + " so we don't know what column to filter it on. Specify which" + " column of the database contains data for the" + f" {part.name} partition.", + ) + partitions_: TimeWindow | Sequence[str] + if all(isinstance(partition, TimeWindow) for partition in partitions): + checker = MultiTimePartitionsChecker( + partitions=cast(list[TimeWindow], partitions), + ) + if not checker.is_consecutive(): + raise ValueError("Dates are not consecutive.") + partitions_ = TimeWindow( + start=checker.start, + end=checker.end, + ) + elif all(isinstance(partition, str) for partition in partitions): + partitions_ = list(set(cast(list[str], partitions))) + else: + raise ValueError("Unknown partition type") + partition_dimensions.append( + TablePartitionDimension( + partition_expr=cast(str, partition_expr_str), + partitions=partitions_, + ), + ) + return partition_dimensions + + +def generate_single_partition_dimension( + partition_expr: str, + asset_partition_keys: Sequence[str], + asset_partitions_time_window: TimeWindow | None, +) -> TablePartitionDimension: + """Given a single partition, generate a TablePartitionDimension object that can be used to create a TableSlice object. + + Args: + partition_expr (str): Partition expression for the asset partition + asset_partition_keys (Sequence[str]): Partition keys for the asset + asset_partitions_time_window (TimeWindow | None): TimeWindow object for the asset partition + + Returns: + TablePartitionDimension: TablePartitionDimension object + """ + partition_dimension: TablePartitionDimension + if isinstance(asset_partitions_time_window, TimeWindow): + partition_dimension = TablePartitionDimension( + partition_expr=partition_expr, + partitions=(asset_partitions_time_window if asset_partition_keys else []), + ) + else: + partition_dimension = TablePartitionDimension( + partition_expr=partition_expr, + partitions=asset_partition_keys, + ) + return partition_dimension + + +class MultiTimePartitionsChecker: + def __init__(self, partitions: list[TimeWindow]): + """Helper class that defines checks on a list of TimeWindow objects + most importantly, partitions should be consecutive. + + Args: + partitions (list[TimeWindow]): List of TimeWindow objects + """ + self._partitions = partitions + + start_date = min([w.start for w in self._partitions]) + end_date = max([w.end for w in self._partitions]) + + if not isinstance(start_date, dt.datetime): + raise ValueError("Start date is not a datetime") + if not isinstance(end_date, dt.datetime): + raise ValueError("End date is not a datetime") + + self.start = start_date + self.end = end_date + + @property + def hourly_delta(self) -> int: + deltas = [date_diff(w.start, w.end).in_hours() for w in self._partitions] + if len(set(deltas)) != 1: + raise ValueError( + "TimeWindowPartitionsDefinition must have the same delta from start to end", + ) + return int(deltas[0]) + + def is_consecutive(self) -> bool: + """Checks whether the provided start dates of each partition timewindow is consecutive""" + return ( + len( + { + pdi(self.start).add(hours=self.hourly_delta * i) + for i in range(date_diff(self.start, self.end).in_days() + 1) + } + - {pdi(d.start) for d in self._partitions}, + ) + == 1 + ) + + +def date_diff(start: dt.datetime, end: dt.datetime) -> pendulum.Interval: + """Compute an interval between two dates""" + start_ = pendulum.instance(start) + end_ = pendulum.instance(end) + return end_ - start_ diff --git a/libraries/dagster-delta/dagster_delta/_handler/__init__.py b/libraries/dagster-delta/dagster_delta/_handler/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/libraries/dagster-delta/dagster_delta/_handler/base.py b/libraries/dagster-delta/dagster_delta/_handler/base.py new file mode 100644 index 0000000..03800f4 --- /dev/null +++ b/libraries/dagster-delta/dagster_delta/_handler/base.py @@ -0,0 +1,281 @@ +import logging +from abc import abstractmethod +from typing import Any, Generic, Optional, TypeVar, Union, cast + +import pyarrow as pa +import pyarrow.compute as pc +import pyarrow.dataset as ds +from dagster import ( + InputContext, + MetadataValue, + OutputContext, + TableColumn, + TableSchema, +) +from dagster._core.storage.db_io_manager import DbTypeHandler, TableSlice +from deltalake import CommitProperties, DeltaTable, WriterProperties, write_deltalake +from deltalake.exceptions import TableNotFoundError +from deltalake.schema import Schema, _convert_pa_schema_to_delta +from deltalake.table import FilterLiteralType + +from dagster_delta._handler.merge import merge_execute +from dagster_delta._handler.utils import ( + create_predicate, + extract_date_format_from_partition_definition, + partition_dimensions_to_dnf, + read_table, +) +from dagster_delta.io_manager.base import ( + TableConnection, + _DeltaTableIOManagerResourceConfig, +) + +T = TypeVar("T") +ArrowTypes = Union[pa.Table, pa.RecordBatchReader, ds.Dataset] + + +class DeltalakeBaseArrowTypeHandler(DbTypeHandler[T], Generic[T]): # noqa: D101 + @abstractmethod + def from_arrow(self, obj: pa.RecordBatchReader, target_type: type) -> T: + """Abstract method to convert arrow to target type""" + pass + + @abstractmethod + def to_arrow(self, obj: T) -> tuple[ArrowTypes, dict[str, Any]]: + """Abstract method to convert type to arrow""" + pass + + @abstractmethod + def get_output_stats(self, obj: T) -> dict[str, MetadataValue]: + """Abstract method to return output stats""" + pass + + def handle_output( + self, + context: OutputContext, + table_slice: TableSlice, + obj: T, + connection: TableConnection, + ): + """Stores pyarrow types in Delta table.""" + logger = logging.getLogger() + logger.setLevel("DEBUG") + definition_metadata = context.definition_metadata or {} + merge_predicate_from_metadata = definition_metadata.get("merge_predicate") + additional_table_config = definition_metadata.get("table_configuration", {}) + if connection.table_config is not None: + table_config = additional_table_config | connection.table_config + else: + table_config = additional_table_config + resource_config = context.resource_config or {} + object_stats = self.get_output_stats(obj) + data, delta_params = self.to_arrow(obj=obj) + delta_schema = Schema.from_pyarrow(_convert_pa_schema_to_delta(data.schema)) + resource_config = cast(_DeltaTableIOManagerResourceConfig, context.resource_config) + engine = resource_config.get("writer_engine") + save_mode = definition_metadata.get("mode") + main_save_mode = resource_config.get("mode") + custom_metadata = definition_metadata.get("custom_metadata") or resource_config.get( + "custom_metadata", + ) + schema_mode = definition_metadata.get("schema_mode") or resource_config.get( + "schema_mode", + ) + writer_properties = resource_config.get("writer_properties") + writer_properties = ( + WriterProperties(**writer_properties) if writer_properties is not None else None # type: ignore + ) + + commit_properties = definition_metadata.get("commit_properties") or resource_config.get( + "commit_properties", + ) + commit_properties = ( + CommitProperties(**commit_properties) if commit_properties is not None else None # type: ignore + ) + merge_config = resource_config.get("merge_config") + + date_format = extract_date_format_from_partition_definition(context) + + if save_mode is not None: + logger.debug( + "IO manager mode overridden with the asset metadata mode, %s -> %s", + main_save_mode, + save_mode, + ) + main_save_mode = save_mode + logger.debug("Writing with mode: `%s`", main_save_mode) + + merge_stats = None + partition_filters = None + partition_columns = None + predicate = None + + if table_slice.partition_dimensions is not None: + partition_filters = partition_dimensions_to_dnf( + partition_dimensions=table_slice.partition_dimensions, + table_schema=delta_schema, + str_values=True, + date_format=date_format, + ) + if partition_filters is not None and engine == "rust": + ## Convert partition_filter to predicate + predicate = create_predicate(partition_filters) + partition_filters = None + else: + predicate = None + # TODO(): make robust and move to function + partition_columns = [dim.partition_expr for dim in table_slice.partition_dimensions] + + if main_save_mode not in ["merge", "create_or_replace"]: + if predicate is not None and engine == "rust": + logger.debug("Using explicit partition predicate: \n%s", predicate) + elif partition_filters is not None and engine == "pyarrow": + logger.debug("Using explicit partition_filter: \n%s", partition_filters) + write_deltalake( # type: ignore + table_or_uri=connection.table_uri, + data=data, + storage_options=connection.storage_options, + mode=main_save_mode, + partition_filters=partition_filters, + predicate=predicate, + partition_by=partition_columns, + engine=engine, + schema_mode=schema_mode, + configuration=table_config, + custom_metadata=custom_metadata, + writer_properties=writer_properties, + commit_properties=commit_properties, + **delta_params, + ) + elif main_save_mode == "create_or_replace": + DeltaTable.create( + table_uri=connection.table_uri, + schema=_convert_pa_schema_to_delta(data.schema), + mode="overwrite", + partition_by=partition_columns, + configuration=table_config, + storage_options=connection.storage_options, + custom_metadata=custom_metadata, + ) + else: + if merge_config is None: + raise ValueError( + "Merge Configuration should be provided when `mode = WriterMode.merge`", + ) + try: + dt = DeltaTable(connection.table_uri, storage_options=connection.storage_options) + except TableNotFoundError: + logger.debug("Creating a DeltaTable first before merging.") + dt = DeltaTable.create( + table_uri=connection.table_uri, + schema=_convert_pa_schema_to_delta(data.schema), + partition_by=partition_columns, + configuration=table_config, + storage_options=connection.storage_options, + custom_metadata=custom_metadata, + ) + merge_stats = merge_execute( + dt, + data, + merge_config, + writer_properties=writer_properties, + commit_properties=commit_properties, + custom_metadata=custom_metadata, + delta_params=delta_params, + merge_predicate_from_metadata=merge_predicate_from_metadata, + partition_filters=partition_filters, + ) + + dt = DeltaTable(connection.table_uri, storage_options=connection.storage_options) + try: + stats = _get_partition_stats(dt=dt, partition_filters=partition_filters) + except Exception as e: + context.log.warning(f"error while computing table stats: {e}") + stats = {} + + output_metadata = { + # "dagster/table_name": table_slice.table, + "table_uri": MetadataValue.path(connection.table_uri), + # "dagster/uri": MetadataValue.path(connection.table_uri), + "dagster/column_schema": MetadataValue.table_schema( + TableSchema( + columns=[ + TableColumn(name=name, type=str(dtype)) + for name, dtype in zip(data.schema.names, data.schema.types) + ], + ), + ), + "table_version": MetadataValue.int(dt.version()), + **stats, + **object_stats, + } + if merge_stats is not None: + output_metadata["num_output_rows"] = MetadataValue.int( + merge_stats.get("num_output_rows", 0), + ) + output_metadata["merge_stats"] = MetadataValue.json(merge_stats) + + context.add_output_metadata(output_metadata) + + def load_input( + self, + context: InputContext, + table_slice: TableSlice, + connection: TableConnection, + ) -> T: + """Loads the input as a pyarrow Table or RecordBatchReader.""" + parquet_read_options = None + if context.resource_config is not None: + parquet_read_options = context.resource_config.get("parquet_read_options", None) + parquet_read_options = ( + ds.ParquetReadOptions(**parquet_read_options) + if parquet_read_options is not None + else None + ) + + dataset = read_table(table_slice, connection, parquet_read_options=parquet_read_options) + + if context.dagster_type.typing_type == ds.Dataset: + if table_slice.columns is not None: + raise ValueError("Cannot select columns when loading as Dataset.") + return dataset + + scanner = dataset.scanner(columns=table_slice.columns) + return self.from_arrow(scanner.to_reader(), context.dagster_type.typing_type) + + +def _get_partition_stats( + dt: DeltaTable, + partition_filters: Optional[list[FilterLiteralType]] = None, +) -> dict[str, Any]: + """Gets the stats for a partition + + Args: + dt (DeltaTable): DeltaTable object + partition_filters (list[FilterLiteralType] | None, optional): filters to grabs stats with. Defaults to None. + + Returns: + dict[str, MetadataValue]: Partition stats + """ + files = pa.array(dt.files(partition_filters=partition_filters)) + files_table = pa.Table.from_arrays([files], names=["path"]) + actions_table = pa.Table.from_batches([dt.get_add_actions(flatten=True)]) + actions_table = actions_table.select(["path", "size_bytes", "num_records"]) + table = files_table.join(actions_table, keys="path") + + stats: dict[str, Any] + + stats = { + "size_MB": MetadataValue.float( + pc.sum(table.column("size_bytes")).as_py() * 0.00000095367432, # type: ignore + ), + } + row_count = MetadataValue.int( + pc.sum(table.column("num_records")).as_py(), # type: ignore + ) + if partition_filters is not None: + stats["dagster/partition_row_count"] = row_count + else: + stats["dagster/row_count"] = row_count + + return stats diff --git a/libraries/dagster-delta/dagster_delta/_handler/merge.py b/libraries/dagster-delta/dagster_delta/_handler/merge.py new file mode 100644 index 0000000..6377d81 --- /dev/null +++ b/libraries/dagster-delta/dagster_delta/_handler/merge.py @@ -0,0 +1,68 @@ +import logging +from typing import Any, Optional, TypeVar, Union + +import pyarrow as pa +import pyarrow.dataset as ds +from deltalake import CommitProperties, DeltaTable, WriterProperties +from deltalake.table import FilterLiteralType + +from dagster_delta._handler.utils import create_predicate +from dagster_delta.config import MergeType + +T = TypeVar("T") +ArrowTypes = Union[pa.Table, pa.RecordBatchReader, ds.Dataset] + + +def merge_execute( + dt: DeltaTable, + data: Union[pa.RecordBatchReader, pa.Table], + merge_config: dict[str, Any], + writer_properties: Optional[WriterProperties], + commit_properties: Optional[CommitProperties], + custom_metadata: Optional[dict[str, str]], + delta_params: dict[str, Any], + merge_predicate_from_metadata: Optional[str], + partition_filters: Optional[list[FilterLiteralType]] = None, +) -> dict[str, Any]: + merge_type = merge_config.get("merge_type") + error_on_type_mismatch = merge_config.get("error_on_type_mismatch", True) + + if merge_predicate_from_metadata is not None: + predicate = str(merge_predicate_from_metadata) + elif merge_config.get("predicate") is not None: + predicate = str(merge_config.get("predicate")) + else: + raise Exception("merge predicate was not provided") + + target_alias = merge_config.get("target_alias") + + if partition_filters is not None: + partition_predicate = create_predicate(partition_filters, target_alias=target_alias) + + predicate = f"{predicate} AND {partition_predicate}" + logger = logging.getLogger() + logger.setLevel("DEBUG") + logger.debug("Using explicit MERGE partition predicate: \n%s", predicate) + + merger = dt.merge( + source=data, + predicate=predicate, + source_alias=merge_config.get("source_alias"), + target_alias=target_alias, + error_on_type_mismatch=error_on_type_mismatch, + writer_properties=writer_properties, + commit_properties=commit_properties, + custom_metadata=custom_metadata, + **delta_params, + ) + + if merge_type == MergeType.update_only: + return merger.when_matched_update_all().execute() + elif merge_type == MergeType.deduplicate_insert: + return merger.when_not_matched_insert_all().execute() + elif merge_type == MergeType.upsert: + return merger.when_matched_update_all().when_not_matched_insert_all().execute() + elif merge_type == MergeType.replace_delete_unmatched: + return merger.when_matched_update_all().when_not_matched_by_source_delete().execute() + else: + raise NotImplementedError diff --git a/libraries/dagster-delta/dagster_delta/_handler/utils/__init__.py b/libraries/dagster-delta/dagster_delta/_handler/utils/__init__.py new file mode 100644 index 0000000..3d47235 --- /dev/null +++ b/libraries/dagster-delta/dagster_delta/_handler/utils/__init__.py @@ -0,0 +1,11 @@ +from dagster_delta._handler.utils.date_format import extract_date_format_from_partition_definition +from dagster_delta._handler.utils.dnf import partition_dimensions_to_dnf +from dagster_delta._handler.utils.predicates import create_predicate +from dagster_delta._handler.utils.reader import read_table + +__all__ = [ + "create_predicate", + "read_table", + "extract_date_format_from_partition_definition", + "partition_dimensions_to_dnf", +] diff --git a/libraries/dagster-delta/dagster_delta/_handler/utils/date_format.py b/libraries/dagster-delta/dagster_delta/_handler/utils/date_format.py new file mode 100644 index 0000000..9f9388a --- /dev/null +++ b/libraries/dagster-delta/dagster_delta/_handler/utils/date_format.py @@ -0,0 +1,73 @@ +from typing import Optional, Union + +from dagster import ( + InputContext, + MultiPartitionsDefinition, + OutputContext, +) +from dagster._core.definitions.time_window_partitions import ( + TimeWindowPartitionsDefinition, +) + + +def extract_date_format_from_partition_definition( + context: Union[OutputContext, InputContext], +) -> Optional[dict[str, str]]: + """Gets the date format from the partition definition if there is a TimeWindowPartitionsDefinition present (nested or not), to be used to properly compare with columns + in the delta table which are not a datetime object. Returns None if no TimeWindowPartitionsDefinition were present. + """ + if isinstance(context, InputContext): + if context.has_asset_partitions: + if context.upstream_output is not None: + partition_expr = context.upstream_output.definition_metadata["partition_expr"] # type: ignore + partitions_definition = context.asset_partitions_def + else: + raise ValueError( + "'partition_expr' should have been set in the metadata of the incoming asset since it has a partition definition.", + ) + else: + return None + elif isinstance(context, OutputContext): + if context.has_asset_partitions: + if ( + context.definition_metadata is not None + and "partition_expr" in context.definition_metadata + ): + partition_expr = context.definition_metadata["partition_expr"] + else: + raise ValueError( + "'partition_expr' should have been set in the metadata of the incoming asset since it has a partition definition.", + ) + partitions_definition = context.asset_partitions_def + else: + return None + if partition_expr is None or partitions_definition is None: + return None + + date_format: dict[str, str] = {} + if isinstance(partitions_definition, TimeWindowPartitionsDefinition): + if isinstance(partition_expr, str): + date_format[partition_expr] = partitions_definition.fmt # type: ignore + else: + raise ValueError( + "Single partition definition provided, so partion_expr needs to be a string", + ) + elif isinstance(partitions_definition, MultiPartitionsDefinition): + if isinstance(partition_expr, dict): + for partition_dims_definition in partitions_definition.partitions_defs: + if isinstance( + partition_dims_definition.partitions_def, + TimeWindowPartitionsDefinition, + ): + partition_expr_name = partition_expr.get(partition_dims_definition.name) + if partition_expr_name is None: + raise ValueError( + f"Partition_expr mapping is invalid. Partition_dimension :{partition_dims_definition.name} not found in partition_expr: {partition_expr}.", + ) + date_format[partition_expr_name] = partition_dims_definition.partitions_def.fmt + else: + raise ValueError( + "MultiPartitionsDefinition provided, so partion_expr needs to be a dictionary mapping of {dimension: column}", + ) + + return date_format if len(date_format) else None diff --git a/libraries/dagster-delta/dagster_delta/_handler/utils/dnf.py b/libraries/dagster-delta/dagster_delta/_handler/utils/dnf.py new file mode 100644 index 0000000..f591ec6 --- /dev/null +++ b/libraries/dagster-delta/dagster_delta/_handler/utils/dnf.py @@ -0,0 +1,154 @@ +from collections.abc import Iterable, Sequence +from typing import Optional, Union, cast + +from dagster._core.definitions.time_window_partitions import ( + TimeWindow, +) +from dagster._core.storage.db_io_manager import TablePartitionDimension +from deltalake.schema import Field as DeltaField +from deltalake.schema import PrimitiveType, Schema +from deltalake.table import FilterLiteralType + +from dagster_delta.io_manager.base import ( + DELTA_DATE_FORMAT, + DELTA_DATETIME_FORMAT, +) + + +def partition_dimensions_to_dnf( + partition_dimensions: Iterable[TablePartitionDimension], + table_schema: Schema, + str_values: bool = False, + input_dnf: bool = False, # during input we want to read a range when it's (un)-partitioned + date_format: Optional[dict[str, str]] = None, +) -> Optional[list[FilterLiteralType]]: + """Converts partition dimensions to dnf filters""" + parts = [] + for partition_dimension in partition_dimensions: + field = _field_from_schema(partition_dimension.partition_expr, table_schema) + if field is None: + raise ValueError( + f"Field {partition_dimension.partition_expr} is not part of table schema.", + "Currently only column names are supported as partition expressions", + ) + if isinstance(field.type, PrimitiveType): + if field.type.type in ["timestamp", "date"]: + filter_ = _time_window_partition_dnf( + partition_dimension, + field.type.type, + str_values, + input_dnf, + ) + if isinstance(filter_, list): + parts.extend(filter_) + else: + parts.append(filter_) + elif field.type.type in ["string", "integer"]: + field_date_format = date_format.get(field.name) if date_format is not None else None + filter_ = _value_dnf( + partition_dimension, + field_date_format, + field.type.type, + ) + if isinstance(filter_, list): + parts.extend(filter_) + else: + parts.append(filter_) + else: + raise ValueError(f"Unsupported partition type {field.type.type}") + else: + raise ValueError(f"Unsupported partition type {field.type}") + + return parts if len(parts) > 0 else None + + +def _value_dnf( + table_partition: TablePartitionDimension, + date_format: Optional[str] = None, + field_type: Optional[str] = None, +) -> Union[ + list[tuple[str, str, Union[int, str]]], + tuple[str, str, Sequence[str]], + tuple[str, str, str], +]: # noqa: ANN202 + # ", ".join(f"'{partition}'" for partition in table_partition.partitions) # noqa: ERA001 + if ( + isinstance(table_partition.partitions, list) + and all(isinstance(p, TimeWindow) for p in table_partition.partitions) + ) or isinstance(table_partition.partitions, TimeWindow): + if date_format is None: + raise Exception( + "Date format not set on time based partition definition, even though field is (str, int). Set date fmt on the partition_def, or change column type to date/datetime.", + ) + if isinstance(table_partition.partitions, list): + start_dts = [partition.start for partition in table_partition.partitions] # type: ignore + end_dts = [partition.end for partition in table_partition.partitions] # type: ignore + start_dt = min(start_dts) + end_dt = max(end_dts) + else: + start_dt = table_partition.partitions.start + end_dt = table_partition.partitions.end + + start_dt = start_dt.strftime(date_format) + end_dt = end_dt.strftime(date_format) + + if field_type == "integer": + start_dt = int(start_dt) + end_dt = int(end_dt) + return [ + (table_partition.partition_expr, ">=", start_dt), + (table_partition.partition_expr, "<", end_dt), + ] + + else: + partition = cast(Sequence[str], table_partition.partitions) + partition = list(set(partition)) + if len(partition) > 1: + return (table_partition.partition_expr, "in", partition) + else: + return (table_partition.partition_expr, "=", partition[0]) + + +def _time_window_partition_dnf( + table_partition: TablePartitionDimension, + data_type: str, + str_values: bool, + input_dnf: bool, +) -> Union[FilterLiteralType, list[FilterLiteralType]]: + if isinstance(table_partition.partitions, list): + raise Exception( + "For date primitive we shouldn't have received a sequence[str] but a TimeWindow", + ) + else: + partition = cast(TimeWindow, table_partition.partitions) + start_dt, end_dt = partition + start_dt, end_dt = start_dt.replace(tzinfo=None), end_dt.replace(tzinfo=None) + + if str_values: + if data_type == "timestamp": + start_dt, end_dt = ( + start_dt.strftime(DELTA_DATETIME_FORMAT), + end_dt.strftime(DELTA_DATETIME_FORMAT), + ) + elif data_type == "date": + start_dt, end_dt = ( + start_dt.strftime(DELTA_DATE_FORMAT), + end_dt.strftime(DELTA_DATE_FORMAT), + ) + else: + raise ValueError(f"Unknown primitive type: {data_type}") + + if input_dnf: + return [ + (table_partition.partition_expr, ">=", start_dt), + (table_partition.partition_expr, "<", end_dt), + ] + else: + return (table_partition.partition_expr, "=", start_dt) + + +def _field_from_schema(field_name: str, schema: Schema) -> Optional[DeltaField]: + for field in schema.fields: + if field.name == field_name: + return field + return None diff --git a/libraries/dagster-delta/dagster_delta/_handler/utils/predicates.py b/libraries/dagster-delta/dagster_delta/_handler/utils/predicates.py new file mode 100644 index 0000000..e5c85bb --- /dev/null +++ b/libraries/dagster-delta/dagster_delta/_handler/utils/predicates.py @@ -0,0 +1,27 @@ +from datetime import datetime +from typing import Optional + +from deltalake.table import FilterLiteralType + + +def create_predicate( + partition_filters: list[FilterLiteralType], + target_alias: Optional[str] = None, +) -> str: + partition_predicates = [] + for part_filter in partition_filters: + column = f"{target_alias}.{part_filter[0]}" if target_alias is not None else part_filter[0] + value = part_filter[2] + if isinstance(value, (int, float, bool)): + value = str(value) + elif isinstance(value, str): + value = f"'{value}'" + elif isinstance(value, list): + value = str(tuple(v for v in value)) + elif isinstance(value, datetime): + value = str( + int(value.timestamp() * 1000 * 1000), + ) # convert to microseconds + partition_predicates.append(f"{column} {part_filter[1]} {value}") + + return " AND ".join(partition_predicates) diff --git a/libraries/dagster-delta/dagster_delta/_handler/utils/reader.py b/libraries/dagster-delta/dagster_delta/_handler/utils/reader.py new file mode 100644 index 0000000..59ce794 --- /dev/null +++ b/libraries/dagster-delta/dagster_delta/_handler/utils/reader.py @@ -0,0 +1,53 @@ +import logging +from typing import Optional + +import pyarrow.dataset as ds +from dagster._core.storage.db_io_manager import TableSlice +from deltalake import DeltaTable + +from dagster_delta._handler.utils.dnf import partition_dimensions_to_dnf + +try: + from pyarrow.parquet import filters_to_expression # pyarrow >= 10.0.0 +except ImportError: + from pyarrow.parquet import _filters_to_expression as filters_to_expression + + +from dagster_delta.io_manager.base import ( + TableConnection, +) + + +def read_table( + table_slice: TableSlice, + connection: TableConnection, + version: Optional[int] = None, + date_format: Optional[dict[str, str]] = None, + parquet_read_options: Optional[ds.ParquetReadOptions] = None, +) -> ds.Dataset: + table = DeltaTable( + table_uri=connection.table_uri, + version=version, + storage_options=connection.storage_options, + ) + logger = logging.getLogger() + logger.setLevel("DEBUG") + logger.debug("Connection timeout duration %s", connection.storage_options.get("timeout")) + + partition_expr = None + if table_slice.partition_dimensions is not None: + partition_filters = partition_dimensions_to_dnf( + partition_dimensions=table_slice.partition_dimensions, + table_schema=table.schema(), + input_dnf=True, + date_format=date_format, + ) + if partition_filters is not None: + partition_expr = filters_to_expression([partition_filters]) + + logger.debug("Dataset input predicate %s", partition_expr) + dataset = table.to_pyarrow_dataset(parquet_read_options=parquet_read_options) + if partition_expr is not None: + dataset = dataset.filter(expression=partition_expr) + + return dataset diff --git a/libraries/dagster-delta/dagster_delta/config.py b/libraries/dagster-delta/dagster_delta/config.py index c2fba28..35cdb39 100644 --- a/libraries/dagster-delta/dagster_delta/config.py +++ b/libraries/dagster-delta/dagster_delta/config.py @@ -4,11 +4,20 @@ from dagster import Config +def _to_str_dict(dictionary: dict) -> dict[str, str]: + """Filters dict of None values and casts other values to str.""" + return {key: str(value) for key, value in dictionary.items() if value is not None} + + class LocalConfig(Config): """Storage configuration for local object store.""" provider: Literal["local"] = "local" + def str_dict(self) -> dict[str, str]: + """Storage options as str dict.""" + return _to_str_dict(self.model_dump()) + class AzureConfig(Config): """Storage configuration for Microsoft Azure Blob or ADLS Gen 2 object store.""" @@ -54,6 +63,10 @@ class AzureConfig(Config): container_name: Optional[str] = None """Storage container name""" + def str_dict(self) -> dict[str, str]: + """Storage options as str dict.""" + return _to_str_dict(self.model_dump()) + class S3Config(Config): """Storage configuration for Amazon Web Services (AWS) S3 object store.""" @@ -99,14 +112,9 @@ class S3Config(Config): https://docs.aws.amazon.com/AmazonECS/latest/developerguide/task-iam-roles.html """ - copy_if_not_exists: Optional[str] = None - """Specify additional headers passed to storage backend, that enable 'if_not_exists' semantics. - - https://docs.rs/object_store/0.7.0/object_store/aws/enum.S3CopyIfNotExists.html#variant.Header - """ - - AWS_S3_ALLOW_UNSAFE_RENAME: Optional[bool] = None - """Allows tables writes that may conflict with concurrent writers.""" + def str_dict(self) -> dict[str, str]: + """Storage options as str dict.""" + return _to_str_dict(self.model_dump()) class GcsConfig(Config): @@ -126,6 +134,23 @@ class GcsConfig(Config): application_credentials: Optional[str] = None """Application credentials path""" + def str_dict(self) -> dict[str, str]: + """Storage options as str dict.""" + return _to_str_dict(self.model_dump()) + + +class BackoffConfig(Config): + """Configuration for exponential back off https://docs.rs/object_store/latest/object_store/struct.BackoffConfig.html""" + + init_backoff: Optional[str] = None + """The initial backoff duration""" + + max_backoff: Optional[str] = None + """The maximum backoff duration""" + + base: Optional[float] = None + """The multiplier to use for the next backoff duration""" + class ClientConfig(Config): """Configuration for http client interacting with storage APIs.""" @@ -178,7 +203,7 @@ class ClientConfig(Config): """HTTP proxy to use for requests""" timeout: Optional[str] = None - """Request timeout + """Request timeout, e.g. 10s, 60s The timeout is applied from when the request starts connecting until the response body has finished """ @@ -186,9 +211,48 @@ class ClientConfig(Config): user_agent: Optional[str] = None """User-Agent header to be used by this client""" + OBJECT_STORE_CONCURRENCY_LIMIT: Optional[str] = None + """The number of concurrent connections the underlying object store can create""" + + MOUNT_ALLOW_UNSAFE_RENAME: Optional[str] = None + """If set it will allow unsafe renames on mounted storage""" + + max_retries: Optional[int] = None + """The maximum number of times to retry a request. Set to 0 to disable retries""" + + retry_timeout: Optional[str] = None + """The maximum duration of time from the initial request after which no further retries will be attempted. e.g. 10s, 60s""" + + backoff_config: Optional[BackoffConfig] = None + """Configuration for exponential back off """ + + def str_dict(self) -> dict[str, str]: + """Storage options as str dict.""" + model_dump = self.model_dump() + str_dict: dict[str, str] = {} + for key, value in model_dump.items(): + if value is not None: + if isinstance(value, BackoffConfig): + ## delta-rs uses custom config keys for the BackOffConfig + ## https://delta-io.github.io/delta-rs/integrations/object-storage/special_configuration/ + if value.base is not None: + str_dict["backoff_config.base"] = str(value.base) + if value.max_backoff is not None: + str_dict["backoff_config.max_backoff"] = str(value.max_backoff) + if value.init_backoff is not None: + str_dict["backoff_config.init_backoff"] = str(value.init_backoff) + else: + str_dict[key] = str(value) + return str_dict + class MergeType(str, Enum): - """Enum of the possible IO Manager merge types""" + """Enum of the possible IO Manager merge types + - "deduplicate_insert" <- Deduplicates on write + - "update_only" <- updates only the matches records + - "upsert" <- updates existing matches and inserts non matched records + - "replace_and_delete_unmatched" <- updates existing matches and deletes unmatched + """ deduplicate_insert = "deduplicate_insert" # Deduplicates on write update_only = "update_only" # updates only the records @@ -203,7 +267,9 @@ class MergeConfig(Config): """The type of MERGE to execute.""" predicate: Optional[str] = None - """SQL like predicate on how to merge, passed into DeltaTable.merge()""" + """SQL like predicate on how to merge, passed into DeltaTable.merge() + + This can also be set on the asset definition metadata using the `merge_predicate` key""" source_alias: Optional[str] = None """Alias for the source table""" diff --git a/libraries/dagster-delta/dagster_delta/handler.py b/libraries/dagster-delta/dagster_delta/handler.py deleted file mode 100644 index 1cef9fc..0000000 --- a/libraries/dagster-delta/dagster_delta/handler.py +++ /dev/null @@ -1,611 +0,0 @@ -import logging -from abc import abstractmethod -from collections.abc import Iterable, Sequence -from datetime import datetime -from typing import Any, Generic, Optional, TypeVar, Union, cast - -import pyarrow as pa -import pyarrow.compute as pc -import pyarrow.dataset as ds -from dagster import ( - InputContext, - MetadataValue, - MultiPartitionsDefinition, - OutputContext, - TableColumn, - TableSchema, -) -from dagster._core.definitions.time_window_partitions import ( - TimeWindow, - TimeWindowPartitionsDefinition, -) -from dagster._core.storage.db_io_manager import DbTypeHandler, TablePartitionDimension, TableSlice -from deltalake import CommitProperties, DeltaTable, WriterProperties, write_deltalake -from deltalake.exceptions import TableNotFoundError -from deltalake.schema import Field as DeltaField -from deltalake.schema import PrimitiveType, Schema, _convert_pa_schema_to_delta -from deltalake.table import FilterLiteralType - -try: - from pyarrow.parquet import filters_to_expression # pyarrow >= 10.0.0 -except ImportError: - from pyarrow.parquet import _filters_to_expression as filters_to_expression - - -from .config import MergeType -from .io_manager import ( - DELTA_DATE_FORMAT, - DELTA_DATETIME_FORMAT, - TableConnection, - _DeltaTableIOManagerResourceConfig, -) - -T = TypeVar("T") -ArrowTypes = Union[pa.Table, pa.RecordBatchReader] - - -def _create_predicate( - partition_filters: list[FilterLiteralType], - target_alias: Optional[str] = None, -) -> Optional[str]: - partition_predicates = [] - for part_filter in partition_filters: - column = f"{target_alias}.{part_filter[0]}" if target_alias is not None else part_filter[0] - value = part_filter[2] - if isinstance(value, (int, float, bool)): - value = str(value) - elif isinstance(value, str): - value = f"'{value}'" - elif isinstance(value, list): - value = str(tuple(v for v in value)) - elif isinstance(value, datetime): - value = str( - int(value.timestamp() * 1000 * 1000), - ) # convert to microseconds - partition_predicates.append(f"{column} {part_filter[1]} {value}") - - return " AND ".join(partition_predicates) - - -def _merge_execute( - dt: DeltaTable, - data: Union[pa.RecordBatchReader, pa.Table], - merge_config: dict[str, Any], - writer_properties: Optional[WriterProperties], - commit_properties: Optional[CommitProperties], - custom_metadata: Optional[dict[str, str]], - delta_params: dict[str, Any], - merge_predicate_from_metadata: Optional[str], - partition_filters: Optional[list[FilterLiteralType]] = None, -) -> dict[str, Any]: - merge_type = merge_config.get("merge_type") - error_on_type_mismatch = merge_config.get("error_on_type_mismatch", True) - - if merge_predicate_from_metadata is not None: - predicate = str(merge_predicate_from_metadata) - elif merge_config.get("predicate") is not None: - predicate = str(merge_config.get("predicate")) - else: - raise Exception("merge predicate was not provided") - - target_alias = merge_config.get("target_alias") - - if partition_filters is not None: - partition_predicate = _create_predicate(partition_filters, target_alias=target_alias) - - predicate = f"{predicate} AND {partition_predicate}" - logger = logging.getLogger() - logger.setLevel("DEBUG") - logger.debug("Using explicit MERGE partition predicate: \n%s", predicate) - - merger = dt.merge( - source=data, - predicate=predicate, - source_alias=merge_config.get("source_alias"), - target_alias=target_alias, - error_on_type_mismatch=error_on_type_mismatch, - writer_properties=writer_properties, - commit_properties=commit_properties, - custom_metadata=custom_metadata, - **delta_params, - ) - - if merge_type == MergeType.update_only: - return merger.when_matched_update_all().execute() - elif merge_type == MergeType.deduplicate_insert: - return merger.when_not_matched_insert_all().execute() - elif merge_type == MergeType.upsert: - return merger.when_matched_update_all().when_not_matched_insert_all().execute() - elif merge_type == MergeType.replace_delete_unmatched: - return merger.when_matched_update_all().when_not_matched_by_source_delete().execute() - else: - raise NotImplementedError - - -class DeltalakeBaseArrowTypeHandler(DbTypeHandler[T], Generic[T]): # noqa: D101 - @abstractmethod - def from_arrow(self, obj: pa.RecordBatchReader, target_type: type) -> T: - """Abstract method to convert arrow to target type""" - pass - - @abstractmethod - def to_arrow(self, obj: T) -> tuple[ArrowTypes, dict[str, Any]]: - """Abstract method to convert type to arrow""" - pass - - @abstractmethod - def get_output_stats(self, obj: T) -> dict[str, MetadataValue]: - """Abstract method to return output stats""" - pass - - def handle_output( - self, - context: OutputContext, - table_slice: TableSlice, - obj: T, - connection: TableConnection, - ): - """Stores pyarrow types in Delta table.""" - logger = logging.getLogger() - logger.setLevel("DEBUG") - definition_metadata = context.definition_metadata or {} - merge_predicate_from_metadata = definition_metadata.get("merge_predicate") - additional_table_config = definition_metadata.get("table_configuration", {}) - if connection.table_config is not None: - table_config = additional_table_config | connection.table_config - else: - table_config = additional_table_config - resource_config = context.resource_config or {} - object_stats = self.get_output_stats(obj) - data, delta_params = self.to_arrow(obj=obj) - delta_schema = Schema.from_pyarrow(_convert_pa_schema_to_delta(data.schema)) - resource_config = cast(_DeltaTableIOManagerResourceConfig, context.resource_config) - engine = resource_config.get("writer_engine") - save_mode = definition_metadata.get("mode") - main_save_mode = resource_config.get("mode") - custom_metadata = definition_metadata.get("custom_metadata") or resource_config.get( - "custom_metadata", - ) - schema_mode = definition_metadata.get("schema_mode") or resource_config.get( - "schema_mode", - ) - writer_properties = resource_config.get("writer_properties") - writer_properties = ( - WriterProperties(**writer_properties) if writer_properties is not None else None # type: ignore - ) - - commit_properties = definition_metadata.get("commit_properties") or resource_config.get( - "commit_properties", - ) - commit_properties = ( - CommitProperties(**commit_properties) if commit_properties is not None else None # type: ignore - ) - merge_config = resource_config.get("merge_config") - - date_format = extract_date_format_from_partition_definition(context) - - if save_mode is not None: - logger.debug( - "IO manager mode overridden with the asset metadata mode, %s -> %s", - main_save_mode, - save_mode, - ) - main_save_mode = save_mode - logger.debug("Writing with mode: `%s`", main_save_mode) - - merge_stats = None - partition_filters = None - partition_columns = None - predicate = None - - if table_slice.partition_dimensions is not None: - partition_filters = partition_dimensions_to_dnf( - partition_dimensions=table_slice.partition_dimensions, - table_schema=delta_schema, - str_values=True, - date_format=date_format, - ) - if partition_filters is not None and engine == "rust": - ## Convert partition_filter to predicate - predicate = _create_predicate(partition_filters) - partition_filters = None - else: - predicate = None - # TODO(): make robust and move to function - partition_columns = [dim.partition_expr for dim in table_slice.partition_dimensions] - - if main_save_mode not in ["merge", "create_or_replace"]: - if predicate is not None and engine == "rust": - logger.debug("Using explicit partition predicate: \n%s", predicate) - elif partition_filters is not None and engine == "pyarrow": - logger.debug("Using explicit partition_filter: \n%s", partition_filters) - write_deltalake( # type: ignore - table_or_uri=connection.table_uri, - data=data, - storage_options=connection.storage_options, - mode=main_save_mode, - partition_filters=partition_filters, - predicate=predicate, - partition_by=partition_columns, - engine=engine, - schema_mode=schema_mode, - configuration=table_config, - custom_metadata=custom_metadata, - writer_properties=writer_properties, - commit_properties=commit_properties, - **delta_params, - ) - elif main_save_mode == "create_or_replace": - DeltaTable.create( - table_uri=connection.table_uri, - schema=_convert_pa_schema_to_delta(data.schema), - mode="overwrite", - partition_by=partition_columns, - configuration=table_config, - storage_options=connection.storage_options, - custom_metadata=custom_metadata, - ) - else: - if merge_config is None: - raise ValueError( - "Merge Configuration should be provided when `mode = WriterMode.merge`", - ) - try: - dt = DeltaTable(connection.table_uri, storage_options=connection.storage_options) - except TableNotFoundError: - logger.debug("Creating a DeltaTable first before merging.") - dt = DeltaTable.create( - table_uri=connection.table_uri, - schema=_convert_pa_schema_to_delta(data.schema), - partition_by=partition_columns, - configuration=table_config, - storage_options=connection.storage_options, - custom_metadata=custom_metadata, - ) - merge_stats = _merge_execute( - dt, - data, - merge_config, - writer_properties=writer_properties, - commit_properties=commit_properties, - custom_metadata=custom_metadata, - delta_params=delta_params, - merge_predicate_from_metadata=merge_predicate_from_metadata, - partition_filters=partition_filters, - ) - - dt = DeltaTable(connection.table_uri, storage_options=connection.storage_options) - try: - stats = _get_partition_stats(dt=dt, partition_filters=partition_filters) - except Exception as e: - context.log.warning(f"error while computing table stats: {e}") - stats = {} - - output_metadata = { - "table_columns": MetadataValue.table_schema( - TableSchema( - columns=[ - TableColumn(name=name, type=str(dtype)) - for name, dtype in zip(data.schema.names, data.schema.types) - ], - ), - ), - "table_uri": MetadataValue.path(connection.table_uri), - "table_version": MetadataValue.int(dt.version()), - **stats, - **object_stats, - } - if merge_stats is not None: - output_metadata["num_output_rows"] = MetadataValue.int( - merge_stats.get("num_output_rows", 0), - ) - output_metadata["merge_stats"] = MetadataValue.json(merge_stats) - - context.add_output_metadata(output_metadata) - - def load_input( - self, - context: InputContext, - table_slice: TableSlice, - connection: TableConnection, - ) -> T: - """Loads the input as a pyarrow Table or RecordBatchReader.""" - parquet_read_options = None - if context.resource_config is not None: - parquet_read_options = context.resource_config.get("parquet_read_options", None) - parquet_read_options = ( - ds.ParquetReadOptions(**parquet_read_options) - if parquet_read_options is not None - else None - ) - - dataset = _table_reader(table_slice, connection, parquet_read_options=parquet_read_options) - - if context.dagster_type.typing_type == ds.Dataset: - if table_slice.columns is not None: - raise ValueError("Cannot select columns when loading as Dataset.") - return dataset - - scanner = dataset.scanner(columns=table_slice.columns) - return self.from_arrow(scanner.to_reader(), context.dagster_type.typing_type) - - -class DeltaLakePyArrowTypeHandler(DeltalakeBaseArrowTypeHandler[ArrowTypes]): # noqa: D101 - def from_arrow(self, obj: pa.RecordBatchReader, target_type: type[ArrowTypes]) -> ArrowTypes: # noqa: D102 - if target_type == pa.Table: - return obj.read_all() - return obj - - def to_arrow(self, obj: ArrowTypes) -> tuple[ArrowTypes, dict[str, Any]]: # noqa: D102 - if isinstance(obj, ds.Dataset): - return obj.scanner().to_reader(), {} - return obj, {} - - def get_output_stats(self, obj: ArrowTypes) -> dict[str, MetadataValue]: # noqa: ARG002 - """Returns output stats to be attached to the the context. - - Args: - obj (PolarsTypes): LazyFrame or DataFrame - - Returns: - Mapping[str, MetadataValue]: metadata stats - """ - return {} - - @property - def supported_types(self) -> Sequence[type[object]]: - """Returns the supported dtypes for this typeHandler""" - return [pa.Table, pa.RecordBatchReader, ds.Dataset] - - -def partition_dimensions_to_dnf( - partition_dimensions: Iterable[TablePartitionDimension], - table_schema: Schema, - str_values: bool = False, - input_dnf: bool = False, # during input we want to read a range when it's (un)-partitioned - date_format: Optional[dict[str, str]] = None, -) -> Optional[list[FilterLiteralType]]: - """Converts partition dimensions to dnf filters""" - parts = [] - for partition_dimension in partition_dimensions: - field = _field_from_schema(partition_dimension.partition_expr, table_schema) - if field is None: - raise ValueError( - f"Field {partition_dimension.partition_expr} is not part of table schema.", - "Currently only column names are supported as partition expressions", - ) - if isinstance(field.type, PrimitiveType): - if field.type.type in ["timestamp", "date"]: - filter_ = _time_window_partition_dnf( - partition_dimension, - field.type.type, - str_values, - input_dnf, - ) - if isinstance(filter_, list): - parts.extend(filter_) - else: - parts.append(filter_) - elif field.type.type in ["string", "integer"]: - field_format = date_format.get(field.name) if date_format is not None else None - filter_ = _value_dnf( - partition_dimension, - field_format, - field.type.type, - ) - if isinstance(filter_, list): - parts.extend(filter_) - else: - parts.append(filter_) - else: - raise ValueError(f"Unsupported partition type {field.type.type}") - else: - raise ValueError(f"Unsupported partition type {field.type}") - - return parts if len(parts) > 0 else None - - -def _value_dnf( - table_partition: TablePartitionDimension, - date_format: Optional[str] = None, - field_type: Optional[str] = None, -) -> Union[ - list[tuple[str, str, Union[int, str]]], - tuple[str, str, Sequence[str]], - tuple[str, str, str], -]: # noqa: ANN202 - # ", ".join(f"'{partition}'" for partition in table_partition.partitions) # noqa: ERA001 - if ( - isinstance(table_partition.partitions, list) - and all(isinstance(p, TimeWindow) for p in table_partition.partitions) - ) or isinstance(table_partition.partitions, TimeWindow): - if date_format is None: - raise Exception("Date format was not provided") - if isinstance(table_partition.partitions, list): - start_dts = [partition.start for partition in table_partition.partitions] # type: ignore - end_dts = [partition.end for partition in table_partition.partitions] # type: ignore - start_dt = min(start_dts) - end_dt = max(end_dts) - else: - start_dt = table_partition.partitions.start - end_dt = table_partition.partitions.end - - start_dt = start_dt.strftime(date_format) - end_dt = end_dt.strftime(date_format) - - if field_type == "integer": - start_dt = int(start_dt) - end_dt = int(end_dt) - return [ - (table_partition.partition_expr, ">=", start_dt), - (table_partition.partition_expr, "<", end_dt), - ] - - else: - partition = cast(Sequence[str], table_partition.partitions) - partition = list(set(partition)) - if len(partition) > 1: - return (table_partition.partition_expr, "in", partition) - else: - return (table_partition.partition_expr, "=", partition[0]) - - -def _time_window_partition_dnf( - table_partition: TablePartitionDimension, - data_type: str, - str_values: bool, - input_dnf: bool, -) -> Union[FilterLiteralType, list[FilterLiteralType]]: - if isinstance(table_partition.partitions, list): - start_dt = min( - [cast(TimeWindow, partition).start for partition in table_partition.partitions], - ).replace(tzinfo=None) - end_dt = max( - [cast(TimeWindow, partition).end for partition in table_partition.partitions], - ).replace(tzinfo=None) - else: - partition = cast(TimeWindow, table_partition.partitions) - start_dt, end_dt = partition - start_dt, end_dt = start_dt.replace(tzinfo=None), end_dt.replace(tzinfo=None) - - if str_values: - if data_type == "timestamp": - start_dt, end_dt = ( - start_dt.strftime(DELTA_DATETIME_FORMAT), - end_dt.strftime(DELTA_DATETIME_FORMAT), - ) - elif data_type == "date": - start_dt, end_dt = ( - start_dt.strftime(DELTA_DATE_FORMAT), - end_dt.strftime(DELTA_DATETIME_FORMAT), - ) - else: - raise ValueError(f"Unknown primitive type: {data_type}") - - if input_dnf: - return [ - (table_partition.partition_expr, ">=", start_dt), - (table_partition.partition_expr, "<", end_dt), - ] - else: - return (table_partition.partition_expr, "=", start_dt) - - -def _field_from_schema(field_name: str, schema: Schema) -> Optional[DeltaField]: - for field in schema.fields: - if field.name == field_name: - return field - return None - - -def _get_partition_stats( - dt: DeltaTable, - partition_filters: Optional[list[FilterLiteralType]] = None, -) -> dict[str, Any]: - """_summary_ - - Args: - dt (DeltaTable): DeltaTable object - partition_filters (list[FilterLiteralType] | None, optional): filters to grabs stats with. Defaults to None. - - Returns: - dict[str, MetadataValue]: _description_ - """ - files = pa.array(dt.files(partition_filters=partition_filters)) - files_table = pa.Table.from_arrays([files], names=["path"]) - actions_table = pa.Table.from_batches([dt.get_add_actions(flatten=True)]) - actions_table = actions_table.select(["path", "size_bytes", "num_records"]) - table = files_table.join(actions_table, keys="path") - - stats = { - "size_MB": MetadataValue.float( - pc.sum(table.column("size_bytes")).as_py() * 0.00000095367432, # type: ignore - ), - "dagster/row_count": MetadataValue.int(pc.sum(table.column("num_records")).as_py()), # type: ignore - } - - return stats - - -def _table_reader( - table_slice: TableSlice, - connection: TableConnection, - version: Optional[int] = None, - date_format: Optional[dict[str, str]] = None, - parquet_read_options: Optional[ds.ParquetReadOptions] = None, -) -> ds.Dataset: - table = DeltaTable( - table_uri=connection.table_uri, - version=version, - storage_options=connection.storage_options, - ) - logger = logging.getLogger() - logger.setLevel("DEBUG") - logger.debug("Connection timeout duration %s", connection.storage_options.get("timeout")) - - partition_expr = None - if table_slice.partition_dimensions is not None: - partition_filters = partition_dimensions_to_dnf( - partition_dimensions=table_slice.partition_dimensions, - table_schema=table.schema(), - input_dnf=True, - date_format=date_format, - ) - if partition_filters is not None: - partition_expr = filters_to_expression([partition_filters]) - - logger.debug("Dataset input predicate %s", partition_expr) - dataset = table.to_pyarrow_dataset(parquet_read_options=parquet_read_options) - if partition_expr is not None: - dataset = dataset.filter(expression=partition_expr) - - return dataset - - -def extract_date_format_from_partition_definition( - context: Union[OutputContext, InputContext], -) -> Optional[dict[str, str]]: - """Gets the date format from the partition definition if there is a TimeWindowPartitionsDefinition present (nested or not), to be used to properly compare with columns - in the delta table which are not a datetime object. Returns None if no TimeWindowPartitionsDefinition were present. - """ - if isinstance(context, InputContext): - if context.has_asset_partitions: - if context.upstream_output is not None: - partition_expr = context.upstream_output.definition_metadata["partition_expr"] # type: ignore - partitions_definition = context.asset_partitions_def - else: - raise ValueError( - "'partition_expr' should have been set in the metadata of the incoming asset since it has a partition definition.", - ) - else: - return None - elif isinstance(context, OutputContext): - if context.has_asset_partitions: - if ( - context.definition_metadata is not None - and "partition_expr" in context.definition_metadata - ): - partition_expr = context.definition_metadata["partition_expr"] - else: - raise ValueError( - "'partition_expr' should have been set in the metadata of the incoming asset since it has a partition definition.", - ) - partitions_definition = context.asset_partitions_def - else: - return None - if partition_expr is None or partitions_definition is None: - return None - date_format: dict[str, str] = {} - if isinstance(partitions_definition, TimeWindowPartitionsDefinition): - date_format[partition_expr] = partitions_definition.fmt # type: ignore - elif isinstance(partitions_definition, MultiPartitionsDefinition): - for partition_dims_definition in partitions_definition.partitions_defs: - if isinstance( - partition_dims_definition.partitions_def, - TimeWindowPartitionsDefinition, - ): - date_format[partition_expr[partition_dims_definition.name]] = ( - partition_dims_definition.partitions_def.fmt - ) - - return date_format if len(date_format) else None diff --git a/libraries/dagster-delta/dagster_delta/io_manager/__init__.py b/libraries/dagster-delta/dagster_delta/io_manager/__init__.py new file mode 100644 index 0000000..5025805 --- /dev/null +++ b/libraries/dagster-delta/dagster_delta/io_manager/__init__.py @@ -0,0 +1,27 @@ +from dagster_delta.io_manager.arrow import DeltaLakePyarrowIOManager +from dagster_delta.io_manager.base import ( + BaseDeltaLakeIOManager, + SchemaMode, + WriteMode, + WriterEngine, +) + +__all__ = [ + "WriteMode", + "WriterEngine", + "SchemaMode", + "BaseDeltaLakeIOManager", + "DeltaLakePyarrowIOManager", +] + + +try: + from dagster_delta.io_manager.polars import DeltaLakePolarsIOManager # noqa + + __all__.extend(["DeltaLakePolarsIOManager"]) + +except ImportError as e: + if "polars" in str(e): + pass + else: + raise e diff --git a/libraries/dagster-delta/dagster_delta/io_manager/arrow.py b/libraries/dagster-delta/dagster_delta/io_manager/arrow.py new file mode 100644 index 0000000..6c2a83c --- /dev/null +++ b/libraries/dagster-delta/dagster_delta/io_manager/arrow.py @@ -0,0 +1,46 @@ +from collections.abc import Sequence +from typing import Any + +import dagster as dg +import pyarrow as pa +import pyarrow.dataset as ds +from dagster._core.storage.db_io_manager import DbTypeHandler + +from dagster_delta._handler.base import ArrowTypes, DeltalakeBaseArrowTypeHandler +from dagster_delta.io_manager.base import ( + BaseDeltaLakeIOManager as BaseDeltaLakeIOManager, +) + + +class DeltaLakePyarrowIOManager(BaseDeltaLakeIOManager): # noqa: D101 + @staticmethod + def type_handlers() -> Sequence[DbTypeHandler]: # noqa: D102 + return [_DeltaLakePyArrowTypeHandler()] + + +class _DeltaLakePyArrowTypeHandler(DeltalakeBaseArrowTypeHandler[ArrowTypes]): # noqa: D101 + def from_arrow(self, obj: pa.RecordBatchReader, target_type: type[ArrowTypes]) -> ArrowTypes: # noqa: D102 + if target_type == pa.Table: + return obj.read_all() + return obj + + def to_arrow(self, obj: ArrowTypes) -> tuple[ArrowTypes, dict[str, Any]]: # noqa: D102 + if isinstance(obj, ds.Dataset): + return obj.scanner().to_reader(), {} + return obj, {} + + def get_output_stats(self, obj: ArrowTypes) -> dict[str, dg.MetadataValue]: # noqa: ARG002 + """Returns output stats to be attached to the the context. + + Args: + obj (ArrowTypes): Union[pa.Table, pa.RecordBatchReader, ds.Dataset] + + Returns: + Mapping[str, MetadataValue]: metadata stats + """ + return {} + + @property + def supported_types(self) -> Sequence[type[object]]: + """Returns the supported dtypes for this typeHandler""" + return [pa.Table, pa.RecordBatchReader, ds.Dataset] diff --git a/libraries/dagster-delta/dagster_delta/io_manager.py b/libraries/dagster-delta/dagster_delta/io_manager/base.py similarity index 81% rename from libraries/dagster-delta/dagster_delta/io_manager.py rename to libraries/dagster-delta/dagster_delta/io_manager/base.py index ff754e9..91793ea 100644 --- a/libraries/dagster-delta/dagster_delta/io_manager.py +++ b/libraries/dagster-delta/dagster_delta/io_manager/base.py @@ -11,21 +11,27 @@ from dagster._core.definitions.time_window_partitions import TimeWindow from dagster._core.storage.db_io_manager import ( DbClient, - DbIOManager, DbTypeHandler, TablePartitionDimension, TableSlice, ) from pydantic import Field -from .dbiomanager_fixed import DbIOManagerFixed +from dagster_delta._db_io_manager import CustomDbIOManager if sys.version_info >= (3, 11): from typing import NotRequired else: from typing_extensions import NotRequired -from .config import AzureConfig, ClientConfig, GcsConfig, LocalConfig, MergeConfig, S3Config +from dagster_delta.config import ( + AzureConfig, + ClientConfig, + GcsConfig, + LocalConfig, + MergeConfig, + S3Config, +) DELTA_DATETIME_FORMAT = "%Y-%m-%d %H:%M:%S" DELTA_DATE_FORMAT = "%Y-%m-%d" @@ -60,7 +66,7 @@ class SchemaMode(str, Enum): """Deltalake schema mode""" overwrite = "overwrite" - append = "append" + merge = "merge" class WriterEngine(str, Enum): @@ -85,56 +91,56 @@ class _DeltaTableIOManagerResourceConfig(TypedDict): parquet_read_options: NotRequired[dict[str, Any]] -class DeltaLakeIOManager(ConfigurableIOManagerFactory): +class BaseDeltaLakeIOManager(ConfigurableIOManagerFactory): """Base class for an IO manager definition that reads inputs from and writes outputs to Delta Lake. Examples: - .. code-block:: python - - from dagster_delta import DeltaLakeIOManager - from dagster_delta_polars import DeltaLakePolarsTypeHandler + ```python + from dagster_delta import BaseDeltaLakeIOManager + from dagster_delta.io_manager.arrow import _DeltaLakePyArrowTypeHandler - class MyDeltaLakeIOManager(DeltaLakeIOManager): - @staticmethod - def type_handlers() -> Sequence[DbTypeHandler]: - return [DeltaLakePolarsTypeHandler()] + class MyDeltaLakeIOManager(BaseDeltaLakeIOManager): + @staticmethod + def type_handlers() -> Sequence[DbTypeHandler]: + return [_DeltaLakePyArrowTypeHandler()] - @asset( - key_prefix=["my_schema"] # will be used as the schema (parent folder) in Delta Lake - ) - def my_table() -> pl.DataFrame: # the name of the asset will be the table name - ... + @asset( + key_prefix=["my_schema"] # will be used as the schema (parent folder) in Delta Lake + ) + def my_table() -> pa.Table: # the name of the asset will be the table name + ... - defs = Definitions( - assets=[my_table], - resources={"io_manager": MyDeltaLakeIOManager()} - ) + defs = Definitions( + assets=[my_table], + resources={"io_manager": MyDeltaLakeIOManager()} + ) + ``` If you do not provide a schema, Dagster will determine a schema based on the assets and ops using the I/O Manager. For assets, the schema will be determined from the asset key, as in the above example. For ops, the schema can be specified by including a "schema" entry in output metadata. If none of these is provided, the schema will default to "public". - .. code-block:: python + ```python - @op( - out={"my_table": Out(metadata={"schema": "my_schema"})} - ) - def make_my_table() -> pd.DataFrame: - ... + @op( + out={"my_table": Out(metadata={"schema": "my_schema"})} + ) + def make_my_table() -> pa.Table: + ... + ``` To only use specific columns of a table as input to a downstream op or asset, add the metadata "columns" to the In or AssetIn. - .. code-block:: python - + ```python @asset( ins={"my_table": AssetIn("my_table", metadata={"columns": ["a"]})} ) - def my_table_a(my_table: pd.DataFrame): + def my_table_a(my_table: pa.Table: # my_table will just contain the data from column "a" ... - + ``` """ root_uri: str = Field(description="Storage location where Delta tables are stored.") @@ -203,9 +209,9 @@ def type_handlers() -> Sequence[DbTypeHandler]: # noqa: D102 def default_load_type() -> Optional[type]: # noqa: D102 return None - def create_io_manager(self, context) -> DbIOManager: # noqa: D102, ANN001, ARG002 + def create_io_manager(self, context) -> CustomDbIOManager: # noqa: D102, ANN001, ARG002 self.storage_options.model_dump() - return DbIOManagerFixed( + return CustomDbIOManager( db_client=DeltaLakeDbClient(), database="deltalake", schema=self.schema_, @@ -258,24 +264,25 @@ def connect( # noqa: D102 root_uri = resource_config["root_uri"].rstrip("/") storage_options = resource_config["storage_options"] + # Values of a config are unpacked into a dict[str,Any], we convert it back to the Config + # so that we can do str_dict() if "local" in storage_options: - storage_options = storage_options["local"] + storage_options = LocalConfig(**storage_options["local"]) # type: ignore elif "s3" in storage_options: - storage_options = storage_options["s3"] + storage_options = S3Config(**storage_options["s3"]) # type: ignore elif "azure" in storage_options: - storage_options = storage_options["azure"] + storage_options = AzureConfig(**storage_options["azure"]) # type: ignore elif "gcs" in storage_options: - storage_options = storage_options["gcs"] + storage_options = GcsConfig(storage_options["gcs"]) # type: ignore else: - storage_options = {} + raise NotImplementedError("No valid storage_options config found.") client_options = resource_config.get("client_options") - client_options = client_options or {} + client_options = ( + ClientConfig(**client_options).str_dict() if client_options is not None else {} # type: ignore + ) - storage_options = { - **{k: str(v) for k, v in storage_options.items() if v is not None}, - **{k: str(v) for k, v in client_options.items() if v is not None}, - } + options = {**storage_options.str_dict(), **client_options} table_config = resource_config.get("table_config") # Ignore schema if None or empty string, useful to set schema = "" which overrides the assetkey @@ -285,7 +292,7 @@ def connect( # noqa: D102 table_uri = f"{root_uri}/{table_slice.table}" conn = TableConnection( table_uri=table_uri, - storage_options=storage_options or {}, + storage_options=options or {}, table_config=table_config, ) diff --git a/libraries/dagster-delta/dagster_delta/io_manager/polars.py b/libraries/dagster-delta/dagster_delta/io_manager/polars.py new file mode 100644 index 0000000..25d4013 --- /dev/null +++ b/libraries/dagster-delta/dagster_delta/io_manager/polars.py @@ -0,0 +1,210 @@ +import logging +from collections.abc import Sequence +from typing import Any, Optional, Union + +try: + import polars as pl +except ImportError as e: + raise ImportError( + "Please install dagster-delta[polars]", + ) from e +import pyarrow as pa +import pyarrow.dataset as ds +from dagster import InputContext, MetadataValue, OutputContext +from dagster._core.storage.db_io_manager import ( + DbTypeHandler, + TableSlice, +) + +from dagster_delta._handler.base import ( + DeltalakeBaseArrowTypeHandler, +) +from dagster_delta._handler.utils import extract_date_format_from_partition_definition, read_table +from dagster_delta.io_manager.arrow import _DeltaLakePyArrowTypeHandler +from dagster_delta.io_manager.base import BaseDeltaLakeIOManager, TableConnection + +PolarsTypes = Union[pl.DataFrame, pl.LazyFrame] + + +class _DeltaLakePolarsTypeHandler(DeltalakeBaseArrowTypeHandler[PolarsTypes]): # noqa: D101 + def from_arrow( # noqa: D102 + self, + obj: Union[ds.Dataset, pa.RecordBatchReader], + target_type: type[PolarsTypes], + ) -> PolarsTypes: + if isinstance(obj, pa.RecordBatchReader): + return pl.DataFrame(obj.read_all()) + elif isinstance(obj, ds.Dataset): + df = pl.scan_pyarrow_dataset(obj) + if target_type == pl.DataFrame: + return df.collect() + else: + return df + else: + raise NotImplementedError("Unsupported objected passed of type: %s", type(obj)) + + def to_arrow(self, obj: PolarsTypes) -> tuple[pa.Table, dict[str, Any]]: # noqa: D102 + if isinstance(obj, pl.LazyFrame): + obj = obj.collect() + + logger = logging.getLogger() + logger.setLevel("DEBUG") + logger.debug("shape of dataframe: %s", obj.shape) + # TODO(ion): maybe move stats collection here + + return obj.to_arrow(), {"large_dtypes": True} + + def load_input( + self, + context: InputContext, + table_slice: TableSlice, + connection: TableConnection, + ) -> PolarsTypes: + """Loads the input as a Polars DataFrame or LazyFrame.""" + definition_metadata = ( + context.definition_metadata if context.definition_metadata is not None else {} + ) + date_format = extract_date_format_from_partition_definition(context) + + parquet_read_options = None + if context.resource_config is not None: + parquet_read_options = context.resource_config.get("parquet_read_options", None) + parquet_read_options = ( + ds.ParquetReadOptions(**parquet_read_options) + if parquet_read_options is not None + else None + ) + + dataset = read_table( + table_slice, + connection, + version=definition_metadata.get("table_version"), + date_format=date_format, + parquet_read_options=parquet_read_options, + ) + + if table_slice.columns is not None: + if context.dagster_type.typing_type == pl.LazyFrame: + return self.from_arrow(dataset, context.dagster_type.typing_type).select( + table_slice.columns, + ) + else: + scanner = dataset.scanner(columns=table_slice.columns) + return self.from_arrow(scanner.to_reader(), context.dagster_type.typing_type) + else: + return self.from_arrow(dataset, context.dagster_type.typing_type) + + def handle_output( + self, + context: OutputContext, + table_slice: TableSlice, + obj: Union[pl.DataFrame, pl.LazyFrame], + connection: TableConnection, + ): + """Writes polars frame as delta table""" + super().handle_output(context, table_slice, obj, connection) + metadata = {**context.consume_logged_metadata()} + + if connection.table_uri.startswith("lakefs://"): + # We grab the lakefs endpoint from the object storage options + for key, value in connection.storage_options.items(): + if key.lower() in ["endpoint", "aws_endpoint", "aws_endpoint_url", "endpoint_url"]: + metadata["lakefs_link"] = MetadataValue.url( + _convert_uri_to_lakefs_link(connection.table_uri, value), + ) + break + context.add_output_metadata(metadata) + + def get_output_stats(self, obj: PolarsTypes) -> dict[str, MetadataValue]: + """Returns output stats to be attached to the the context. + + Args: + obj (PolarsTypes): LazyFrame or DataFrame + + Returns: + Mapping[str, MetadataValue]: metadata stats + """ + stats = {} + # TODO(ion): think of more meaningful stats to add from a dataframe + if isinstance(obj, pl.DataFrame): + stats["num_rows_in_source"] = MetadataValue.int(obj.shape[0]) + + return stats + + @property + def supported_types(self) -> Sequence[type[object]]: + """Returns the supported dtypes for this typeHandler""" + return [pl.DataFrame, pl.LazyFrame] + + +class DeltaLakePolarsIOManager(BaseDeltaLakeIOManager): + """Base class for an IO manager definition that reads inputs from and writes outputs to Delta Lake. + + Examples: + .. code-block:: python + + from dagster_delta import DeltaLakePolarsIOManager + + @asset( + key_prefix=["my_schema"] # will be used as the schema (parent folder) in Delta Lake + ) + def my_table() -> pl.DataFrame: # the name of the asset will be the table name + ... + + defs = Definitions( + assets=[my_table], + resources={"io_manager": DeltaLakePolarsIOManager()} + ) + + If you do not provide a schema, Dagster will determine a schema based on the assets and ops using + the I/O Manager. For assets, the schema will be determined from the asset key, as in the above example. + For ops, the schema can be specified by including a "schema" entry in output metadata. If none + of these is provided, the schema will default to "public". + + .. code-block:: python + + @op( + out={"my_table": Out(metadata={"schema": "my_schema"})} + ) + def make_my_table() -> pl.DataFrame: + ... + + To only use specific columns of a table as input to a downstream op or asset, add the metadata "columns" to the + In or AssetIn. + + .. code-block:: python + + @asset( + ins={"my_table": AssetIn("my_table", metadata={"columns": ["a"]})} + ) + def my_table_a(my_table: pl.DataFrame): + # my_table will just contain the data from column "a" + ... + + """ + + @staticmethod + def type_handlers() -> Sequence[DbTypeHandler]: + """Returns all available type handlers on this IO Manager.""" + return [_DeltaLakePolarsTypeHandler(), _DeltaLakePyArrowTypeHandler()] + + @staticmethod + def default_load_type() -> Optional[type]: + """Grabs the default load type if no type hint is passed.""" + return pl.DataFrame + + +def _convert_uri_to_lakefs_link(uri: str, lakefs_base_url: str) -> str: + """Convert an S3 uri to a link to lakefs""" + from urllib.parse import quote + + uri = uri[len("lakefs://") :] + parts = uri.split("/", 2) + if len(parts) < 3: + return "https://error-invalid-s3-uri-format" + repository = parts[0] + ref = parts[1] + path = parts[2] + encoded_path = quote(path + "/") + https_url = f"{lakefs_base_url.rstrip('/')}/repositories/{repository}/objects?ref={ref}&path={encoded_path}" + return https_url diff --git a/libraries/dagster-delta/dagster_delta/resource.py b/libraries/dagster-delta/dagster_delta/resource.py deleted file mode 100644 index 632a6d2..0000000 --- a/libraries/dagster-delta/dagster_delta/resource.py +++ /dev/null @@ -1,77 +0,0 @@ -from typing import Optional, Union - -from dagster import ConfigurableResource -from deltalake import DeltaTable -from pydantic import Field - -from .config import AzureConfig, ClientConfig, GcsConfig, LocalConfig, S3Config - - -class DeltaTableResource(ConfigurableResource): - """Resource for interacting with a Delta table. - - Examples: - .. code-block:: python - - from dagster import Definitions, asset - from dagster_delta import DeltaTableResource, LocalConfig - - @asset - def my_table(delta_table: DeltaTableResource): - df = delta_table.load().to_pandas() - - defs = Definitions( - assets=[my_table], - resources={ - "delta_table": DeltaTableResource( - url="/path/to/table", - storage_options=LocalConfig() - ) - } - ) - - """ - - url: str - - storage_options: Union[AzureConfig, S3Config, LocalConfig, GcsConfig] = Field( # noqa: UP007 - discriminator="provider", - ) - - client_options: Optional[ClientConfig] = Field( - default=None, - description="Additional configuration passed to http client.", - ) - - version: Optional[int] = Field(default=None, description="Version to load delta table.") - - def load(self) -> DeltaTable: - """Loads the table with passed configuration. - - Returns: - DeltaTable: table - """ - storage_options = self.storage_options - if "local" in storage_options: - storage_options = storage_options["local"] - elif "s3" in storage_options: - storage_options = storage_options["s3"] - elif "azure" in storage_options: - storage_options = storage_options["azure"] - elif "gcs" in storage_options: - storage_options = storage_options["gcs"] - else: - storage_options = {} - - client_options = self.client_options.dict() if self.client_options is not None else {} - - storage_options = { - **{k: str(v) for k, v in storage_options.items() if v is not None}, - **{k: str(v) for k, v in client_options.items() if v is not None}, - } - table = DeltaTable( - table_uri=self.url, - storage_options=storage_options, - version=self.version, - ) - return table diff --git a/libraries/dagster-delta/dagster_delta/resources.py b/libraries/dagster-delta/dagster_delta/resources.py new file mode 100644 index 0000000..024bbd0 --- /dev/null +++ b/libraries/dagster-delta/dagster_delta/resources.py @@ -0,0 +1,62 @@ +from typing import Optional, Union + +from dagster import ConfigurableResource +from deltalake import DeltaTable +from pydantic import Field + +from dagster_delta.config import AzureConfig, ClientConfig, GcsConfig, LocalConfig, S3Config + + +class DeltaTableResource(ConfigurableResource): + """Resource for interacting with a Delta table. + + Examples: + ```python + from dagster import Definitions, asset + from dagster_delta import DeltaTableResource, LocalConfig + + @asset + def my_table(delta_table: DeltaTableResource): + df = delta_table.load().to_pandas() + + defs = Definitions( + assets=[my_table], + resources={ + "delta_table": DeltaTableResource( + url="/path/to/table", + storage_options=LocalConfig() + ) + } + ) + ``` + """ + + url: str + + storage_options: Union[AzureConfig, S3Config, LocalConfig, GcsConfig] = Field( # noqa: UP007 + discriminator="provider", + ) + + client_options: Optional[ClientConfig] = Field( + default=None, + description="Additional configuration passed to http client.", + ) + + version: Optional[int] = Field(default=None, description="Version to load delta table.") + + def load(self) -> DeltaTable: + """Loads the table with passed configuration. + + Returns: + DeltaTable: table + """ + storage_options = self.storage_options.str_dict() + client_options = self.client_options.str_dict() if self.client_options else {} + options = {**storage_options, **client_options} + + table = DeltaTable( + table_uri=self.url, + storage_options=options, + version=self.version, + ) + return table diff --git a/libraries/dagster-delta/dagster_delta_tests/_db_io_manager/conftest.py b/libraries/dagster-delta/dagster_delta_tests/_db_io_manager/conftest.py new file mode 100644 index 0000000..8054737 --- /dev/null +++ b/libraries/dagster-delta/dagster_delta_tests/_db_io_manager/conftest.py @@ -0,0 +1,60 @@ +import pytest +from dagster import ( + DailyPartitionsDefinition, + MultiPartitionsDefinition, + StaticPartitionsDefinition, +) + +from dagster_delta import DeltaLakePyarrowIOManager +from dagster_delta.config import LocalConfig + + +@pytest.fixture +def io_manager( + tmp_path, +) -> DeltaLakePyarrowIOManager: + return DeltaLakePyarrowIOManager( + root_uri=str(tmp_path), + storage_options=LocalConfig(), + ) + + +@pytest.fixture +def daily_partitions_definition() -> DailyPartitionsDefinition: + return DailyPartitionsDefinition(start_date="2022-01-01", end_date="2022-01-10") + + +@pytest.fixture +def letter_partitions_definition() -> StaticPartitionsDefinition: + return StaticPartitionsDefinition(["a", "b", "c"]) + + +@pytest.fixture +def color_partitions_definition() -> StaticPartitionsDefinition: + return StaticPartitionsDefinition(["red", "blue", "yellow"]) + + +@pytest.fixture +def multi_partition_with_letter( + daily_partitions_definition: DailyPartitionsDefinition, + letter_partitions_definition: StaticPartitionsDefinition, +) -> MultiPartitionsDefinition: + return MultiPartitionsDefinition( + partitions_defs={ + "date": daily_partitions_definition, + "letter": letter_partitions_definition, + }, + ) + + +@pytest.fixture +def multi_partition_with_color( + daily_partitions_definition: DailyPartitionsDefinition, + color_partitions_definition: StaticPartitionsDefinition, +) -> MultiPartitionsDefinition: + return MultiPartitionsDefinition( + partitions_defs={ + "date": daily_partitions_definition, + "color": color_partitions_definition, + }, + ) diff --git a/libraries/dagster-delta/dagster_delta_tests/_db_io_manager/test_db_io_manager.py b/libraries/dagster-delta/dagster_delta_tests/_db_io_manager/test_db_io_manager.py new file mode 100644 index 0000000..1c703dd --- /dev/null +++ b/libraries/dagster-delta/dagster_delta_tests/_db_io_manager/test_db_io_manager.py @@ -0,0 +1,400 @@ +import datetime as dt +import os +import warnings +from datetime import datetime + +import dagster as dg +import pyarrow as pa +from dagster import ( + AssetExecutionContext, + AssetIn, + DailyPartitionsDefinition, + DimensionPartitionMapping, + MultiPartitionMapping, + MultiPartitionsDefinition, + MultiToSingleDimensionPartitionMapping, + SpecificPartitionsPartitionMapping, + StaticPartitionMapping, + StaticPartitionsDefinition, + asset, + materialize, +) +from deltalake import DeltaTable + +from dagster_delta import DeltaLakePyarrowIOManager + +warnings.filterwarnings("ignore", category=dg.ExperimentalWarning) + + +daily_partitions_def = DailyPartitionsDefinition( + start_date="2022-01-01", + end_date="2022-01-10", +) + +letter_partitions_def = StaticPartitionsDefinition(["a", "b", "c"]) + +color_partitions_def = StaticPartitionsDefinition(["red", "blue", "yellow"]) + +multi_partition_with_letter = MultiPartitionsDefinition( + partitions_defs={ + "date": daily_partitions_def, + "letter": letter_partitions_def, + }, +) + +multi_partition_with_color = MultiPartitionsDefinition( + partitions_defs={ + "date": daily_partitions_def, + "color": color_partitions_def, + }, +) + + +@asset( + key_prefix=["my_schema"], +) +def asset_1() -> pa.Table: + return pa.Table.from_pydict( + { + "value": [1], + "b": [1], + }, + ) + + +@asset( + key_prefix=["my_schema"], +) +def asset_2(asset_1: pa.Table) -> pa.Table: + return asset_1 + + +# case: we have multiple partitions +@asset( + key_prefix=["my_schema"], + partitions_def=multi_partition_with_color, + metadata={ + "partition_expr": { + "date": "date_column", + "color": "color_column", + }, + }, +) +def multi_partitioned_asset_1(context: AssetExecutionContext) -> pa.Table: + color, date = context.partition_key.split("|") + date_parsed = dt.datetime.strptime(date, "%Y-%m-%d").date() + + return pa.Table.from_pydict( + { + "date_column": [date_parsed], + "value": [1], + "b": [1], + "color_column": [color], + }, + ) + + +# Multi-to-multi asset is supported +@asset( + key_prefix=["my_schema"], + partitions_def=multi_partition_with_color, + metadata={ + "partition_expr": { + "date": "date_column", + "color": "color_column", + }, + }, +) +def multi_partitioned_asset_2(multi_partitioned_asset_1: pa.Table) -> pa.Table: + return multi_partitioned_asset_1 + + +@asset( + key_prefix=["my_schema"], +) +def non_partitioned_asset(multi_partitioned_asset_1: pa.Table) -> pa.Table: + return multi_partitioned_asset_1 + + +# Multi-to-single asset is supported through MultiToSingleDimensionPartitionMapping +@asset( + key_prefix=["my_schema"], + partitions_def=daily_partitions_def, + ins={ + "multi_partitioned_asset": AssetIn( + ["my_schema", "multi_partitioned_asset_1"], + partition_mapping=MultiToSingleDimensionPartitionMapping( + partition_dimension_name="date", + ), + ), + }, + metadata={ + "partition_expr": "date_column", + }, +) +def single_partitioned_asset_date(multi_partitioned_asset: pa.Table) -> pa.Table: + return multi_partitioned_asset + + +@asset( + key_prefix=["my_schema"], + partitions_def=color_partitions_def, + ins={ + "multi_partitioned_asset": AssetIn( + ["my_schema", "multi_partitioned_asset_1"], + partition_mapping=MultiToSingleDimensionPartitionMapping( + partition_dimension_name="color", + ), + ), + }, + metadata={ + "partition_expr": "color_column", + }, +) +def single_partitioned_asset_color(multi_partitioned_asset: pa.Table) -> pa.Table: + return multi_partitioned_asset + + +@asset( + partitions_def=multi_partition_with_letter, + key_prefix=["my_schema"], + metadata={"partition_expr": {"date": "date_column", "letter": "letter"}}, + ins={ + "multi_partitioned_asset": AssetIn( + ["my_schema", "multi_partitioned_asset_1"], + partition_mapping=MultiPartitionMapping( + { + "color": DimensionPartitionMapping( + dimension_name="letter", + partition_mapping=StaticPartitionMapping( + {"blue": "a", "red": "b", "yellow": "c"}, + ), + ), + "date": DimensionPartitionMapping( + dimension_name="date", + partition_mapping=SpecificPartitionsPartitionMapping( + ["2022-01-01", "2024-01-01"], + ), + ), + }, + ), + ), + }, +) +def mapped_multi_partition( + context: AssetExecutionContext, + multi_partitioned_asset: pa.Table, +) -> pa.Table: + _, letter = context.partition_key.split("|") + + table_ = multi_partitioned_asset.append_column("letter", pa.array([letter])) + return table_ + + +def test_unpartitioned_asset_to_unpartitioned_asset( + io_manager: DeltaLakePyarrowIOManager, +): + warnings.filterwarnings("ignore", category=dg.ExperimentalWarning) + resource_defs = {"io_manager": io_manager} + + res = materialize([asset_1, asset_2], resources=resource_defs) + assert res.success + + asset_1_data = res.asset_value(asset_1.key) + asset_2_data = res.asset_value(asset_2.key) + assert asset_1_data == asset_2_data + + +def test_multi_partitioned_to_multi_partitioned_asset( + tmp_path, + io_manager: DeltaLakePyarrowIOManager, +): + warnings.filterwarnings("ignore", category=dg.ExperimentalWarning) + resource_defs = {"io_manager": io_manager} + + multi_partitioned_asset_1_data_all = [] + multi_partitioned_asset_2_data_all = [] + + for partition_key in ["red|2022-01-01", "red|2022-01-02", "red|2022-01-03"]: + res = materialize( + [multi_partitioned_asset_1, multi_partitioned_asset_2], + partition_key=partition_key, + resources=resource_defs, + ) + assert res.success + + multi_partitioned_asset_1_data = res.asset_value(multi_partitioned_asset_1.key) + multi_partitioned_asset_2_data = res.asset_value(multi_partitioned_asset_2.key) + assert multi_partitioned_asset_1_data == multi_partitioned_asset_2_data + + multi_partitioned_asset_1_data_all.append(multi_partitioned_asset_1_data) + multi_partitioned_asset_2_data_all.append(multi_partitioned_asset_2_data) + + dt = DeltaTable(os.path.join(str(tmp_path), "/".join(multi_partitioned_asset_1.key.path))) + + assert dt.to_pyarrow_table().sort_by("date_column") == pa.concat_tables( + multi_partitioned_asset_1_data_all, + ) + + dt = DeltaTable(os.path.join(str(tmp_path), "/".join(multi_partitioned_asset_2.key.path))) + assert dt.metadata().partition_columns == ["color_column", "date_column"] + assert dt.to_pyarrow_table().sort_by("date_column") == pa.concat_tables( + multi_partitioned_asset_2_data_all, + ) + + +def test_multi_partitioned_to_single_partitioned_asset_colors( + tmp_path, + io_manager: DeltaLakePyarrowIOManager, +): + warnings.filterwarnings("ignore", category=dg.ExperimentalWarning) + resource_defs = {"io_manager": io_manager} + + multi_partitioned_asset_1_data_all = [] + + for partition_key in ["red|2022-01-01", "blue|2022-01-01", "yellow|2022-01-01"]: + res = materialize( + [multi_partitioned_asset_1], + partition_key=partition_key, + resources=resource_defs, + ) + assert res.success + + multi_partitioned_asset_1_data = res.asset_value(multi_partitioned_asset_1.key) + multi_partitioned_asset_1_data_all.append(multi_partitioned_asset_1_data) + + res = materialize( + [multi_partitioned_asset_1, single_partitioned_asset_date], + partition_key="2022-01-01", + resources=resource_defs, + selection=[single_partitioned_asset_date], + ) + assert res.success + + single_partitioned_asset_date_data = res.asset_value(single_partitioned_asset_date.key) + + assert single_partitioned_asset_date_data.sort_by("color_column") == pa.concat_tables( + multi_partitioned_asset_1_data_all, + ).sort_by("color_column") + + dt = DeltaTable(os.path.join(str(tmp_path), "/".join(single_partitioned_asset_date.key.path))) + assert dt.metadata().partition_columns == ["date_column"] + assert single_partitioned_asset_date_data == dt.to_pyarrow_table() + + +def test_multi_partitioned_to_single_partitioned_asset_dates( + tmp_path, + io_manager: DeltaLakePyarrowIOManager, +): + warnings.filterwarnings("ignore", category=dg.ExperimentalWarning) + resource_defs = {"io_manager": io_manager} + + multi_partitioned_asset_1_data_all = [] + + for partition_key in ["red|2022-01-01", "red|2022-01-02", "red|2022-01-03"]: + res = materialize( + [multi_partitioned_asset_1], + partition_key=partition_key, + resources=resource_defs, + ) + assert res.success + + multi_partitioned_asset_1_data = res.asset_value(multi_partitioned_asset_1.key) + multi_partitioned_asset_1_data_all.append(multi_partitioned_asset_1_data) + + res = materialize( + [multi_partitioned_asset_1, single_partitioned_asset_color], + partition_key="red", + resources=resource_defs, + selection=[single_partitioned_asset_color], + ) + assert res.success + + single_partitioned_asset_color_data = res.asset_value(single_partitioned_asset_color.key) + + assert single_partitioned_asset_color_data.sort_by("date_column") == pa.concat_tables( + multi_partitioned_asset_1_data_all, + ).sort_by("date_column") + + dt = DeltaTable(os.path.join(str(tmp_path), "/".join(single_partitioned_asset_color.key.path))) + assert dt.metadata().partition_columns == ["color_column"] + + assert single_partitioned_asset_color_data == dt.to_pyarrow_table() + + +def test_multi_partitioned_to_non_partitioned_asset( + tmp_path, + io_manager: DeltaLakePyarrowIOManager, +): + warnings.filterwarnings("ignore", category=dg.ExperimentalWarning) + resource_defs = {"io_manager": io_manager} + + multi_partitioned_asset_1_data_all = [] + + for partition_key in ["red|2022-01-01", "red|2022-01-02", "red|2022-01-03"]: + res = materialize( + [multi_partitioned_asset_1], + partition_key=partition_key, + resources=resource_defs, + ) + assert res.success + + multi_partitioned_asset_1_data = res.asset_value(multi_partitioned_asset_1.key) + multi_partitioned_asset_1_data_all.append(multi_partitioned_asset_1_data) + + res = materialize( + [multi_partitioned_asset_1, non_partitioned_asset], + resources=resource_defs, + selection=[non_partitioned_asset], + ) + assert res.success + + non_partitioned_asset_data = res.asset_value(non_partitioned_asset.key) + + assert non_partitioned_asset_data.sort_by("date_column") == pa.concat_tables( + multi_partitioned_asset_1_data_all, + ).sort_by("date_column") + + dt = DeltaTable(os.path.join(str(tmp_path), "/".join(non_partitioned_asset.key.path))) + assert dt.metadata().partition_columns == [] + + assert non_partitioned_asset_data == dt.to_pyarrow_table() + + +def test_multi_partitioned_to_multi_partitioned_with_different_dimensions( + tmp_path, + io_manager: DeltaLakePyarrowIOManager, +): + warnings.filterwarnings("ignore", category=dg.ExperimentalWarning) + resource_defs = {"io_manager": io_manager} + + for partition_key in ["red|2022-01-01", "blue|2022-01-01", "yellow|2022-01-01"]: + res = materialize( + [multi_partitioned_asset_1], + partition_key=partition_key, + resources=resource_defs, + ) + assert res.success + + res = materialize( + [multi_partitioned_asset_1, mapped_multi_partition], + partition_key="2022-01-01|a", + resources=resource_defs, + selection=[mapped_multi_partition], + ) + + date_parsed = datetime.strptime("2022-01-01", "%Y-%m-%d").date() + + expected = pa.Table.from_pydict( + { + "date_column": [date_parsed], + "value": [1], + "b": [1], + "color_column": ["blue"], + "letter": ["a"], + }, + ) + + dt = DeltaTable(os.path.join(str(tmp_path), "/".join(mapped_multi_partition.key.path))) + assert dt.metadata().partition_columns == ["date_column", "letter"] + assert expected == dt.to_pyarrow_table() diff --git a/libraries/dagster-delta/dagster_delta_tests/_db_io_manager/test_db_io_manager_utils.py b/libraries/dagster-delta/dagster_delta_tests/_db_io_manager/test_db_io_manager_utils.py new file mode 100644 index 0000000..5aee54f --- /dev/null +++ b/libraries/dagster-delta/dagster_delta_tests/_db_io_manager/test_db_io_manager_utils.py @@ -0,0 +1,173 @@ +import datetime as dt + +import pytest +from dagster import AssetKey, MultiPartitionKey, MultiPartitionsDefinition, TimeWindow +from dagster._core.storage.db_io_manager import TablePartitionDimension + +from dagster_delta._db_io_manager import utils + + +# NB: dagster uses dt.datetime even for dates +@pytest.fixture +def daily_partitions_time_window_consecutive() -> list[TimeWindow]: + return [ + TimeWindow( + start=dt.datetime(2022, 1, 1, 0), + end=dt.datetime(2022, 1, 2, 0), + ), + TimeWindow( + start=dt.datetime(2022, 1, 2, 0), + end=dt.datetime(2022, 1, 3, 0), + ), + TimeWindow( + start=dt.datetime(2022, 1, 3, 0), + end=dt.datetime(2022, 1, 4, 0), + ), + ] + + +@pytest.fixture +def daily_partitions_time_window_not_consecutive() -> list[TimeWindow]: + return [ + TimeWindow( + start=dt.datetime(2022, 1, 1, 0), + end=dt.datetime(2022, 1, 2, 0), + ), + TimeWindow( + start=dt.datetime(2022, 1, 2, 0), + end=dt.datetime(2022, 1, 3, 0), + ), + TimeWindow( + start=dt.datetime(2022, 1, 4, 0), + end=dt.datetime(2022, 1, 5, 0), + ), + ] + + +def test_multi_time_partitions_checker_consecutive( + daily_partitions_time_window_consecutive: list[TimeWindow], +): + checker = utils.MultiTimePartitionsChecker( + partitions=daily_partitions_time_window_consecutive, + ) + assert checker.start == dt.datetime(2022, 1, 1, 0) + assert checker.end == dt.datetime(2022, 1, 4, 0) + assert checker.hourly_delta == 24 + assert checker.is_consecutive() + + +def test_multi_time_partitions_checker_non_consecutive( + daily_partitions_time_window_not_consecutive: list[TimeWindow], +): + checker = utils.MultiTimePartitionsChecker( + partitions=daily_partitions_time_window_not_consecutive, + ) + assert checker.hourly_delta == 24 + assert checker.start == dt.datetime(2022, 1, 1, 0) + assert checker.end == dt.datetime(2022, 1, 5, 0) + assert not checker.is_consecutive() + + +def test_generate_single_partition_dimension_static(): + partition_dimension = utils.generate_single_partition_dimension( + partition_expr="color_column", + asset_partition_keys=["red"], + asset_partitions_time_window=None, + ) + assert isinstance(partition_dimension, TablePartitionDimension) + assert partition_dimension.partition_expr == "color_column" + assert partition_dimension.partitions == ["red"] + + +def test_generate_single_partition_dimension_time_window(): + partition_dimension = utils.generate_single_partition_dimension( + partition_expr="date_column", + asset_partition_keys=["2022-01-01"], + asset_partitions_time_window=TimeWindow( + start=dt.datetime(2022, 1, 1, 0), + end=dt.datetime(2022, 1, 2, 0), + ), + ) + assert isinstance(partition_dimension, TablePartitionDimension) + assert isinstance(partition_dimension.partitions, TimeWindow) + assert partition_dimension.partition_expr == "date_column" + assert partition_dimension.partitions.start == dt.datetime(2022, 1, 1, 0) + assert partition_dimension.partitions.end == dt.datetime(2022, 1, 2, 0) + + +def test_generate_partition_dimensions_color_varying( + multi_partition_with_color: MultiPartitionsDefinition, +): + partition_dimensions = utils.generate_multi_partitions_dimension( + asset_key=AssetKey("my_asset"), + # NB: these must be multi partition keys + asset_partition_keys=[ + MultiPartitionKey(keys_by_dimension={"color": "red", "date": "2022-01-01"}), + MultiPartitionKey( + keys_by_dimension={"color": "blue", "date": "2022-01-01"}, + ), + MultiPartitionKey( + keys_by_dimension={"color": "yellow", "date": "2022-01-01"}, + ), + ], + asset_partitions_def=multi_partition_with_color, + partition_expr={ + "date": "date_column", + "color": "color_column", + }, + ) + assert len(partition_dimensions) == 2 + assert partition_dimensions[0].partition_expr == "color_column" + assert partition_dimensions[1].partition_expr == "date_column" + assert partition_dimensions[1].partitions.start == dt.datetime( # type: ignore + 2022, + 1, + 1, + 0, + tzinfo=dt.timezone.utc, + ) + assert partition_dimensions[1].partitions.end == dt.datetime( # type: ignore + 2022, + 1, + 2, + 0, + tzinfo=dt.timezone.utc, + ) + assert sorted(partition_dimensions[0].partitions) == ["blue", "red", "yellow"] + + +def test_generate_partition_dimensions_date_varying( + multi_partition_with_color: MultiPartitionsDefinition, +): + partition_dimensions = utils.generate_multi_partitions_dimension( + asset_key=AssetKey("my_asset"), + # NB: these must be multi partition keys + asset_partition_keys=[ + MultiPartitionKey(keys_by_dimension={"color": "red", "date": "2022-01-01"}), + MultiPartitionKey(keys_by_dimension={"color": "red", "date": "2022-01-02"}), + MultiPartitionKey(keys_by_dimension={"color": "red", "date": "2022-01-03"}), + ], + asset_partitions_def=multi_partition_with_color, + partition_expr={ + "date": "date_column", + "color": "color_column", + }, + ) + assert len(partition_dimensions) == 2 + assert partition_dimensions[0].partition_expr == "color_column" + assert partition_dimensions[1].partition_expr == "date_column" + assert partition_dimensions[1].partitions.start == dt.datetime( # type: ignore + 2022, + 1, + 1, + 0, + tzinfo=dt.timezone.utc, + ) + assert partition_dimensions[1].partitions.end == dt.datetime( # type: ignore + 2022, + 1, + 4, + 0, + tzinfo=dt.timezone.utc, + ) + assert partition_dimensions[0].partitions == ["red"] diff --git a/libraries/dagster-delta/dagster_delta_tests/_db_io_manager/test_root_name.py b/libraries/dagster-delta/dagster_delta_tests/_db_io_manager/test_root_name.py new file mode 100644 index 0000000..43322b3 --- /dev/null +++ b/libraries/dagster-delta/dagster_delta_tests/_db_io_manager/test_root_name.py @@ -0,0 +1,42 @@ +import os +import warnings + +import dagster as dg +import pyarrow as pa +from dagster import ( + asset, + materialize, +) +from deltalake import DeltaTable + +from dagster_delta import DeltaLakePyarrowIOManager + + +@asset( + key_prefix=["my_schema"], + metadata={"root_name": "custom_asset"}, +) +def asset_1() -> pa.Table: + return pa.Table.from_pydict( + { + "value": [1], + "b": [1], + }, + ) + + +def test_asset_with_root_name( + tmp_path, + io_manager: DeltaLakePyarrowIOManager, +): + warnings.filterwarnings("ignore", category=dg.ExperimentalWarning) + resource_defs = {"io_manager": io_manager} + + res = materialize([asset_1], resources=resource_defs) + assert res.success + + data = res.asset_value(asset_1.key) + + dt = DeltaTable(os.path.join(str(tmp_path), "my_schema", "custom_asset")) + + assert data == dt.to_pyarrow_table() diff --git a/libraries/dagster-delta/dagster_delta_tests/polars/test_type_handler.py b/libraries/dagster-delta/dagster_delta_tests/polars/test_type_handler.py new file mode 100644 index 0000000..7c654ed --- /dev/null +++ b/libraries/dagster-delta/dagster_delta_tests/polars/test_type_handler.py @@ -0,0 +1,552 @@ +import os +import warnings +from datetime import datetime + +import polars as pl +import pytest +from dagster import ( + AssetExecutionContext, + AssetIn, + DailyPartitionsDefinition, + DynamicPartitionsDefinition, + ExperimentalWarning, + MultiPartitionKey, + MultiPartitionsDefinition, + Out, + StaticPartitionsDefinition, + asset, + graph, + instance_for_test, + materialize, + op, +) +from dagster._check import CheckError +from deltalake import DeltaTable + +from dagster_delta import DeltaLakePolarsIOManager, LocalConfig, WriteMode +from dagster_delta.io_manager.base import DELTA_DATE_FORMAT + +warnings.filterwarnings("ignore", category=ExperimentalWarning) + + +@pytest.fixture +def io_manager(tmp_path) -> DeltaLakePolarsIOManager: + return DeltaLakePolarsIOManager( + root_uri=str(tmp_path), + storage_options=LocalConfig(), + mode=WriteMode.overwrite, + ) + + +@op(out=Out(metadata={"schema": "a_df"})) +def a_df() -> pl.DataFrame: + return pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + + +@op(out=Out(metadata={"schema": "add_one"})) +def add_one(df: pl.DataFrame) -> pl.DataFrame: + return df + 1 + + +@graph +def add_one_to_dataframe(): + add_one(a_df()) + + +def test_deltalake_io_manager_with_ops(tmp_path, io_manager): + resource_defs = {"io_manager": io_manager} + + job = add_one_to_dataframe.to_job(resource_defs=resource_defs) + + # run the job twice to ensure that tables get properly deleted + for _ in range(2): + res = job.execute_in_process() + + assert res.success + + dt = DeltaTable(os.path.join(tmp_path, "a_df/result")) + out_df = dt.to_pyarrow_table() + assert out_df["a"].to_pylist() == [1, 2, 3] + + dt = DeltaTable(os.path.join(tmp_path, "add_one/result")) + out_df = dt.to_pyarrow_table() + assert out_df["a"].to_pylist() == [2, 3, 4] + + +@asset(key_prefix=["my_schema"]) +def b_df() -> pl.DataFrame: + return pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + + +@asset(key_prefix=["my_schema"]) +def b_plus_one(b_df: pl.DataFrame) -> pl.DataFrame: + return b_df + 1 + + +@asset(key_prefix=["my_schema"]) +def b_df_lazy() -> pl.LazyFrame: + return pl.LazyFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + + +@asset(key_prefix=["my_schema"]) +def b_plus_one_lazy(b_df_lazy: pl.LazyFrame) -> pl.LazyFrame: + return b_df_lazy.select(pl.all() + 1) + + +@pytest.mark.parametrize( + ("asset1", "asset2", "asset1_path", "asset2_path"), + [ + (b_df, b_plus_one, "b_df", "b_plus_one"), + (b_df_lazy, b_plus_one_lazy, "b_df_lazy", "b_plus_one_lazy"), + ], +) +def test_deltalake_io_manager_with_assets( + tmp_path, + io_manager, + asset1, + asset2, + asset1_path, + asset2_path, +): + warnings.filterwarnings("ignore", category=ExperimentalWarning) + resource_defs = {"io_manager": io_manager} + + # materialize asset twice to ensure that tables get properly deleted + for _ in range(2): + res = materialize([asset1, asset2], resources=resource_defs) + assert res.success + + dt = DeltaTable(os.path.join(tmp_path, "my_schema/" + asset1_path)) + out_df = dt.to_pyarrow_table() + assert out_df["a"].to_pylist() == [1, 2, 3] + + dt = DeltaTable(os.path.join(tmp_path, "my_schema/" + asset2_path)) + out_df = dt.to_pyarrow_table() + assert out_df["a"].to_pylist() == [2, 3, 4] + + +def test_deltalake_io_manager_with_schema(tmp_path): + @asset + def my_df() -> pl.DataFrame: + return pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + + @asset + def my_df_plus_one(my_df: pl.DataFrame) -> pl.DataFrame: + return my_df + 1 + + io_manager = DeltaLakePolarsIOManager( + root_uri=str(tmp_path), + storage_options=LocalConfig(), + schema="custom_schema", + ) + + resource_defs = {"io_manager": io_manager} + + # materialize asset twice to ensure that tables get properly deleted + for _ in range(2): + res = materialize([my_df, my_df_plus_one], resources=resource_defs) + assert res.success + + dt = DeltaTable(os.path.join(tmp_path, "custom_schema/my_df")) + out_df = dt.to_pyarrow_table() + assert out_df["a"].to_pylist() == [1, 2, 3] + + dt = DeltaTable(os.path.join(tmp_path, "custom_schema/my_df_plus_one")) + out_df = dt.to_pyarrow_table() + assert out_df["a"].to_pylist() == [2, 3, 4] + + +@asset(key_prefix=["my_schema"], ins={"b_df": AssetIn("b_df", metadata={"columns": ["a"]})}) +def b_plus_one_columns(b_df: pl.DataFrame) -> pl.DataFrame: + return b_df + 1 + + +@asset( + key_prefix=["my_schema"], + ins={"b_df_lazy": AssetIn("b_df_lazy", metadata={"columns": ["a"]})}, +) +def b_plus_one_columns_lazy(b_df_lazy: pl.LazyFrame) -> pl.LazyFrame: + return b_df_lazy.select(pl.all() + 1) + + +@pytest.mark.parametrize( + ("asset1", "asset2", "asset1_path", "asset2_path"), + [ + (b_df, b_plus_one_columns, "b_df", "b_plus_one_columns"), + (b_df_lazy, b_plus_one_columns_lazy, "b_df_lazy", "b_plus_one_columns_lazy"), + ], +) +def test_loading_columns(tmp_path, io_manager, asset1, asset2, asset1_path, asset2_path): + warnings.filterwarnings("ignore", category=ExperimentalWarning) + resource_defs = {"io_manager": io_manager} + + # materialize asset twice to ensure that tables get properly deleted + for _ in range(2): + res = materialize([asset1, asset2], resources=resource_defs) + assert res.success + + dt = DeltaTable(os.path.join(tmp_path, "my_schema/" + asset1_path)) + out_df = dt.to_pyarrow_table() + assert out_df["a"].to_pylist() == [1, 2, 3] + + dt = DeltaTable(os.path.join(tmp_path, "my_schema/" + asset2_path)) + out_df = dt.to_pyarrow_table() + assert out_df["a"].to_pylist() == [2, 3, 4] + + assert out_df.shape[1] == 1 + + +@op +def non_supported_type() -> int: + return 1 + + +@graph +def not_supported(): + non_supported_type() + + +def test_not_supported_type(io_manager): + resource_defs = {"io_manager": io_manager} + + job = not_supported.to_job(resource_defs=resource_defs) + + with pytest.raises( + CheckError, + match="DeltaLakeIOManager does not have a handler for type ''", + ): + job.execute_in_process() + + +@asset( + partitions_def=DailyPartitionsDefinition(start_date="2022-01-01"), + key_prefix=["my_schema"], + metadata={"partition_expr": "time"}, + config_schema={"value": str}, +) +def daily_partitioned(context: AssetExecutionContext) -> pl.DataFrame: + partition = datetime.strptime(context.partition_key, DELTA_DATE_FORMAT).date() + value = context.op_execution_context.op_config["value"] + + return pl.DataFrame( + { + "time": [partition, partition, partition], + "a": [value, value, value], + "b": [4, 5, 6], + }, + ) + + +@asset( + partitions_def=DailyPartitionsDefinition(start_date="2022-01-01"), + key_prefix=["my_schema"], + metadata={"partition_expr": "time"}, + config_schema={"value": str}, +) +def daily_partitioned_lazy(context: AssetExecutionContext) -> pl.LazyFrame: + partition = datetime.strptime( + context.partition_key, + DELTA_DATE_FORMAT, + ).date() + value = context.op_execution_context.op_config["value"] + + return pl.LazyFrame( + { + "time": [partition, partition, partition], + "a": [value, value, value], + "b": [4, 5, 6], + }, + ) + + +@pytest.mark.parametrize( + ("asset1", "asset1_path"), + [ + (daily_partitioned, "daily_partitioned"), + (daily_partitioned_lazy, "daily_partitioned_lazy"), + ], +) +def test_time_window_partitioned_asset(tmp_path, io_manager, asset1, asset1_path): + resource_defs = {"io_manager": io_manager} + + materialize( + [asset1], + partition_key="2022-01-01", + resources=resource_defs, + run_config={"ops": {"my_schema__" + asset1_path: {"config": {"value": "1"}}}}, + ) + + dt = DeltaTable(os.path.join(tmp_path, "my_schema/" + asset1_path)) + out_df = dt.to_pyarrow_table() + assert out_df["a"].to_pylist() == ["1", "1", "1"] + + materialize( + [asset1], + partition_key="2022-01-02", + resources=resource_defs, + run_config={"ops": {"my_schema__" + asset1_path: {"config": {"value": "2"}}}}, + ) + + dt.update_incremental() + out_df = dt.to_pyarrow_table() + assert sorted(out_df["a"].to_pylist()) == ["1", "1", "1", "2", "2", "2"] + + materialize( + [asset1], + partition_key="2022-01-01", + resources=resource_defs, + run_config={"ops": {"my_schema__" + asset1_path: {"config": {"value": "3"}}}}, + ) + + dt.update_incremental() + out_df = dt.to_pyarrow_table() + assert sorted(out_df["a"].to_pylist()) == ["2", "2", "2", "3", "3", "3"] + + +@asset( + partitions_def=DailyPartitionsDefinition(start_date="2022-01-01"), + key_prefix=["my_schema"], + metadata={"partition_expr": "time"}, +) +def load_partitioned(daily_partitioned: pl.DataFrame) -> pl.DataFrame: + return daily_partitioned + + +def test_load_partitioned_asset(io_manager): + warnings.filterwarnings("ignore", category=ExperimentalWarning) + resource_defs = {"io_manager": io_manager} + + res = materialize( + [daily_partitioned, load_partitioned], + partition_key="2022-01-01", + resources=resource_defs, + run_config={"ops": {"my_schema__daily_partitioned": {"config": {"value": "1"}}}}, + ) + + assert res.success + table = res.asset_value(["my_schema", "load_partitioned"]) + assert table.shape[0] == 3 + + res = materialize( + [daily_partitioned, load_partitioned], + partition_key="2022-01-02", + resources=resource_defs, + run_config={"ops": {"my_schema__daily_partitioned": {"config": {"value": "2"}}}}, + ) + + assert res.success + table = res.asset_value(["my_schema", "load_partitioned"]) + assert table.shape[0] == 3 + + +@asset( + partitions_def=StaticPartitionsDefinition(["red", "yellow", "blue"]), + key_prefix=["my_schema"], + metadata={"partition_expr": "color"}, + config_schema={"value": str}, +) +def static_partitioned(context) -> pl.DataFrame: + partition = context.partition_key + value = context.op_execution_context.op_config["value"] + + return pl.DataFrame( + { + "color": [partition, partition, partition], + "a": [value, value, value], + "b": [4, 5, 6], + }, + ) + + +def test_static_partitioned_asset(tmp_path, io_manager): + resource_defs = {"io_manager": io_manager} + + materialize( + [static_partitioned], + partition_key="red", + resources=resource_defs, + run_config={"ops": {"my_schema__static_partitioned": {"config": {"value": "1"}}}}, + ) + + dt = DeltaTable(os.path.join(tmp_path, "my_schema/static_partitioned")) + out_df = dt.to_pyarrow_table() + assert out_df["a"].to_pylist() == ["1", "1", "1"] + + materialize( + [static_partitioned], + partition_key="blue", + resources=resource_defs, + run_config={"ops": {"my_schema__static_partitioned": {"config": {"value": "2"}}}}, + ) + + dt.update_incremental() + out_df = dt.to_pyarrow_table() + assert sorted(out_df["a"].to_pylist()) == ["1", "1", "1", "2", "2", "2"] + + materialize( + [static_partitioned], + partition_key="red", + resources=resource_defs, + run_config={"ops": {"my_schema__static_partitioned": {"config": {"value": "3"}}}}, + ) + + dt.update_incremental() + out_df = dt.to_pyarrow_table() + assert sorted(out_df["a"].to_pylist()) == ["2", "2", "2", "3", "3", "3"] + + +@asset( + partitions_def=MultiPartitionsDefinition( + { + "time": DailyPartitionsDefinition(start_date="2022-01-01"), + "color": StaticPartitionsDefinition(["red", "yellow", "blue"]), + }, + ), + key_prefix=["my_schema"], + metadata={"partition_expr": {"time": "time", "color": "color"}}, + config_schema={"value": str}, +) +def multi_partitioned(context) -> pl.DataFrame: + partition = context.partition_key.keys_by_dimension + time_partition = datetime.strptime(partition["time"], DELTA_DATE_FORMAT).date() + value = context.op_execution_context.op_config["value"] + return pl.DataFrame( + { + "color": [partition["color"], partition["color"], partition["color"]], + "time": [time_partition, time_partition, time_partition], + "a": [value, value, value], + }, + ) + + +def test_multi_partitioned_asset(tmp_path, io_manager): + resource_defs = {"io_manager": io_manager} + + materialize( + [multi_partitioned], + partition_key=MultiPartitionKey({"time": "2022-01-01", "color": "red"}), + resources=resource_defs, + run_config={"ops": {"my_schema__multi_partitioned": {"config": {"value": "1"}}}}, + ) + + dt = DeltaTable(os.path.join(tmp_path, "my_schema/multi_partitioned")) + out_df = dt.to_pyarrow_table() + assert out_df["a"].to_pylist() == ["1", "1", "1"] + + materialize( + [multi_partitioned], + partition_key=MultiPartitionKey({"time": "2022-01-01", "color": "blue"}), + resources=resource_defs, + run_config={"ops": {"my_schema__multi_partitioned": {"config": {"value": "2"}}}}, + ) + + dt.update_incremental() + out_df = dt.to_pyarrow_table() + assert sorted(out_df["a"].to_pylist()) == ["1", "1", "1", "2", "2", "2"] + + materialize( + [multi_partitioned], + partition_key=MultiPartitionKey({"time": "2022-01-02", "color": "red"}), + resources=resource_defs, + run_config={"ops": {"my_schema__multi_partitioned": {"config": {"value": "3"}}}}, + ) + + dt.update_incremental() + out_df = dt.to_pyarrow_table() + assert sorted(out_df["a"].to_pylist()) == [ + "1", + "1", + "1", + "2", + "2", + "2", + "3", + "3", + "3", + ] + + materialize( + [multi_partitioned], + partition_key=MultiPartitionKey({"time": "2022-01-01", "color": "red"}), + resources=resource_defs, + run_config={"ops": {"my_schema__multi_partitioned": {"config": {"value": "4"}}}}, + ) + + dt.update_incremental() + out_df = dt.to_pyarrow_table() + assert sorted(out_df["a"].to_pylist()) == [ + "2", + "2", + "2", + "3", + "3", + "3", + "4", + "4", + "4", + ] + + +dynamic_fruits = DynamicPartitionsDefinition(name="dynamic_fruits") + + +@asset( + partitions_def=dynamic_fruits, + key_prefix=["my_schema"], + metadata={"partition_expr": "fruit"}, + config_schema={"value": str}, +) +def dynamic_partitioned(context: AssetExecutionContext) -> pl.DataFrame: + partition = context.partition_key + value = context.op_execution_context.op_config["value"] + return pl.DataFrame( + { + "fruit": [partition, partition, partition], + "a": [value, value, value], + }, + ) + + +def test_dynamic_partition(tmp_path, io_manager): + with instance_for_test() as instance: + resource_defs = {"io_manager": io_manager} + + instance.add_dynamic_partitions(dynamic_fruits.name, ["apple"]) # type: ignore + + materialize( + [dynamic_partitioned], + partition_key="apple", + resources=resource_defs, + instance=instance, + run_config={"ops": {"my_schema__dynamic_partitioned": {"config": {"value": "1"}}}}, + ) + + dt = DeltaTable(os.path.join(tmp_path, "my_schema/dynamic_partitioned")) + out_df = dt.to_pyarrow_table() + assert out_df["a"].to_pylist() == ["1", "1", "1"] + + instance.add_dynamic_partitions(dynamic_fruits.name, ["orange"]) # type: ignore + + materialize( + [dynamic_partitioned], + partition_key="orange", + resources=resource_defs, + instance=instance, + run_config={"ops": {"my_schema__dynamic_partitioned": {"config": {"value": "2"}}}}, + ) + + dt.update_incremental() + out_df = dt.to_pyarrow_table() + assert sorted(out_df["a"].to_pylist()) == ["1", "1", "1", "2", "2", "2"] + + materialize( + [dynamic_partitioned], + partition_key="apple", + resources=resource_defs, + instance=instance, + run_config={"ops": {"my_schema__dynamic_partitioned": {"config": {"value": "3"}}}}, + ) + + dt.update_incremental() + out_df = dt.to_pyarrow_table() + assert sorted(out_df["a"].to_pylist()) == ["2", "2", "2", "3", "3", "3"] diff --git a/libraries/dagster-delta/dagster_delta_tests/polars/test_type_handler_save_modes.py b/libraries/dagster-delta/dagster_delta_tests/polars/test_type_handler_save_modes.py new file mode 100644 index 0000000..3b5f88b --- /dev/null +++ b/libraries/dagster-delta/dagster_delta_tests/polars/test_type_handler_save_modes.py @@ -0,0 +1,135 @@ +import os + +import polars as pl +import pytest +from dagster import ( + Out, + graph, + op, +) +from deltalake import DeltaTable + +from dagster_delta import DeltaLakePolarsIOManager, LocalConfig, WriteMode + + +@pytest.fixture +def io_manager(tmp_path) -> DeltaLakePolarsIOManager: + return DeltaLakePolarsIOManager( + root_uri=str(tmp_path), + storage_options=LocalConfig(), + mode=WriteMode.overwrite, + ) + + +@pytest.fixture +def io_manager_append(tmp_path) -> DeltaLakePolarsIOManager: + return DeltaLakePolarsIOManager( + root_uri=str(tmp_path), + storage_options=LocalConfig(), + mode=WriteMode.append, + ) + + +@pytest.fixture +def io_manager_ignore(tmp_path) -> DeltaLakePolarsIOManager: + return DeltaLakePolarsIOManager( + root_uri=str(tmp_path), + storage_options=LocalConfig(), + mode=WriteMode.ignore, + ) + + +@op(out=Out(metadata={"schema": "a_df"})) +def a_df() -> pl.DataFrame: + return pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + + +@op(out=Out(metadata={"schema": "add_one"})) +def add_one(df: pl.DataFrame): # noqa: ANN201 + return df + 1 + + +@graph +def add_one_to_dataframe(): + add_one(a_df()) + + +@graph +def just_a_df(): + a_df() + + +def test_deltalake_io_manager_with_ops_appended(tmp_path, io_manager_append): + resource_defs = {"io_manager": io_manager_append} + + job = just_a_df.to_job(resource_defs=resource_defs) + + # run the job twice to ensure tables get appended + expected_result1 = [1, 2, 3] + + for _ in range(2): + res = job.execute_in_process() + + assert res.success + + dt = DeltaTable(os.path.join(tmp_path, "a_df/result")) + out_df = dt.to_pyarrow_table() + assert out_df["a"].to_pylist() == expected_result1 + + expected_result1.extend(expected_result1) + + +def test_deltalake_io_manager_with_ops_ignored(tmp_path, io_manager_ignore): + resource_defs = {"io_manager": io_manager_ignore} + + job = just_a_df.to_job(resource_defs=resource_defs) + + # run the job 5 times to ensure tables gets ignored on each write + for _ in range(5): + res = job.execute_in_process() + + assert res.success + + dt = DeltaTable(os.path.join(tmp_path, "a_df/result")) + out_df = dt.to_pyarrow_table() + assert out_df["a"].to_pylist() == [1, 2, 3] + + dt = DeltaTable(os.path.join(tmp_path, "a_df/result")) + assert dt.version() == 0 + + +@op(out=Out(metadata={"schema": "a_df", "mode": "append"})) +def a_df_custom() -> pl.DataFrame: + return pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + + +@graph +def add_one_to_dataframe_custom(): + add_one(a_df_custom()) + + +def test_deltalake_io_manager_with_ops_mode_overridden(tmp_path, io_manager): + resource_defs = {"io_manager": io_manager} + + job = add_one_to_dataframe_custom.to_job(resource_defs=resource_defs) + + # run the job twice to ensure that tables get properly deleted + + a_df_result = [1, 2, 3] + add_one_result = [2, 3, 4] + + for _ in range(2): + res = job.execute_in_process() + + assert res.success + + dt = DeltaTable(os.path.join(tmp_path, "a_df/result")) + out_df = dt.to_pyarrow_table() + assert out_df["a"].to_pylist() == a_df_result + + dt = DeltaTable(os.path.join(tmp_path, "add_one/result")) + out_df = dt.to_pyarrow_table() + assert out_df["a"].to_pylist() == add_one_result + + a_df_result.extend(a_df_result) + add_one_result.extend(add_one_result) diff --git a/libraries/dagster-delta/dagster_delta_tests/test_config.py b/libraries/dagster-delta/dagster_delta_tests/test_config.py new file mode 100644 index 0000000..b5ca52d --- /dev/null +++ b/libraries/dagster-delta/dagster_delta_tests/test_config.py @@ -0,0 +1,64 @@ +import os +import warnings + +import pyarrow as pa +import pytest +from dagster import ( + ExperimentalWarning, + Out, + graph, + op, +) +from deltalake import DeltaTable + +from dagster_delta import BackoffConfig, ClientConfig, DeltaLakePyarrowIOManager, LocalConfig + +warnings.filterwarnings("ignore", category=ExperimentalWarning) + + +@pytest.fixture +def io_manager(tmp_path) -> DeltaLakePyarrowIOManager: + return DeltaLakePyarrowIOManager( + root_uri=str(tmp_path), + storage_options=LocalConfig(), + client_options=ClientConfig( + max_retries=10, + retry_timeout="10s", + backoff_config=BackoffConfig(init_backoff="10s", base=1.2), + ), + ) + + +@op(out=Out(metadata={"schema": "a_df"})) +def a_df() -> pa.Table: + return pa.Table.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6]}) + + +@op(out=Out(metadata={"schema": "add_one"})) +def add_one(df: pa.Table): # noqa: ANN201 + return df.set_column(0, "a", pa.array([2, 3, 4])) + + +@graph +def add_one_to_dataframe(): + add_one(a_df()) + + +def test_deltalake_io_manager_with_client_config(tmp_path, io_manager): + resource_defs = {"io_manager": io_manager} + + job = add_one_to_dataframe.to_job(resource_defs=resource_defs) + + # run the job twice to ensure that tables get properly deleted + for _ in range(2): + res = job.execute_in_process() + + assert res.success + + dt = DeltaTable(os.path.join(tmp_path, "a_df/result")) + out_df = dt.to_pyarrow_table() + assert out_df["a"].to_pylist() == [1, 2, 3] + + dt = DeltaTable(os.path.join(tmp_path, "add_one/result")) + out_df = dt.to_pyarrow_table() + assert out_df["a"].to_pylist() == [2, 3, 4] diff --git a/libraries/dagster-delta/dagster_delta_tests/test_delta_table_resource.py b/libraries/dagster-delta/dagster_delta_tests/test_delta_table_resource.py index 9d654ad..8a3080b 100644 --- a/libraries/dagster-delta/dagster_delta_tests/test_delta_table_resource.py +++ b/libraries/dagster-delta/dagster_delta_tests/test_delta_table_resource.py @@ -4,8 +4,7 @@ from dagster import asset, materialize from deltalake import write_deltalake -from dagster_delta import DeltaTableResource -from dagster_delta.config import LocalConfig +from dagster_delta import BackoffConfig, ClientConfig, DeltaTableResource, LocalConfig def test_resource(tmp_path): @@ -35,6 +34,11 @@ def read_table(delta_table: DeltaTableResource): "delta_table": DeltaTableResource( url=os.path.join(tmp_path, "table"), storage_options=LocalConfig(), + client_options=ClientConfig( + max_retries=10, + retry_timeout="10s", + backoff_config=BackoffConfig(init_backoff="10s", base=1.2), + ), ), }, ) @@ -73,6 +77,11 @@ def read_table(delta_table: DeltaTableResource): "delta_table": DeltaTableResource( url=os.path.join(tmp_path, "table"), storage_options=LocalConfig(), + client_options=ClientConfig( + max_retries=10, + retry_timeout="10s", + backoff_config=BackoffConfig(init_backoff="10s", base=1.2), + ), version=0, ), }, diff --git a/libraries/dagster-delta/dagster_delta_tests/test_io_manager.py b/libraries/dagster-delta/dagster_delta_tests/test_io_manager.py index d3ee846..9e44390 100644 --- a/libraries/dagster-delta/dagster_delta_tests/test_io_manager.py +++ b/libraries/dagster-delta/dagster_delta_tests/test_io_manager.py @@ -14,9 +14,8 @@ from deltalake import DeltaTable from deltalake.schema import Field, PrimitiveType, Schema -from dagster_delta import DeltaLakePyarrowIOManager, LocalConfig -from dagster_delta.handler import partition_dimensions_to_dnf -from dagster_delta.io_manager import WriteMode +from dagster_delta import DeltaLakePyarrowIOManager, LocalConfig, WriteMode +from dagster_delta._handler.utils import partition_dimensions_to_dnf TablePartitionDimension( partitions=TimeWindow(datetime(2020, 1, 2), datetime(2020, 2, 3)), @@ -24,7 +23,7 @@ ) -@pytest.fixture() +@pytest.fixture def test_schema() -> Schema: fields = [ Field(name="string_col", type=PrimitiveType("string")), @@ -72,7 +71,7 @@ def add_one_to_dataset(): add_one(a_pa_table()) -@pytest.fixture() +@pytest.fixture def io_manager_with_parquet_read_options(tmp_path) -> DeltaLakePyarrowIOManager: return DeltaLakePyarrowIOManager( root_uri=str(tmp_path), diff --git a/libraries/dagster-delta/dagster_delta_tests/test_metadata_inputs.py b/libraries/dagster-delta/dagster_delta_tests/test_metadata_inputs.py index 83be8c5..6d597b9 100644 --- a/libraries/dagster-delta/dagster_delta_tests/test_metadata_inputs.py +++ b/libraries/dagster-delta/dagster_delta_tests/test_metadata_inputs.py @@ -12,7 +12,7 @@ from dagster_delta import DeltaLakePyarrowIOManager, LocalConfig, WriterEngine -@pytest.fixture() +@pytest.fixture def io_manager(tmp_path) -> DeltaLakePyarrowIOManager: return DeltaLakePyarrowIOManager( root_uri=str(tmp_path), @@ -59,7 +59,7 @@ def test_deltalake_io_manager_with_ops_rust_writer(tmp_path, io_manager): result.extend([1, 2, 3]) -@pytest.fixture() +@pytest.fixture def io_manager_with_writer_metadata(tmp_path) -> DeltaLakePyarrowIOManager: return DeltaLakePyarrowIOManager( root_uri=str(tmp_path), diff --git a/libraries/dagster-delta/dagster_delta_tests/test_type_handler.py b/libraries/dagster-delta/dagster_delta_tests/test_type_handler.py index 3c94fe1..a717c67 100644 --- a/libraries/dagster-delta/dagster_delta_tests/test_type_handler.py +++ b/libraries/dagster-delta/dagster_delta_tests/test_type_handler.py @@ -25,12 +25,13 @@ from dagster._check import CheckError from deltalake import DeltaTable -from dagster_delta import DELTA_DATE_FORMAT, DeltaLakePyarrowIOManager, LocalConfig +from dagster_delta import DeltaLakePyarrowIOManager, LocalConfig +from dagster_delta.io_manager.base import DELTA_DATE_FORMAT warnings.filterwarnings("ignore", category=ExperimentalWarning) -@pytest.fixture() +@pytest.fixture def io_manager(tmp_path) -> DeltaLakePyarrowIOManager: return DeltaLakePyarrowIOManager(root_uri=str(tmp_path), storage_options=LocalConfig()) diff --git a/libraries/dagster-delta/dagster_delta_tests/test_type_handler_extra_params.py b/libraries/dagster-delta/dagster_delta_tests/test_type_handler_extra_params.py index a031f7f..7286f77 100644 --- a/libraries/dagster-delta/dagster_delta_tests/test_type_handler_extra_params.py +++ b/libraries/dagster-delta/dagster_delta_tests/test_type_handler_extra_params.py @@ -12,7 +12,7 @@ from dagster_delta import DeltaLakePyarrowIOManager, LocalConfig, WriterEngine -@pytest.fixture() +@pytest.fixture def io_manager(tmp_path) -> DeltaLakePyarrowIOManager: return DeltaLakePyarrowIOManager( root_uri=str(tmp_path), diff --git a/libraries/dagster-delta/pyproject.toml b/libraries/dagster-delta/pyproject.toml index ea03481..c399a4d 100644 --- a/libraries/dagster-delta/pyproject.toml +++ b/libraries/dagster-delta/pyproject.toml @@ -1,12 +1,13 @@ [project] name = "dagster-delta" -version = "0.2.1" -description = "base deltalake IO Managers for Dagster" +version = "0.3.0" +description = "Deltalake IO Managers for Dagster with pyarrow and Polars support." readme = "README.md" requires-python = ">=3.9" dependencies = [ "dagster>=1.8,<2.0", "deltalake>=0.24", + "pendulum>=3.0.0", ] authors = [{name = "Ion Koutsouris"}] license = { file = "licenses/LICENSE" } @@ -17,9 +18,12 @@ classifiers = [ "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12" ] -keywords = ["dagster", "deltalake", "delta","datalake", "io manager"] - +keywords = ["dagster", "deltalake", "delta","datalake", "io manager", "polars", "pyarrow"] +[project.optional-dependencies] +polars = [ + "polars" +] [build-system] requires = ["setuptools", "wheel"] @@ -109,6 +113,9 @@ lint.ignore = [ # Allow autofix for all enabled rules (when `--fix`) is provided. lint.fixable = ["ALL"] +# Allow unused variables when underscore-prefixed. +lint.dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" + # Exclude a variety of commonly ignored directories. exclude = [ ".bzr", @@ -137,9 +144,6 @@ exclude = [ # Same as Black. line-length = 100 -# Allow unused variables when underscore-prefixed. -lint.dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" - # Assume Python 3.9 target-version = "py39" @@ -153,3 +157,12 @@ suppress-none-returning = true [tool.ruff.lint.pydocstyle] # Use Google-style docstrings. convention = "google" + +[dependency-groups] +dev = [ + "polars>=1.22.0", + "pyarrow<=18.0.0", + "pyright>=1.1.393", + "pytest>=8.3.4", + "ruff>=0.9.5", +] diff --git a/libraries/dagster-unity-catalog-polars/pyproject.toml b/libraries/dagster-unity-catalog-polars/pyproject.toml index eb712b7..7179b96 100644 --- a/libraries/dagster-unity-catalog-polars/pyproject.toml +++ b/libraries/dagster-unity-catalog-polars/pyproject.toml @@ -156,3 +156,10 @@ suppress-none-returning = true [tool.ruff.lint.pydocstyle] # Use Google-style docstrings. convention = "google" + +[dependency-groups] +dev = [ + "pyright>=1.1.393", + "pytest>=8.3.4", + "ruff>=0.9.5", +] diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 404ff35..0000000 --- a/requirements.txt +++ /dev/null @@ -1,18 +0,0 @@ -## build pckgs -polars -dagster -lakefs -pydantic>=2 -pyarrow==18.0.0 -databricks-sdk -databricks-sql-connector -deltalake - -## dev pckgs -ruff==0.5.7 -pyright==1.1.375 -pytest==8.3.2 - -## misc -pip -build