From 284555d7c2135080ef5d099e163fcdf7b9b6f9d7 Mon Sep 17 00:00:00 2001 From: askatasuna Date: Thu, 24 Oct 2024 00:05:00 +0300 Subject: [PATCH] Fixed dependency namings --- .../conditions.py => conditions/ml.py} | 10 +- .../llm_conditions => ml}/__init__.py | 0 .../llm_conditions => ml}/dataset.py | 0 chatsky/ml/models/__init__.py | 3 + .../models/base_model.py | 4 +- .../models/remote_api/__init__.py | 0 .../models/remote_api/async_mixin.py | 4 +- .../remote_api/google_dialogflow_model.py | 4 +- .../models/remote_api/hf_api_model.py | 4 +- .../models/remote_api/rasa_model.py | 6 +- .../conditions/llm_conditions => ml}/utils.py | 0 .../llm_conditions/models/__init__.py | 9 - chatsky/script/core/context.py | 287 ------------------ tests/script/extras/conditions/conftest.py | 4 +- .../extras/conditions/test_conditions.py | 19 +- .../extras/conditions/test_dialogflow.py | 2 +- tests/script/extras/conditions/test_gensim.py | 37 --- tests/script/extras/conditions/test_hf.py | 63 ---- tests/script/extras/conditions/test_hf_api.py | 2 +- tests/script/extras/conditions/test_rasa.py | 2 +- .../script/extras/conditions/test_sklearn.py | 53 ---- .../extras/conditions/test_tutorials.py | 13 +- tests/script/extras/conditions/test_utils.py | 8 +- .../script/extras/conditions/1_hf_api.py | 10 +- .../script/extras/conditions/2_dialogflow.py | 10 +- tutorials/script/extras/conditions/3_rasa.py | 10 +- 26 files changed, 52 insertions(+), 512 deletions(-) rename chatsky/{script/conditions/llm_conditions/conditions.py => conditions/ml.py} (92%) rename chatsky/{script/conditions/llm_conditions => ml}/__init__.py (100%) rename chatsky/{script/conditions/llm_conditions => ml}/dataset.py (100%) create mode 100644 chatsky/ml/models/__init__.py rename chatsky/{script/conditions/llm_conditions => ml}/models/base_model.py (97%) rename chatsky/{script/conditions/llm_conditions => ml}/models/remote_api/__init__.py (100%) rename chatsky/{script/conditions/llm_conditions => ml}/models/remote_api/async_mixin.py (85%) rename chatsky/{script/conditions/llm_conditions => ml}/models/remote_api/google_dialogflow_model.py (96%) rename chatsky/{script/conditions/llm_conditions => ml}/models/remote_api/hf_api_model.py (94%) rename chatsky/{script/conditions/llm_conditions => ml}/models/remote_api/rasa_model.py (93%) rename chatsky/{script/conditions/llm_conditions => ml}/utils.py (100%) delete mode 100644 chatsky/script/conditions/llm_conditions/models/__init__.py delete mode 100644 chatsky/script/core/context.py delete mode 100644 tests/script/extras/conditions/test_gensim.py delete mode 100644 tests/script/extras/conditions/test_hf.py delete mode 100644 tests/script/extras/conditions/test_sklearn.py diff --git a/chatsky/script/conditions/llm_conditions/conditions.py b/chatsky/conditions/ml.py similarity index 92% rename from chatsky/script/conditions/llm_conditions/conditions.py rename to chatsky/conditions/ml.py index f028829e6..93dcc6b23 100644 --- a/chatsky/script/conditions/llm_conditions/conditions.py +++ b/chatsky/conditions/ml.py @@ -8,15 +8,15 @@ from functools import singledispatch try: + #!!! remove sklearn, use something else instead from sklearn.metrics.pairwise import cosine_similarity sklearn_available = True except ImportError: sklearn_available = False -from chatsky.script import Context -from chatsky.pipeline import Pipeline -from chatsky.script.conditions.llm_conditions.dataset import DatasetItem -from chatsky.script.conditions.llm_conditions.models.base_model import ExtrasBaseModel +from chatsky import Context, Pipeline +from chatsky.ml.dataset import DatasetItem +from chatsky.ml.models.base_model import ExtrasBaseModel @singledispatch @@ -92,7 +92,7 @@ def has_match( any of the pre-defined intent utterances. The model passed to this function should be in the fit state. - :param model: Any model from the :py:mod:`~chatsky.script.conditions.llm_conditions.models.local.cosine_matchers` module. + :param model: Any model from the :py:mod:`~chatsky.ml.models.local.cosine_matchers` module. :param positive_examples: Utterances that the request should match. :param negative_examples: Utterances that the request should not match. :param threshold: Similarity threshold that triggers a positive response from the function. diff --git a/chatsky/script/conditions/llm_conditions/__init__.py b/chatsky/ml/__init__.py similarity index 100% rename from chatsky/script/conditions/llm_conditions/__init__.py rename to chatsky/ml/__init__.py diff --git a/chatsky/script/conditions/llm_conditions/dataset.py b/chatsky/ml/dataset.py similarity index 100% rename from chatsky/script/conditions/llm_conditions/dataset.py rename to chatsky/ml/dataset.py diff --git a/chatsky/ml/models/__init__.py b/chatsky/ml/models/__init__.py new file mode 100644 index 000000000..c5b4d2cf7 --- /dev/null +++ b/chatsky/ml/models/__init__.py @@ -0,0 +1,3 @@ +from .remote_api.google_dialogflow_model import GoogleDialogFlowModel, AsyncGoogleDialogFlowModel # noqa: F401 +from .remote_api.rasa_model import AsyncRasaModel, RasaModel # noqa: F401 +from .remote_api.hf_api_model import AsyncHFAPIModel, HFAPIModel # noqa: F401 diff --git a/chatsky/script/conditions/llm_conditions/models/base_model.py b/chatsky/ml/models/base_model.py similarity index 97% rename from chatsky/script/conditions/llm_conditions/models/base_model.py rename to chatsky/ml/models/base_model.py index a76cc3e1b..281d91dcd 100644 --- a/chatsky/script/conditions/llm_conditions/models/base_model.py +++ b/chatsky/ml/models/base_model.py @@ -7,9 +7,9 @@ from copy import copy from abc import ABC, abstractmethod -from chatsky.script import Context +from chatsky import Context -from chatsky.script.conditions.llm_conditions.dataset import Dataset +from chatsky.ml.dataset import Dataset import asyncio from async_lru import alru_cache diff --git a/chatsky/script/conditions/llm_conditions/models/remote_api/__init__.py b/chatsky/ml/models/remote_api/__init__.py similarity index 100% rename from chatsky/script/conditions/llm_conditions/models/remote_api/__init__.py rename to chatsky/ml/models/remote_api/__init__.py diff --git a/chatsky/script/conditions/llm_conditions/models/remote_api/async_mixin.py b/chatsky/ml/models/remote_api/async_mixin.py similarity index 85% rename from chatsky/script/conditions/llm_conditions/models/remote_api/async_mixin.py rename to chatsky/ml/models/remote_api/async_mixin.py index 74957b56d..1211c938d 100644 --- a/chatsky/script/conditions/llm_conditions/models/remote_api/async_mixin.py +++ b/chatsky/ml/models/remote_api/async_mixin.py @@ -5,8 +5,8 @@ This module provides the mixin that overrides the :py:meth:`__call__` method in all the descendants making them asynchronous. """ -from chatsky.script import Context -from chatsky.script.conditions.llm_conditions.models.base_model import ExtrasBaseModel +from chatsky import Context +from chatsky.ml.models.base_model import ExtrasBaseModel class AsyncMixin(ExtrasBaseModel): diff --git a/chatsky/script/conditions/llm_conditions/models/remote_api/google_dialogflow_model.py b/chatsky/ml/models/remote_api/google_dialogflow_model.py similarity index 96% rename from chatsky/script/conditions/llm_conditions/models/remote_api/google_dialogflow_model.py rename to chatsky/ml/models/remote_api/google_dialogflow_model.py index 5b01e74d8..2d6d5c2f2 100644 --- a/chatsky/script/conditions/llm_conditions/models/remote_api/google_dialogflow_model.py +++ b/chatsky/ml/models/remote_api/google_dialogflow_model.py @@ -11,8 +11,8 @@ from pathlib import Path from async_lru import alru_cache -from chatsky.script.conditions.llm_conditions.models.base_model import ExtrasBaseModel -from chatsky.script.conditions.llm_conditions.models.remote_api.async_mixin import AsyncMixin +from chatsky.ml.models.base_model import ExtrasBaseModel +from chatsky.ml.models.remote_api.async_mixin import AsyncMixin try: from google.cloud import dialogflow_v2 diff --git a/chatsky/script/conditions/llm_conditions/models/remote_api/hf_api_model.py b/chatsky/ml/models/remote_api/hf_api_model.py similarity index 94% rename from chatsky/script/conditions/llm_conditions/models/remote_api/hf_api_model.py rename to chatsky/ml/models/remote_api/hf_api_model.py index ad22f415b..367662886 100644 --- a/chatsky/script/conditions/llm_conditions/models/remote_api/hf_api_model.py +++ b/chatsky/ml/models/remote_api/hf_api_model.py @@ -21,8 +21,8 @@ except ImportError: hf_api_available = False -from chatsky.script.conditions.llm_conditions.models.base_model import ExtrasBaseModel -from chatsky.script.conditions.llm_conditions.models.remote_api.async_mixin import AsyncMixin +from chatsky.ml.models.base_model import ExtrasBaseModel +from chatsky.ml.models.remote_api.async_mixin import AsyncMixin class AbstractHFAPIModel(ExtrasBaseModel): diff --git a/chatsky/script/conditions/llm_conditions/models/remote_api/rasa_model.py b/chatsky/ml/models/remote_api/rasa_model.py similarity index 93% rename from chatsky/script/conditions/llm_conditions/models/remote_api/rasa_model.py rename to chatsky/ml/models/remote_api/rasa_model.py index fb60c6402..915911e54 100644 --- a/chatsky/script/conditions/llm_conditions/models/remote_api/rasa_model.py +++ b/chatsky/ml/models/remote_api/rasa_model.py @@ -21,9 +21,9 @@ rasa_available = False from http import HTTPStatus -from chatsky.script.conditions.llm_conditions.utils import RasaResponse -from chatsky.script.conditions.llm_conditions.models.base_model import ExtrasBaseModel -from chatsky.script.conditions.llm_conditions.models.remote_api.async_mixin import AsyncMixin +from chatsky.ml.utils import RasaResponse +from chatsky.ml.models.base_model import ExtrasBaseModel +from chatsky.ml.models.remote_api.async_mixin import AsyncMixin class AbstractRasaModel(ExtrasBaseModel): diff --git a/chatsky/script/conditions/llm_conditions/utils.py b/chatsky/ml/utils.py similarity index 100% rename from chatsky/script/conditions/llm_conditions/utils.py rename to chatsky/ml/utils.py diff --git a/chatsky/script/conditions/llm_conditions/models/__init__.py b/chatsky/script/conditions/llm_conditions/models/__init__.py deleted file mode 100644 index 34cff459e..000000000 --- a/chatsky/script/conditions/llm_conditions/models/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from .local.classifiers.huggingface import HFClassifier # noqa: F401 -from .local.classifiers.regex import RegexClassifier, RegexModel # noqa: F401 -from .local.classifiers.sklearn import SklearnClassifier # noqa: F401 -from .local.cosine_matchers.gensim import GensimMatcher # noqa: F401 -from .local.cosine_matchers.huggingface import HFMatcher # noqa: F401 -from .local.cosine_matchers.sklearn import SklearnMatcher # noqa: F401 -from .remote_api.google_dialogflow_model import GoogleDialogFlowModel, AsyncGoogleDialogFlowModel # noqa: F401 -from .remote_api.rasa_model import AsyncRasaModel, RasaModel # noqa: F401 -from .remote_api.hf_api_model import AsyncHFAPIModel, HFAPIModel # noqa: F401 diff --git a/chatsky/script/core/context.py b/chatsky/script/core/context.py deleted file mode 100644 index 51175e5fc..000000000 --- a/chatsky/script/core/context.py +++ /dev/null @@ -1,287 +0,0 @@ -""" -Context -------- -A Context is a data structure that is used to store information about the current state of a conversation. -It is used to keep track of the user's input, the current stage of the conversation, and any other -information that is relevant to the current context of a dialog. -The Context provides a convenient interface for working with data, allowing developers to easily add, -retrieve, and manipulate data as the conversation progresses. - -The Context data structure provides several key features to make working with data easier. -Developers can use the context to store any information that is relevant to the current conversation, -such as user data, session data, conversation history, or etc. -This allows developers to easily access and use this data throughout the conversation flow. - -Another important feature of the context is data serialization. -The context can be easily serialized to a format that can be stored or transmitted, such as JSON. -This allows developers to save the context data and resume the conversation later. -""" - -from __future__ import annotations -import logging -from uuid import UUID, uuid4 -from typing import Any, Optional, Union, Dict, List, Set, TYPE_CHECKING - -from pydantic import BaseModel, Field, field_validator - -from chatsky.script.core.message import Message -from chatsky.script.core.types import NodeLabel2Type -from chatsky.pipeline.types import ComponentExecutionState -from chatsky.slots.slots import SlotManager - -if TYPE_CHECKING: - from chatsky.script.core.script import Node - -logger = logging.getLogger(__name__) - - -def get_last_index(dictionary: dict) -> int: - """ - Obtain the last index from the `dictionary`. Return `-1` if the `dict` is empty. - - :param dictionary: Dictionary with unsorted keys. - :return: Last index from the `dictionary`. - """ - indices = list(dictionary) - return indices[-1] if indices else -1 - - -class FrameworkData(BaseModel): - """ - Framework uses this to store data related to any of its modules. - """ - - service_states: Dict[str, ComponentExecutionState] = Field(default_factory=dict, exclude=True) - "Statuses of all the pipeline services. Cleared at the end of every turn." - actor_data: Dict[str, Any] = Field(default_factory=dict, exclude=True) - "Actor service data. Cleared at the end of every turn." - stats: Dict[str, Any] = Field(default_factory=dict) - "Enables complex stats collection across multiple turns." - slot_manager: SlotManager = Field(default_factory=SlotManager) - "Stores extracted slots." - llm_labels: Dict[str, Dict[str, Any]] = Field(default_factory=dict) - "Contains predicted labels for the models on this step." - llm_labels_cache: Dict[str, Dict[str, Dict[str, float]]] = Field(default_factory=dict) - "Stores predicted labels for the messages for the whole context." - - -class Context(BaseModel): - """ - A structure that is used to store data about the context of a dialog. - - Avoid storing unserializable data in the fields of this class in order for - context storages to work. - """ - - id: Union[UUID, int, str] = Field(default_factory=uuid4) - """ - `id` is the unique context identifier. By default, randomly generated using `uuid4` `id` is used. - `id` can be used to trace the user behavior, e.g while collecting the statistical data. - """ - labels: Dict[int, NodeLabel2Type] = Field(default_factory=dict) - """ - `labels` stores the history of all passed `labels` - - - key - `id` of the turn. - - value - `label` on this turn. - """ - requests: Dict[int, Message] = Field(default_factory=dict) - """ - `requests` stores the history of all `requests` received by the agent - - - key - `id` of the turn. - - value - `request` on this turn. - """ - responses: Dict[int, Message] = Field(default_factory=dict) - """ - `responses` stores the history of all agent `responses` - - - key - `id` of the turn. - - value - `response` on this turn. - """ - misc: Dict[str, Any] = Field(default_factory=dict) - """ - `misc` stores any custom data. The scripting doesn't use this dictionary by default, - so storage of any data won't reflect on the work on the internal Chatsky Scripting functions. - - Avoid storing unserializable data in order for context storages to work. - - - key - Arbitrary data name. - - value - Arbitrary data. - """ - framework_data: FrameworkData = Field(default_factory=FrameworkData) - """ - This attribute is used for storing custom data required for pipeline execution. - It is meant to be used by the framework only. Accessing it may result in pipeline breakage. - """ - - @field_validator("labels", "requests", "responses") - @classmethod - def sort_dict_keys(cls, dictionary: dict) -> dict: - """ - Sort the keys in the `dictionary`. This needs to be done after deserialization, - since the keys are deserialized in a random order. - - :param dictionary: Dictionary with unsorted keys. - :return: Dictionary with sorted keys. - """ - return {key: dictionary[key] for key in sorted(dictionary)} - - @classmethod - def cast(cls, ctx: Optional[Union[Context, dict, str]] = None, *args, **kwargs) -> Context: - """ - Transform different data types to the objects of the - :py:class:`~.Context` class. - Return an object of the :py:class:`~.Context` - type that is initialized by the input data. - - :param ctx: Data that is used to initialize an object of the - :py:class:`~.Context` type. - An empty :py:class:`~.Context` object is returned if no data is given. - :return: Object of the :py:class:`~.Context` - type that is initialized by the input data. - """ - if not ctx: - ctx = Context(*args, **kwargs) - elif isinstance(ctx, dict): - ctx = Context.model_validate(ctx) - elif isinstance(ctx, str): - ctx = Context.model_validate_json(ctx) - elif not isinstance(ctx, Context): - raise ValueError( - f"Context expected to be an instance of the Context class " - f"or an instance of the dict/str(json) type. Got: {type(ctx)}" - ) - return ctx - - def add_request(self, request: Message): - """ - Add a new `request` to the context. - The new `request` is added with the index of `last_index + 1`. - - :param request: `request` to be added to the context. - """ - request_message = Message.model_validate(request) - last_index = get_last_index(self.requests) - self.requests[last_index + 1] = request_message - - def add_response(self, response: Message): - """ - Add a new `response` to the context. - The new `response` is added with the index of `last_index + 1`. - - :param response: `response` to be added to the context. - """ - response_message = Message.model_validate(response) - last_index = get_last_index(self.responses) - self.responses[last_index + 1] = response_message - - def add_label(self, label: NodeLabel2Type): - """ - Add a new :py:data:`~.NodeLabel2Type` to the context. - The new `label` is added with the index of `last_index + 1`. - - :param label: `label` that we need to add to the context. - """ - last_index = get_last_index(self.labels) - self.labels[last_index + 1] = label - - def clear( - self, - hold_last_n_indices: int, - field_names: Union[Set[str], List[str]] = {"requests", "responses", "labels"}, - ): - """ - Delete all records from the `requests`/`responses`/`labels` except for - the last `hold_last_n_indices` turns. - If `field_names` contains `misc` field, `misc` field is fully cleared. - - :param hold_last_n_indices: Number of last turns to keep. - :param field_names: Properties of :py:class:`~.Context` to clear. - Defaults to {"requests", "responses", "labels"} - """ - field_names = field_names if isinstance(field_names, set) else set(field_names) - if "requests" in field_names: - for index in list(self.requests)[:-hold_last_n_indices]: - del self.requests[index] - if "responses" in field_names: - for index in list(self.responses)[:-hold_last_n_indices]: - del self.responses[index] - if "misc" in field_names: - self.misc.clear() - if "labels" in field_names: - for index in list(self.labels)[:-hold_last_n_indices]: - del self.labels[index] - if "framework_data" in field_names: - self.framework_data = FrameworkData() - - @property - def last_label(self) -> Optional[NodeLabel2Type]: - """ - Return the last :py:data:`~.NodeLabel2Type` of - the :py:class:`~.Context`. - Return `None` if `labels` is empty. - - Since `start_label` is not added to the `labels` field, - empty `labels` usually indicates that the current node is the `start_node`. - """ - last_index = get_last_index(self.labels) - return self.labels.get(last_index) - - @property - def last_response(self) -> Optional[Message]: - """ - Return the last `response` of the current :py:class:`~.Context`. - Return `None` if `responses` is empty. - """ - last_index = get_last_index(self.responses) - return self.responses.get(last_index) - - @last_response.setter - def last_response(self, response: Optional[Message]): - """ - Set the last `response` of the current :py:class:`~.Context`. - Required for use with various response wrappers. - """ - last_index = get_last_index(self.responses) - self.responses[last_index] = Message() if response is None else Message.model_validate(response) - - @property - def last_request(self) -> Optional[Message]: - """ - Return the last `request` of the current :py:class:`~.Context`. - Return `None` if `requests` is empty. - """ - last_index = get_last_index(self.requests) - return self.requests.get(last_index) - - @last_request.setter - def last_request(self, request: Optional[Message]): - """ - Set the last `request` of the current :py:class:`~.Context`. - Required for use with various request wrappers. - """ - last_index = get_last_index(self.requests) - self.requests[last_index] = Message() if request is None else Message.model_validate(request) - - @property - def current_node(self) -> Optional[Node]: - """ - Return current :py:class:`~chatsky.script.core.script.Node`. - """ - actor_data = self.framework_data.actor_data - node = ( - actor_data.get("processed_node") - or actor_data.get("pre_response_processed_node") - or actor_data.get("next_node") - or actor_data.get("pre_transitions_processed_node") - or actor_data.get("previous_node") - ) - if node is None: - logger.warning( - "The `current_node` method should be called " - "when an actor is running between the " - "`ActorStage.GET_PREVIOUS_NODE` and `ActorStage.FINISH_TURN` stages." - ) - - return node diff --git a/tests/script/extras/conditions/conftest.py b/tests/script/extras/conditions/conftest.py index 12e9e754b..2bb5fb8ca 100644 --- a/tests/script/extras/conditions/conftest.py +++ b/tests/script/extras/conditions/conftest.py @@ -1,7 +1,7 @@ import pytest -from chatsky.pipeline import Pipeline -from chatsky.script.conditions.llm_conditions.dataset import Dataset +from chatsky import Pipeline +from chatsky.ml.dataset import Dataset from chatsky.utils.testing.toy_script import TOY_SCRIPT from tests.test_utils import get_path_from_tests_to_current_dir diff --git a/tests/script/extras/conditions/test_conditions.py b/tests/script/extras/conditions/test_conditions.py index 375ba3d6e..a7c1520ab 100644 --- a/tests/script/extras/conditions/test_conditions.py +++ b/tests/script/extras/conditions/test_conditions.py @@ -1,10 +1,9 @@ import pytest -from chatsky.script import Context, Message -from chatsky.script.conditions.llm_conditions.utils import LABEL_KEY -from chatsky.script.conditions.llm_conditions.dataset import DatasetItem, Dataset -from chatsky.script.conditions.llm_conditions.conditions import has_cls_label, has_match -from chatsky.script.conditions.llm_conditions.models.local.cosine_matchers.sklearn import SklearnMatcher, sklearn_available -from chatsky.script.conditions.llm_conditions.models.base_model import ExtrasBaseModel +from chatsky import Context, Message +from chatsky.ml.utils import LABEL_KEY +from chatsky.ml.dataset import DatasetItem, Dataset +from chatsky.conditions.ml import has_cls_label, has_match +from chatsky.ml.models.base_model import ExtrasBaseModel class DummyModel(ExtrasBaseModel): @@ -18,13 +17,6 @@ def __call__(self, text): pass -@pytest.fixture(scope="session") -def standard_model(testing_dataset): - from sklearn.feature_extraction.text import TfidfVectorizer - - yield SklearnMatcher(tokenizer=TfidfVectorizer(stop_words=None), dataset=testing_dataset) - - @pytest.mark.parametrize( ["input"], [ @@ -51,7 +43,6 @@ def test_conds_invalid(input, testing_pipeline): _ = has_cls_label(model, input)(Context(), testing_pipeline) -@pytest.mark.skipif(not sklearn_available, reason="Sklearn package missing.") @pytest.mark.parametrize( ["_input", "last_request", "thresh"], [ diff --git a/tests/script/extras/conditions/test_dialogflow.py b/tests/script/extras/conditions/test_dialogflow.py index c6bfe51a9..e64379d53 100644 --- a/tests/script/extras/conditions/test_dialogflow.py +++ b/tests/script/extras/conditions/test_dialogflow.py @@ -1,7 +1,7 @@ import os import pytest -from chatsky.script.conditions.llm_conditions.models.remote_api.google_dialogflow_model import ( +from chatsky.ml.models.remote_api.google_dialogflow_model import ( GoogleDialogFlowModel, AsyncGoogleDialogFlowModel, dialogflow_available, diff --git a/tests/script/extras/conditions/test_gensim.py b/tests/script/extras/conditions/test_gensim.py deleted file mode 100644 index 308ac9b4a..000000000 --- a/tests/script/extras/conditions/test_gensim.py +++ /dev/null @@ -1,37 +0,0 @@ -import pytest - -from chatsky.script.conditions.llm_conditions.models.local.cosine_matchers.gensim import GensimMatcher, gensim_available -from chatsky.script.conditions.llm_conditions.models.local.cosine_matchers.cosine_matcher_mixin import numpy_available -from chatsky.script.conditions.llm_conditions.dataset import Dataset - -if not gensim_available or not numpy_available: - pytest.skip(allow_module_level=True, reason="`Gensim` package missing.") - -import numpy as np -import gensim -import gensim.downloader as api - - -@pytest.fixture(scope="session") -def testing_model(testing_dataset): - wv = api.load("glove-wiki-gigaword-50") - model = gensim.models.word2vec.Word2Vec() - model.wv = wv - model = GensimMatcher(model=model, dataset=testing_dataset, namespace_key="gensim", min_count=1) - yield model - - -def test_transform(testing_model: GensimMatcher): - result = testing_model.transform("one two three") - assert isinstance(result, np.ndarray) - - -def test_fit(testing_model: GensimMatcher, testing_dataset: Dataset): - testing_model.fit(testing_dataset, min_count=1) - assert testing_model - - -def test_saving(save_file: str, testing_model: GensimMatcher): - testing_model.save(save_file) - new_testing_model = GensimMatcher.load(save_file, "gensim") - assert new_testing_model diff --git a/tests/script/extras/conditions/test_hf.py b/tests/script/extras/conditions/test_hf.py deleted file mode 100644 index f30288179..000000000 --- a/tests/script/extras/conditions/test_hf.py +++ /dev/null @@ -1,63 +0,0 @@ -import pytest - -try: - from transformers import AutoModelForSequenceClassification, AutoTokenizer - import torch - import numpy as np -except ImportError: - pytest.skip(allow_module_level=True) - -from chatsky.script.conditions.llm_conditions.models.local.classifiers.huggingface import HFClassifier -from chatsky.script.conditions.llm_conditions.models.local.cosine_matchers.huggingface import HFMatcher - - -@pytest.fixture(scope="session") -def testing_model(hf_model_name): - model = AutoModelForSequenceClassification.from_pretrained(hf_model_name) - yield model - - -@pytest.fixture(scope="session") -def testing_tokenizer(hf_model_name): - tokenizer = AutoTokenizer.from_pretrained(hf_model_name) - yield tokenizer - - -@pytest.fixture(scope="session") -def testing_classifier(testing_model, testing_tokenizer): - yield HFClassifier( - model=testing_model, tokenizer=testing_tokenizer, device=torch.device("cpu"), namespace_key="HFclassifier" - ) - - -@pytest.fixture(scope="session") -def hf_matcher(testing_model, testing_tokenizer, testing_dataset): - yield HFMatcher( - model=testing_model, - tokenizer=testing_tokenizer, - dataset=testing_dataset, - device=torch.device("cpu"), - namespace_key="HFmodel", - ) - - -def test_saving(save_dir, testing_classifier: HFClassifier, hf_matcher: HFMatcher): - testing_classifier.save(path=save_dir) - testing_classifier = HFClassifier.load(save_dir, namespace_key="HFclassifier") - assert testing_classifier - hf_matcher.save(path=save_dir) - hf_matcher = HFMatcher.load(save_dir, namespace_key="HFmodel") - assert hf_matcher - - -def test_predict(testing_classifier: HFClassifier): - result = testing_classifier.predict("We are looking for x.") - assert result - assert isinstance(result, dict) - - -def test_transform(hf_matcher: HFMatcher, testing_classifier: HFClassifier): - result_1 = hf_matcher.transform("one two three") - assert isinstance(result_1, np.ndarray) - result_2 = testing_classifier.transform("one two three") - assert isinstance(result_2, np.ndarray) diff --git a/tests/script/extras/conditions/test_hf_api.py b/tests/script/extras/conditions/test_hf_api.py index b20da19dd..20fcdf05f 100644 --- a/tests/script/extras/conditions/test_hf_api.py +++ b/tests/script/extras/conditions/test_hf_api.py @@ -1,7 +1,7 @@ import os import pytest -from chatsky.script.conditions.llm_conditions.models.remote_api.hf_api_model import ( +from chatsky.ml.models.remote_api.hf_api_model import ( HFAPIModel, AsyncHFAPIModel, hf_api_available, diff --git a/tests/script/extras/conditions/test_rasa.py b/tests/script/extras/conditions/test_rasa.py index aecb9fa27..a40550d42 100644 --- a/tests/script/extras/conditions/test_rasa.py +++ b/tests/script/extras/conditions/test_rasa.py @@ -1,7 +1,7 @@ import os import pytest -from chatsky.script.conditions.llm_conditions.models.remote_api.rasa_model import RasaModel, AsyncRasaModel, rasa_available +from chatsky.ml.models.remote_api.rasa_model import RasaModel, AsyncRasaModel, rasa_available from tests.context_storages.test_dbs import ping_localhost RASA_ACTIVE = ping_localhost(5005) diff --git a/tests/script/extras/conditions/test_sklearn.py b/tests/script/extras/conditions/test_sklearn.py deleted file mode 100644 index 8a8203799..000000000 --- a/tests/script/extras/conditions/test_sklearn.py +++ /dev/null @@ -1,53 +0,0 @@ -import pytest - -from chatsky.script.conditions.llm_conditions.models.local.classifiers.sklearn import SklearnClassifier, sklearn_available -from chatsky.script.conditions.llm_conditions.models.local.cosine_matchers.sklearn import SklearnMatcher -from chatsky.script.conditions.llm_conditions.dataset import Dataset - -if not sklearn_available: - pytest.skip(allow_module_level=True, reason="`Sklearn` package missing.") - -import numpy as np -from sklearn.feature_extraction.text import TfidfVectorizer -from sklearn.linear_model import LogisticRegression - - -@pytest.fixture(scope="session") -def testing_classifier(): - classifier = SklearnClassifier(model=LogisticRegression(), tokenizer=TfidfVectorizer(), namespace_key="classifier") - yield classifier - - -@pytest.fixture(scope="session") -def testing_model(testing_dataset): - model = SklearnMatcher( - model=None, - tokenizer=TfidfVectorizer(), - dataset=testing_dataset, - namespace_key="model", - ) - yield model - - -def test_saving(save_file: str, testing_classifier: SklearnClassifier, testing_model: SklearnMatcher): - testing_classifier.save(save_file) - new_classifier = SklearnClassifier.load(save_file, namespace_key="classifier") - assert isinstance(new_classifier.model, type(testing_classifier.model)) - assert isinstance(new_classifier.tokenizer, type(testing_classifier.tokenizer)) - assert new_classifier.namespace_key == testing_classifier.namespace_key - testing_model.save(save_file) - _ = SklearnMatcher.load(path=save_file, namespace_key="model") - assert isinstance(new_classifier.model, type(testing_classifier.model)) - assert isinstance(new_classifier.tokenizer, type(testing_classifier.tokenizer)) - assert new_classifier.namespace_key == testing_classifier.namespace_key - - -def test_fit(testing_classifier: SklearnClassifier, testing_model: SklearnMatcher, testing_dataset: Dataset): - tc_result = testing_classifier.fit(testing_dataset) - ts_result = testing_model.fit(testing_dataset) - assert tc_result is None and ts_result is None - - -def test_transform(testing_model: SklearnMatcher): - result = testing_model.transform("one two three") - assert isinstance(result, np.ndarray) diff --git a/tests/script/extras/conditions/test_tutorials.py b/tests/script/extras/conditions/test_tutorials.py index 44115ecdf..c4f771649 100644 --- a/tests/script/extras/conditions/test_tutorials.py +++ b/tests/script/extras/conditions/test_tutorials.py @@ -4,11 +4,9 @@ import pytest from tests.test_utils import get_path_from_tests_to_current_dir -from chatsky.script.conditions.llm_conditions.models.remote_api.google_dialogflow_model import dialogflow_available -from chatsky.script.conditions.llm_conditions.models.remote_api.rasa_model import rasa_available -from chatsky.script.conditions.llm_conditions.models.remote_api.hf_api_model import hf_api_available -from chatsky.script.conditions.llm_conditions.models.local.cosine_matchers.gensim import gensim_available -from chatsky.script.conditions.llm_conditions.models.local.classifiers.sklearn import sklearn_available +from chatsky.ml.models.remote_api.google_dialogflow_model import dialogflow_available +from chatsky.ml.models.remote_api.rasa_model import rasa_available +from chatsky.ml.models.remote_api.hf_api_model import hf_api_available from chatsky.utils.testing.common import check_happy_path from tests.context_storages.test_dbs import ping_localhost @@ -21,16 +19,13 @@ @pytest.mark.parametrize( ["example_module_name", "skip_condition"], [ - ("1_base_tutorial", not sklearn_available), ("7_rasa", os.getenv("RASA_API_KEY") is None or not rasa_available or not RASA_ACTIVE), ( "5_dialogflow", not (os.getenv("GDF_ACCOUNT_JSON") and os.path.exists(os.getenv("GDF_ACCOUNT_JSON"))) or not dialogflow_available, ), - ("6_hf_api", os.getenv("HF_API_KEY") is None or not hf_api_available), - ("2_gensim_tutorial", not gensim_available), - ("4_sklearn_tutorial", not sklearn_available), + ("6_hf_api", os.getenv("HF_API_KEY") is None or not hf_api_available) ], ) @pytest.mark.rasa diff --git a/tests/script/extras/conditions/test_utils.py b/tests/script/extras/conditions/test_utils.py index d77794fd9..c1cc561ff 100644 --- a/tests/script/extras/conditions/test_utils.py +++ b/tests/script/extras/conditions/test_utils.py @@ -1,9 +1,9 @@ import pytest -from chatsky.script import Context, Message -from chatsky.script.conditions.llm_conditions.dataset import Dataset, pyyaml_available -from chatsky.script.conditions.llm_conditions.utils import LABEL_KEY -from chatsky.script.conditions.llm_conditions.models.remote_api.async_mixin import AsyncMixin +from chatsky import Context, Message +from chatsky.ml.dataset import Dataset, pyyaml_available +from chatsky.ml.utils import LABEL_KEY +from chatsky.ml.models.remote_api.async_mixin import AsyncMixin from tests.test_utils import get_path_from_tests_to_current_dir path = get_path_from_tests_to_current_dir(__file__) diff --git a/tutorials/script/extras/conditions/1_hf_api.py b/tutorials/script/extras/conditions/1_hf_api.py index e0488863f..38193876a 100644 --- a/tutorials/script/extras/conditions/1_hf_api.py +++ b/tutorials/script/extras/conditions/1_hf_api.py @@ -9,20 +9,20 @@ # %% import os -from chatsky.script import ( +from chatsky import ( Message, RESPONSE, GLOBAL, TRANSITIONS, LOCAL, ) -from chatsky.script import conditions as cnd +from chatsky import conditions as cnd -from chatsky.script.conditions.llm_conditions.models.remote_api.hf_api_model import ( +from chatsky.ml.models.remote_api.hf_api_model import ( HFAPIModel, ) -from chatsky.script.conditions.llm_conditions import conditions as i_cnd -from chatsky.pipeline import Pipeline +from chatsky.ml import conditions as i_cnd +from chatsky import Pipeline from chatsky.messengers.console import CLIMessengerInterface from chatsky.utils.testing.common import ( is_interactive_mode, diff --git a/tutorials/script/extras/conditions/2_dialogflow.py b/tutorials/script/extras/conditions/2_dialogflow.py index 5698ce393..ba35c3e1f 100644 --- a/tutorials/script/extras/conditions/2_dialogflow.py +++ b/tutorials/script/extras/conditions/2_dialogflow.py @@ -12,20 +12,20 @@ # %% import os -from chatsky.script import ( +from chatsky import ( Message, RESPONSE, GLOBAL, TRANSITIONS, LOCAL, ) -from chatsky.script import conditions as cnd +from chatsky import conditions as cnd -from chatsky.script.conditions.llm_conditions.models.remote_api.google_dialogflow_model import ( +from chatsky.ml.models.remote_api.google_dialogflow_model import ( GoogleDialogFlowModel, ) -from chatsky.script.conditions.llm_conditions import conditions as i_cnd -from chatsky.pipeline import Pipeline +from chatsky.ml import conditions as i_cnd +from chatsky import Pipeline from chatsky.messengers.console import CLIMessengerInterface from chatsky.utils.testing.common import ( is_interactive_mode, diff --git a/tutorials/script/extras/conditions/3_rasa.py b/tutorials/script/extras/conditions/3_rasa.py index 139064f65..03361c8ef 100644 --- a/tutorials/script/extras/conditions/3_rasa.py +++ b/tutorials/script/extras/conditions/3_rasa.py @@ -10,18 +10,18 @@ # %% import os -from chatsky.script import ( +from chatsky import ( Message, RESPONSE, GLOBAL, TRANSITIONS, LOCAL, ) -from chatsky.script import conditions as cnd +from chatsky import conditions as cnd -from chatsky.script.conditions.llm_conditions.models.remote_api.rasa_model import RasaModel -from chatsky.script.conditions.llm_conditions import conditions as i_cnd -from chatsky.pipeline import Pipeline +from chatsky.ml.models.remote_api.rasa_model import RasaModel +from chatsky.ml import conditions as i_cnd +from chatsky import Pipeline from chatsky.messengers.console import CLIMessengerInterface from chatsky.utils.testing.common import ( is_interactive_mode,