-
Notifications
You must be signed in to change notification settings - Fork 109
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Machine Learning Optimization Template (#238)
Squash and merge of ML template PR #238
- Loading branch information
1 parent
f9565d2
commit da1818a
Showing
8 changed files
with
638 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# Machine Learning Optimization | ||
Code for reinforcement learning loop with openfasoc generators for optimizing metrics | ||
|
||
## Code Setup | ||
The code is setup as follows: | ||
|
||
The top level directory contains two sub-directories: | ||
* model.py: top level RL script, used to set hyperparameters and run training | ||
* run_training.py: contains all OpenAI Gym environments. These function as the agent in the RL loop and contain information about parameter space, valid action steps and reward. | ||
* eval.py: contains all of the code for evaluation | ||
* gen_spec.py: contains all of the random specification generation | ||
|
||
## Training | ||
Make sure that you have OpenAI Gym and Ray installed. To do this, run the following command: | ||
|
||
To generate the design specifications that the agent trains on, run: | ||
``` | ||
python3.10 gen_specs.py | ||
``` | ||
The result is a yaml file dumped to the ../generators/gdsfactory-gen/. | ||
|
||
To train the agent, open ipython from the top level directory and then: | ||
``` | ||
python3.10 model.py | ||
``` | ||
The training checkpoints will be saved in your home directory under ray\_results. Tensorboard can be used to load reward and loss plots using the command: | ||
|
||
``` | ||
tensorboard --logdir path/to/checkpoint | ||
``` | ||
|
||
## Validation | ||
The evaluation script takes the trained agent and gives it new specs that the agent has never seen before. To generate new design specs, run the gen_specs.py file again with your desired number of specs to validate on. To run validation: | ||
|
||
``` | ||
python3.10 eval.py | ||
``` | ||
|
||
The evaluation result will be saved to the ../generators/gdsfactory-gen/. | ||
|
||
## Results | ||
Please note that results vary greatly based on random seed and spec generation (both for testing and validation). An example spec file is provided that was used to generate the results below. | ||
|
||
<p float="left"> | ||
<img src="image1.png" width="400" /> <img src="image2.png" width="400" /> | ||
</p> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
#training import | ||
import numpy as np | ||
import gym | ||
import ray | ||
import ray.tune as tune | ||
from ray.rllib.algorithms.ppo import PPO | ||
from run_training import Envir | ||
from ../generators/gdsfactory-gen/sky130_nist_tapeout import single_build_and_simulation | ||
import pickle | ||
import yaml | ||
from pathlib import Path | ||
import argparse | ||
|
||
def unlookup(norm_spec, goal_spec): | ||
spec = -1*np.multiply((norm_spec+1), goal_spec)/(norm_spec-1) | ||
return spec | ||
|
||
specs = yaml.safe_load(Path('newnew_eval_3.yaml').read_text()) | ||
|
||
# | ||
#training set up | ||
env_config = { | ||
"generalize":True, | ||
"num_valid":2, | ||
"save_specs":False, | ||
"inputspec":specs, | ||
"run_valid":True, | ||
"horizon":25, | ||
} | ||
|
||
config_eval = { | ||
#"sample_batch_size": 200, | ||
"env": Envir, | ||
"env_config":{ | ||
"generalize":True, | ||
"num_valid":2, | ||
"save_specs":False, | ||
"inputspec":specs, | ||
"run_valid":True, | ||
"horizon":25, | ||
}, | ||
} | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--checkpoint_dir', '-cpd', type=str) | ||
args = parser.parse_args() | ||
env = Envir(env_config=env_config) | ||
|
||
agent = PPO.from_checkpoint("/home/wentian/ray_results/brandnewBound_1/PPO_Envir_7fc09_00000_0_2023-08-18_20-40-42/checkpoint_000015") | ||
|
||
|
||
|
||
norm_spec_ref = env.global_g | ||
spec_num = len(env.specs) | ||
|
||
|
||
rollouts = [] | ||
next_states = [] | ||
obs_reached = [] | ||
obs_nreached = [] | ||
action_array = [] | ||
action_arr_comp = [] | ||
rollout_steps = 0 | ||
reached_spec = 0 | ||
f = open("newnewnew_eval__3.txt", "a") | ||
|
||
while rollout_steps < 100: | ||
rollout_num = [] | ||
state, info = env.reset() | ||
|
||
done = False | ||
truncated = False | ||
reward_total = 0.0 | ||
steps=0 | ||
f.write('new----------------------------------------') | ||
while not done and not truncated: | ||
action = agent.compute_single_action(state) | ||
action_array.append(action) | ||
|
||
next_state, reward, done, truncated, info = env.step(action) | ||
f.write(str(action)+'\n') | ||
f.write(str(reward)+'\n') | ||
f.write(str(done)+'n') | ||
print(next_state) | ||
print(action) | ||
print(reward) | ||
print(done) | ||
reward_total += reward | ||
|
||
rollout_num.append(reward) | ||
next_states.append(next_state) | ||
|
||
state = next_state | ||
|
||
norm_ideal_spec = state[spec_num:spec_num+spec_num] | ||
ideal_spec = unlookup(norm_ideal_spec, norm_spec_ref) | ||
if done == True: | ||
reached_spec += 1 | ||
obs_reached.append(ideal_spec) | ||
action_arr_comp.append(action_array) | ||
action_array = [] | ||
pickle.dump(action_arr_comp, open("action_arr_test", "wb")) | ||
else: | ||
obs_nreached.append(ideal_spec) #save unreached observation | ||
action_array=[] | ||
f.write('done----------------------------------------') | ||
rollouts.append(rollout_num) | ||
print("Episode reward", reward_total) | ||
rollout_steps+=1 | ||
#if out is not None: | ||
#pickle.dump(rollouts, open(str(out)+'reward', "wb")) | ||
pickle.dump(obs_reached, open("opamp_obs_reached_test","wb")) | ||
pickle.dump(obs_nreached, open("opamp_obs_nreached_test","wb")) | ||
|
||
f.write("Specs reached: " + str(reached_spec) + "/" + str(len(obs_nreached))) | ||
print("Specs reached: " + str(reached_spec) + "/" + str(len(obs_nreached))) | ||
|
||
print("Num specs reached: " + str(reached_spec) + "/" + str(1)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
#!/usr/bin/env python3 | ||
## Generate the design specifications and then save to a pickle file | ||
|
||
import numpy as np | ||
import random | ||
import yaml | ||
import os | ||
import argparse | ||
|
||
def gen_data(env, num_specs): | ||
|
||
specs_range = { | ||
"gain_min" : [float(1000338000.0), float(3000338000.0)], | ||
"FOM" : [float(5*10**11), float(5*10**11)] | ||
} | ||
specs_range_vals = list(specs_range.values()) | ||
specs_valid = [] | ||
for spec in specs_range_vals: | ||
if isinstance(spec[0],int): | ||
list_val = [random.randint(int(spec[0]),int(spec[1])) for x in range(0,num_specs)] | ||
else: | ||
list_val = [random.uniform(float(spec[0]),float(spec[1])) for x in range(0,num_specs)] | ||
specs_valid.append(tuple(list_val)) | ||
i=0 | ||
for key,value in specs_range.items(): | ||
specs_range[key] = specs_valid[i] | ||
i+=1 | ||
|
||
output = str(specs_range) | ||
with open(env, 'w') as f: | ||
f.write(output.replace('(','[').replace(')',']').replace(',',',\n')) | ||
|
||
def main(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--num_specs', type=str) | ||
args = parser.parse_args() | ||
|
||
gen_data("newnew_eval_3.yaml", int(50)) | ||
|
||
if __name__=="__main__": | ||
main() |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
#training import | ||
import gym | ||
import ray | ||
import ray.tune as tune | ||
from ray.rllib.algorithms.ppo import PPO | ||
from run_training import Envir | ||
from sky130_nist_tapeout import single_build_and_simulation | ||
sky130_nist_tapeout.path.append('../generators/gdsfactory-gen/') | ||
|
||
import argparse | ||
# | ||
#training set up | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--checkpoint_dir', '-cpd', type=str) | ||
args = parser.parse_args() | ||
ray.init(num_cpus=33, num_gpus=0,include_dashboard=True, ignore_reinit_error=True) | ||
|
||
#configures training of the agent with associated hyperparameters | ||
config_train = { | ||
#"sample_batch_size": 200, | ||
"env": Envir, | ||
"train_batch_size": 1000, | ||
#"sgd_minibatch_size": 1200, | ||
#"num_sgd_iter": 3, | ||
#"lr":1e-3, | ||
#"vf_loss_coeff": 0.5, | ||
#"rollout_fragment_length": 63, | ||
"model":{"fcnet_hiddens": [64, 64]}, | ||
"num_workers": 32, | ||
"env_config":{"generalize":True, "run_valid":False, "horizon":20}, | ||
} | ||
|
||
#Runs training and saves the result in ~/ray_results/train_ngspice_45nm | ||
#If checkpoint fails for any reason, training can be restored | ||
trials = tune.run( | ||
"PPO", #You can replace this string with ppo.PPOTrainer if you want / have customized it | ||
name="brandnewBound_1", # The name can be different. | ||
stop={"episode_reward_mean": 12, "training_iteration": 15}, | ||
checkpoint_freq=1, | ||
config=config_train, | ||
#restore="/home/wentian/ray_results/brandnewBound/PPO_Envir_cc8be_00000_0_2023-08-16_01-11-16/checkpoint_000002", | ||
#restore="/home/wentian/ray_results/brandnewBound/PPO_Envir_f6236_00000_0_2023-08-16_04-40-01/checkpoint_000003", | ||
#restore="/home/wentian/ray_results/brandnewBound/PPO_Envir_4615a_00000_0_2023-08-16_06-58-15/checkpoint_000006" | ||
#restore="/home/wentian/ray_results/brandnewBound/PPO_Envir_d8b02_00000_0_2023-08-17_02-07-41/checkpoint_000012", | ||
restore="/home/wentian/ray_results/brandnewBound_1/PPO_Envir_d6a0f_00000_0_2023-08-18_05-19-43/checkpoint_000012", | ||
) | ||
# |
Oops, something went wrong.