Skip to content

Commit

Permalink
Refactor/improve error handling flow (#36)
Browse files Browse the repository at this point in the history
* refactor: improve error handling

* fix: typo

* chore: add more debug logging
  • Loading branch information
nickcom007 authored Jun 27, 2024
1 parent 7ed7b3d commit 374aa06
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 15 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,6 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
src/.DS_Store

# data
lora
43 changes: 43 additions & 0 deletions src/core/exception.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from loguru import logger
from client.fed_ledger import FedLedger
import sys


def handle_os_error(e: OSError):
if "No space left on device" in str(e):
logger.error("No more disk space, exiting with code 101")
sys.exit(101)
else:
logger.error("Unknown OSError detected, exiting with code 100, will restart...")
sys.exit(100)


def handle_runtime_error(e: RuntimeError, assignment_id: str, client: FedLedger):
if "CUDA error: device-side assert triggered" in str(e):
logger.error(
"CUDA device-side assert triggered error detected, exiting with code 100, will restart..."
)
sys.exit(100)
if "out of memory" in str(e):
logger.error(
"CUDA out of memory error detected, will mark the assignment as failed"
)
client.mark_assignment_as_failed(assignment_id)
else:
logger.error(
"Unknown RuntimeError detected, exiting with code 100, will restart..."
)
sys.exit(100)


def handle_value_error(e: ValueError, assignment_id: str, client: FedLedger):
if "FP16 Mixed precision training with AMP or APEX" in str(e):
logger.error(
"FP16 Mixed precision training with AMP or APEX error detected, exiting with code 101"
)
sys.exit(101)
else:
logger.error(
"Unknown ValueError detected, exiting with code 100, will restart..."
)
sys.exit(100)
41 changes: 26 additions & 15 deletions src/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@
from core.template import template_dict
from core.hf_utils import download_lora_config, download_lora_repo
from core.constant import SUPPORTED_BASE_MODELS
from core.exception import (
handle_os_error,
handle_runtime_error,
handle_value_error,
)
from tenacity import retry, stop_after_attempt, wait_exponential
from client.fed_ledger import FedLedger
from peft import PeftModel
Expand Down Expand Up @@ -267,6 +272,9 @@ def validate(
"assignment_id is required for submitting validation result to the server"
)

model = None
eval_dataset = None

try:
fed_ledger = FedLedger(FLOCK_API_KEY)
parser = HfArgumentParser(TrainingArguments)
Expand Down Expand Up @@ -327,28 +335,24 @@ def validate(
logger.info(
f"Successfully submitted validation result for assignment {assignment_id}"
)
except (OSError, RuntimeError) as e:
# Handle CUDA related error
if "CUDA error: device-side assert triggered" in str(e):
logger.error("CUDA error detected, exiting with code 100")
sys.exit(100)
else:
# log the type of the exception
logger.error(f"An error occurred while validating the model: {e}")
# fail this assignment
fed_ledger.mark_assignment_as_failed(assignment_id)

# raise for other exceptions

# raise for exceptions, will handle at `loop` level
except Exception as e:
raise e
finally:
# offload the model to save memory
gc.collect()
model.cpu()
del model, eval_dataset
if model is not None:
logger.debug("Offloading model to save memory")
model.cpu()
del model
if eval_dataset is not None:
logger.debug("Offloading eval_dataset to save memory")
del eval_dataset
torch.cuda.empty_cache()
# remove lora folder
if os.path.exists("lora"):
logger.debug("Removing lora folder")
os.system("rm -rf lora")


Expand Down Expand Up @@ -442,7 +446,14 @@ def loop(validation_args_file: str, task_id: str = None, auto_clean_cache: bool
)
break # Break the loop if no exception
except KeyboardInterrupt:
break
# directly terminate the process if keyboard interrupt
sys.exit(1)
except OSError as e:
handle_os_error(e)
except RuntimeError as e:
handle_runtime_error(e, assignment_id, fed_ledger)
except ValueError as e:
handle_value_error(e, assignment_id, fed_ledger)
except Exception as e:
logger.error(f"Attempt {attempt + 1} failed: {e}")
if attempt == 2:
Expand Down

0 comments on commit 374aa06

Please sign in to comment.