Skip to content

Commit

Permalink
Add warning if no available pretrained weights
Browse files Browse the repository at this point in the history
  • Loading branch information
james77777778 committed May 15, 2024
1 parent 436e7fb commit 814aa05
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 1 deletion.
7 changes: 6 additions & 1 deletion kimm/models/ghostnet.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import typing
import warnings

import keras
from keras import backend
Expand Down Expand Up @@ -398,7 +399,7 @@ def __init__(
dropout_rate: float = 0.2,
classes: int = 1000,
classifier_activation: str = "softmax",
weights: typing.Optional[str] = None,
weights: typing.Optional[str] = "imagenet",
name: typing.Optional[str] = None,
**kwargs,
):
Expand All @@ -409,6 +410,10 @@ def __init__(
)
kwargs = self.fix_config(kwargs)
if len(getattr(self, "available_weights", [])) == 0:
warnings.warn(

Check warning on line 413 in kimm/models/ghostnet.py

View check run for this annotation

Codecov / codecov/patch

kimm/models/ghostnet.py#L412-L413

Added lines #L412 - L413 were not covered by tests
f"{self.__class__.__name__} doesn't have pretrained weights "
f"for '{weights}'."
)
weights = None

Check warning on line 417 in kimm/models/ghostnet.py

View check run for this annotation

Codecov / codecov/patch

kimm/models/ghostnet.py#L417

Added line #L417 was not covered by tests
super().__init__(
width=self.width,
Expand Down
5 changes: 5 additions & 0 deletions kimm/models/mobilenet_v3.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import math
import typing
import warnings

import keras
from keras import layers
Expand Down Expand Up @@ -335,6 +336,10 @@ def __init__(
if hasattr(self, "padding"):
kwargs["padding"] = self.padding
if len(getattr(self, "available_weights", [])) == 0:
warnings.warn(

Check warning on line 339 in kimm/models/mobilenet_v3.py

View check run for this annotation

Codecov / codecov/patch

kimm/models/mobilenet_v3.py#L332-L339

Added lines #L332 - L339 were not covered by tests
f"{self.__class__.__name__} doesn't have pretrained weights "
f"for '{weights}'."
)
weights = None

Check warning on line 343 in kimm/models/mobilenet_v3.py

View check run for this annotation

Codecov / codecov/patch

kimm/models/mobilenet_v3.py#L343

Added line #L343 was not covered by tests
super().__init__(
width=self.width,
Expand Down
5 changes: 5 additions & 0 deletions kimm/models/vision_transformer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import typing
import warnings

import keras
from keras import backend
Expand Down Expand Up @@ -176,6 +177,10 @@ def __init__(
)
kwargs = self.fix_config(kwargs)
if len(getattr(self, "available_weights", [])) == 0:
warnings.warn(

Check warning on line 180 in kimm/models/vision_transformer.py

View check run for this annotation

Codecov / codecov/patch

kimm/models/vision_transformer.py#L179-L180

Added lines #L179 - L180 were not covered by tests
f"{self.__class__.__name__} doesn't have pretrained weights "
f"for '{weights}'."
)
weights = None

Check warning on line 184 in kimm/models/vision_transformer.py

View check run for this annotation

Codecov / codecov/patch

kimm/models/vision_transformer.py#L184

Added line #L184 was not covered by tests
super().__init__(
patch_size=self.patch_size,
Expand Down

0 comments on commit 814aa05

Please sign in to comment.