Skip to content

Commit

Permalink
feat: ensure we always topologically sort operations in _iter, some s…
Browse files Browse the repository at this point in the history
…mall synthesis feature polish
  • Loading branch information
z3z1ma committed Jan 3, 2025
1 parent b222cec commit db59bd0
Show file tree
Hide file tree
Showing 4 changed files with 180 additions and 18 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "dbt-osmosis"
version = "1.1.2"
version = "1.1.3"
description = "A dbt utility for managing YAML to make developing with dbt more delightful."
readme = "README.md"
license = { text = "Apache-2.0" }
Expand Down
72 changes: 72 additions & 0 deletions src/dbt_osmosis/core/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
__all__ = [
"generate_model_spec_as_json",
"generate_column_doc",
"generate_table_doc",
]


Expand Down Expand Up @@ -123,6 +124,45 @@ def _create_llm_prompt_for_column(
]


def _create_llm_prompt_for_table(
sql_content: str, table_name: str, upstream_docs: list[str] | None = None
) -> list[dict[str, t.Any]]:
"""Builds a system + user prompt instructing the model to produce a string description describing a single model."""
if upstream_docs is None:
upstream_docs = []

system_prompt = dedent(f"""
You are a helpful SQL Developer and an Expert in dbt.
Your job is to produce a concise documentation string
for a table named {table_name}.
IMPORTANT RULES:
1. DO NOT output extra commentary or Markdown fences.
2. Provide only the column description text, nothing else.
3. If upstream docs exist, you may incorporate them. If none exist,
a short placeholder is acceptable.
4. Avoid speculation. Keep it short and relevant.
""")

user_message = dedent(f"""
The SQL for the model is:
>>> SQL CODE START
{sql_content}
>>> SQL CODE END
The upstream documentation is:
{os.linesep.join(upstream_docs)}
Please return only the text suitable for the "description" field.
""")

return [
{"role": "system", "content": system_prompt.strip()},
{"role": "user", "content": user_message.strip()},
]


def generate_model_spec_as_json(
sql_content: str,
upstream_docs: list[str] | None = None,
Expand Down Expand Up @@ -207,6 +247,38 @@ def generate_column_doc(
return content.strip()


def generate_table_doc(
sql_content: str,
table_name: str,
upstream_docs: list[str] | None = None,
model_engine: str = "gpt-4o",
temperature: float = 0.7,
) -> str:
"""Calls OpenAI to generate documentation for a single column in a table.
Args:
sql_content (str): The SQL code for the table
table_name (str | None): Name of the table/model (optional)
upstream_docs (list[str] | None): Optional docs or references you might have
model_engine (str): The OpenAI model to use (e.g., 'gpt-3.5-turbo')
temperature (float): OpenAI completion temperature
Returns:
str: A short docstring suitable for a "description" field
"""
messages = _create_llm_prompt_for_table(sql_content, table_name, upstream_docs)
response = openai.chat.completions.create(
model=model_engine,
messages=messages, # pyright: ignore[reportArgumentType]
temperature=temperature,
)

content = response.choices[0].message.content
if not content:
raise ValueError("OpenAI returned an empty response")
return content.strip()


if __name__ == "__main__":
# Kitchen sink
sample_sql = """
Expand Down
122 changes: 106 additions & 16 deletions src/dbt_osmosis/core/osmosis.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import time
import typing as t
import uuid
from collections import OrderedDict
from collections import ChainMap, OrderedDict, defaultdict, deque
from collections.abc import Iterable, Iterator
from concurrent.futures import FIRST_EXCEPTION, Future, ThreadPoolExecutor, wait
from dataclasses import dataclass, field
Expand Down Expand Up @@ -663,6 +663,59 @@ def _get_node_path(node: ResultNode) -> Path | None:
return None


def _topological_sort(
candidate_nodes: list[tuple[str, ResultNode]],
) -> list[tuple[str, ResultNode]]:
"""
Perform a topological sort on the given candidate_nodes (uid, node) pairs
based on their dependencies. If a cycle is detected, raise a ValueError.
Kahn’s Algorithm:
1) Build adjacency list: parent -> {child, child, ...}
(Because if node 'child' depends on 'parent', we have an edge parent->child).
2) Compute in-degrees for all nodes.
3) Collect all nodes with in-degree == 0 into a queue.
4) Repeatedly pop from queue and 'visit' that node,
then decrement the in-degree of its children.
If any child's in-degree becomes 0, push it into the queue.
5) If we visited all nodes, we have a valid topological order.
Otherwise, a cycle exists.
"""
adjacency: defaultdict[str, set[str]] = defaultdict(set)
in_degree: defaultdict[str, int] = defaultdict(int)

all_uids = set(uid for uid, _ in candidate_nodes)

for uid, _ in candidate_nodes:
in_degree[uid] = 0

for uid, node in candidate_nodes:
for dep_uid in node.depends_on_nodes:
if dep_uid in all_uids:
adjacency[dep_uid].add(uid)
in_degree[uid] += 1

queue: deque[str] = deque([uid for uid, deg in in_degree.items() if deg == 0])
sorted_uids: list[str] = []

while queue:
parent_uid = queue.popleft()
sorted_uids.append(parent_uid)

for child_uid in adjacency[parent_uid]:
in_degree[child_uid] -= 1
if in_degree[child_uid] == 0:
queue.append(child_uid)

if len(sorted_uids) < len(candidate_nodes):
raise ValueError(
"Cycle detected in node dependencies. Cannot produce a valid topological order."
)

uid_to_node = dict(candidate_nodes)
return [(uid, uid_to_node[uid]) for uid in sorted_uids]


def _iter_candidate_nodes(
context: YamlRefactorContext,
) -> Iterator[tuple[str, ResultNode]]:
Expand All @@ -689,10 +742,14 @@ def f(node: ResultNode) -> bool:
logger.debug(":white_check_mark: Node => %s passed filtering logic.", node.unique_id)
return True

candidate_nodes: list[t.Any] = []
items = chain(context.project.manifest.nodes.items(), context.project.manifest.sources.items())
for uid, dbt_node in items:
if f(dbt_node):
yield uid, dbt_node
candidate_nodes.append((uid, dbt_node))

for uid, node in _topological_sort(candidate_nodes):
yield uid, node


# Introspection
Expand Down Expand Up @@ -1875,7 +1932,11 @@ def synthesize_missing_documentation_with_openai(
) -> None:
"""Synthesize missing documentation for a dbt node using OpenAI's GPT-4o API."""
try:
from dbt_osmosis.core.llm import generate_column_doc, generate_model_spec_as_json
from dbt_osmosis.core.llm import (
generate_column_doc,
generate_model_spec_as_json,
generate_table_doc,
)
except ImportError:
raise ImportError("Please install the 'dbt-osmosis[openai]' extra to use this feature.")
if node is None:
Expand All @@ -1892,8 +1953,23 @@ def synthesize_missing_documentation_with_openai(
":no_entry_sign: No columns to synthesize documentation for => %s", node.unique_id
)
return
documented = len([n for n in node.columns.values() if n.description])
if total - documented > 10:
documented = len([
column
for column in node.columns.values()
if column.description and column.description not in context.placeholders
])
node_map = ChainMap(
t.cast(dict[str, ResultNode], context.project.manifest.nodes),
t.cast(dict[str, ResultNode], context.project.manifest.sources),
)
upstream_docs: list[str] = []
for uid in node.depends_on_nodes:
dep = node_map.get(t.cast(str, uid))
if dep is not None:
upstream_docs.append(f"{uid}: {dep.description}")
if ( # NOTE a semi-arbitrary limit by which its probably better to one shot the table versus many smaller requests
total - documented > 10
):
logger.info(
":robot: Synthesizing bulk documentation for => %s columns in node => %s",
total - documented,
Expand All @@ -1903,29 +1979,43 @@ def synthesize_missing_documentation_with_openai(
getattr(
node, "raw_code", f"SELECT {', '.join(node.columns)} FROM {node.schema}.{node.name}"
),
[context.project.manifest.nodes[n].description for n in node.depends_on_nodes],
f"{node.name} ({node.resource_type}) -- {node.description}",
upstream_docs=upstream_docs,
existing_context=f"{node.unique_id} -- {node.description}",
temperature=0.4,
)
if not node.description or node.description in context.placeholders:
node.description = spec.get("description", node.description)
for synth_col in spec.get("columns", []):
cur_col = node.columns.get(synth_col["name"])
if cur_col and (not cur_col.description or cur_col.description in context.placeholders):
cur_col.description = synth_col.get("description", cur_col.description)
usr_col = node.columns.get(synth_col["name"])
if usr_col and (not usr_col.description or usr_col.description in context.placeholders):
usr_col.description = synth_col.get("description", usr_col.description)
else:
for col_name, col in node.columns.items():
if not node.description or node.description in context.placeholders:
logger.info(
":robot: Synthesizing documentation for node => %s",
node.unique_id,
)
node.description = generate_table_doc(
getattr(
node,
"raw_code",
f"SELECT {', '.join(node.columns)} FROM {node.schema}.{node.name}",
),
table_name=node.relation_name or node.name,
upstream_docs=upstream_docs,
)
for column_name, col in node.columns.items():
if not col.description or col.description in context.placeholders:
logger.info(
":robot: Synthesizing documentation for column => %s in node => %s",
col_name,
column_name,
node.unique_id,
)
col.description = generate_column_doc(
col_name,
f"{node.name} ({node.resource_type}) -- {node.description}",
node.relation_name or node.name,
[context.project.manifest.nodes[n].description for n in node.depends_on_nodes],
column_name,
existing_context=f"{node.unique_id} -- {node.description}",
table_name=node.relation_name or node.name,
upstream_docs=upstream_docs,
temperature=0.7,
)

Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit db59bd0

Please sign in to comment.