Skip to content

Commit

Permalink
feat: Add Snowflake integration (#1064)
Browse files Browse the repository at this point in the history
* initial commit

* add unit tests

* add pyproject.toml

* add pydoc config

* add CHANGELOG file

* update pyproject.toml

* lint file

* add example and fix lint

* update comments

* add header and trailing line

* update based on review
  • Loading branch information
medsriha authored Sep 16, 2024
1 parent b72d857 commit b47583f
Show file tree
Hide file tree
Showing 9 changed files with 1,279 additions and 0 deletions.
1 change: 1 addition & 0 deletions integrations/snowflake/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
## [integrations/snowflake-v0.0.1] - 2024-09-06
23 changes: 23 additions & 0 deletions integrations/snowflake/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# snowflake-haystack

[![PyPI - Version](https://img.shields.io/pypi/v/snowflake-haystack.svg)](https://pypi.org/project/snowflake-haystack)
[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/snowflake-haystack.svg)](https://pypi.org/project/snowflake-haystack)

-----

**Table of Contents**

- [Installation](#installation)
- [License](#license)

## Installation

```console
pip install snowflake-haystack
```
## Examples
You can find a code example showing how to use the Retriever under the `example/` folder of this repo.

## License

`snowflake-haystack` is distributed under the terms of the [Apache-2.0](https://spdx.org/licenses/Apache-2.0.html) license.
120 changes: 120 additions & 0 deletions integrations/snowflake/example/text2sql_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from dotenv import load_dotenv
from haystack import Pipeline
from haystack.components.builders import PromptBuilder
from haystack.components.converters import OutputAdapter
from haystack.components.generators import OpenAIGenerator
from haystack.utils import Secret

from haystack_integrations.components.retrievers.snowflake import SnowflakeTableRetriever

load_dotenv()

sql_template = """
You are a SQL expert working with Snowflake.
Your task is to create a Snowflake SQL query for the given question.
Refrain from explaining your answer. Your answer must be the SQL query
in plain text format without using Markdown.
Here are some relevant tables, a description about it, and their
columns:
Database name: DEMO_DB
Schema name: ADVENTURE_WORKS
Table names:
- ADDRESS: Employees Address Table
- EMPLOYEE: Employees directory
- SALESTERRITORY: Sales territory lookup table.
- SALESORDERHEADER: General sales order information.
User's question: {{ question }}
Generated SQL query:
"""

sql_builder = PromptBuilder(template=sql_template)

analyst_template = """
You are an expert data analyst.
Your role is to answer the user's question {{ question }} using the information
in the table.
You will base your response solely on the information provided in the
table.
Do not rely on your knowledge base; only the data that is in the table.
Refrain from using the term "table" in your response, but instead, use
the word "data"
If the table is blank say:
"The specific answer can't be found in the database. Try rephrasing your
question."
Additionally, you will present the table in a tabular format and provide
the SQL query used to extract the relevant rows from the database in
Markdown.
If the table is larger than 10 rows, display the most important rows up
to 10 rows. Your answer must be detailed and provide insights based on
the question and the available data.
SQL query:
{{ sql_query }}
Table:
{{ table }}
Answer:
"""

analyst_builder = PromptBuilder(template=analyst_template)

# LLM responsible for generating the SQL query
sql_llm = OpenAIGenerator(
model="gpt-4o",
api_key=Secret.from_env_var("OPENAI_API_KEY"),
generation_kwargs={"temperature": 0.0, "max_tokens": 1000},
)

# LLM responsible for analyzing the table
analyst_llm = OpenAIGenerator(
model="gpt-4o",
api_key=Secret.from_env_var("OPENAI_API_KEY"),
generation_kwargs={"temperature": 0.0, "max_tokens": 2000},
)

snowflake = SnowflakeTableRetriever(
user="<ACCOUNT-USER>",
account="<ACCOUNT-IDENTIFIER>",
api_key=Secret.from_env_var("SNOWFLAKE_API_KEY"),
warehouse="<WAREHOUSE-NAME>",
)

adapter = OutputAdapter(template="{{ replies[0] }}", output_type=str)

pipeline = Pipeline()

pipeline.add_component(name="sql_builder", instance=sql_builder)
pipeline.add_component(name="sql_llm", instance=sql_llm)
pipeline.add_component(name="adapter", instance=adapter)
pipeline.add_component(name="snowflake", instance=snowflake)
pipeline.add_component(name="analyst_builder", instance=analyst_builder)
pipeline.add_component(name="analyst_llm", instance=analyst_llm)


pipeline.connect("sql_builder.prompt", "sql_llm.prompt")
pipeline.connect("sql_llm.replies", "adapter.replies")
pipeline.connect("adapter.output", "snowflake.query")
pipeline.connect("snowflake.table", "analyst_builder.table")
pipeline.connect("adapter.output", "analyst_builder.sql_query")
pipeline.connect("analyst_builder.prompt", "analyst_llm.prompt")

question = "What are my top territories by number of orders and by sales value?"

response = pipeline.run(data={"sql_builder": {"question": question}, "analyst_builder": {"question": question}})
30 changes: 30 additions & 0 deletions integrations/snowflake/pydoc/config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
loaders:
- type: haystack_pydoc_tools.loaders.CustomPythonLoader
search_path: [../src]
modules:
[
"haystack_integrations.components.retrievers.snowflake.snowflake_retriever"
]
ignore_when_discovered: ["__init__"]
processors:
- type: filter
expression:
documented_only: true
do_not_filter_modules: false
skip_empty_modules: true
- type: smart
- type: crossref
renderer:
type: haystack_pydoc_tools.renderers.ReadmeIntegrationRenderer
excerpt: Snowflake integration for Haystack
category_slug: integrations-api
title: Snowflake
slug: integrations-Snowflake
order: 130
markdown:
descriptive_class_title: false
classdef_code_block: false
descriptive_module_title: true
add_method_class_prefix: true
add_member_class_prefix: false
filename: _readme_snowflake.md
149 changes: 149 additions & 0 deletions integrations/snowflake/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
[build-system]
requires = ["hatchling", "hatch-vcs"]
build-backend = "hatchling.build"

[project]
name = "snowflake-haystack"
dynamic = ["version"]
description = 'A Snowflake integration for the Haystack framework.'
readme = "README.md"
requires-python = ">=3.8"
license = "Apache-2.0"
keywords = []
authors = [{ name = "deepset GmbH", email = "[email protected]" },
{ name = "Mohamed Sriha", email = "[email protected]" }]
classifiers = [
"License :: OSI Approved :: Apache Software License",
"Development Status :: 4 - Beta",
"Programming Language :: Python",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]
dependencies = ["haystack-ai", "snowflake-connector-python>=3.10.1", "tabulate>=0.9.0"]

[project.urls]
Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/snowflake#readme"
Issues = "https://github.com/deepset-ai/haystack-core-integrations/issues"
Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/snowflake"

[tool.hatch.build.targets.wheel]
packages = ["src/haystack_integrations"]

[tool.hatch.version]
source = "vcs"
tag-pattern = 'integrations\/snowflake-v(?P<version>.*)'

[tool.hatch.version.raw-options]
root = "../.."
git_describe_command = 'git describe --tags --match="integrations/snowflake-v[0-9]*"'

[tool.hatch.envs.default]
dependencies = ["coverage[toml]>=6.5", "pytest", "pytest-rerunfailures", "haystack-pydoc-tools"]
[tool.hatch.envs.default.scripts]
test = "pytest {args:tests}"
test-cov = "coverage run -m pytest {args:tests}"
test-cov-retry = "test-cov --reruns 3 --reruns-delay 30 -x"
cov-report = ["- coverage combine", "coverage report"]
cov = ["test-cov", "cov-report"]
cov-retry = ["test-cov-retry", "cov-report"]
docs = ["pydoc-markdown pydoc/config.yml"]


[[tool.hatch.envs.all.matrix]]
python = ["3.8", "3.9", "3.10", "3.11"]

[tool.hatch.envs.lint]
detached = true
dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"]
[tool.hatch.envs.lint.scripts]
typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}"
style = ["ruff check {args:. --exclude tests/}", "black --check --diff {args:.}"]
fmt = ["black {args:.}", "ruff --fix {args:. --exclude tests/}", "style"]
all = ["style", "typing"]

[tool.black]
target-version = ["py38"]
line-length = 120
skip-string-normalization = true

[tool.ruff]
target-version = "py38"
line-length = 120
select = [
"A",
"ARG",
"B",
"C",
"DTZ",
"E",
"EM",
"F",
"I",
"ICN",
"ISC",
"N",
"PLC",
"PLE",
"PLR",
"PLW",
"Q",
"RUF",
"S",
"T",
"TID",
"UP",
"W",
"YTT",
]
ignore = [
# Allow non-abstract empty methods in abstract base classes
"B027",
# Ignore checks for possible passwords
"S105",
"S106",
"S107",
# Ignore complexity
"C901",
"PLR0911",
"PLR0912",
"PLR0913",
"PLR0915",
# Ignore SQL injection
"S608",
# Unused method argument
"ARG002"
]
unfixable = [
# Don't touch unused imports
"F401",
]

[tool.ruff.isort]
known-first-party = ["snowflake_haystack"]

[tool.ruff.flake8-tidy-imports]
ban-relative-imports = "parents"

[tool.ruff.per-file-ignores]
# Tests can use magic values, assertions, and relative imports
"tests/**/*" = ["PLR2004", "S101", "TID252"]

[tool.coverage.run]
source = ["haystack_integrations"]
branch = true
parallel = false


[tool.coverage.report]
omit = ["*/tests/*", "*/__init__.py"]
show_missing = true
exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"]

[[tool.mypy.overrides]]
module = ["haystack.*", "haystack_integrations.*", "pytest.*", "openai.*", "snowflake.*"]
ignore_missing_imports = true
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

from .snowflake_table_retriever import SnowflakeTableRetriever

__all__ = ["SnowflakeTableRetriever"]
Loading

0 comments on commit b47583f

Please sign in to comment.