Skip to content

Commit

Permalink
add for validate gpu type (#53)
Browse files Browse the repository at this point in the history
* add for validate gpu type

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Feat/use revision to download (#55)

* Update exception.py (#54)

* support for qwen2.5,llama 3.1,gemma-2,phi-3  (#52)

* support for qwen2.5, rename template name qwen1.5 to qwen. llama 3.1 for llama3,gemma-2 for gemma,phi-3 for phi3. update transformers version.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add version limit

* add support name limit

* Update template.py

* delete moe

* Update requirements.txt

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Nick W <[email protected]>

* use revision

---------

Co-authored-by: feng <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* fix data path

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Nick W <[email protected]>
Co-authored-by: Nick W <[email protected]>
  • Loading branch information
4 people authored Oct 3, 2024
1 parent 4ec132a commit 8366410
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 11 deletions.
3 changes: 2 additions & 1 deletion src/client/fed_ledger.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def request_validation_assignment(self, task_id: str):
response = requests.post(url, headers=self.headers)
return response

def submit_validation_result(self, assignment_id: str, loss: float):
def submit_validation_result(self, assignment_id: str, loss: float, gpu_type: str):
url = f"{self.url}/tasks/update-validation-assignment/{assignment_id}"
response = requests.post(
url,
Expand All @@ -26,6 +26,7 @@ def submit_validation_result(self, assignment_id: str, loss: float):
"status": "completed",
"data": {
"loss": loss,
"gpu_type": gpu_type,
},
},
)
Expand Down
9 changes: 9 additions & 0 deletions src/core/gpu_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from torch.cuda import get_device_name


def get_gpu_type():
try:
gpu_name = get_device_name(0)
return gpu_name
except Exception as e:
return f"Error retrieving GPU type: {e}"
11 changes: 7 additions & 4 deletions src/core/hf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,19 @@
api = HfApi()


def download_lora_config(repo_id: str) -> bool:
def download_lora_config(repo_id: str, revision: str) -> bool:
try:
api.hf_hub_download(
repo_id=repo_id, filename="adapter_config.json", local_dir="lora"
repo_id=repo_id,
filename="adapter_config.json",
local_dir="lora",
revision=revision,
)
except EntryNotFoundError:
logger.info("No adapter_config.json found in the repo, assuming full model")
return False
return True


def download_lora_repo(repo_id: str) -> None:
api.snapshot_download(repo_id=repo_id, local_dir="lora")
def download_lora_repo(repo_id: str, revision: str) -> None:
api.snapshot_download(repo_id=repo_id, local_dir="lora", revision=revision)
17 changes: 11 additions & 6 deletions src/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from core.dataset import UnifiedSFTDataset
from core.template import template_dict
from core.hf_utils import download_lora_config, download_lora_repo
from core.gpu_utils import get_gpu_type
from core.constant import SUPPORTED_BASE_MODELS
from core.exception import (
handle_os_error,
Expand Down Expand Up @@ -112,7 +113,7 @@ def load_tokenizer(model_name_or_path: str) -> AutoTokenizer:


def load_model(
model_name_or_path: str, lora_only: bool, val_args: TrainingArguments
model_name_or_path: str, lora_only: bool, revision: str, val_args: TrainingArguments
) -> Trainer:
logger.info(f"Loading model from base model: {model_name_or_path}")

Expand All @@ -127,7 +128,7 @@ def load_model(
device_map=None,
)
# check whether it is a lora weight
if download_lora_config(model_name_or_path):
if download_lora_config(model_name_or_path, revision):
logger.info("Repo is a lora weight, loading model with adapter weights")
with open("lora/adapter_config.json", "r") as f:
adapter_config = json.load(f)
Expand All @@ -136,7 +137,7 @@ def load_model(
base_model, token=HF_TOKEN, **model_kwargs
)
# download the adapter weights
download_lora_repo(model_name_or_path)
download_lora_repo(model_name_or_path, revision)
model = PeftModel.from_pretrained(
model,
"lora",
Expand Down Expand Up @@ -274,6 +275,7 @@ def validate(
assignment_id: str = None,
local_test: bool = False,
lora_only: bool = True,
revision: str = "main",
):
if not local_test and assignment_id is None:
raise ValueError(
Expand All @@ -287,12 +289,13 @@ def validate(
fed_ledger = FedLedger(FLOCK_API_KEY)
parser = HfArgumentParser(TrainingArguments)
val_args = parser.parse_json_file(json_file=validation_args_file)[0]
gpu_type = get_gpu_type()

tokenizer = load_tokenizer(model_name_or_path)
eval_dataset = load_sft_dataset(
eval_file, context_length, template_name=base_model, tokenizer=tokenizer
)
model = load_model(model_name_or_path, lora_only, val_args)
model = load_model(model_name_or_path, lora_only, revision, val_args)
# if model is not loaded, mark the assignment as failed and return
if model is None:
fed_ledger.mark_assignment_as_failed(assignment_id)
Expand All @@ -308,6 +311,7 @@ def validate(
resp = fed_ledger.submit_validation_result(
assignment_id=assignment_id,
loss=LOSS_FOR_MODEL_PARAMS_EXCEED,
gpu_type=gpu_type,
)
# check response is 200
if resp.status_code != 200:
Expand All @@ -330,8 +334,7 @@ def validate(
logger.info("The model can be correctly validated by validators.")
return
resp = fed_ledger.submit_validation_result(
assignment_id=assignment_id,
loss=eval_loss,
assignment_id=assignment_id, loss=eval_loss, gpu_type=gpu_type
)
# check response is 200
if resp.status_code != 200:
Expand Down Expand Up @@ -455,6 +458,7 @@ def loop(
continue
resp = resp.json()
eval_file = download_file(resp["data"]["validation_set_url"])
revision = resp["task_submission"]["data"].get("revision", "main")
assignment_id = resp["id"]

for attempt in range(3):
Expand All @@ -471,6 +475,7 @@ def loop(
assignment_id=resp["id"],
local_test=False,
lora_only=lora_only,
revision=revision,
)
break # Break the loop if no exception
except KeyboardInterrupt:
Expand Down

0 comments on commit 8366410

Please sign in to comment.