-
Notifications
You must be signed in to change notification settings - Fork 33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: download only necessary model files #58
base: main
Are you sure you want to change the base?
Conversation
WalkthroughThe changes in this pull request primarily focus on enhancements to the Changes
Poem
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Outside diff range and nitpick comments (3)
src/validate.py (3)
Line range hint
324-326
: Initialize 'resp' before using it to avoid 'NameError'In the
loop
function, if thefor
loop completes without a successful assignment (nobreak
is executed), the variableresp
might remain undefined. This can lead to aNameError
when checkingif resp is None or resp.status_code != 200:
.Initialize
resp
before thefor
loop:def loop( validation_args_file: str, task_id: str = None, auto_clean_cache: bool = True, lora_only: bool = True, ): if task_id is None: raise ValueError("task_id is required for asking assignment_id") if auto_clean_cache: logger.info("Auto clean the model cache except for the base model") else: logger.info("Skip auto clean the model cache") repo_path = Path(__file__).resolve().parent.parent if not IS_DOCKER_CONTAINER: is_latest_version(repo_path) else: logger.info("Skip checking the latest version in docker container") logger.info( "Please make sure you are using the latest version of the docker image." ) fed_ledger = FedLedger(FLOCK_API_KEY) task_id_list = task_id.split(",") logger.info(f"Validating task_id: {task_id_list}") + resp = None # Initialize resp variable last_successful_request_time = [time.time()] * len(task_id_list) while True: clean_model_cache(auto_clean_cache)
Line range hint
275-278
: Add a delay between retries in the exception handling loopIn the
loop
function, consider adding a brief delay between retries when handling exceptions during validation attempts. This prevents rapid successive attempts and allows for graceful recovery.Apply this diff to add a delay:
for attempt in range(3): try: ctx = click.Context(validate) ctx.invoke( validate, model_name_or_path=resp["task_submission"]["data"]["hg_repo_id"], base_model=resp["data"]["base_model"], eval_file=eval_file, context_length=resp["data"]["context_length"], max_params=resp["data"]["max_params"], validation_args_file=validation_args_file, assignment_id=resp["id"], local_test=False, lora_only=lora_only, revision=revision, ) break # Break the loop if no exception except KeyboardInterrupt: sys.exit(1) except OSError as e: handle_os_error(e, assignment_id, fed_ledger) except RuntimeError as e: handle_runtime_error(e, assignment_id, fed_ledger) except ValueError as e: handle_value_error(e, assignment_id, fed_ledger) except Exception as e: logger.error(f"Attempt {attempt + 1} failed: {e}") + time.sleep(5) # Add a delay before next attempt if attempt == 2: logger.error( f"Marking assignment {assignment_id} as failed after 3 attempts" ) fed_ledger.mark_assignment_as_failed(assignment_id)
Line range hint
235-239
: Specify the exception to be raised for outdated codeIn the
is_latest_version
function, after logging the error about the local code not being up-to-date, the bareraise
statement should specify the exception to be raised for clarity and proper exception handling.Apply this diff to specify the exception:
if local_commit.hexsha != remote_commit.hexsha: logger.error( "The local code is not up to date with the main branch. Please update your version." ) - raise + raise Exception("Local code is outdated. Please pull the latest changes.")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
- src/validate.py (1 hunks)
🔇 Additional comments (1)
src/validate.py (1)
Line range hint
190-193
: Verify the model after merging adapter weightsAfter merging and unloading the adapter weights with
model = model.merge_and_unload()
, ensure that the model remains fully functional and all necessary attributes are intact. This helps prevent unexpected behavior during model evaluation.
model_name_or_path, | ||
token=HF_TOKEN, | ||
**model_kwargs, | ||
subfolder="", | ||
allow_patterns=[ | ||
"adapter_config.json", | ||
"adapter_model.safetensors", | ||
"special_tokens_map.json", | ||
"tokenizer.json", | ||
"tokenizer_config.json", | ||
"training_args.bin", | ||
"tokenizer.model", | ||
], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replace 'token' with 'use_auth_token' in 'from_pretrained'
The AutoModelForCausalLM.from_pretrained
method does not accept a token
argument. Instead, use use_auth_token=HF_TOKEN
to properly authenticate with Hugging Face for private or gated models.
Apply this diff to correct the argument:
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
- token=HF_TOKEN,
+ use_auth_token=HF_TOKEN,
**model_kwargs,
subfolder="",
allow_patterns=[
"adapter_config.json",
"adapter_model.safetensors",
"special_tokens_map.json",
"tokenizer.json",
"tokenizer_config.json",
"training_args.bin",
"tokenizer.model",
],
)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
model_name_or_path, | |
token=HF_TOKEN, | |
**model_kwargs, | |
subfolder="", | |
allow_patterns=[ | |
"adapter_config.json", | |
"adapter_model.safetensors", | |
"special_tokens_map.json", | |
"tokenizer.json", | |
"tokenizer_config.json", | |
"training_args.bin", | |
"tokenizer.model", | |
], | |
model_name_or_path, | |
use_auth_token=HF_TOKEN, | |
**model_kwargs, | |
subfolder="", | |
allow_patterns=[ | |
"adapter_config.json", | |
"adapter_model.safetensors", | |
"special_tokens_map.json", | |
"tokenizer.json", | |
"tokenizer_config.json", | |
"training_args.bin", | |
"tokenizer.model", | |
], |
"special_tokens_map.json", | ||
"tokenizer.json", | ||
"tokenizer_config.json", | ||
"training_args.bin", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need training_args to run successfully?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it is not always needed. Can help those that want to run custom validation scripts and usually just a few kBs. I'll remove it though for now.
download only necessary model files
Summary by CodeRabbit
New Features
lora_only
parameter in the validation process to filter models.Bug Fixes
Documentation