Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
chyroc committed Jan 8, 2025
2 parents 3260e8a + a3042be commit a066fa3
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 25 deletions.
4 changes: 2 additions & 2 deletions cozepy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .audio.rooms import CreateRoomResp
from .audio.speech import AudioFormat
from .audio.transcriptions import CreateTranslationResp
from .audio.transcriptions import CreateTranscriptionsResp
from .audio.voices import Voice
from .auth import (
AsyncDeviceOAuthApp,
Expand Down Expand Up @@ -143,7 +143,7 @@
"Voice",
"AudioFormat",
# audio.transcriptions
"CreateTranslationResp",
"CreateTranscriptionsResp",
# auth
"AsyncDeviceOAuthApp",
"AsyncJWTOAuthApp",
Expand Down
18 changes: 9 additions & 9 deletions cozepy/audio/transcriptions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from cozepy.util import remove_url_trailing_slash


class CreateTranslationResp(CozeModel):
class CreateTranscriptionsResp(CozeModel):
# The text of translation
text: str

Expand All @@ -23,18 +23,18 @@ def create(
*,
file: FileTypes,
**kwargs,
) -> CreateTranslationResp:
) -> CreateTranscriptionsResp:
"""
create translation
create transcriptions
:param file: The file to be translated.
:return: create translation result
:return: create transcriptions result
"""
url = f"{self._base_url}/v1/audio/transcriptions"
headers: Optional[dict] = kwargs.get("headers")
files = {"file": _try_fix_file(file)}
return self._requester.request(
"post", url, stream=False, cast=CreateTranslationResp, headers=headers, files=files
"post", url, stream=False, cast=CreateTranscriptionsResp, headers=headers, files=files
)


Expand All @@ -53,16 +53,16 @@ async def create(
*,
file: FileTypes,
**kwargs,
) -> CreateTranslationResp:
) -> CreateTranscriptionsResp:
"""
create translation
create transcriptions
:param file: The file to be translated.
:return: create translation result
:return: create transcriptions result
"""
url = f"{self._base_url}/v1/audio/transcriptions"
files = {"file": _try_fix_file(file)}
headers: Optional[dict] = kwargs.get("headers")
return await self._requester.arequest(
"post", url, stream=False, cast=CreateTranslationResp, headers=headers, files=files
"post", url, stream=False, cast=CreateTranscriptionsResp, headers=headers, files=files
)
14 changes: 7 additions & 7 deletions cozepy/websockets/audio/speech/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,22 @@ class InputTextBufferCommitEvent(WebsocketsEvent):

# req
class SpeechUpdateEvent(WebsocketsEvent):
class OpusConfig(object):
class OpusConfig(BaseModel):
bitrate: Optional[int] = None
use_cbr: Optional[bool] = None
frame_size_ms: Optional[float] = None

class PCMConfig(object):
class PCMConfig(BaseModel):
sample_rate: Optional[int] = None

class OutputAudio(object):
class OutputAudio(BaseModel):
codec: Optional[str]
pcm_config: Optional["SpeechUpdateEvent.PCMConfig"]
opus_config: Optional["SpeechUpdateEvent.OpusConfig"]
speech_rate: Optional[int]
voice_id: Optional[str]

class Data:
class Data(BaseModel):
output_audio: "SpeechUpdateEvent.OutputAudio"

type: WebsocketsEventType = WebsocketsEventType.SPEECH_UPDATE
Expand Down Expand Up @@ -78,17 +78,17 @@ class SpeechAudioCompletedEvent(WebsocketsEvent):

class AsyncWebsocketsAudioSpeechEventHandler(AsyncWebsocketsBaseEventHandler):
async def on_input_text_buffer_committed(
self, cli: "AsyncWebsocketsAudioSpeechEventHandler", event: InputTextBufferCommittedEvent
self, cli: "AsyncWebsocketsAudioSpeechCreateClient", event: InputTextBufferCommittedEvent
):
pass

async def on_speech_audio_update(
self, cli: "AsyncWebsocketsAudioSpeechEventHandler", event: SpeechAudioUpdateEvent
self, cli: "AsyncWebsocketsAudioSpeechCreateClient", event: SpeechAudioUpdateEvent
):
pass

async def on_speech_audio_completed(
self, cli: "AsyncWebsocketsAudioSpeechEventHandler", event: SpeechAudioCompletedEvent
self, cli: "AsyncWebsocketsAudioSpeechCreateClient", event: SpeechAudioCompletedEvent
):
pass

Expand Down
20 changes: 20 additions & 0 deletions cozepy/websockets/ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,17 @@ class WebsocketsEvent(CozeModel, ABC):


class AsyncWebsocketsBaseClient(abc.ABC):
class State(str, Enum):
"""
initialized, connecting, connected, closing, closed
"""

INITIALIZED = "initialized"
CONNECTING = "connecting"
CONNECTED = "connected"
CLOSING = "closing"
CLOSED = "closed"

def __init__(
self,
base_url: str,
Expand All @@ -74,6 +85,7 @@ def __init__(
wait_events: Optional[List[WebsocketsEventType]] = None,
**kwargs,
):
self._state = self.State.INITIALIZED
self._base_url = remove_url_trailing_slash(base_url)
self._auth = auth
self._requester = requester
Expand All @@ -99,6 +111,9 @@ async def __call__(self):
await self.close()

async def connect(self):
if self._state != self.State.INITIALIZED:
raise ValueError(f"Cannot connect in {self._state.value} state")
self._state = self.State.CONNECTING
headers = {
"Authorization": f"Bearer {self._auth.token}",
"X-Coze-Client-User-Agent": coze_client_user_agent(),
Expand All @@ -110,6 +125,7 @@ async def connect(self):
user_agent_header=user_agent(),
additional_headers=headers,
)
self._state = self.State.CONNECTED
log_info("[%s] connected to websocket", self._path)

self._receive_task = asyncio.create_task(self._receive_loop())
Expand All @@ -126,7 +142,11 @@ def on(self, event_type: WebsocketsEventType, handler: Callable):
self._on_event[event_type] = handler

async def close(self) -> None:
if self._state not in (self.State.CONNECTED, self.State.CONNECTING):
return
self._state = self.State.CLOSING
await self._close()
self._state = self.State.CLOSED

async def _send_loop(self) -> None:
try:
Expand Down
14 changes: 7 additions & 7 deletions tests/test_audio_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from tests.test_util import logid_key


def mock_create_translation(respx_mock):
def mock_create_transcriptions(respx_mock):
logid = random_hex(10)
raw_response = httpx.Response(
200,
Expand All @@ -25,11 +25,11 @@ def mock_create_translation(respx_mock):


@pytest.mark.respx(base_url="https://api.coze.com")
class TestAudioTranslation:
def test_sync_translation_create(self, respx_mock):
class TestSyncAudioTranscriptions:
def test_sync_transcriptions_create(self, respx_mock):
coze = Coze(auth=TokenAuth(token="token"))

mock_logid = mock_create_translation(respx_mock)
mock_logid = mock_create_transcriptions(respx_mock)

res = coze.audio.transcriptions.create(file=("filename", "content"))
assert res
Expand All @@ -39,11 +39,11 @@ def test_sync_translation_create(self, respx_mock):

@pytest.mark.respx(base_url="https://api.coze.com")
@pytest.mark.asyncio
class TestAsyncAudioTranslation:
async def test_async_translation_create(self, respx_mock):
class TestAsyncAudioTranscriptions:
async def test_async_transcriptions_create(self, respx_mock):
coze = AsyncCoze(auth=TokenAuth(token="token"))

mock_logid = mock_create_translation(respx_mock)
mock_logid = mock_create_transcriptions(respx_mock)

res = await coze.audio.transcriptions.create(file=("filename", "content"))
assert res
Expand Down

0 comments on commit a066fa3

Please sign in to comment.