Skip to content

Commit

Permalink
Add better docstrings and work on PR review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Sheepsta300 committed Aug 27, 2024
1 parent 112487e commit f47fc9b
Showing 1 changed file with 54 additions and 41 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
import io
import time
from typing import Any, Dict, Iterator, Literal, Optional, Tuple, Union

Expand All @@ -15,43 +16,58 @@
class AzureOpenAIWhisperParser(BaseBlobParser):
"""Transcribe and parse audio files.
Audio transcription is with Azure OpenAI Whisper model.
Audio transcription is with the Azure OpenAI Whisper model.
This is different to the Open AI Whisper parser and requires
an Azure OpenAI API Key.
Args:
api_key: Azure OpenAI API key
chunk_duration_threshold: minimum duration of a chunk in seconds
NOTE: According to the OpenAI API, the chunk duration should be at least 0.1
seconds. If the chunk duration is less or equal than the threshold,
it will be skipped.
"""

def __init__(
self,
api_key: Optional[str] = None,
*,
deployment_id: str,
chunk_duration_threshold: float = 0.1,
base_url: Optional[str] = None,
azure_endpoint: Optional[str] = None,
api_version: Optional[str] = None,
language: Union[str, None] = None,
prompt: Union[str, None] = None,
language: Optional[str] = None,
prompt: Optional[str] = None,
response_format: Union[
Literal["json", "text", "srt", "verbose_json", "vtt"], None
] = None,
temperature: Union[float, None] = None,
# input_format: Union[
# Literal["flac", "mp3", "mp4", "mpeg", "mpga", "m4a", "ogg", "wav", "webm"]
# ] = "mp3",
temperature: Optional[float] = None,

deployment_id: str,
chunk_duration_threshold: float = 0.1,
):
self.api_key = api_key
self.base_url = base_url or os.environ.get("AZURE_OPENAI_ENDPOINT")
"""Initialize the parser.
Args:
api_key (Optional[str]): Azure OpenAI API key.
deployment_model (str): Identifier for the specific model deployment.
chunk_duration_threshold (float): Minimum duration of a chunk in seconds
NOTE: According to the OpenAI API, the chunk duration should be at least 0.1

Check failure on line 47 in libs/community/langchain_community/document_loaders/parsers/audio.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.12

Ruff (E501)

langchain_community/document_loaders/parsers/audio.py:47:89: E501 Line too long (92 > 88)

Check failure on line 47 in libs/community/langchain_community/document_loaders/parsers/audio.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.8

Ruff (E501)

langchain_community/document_loaders/parsers/audio.py:47:89: E501 Line too long (92 > 88)
seconds. If the chunk duration is less or equal than the threshold,
it will be skipped.
azure_endpoint (Optional[str]): URL endpoint for the Azure OpenAI service.
api_version (Optional[str]): Version of the OpenAI API to use.
language (Optional[str]): Language for processing the request.
prompt (Optional[str]): Query or instructions for the AI model.
response_format
(Union[Literal["json", "text", "srt", "verbose_json", "vtt"], None]):
Format for the response from the service.
temperature (Optional[float]): Controls the randomness of the AI model’s output.

Check failure on line 57 in libs/community/langchain_community/document_loaders/parsers/audio.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.12

Ruff (E501)

langchain_community/document_loaders/parsers/audio.py:57:89: E501 Line too long (92 > 88)

Check failure on line 57 in libs/community/langchain_community/document_loaders/parsers/audio.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.8

Ruff (E501)

langchain_community/document_loaders/parsers/audio.py:57:89: E501 Line too long (92 > 88)
"""
self.api_key = api_key or os.environ.get("AZURE_OPENAI_API_KEY")
self.azure_endpoint = azure_endpoint or os.environ.get("AZURE_OPENAI_ENDPOINT")
self.api_version = api_version or os.environ.get("OPENAI_API_VERSION")

self.deployment_id = deployment_id
self.chunk_duration_threshold = chunk_duration_threshold
self.language = language
self.prompt = prompt
self.response_format = response_format
self.temperature = temperature
# self.input_format = input_format

self.deployment_id = deployment_id
self.chunk_duration_threshold = chunk_duration_threshold

@property
def _create_params(self) -> Dict[str, Any]:
Expand All @@ -64,9 +80,14 @@ def _create_params(self) -> Dict[str, Any]:
return {k: v for k, v in params.items() if v is not None}

def lazy_parse(self, blob: Blob) -> Iterator[Document]:
"""Lazily parse the blob."""
"""Lazily parse the blob.
Args:
blob (Blob): The file to be parsed.
import io
Returns:
Iterator[Document]: The parsed transcript of the file.
"""

try:
import openai
Expand All @@ -75,50 +96,48 @@ def lazy_parse(self, blob: Blob) -> Iterator[Document]:
"openai package not found, please install it with "
"`pip install openai`"
)

try:
from pydub import AudioSegment
except ImportError:
raise ImportError(
"pydub package not found, please install it with " "`pip install pydub`"
"pydub package not found, please install it with "
"`pip install pydub`"
)

if is_openai_v1():
# api_key optional, defaults to `os.environ['AZURE_OPENAI_API_KEY']`
# api_version optional, defaults to `os.environ['OPENAI_API_VERSION']`
# azure_endpoint/base_rule optional,
# defaults to `os.environ['AZURE_OPENAI_ENDPOINT']`
client = openai.AzureOpenAI(
api_key=self.api_key,
azure_endpoint=self.base_url,
azure_endpoint=self.azure_endpoint,
api_version=self.api_version,
)
else:
# Set the API key if provided
if self.api_key:
openai.api_key = self.api_key
if self.base_url:
openai.base_url = self.base_url
if self.azure_endpoint:
openai.base_url = self.azure_endpoint

# Audio file from disk
audio = AudioSegment.from_file(blob.path)

file_extension = os.path.splitext(blob.path)[1][1:]
# Define the duration of each chunk in minutes
# Need to meet 25MB size limit for Whisper API
chunk_duration = 20
chunk_duration_ms = chunk_duration * 60 * 1000

print(blob.source)

Check failure on line 128 in libs/community/langchain_community/document_loaders/parsers/audio.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.12

Ruff (T201)

langchain_community/document_loaders/parsers/audio.py:128:9: T201 `print` found

Check failure on line 128 in libs/community/langchain_community/document_loaders/parsers/audio.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.8

Ruff (T201)

langchain_community/document_loaders/parsers/audio.py:128:9: T201 `print` found
# Split the audio into chunk_duration_ms chunks
for split_number, i in enumerate(range(0, len(audio), chunk_duration_ms)):
# Audio chunk
chunk = audio[i : i + chunk_duration_ms]
# Skip chunks that are too short to transcribe
if chunk.duration_seconds <= self.chunk_duration_threshold:
continue
file_obj = io.BytesIO(chunk.export(format="mp3").read())
file_obj = io.BytesIO(chunk.export(format=file_extension).read())
if blob.source is not None:
file_obj.name = blob.source + f"_part_{split_number}.mp3"
file_obj.name = os.path.splitext(blob.source)[0] + f"_part_{split_number}.{file_extension}"

Check failure on line 138 in libs/community/langchain_community/document_loaders/parsers/audio.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.12

Ruff (E501)

langchain_community/document_loaders/parsers/audio.py:138:89: E501 Line too long (107 > 88)

Check failure on line 138 in libs/community/langchain_community/document_loaders/parsers/audio.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.8

Ruff (E501)

langchain_community/document_loaders/parsers/audio.py:138:89: E501 Line too long (107 > 88)
else:
file_obj.name = f"part_{split_number}.mp3"
file_obj.name = f"part_{split_number}.{file_extension}"

# Transcribe
print(f"Transcribing part {split_number + 1}!") # noqa: T201
Expand Down Expand Up @@ -157,7 +176,7 @@ class OpenAIWhisperParser(BaseBlobParser):
Args:
api_key: OpenAI API key
chunk_duration_threshold: minimum duration of a chunk in seconds
chunk_duration_threshold: Minimum duration of a chunk in seconds
NOTE: According to the OpenAI API, the chunk duration should be at least 0.1
seconds. If the chunk duration is less or equal than the threshold,
it will be skipped.
Expand Down Expand Up @@ -199,8 +218,6 @@ def _create_params(self) -> Dict[str, Any]:
def lazy_parse(self, blob: Blob) -> Iterator[Document]:
"""Lazily parse the blob."""

import io

try:
import openai
except ImportError:
Expand Down Expand Up @@ -378,8 +395,6 @@ def __init__(
def lazy_parse(self, blob: Blob) -> Iterator[Document]:
"""Lazily parse the blob."""

import io

try:
from pydub import AudioSegment
except ImportError:
Expand Down Expand Up @@ -574,8 +589,6 @@ def __init__(
def lazy_parse(self, blob: Blob) -> Iterator[Document]:
"""Lazily parse the blob."""

import io

try:
from pydub import AudioSegment
except ImportError:
Expand Down

0 comments on commit f47fc9b

Please sign in to comment.