-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merge extended_conditions into main repository #35
Open
RLKRo
wants to merge
185
commits into
dev
Choose a base branch
from
merge/extended_conditions
base: dev
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 90 commits
Commits
Show all changes
185 commits
Select commit
Hold shift + click to select a range
4343dff
dev: commit all
kudep cd218c0
Add source, example, test files
RLKRo 9b706b0
Move df_extended_conditions to dff/script/logic/extended_conditions
RLKRo 46cbc61
Move examples to examples/extended_conditions
RLKRo 33a1dbe
Move tests to tests/extended_conditions
RLKRo 9a57271
Replace old addon names
RLKRo ea80beb
refactor: remove info from __init__.py
RLKRo de4c2e3
fix: references to files in examples
RLKRo b55e5b8
Add rasa docker
RLKRo e4a3853
Update setup.py
RLKRo 8ed8b1c
Add env variables
RLKRo 4d70a8a
test: partially fix huggingface tests
RLKRo 2795e73
merge moved examples
ruthenian8 7230f62
\nremove utils, \nmigrate examples to pipeline, \nadd readme, \nalter…
ruthenian8 e13e2dc
reformat docs for RST, debug hf models: save dataset for hf matcher
ruthenian8 c7a1f44
debug examples #1: introduce skip conditions
ruthenian8 f33d657
update documentation and examples: add docstrings, module-level docs"
ruthenian8 f66d67e
Merge branch 'rdev' into merge/extended_conditions
ruthenian8 3f1563c
Apply formatting:
ruthenian8 15cc5aa
update references in tests
ruthenian8 b448d9b
fix tests for remote execution
ruthenian8 57407d5
Merge branch 'dev' into merge/extended_conditions
ruthenian8 e3ad218
format test_dialogflow.py
ruthenian8 fa8feee
Fix CI problems:
ruthenian8 f92b95d
Alter testing options:
ruthenian8 31512d9
Change deployment options:
ruthenian8 b9372f3
fix tests for rasa & dialogflow
ruthenian8 cefddaa
improve coverage by removing untested lines and adding new tests
ruthenian8 6a31168
revert rasa example
ruthenian8 df76f85
reformat rasa example
ruthenian8 ce5af99
debug Dataset class: allow instantiation from list
ruthenian8 aa5197a
adjust examples for doc building
ruthenian8 4381faa
rewrite examples
ruthenian8 ef641ac
merge remote dev
ruthenian8 3c64aae
adapt for Message class
ruthenian8 2e5fa91
format file headers; alter coverage.yml
ruthenian8 6bf99f3
add device to hf example; remove hf from .env_file
ruthenian8 9fc1eca
add parameters to BaseModel abstract class; change build_docs.yml
ruthenian8 ff16b82
fix rasa random_seed for training uniformity
ruthenian8 3abaa3e
rasa add random_seed
ruthenian8 abd7767
Update dockerfile_extended_conditions
ruthenian8 7e65385
use ast.literal_eval to circumvent file creation
ruthenian8 e0002b5
redefine skip conditions for tests; update docs
ruthenian8 8bec534
remove skip conditions for test_dialogflow
ruthenian8 c86939e
docs: fix warnings
avsakharov 19d0bc8
change workflow for build docs
ruthenian8 f66229e
correct typo
ruthenian8 a7a094a
fix typo
ruthenian8 e496c37
add python hash seed to .env_file
ruthenian8 e10ead8
rework gensim example
ruthenian8 03de8b5
change thresholds for gensim example && remove variables from test_full
ruthenian8 2a3ffa1
use correct url && remove unused imports
ruthenian8 8c4e6f5
remove old code from test_dialogflow
ruthenian8 63fe863
use word2vec format to avoid problems with pickle.load
ruthenian8 77f4d94
employ additional skip conditions for examples; change threshold in g…
ruthenian8 5d72701
remove sklearn dependency from conftest; check spelling; import sklea…
ruthenian8 64f9def
circumvent import errors from pyyaml; remove torch.device from type a…
ruthenian8 7cd57de
Merge branch 'dev' into merge/extended_conditions
ruthenian8 3361577
add empty line to test_full
ruthenian8 73b0d5c
circumvent 'import joblib' error in 'test_no_deps'
ruthenian8 677ff33
import numpy in test_sklearn after skip_conditions
ruthenian8 64d643b
change docs for modules
ruthenian8 ef8659d
apply lint
ruthenian8 29af70a
Update documentation for extra_conditions
ruthenian8 388e5a5
document utils; change header for hf_api_model
ruthenian8 77f6971
apply lint: invalid docs in utils
ruthenian8 50ad934
merge dev into extended conditions
ruthenian8 1713194
partial fix of tests
ruthenian8 752b438
Update workflows
ruthenian8 9a2dc1b
Update setup.py
ruthenian8 75017cd
Update setup.py
ruthenian8 b4d80dd
Update setup.py
ruthenian8 eb96c42
correct setup.py
ruthenian8 c9e0a34
Update conftest.py
ruthenian8 3ced545
update tutorials; use categorical_code as normal attribute
ruthenian8 90e8f25
Update pytest markers
ruthenian8 a58e5bd
Update tutors
ruthenian8 8794344
update test_full
ruthenian8 25b6d2d
Update docs & code comments
ruthenian8 bd7355a
require requests for extended conditions; update requirements in tuto…
ruthenian8 4614c93
rename BaseModel to ExtrasBaseModel
ruthenian8 f89605d
set up GDF in test_full
ruthenian8 8ef14f6
add debug print to test_tutors
ruthenian8 881826e
add realpath directives to workflows; alter transformers version
ruthenian8 fa2d619
Update env variables
ruthenian8 e533d5b
Update imports
ruthenian8 b9d90b5
Update hf example
ruthenian8 98fbda2
configure softmax from dim=0 to dim=1
ruthenian8 cdf8ec0
Update happy path
ruthenian8 5fbc16f
Merge branch 'dev' into merge/extended_conditions
RLKRo 9488b6c
Updated extra dependencies
NotBioWaste 564601f
Added ext profile to CONTRIBUTING.md
NotBioWaste ca02a08
Fix typo
NotBioWaste a0edbc6
Merge remote-tracking branch 'origin/dev' into merge/extended_conditions
NotBioWaste cd6c025
Reworking namespaces and label caching
NotBioWaste 3809d57
Moved llm_conditions to
NotBioWaste d520d0c
Fixed models call
NotBioWaste fd77a11
Fixed dependecies and references to modules
NotBioWaste d2d3680
Fixed tests, rewriting tutorials
NotBioWaste 8ace188
Added caching for async API calls, working on async ExtrasBaseAPIModel
NotBioWaste905 f639141
Removed local models, updated tutorials
NotBioWaste905 99ced4d
Fixed namespace reference
NotBioWaste905 f15be68
Started working on llm_responses
NotBioWaste905 7dd03a1
Fixed typos in tutorials
NotBioWaste 56b7789
Created class, created 1st tutorial
NotBioWaste af60115
Added dependecies for langchain
NotBioWaste b3b79a5
Fixed adding custom prompt for each node
NotBioWaste 6eb910d
Added image processing, updated tutorial
NotBioWaste 1f8cddc
Added typehint
NotBioWaste 74cd954
Added llm_response, LLM_API, history management
NotBioWaste 1fd31a2
Fixed image reading
NotBioWaste 2c48490
Started llm condition
NotBioWaste a1884e5
Added message_to_langchain
NotBioWaste 61f302e
Implementing deepeval integration
NotBioWaste 38a8f8f
Figured out how to implement DeepEval functions
NotBioWaste905 592267f
Adding conditions
NotBioWaste baccc47
Implemented simple conditions call, added BaseMethod class, renaming,…
NotBioWaste 8e84ba1
Fixed history extraction
NotBioWaste 2b2847b
Delete test_bot.py
NotBioWaste905 7e336ac
Fixed prompt handling, switched to AIMessage in LLM response
NotBioWaste 71babbf
Merge branch 'feat/llm_responses' of https://github.com/deeppavlov/di…
NotBioWaste 351ae06
Fixed conditions call
NotBioWaste e3d0d15
Working on autotesting
NotBioWaste 0405998
Added tests
NotBioWaste 3dbfd0c
Removed unused method
NotBioWaste 5c876ba
Added annotations
NotBioWaste 8f1932c
Added structured output support, tweaked tests
NotBioWaste aedf47e
Reworking tutorials
NotBioWaste adadb05
Reworked prompt usage and hierarchy, reworked filters and methods
NotBioWaste 0288896
No idea how to make script smaller in tutorials
NotBioWaste 67e2758
Small fixes in tutorials and structured generation
NotBioWaste 428a9f0
Working on user guide
NotBioWaste 5e26b4b
Fixed some tutorials, finished user guide
NotBioWaste 5dbb6cd
Bugfixes in docs
NotBioWaste db63d1a
Lint
NotBioWaste 2b9080f
Removed type annotation that broke docs building
NotBioWaste 2bcda71
Tests and bugfixes
NotBioWaste d2f28ed
Deleted DeepEval references
NotBioWaste 7318c91
Numpy versions trouble
NotBioWaste 27eae27
Fixed dependecies
NotBioWaste 3fed1fc
Made everything asynchronous
NotBioWaste 30862ca
Added and unified docstring
NotBioWaste 06ab5bc
Added 4th tutorial, fixed message_schema parameter passing
NotBioWaste 798a77b
Bugfix, added max_size to the message_to_langchain function
NotBioWaste 3343159
Made even more everything asynchronous
NotBioWaste 014ff7e
Remade condition, added logprob check
NotBioWaste 761bd81
Async bugfix, added model_result_to_text, working on message_schema f…
NotBioWaste 90a811e
Minor fixes, tinkering tests
NotBioWaste 5bff191
Merge branch 'refs/heads/dev' into feat/llm_responses
RLKRo 8b88ba6
update lock file
RLKRo 20c4afd
Merge remote-tracking branch 'origin/feat/llm_responses' into feat/ll…
RLKRo 0139421
Merge remote-tracking branch 'origin/master' into feat/llm_responses
NotBioWaste905 9bb0cba
Updating to v1.0
NotBioWaste905 f2d6b68
Finished tests, finished update
NotBioWaste905 6fddaea
lint
NotBioWaste905 e06bc2b
Started working on llm slots
NotBioWaste905 22d8efc
Resolving pydantic errors
NotBioWaste905 aa735b5
Delete llmslot_test.py
NotBioWaste905 cc91133
Finished LLMSlot, working on LLMGroupSlot
NotBioWaste905 8756838
Merge remote-tracking branch 'origin/feat/llm_responses' into feat/ll…
NotBioWaste905 f1857f6
Added flag to
NotBioWaste905 c334ff5
First test attempts
NotBioWaste905 8306bbb
linting
NotBioWaste905 f842776
Merge branch 'feat/slots_extraction_update' into feat/llm_responses
NotBioWaste905 ada17ca
Merge remote-tracking branch 'origin/feat/llm_responses' into feat/ll…
NotBioWaste905 a45f653
File structure fixed
NotBioWaste905 3838d30
Fixed naming
NotBioWaste905 0e650f8
Create LLMCondition and LLMResponse classes
NotBioWaste905 ca79f94
Merge branch 'dev' into merge/extended_conditions
NotBioWaste905 015cb4f
Debugging flattening
NotBioWaste905 b6e5eeb
Bugfix
NotBioWaste905 b20137e
Added return_type property for LLMSlot
NotBioWaste905 25f5b04
Changed return_type from Any to type
NotBioWaste905 b651087
lint
NotBioWaste905 284555d
Fixed dependency namings
NotBioWaste905 354b51d
Fixed singledispatch
NotBioWaste905 640aeb3
Removed Dataset and ExtrasBaseModel, created HasLabel and HasMatch co…
NotBioWaste905 ee7d5e2
Removed deprecated files
NotBioWaste905 492239d
Deleted synchronous variants, removed property models_labels from Con…
NotBioWaste905 474cd7f
Deleted unused modules, merged classes with their abstract variants
NotBioWaste905 1b5a77b
removed deprecated from_script from tutorials
NotBioWaste905 c18d375
Fixed LLMCondition class
NotBioWaste905 e884494
Removed inner functions, fixed signatures in conditions
NotBioWaste905 459f7fc
Fixed missing 'models' field in Pipeline, updated tutorials
NotBioWaste905 57a2d9d
Merge branch 'feat/llm_responses' into merge/extended_conditions
NotBioWaste905 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,6 +35,16 @@ jobs: | |
with: | ||
pandoc-version: '3.1.6' | ||
|
||
- name: Create gdf_account.json | ||
uses: jsdaniell/[email protected] | ||
with: | ||
name: "gdf_account.json" | ||
json: ${{ secrets.GDF_ACCOUNT_JSON }} | ||
|
||
- name: write realpath to env | ||
run: | | ||
echo "GDF_ACCOUNT_JSON=$(realpath ./gdf_account.json)" >> $GITHUB_ENV | ||
|
||
- name: install dependencies | ||
run: | | ||
make venv | ||
|
@@ -45,6 +55,8 @@ jobs: | |
TG_API_ID: ${{ secrets.TG_API_ID }} | ||
TG_API_HASH: ${{ secrets.TG_API_HASH }} | ||
TG_BOT_USERNAME: ${{ secrets.TG_BOT_USERNAME }} | ||
GDF_ACCOUNT_JSON: ${{ env.GDF_ACCOUNT_JSON }} | ||
HF_API_KEY: ${{ secrets.HF_API_KEY }} | ||
run: | | ||
make doc | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -50,11 +50,23 @@ jobs: | |
rm -rf /tmp/backup | ||
touch venv # disable venv target | ||
|
||
- name: Create gdf_account.json | ||
uses: jsdaniell/[email protected] | ||
with: | ||
name: "gdf_account.json" | ||
json: ${{ secrets.GDF_ACCOUNT_JSON }} | ||
|
||
- name: write realpath to env | ||
run: | | ||
echo "GDF_ACCOUNT_JSON=$(realpath ./gdf_account.json)" >> $GITHUB_ENV | ||
|
||
- name: run tests | ||
env: | ||
TG_BOT_TOKEN: ${{ secrets.TG_BOT_TOKEN }} | ||
TG_API_ID: ${{ secrets.TG_API_ID }} | ||
TG_API_HASH: ${{ secrets.TG_API_HASH }} | ||
TG_BOT_USERNAME: ${{ secrets.TG_BOT_USERNAME }} | ||
HF_API_KEY: ${{ secrets.HF_API_KEY }} | ||
GDF_ACCOUNT_JSON: ${{ env.GDF_ACCOUNT_JSON }} | ||
run: | | ||
make test TEST_ALLOW_SKIP=telegram |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -43,18 +43,30 @@ jobs: | |
python -m pip install -e .[test_full] | ||
shell: bash | ||
|
||
- name: Create gdf_account.json | ||
uses: jsdaniell/[email protected] | ||
with: | ||
name: "gdf_account.json" | ||
json: ${{ secrets.GDF_ACCOUNT_JSON }} | ||
|
||
- name: write realpath to env | ||
run: | | ||
echo "GDF_ACCOUNT_JSON=$(realpath ./gdf_account.json)" >> $GITHUB_ENV | ||
|
||
- name: run pytest | ||
env: | ||
TG_BOT_TOKEN: ${{ secrets.TG_BOT_TOKEN }} | ||
TG_API_ID: ${{ secrets.TG_API_ID }} | ||
TG_API_HASH: ${{ secrets.TG_API_HASH }} | ||
TG_BOT_USERNAME: ${{ secrets.TG_BOT_USERNAME }} | ||
HF_API_KEY: ${{ secrets.HF_API_KEY }} | ||
GDF_ACCOUNT_JSON: ${{ env.GDF_ACCOUNT_JSON }} | ||
run: | | ||
if [ "$RUNNER_OS" == "Linux" ]; then | ||
source <(cat .env_file | sed 's/=/=/' | sed 's/^/export /') | ||
pytest --tb=long -vv --cache-clear --no-cov --allow-skip=telegram tests/ | ||
else | ||
pytest -m "not docker" --tb=long -vv --cache-clear --no-cov --allow-skip=telegram,docker tests/ | ||
pytest -m "not docker" --tb=long -vv --cache-clear --no-cov --allow-skip=telegram,docker,huggingface,rasa,dialogflow tests/ | ||
fi | ||
shell: bash | ||
test_no_deps: | ||
|
@@ -77,12 +89,24 @@ jobs: | |
python -m pip install -e .[tests] | ||
shell: bash | ||
|
||
- name: Create gdf_account.json | ||
uses: jsdaniell/[email protected] | ||
with: | ||
name: "gdf_account.json" | ||
json: ${{ secrets.GDF_ACCOUNT_JSON }} | ||
|
||
- name: write realpath to env | ||
run: | | ||
echo "GDF_ACCOUNT_JSON=$(realpath ./gdf_account.json)" >> $GITHUB_ENV | ||
|
||
- name: run pytest | ||
env: | ||
TG_BOT_TOKEN: ${{ secrets.TG_BOT_TOKEN }} | ||
TG_API_ID: ${{ secrets.TG_API_ID }} | ||
TG_API_HASH: ${{ secrets.TG_API_HASH }} | ||
TG_BOT_USERNAME: ${{ secrets.TG_BOT_USERNAME }} | ||
HF_API_KEY: ${{ secrets.HF_API_KEY }} | ||
GDF_ACCOUNT_JSON: ${{ env.GDF_ACCOUNT_JSON }} | ||
run: | | ||
source <(cat .env_file | sed 's/=/=/' | sed 's/^/export /') | ||
pytest --tb=long -vv --cache-clear --no-cov --allow-skip=all tests/ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
""" | ||
Conditions | ||
------------ | ||
|
||
This module provides condition functions for annotation processing. | ||
""" | ||
from typing import Callable, Optional, List | ||
from functools import singledispatch | ||
|
||
try: | ||
from sklearn.metrics.pairwise import cosine_similarity | ||
|
||
sklearn_available = True | ||
except ImportError: | ||
sklearn_available = False | ||
from dff.script import Context | ||
from dff.pipeline import Pipeline | ||
from dff.script.extras.conditions.dataset import DatasetItem | ||
from dff.script.extras.conditions.utils import LABEL_KEY | ||
from dff.script.extras.conditions.models.base_model import ExtrasBaseModel | ||
|
||
|
||
@singledispatch | ||
def has_cls_label(label, namespace: Optional[str] = None, threshold: float = 0.9): | ||
""" | ||
Use this condition, when you need to check, whether the probability | ||
of a particular label for the last annotated user utterance surpasses the threshold. | ||
|
||
:param label: String name or a reference to a DatasetItem object, or a collection thereof. | ||
:param namespace: Namespace key of a particular model that should detect the dataset_item. | ||
If not set, all namespaces will be searched for the required dataset_item. | ||
:param threshold: The minimal label probability that triggers a positive response | ||
from the function. | ||
""" | ||
raise NotImplementedError | ||
|
||
|
||
@has_cls_label.register(str) | ||
def _(label, namespace: Optional[str] = None, threshold: float = 0.9): | ||
def has_cls_label_innner(ctx: Context, _) -> bool: | ||
if LABEL_KEY not in ctx.framework_states: | ||
return False | ||
if namespace is not None: | ||
return ctx.framework_states[LABEL_KEY].get(namespace, {}).get(label, 0) >= threshold | ||
scores = [item.get(label, 0) for item in ctx.framework_states[LABEL_KEY].values()] | ||
comparison_array = [item >= threshold for item in scores] | ||
return any(comparison_array) | ||
|
||
return has_cls_label_innner | ||
|
||
|
||
@has_cls_label.register(DatasetItem) | ||
def _(label, namespace: Optional[str] = None, threshold: float = 0.9) -> Callable[[Context, Pipeline], bool]: | ||
def has_cls_label_innner(ctx: Context, _) -> bool: | ||
if LABEL_KEY not in ctx.framework_states: | ||
return False | ||
if namespace is not None: | ||
return ctx.framework_states[LABEL_KEY].get(namespace, {}).get(label.label, 0) >= threshold | ||
scores = [item.get(label.label, 0) for item in ctx.framework_states[LABEL_KEY].values()] | ||
comparison_array = [item >= threshold for item in scores] | ||
return any(comparison_array) | ||
|
||
return has_cls_label_innner | ||
|
||
|
||
@has_cls_label.register(list) | ||
def _(label, namespace: Optional[str] = None, threshold: float = 0.9): | ||
def has_cls_label_innner(ctx: Context, pipeline: Pipeline) -> bool: | ||
if LABEL_KEY not in ctx.framework_states: | ||
return False | ||
scores = [has_cls_label(item, namespace, threshold)(ctx, pipeline) for item in label] | ||
for score in scores: | ||
if score >= threshold: | ||
return True | ||
return False | ||
|
||
return has_cls_label_innner | ||
|
||
|
||
def has_match( | ||
model: ExtrasBaseModel, | ||
positive_examples: Optional[List[str]], | ||
negative_examples: Optional[List[str]] = None, | ||
threshold: float = 0.9, | ||
): | ||
""" | ||
Use this condition, if you need to check whether the last request matches | ||
any of the pre-defined intent utterances. | ||
The model passed to this function should be in the fit state. | ||
|
||
:param model: Any model from the :py:mod:`~dff.script.extras.conditions.models.local.cosine_matchers` module. | ||
:param positive_examples: Utterances that the request should match. | ||
:param negative_examples: Utterances that the request should not match. | ||
:param threshold: Similarity threshold that triggers a positive response from the function. | ||
""" | ||
if negative_examples is None: | ||
negative_examples = [] | ||
|
||
def has_match_inner(ctx: Context, _) -> bool: | ||
if not (ctx.last_request and ctx.last_request.text): | ||
return False | ||
input_vector = model.transform(ctx.last_request.text) | ||
positive_vectors = [model.transform(item) for item in positive_examples] | ||
negative_vectors = [model.transform(item) for item in negative_examples] | ||
positive_sims = [cosine_similarity(input_vector, item)[0][0] for item in positive_vectors] | ||
negative_sims = [cosine_similarity(input_vector, item)[0][0] for item in negative_vectors] | ||
max_pos_sim = max(positive_sims) | ||
max_neg_sim = 0 if len(negative_sims) == 0 else max(negative_sims) | ||
return bool(max_pos_sim > threshold > max_neg_sim) | ||
|
||
return has_match_inner |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
""" | ||
Dataset | ||
-------- | ||
|
||
This module contains data structures that are required to parse items from files | ||
and parse requests and responses to and from various APIs. | ||
|
||
""" | ||
from pathlib import Path | ||
import json | ||
from typing import List, Dict, Union | ||
|
||
from pydantic import BaseModel, Field, field_validator, model_validator | ||
|
||
try: | ||
from yaml import load, SafeLoader | ||
|
||
pyyaml_available = True | ||
except ImportError: | ||
pyyaml_available = False | ||
|
||
|
||
class DatasetItem(BaseModel, arbitrary_types_allowed=True): | ||
""" | ||
Data structure for storing labeled utterances. | ||
|
||
:param label: Raw classification label. | ||
:param samples: Utterance examples. At least one sentence is required. | ||
""" | ||
|
||
label: str | ||
samples: List[Union[List[str], Dict[str, str], str]] = Field(default_factory=list, min_length=1) | ||
categorical_code: int = Field(default=0) | ||
|
||
|
||
class Dataset(BaseModel, arbitrary_types_allowed=True): | ||
""" | ||
Data structure for storing multiple :py:class:`~DatasetItem` objects. | ||
|
||
:param items: Can be initialized either with a list or with a dict | ||
of :py:class:`~DatasetItem` objects. | ||
Makes each item accessible by its label. | ||
""" | ||
|
||
items: Dict[str, DatasetItem] = Field(default_factory=dict) | ||
flat_items: list = Field(default_factory=list) | ||
"""`flat_items` field is populated automatically using objects from the `items` field.""" | ||
|
||
def __getitem__(self, idx: str): | ||
return self.flat_items[idx] | ||
|
||
def __len__(self): | ||
return len(self.flat_items) | ||
|
||
@classmethod | ||
def _get_path(cls, file: str): | ||
if isinstance(file, Path): | ||
file_path = file | ||
else: | ||
file_path = Path(file) | ||
if not file_path.exists() or not file_path.is_file(): | ||
raise OSError(f"File does not exist: {file}") | ||
return file_path | ||
|
||
@classmethod | ||
def parse_json(cls, file: Union[str, Path]): | ||
file_path = cls._get_path(file) | ||
items = json.load(file_path.open("r", encoding="utf-8")) | ||
return cls(items=[DatasetItem.model_validate(item) for item in items]) | ||
|
||
@classmethod | ||
def parse_jsonl(cls, file: Union[str, Path]): | ||
file_path = cls._get_path(file) | ||
lines = file_path.open("r", encoding="utf-8").readlines() | ||
items = [DatasetItem.model_validate_json(line) for line in lines] | ||
return cls(items=items) | ||
|
||
@classmethod | ||
def parse_yaml(cls, file: Union[str, Path]): | ||
if not pyyaml_available: | ||
raise ImportError("`pyyaml` package missing. Try `pip install dff[ext].`") | ||
file_path = cls._get_path(file) | ||
raw_items = load(file_path.open("r", encoding="utf-8").read(), SafeLoader)["items"] | ||
items = [DatasetItem.model_validate(item) for item in raw_items] | ||
return cls(items=items) | ||
|
||
@field_validator("items", mode="before") | ||
@classmethod | ||
def pre_validate_items(cls, value: Union[Dict[str, DatasetItem], List[DatasetItem]]): | ||
if isinstance(value, list): # if items were passed as a list, cast them to a dict | ||
new_value = [DatasetItem.model_validate(item) for item in value] | ||
item_labels = [item.label for item in new_value] | ||
value = {label: item for label, item in zip(item_labels, new_value)} | ||
|
||
return value | ||
|
||
# @root_validator | ||
@model_validator(mode="after") | ||
def post_validation(self): | ||
items: Dict[str, DatasetItem] = self.items | ||
for idx, key in enumerate(items.keys()): | ||
items[key].categorical_code = idx | ||
|
||
sentences = [sentence for dataset_item in items.values() for sentence in dataset_item.samples] | ||
pred_labels = [ | ||
label for dataset_item in items.values() for label in [dataset_item.label] * len(dataset_item.samples) | ||
] | ||
self.flat_items = list(zip(sentences, pred_labels)) | ||
return self |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from .local.classifiers.huggingface import HFClassifier # noqa: F401 | ||
from .local.classifiers.regex import RegexClassifier, RegexModel # noqa: F401 | ||
from .local.classifiers.sklearn import SklearnClassifier # noqa: F401 | ||
from .local.cosine_matchers.gensim import GensimMatcher # noqa: F401 | ||
from .local.cosine_matchers.huggingface import HFMatcher # noqa: F401 | ||
from .local.cosine_matchers.sklearn import SklearnMatcher # noqa: F401 | ||
from .remote_api.google_dialogflow_model import GoogleDialogFlowModel, AsyncGoogleDialogFlowModel # noqa: F401 | ||
from .remote_api.rasa_model import AsyncRasaModel, RasaModel # noqa: F401 | ||
from .remote_api.hf_api_model import AsyncHFAPIModel, HFAPIModel # noqa: F401 |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think
Path
file object should support directreadlines()
calls.