Skip to content

Commit

Permalink
Merge pull request #159 from stanford-oval/dev-code-formatter
Browse files Browse the repository at this point in the history
Use `black` as the python code formatter.
  • Loading branch information
Yucheng-Jiang authored Sep 4, 2024
2 parents 6936210 + b5ce593 commit f78a073
Show file tree
Hide file tree
Showing 16 changed files with 1,386 additions and 828 deletions.
10 changes: 10 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
repos:
- repo: https://github.com/psf/black
rev: 24.8.0
hooks:
- id: black
name: Format Python code with black
entry: black
args: ["knowledge_storm/"]
language: python
pass_filenames: true
7 changes: 6 additions & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,9 @@ Following the suggested format can lead to a faster review process.
**Code Format:**
We adopt [PEP8 rules](https://peps.python.org/pep-0008/) for arranging and formatting Python code. Please use a code formatter tool in your IDE to reformat the code before submitting the PR.
We adopt [`black`](https://github.com/psf/black) for arranging and formatting Python code. To streamline the contribution process, we set up a [pre-commit hook](https://pre-commit.com/) to format the code under `knowledge_storm/` before committing. To install the pre-commit hook, run:
```
pip install pre-commit
pre-commit install
```
The hook will automatically format the code before each commit.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
- [2024/05] We add Bing Search support in [rm.py](knowledge_storm/rm.py). Test STORM with `GPT-4o` - we now configure the article generation part in our demo using `GPT-4o` model.
- [2024/04] We release refactored version of STORM codebase! We define [interface](knowledge_storm/interface.py) for STORM pipeline and reimplement STORM-wiki (check out [`src/storm_wiki`](knowledge_storm/storm_wiki)) to demonstrate how to instantiate the pipeline. We provide API to support customization of different language models and retrieval/search integration.

[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)

## Overview [(Try STORM now!)](https://storm.genie.stanford.edu/)

<p align="center">
Expand Down
2 changes: 1 addition & 1 deletion knowledge_storm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .storm_wiki.engine import (
STORMWikiLMConfigs,
STORMWikiRunnerArguments,
STORMWikiRunner
STORMWikiRunner,
)

__version__ = "0.2.5"
69 changes: 47 additions & 22 deletions knowledge_storm/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from collections import OrderedDict
from typing import Dict, List, Optional, Union

logging.basicConfig(level=logging.INFO, format='%(name)s : %(levelname)-8s : %(message)s')
logging.basicConfig(
level=logging.INFO, format="%(name)s : %(levelname)-8s : %(message)s"
)
logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -70,7 +72,9 @@ class Article(ABC):
def __init__(self, topic_name):
self.root = ArticleSectionNode(topic_name)

def find_section(self, node: ArticleSectionNode, name: str) -> Optional[ArticleSectionNode]:
def find_section(
self, node: ArticleSectionNode, name: str
) -> Optional[ArticleSectionNode]:
"""
Return the node of the section given the section name.
Expand Down Expand Up @@ -152,7 +156,9 @@ def prune_empty_nodes(self, node=None):
if node is None:
node = self.root

node.children[:] = [child for child in node.children if self.prune_empty_nodes(child)]
node.children[:] = [
child for child in node.children if self.prune_empty_nodes(child)
]

if (node.content is None or node.content == "") and not node.children:
return None
Expand All @@ -178,7 +184,9 @@ def update_search_top_k(self, k):
def collect_and_reset_rm_usage(self):
combined_usage = []
for attr_name in self.__dict__:
if '_rm' in attr_name and hasattr(getattr(self, attr_name), 'get_usage_and_reset'):
if "_rm" in attr_name and hasattr(
getattr(self, attr_name), "get_usage_and_reset"
):
combined_usage.append(getattr(self, attr_name).get_usage_and_reset())

name_to_usage = {}
Expand Down Expand Up @@ -240,7 +248,9 @@ class OutlineGenerationModule(ABC):
"""

@abstractmethod
def generate_outline(self, topic: str, information_table: InformationTable, **kwargs) -> Article:
def generate_outline(
self, topic: str, information_table: InformationTable, **kwargs
) -> Article:
"""
Generate outline for the article. Required arguments include:
topic: the topic of interest
Expand All @@ -263,11 +273,13 @@ class ArticleGenerationModule(ABC):
"""

@abstractmethod
def generate_article(self,
topic: str,
information_table: InformationTable,
article_with_outline: Article,
**kwargs) -> Article:
def generate_article(
self,
topic: str,
information_table: InformationTable,
article_with_outline: Article,
**kwargs,
) -> Article:
"""
Generate article. Required arguments include:
topic: the topic of interest
Expand Down Expand Up @@ -312,22 +324,23 @@ def wrapper(self, *args, **kwargs):
class LMConfigs(ABC):
"""Abstract base class for language model configurations of the knowledge curation engine.
The language model used for each part should be declared with a suffix '_lm' in the attribute name."""
The language model used for each part should be declared with a suffix '_lm' in the attribute name.
"""

def __init__(self):
pass

def init_check(self):
for attr_name in self.__dict__:
if '_lm' in attr_name and getattr(self, attr_name) is None:
if "_lm" in attr_name and getattr(self, attr_name) is None:
logging.warning(
f"Language model for {attr_name} is not initialized. Please call set_{attr_name}()"
)

def collect_and_reset_lm_history(self):
history = []
for attr_name in self.__dict__:
if '_lm' in attr_name and hasattr(getattr(self, attr_name), 'history'):
if "_lm" in attr_name and hasattr(getattr(self, attr_name), "history"):
history.extend(getattr(self, attr_name).history)
getattr(self, attr_name).history = []

Expand All @@ -336,7 +349,9 @@ def collect_and_reset_lm_history(self):
def collect_and_reset_lm_usage(self):
combined_usage = []
for attr_name in self.__dict__:
if '_lm' in attr_name and hasattr(getattr(self, attr_name), 'get_usage_and_reset'):
if "_lm" in attr_name and hasattr(
getattr(self, attr_name), "get_usage_and_reset"
):
combined_usage.append(getattr(self, attr_name).get_usage_and_reset())

model_name_to_usage = {}
Expand All @@ -345,17 +360,22 @@ def collect_and_reset_lm_usage(self):
if model_name not in model_name_to_usage:
model_name_to_usage[model_name] = tokens
else:
model_name_to_usage[model_name]['prompt_tokens'] += tokens['prompt_tokens']
model_name_to_usage[model_name]['completion_tokens'] += tokens['completion_tokens']
model_name_to_usage[model_name]["prompt_tokens"] += tokens[
"prompt_tokens"
]
model_name_to_usage[model_name]["completion_tokens"] += tokens[
"completion_tokens"
]

return model_name_to_usage

def log(self):

return OrderedDict(
{
attr_name: getattr(self, attr_name).kwargs for attr_name in self.__dict__ if
'_lm' in attr_name and hasattr(getattr(self, attr_name), 'kwargs')
attr_name: getattr(self, attr_name).kwargs
for attr_name in self.__dict__
if "_lm" in attr_name and hasattr(getattr(self, attr_name), "kwargs")
}
)

Expand All @@ -379,16 +399,21 @@ def wrapper(*args, **kwargs):
self.time[func.__name__] = execution_time
logger.info(f"{func.__name__} executed in {execution_time:.4f} seconds")
self.lm_cost[func.__name__] = self.lm_configs.collect_and_reset_lm_usage()
if hasattr(self, 'retriever'):
self.rm_cost[func.__name__] = self.retriever.collect_and_reset_rm_usage()
if hasattr(self, "retriever"):
self.rm_cost[func.__name__] = (
self.retriever.collect_and_reset_rm_usage()
)
return result

return wrapper

def apply_decorators(self):
"""Apply decorators to methods that need them."""
methods_to_decorate = [method_name for method_name in dir(self)
if callable(getattr(self, method_name)) and method_name.startswith('run_')]
methods_to_decorate = [
method_name
for method_name in dir(self)
if callable(getattr(self, method_name)) and method_name.startswith("run_")
]
for method_name in methods_to_decorate:
original_method = getattr(self, method_name)
decorated_method = self.log_execution_time_and_lm_rm_usage(original_method)
Expand Down
Loading

0 comments on commit f78a073

Please sign in to comment.