4
4
from loguru import logger
5
5
6
6
from infinity_emb import AsyncEngineArray , EngineArgs , AsyncEmbeddingEngine
7
+ from infinity_emb .inference .select_model import get_engine_type_from_config
7
8
from gpt_server .model_worker .base .model_worker_base import ModelWorkerBase
8
9
9
10
label_to_category = {
@@ -49,30 +50,26 @@ def __init__(
49
50
bettertransformer = True
50
51
if model_type is not None and "deberta" in model_type :
51
52
bettertransformer = False
52
- self .engine : AsyncEmbeddingEngine = AsyncEngineArray .from_args (
53
- [
54
- EngineArgs (
55
- model_name_or_path = model_path ,
56
- engine = "torch" ,
57
- embedding_dtype = "float32" ,
58
- dtype = "float32" ,
59
- device = device ,
60
- bettertransformer = bettertransformer ,
61
- )
62
- ]
63
- )[0 ]
53
+ engine_args = EngineArgs (
54
+ model_name_or_path = model_path ,
55
+ engine = "torch" ,
56
+ embedding_dtype = "float32" ,
57
+ dtype = "float32" ,
58
+ device = device ,
59
+ bettertransformer = bettertransformer ,
60
+ )
61
+ engine_type = get_engine_type_from_config (engine_args )
62
+ engine_type_str = str (engine_type )
63
+ if "EmbedderEngine" in engine_type_str :
64
+ self .mode = "embedding"
65
+ elif "RerankEngine" in engine_type_str :
66
+ self .mode = "rerank"
67
+ elif "ImageEmbedEngine" in engine_type_str :
68
+ self .mode = "image"
69
+ self .engine : AsyncEmbeddingEngine = AsyncEngineArray .from_args ([engine_args ])[0 ]
64
70
loop = asyncio .get_running_loop ()
65
71
loop .create_task (self .engine .astart ())
66
- self .mode = "embedding"
67
- # rerank
68
- for model_name in model_names :
69
- if "rerank" in model_name :
70
- self .mode = "rerank"
71
- break
72
- if self .mode == "rerank" :
73
- logger .info ("正在使用 rerank 模型..." )
74
- elif self .mode == "embedding" :
75
- logger .info ("正在使用 embedding 模型..." )
72
+ logger .info (f"正在使用 { self .mode } 模型..." )
76
73
logger .info (f"模型:{ model_names [0 ]} " )
77
74
78
75
async def astart (self ):
@@ -83,7 +80,7 @@ async def get_embeddings(self, params):
83
80
logger .info (f"worker_id: { self .worker_id } " )
84
81
self .call_ct += 1
85
82
ret = {"embedding" : [], "token_num" : 0 }
86
- texts = params ["input" ]
83
+ texts : list = params ["input" ]
87
84
if self .mode == "embedding" :
88
85
texts = list (map (lambda x : x .replace ("\n " , " " ), texts ))
89
86
embeddings , usage = await self .engine .embed (sentences = texts )
@@ -105,6 +102,17 @@ async def get_embeddings(self, params):
105
102
embedding = [
106
103
[round (float (score ["relevance_score" ]), 6 )] for score in ranking
107
104
]
105
+ elif self .mode == "image" :
106
+ if (
107
+ isinstance (texts [0 ], bytes )
108
+ or "http" in texts [0 ]
109
+ or "data:image" in texts [0 ]
110
+ ):
111
+ embeddings , usage = await self .engine .image_embed (images = texts )
112
+ else :
113
+ embeddings , usage = await self .engine .embed (sentences = texts )
114
+
115
+ embedding = [embedding .tolist () for embedding in embeddings ]
108
116
ret ["embedding" ] = embedding
109
117
ret ["token_num" ] = usage
110
118
return ret
0 commit comments