From a4c98ab83b37b0ff9ee2e52dfbf12760a060973e Mon Sep 17 00:00:00 2001 From: Kristoffer Andersson Date: Fri, 22 Nov 2024 13:42:55 +0100 Subject: [PATCH] feat: use gpu if available --- .../sentiment_analyzer.py | 33 +++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/sparv-sbx-sentence-sentiment-kb-sent/src/sbx_sentence_sentiment_kb_sent/sentiment_analyzer.py b/sparv-sbx-sentence-sentiment-kb-sent/src/sbx_sentence_sentiment_kb_sent/sentiment_analyzer.py index 2171719..bec8d9b 100644 --- a/sparv-sbx-sentence-sentiment-kb-sent/src/sbx_sentence_sentiment_kb_sent/sentiment_analyzer.py +++ b/sparv-sbx-sentence-sentiment-kb-sent/src/sbx_sentence_sentiment_kb_sent/sentiment_analyzer.py @@ -3,6 +3,7 @@ from collections import defaultdict from typing import Dict, List, Optional, Union +import torch from sparv import api as sparv_api # type: ignore [import-untyped] from transformers import ( # type: ignore [import-untyped] AutoModelForSequenceClassification, @@ -23,6 +24,20 @@ MAX_LENGTH: int = 700 +def _get_dtype() -> torch.dtype: + if torch.cuda.is_available(): + logger.info("Using GPU (cuda)") + dtype = torch.float16 + else: + logger.warning("Using CPU, is cuda available?") + dtype = torch.float32 + return dtype + + +def _get_device_map() -> Optional[str]: + return "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None + + class SentimentAnalyzer: """Sentiment analyzer.""" @@ -45,8 +60,19 @@ def __init__( self.tokenizer = self._default_tokenizer() if tokenizer is None else tokenizer self.model = self._default_model() if model is None else model self.num_decimals = num_decimals + + if torch.cuda.is_available() and torch.cuda.device_count() == 1: + logger.info("Using GPU (cuda)") + self.model = self.model.cuda() # type: ignore + else: + logger.warning("Using CPU, is cuda available?") + self.classifier = pipeline( - "sentiment-analysis", model=self.model, tokenizer=self.tokenizer + "sentiment-analysis", + model=self.model, + tokenizer=self.tokenizer, + torch_dtype=_get_dtype(), + device_map=_get_device_map(), ) @classmethod @@ -56,7 +82,10 @@ def _default_tokenizer(cls) -> PreTrainedTokenizerFast: @classmethod def _default_model(cls) -> MegatronBertForSequenceClassification: return AutoModelForSequenceClassification.from_pretrained( - MODEL_NAME, revision=MODEL_REVISION + MODEL_NAME, + revision=MODEL_REVISION, + torch_dtype=_get_dtype(), + device_map=_get_device_map(), ) @classmethod