diff --git a/hannah/nas/performance_prediction/features/dataset.py b/hannah/nas/performance_prediction/features/dataset.py index c3067264..e976f162 100644 --- a/hannah/nas/performance_prediction/features/dataset.py +++ b/hannah/nas/performance_prediction/features/dataset.py @@ -149,14 +149,13 @@ def get_features(nx_graph): df = unfold_columns(df, columns=get_list_columns(df)) dataframes.append(df) df = pd.concat(dataframes) - # df.dropna(axis = 0, how = 'all', inplace = True) df = pd.get_dummies(df, dummy_na=True) df = df.fillna(0) for col in COLUMNS: if col not in df.columns: df[col] = 0 df = df.reindex(sorted(df.columns), axis=1) # Sort to have consistency - return df.astype(np.float32) + return df def get_list_columns(df): list_cols = [] diff --git a/poetry.lock b/poetry.lock index caeabb88..ee127b02 100644 --- a/poetry.lock +++ b/poetry.lock @@ -964,12 +964,12 @@ files = [ [[package]] name = "dgl" -version = "2.3.0+cu121" +version = "2.3.0" description = "Deep Graph Library" optional = false python-versions = "*" files = [ - {file = "dgl-2.3.0+cu121-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:0423c4e819482077aee0c8c70dec2bfc160ada86ea473313d03bf61045d6ae17"}, + {file = "dgl-2.3.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:01f4e209dacc71a1ca28cfccd9f24f8ff16f80ff27d829fdc641dd50afa5b499"}, ] [package.dependencies] @@ -984,16 +984,16 @@ tqdm = "*" [package.source] type = "url" -url = "https://data.dgl.ai/wheels/torch-2.3/cu121/dgl-2.3.0%2Bcu121-cp310-cp310-manylinux1_x86_64.whl" +url = "https://data.dgl.ai/wheels/torch-2.3/dgl-2.3.0-cp310-cp310-manylinux1_x86_64.whl" [[package]] name = "dgl" -version = "2.3.0+cu121" +version = "2.3.0" description = "Deep Graph Library" optional = false python-versions = "*" files = [ - {file = "dgl-2.3.0+cu121-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:634cabad7664f798d02752124ce0b1c5cb23b635721b98e2e5fadba0a160f8df"}, + {file = "dgl-2.3.0-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:0a6dda40d30ec564a6b55ee32e072660ded702fdedb0126bf339f7f01c35fafc"}, ] [package.dependencies] @@ -1008,16 +1008,16 @@ tqdm = "*" [package.source] type = "url" -url = "https://data.dgl.ai/wheels/torch-2.3/cu121/dgl-2.3.0%2Bcu121-cp311-cp311-manylinux1_x86_64.whl" +url = "https://data.dgl.ai/wheels/torch-2.3/dgl-2.3.0-cp311-cp311-manylinux1_x86_64.whl" [[package]] name = "dgl" -version = "2.3.0+cu121" +version = "2.3.0" description = "Deep Graph Library" optional = false python-versions = "*" files = [ - {file = "dgl-2.3.0+cu121-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:ec84ddaf3f4b20abbefffae45a32ca535e2beeffca8033b84230fbf09b5d5802"}, + {file = "dgl-2.3.0-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:7655ab0d454842c91f1a85e639951dcbe994bed25a71bd2d27f066b962808767"}, ] [package.dependencies] @@ -1032,16 +1032,16 @@ tqdm = "*" [package.source] type = "url" -url = "https://data.dgl.ai/wheels/torch-2.3/cu121/dgl-2.3.0%2Bcu121-cp312-cp312-manylinux1_x86_64.whl" +url = "https://data.dgl.ai/wheels/torch-2.3/dgl-2.3.0-cp312-cp312-manylinux1_x86_64.whl" [[package]] name = "dgl" -version = "2.3.0+cu121" +version = "2.3.0" description = "Deep Graph Library" optional = false python-versions = "*" files = [ - {file = "dgl-2.3.0+cu121-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:199cf99826735c8009ee84efe5bc2db03134445c1e1499bc44d458aec5675cfc"}, + {file = "dgl-2.3.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:56f9c13782b202c49d4bf3bd9a47e5c52466a98087eb443a29feeb53934bc82f"}, ] [package.dependencies] @@ -1056,7 +1056,7 @@ tqdm = "*" [package.source] type = "url" -url = "https://data.dgl.ai/wheels/torch-2.3/cu121/dgl-2.3.0%2Bcu121-cp39-cp39-manylinux1_x86_64.whl" +url = "https://data.dgl.ai/wheels/torch-2.3/dgl-2.3.0-cp39-cp39-manylinux1_x86_64.whl" [[package]] name = "distlib" @@ -5968,4 +5968,4 @@ vision = ["albumentations", "gdown", "imagecorruptions", "kornia", "pycocotools" [metadata] lock-version = "2.0" python-versions = ">=3.9 <3.13" -content-hash = "0966a8e6cfa7a91b47a9ba5d503e09c990d3c4c6db106e51eb2f2960f62306ae" +content-hash = "949f73c192a0b738244600b4098ca1b1bcca1ef88b5c913172c335295cad4311" diff --git a/pyproject.toml b/pyproject.toml index aae0c54f..b7a798c2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,12 +62,18 @@ onnx = "^1.16.0" spox = "^0.12.0" optree = "^0.11.0" pydantic = "^2.8.2" -dgl = [ - {platform = 'linux', python='~3.9', url = "https://data.dgl.ai/wheels/torch-2.3/cu121/dgl-2.3.0%2Bcu121-cp39-cp39-manylinux1_x86_64.whl"}, - {platform = 'linux', python='~3.10', url = "https://data.dgl.ai/wheels/torch-2.3/cu121/dgl-2.3.0%2Bcu121-cp310-cp310-manylinux1_x86_64.whl"}, - {platform = 'linux', python='~3.11', url = "https://data.dgl.ai/wheels/torch-2.3/cu121/dgl-2.3.0%2Bcu121-cp311-cp311-manylinux1_x86_64.whl"}, - {platform = 'linux', python='~3.12', url = "https://data.dgl.ai/wheels/torch-2.3/cu121/dgl-2.3.0%2Bcu121-cp312-cp312-manylinux1_x86_64.whl"}, +#dgl = [ +# {platform = 'linux', python='~3.9', url = "https://data.dgl.ai/wheels/torch-2.3/cu121/dgl-2.3.0%2Bcu121-cp39-cp39-manylinux1_x86_64.whl"}, +# {platform = 'linux', python='~3.10', url = "https://data.dgl.ai/wheels/torch-2.3/cu121/dgl-2.3.0%2Bcu121-cp310-cp310-manylinux1_x86_64.whl"}, +# {platform = 'linux', python='~3.11', url = "https://data.dgl.ai/wheels/torch-2.3/cu121/dgl-2.3.0%2Bcu121-cp311-cp311-manylinux1_x86_64.whl"}, +# {platform = 'linux', python='~3.12', url = "https://data.dgl.ai/wheels/torch-2.3/cu121/dgl-2.3.0%2Bcu121-cp312-cp312-manylinux1_x86_64.whl"}, +#] +dgl = [ + {platform = 'linux', python='~3.9', url = "https://data.dgl.ai/wheels/torch-2.3/dgl-2.3.0-cp39-cp39-manylinux1_x86_64.whl"}, + {platform = 'linux', python='~3.10', url = "https://data.dgl.ai/wheels/torch-2.3/dgl-2.3.0-cp310-cp310-manylinux1_x86_64.whl"}, + {platform = 'linux', python='~3.11', url = "https://data.dgl.ai/wheels/torch-2.3/dgl-2.3.0-cp311-cp311-manylinux1_x86_64.whl"}, + {platform = 'linux', python='~3.12', url = "https://data.dgl.ai/wheels/torch-2.3/dgl-2.3.0-cp312-cp312-manylinux1_x86_64.whl"}, ] @@ -75,6 +81,7 @@ dgl = [ + [tool.poetry.dev-dependencies] pytest = ">=7.2.0" pre_commit = ">=2.7.1"