-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathmain.py
155 lines (128 loc) · 7.04 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
"""
Speckle-free Holography with Partially Coherent Light Sources and Camera-in-the-loop Calibration:
This is the main executive script used for the phase optimization using SGD + camera-in-the-loop (CITL).
This code and data is released under the Creative Commons Attribution-NonCommercial 4.0 International license (CC BY-NC.) In a nutshell:
# The license is only for non-commercial use (commercial licenses can be obtained from Stanford).
# The material is provided as-is, with no warranties whatsoever.
# If you publish any code, data, or scientific work based on this, please cite our work.
@article{Peng:2021:PartiallyCoherent,
author = {Yifan Peng and Suyeon Choi and Jonghyun Kim and Gordon Wetzstein },
title = {Speckle-free holography with partially coherent light sources and camera-in-the-loop calibration},
journal = {Science Advances},
volume = {7},
number = {46},
pages = {eabg5040},
year = {2021},
doi = {10.1126/sciadv.abg5040}
-----
$ python main.py --channel=0 --algorithm=SGD --root_path=./phases
"""
import os
import sys
sys.path.append('neural-holography')
import cv2
import torch
import torch.nn as nn
import configargparse
from torch.utils.tensorboard import SummaryWriter
import utils.utils as utils
from utils.augmented_image_loader import ImageLoader
from utils.modules import SGD, PhysicalProp
from propagation_ASM import propagation_ASM
from propagation_partial import PartialProp
# Command line argument processing
p = configargparse.ArgumentParser()
p.add('-c', '--config_filepath', required=False, is_config_file=True, help='Path to config file.')
p.add_argument('--channel', type=int, default=1, help='Red:0, green:1, blue:2')
p.add_argument('--method', type=str, default='SGD', help='Type of algorithm')
p.add_argument('--prop_model', type=str, default='ASM', help='Type of propagation model, ASM or model')
p.add_argument('--root_path', type=str, default='./phases', help='Directory where optimized phases will be saved.')
p.add_argument('--data_path', type=str, default='./neural-holography/data', help='Directory for the dataset')
p.add_argument('--src_type', type=str, default='sLED', help='sLED or LED')
p.add_argument('--citl', type=utils.str2bool, default=False, help='Use of Camera-in-the-loop optimization with SGD')
p.add_argument('--experiment', type=str, default='', help='Name of experiment')
p.add_argument('--lr', type=float, default=6e-3, help='Learning rate for phase variables (for SGD)')
p.add_argument('--lr_s', type=float, default=1e-3, help='Learning rate for learnable scale (for SGD)')
p.add_argument('--num_iters', type=int, default=1000, help='Number of iterations (SGD)')
# parse arguments
opt = p.parse_args()
run_id = f'{opt.experiment}_{opt.method}_{opt.prop_model}' # {algorithm}_{prop_model} format
if opt.citl:
run_id = f'{run_id}_citl'
channel = opt.channel # Red:0 / Green:1 / Blue:2
chan_str = ('red', 'green', 'blue')[channel]
print(f' - optimizing phase with {opt.method}/{opt.prop_model} ... ')
if opt.citl:
print(f' - with camera-in-the-loop ...')
# Hyperparameters setting
cm, mm, um, nm = 1e-2, 1e-3, 1e-6, 1e-9
prop_dist = (10 * cm, 10 * cm, 10 * cm)[channel] # propagation distance from SLM plane to target plane
wavelength = (634.8 * nm, 510 * nm, 450 * nm)[channel] # SLED
if opt.src_type == 'LED':
wavelength = (633 * nm, 532 * nm, 460 * nm)[channel] # LED
feature_size = (6.4 * um, 6.4 * um) # SLM pitch
slm_res = (1080, 1920) # resolution of SLM
image_res = (1080, 1920)
roi_res = (880, 1600) # regions of interest (to penalize for SGD)
dtype = torch.float32 # default datatype (Note: the result may be slightly different if you use float64, etc.)
device = torch.device('cuda') # The gpu you are using
# Options for the algorithm
loss = nn.MSELoss().to(device) # loss functions to use (try other loss functions!)
s0 = 1.0 # initial scale
root_path = os.path.join(opt.root_path, run_id, chan_str) # path for saving out optimized phases
# Tensorboard writer
summaries_dir = os.path.join(root_path, 'summaries')
utils.cond_mkdir(summaries_dir)
writer = SummaryWriter(summaries_dir)
# Hardware setup for CITL
if opt.citl:
camera_prop = PhysicalProp(channel, laser_arduino=True, roi_res=(roi_res[1], roi_res[0]), slm_settle_time=0.12,
range_row=(220, 1000), range_col=(300, 1630),
patterns_path=f'F:/citl/calibration',
show_preview=True)
else:
camera_prop = None
# Simulation model
if opt.prop_model == 'ASM':
propagator = propagation_ASM # Ideal model
elif opt.prop_model.upper() == 'MODEL':
propagator = PartialProp(distance=prop_dist, feature_size=feature_size, batch_size=12,
wavelength_central=wavelength, num_wvls=15,
sample_wavelength_rate=1*nm,
randomly_sampled=True,
use_sampling_pool=True,
f_col=200*mm,
source_diameter=75*um,
source_amp_sigma=30*um,
src_type=opt.src_type, # 'sLED' or 'LED'
device=device).to(device)
propagator.eval()
# Select Phase generation method, algorithm
if opt.method == 'SGD':
phase_only_algorithm = SGD(prop_dist, wavelength, feature_size, opt.num_iters, roi_res, root_path,
opt.prop_model, propagator, loss, opt.lr, opt.lr_s, s0, opt.citl, camera_prop, writer, device)
# Augmented image loader (if you want to shuffle, augment dataset, put options accordingly.)
image_loader = ImageLoader(opt.data_path, channel=channel,
image_res=image_res, homography_res=roi_res,
crop_to_homography=True,
shuffle=False, vertical_flips=False, horizontal_flips=False)
# Loop over the dataset
for k, target in enumerate(image_loader):
# get target image
target_amp, target_res, target_filename = target
target_path, target_filename = os.path.split(target_filename[0])
target_idx = target_filename.split('_')[-1]
target_amp = target_amp.to(device)
print(target_idx)
# if you want to separate folders by target_idx or whatever, you can do so here.
phase_only_algorithm.init_scale = s0 * utils.crop_image(target_amp, roi_res, stacked_complex=False).mean()
phase_only_algorithm.phase_path = os.path.join(root_path)
# run algorithm (See algorithm_modules.py and algorithms.py)
# iterative methods, initial phase: random guess
init_phase = (-0.5 + 1.0 * torch.rand(1, 1, *slm_res)).to(device)
final_phase = phase_only_algorithm(target_amp, init_phase)
# save the final result somewhere.
phase_out_8bit = utils.phasemap_8bit(final_phase.cpu().detach(), inverted=True)
utils.cond_mkdir(root_path)
cv2.imwrite(os.path.join(root_path, f'{target_idx}.png'), phase_out_8bit)
print(f' - Done, result: --root_path={root_path}')