forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
checkpoint.py
73 lines (60 loc) · 2.48 KB
/
checkpoint.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
from pathlib import Path
from typing import Any, Dict, Optional
def get_default_model_resource_dir(model_file_path: str) -> Path:
"""
Get the default path to resouce files (which contain files such as the
checkpoint and param files), either:
1. Uses the path from pkg_resources, only works with buck2
2. Uses default path located in examples/models/llama/params
Expected to be called from with a `model.py` file located in a
`executorch/examples/models/<model_name>` directory.
Args:
model_file_path: The file path to the eager model definition.
For example, `executorch/examples/models/llama/model.py`,
where `executorch/examples/models/llama` contains all
the llama2-related files.
Returns:
The path to the resource directory containing checkpoint, params, etc.
"""
try:
import pkg_resources
# 1st way: If we can import this path, we are running with buck2 and all resources can be accessed with pkg_resources.
# pyre-ignore
from executorch.examples.models.llama import params # noqa
# Get the model name from the cwd, assuming that this module is called from a path such as
# examples/models/<model_name>/model.py.
model_name = Path(model_file_path).parent.name
resource_dir = Path(
pkg_resources.resource_filename(
f"executorch.examples.models.{model_name}", "params"
)
)
except:
# 2nd way.
resource_dir = Path(model_file_path).absolute().parent / "params"
return resource_dir
def get_checkpoint_dtype(checkpoint: Dict[str, Any]) -> Optional[str]:
"""
Get the dtype of the checkpoint, returning "None" if the checkpoint is empty.
"""
dtype = None
if len(checkpoint) > 0:
first_key = next(iter(checkpoint))
first = checkpoint[first_key]
dtype = first.dtype
mismatched_dtypes = [
(key, value.dtype)
for key, value in checkpoint.items()
if value.dtype != dtype
]
if len(mismatched_dtypes) > 0:
print(
f"Mixed dtype model. Dtype of {first_key}: {first.dtype}. Mismatches in the checkpoint: {mismatched_dtypes}"
)
return dtype