diff --git a/apps/application/migrations/0012_application_stt_model_application_stt_model_enable_and_more.py b/apps/application/migrations/0012_application_stt_model_application_stt_model_enable_and_more.py new file mode 100644 index 00000000000..f50c39d2f87 --- /dev/null +++ b/apps/application/migrations/0012_application_stt_model_application_stt_model_enable_and_more.py @@ -0,0 +1,35 @@ +# Generated by Django 4.2.15 on 2024-09-05 14:35 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('setting', '0006_alter_model_status'), + ('application', '0011_application_model_params_setting'), + ] + + operations = [ + migrations.AddField( + model_name='application', + name='stt_model', + field=models.ForeignKey(blank=True, db_constraint=False, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='stt_model_id', to='setting.model'), + ), + migrations.AddField( + model_name='application', + name='stt_model_enable', + field=models.BooleanField(default=False, verbose_name='语音识别模型是否启用'), + ), + migrations.AddField( + model_name='application', + name='tts_model', + field=models.ForeignKey(blank=True, db_constraint=False, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='tts_model_id', to='setting.model'), + ), + migrations.AddField( + model_name='application', + name='tts_model_enable', + field=models.BooleanField(default=False, verbose_name='语音合成模型是否启用'), + ), + ] diff --git a/apps/application/models/application.py b/apps/application/models/application.py index 837aadbada6..54019cb0705 100644 --- a/apps/application/models/application.py +++ b/apps/application/models/application.py @@ -54,6 +54,10 @@ class Application(AppModelMixin): work_flow = models.JSONField(verbose_name="工作流数据", default=dict) type = models.CharField(verbose_name="应用类型", choices=ApplicationTypeChoices.choices, default=ApplicationTypeChoices.SIMPLE, max_length=256) + tts_model = models.ForeignKey(Model, related_name='tts_model_id', on_delete=models.SET_NULL, db_constraint=False, blank=True, null=True) + stt_model = models.ForeignKey(Model, related_name='stt_model_id', on_delete=models.SET_NULL, db_constraint=False, blank=True, null=True) + tts_model_enable = models.BooleanField(verbose_name="语音合成模型是否启用", default=False) + stt_model_enable = models.BooleanField(verbose_name="语音识别模型是否启用", default=False) @staticmethod def get_default_model_prompt(): diff --git a/apps/application/serializers/application_serializers.py b/apps/application/serializers/application_serializers.py index eecfc6a0f16..74c47973d66 100644 --- a/apps/application/serializers/application_serializers.py +++ b/apps/application/serializers/application_serializers.py @@ -516,7 +516,7 @@ def list(self, with_valid=True): @staticmethod def reset_application(application: Dict): application['multiple_rounds_dialogue'] = True if application.get('dialogue_number') > 0 else False - del application['dialogue_number'] + if 'dataset_setting' in application: application['dataset_setting'] = {'search_mode': 'embedding', 'no_references_setting': { 'status': 'ai_questioning', @@ -710,21 +710,39 @@ def edit(self, instance: Dict, with_valid=True): raise AppApiException(500, "模型不存在") if not model.is_permission(application.user_id): raise AppApiException(500, f"沒有权限使用该模型:{model.name}") + if instance.get('stt_model_id') is None or len(instance.get('stt_model_id')) == 0: + application.stt_model_id = None + else: + model = QuerySet(Model).filter( + id=instance.get('stt_model_id')).first() + if model is None: + raise AppApiException(500, "模型不存在") + if not model.is_permission(application.user_id): + raise AppApiException(500, f"沒有权限使用该模型:{model.name}") + if instance.get('tts_model_id') is None or len(instance.get('tts_model_id')) == 0: + application.tts_model_id = None + else: + model = QuerySet(Model).filter( + id=instance.get('tts_model_id')).first() + if model is None: + raise AppApiException(500, "模型不存在") + if not model.is_permission(application.user_id): + raise AppApiException(500, f"沒有权限使用该模型:{model.name}") if 'work_flow' in instance: # 当前用户可修改关联的知识库列表 application_dataset_id_list = [str(dataset_dict.get('id')) for dataset_dict in self.list_dataset(with_valid=False)] self.update_reverse_search_node(instance.get('work_flow'), application_dataset_id_list) + # 找到语音配置相关 + self.get_work_flow_model(instance) update_keys = ['name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'prologue', 'status', - 'dataset_setting', 'model_setting', 'problem_optimization', + 'dataset_setting', 'model_setting', 'problem_optimization', 'dialogue_number', + 'stt_model_id', 'tts_model_id', 'tts_model_enable', 'stt_model_enable', 'api_key_is_active', 'icon', 'work_flow', 'model_params_setting'] for update_key in update_keys: if update_key in instance and instance.get(update_key) is not None: - if update_key == 'multiple_rounds_dialogue': - application.__setattr__('dialogue_number', 0 if not instance.get(update_key) else 3) - else: - application.__setattr__(update_key, instance.get(update_key)) + application.__setattr__(update_key, instance.get(update_key)) application.save() if 'dataset_id_list' in instance: @@ -823,6 +841,27 @@ def save_other_config(self, data): application.save() + @staticmethod + def get_work_flow_model(instance): + nodes = instance.get('work_flow')['nodes'] + for node in nodes: + if node['id'] == 'base-node': + instance['stt_model_id'] = node['properties']['node_data']['stt_model_id'] + instance['tts_model_id'] = node['properties']['node_data']['tts_model_id'] + instance['stt_model_enable'] = node['properties']['node_data']['stt_model_enable'] + instance['tts_model_enable'] = node['properties']['node_data']['tts_model_enable'] + break + + def speech_to_text(self, filelist): + # todo 找到模型 mp3转text + print(self.application_id) + print(filelist) + + def text_to_speech(self, text): + # todo 找到模型 text转bytes + print(self.application_id) + print(text) + class ApplicationKeySerializerModel(serializers.ModelSerializer): class Meta: model = ApplicationApiKey diff --git a/apps/application/urls.py b/apps/application/urls.py index 12b25fd077c..2b7e779faeb 100644 --- a/apps/application/urls.py +++ b/apps/application/urls.py @@ -63,5 +63,8 @@ path( 'application//chat//chat_record//dataset//document_id//improve/', views.ChatView.ChatRecord.Improve.Operate.as_view(), - name='') + name=''), + path('application///speech_to_text', views.Application.SpeechToText.as_view(), name='application/audio'), + path('application///text_to_speech', views.Application.TextToSpeech.as_view(), name='application/audio'), + ] diff --git a/apps/application/views/application_views.py b/apps/application/views/application_views.py index 560c69fab58..db0dae034c3 100644 --- a/apps/application/views/application_views.py +++ b/apps/application/views/application_views.py @@ -534,3 +534,35 @@ def get(self, request: Request, current_page: int, page_size: int): ApplicationSerializer.Query( data={**query_params_to_single_dict(request.query_params), 'user_id': request.user.id}).page( current_page, page_size)) + + class SpeechToText(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['POST'], detail=False) + @has_permissions(ViewPermission([RoleConstants.ADMIN, RoleConstants.USER], + [lambda r, keywords: Permission(group=Group.APPLICATION, + operate=Operate.USE, + dynamic_tag=keywords.get( + 'application_id'))], + compare=CompareConstants.AND)) + def post(self, request: Request, application_id: str, model_id: str): + return result.success( + ApplicationSerializer.Operate( + data={'application_id': application_id, 'user_id': request.user.id, 'model_id': model_id}) + .speech_to_text(request.FILES.getlist('file')[0])) + + class TextToSpeech(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['POST'], detail=False) + @has_permissions(ViewPermission([RoleConstants.ADMIN, RoleConstants.USER], + [lambda r, keywords: Permission(group=Group.APPLICATION, + operate=Operate.USE, + dynamic_tag=keywords.get( + 'application_id'))], + compare=CompareConstants.AND)) + def post(self, request: Request, application_id: str, model_id: str): + return result.success( + ApplicationSerializer.Operate( + data={'application_id': application_id, 'user_id': request.user.id, 'model_id': model_id}) + .text_to_speech(request.data.get('text'))) diff --git a/ui/src/api/application.ts b/ui/src/api/application.ts index 5711b8ac7e1..2cf3b413eb3 100644 --- a/ui/src/api/application.ts +++ b/ui/src/api/application.ts @@ -250,6 +250,35 @@ const getApplicationRerankerModel: ( ) => Promise>> = (application_id, loading) => { return get(`${prefix}/${application_id}/model`, { model_type: 'RERANKER' }, loading) } + +/** + * 获取当前用户可使用的模型列表 + * @param application_id + * @param loading + * @query { query_text: string, top_number: number, similarity: number } + * @returns + */ +const getApplicationSTTModel: ( + application_id: string, + loading?: Ref +) => Promise>> = (application_id, loading) => { + return get(`${prefix}/${application_id}/model`, { model_type: 'STT' }, loading) +} + +/** + * 获取当前用户可使用的模型列表 + * @param application_id + * @param loading + * @query { query_text: string, top_number: number, similarity: number } + * @returns + */ +const getApplicationTTSModel: ( + application_id: string, + loading?: Ref +) => Promise>> = (application_id, loading) => { + return get(`${prefix}/${application_id}/model`, { model_type: 'TTS' }, loading) +} + /** * 发布应用 * @param 参数 @@ -324,5 +353,7 @@ export default { listFunctionLib, getFunctionLib, getModelParamsForm, - getApplicationRerankerModel + getApplicationRerankerModel, + getApplicationSTTModel, + getApplicationTTSModel, } diff --git a/ui/src/api/type/application.ts b/ui/src/api/type/application.ts index 944f9179d30..4584149083a 100644 --- a/ui/src/api/type/application.ts +++ b/ui/src/api/type/application.ts @@ -14,6 +14,10 @@ interface ApplicationFormType { type?: string work_flow?: any model_params_setting?: any + stt_model_id?: string + tts_model_id?: string + stt_model_enable?: boolean + tts_model_enable?: boolean } interface chatType { id: string diff --git a/ui/src/views/application/ApplicationSetting.vue b/ui/src/views/application/ApplicationSetting.vue index 7ba0a87db3e..e52b7d734f6 100644 --- a/ui/src/views/application/ApplicationSetting.vue +++ b/ui/src/views/application/ApplicationSetting.vue @@ -288,6 +288,147 @@ + + + + + +
+ + {{ item.name }} + 公用 + +
+ + + +
+ + +
+ + {{ item.name }} + {{ + $t('views.application.applicationForm.form.aiModel.unavailable') + }} +
+ + + +
+
+
+
+ + + + + +
+ + {{ item.name }} + 公用 + +
+ + + +
+ + +
+ + {{ item.name }} + {{ + $t('views.application.applicationForm.form.aiModel.unavailable') + }} +
+ + + +
+
+
+
@@ -411,6 +552,10 @@ const applicationForm = ref({ }, model_params_setting: {}, problem_optimization: false, + stt_model_id: '', + tts_model_id: '', + stt_model_enable: false, + tts_model_enable: false, type: 'SIMPLE' }) @@ -440,6 +585,8 @@ const rules = reactive>({ const modelOptions = ref(null) const providerOptions = ref>([]) const datasetList = ref([]) +const sttModelOptions = ref(null) +const ttsModelOptions = ref(null) const submit = async (formEl: FormInstance | undefined) => { if (!formEl) return @@ -508,6 +655,8 @@ function getDetail() { application.asyncGetApplicationDetail(id, loading).then((res: any) => { applicationForm.value = res.data applicationForm.value.model_id = res.data.model + applicationForm.value.stt_model_id = res.data.stt_model + applicationForm.value.tts_model_id = res.data.tts_model }) } @@ -530,6 +679,32 @@ function getModel() { }) } +function getSTTModel() { + loading.value = true + applicationApi + .getApplicationSTTModel(id) + .then((res: any) => { + sttModelOptions.value = groupBy(res?.data, 'provider') + loading.value = false + }) + .catch(() => { + loading.value = false + }) +} + +function getTTSModel() { + loading.value = true + applicationApi + .getApplicationTTSModel(id) + .then((res: any) => { + ttsModelOptions.value = groupBy(res?.data, 'provider') + loading.value = false + }) + .catch(() => { + loading.value = false + }) +} + function getProvider() { loading.value = true model @@ -552,6 +727,8 @@ onMounted(() => { getModel() getDataset() getDetail() + getSTTModel() + getTTSModel() })