Skip to content

Commit

Permalink
feat: Enhance MediaMessage with async media handling and format detec…
Browse files Browse the repository at this point in the history
…tion

Add support for asynchronous media resource loading, format detection using python-magic, and improved media handling methods for URL, path, and data sources
  • Loading branch information
lss233 committed Feb 24, 2025
1 parent bb6ce74 commit 1c6b66b
Show file tree
Hide file tree
Showing 4 changed files with 200 additions and 14 deletions.
109 changes: 95 additions & 14 deletions kirara_ai/im/message.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import base64
import os
import tempfile
from abc import ABC, abstractmethod
from typing import List, Optional

import aiofiles
import aiohttp
import magic

from kirara_ai.im.sender import ChatSender


Expand Down Expand Up @@ -46,41 +50,118 @@ def __init__(
self.path = path
self.data = data
self.format = format
self.resource_type = "media" # 由子类重写为具体类型

# 根据传入的参数计算其他属性
if url:
self._from_url(url)
self._from_url(url, format)
elif path:
self._from_path(path)
self._from_path(path, format)
elif data and format:
self._from_data(data, format)
else:
raise ValueError("Must provide either url, path, or data + format.")

def _from_url(self, url: str):
async def _load_data_from_path(self) -> None:
"""异步从文件路径读取数据并赋值给self.data"""
async with aiofiles.open(self.path, "rb") as f:
self.data = await f.read()
await self._detect_format()

async def _load_data_from_url(self) -> None:
"""异步从URL下载数据并赋值给self.data"""
async with aiohttp.ClientSession() as session:
async with session.get(self.url) as resp:
self.data = await resp.read()
await self._detect_format()

async def _detect_format(self) -> None:
"""使用python-magic检测数据格式并赋值给self.format"""
if self.format:
return

mime_type = magic.from_buffer(self.data, mime=True)
self.format = mime_type.split('/')[-1]
self.resource_type = mime_type.split('/')[0]

async def get_url(self) -> str:
"""获取媒体资源的URL"""
if self.url:
return self.url

if not self.data:
if self.path:
await self._load_data_from_path()
else:
raise ValueError("No available media source")

return f"data:{self.resource_type}/{self.format};base64,{base64.b64encode(self.data).decode()}"

async def get_path(self) -> str:
"""获取媒体资源的文件路径"""
if self.path:
return self.path

if not self.data:
if self.url:
await self._load_data_from_url()
else:
raise ValueError("No available media source")

with tempfile.NamedTemporaryFile(delete=False) as f:
f.write(self.data)
self.path = f.name
return self.path

async def get_data(self) -> bytes:
"""获取媒体资源的二进制数据"""
if self.data:
return self.data

if self.path:
await self._load_data_from_path()
return self.data
if self.url:
await self._load_data_from_url()
return self.data

raise ValueError("No available media source")

def _from_url(self, url: str, format: Optional[str] = None):
"""从 URL 计算其他属性"""
self.url = url
self.path = None
self.data = None
self.format = url.split(".")[-1] if "." in url else None
self.format = format

def _from_path(self, path: str):
def _from_path(self, path: str, format: Optional[str] = None):
"""从文件路径计算其他属性"""
self.path = path
self.url = None
self.data = None
self.format = os.path.splitext(path)[-1].lstrip(".")
self.format = format

def _from_data(self, data: bytes, format: str):
"""从数据和格式计算其他属性"""
self.data = data
self.format = format
self.url = None
self.path = None

def to_dict(self):
return {
"type": self.resource_type,
"url": self.url,
"path": self.path,
"data": base64.b64encode(self.data).decode() if self.data else None,
"format": self.format,
}


# 定义语音消息
class VoiceMessage(MediaMessage):
resource_type = "audio"

def to_dict(self):
return {
"type": "voice",
Expand All @@ -96,6 +177,8 @@ def to_plain(self):

# 定义图片消息
class ImageMessage(MediaMessage):
resource_type = "image"

def to_dict(self):
return {
"type": "image",
Expand All @@ -114,7 +197,6 @@ def __repr__(self):
# 定义@消息元素
# :deprecated
class AtElement(MessageElement):

def __init__(self, user_id: str, nickname: str = ""):
self.user_id = user_id
self.nickname = nickname
Expand Down Expand Up @@ -161,7 +243,8 @@ def __repr__(self):

# 定义文件消息元素
class FileElement(MediaMessage):

resource_type = "file"

def to_dict(self):
return {
"type": "file",
Expand Down Expand Up @@ -212,12 +295,10 @@ def __repr__(self):


# 定义视频消息元素
class VideoElement(MessageElement):
def __init__(self, file: str):
self.file = file

class VideoElement(MediaMessage):
resource_type = "video"

def to_dict(self):

return {"type": "video", "data": {"file": self.file}}

def to_plain(self):
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,11 @@ dependencies = [
"setuptools",
"tomli>=2.0.0",
"pre-commit",
"python-magic",
"python-magic-bin ; platform_system == 'Windows'",
]


[project.scripts]
kirara_ai = "kirara_ai.__main__:main"

Expand Down
1 change: 1 addition & 0 deletions tests/resources/test_image.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
This is a test file for MediaElement testing.
101 changes: 101 additions & 0 deletions tests/test_media_element.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import os

import aiohttp
import pytest

from kirara_ai.im.message import MediaMessage

# 测试资源路径
TEST_RESOURCE_PATH = os.path.join(os.path.dirname(__file__), "resources", "test_image.txt")
TEST_URL = "https://httpbin.org/image/jpeg" # 一个可用的测试图片URL

# 创建测试用的具体子类
class TestMediaMessage(MediaMessage):
__test__ = False

resource_type = "test"
def to_plain(self):
return "[TestMedia]"

@pytest.mark.asyncio
async def test_media_element_from_path():
# 测试从文件路径初始化
media = TestMediaMessage(path=TEST_RESOURCE_PATH)

# 测试获取数据
data = await media.get_data()
assert data is not None
assert isinstance(data, bytes)

# 测试获取URL (data URL格式)
url = await media.get_url()
assert url.startswith("data:")
assert "base64" in url

# 测试获取路径
path = await media.get_path()
assert os.path.exists(path)
assert os.path.isfile(path)

@pytest.mark.asyncio
async def test_media_element_from_url():
# 测试从URL初始化
media = TestMediaMessage(url=TEST_URL)

# 测试获取数据
data = await media.get_data()
assert data is not None
assert isinstance(data, bytes)

# 测试获取原始URL
url = await media.get_url()
assert url == TEST_URL

# 测试获取临时文件路径
path = await media.get_path()
assert os.path.exists(path)
assert os.path.isfile(path)

@pytest.mark.asyncio
async def test_media_element_from_data():
# 首先从文件读取一些测试数据
with open(TEST_RESOURCE_PATH, "rb") as f:
test_data = f.read()

# 测试从二进制数据初始化
media = TestMediaMessage(data=test_data, format="txt")

# 测试获取数据
data = await media.get_data()
assert data == test_data

# 测试获取URL (应该是data URL)
url = await media.get_url()
assert url.startswith("data:")
assert "base64" in url

# 测试获取临时文件路径
path = await media.get_path()
assert os.path.exists(path)
assert os.path.isfile(path)

@pytest.mark.asyncio
async def test_media_element_format_detection():
# 测试格式自动检测
media = TestMediaMessage(path=TEST_RESOURCE_PATH)
await media.get_data() # 触发格式检测
assert media.format is not None
assert media.resource_type is not None

@pytest.mark.asyncio
async def test_media_element_errors():
# 测试错误情况
with pytest.raises(ValueError):
TestMediaMessage() # 没有提供任何参数

with pytest.raises(ValueError):
TestMediaMessage(data=b"test") # 提供数据但没有格式

with pytest.raises(aiohttp.ClientError):
media = TestMediaMessage(url="https://invalid-url-that-does-not-exist.com/image.jpg")
await media.get_data() # 无效的URL

0 comments on commit 1c6b66b

Please sign in to comment.