Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: 应用设置中配置语音输入和播放 #1120

Merged
merged 1 commit into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Generated by Django 4.2.15 on 2024-09-05 14:35

from django.db import migrations, models
import django.db.models.deletion


class Migration(migrations.Migration):

dependencies = [
('setting', '0006_alter_model_status'),
('application', '0011_application_model_params_setting'),
]

operations = [
migrations.AddField(
model_name='application',
name='stt_model',
field=models.ForeignKey(blank=True, db_constraint=False, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='stt_model_id', to='setting.model'),
),
migrations.AddField(
model_name='application',
name='stt_model_enable',
field=models.BooleanField(default=False, verbose_name='语音识别模型是否启用'),
),
migrations.AddField(
model_name='application',
name='tts_model',
field=models.ForeignKey(blank=True, db_constraint=False, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='tts_model_id', to='setting.model'),
),
migrations.AddField(
model_name='application',
name='tts_model_enable',
field=models.BooleanField(default=False, verbose_name='语音合成模型是否启用'),
),
]
4 changes: 4 additions & 0 deletions apps/application/models/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ class Application(AppModelMixin):
work_flow = models.JSONField(verbose_name="工作流数据", default=dict)
type = models.CharField(verbose_name="应用类型", choices=ApplicationTypeChoices.choices,
default=ApplicationTypeChoices.SIMPLE, max_length=256)
tts_model = models.ForeignKey(Model, related_name='tts_model_id', on_delete=models.SET_NULL, db_constraint=False, blank=True, null=True)
stt_model = models.ForeignKey(Model, related_name='stt_model_id', on_delete=models.SET_NULL, db_constraint=False, blank=True, null=True)
tts_model_enable = models.BooleanField(verbose_name="语音合成模型是否启用", default=False)
stt_model_enable = models.BooleanField(verbose_name="语音识别模型是否启用", default=False)

@staticmethod
def get_default_model_prompt():
Expand Down
51 changes: 45 additions & 6 deletions apps/application/serializers/application_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ def list(self, with_valid=True):
@staticmethod
def reset_application(application: Dict):
application['multiple_rounds_dialogue'] = True if application.get('dialogue_number') > 0 else False
del application['dialogue_number']

if 'dataset_setting' in application:
application['dataset_setting'] = {'search_mode': 'embedding', 'no_references_setting': {
'status': 'ai_questioning',
Expand Down Expand Up @@ -710,21 +710,39 @@ def edit(self, instance: Dict, with_valid=True):
raise AppApiException(500, "模型不存在")
if not model.is_permission(application.user_id):
raise AppApiException(500, f"沒有权限使用该模型:{model.name}")
if instance.get('stt_model_id') is None or len(instance.get('stt_model_id')) == 0:
application.stt_model_id = None
else:
model = QuerySet(Model).filter(
id=instance.get('stt_model_id')).first()
if model is None:
raise AppApiException(500, "模型不存在")
if not model.is_permission(application.user_id):
raise AppApiException(500, f"沒有权限使用该模型:{model.name}")
if instance.get('tts_model_id') is None or len(instance.get('tts_model_id')) == 0:
application.tts_model_id = None
else:
model = QuerySet(Model).filter(
id=instance.get('tts_model_id')).first()
if model is None:
raise AppApiException(500, "模型不存在")
if not model.is_permission(application.user_id):
raise AppApiException(500, f"沒有权限使用该模型:{model.name}")
if 'work_flow' in instance:
# 当前用户可修改关联的知识库列表
application_dataset_id_list = [str(dataset_dict.get('id')) for dataset_dict in
self.list_dataset(with_valid=False)]
self.update_reverse_search_node(instance.get('work_flow'), application_dataset_id_list)
# 找到语音配置相关
self.get_work_flow_model(instance)

update_keys = ['name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'prologue', 'status',
'dataset_setting', 'model_setting', 'problem_optimization',
'dataset_setting', 'model_setting', 'problem_optimization', 'dialogue_number',
'stt_model_id', 'tts_model_id', 'tts_model_enable', 'stt_model_enable',
'api_key_is_active', 'icon', 'work_flow', 'model_params_setting']
for update_key in update_keys:
if update_key in instance and instance.get(update_key) is not None:
if update_key == 'multiple_rounds_dialogue':
application.__setattr__('dialogue_number', 0 if not instance.get(update_key) else 3)
else:
application.__setattr__(update_key, instance.get(update_key))
application.__setattr__(update_key, instance.get(update_key))
application.save()

if 'dataset_id_list' in instance:
Expand Down Expand Up @@ -823,6 +841,27 @@ def save_other_config(self, data):

application.save()

@staticmethod
def get_work_flow_model(instance):
nodes = instance.get('work_flow')['nodes']
for node in nodes:
if node['id'] == 'base-node':
instance['stt_model_id'] = node['properties']['node_data']['stt_model_id']
instance['tts_model_id'] = node['properties']['node_data']['tts_model_id']
instance['stt_model_enable'] = node['properties']['node_data']['stt_model_enable']
instance['tts_model_enable'] = node['properties']['node_data']['tts_model_enable']
break

def speech_to_text(self, filelist):
# todo 找到模型 mp3转text
print(self.application_id)
print(filelist)

def text_to_speech(self, text):
# todo 找到模型 text转bytes
print(self.application_id)
print(text)

class ApplicationKeySerializerModel(serializers.ModelSerializer):
class Meta:
model = ApplicationApiKey
Expand Down
5 changes: 4 additions & 1 deletion apps/application/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,5 +63,8 @@
path(
'application/<str:application_id>/chat/<chat_id>/chat_record/<str:chat_record_id>/dataset/<str:dataset_id>/document_id/<str:document_id>/improve/<str:paragraph_id>',
views.ChatView.ChatRecord.Improve.Operate.as_view(),
name='')
name=''),
path('application/<str:application_id>/<str:model_id>/speech_to_text', views.Application.SpeechToText.as_view(), name='application/audio'),
path('application/<str:application_id>/<str:model_id>/text_to_speech', views.Application.TextToSpeech.as_view(), name='application/audio'),

]
32 changes: 32 additions & 0 deletions apps/application/views/application_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,3 +534,35 @@ def get(self, request: Request, current_page: int, page_size: int):
ApplicationSerializer.Query(
data={**query_params_to_single_dict(request.query_params), 'user_id': request.user.id}).page(
current_page, page_size))

class SpeechToText(APIView):
authentication_classes = [TokenAuth]

@action(methods=['POST'], detail=False)
@has_permissions(ViewPermission([RoleConstants.ADMIN, RoleConstants.USER],
[lambda r, keywords: Permission(group=Group.APPLICATION,
operate=Operate.USE,
dynamic_tag=keywords.get(
'application_id'))],
compare=CompareConstants.AND))
def post(self, request: Request, application_id: str, model_id: str):
return result.success(
ApplicationSerializer.Operate(
data={'application_id': application_id, 'user_id': request.user.id, 'model_id': model_id})
.speech_to_text(request.FILES.getlist('file')[0]))

class TextToSpeech(APIView):
authentication_classes = [TokenAuth]

@action(methods=['POST'], detail=False)
@has_permissions(ViewPermission([RoleConstants.ADMIN, RoleConstants.USER],
[lambda r, keywords: Permission(group=Group.APPLICATION,
operate=Operate.USE,
dynamic_tag=keywords.get(
'application_id'))],
compare=CompareConstants.AND))
def post(self, request: Request, application_id: str, model_id: str):
return result.success(
ApplicationSerializer.Operate(
data={'application_id': application_id, 'user_id': request.user.id, 'model_id': model_id})
.text_to_speech(request.data.get('text')))
33 changes: 32 additions & 1 deletion ui/src/api/application.ts
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,35 @@ const getApplicationRerankerModel: (
) => Promise<Result<Array<any>>> = (application_id, loading) => {
return get(`${prefix}/${application_id}/model`, { model_type: 'RERANKER' }, loading)
}

/**
* 获取当前用户可使用的模型列表
* @param application_id
* @param loading
* @query { query_text: string, top_number: number, similarity: number }
* @returns
*/
const getApplicationSTTModel: (
application_id: string,
loading?: Ref<boolean>
) => Promise<Result<Array<any>>> = (application_id, loading) => {
return get(`${prefix}/${application_id}/model`, { model_type: 'STT' }, loading)
}

/**
* 获取当前用户可使用的模型列表
* @param application_id
* @param loading
* @query { query_text: string, top_number: number, similarity: number }
* @returns
*/
const getApplicationTTSModel: (
application_id: string,
loading?: Ref<boolean>
) => Promise<Result<Array<any>>> = (application_id, loading) => {
return get(`${prefix}/${application_id}/model`, { model_type: 'TTS' }, loading)
}

/**
* 发布应用
* @param 参数
Expand Down Expand Up @@ -324,5 +353,7 @@ export default {
listFunctionLib,
getFunctionLib,
getModelParamsForm,
getApplicationRerankerModel
getApplicationRerankerModel,
getApplicationSTTModel,
getApplicationTTSModel,
}
4 changes: 4 additions & 0 deletions ui/src/api/type/application.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ interface ApplicationFormType {
type?: string
work_flow?: any
model_params_setting?: any
stt_model_id?: string
tts_model_id?: string
stt_model_enable?: boolean
tts_model_enable?: boolean
}
interface chatType {
id: string
Expand Down
Loading
Loading