Skip to content

Commit

Permalink
community: llms imports fixes (#18943)
Browse files Browse the repository at this point in the history
Classes are missed in  __all__  and in different places of __init__.py
- BaichuanLLM 
- ChatDatabricks
- ChatMlflow
- Llamafile
- Mlflow
- Together
Added classes to __all__. I also sorted __all__ list.

---------

Co-authored-by: Erick Friis <[email protected]>
  • Loading branch information
leo-gan and efriis authored Mar 18, 2024
1 parent aee5138 commit 7de1d9a
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 15 deletions.
40 changes: 25 additions & 15 deletions libs/community/langchain_community/llms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def _import_databricks() -> Type[BaseLLM]:
return Databricks


# deprecated / only for back compat - do not add to __all__
def _import_databricks_chat() -> Any:
warn_deprecated(
since="0.0.22",
Expand Down Expand Up @@ -325,6 +326,7 @@ def _import_mlflow() -> Type[BaseLLM]:
return Mlflow


# deprecated / only for back compat - do not add to __all__
def _import_mlflow_chat() -> Any:
warn_deprecated(
since="0.0.22",
Expand Down Expand Up @@ -631,7 +633,7 @@ def __getattr__(name: str) -> Any:
return _import_aviary()
elif name == "AzureMLOnlineEndpoint":
return _import_azureml_endpoint()
elif name == "Baichuan":
elif name == "BaichuanLLM" or name == "Baichuan":
return _import_baichuan()
elif name == "QianfanLLMEndpoint":
return _import_baidu_qianfan_endpoint()
Expand Down Expand Up @@ -701,6 +703,8 @@ def __getattr__(name: str) -> Any:
return _import_konko()
elif name == "LlamaCpp":
return _import_llamacpp()
elif name == "Llamafile":
return _import_llamafile()
elif name == "ManifestWrapper":
return _import_manifest()
elif name == "Minimax":
Expand Down Expand Up @@ -818,6 +822,7 @@ def __getattr__(name: str) -> Any:
"Aviary",
"AzureMLOnlineEndpoint",
"AzureOpenAI",
"BaichuanLLM",
"Banana",
"Baseten",
"Beam",
Expand All @@ -836,8 +841,8 @@ def __getattr__(name: str) -> Any:
"Fireworks",
"ForefrontAI",
"Friendli",
"GigaChat",
"GPT4All",
"GigaChat",
"GooglePalm",
"GooseAI",
"GradientLLM",
Expand All @@ -846,22 +851,26 @@ def __getattr__(name: str) -> Any:
"HuggingFacePipeline",
"HuggingFaceTextGenInference",
"HumanInputLLM",
"JavelinAIGateway",
"KoboldApiLLM",
"Konko",
"LlamaCpp",
"TextGen",
"Llamafile",
"ManifestWrapper",
"Minimax",
"Mlflow",
"MlflowAIGateway",
"Modal",
"MosaicML",
"Nebula",
"NIBittensorLLM",
"NLPCloud",
"Nebula",
"OCIGenAI",
"OCIModelDeploymentTGI",
"OCIModelDeploymentVLLM",
"OCIGenAI",
"OctoAIEndpoint",
"Ollama",
"OpaquePrompts",
"OpenAI",
"OpenAIChat",
"OpenLLM",
Expand All @@ -873,30 +882,29 @@ def __getattr__(name: str) -> Any:
"PredictionGuard",
"PromptLayerOpenAI",
"PromptLayerOpenAIChat",
"OpaquePrompts",
"QianfanLLMEndpoint",
"RWKV",
"Replicate",
"SagemakerEndpoint",
"SelfHostedHuggingFaceLLM",
"SelfHostedPipeline",
"SparkLLM",
"StochasticAI",
"TextGen",
"TitanTakeoff",
"TitanTakeoffPro",
"Together",
"Tongyi",
"VertexAI",
"VertexAIModelGarden",
"VLLM",
"VLLMOpenAI",
"VertexAI",
"VertexAIModelGarden",
"VolcEngineMaasLLM",
"WatsonxLLM",
"Writer",
"OctoAIEndpoint",
"Xinference",
"JavelinAIGateway",
"QianfanLLMEndpoint",
"YandexGPT",
"Yuan2",
"VolcEngineMaasLLM",
"SparkLLM",
]


Expand All @@ -912,6 +920,7 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]:
"aviary": _import_aviary,
"azure": _import_azure_openai,
"azureml_endpoint": _import_azureml_endpoint,
"baichuan": _import_baichuan,
"bananadev": _import_bananadev,
"baseten": _import_baseten,
"beam": _import_beam,
Expand All @@ -922,7 +931,7 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]:
"ctransformers": _import_ctransformers,
"ctranslate2": _import_ctranslate2,
"databricks": _import_databricks,
"databricks-chat": _import_databricks_chat,
"databricks-chat": _import_databricks_chat, # deprecated / only for back compat
"deepinfra": _import_deepinfra,
"deepsparse": _import_deepsparse,
"edenai": _import_edenai,
Expand All @@ -942,10 +951,11 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]:
"koboldai": _import_koboldai,
"konko": _import_konko,
"llamacpp": _import_llamacpp,
"llamafile": _import_llamafile,
"textgen": _import_textgen,
"minimax": _import_minimax,
"mlflow": _import_mlflow,
"mlflow-chat": _import_mlflow_chat,
"mlflow-chat": _import_mlflow_chat, # deprecated / only for back compat
"mlflow-ai-gateway": _import_mlflow_ai_gateway,
"modal": _import_modal,
"mosaic": _import_mosaicml,
Expand Down
4 changes: 4 additions & 0 deletions libs/community/tests/unit_tests/llms/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"Aviary",
"AzureMLOnlineEndpoint",
"AzureOpenAI",
"BaichuanLLM",
"Banana",
"Baseten",
"Beam",
Expand Down Expand Up @@ -44,9 +45,11 @@
"KoboldApiLLM",
"Konko",
"LlamaCpp",
"Llamafile",
"TextGen",
"ManifestWrapper",
"Minimax",
"Mlflow",
"MlflowAIGateway",
"Modal",
"MosaicML",
Expand Down Expand Up @@ -77,6 +80,7 @@
"StochasticAI",
"TitanTakeoff",
"TitanTakeoffPro",
"Together",
"Tongyi",
"VertexAI",
"VertexAIModelGarden",
Expand Down

0 comments on commit 7de1d9a

Please sign in to comment.