-
Notifications
You must be signed in to change notification settings - Fork 15
/
random_tuner.py
152 lines (129 loc) · 5 KB
/
random_tuner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import random
import time
import boto3
import re
import pandas as pd
import numpy as np
#################
# Hyperparameters
#################
class CategoricalParameter():
'''
Class for categorical hyperparameters.
Takes one argument which is a list of possible hyperparameter values.
'''
def __init__(self, values):
self.values = values
def get_value(self):
return random.choice(self.values)
class IntegerParameter():
'''
Class for integer hyperparameters.
Takes two arguments: min_value and then max_value.
'''
def __init__(self, min_value, max_value):
self.min_value = min_value
self.max_value = max_value
def get_value(self):
return random.randint(self.min_value, self.max_value)
class ContinuousParameter():
'''
Class for continuous hyperparameters.
Takes two arguments: min_value and then max_value.
'''
def __init__(self, min_value, max_value):
self.min_value = min_value
self.max_value = max_value
def get_value(self):
return random.uniform(self.min_value, self.max_value)
###############
# Random search
###############
def _get_random_hyperparameter_values(hyperparameters):
'''
Converts a dict using hyperparameter classes to a dict of hyperparameter values.
'''
hps = {}
for hp, definition in hyperparameters.items():
if isinstance(definition, (CategoricalParameter, IntegerParameter,
ContinuousParameter, StaticParameter)):
hps[hp] = definition.get_value()
else:
hps[hp] = definition
return hps
def random_search(train_fn,
hyperparameters,
base_name=None,
max_jobs=100,
max_parallel_jobs=100):
'''
Runs random search for hyperparameters.
Takes in:
train_fn: A function that kicks off a training job based on two positional arguments-
job name and hyperparameter dictionary. Note, wait must be set to False if using .fit()
hyperparameters: A dictonary of hyperparameters defined with hyperparameter classes.
base_name: Base name for training jobs. Defaults to 'random-hp-<timestamp>'.
max_jobs: Total number of training jobs to run.
max_parallel_jobs: Most training jobs to run concurrently. This does not affect the quality
of search, just helps stay under account service limits.
Returns a dictionary of max_jobs job names with associated hyperparameter values.
'''
if base_name is None:
base_name = 'random-hp-' + time.strftime('%Y-%m-%d-%H-%M-%S-%j', time.gmtime())
client = boto3.client('sagemaker')
jobs = {}
running_jobs = {}
for i in range(max_jobs):
job = base_name + '-' + str(i)
hps = _get_random_hyperparameter_values(hyperparameters)
jobs[job] = hps.copy()
train_fn(job, hps)
running_jobs[job] = True
while len(running_jobs) == max_parallel_jobs:
for job in running_jobs.keys():
if client.describe_training_job(TrainingJobName=job)['TrainingJobStatus'] != 'InProgress':
running_jobs.pop(job)
time.sleep(20)
return jobs
################
# Analyze output
################
def get_metrics(jobs, regex):
'''
Gets CloudWatch metrics for training jobs
Takes in:
jobs: A dictionary where training job names are keys.
regex: a regular expression string to parse the objective metric value.
Returns a dictionary of training job names as keys and corresponding list
which contains the objective metric from each log stream.
'''
job_metrics = {}
for job in jobs.keys():
client = boto3.client('logs')
streams = client.describe_log_streams(logGroupName='/aws/sagemaker/TrainingJobs',
logStreamNamePrefix=job + '/')
streams = [s['logStreamName'] for s in streams['logStreams']]
stream_metrics = []
for stream in streams:
events = client.get_log_events(logGroupName='/aws/sagemaker/TrainingJobs',
logStreamName=stream)['events']
message = [e['message'] for e in events]
metrics = []
for m in message:
try:
metrics.append(re.search(regex, m).group(1))
except:
pass
stream_metrics.extend(metrics)
job_metrics[job] = stream_metrics
return job_metrics
def table_metrics(jobs, metrics):
'''
Returns Pandas DataFrame of jobs, hyperparameter values, and objective metric value
'''
job_metrics = jobs.copy()
for job in job_metrics.keys():
objective = float(metrics[job][-1]) if len(metrics[job]) > 0 else np.nan
job_metrics[job].update({'objective': objective,
'job_number': int(job.split('-')[-1])})
return pd.DataFrame.from_dict(job_metrics, orient='index')