-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
75 lines (62 loc) · 2.2 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
"""Auxiliary helpers not worth their own modules."""
import os
import random
import numpy as np
from pyutilz.numbalib import set_numba_random_seed
def set_random_seed(seed: int = 42, set_hash_seed: bool = False, set_torch_seed: bool = False):
"""Seed everything ml-related."""
random.seed(seed)
try:
np.random.seed(seed)
except: pass
try:
cp.random.seed(seed)
except: pass
try:
set_numba_random_seed(seed)
except: pass
if set_hash_seed:
os.environ["PYTHONHASHSEED"] = str(seed)
if set_torch_seed:
try:
import torch # pylint: disable=import-outside-toplevel
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True # type: ignore
except: pass
def get_pipeline_last_element(clf) -> object:
for elem_name, elem in clf.named_steps.items():
pass
return elem
def get_full_classifier_name(clf: object) -> str:
clf_name = type(clf).__name__
if clf_name == "TransformedTargetRegressor":
regressor_name = get_full_classifier_name(clf.regressor)
if clf.transformer:
transformer_name = type(clf.transformer).__name__
try:
transformer_name += " " + clf.transformer.method
except:
pass
try:
transformer_name += " " + clf.transformer.output_distribution # QuantileTransformer
except:
pass
else:
try:
transformer_name = clf.func.__name__
except:
transformer_name = "func"
full_clf_name = " -> ".join([regressor_name, transformer_name])
elif clf_name == "Pipeline":
elem = get_pipeline_last_element(clf)
return f"pipe[{get_full_classifier_name(elem)}]"
elif clf_name == "MultiOutputRegressor":
return f"MultiOutputRegressor[{get_full_classifier_name(clf.estimator)}]"
else:
if "Dummy" in clf_name:
full_clf_name = clf_name + "[" + clf.strategy + "]"
else:
full_clf_name = clf_name
return full_clf_name