Skip to content

Commit

Permalink
Merge pull request #105 from weaviate/add-ability-to-pass-target-devices
Browse files Browse the repository at this point in the history
Add ability to pass target devices to sentence transformers pool
  • Loading branch information
antas-marcin authored Jan 21, 2025
2 parents 35788cc + 8f414d1 commit 61c4560
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 15 deletions.
24 changes: 15 additions & 9 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,15 @@ def is_authorized(auth: Optional[HTTPAuthorizationCredentials]) -> bool:


def get_worker():
global current_worker
worker = current_worker % available_workers
current_worker += 1
return worker
if available_workers == 1:
return 0
else:
global current_worker
if current_worker >= 1_000_000_000:
current_worker = 0
worker = current_worker % available_workers
current_worker += 1
return worker


async def lifespan(app: FastAPI):
Expand Down Expand Up @@ -104,12 +109,16 @@ def log_info_about_onnx(onnx_runtime: bool):
)
cuda_support = False
cuda_core = ""
# Use all sentence transformers multi process
use_sentence_transformers_multi_process = (
get_use_sentence_transformers_multi_process()
)

if cuda_env is not None and cuda_env == "true" or cuda_env == "1":
cuda_support = True
cuda_core = os.getenv("CUDA_CORE")
if cuda_core is None or cuda_core == "":
if use_sentence_transformers_vectorizer and torch.cuda.is_available():
if use_sentence_transformers_vectorizer and use_sentence_transformers_multi_process and torch.cuda.is_available():
available_workers = torch.cuda.device_count()
cuda_core = ",".join([f"cuda:{i}" for i in range(available_workers)])
else:
Expand All @@ -118,10 +127,7 @@ def log_info_about_onnx(onnx_runtime: bool):
else:
logger.info("Running on CPU")

# Use all available cores
use_sentence_transformers_multi_process = (
get_use_sentence_transformers_multi_process()
)


# Batch text tokenization enabled by default
direct_tokenize = get_t2v_transformers_direct_tokenize()
Expand Down
14 changes: 8 additions & 6 deletions vectorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def __init__(
workers, self.use_sentence_transformers_multi_process
)
self.logger.info(
f"Sentence transformer vectorizer running with model_name={model_name}, cache_folder={model_path} available_devices:{self.available_devices} trust_remote_code:{trust_remote_code} use_sentence_transformers_multi_process: {self.use_sentence_transformers_multi_process}"
f"Sentence transformer vectorizer running with model_name={model_name}, cache_folder={model_path} trust_remote_code:{trust_remote_code}"
)
self.workers = []
for device in self.available_devices:
Expand All @@ -136,16 +136,21 @@ def __init__(
model.eval() # make sure we're in inference mode, not training
self.workers.append(model)

print(f"have a list of {len(self.workers)}")
if self.use_sentence_transformers_multi_process:
self.pool = self.workers[0].start_multi_process_pool()
self.pool = self.workers[0].start_multi_process_pool(
target_devices=self.get_cuda_devices()
)
self.logger.info(
"Sentence transformer vectorizer is set to use all available devices"
)
self.logger.info(
f"Created pool of {len(self.pool['processes'])} available {'CUDA' if torch.cuda.is_available() else 'CPU'} devices"
)

def get_cuda_devices(self) -> List[str] | None:
if self.cuda_core is not None and self.cuda_core != "":
return self.cuda_core.split(",")

def get_devices(
self,
workers: int | None,
Expand All @@ -169,9 +174,6 @@ def vectorize(self, text: str, config: VectorInputConfig, worker: int = 0):
)
return embedding[0]

print(
f"trying to vectorize: worker {worker} using device: {self.available_devices[worker]} available: {len(self.available_devices)}"
)
embedding = self.workers[worker].encode(
[text],
device=self.available_devices[worker],
Expand Down

0 comments on commit 61c4560

Please sign in to comment.