-
Notifications
You must be signed in to change notification settings - Fork 128
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add Snowflake integration (#1064)
* initial commit * add unit tests * add pyproject.toml * add pydoc config * add CHANGELOG file * update pyproject.toml * lint file * add example and fix lint * update comments * add header and trailing line * update based on review
- Loading branch information
Showing
9 changed files
with
1,279 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
## [integrations/snowflake-v0.0.1] - 2024-09-06 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# snowflake-haystack | ||
|
||
[![PyPI - Version](https://img.shields.io/pypi/v/snowflake-haystack.svg)](https://pypi.org/project/snowflake-haystack) | ||
[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/snowflake-haystack.svg)](https://pypi.org/project/snowflake-haystack) | ||
|
||
----- | ||
|
||
**Table of Contents** | ||
|
||
- [Installation](#installation) | ||
- [License](#license) | ||
|
||
## Installation | ||
|
||
```console | ||
pip install snowflake-haystack | ||
``` | ||
## Examples | ||
You can find a code example showing how to use the Retriever under the `example/` folder of this repo. | ||
|
||
## License | ||
|
||
`snowflake-haystack` is distributed under the terms of the [Apache-2.0](https://spdx.org/licenses/Apache-2.0.html) license. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
from dotenv import load_dotenv | ||
from haystack import Pipeline | ||
from haystack.components.builders import PromptBuilder | ||
from haystack.components.converters import OutputAdapter | ||
from haystack.components.generators import OpenAIGenerator | ||
from haystack.utils import Secret | ||
|
||
from haystack_integrations.components.retrievers.snowflake import SnowflakeTableRetriever | ||
|
||
load_dotenv() | ||
|
||
sql_template = """ | ||
You are a SQL expert working with Snowflake. | ||
Your task is to create a Snowflake SQL query for the given question. | ||
Refrain from explaining your answer. Your answer must be the SQL query | ||
in plain text format without using Markdown. | ||
Here are some relevant tables, a description about it, and their | ||
columns: | ||
Database name: DEMO_DB | ||
Schema name: ADVENTURE_WORKS | ||
Table names: | ||
- ADDRESS: Employees Address Table | ||
- EMPLOYEE: Employees directory | ||
- SALESTERRITORY: Sales territory lookup table. | ||
- SALESORDERHEADER: General sales order information. | ||
User's question: {{ question }} | ||
Generated SQL query: | ||
""" | ||
|
||
sql_builder = PromptBuilder(template=sql_template) | ||
|
||
analyst_template = """ | ||
You are an expert data analyst. | ||
Your role is to answer the user's question {{ question }} using the information | ||
in the table. | ||
You will base your response solely on the information provided in the | ||
table. | ||
Do not rely on your knowledge base; only the data that is in the table. | ||
Refrain from using the term "table" in your response, but instead, use | ||
the word "data" | ||
If the table is blank say: | ||
"The specific answer can't be found in the database. Try rephrasing your | ||
question." | ||
Additionally, you will present the table in a tabular format and provide | ||
the SQL query used to extract the relevant rows from the database in | ||
Markdown. | ||
If the table is larger than 10 rows, display the most important rows up | ||
to 10 rows. Your answer must be detailed and provide insights based on | ||
the question and the available data. | ||
SQL query: | ||
{{ sql_query }} | ||
Table: | ||
{{ table }} | ||
Answer: | ||
""" | ||
|
||
analyst_builder = PromptBuilder(template=analyst_template) | ||
|
||
# LLM responsible for generating the SQL query | ||
sql_llm = OpenAIGenerator( | ||
model="gpt-4o", | ||
api_key=Secret.from_env_var("OPENAI_API_KEY"), | ||
generation_kwargs={"temperature": 0.0, "max_tokens": 1000}, | ||
) | ||
|
||
# LLM responsible for analyzing the table | ||
analyst_llm = OpenAIGenerator( | ||
model="gpt-4o", | ||
api_key=Secret.from_env_var("OPENAI_API_KEY"), | ||
generation_kwargs={"temperature": 0.0, "max_tokens": 2000}, | ||
) | ||
|
||
snowflake = SnowflakeTableRetriever( | ||
user="<ACCOUNT-USER>", | ||
account="<ACCOUNT-IDENTIFIER>", | ||
api_key=Secret.from_env_var("SNOWFLAKE_API_KEY"), | ||
warehouse="<WAREHOUSE-NAME>", | ||
) | ||
|
||
adapter = OutputAdapter(template="{{ replies[0] }}", output_type=str) | ||
|
||
pipeline = Pipeline() | ||
|
||
pipeline.add_component(name="sql_builder", instance=sql_builder) | ||
pipeline.add_component(name="sql_llm", instance=sql_llm) | ||
pipeline.add_component(name="adapter", instance=adapter) | ||
pipeline.add_component(name="snowflake", instance=snowflake) | ||
pipeline.add_component(name="analyst_builder", instance=analyst_builder) | ||
pipeline.add_component(name="analyst_llm", instance=analyst_llm) | ||
|
||
|
||
pipeline.connect("sql_builder.prompt", "sql_llm.prompt") | ||
pipeline.connect("sql_llm.replies", "adapter.replies") | ||
pipeline.connect("adapter.output", "snowflake.query") | ||
pipeline.connect("snowflake.table", "analyst_builder.table") | ||
pipeline.connect("adapter.output", "analyst_builder.sql_query") | ||
pipeline.connect("analyst_builder.prompt", "analyst_llm.prompt") | ||
|
||
question = "What are my top territories by number of orders and by sales value?" | ||
|
||
response = pipeline.run(data={"sql_builder": {"question": question}, "analyst_builder": {"question": question}}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
loaders: | ||
- type: haystack_pydoc_tools.loaders.CustomPythonLoader | ||
search_path: [../src] | ||
modules: | ||
[ | ||
"haystack_integrations.components.retrievers.snowflake.snowflake_retriever" | ||
] | ||
ignore_when_discovered: ["__init__"] | ||
processors: | ||
- type: filter | ||
expression: | ||
documented_only: true | ||
do_not_filter_modules: false | ||
skip_empty_modules: true | ||
- type: smart | ||
- type: crossref | ||
renderer: | ||
type: haystack_pydoc_tools.renderers.ReadmeIntegrationRenderer | ||
excerpt: Snowflake integration for Haystack | ||
category_slug: integrations-api | ||
title: Snowflake | ||
slug: integrations-Snowflake | ||
order: 130 | ||
markdown: | ||
descriptive_class_title: false | ||
classdef_code_block: false | ||
descriptive_module_title: true | ||
add_method_class_prefix: true | ||
add_member_class_prefix: false | ||
filename: _readme_snowflake.md |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
[build-system] | ||
requires = ["hatchling", "hatch-vcs"] | ||
build-backend = "hatchling.build" | ||
|
||
[project] | ||
name = "snowflake-haystack" | ||
dynamic = ["version"] | ||
description = 'A Snowflake integration for the Haystack framework.' | ||
readme = "README.md" | ||
requires-python = ">=3.8" | ||
license = "Apache-2.0" | ||
keywords = [] | ||
authors = [{ name = "deepset GmbH", email = "[email protected]" }, | ||
{ name = "Mohamed Sriha", email = "[email protected]" }] | ||
classifiers = [ | ||
"License :: OSI Approved :: Apache Software License", | ||
"Development Status :: 4 - Beta", | ||
"Programming Language :: Python", | ||
"Programming Language :: Python :: 3.8", | ||
"Programming Language :: Python :: 3.9", | ||
"Programming Language :: Python :: 3.10", | ||
"Programming Language :: Python :: 3.11", | ||
"Programming Language :: Python :: 3.12", | ||
"Programming Language :: Python :: Implementation :: CPython", | ||
"Programming Language :: Python :: Implementation :: PyPy", | ||
] | ||
dependencies = ["haystack-ai", "snowflake-connector-python>=3.10.1", "tabulate>=0.9.0"] | ||
|
||
[project.urls] | ||
Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/snowflake#readme" | ||
Issues = "https://github.com/deepset-ai/haystack-core-integrations/issues" | ||
Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/snowflake" | ||
|
||
[tool.hatch.build.targets.wheel] | ||
packages = ["src/haystack_integrations"] | ||
|
||
[tool.hatch.version] | ||
source = "vcs" | ||
tag-pattern = 'integrations\/snowflake-v(?P<version>.*)' | ||
|
||
[tool.hatch.version.raw-options] | ||
root = "../.." | ||
git_describe_command = 'git describe --tags --match="integrations/snowflake-v[0-9]*"' | ||
|
||
[tool.hatch.envs.default] | ||
dependencies = ["coverage[toml]>=6.5", "pytest", "pytest-rerunfailures", "haystack-pydoc-tools"] | ||
[tool.hatch.envs.default.scripts] | ||
test = "pytest {args:tests}" | ||
test-cov = "coverage run -m pytest {args:tests}" | ||
test-cov-retry = "test-cov --reruns 3 --reruns-delay 30 -x" | ||
cov-report = ["- coverage combine", "coverage report"] | ||
cov = ["test-cov", "cov-report"] | ||
cov-retry = ["test-cov-retry", "cov-report"] | ||
docs = ["pydoc-markdown pydoc/config.yml"] | ||
|
||
|
||
[[tool.hatch.envs.all.matrix]] | ||
python = ["3.8", "3.9", "3.10", "3.11"] | ||
|
||
[tool.hatch.envs.lint] | ||
detached = true | ||
dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] | ||
[tool.hatch.envs.lint.scripts] | ||
typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" | ||
style = ["ruff check {args:. --exclude tests/}", "black --check --diff {args:.}"] | ||
fmt = ["black {args:.}", "ruff --fix {args:. --exclude tests/}", "style"] | ||
all = ["style", "typing"] | ||
|
||
[tool.black] | ||
target-version = ["py38"] | ||
line-length = 120 | ||
skip-string-normalization = true | ||
|
||
[tool.ruff] | ||
target-version = "py38" | ||
line-length = 120 | ||
select = [ | ||
"A", | ||
"ARG", | ||
"B", | ||
"C", | ||
"DTZ", | ||
"E", | ||
"EM", | ||
"F", | ||
"I", | ||
"ICN", | ||
"ISC", | ||
"N", | ||
"PLC", | ||
"PLE", | ||
"PLR", | ||
"PLW", | ||
"Q", | ||
"RUF", | ||
"S", | ||
"T", | ||
"TID", | ||
"UP", | ||
"W", | ||
"YTT", | ||
] | ||
ignore = [ | ||
# Allow non-abstract empty methods in abstract base classes | ||
"B027", | ||
# Ignore checks for possible passwords | ||
"S105", | ||
"S106", | ||
"S107", | ||
# Ignore complexity | ||
"C901", | ||
"PLR0911", | ||
"PLR0912", | ||
"PLR0913", | ||
"PLR0915", | ||
# Ignore SQL injection | ||
"S608", | ||
# Unused method argument | ||
"ARG002" | ||
] | ||
unfixable = [ | ||
# Don't touch unused imports | ||
"F401", | ||
] | ||
|
||
[tool.ruff.isort] | ||
known-first-party = ["snowflake_haystack"] | ||
|
||
[tool.ruff.flake8-tidy-imports] | ||
ban-relative-imports = "parents" | ||
|
||
[tool.ruff.per-file-ignores] | ||
# Tests can use magic values, assertions, and relative imports | ||
"tests/**/*" = ["PLR2004", "S101", "TID252"] | ||
|
||
[tool.coverage.run] | ||
source = ["haystack_integrations"] | ||
branch = true | ||
parallel = false | ||
|
||
|
||
[tool.coverage.report] | ||
omit = ["*/tests/*", "*/__init__.py"] | ||
show_missing = true | ||
exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] | ||
|
||
[[tool.mypy.overrides]] | ||
module = ["haystack.*", "haystack_integrations.*", "pytest.*", "openai.*", "snowflake.*"] | ||
ignore_missing_imports = true |
7 changes: 7 additions & 0 deletions
7
integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from .snowflake_table_retriever import SnowflakeTableRetriever | ||
|
||
__all__ = ["SnowflakeTableRetriever"] |
Oops, something went wrong.