-
Notifications
You must be signed in to change notification settings - Fork 3
/
eye_refocusing.py
163 lines (136 loc) · 6.47 KB
/
eye_refocusing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
# core packages
import json
import numpy as np
import matlab.engine
# training framework
from framework.helpers import matlab_helper as mlh
from framework.helpers import logging_helper as lh
from framework.model import Model
from framework.data_generator import DataGenerator
from framework.training_framework import TrainingFramework
# create a new, module-level logger
logger = lh.get_main_module_logger()
# global TrainingFramework instance
the_framework = None
# object for wrapping the handling of data generation
class EyeRefocusingDataGenerator(DataGenerator):
# init
def __init__(self, framework):
super().__init__(framework)
# get training vars in a file name format
def get_file_name_prefix(self):
return "ns[{ns}]_nf[{nf}]_np[{np}]_nsd[{nsd}]_nr[{nr}]".format(
ns=self.config.data_generator.num_samples,
nf=self.config.data_generator.num_focus_steps,
np=self.config.data_generator.num_passes,
nsd=self.config.data_generator.num_subdivisions,
nr=self.config.data_generator.num_rays)
# generate eye parameter samples
@DataGenerator.GeneratorCallback(dataset='eye', whole_dataset=True, uses_matlab=False, depends=[])
def generate_eye_dataset(self):
num_focus_steps = self.config.data_generator.num_focus_steps
num_samples_per_focus = self.config.data_generator.num_samples // num_focus_steps
focus_distance_id = self.param_info.params['FocusDioptres'].col_id
np.random.seed(self.config.data_generator.random_seed)
focus_samples_normalized = np.linspace(0.0, 1.0, num_focus_steps)
focus_samples_normalized = np.repeat(focus_samples_normalized, num_samples_per_focus)
param_samples_normalized = np.random.uniform(low=0.0, high=1.0, size=(num_samples_per_focus, self.param_info.num_eye_params))
param_samples_normalized = np.tile(param_samples_normalized, (num_focus_steps, 1))
param_samples_normalized[:, focus_distance_id] = focus_samples_normalized
param_samples = self.param_info.denormalize_eye_params(param_samples_normalized)
return param_samples
# generates aberration samples
@DataGenerator.GeneratorCallback(dataset='refocus', whole_dataset=False, uses_matlab=True, threaded=True, depends=['eye'])
def generate_refocus_dataset(self, environment, sample_id, batch_sample_id):
# data generation parameters
param_info = environment.data_generator.param_info
eye_params = environment.data_generator.datasets['eye'].dataframe
config = environment.data_generator.config.data_generator
# init the result to NaN
result = np.empty((1, environment.num_outputs), dtype=float)
result[:] = np.NaN
# compute and store the aberrations
try: # attempt to compute the aberrations through MATLAB
# construct a new eye instance
stage = 'Eye construction'
eye = mlh.call_fn(environment.matlab_instance.EyeParametric)
# set the params for our eye instance
params = eye_params[sample_id].tolist()
eye_properties = set(mlh.call_fn(environment.matlab_instance.properties, eye))
for param in param_info.param_list:
if param.domain == 'eye' and param.name in eye_properties:
stage = 'Eye parameter: ' + param.name
eye = mlh.call_fn(
environment.matlab_instance.SetEyeParameter,
eye, param.name,
params[param.col_id])
# compute the eye parameters for the eye
stage = 'Make eye elements'
eye = mlh.call_fn(environment.matlab_instance.MakeElements, eye)
# extract the parameters for refocusing
focus_distance = 1.0 / param_info.get_param('FocusDioptres', params)
lens_diameter = param_info.get_param('LensD', params)
# compute the refocused parameters
stage = 'Eye refocusing'
refocused_eye, *_ = mlh.call_fn(
environment.matlab_instance.FocusAt,
eye,
focus_distance,
587.56,
config.num_passes,
float(config.num_rays),
config.num_subdivisions,
matlab.double([lens_diameter * 0.91, lens_diameter]),
nargout=5)
# extract the original lens diameter
stage = 'Get original lens diameter'
original_ld = mlh.call_fn(
environment.matlab_instance.GetEyeParameter,
eye,
'LensD')
# extract the original aqueous thickness
stage = 'Get original ACD'
original_acd = mlh.call_fn(
environment.matlab_instance.GetEyeParameter,
eye,
'AqueousT')
# extract the refocused lens diameter
stage = 'Get refocused lens diameter'
refocused_ld = mlh.call_fn(
environment.matlab_instance.GetEyeParameter,
refocused_eye,
'LensD')
# extract the refocused aqueous thickness
stage = 'Get refocused ACD'
refocused_acd = mlh.call_fn(
environment.matlab_instance.GetEyeParameter,
refocused_eye,
'AqueousT')
# write out the computed results
stage = 'Storing results'
result[0, 0] = refocused_ld - original_ld
result[0, 1] = refocused_acd - original_acd
stage = 'Delete the eye instance'
mlh.call_fn(
environment.matlab_instance.DeleteSelf,
eye,
nargout=0)
stage = 'Delete the refocused eye instance'
mlh.call_fn(
environment.matlab_instance.DeleteSelf,
refocused_eye,
nargout=0)
except Exception as exc: # log the wrong eye parameters
self._handle_matlab_error(exc, environment, sample_id, stage, param_info, eye_params)
# return the result
return result
# object for wrapping the handling of model training
class EyeRefocusingModel(Model):
# init
def __init__(self, framework):
super().__init__(framework)
if __name__ == '__main__':
the_framework = TrainingFramework(
data_generator_cls=EyeRefocusingDataGenerator,
model_cls=EyeRefocusingModel)
the_framework.run()