Skip to content

Commit

Permalink
Fixup errors due to updated dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
cgerum committed Aug 23, 2024
1 parent 5ebf2e6 commit b348430
Show file tree
Hide file tree
Showing 5 changed files with 234 additions and 202 deletions.
2 changes: 1 addition & 1 deletion hannah/nas/performance_prediction/features/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def get_features(nx_graph):
if col not in df.columns:
df[col] = 0
df = df.reindex(sorted(df.columns), axis=1) # Sort to have consistency
return df
return df.astype(np.float32)

def get_list_columns(df):
list_cols = []
Expand Down
4 changes: 3 additions & 1 deletion hannah/nas/test/test_nas_graph_dataset_for_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from hannah.nas.graph_conversion import model_to_graph
from hannah.nas.performance_prediction.features.dataset import OnlineNASGraphDataset, get_features, to_dgl_graph
from hannah.models.embedded_vision_net.models import embedded_vision_net, search_space
import pandas as pd


def test_online_dataset():
Expand All @@ -16,6 +17,8 @@ def test_online_dataset():
x = torch.ones(input.shape())
nx_graph = model_to_graph(model, x)
fea = get_features(nx_graph)

fea = fea.astype('float32')
for i, n in enumerate(nx_graph.nodes):
nx_graph.nodes[n]['features'] = fea.iloc[i].to_numpy()
dgl_graph = to_dgl_graph(nx_graph)
Expand All @@ -24,7 +27,6 @@ def test_online_dataset():
labels = [1.0]

dataset = OnlineNASGraphDataset(graphs, labels)
print()


if __name__ == '__main__':
Expand Down
38 changes: 0 additions & 38 deletions hannah/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import functools
import importlib
import operator
import os
import types
from importlib.util import find_spec
from typing import Callable, List, Tuple, Union

import pkg_resources
from pkg_resources import DistributionNotFound

try:
from packaging.version import Version
except (ModuleNotFoundError, DistributionNotFound):
Version = None


def _module_available(module_path: str) -> bool:
Expand All @@ -52,34 +42,6 @@ def _module_available(module_path: str) -> bool:
# Sometimes __spec__ can be None and gives a ValueError
return True


def _compare_version(
package: str, op: Callable, version: str, use_base_version: bool = False
) -> bool:
"""Compare package version with some requirements.
>>> _compare_version("torch", operator.ge, "0.1")
True
>>> _compare_version("does_not_exist", operator.ge, "0.0")
False
"""
try:
pkg = importlib.import_module(package)
except (ImportError, DistributionNotFound):
return False
try:
if hasattr(pkg, "__version__"):
pkg_version = Version(pkg.__version__)
else:
# try pkg_resources to infer version
pkg_version = Version(pkg_resources.get_distribution(package).version)
except TypeError:
# this is mocked by Sphinx, so it should return True to generate all summaries
return True
if use_base_version:
pkg_version = Version(pkg_version.base_version)
return op(pkg_version, Version(version))


_TORCH_AVAILABLE = _module_available("torch")
_KORNIA_AVAILABLE = _module_available("kornia")
_COCO_AVAILABLE = _module_available("pycocotools")
Expand Down
Loading

0 comments on commit b348430

Please sign in to comment.