From f1ee5d41fd3e7e0a2bc3d3cff512a3233dd6f627 Mon Sep 17 00:00:00 2001 From: dmitroprobachay Date: Mon, 6 Jan 2025 12:13:04 +0200 Subject: [PATCH] Repair OCR RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! --- nomeroff_net/nnmodels/ocr_model.py | 6 +++--- nomeroff_net/pipes/number_plate_text_readers/base/ocr.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/nomeroff_net/nnmodels/ocr_model.py b/nomeroff_net/nnmodels/ocr_model.py index 247f3342..5638ef06 100644 --- a/nomeroff_net/nnmodels/ocr_model.py +++ b/nomeroff_net/nnmodels/ocr_model.py @@ -156,10 +156,10 @@ def calculate_loss(self, logits, texts): logits_lens = torch.full(size=(batch_size,), fill_value=input_len, dtype=torch.int32) # calculate ctc loss = self.criterion( - logits, - encoded_texts, + logits.to(device), + encoded_texts.to(device), logits_lens.to(device), - text_lens) + text_lens.to(device)) return loss def step(self, batch): diff --git a/nomeroff_net/pipes/number_plate_text_readers/base/ocr.py b/nomeroff_net/pipes/number_plate_text_readers/base/ocr.py index 7f205218..db25efcb 100644 --- a/nomeroff_net/pipes/number_plate_text_readers/base/ocr.py +++ b/nomeroff_net/pipes/number_plate_text_readers/base/ocr.py @@ -382,10 +382,10 @@ def get_acc(self, predicted: List, decode: List) -> torch.Tensor: logits_lens = torch.full(size=(batch_size,), fill_value=input_len, dtype=torch.int32) acc = functional.ctc_loss( - logits, - encoded_texts, + logits.to(device), + encoded_texts.to(device), logits_lens.to(device), - text_lens + text_lens.to(device) ) return 1 - acc / len(self.letters)