-
Notifications
You must be signed in to change notification settings - Fork 17
/
vizier_verify.py
75 lines (62 loc) · 2.98 KB
/
vizier_verify.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from concurrent import futures
import grpc
import portpicker
from vizier.service import clients
from vizier.service import pyvizier as vz
from vizier.service import vizier_server
from vizier.service import vizier_service_pb2_grpc
NUM_TRIALS = 2
problem = vz.ProblemStatement()
problem.search_space.select_root().add_bool_param(name='bypass')
problem.search_space.select_root().add_bool_param(name='cfu')
problem.search_space.select_root().add_bool_param(name='hardwareDiv')
problem.search_space.select_root().add_bool_param(name='mulDiv')
problem.search_space.select_root().add_bool_param(name='singleCycleShift')
problem.search_space.select_root().add_bool_param(name='singleCycleMulDiv')
problem.search_space.select_root().add_bool_param(name='safe')
problem.search_space.select_root().add_categorical_param(
name='prediction', feasible_values=['none', 'static', 'dynamic', 'dynamic_target'])
problem.search_space.select_root().add_discrete_param(
name='iCacheSize', feasible_values=[0, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384])
problem.search_space.select_root().add_discrete_param(
name='dCacheSize', feasible_values=[0, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384])
problem.metric_information.append(
vz.MetricInformation(
name='cycles', goal=vz.ObjectiveMetricGoal.MINIMIZE))
problem.metric_information.append(
vz.MetricInformation(
name='cells', goal=vz.ObjectiveMetricGoal.MINIMIZE))
study_config = vz.StudyConfig.from_problem(problem)
study_config.algorithm = vz.Algorithm.NSGA2
port = portpicker.pick_unused_port()
address = f'localhost:{port}'
# Setup server.
server = grpc.server(futures.ThreadPoolExecutor(max_workers=100))
# Setup Vizier Service.
servicer = vizier_server.VizierService()
vizier_service_pb2_grpc.add_VizierServiceServicer_to_server(servicer, server)
server.add_secure_port(address, grpc.local_server_credentials())
# Start the server.
server.start()
clients.environment_variables.service_endpoint = address # Server address.
study = clients.Study.from_study_config(
study_config, owner='owner', study_id='example_study_id')
suggestions = study.suggest(count=NUM_TRIALS)
for suggestion in suggestions:
bypass = suggestion.parameters['bypass']
cfu = suggestion.parameters['cfu']
dCacheSize = suggestion.parameters['dCacheSize']
hardwareDiv = suggestion.parameters['hardwareDiv']
iCacheSize = suggestion.parameters['iCacheSize']
mulDiv = suggestion.parameters['mulDiv']
prediction = suggestion.parameters['prediction']
safe = suggestion.parameters['safe']
singleCycleShift = suggestion.parameters['singleCycleShift']
singleCycleMulDiv = suggestion.parameters['singleCycleMulDiv']
cells, cycles = (0, 0)
final_measurement = vz.Measurement({'cycles': cycles, 'cells':cells})
suggestion.complete(final_measurement)
for optimal_trial in study.optimal_trials():
optimal_trial = optimal_trial.materialize()
print("Optimal Trial Suggestion and Objective:", optimal_trial.parameters,
optimal_trial.final_measurement)