Example from paper #75
Closed
OwlyWizard
started this conversation in
General
Replies: 1 comment 2 replies
-
Hi, thanks for visiting us. I'm very happy to hear meent might be helpful for your research. This is the simple code (not cleaned but sharing to save time). import torch
torch.manual_seed(0)
import numpy as np
import meent
import matplotlib.pyplot as plt
from tqdm import tqdm
import scipy
# @title set basic parameters & functions
N_L = 4
N_C = 64
backend = 2 # Torch
device = 0
pol = 0 # 0: TE, 1: TM
n_I = 1 # n_incidence
n_II = 1.5 # n_transmission
theta = 0 * torch.pi / 180 # angle of incidence
phi = 0 * torch.pi / 180 # angle of rotation
thickness = torch.tensor([1500. / N_L for layer in range(N_L)] + [500.]) # thickness of each layer, from top to bottom.
# Final thickness is only for h_substrate
period = torch.tensor([1000.]) # length of the unit cell. Here it's 1D.
fourier_order = [35]
type_complex = torch.complex128
grating_type = 0 # grating type: 0 for 1D grating without rotation (phi == 0)
n1 = 1.5
n2 = 2.0
def single_wav_FoM(wavelength, field_E, res_x):
if wavelength <= 500:
return torch.sum(torch.abs(field_E[-1][0][2 * res_x // 4:3 * res_x // 4].T[0]) ** 2) / torch.sum(
torch.abs(field_E[-1][0][0:-1].T[0]) ** 2)
elif wavelength > 500 and wavelength <= 600:
return (torch.sum(torch.abs(field_E[-1][0][res_x // 4:2 * res_x // 4].T[0]) ** 2) / torch.sum(
torch.abs(field_E[-1][0][0:-1].T[0]) ** 2)) + (
torch.sum(torch.abs(field_E[-1][0][3 * res_x // 4:-1].T[0]) ** 2) / torch.sum(
torch.abs(field_E[-1][0][0:-1].T[0]) ** 2))
else:
return torch.sum(torch.abs(field_E[-1][0][0:res_x // 4].T[0]) ** 2) / torch.sum(
torch.abs(field_E[-1][0][0:-1].T[0]) ** 2)
def multi_wav_FoM(wavelength_list, mee, res_x, res_y, res_z):
FoM = 0
for wavelength in wavelength_list:
mee.wavelength = wavelength
de_ri, de_ti = mee.conv_solve()
field_cell_meent = mee.calculate_field(res_x=res_x, res_y=res_y, res_z=res_z)
FoM += torch.sum(de_ti) * single_wav_FoM(mee.wavelength, field_cell_meent, res_x)
return FoM
def binarization(n1, n2, ucell):
return torch.where(ucell > (n2 + n1) / 2, n2 * torch.ones_like(ucell), n1 * torch.ones_like(ucell))
def plot_intensity(mee, res_x, res_y, res_z):
RGB_index = 0
plt.figure(figsize=(20, 12))
for wavelength in [651., 551., 451]:
plt.subplot(3, 2, 2 * RGB_index + 1)
mee.wavelength = wavelength
de_ri, de_ti = mee.conv_solve()
field_cell_meent = mee.calculate_field(res_x=res_x, res_y=res_y, res_z=res_z)
intensity = torch.abs(field_cell_meent[:, 0, :, 0]) ** 2
listing = intensity.tolist()
if mee.wavelength <= 500:
color_mapping = 'Blues'
elif mee.wavelength > 500 and mee.wavelength < 600:
color_mapping = 'Greens'
else:
color_mapping = 'Reds'
plt.pcolor(np.flip(np.array(listing), axis=0), cmap=color_mapping)
plt.axvline(x=res_x / 4, color='black', linestyle='--')
plt.axvline(x=2 * res_x / 4, color='black', linestyle='--')
plt.axvline(x=3 * res_x / 4, color='black', linestyle='--')
plt.subplot(3, 2, 2 * RGB_index + 2)
plt.pcolor(np.flip(np.array(listing), axis=0)[0:10], cmap=color_mapping)
plt.axvline(x=res_x / 4, color='black', linestyle='--')
plt.axvline(x=2 * res_x / 4, color='black', linestyle='--')
plt.axvline(x=3 * res_x / 4, color='black', linestyle='--')
RGB_index += 1
plt.show()
# @title optimization progress
sample_size = 20
epoch_size = 500
data_res = np.zeros((sample_size, epoch_size, 3))
data_ucell_input = np.zeros((sample_size, epoch_size, N_L, 1, N_C))
data_ucell_output = np.zeros((sample_size, epoch_size, N_L, 1, N_C))
data_ucell_output_binary = np.zeros((sample_size, epoch_size, N_L+1, 1, N_C))
for sampling in tqdm(range(sample_size)):
ucell_1d_latent = torch.rand(N_L, 1, N_C)
ucell_1d_latent.requires_grad = True
# ucell_1d_m = torch.sigmoid(ucell_1d_latent) * (n2 - n1) + n1
ucell_1d_m = torch.cat((torch.sigmoid(ucell_1d_latent) * (n2 - n1) + n1, n1 * torch.ones(1, 1, N_C)))
wavelength = 699. # just initialization
mee = meent.call_mee(backend=backend, grating_type=grating_type, pol=pol, n_I=n_I, n_II=n_II, theta=theta, phi=phi,
fourier_order=fourier_order, wavelength=wavelength, period=period, ucell=ucell_1d_m,
thickness=thickness, type_complex=type_complex, device=device, fft_type=0, improve_dft=True)
res_x, res_y, res_z = N_C * 4, 1, 10
wavelength_list = [671., 641., 611., 571., 541., 511., 471., 441., 411.]
opt = torch.optim.SGD([ucell_1d_latent], lr=0.9)
for epoch in tqdm(range(epoch_size), leave=False):
loss = - multi_wav_FoM(wavelength_list, mee, res_x, res_y, res_z) / len(wavelength_list)
loss_constraint = - torch.sum(torch.abs((mee.ucell - (n2 + n1) * 0.5) / (n2 - (n2 + n1) * 0.5))) / mee.ucell.nelement()
# loss_constraint was constructed by using Binarization function of this reference paper
# Paper : Constraining Continuous Topology Optimizations to Discrete Solutions for Photonic Applications
# https://pubs.acs.org/doi/10.1021/acsphotonics.2c00862
ucell_input = ucell_1d_latent.detach().numpy().copy()
loss.backward()
opt.step()
opt.zero_grad()
ucell_output = ucell_1d_latent.detach().numpy().copy()
# print("")
# print("loss_eff : ")
# print(-loss_eff.item())
# print("Binraization Level : ")
# print(-loss_constraint)
mee.ucell = binarization(n1, n2, mee.ucell)
ucell_output_binary = mee.ucell.detach().numpy().copy()
Binary_FoM = multi_wav_FoM(wavelength_list, mee, res_x, res_y, res_z).item() / len(wavelength_list)
# Valid.append([Binary_FoM, mee.ucell.tolist()])
# Valid.append([Binary_FoM, mee.ucell.detach().numpy()])
data_res[sampling, epoch] = [-loss.item(), -loss_constraint.item(), Binary_FoM]
# data[sampling, epoch, 1] = -loss_constraint.item()
# data[sampling, epoch, 2] = Binary_FoM
data_ucell_input[sampling, epoch] = ucell_input
data_ucell_output[sampling, epoch] = ucell_output
data_ucell_output_binary[sampling, epoch] = ucell_output_binary
# print("Binarized FoM : ")
# print(Binary_FoM)
mee.ucell = torch.cat((torch.sigmoid(ucell_1d_latent * (1 + epoch * 0.02)) * (n2 - n1) + n1, n1 * torch.ones(1, 1, N_C)))
np.save('data_res.npy', data_res)
np.save('data_ucell_input.npy', data_ucell_input)
np.save('data_ucell_output.npy', data_ucell_output)
np.save('data_ucell_output_binary.npy', data_ucell_output_binary)
ucell_1d_m = binarization(n1, n2, mee.ucell)
mee = meent.call_mee(backend=backend, grating_type=grating_type, pol=pol, n_I=n_I, n_II=n_II, theta=theta, phi=phi,
fourier_order=fourier_order, wavelength=wavelength, period=period, ucell=ucell_1d_m,
thickness=thickness, type_complex=type_complex, device=device, fft_type=0, improve_dft=True)
print(multi_wav_FoM(wavelength_list, mee, res_x, res_y, res_z) / len(wavelength_list))
plot_intensity(mee, res_x, res_y, res_z) |
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi everyone,
first I would like to thanks the people developping this tool, it is truly amazing! I was reading the paper and found some example about a colour routeur, which is what I am interested in. Does someone has a simple example of a colour routeur coded in meent, so it could fasten my learning curve? I was working with Lumerical so far but I wanted to try a differentiable RCWA solver, if possible using pytorch ~
I would be happy to share what I develop, my aim is to optimize a 3D colour routeur
thanks!
Beta Was this translation helpful? Give feedback.
All reactions