Skip to content

Commit

Permalink
-created a Thread for the fitting process
Browse files Browse the repository at this point in the history
- thereby, the iter_cb parameter of the lmfit model class is now accessible to be used to interrupt the fitting procedure
- finally, this solves #5
  • Loading branch information
Julian-Hochhaus committed Jun 27, 2023
1 parent 369940b commit a552ec5
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 23 deletions.
25 changes: 25 additions & 0 deletions Python/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,31 @@ def accept(self):
remove_idx, remove_text = self.getRemoveOption()
self.removeOptionChanged.emit(remove_idx,remove_text)
super().accept()
class FitThread(QtCore.QThread):
fitting_finished = QtCore.pyqtSignal(object)

def __init__(self, model=None, data=None, params=None, x=None,weights=None, y=None):
super().__init__()
self.fit_interrupted = False
self.model = model
self.data = data
self.params = params
self.x = x
self.weights = weights
self.y= y
self.result=None

def run(self):
self.fit_interrupted = False
self.result = self.model.fit(self.data, params=self.params, x=self.x,weights=self.weights, iter_cb=self.per_iteration, y=self.y)
self.fitting_finished.emit(self.result)

def per_iteration(self, pars, iteration, resid, *args, **kws):
if self.fit_interrupted:
return True
#print(" ITER ", iteration, [f"{p.name} = {p.value:.5f}" for p in pars.values()])
def interrupt_fit(self):
self.fit_interrupted = True

class RemoveAndEditTableWidget(QtWidgets.QTableWidget):
headerTextChanged = QtCore.pyqtSignal(int, str)
Expand Down
79 changes: 56 additions & 23 deletions Python/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def __init__(self):
self.initUI()

def initUI(self):
self.fit_thread=FitThread(self)
self.version = 'LG4X: LMFit GUI for XPS curve fitting v2.0.4'
self.floating = '.4f'
self.setGeometry(700, 500, 1600, 900)
Expand Down Expand Up @@ -290,7 +291,7 @@ def initUI(self):
btn_undoFit.clicked.connect(self.one_step_back_in_params_history)
fitbuttons_layout.addWidget(btn_undoFit)
# Interrupt fit Button
btn_interrupt = QtWidgets.QPushButton('Interrupt fitting (not implemented)', self)
btn_interrupt = QtWidgets.QPushButton('Interrupt fitting', self)
btn_interrupt.resize(btn_interrupt.sizeHint())
btn_interrupt.clicked.connect(self.interrupt_fit)
fitbuttons_layout.addWidget(btn_interrupt)
Expand Down Expand Up @@ -2037,10 +2038,9 @@ def fit(self):
self.df = np.random.random_sample((points, 2)) + 0.01
self.df[:, 0] = np.linspace(x1, x2, points)
self.ana('sim')

def interrupt_fit(self):
print("does nothing yet")

if self.fit_thread:
self.fit_thread.interrupt_fit()
def one_step_back_in_params_history(self):
"""
Is called if button undo Fit is prest.
Expand Down Expand Up @@ -3149,32 +3149,60 @@ def ana(self, mode):
out = mod.fit(y, pars, x=x, weights=1 / (np.sqrt(self.rows_lightened)), y=y)
else:
out = mod.fit(y, pars, x=x, weights=1 / (np.sqrt(raw_y) * np.sqrt(self.rows_lightened)), y=y)
self.fitting_finished(out, strmode=strmode, mode=mode, x=x,y=y, zeros_in_data=zeros_in_data, raw_x=raw_x, raw_y=raw_y, pars=pars)
else:
try_me_out = self.history_manager(pars)
if try_me_out is not None:
pars, pre = try_me_out
self.pre = pre
self.setPreset(pre[0], pre[1], pre[2], pre[3])
if zeros_in_data:
out = mod.fit(y, pars, x=x, weights=1 / (np.sqrt(self.rows_lightened)), y=raw_y)
self.fit_thread=FitThread(model=mod,data=y, params=pars, x=x, weights=1 / (np.sqrt(self.rows_lightened)),y=raw_y)
self.fit_thread.fitting_finished.connect(
lambda out: self.fitting_finished(out, x=x,y=y, strmode=strmode, mode=mode,
zeros_in_data=zeros_in_data, raw_x=raw_x, raw_y=raw_y,
pars=pars))

self.fit_thread.start()
#out = mod.fit(y, pars, x=x, weights=1 / (np.sqrt(self.rows_lightened)), y=raw_y)

else:
out = mod.fit(y, pars, x=x, weights=1 / (np.sqrt(raw_y) * np.sqrt(self.rows_lightened)), y=raw_y)
self.fit_thread = FitThread(model=mod, data=y, params=pars, x=x,
weights=1 /(np.sqrt(raw_y) * np.sqrt(self.rows_lightened)),
y=raw_y)
self.fit_thread.fitting_finished.connect(lambda out: self.fitting_finished(out, x=x,y=y, strmode=strmode, mode=mode, zeros_in_data=zeros_in_data, raw_x=raw_x, raw_y=raw_y, pars=pars))

self.fit_thread.start()
#out = mod.fit(y, pars, x=x, weights=1 / (np.sqrt(raw_y) * np.sqrt(self.rows_lightened)), y=raw_y)

except Exception as e:
return self.raise_error(window_title="Error: NaN in Model/data!.",
error_message=e.args[0])

def get_attr(self,obj, attr):
"""Format an attribute of an object for printing."""
val = getattr(obj, attr, None)
if val is None:
return 'unknown'
if isinstance(val, int):
return f'{val}'
if isinstance(val, float):
return str(format(val, self.floating))
return repr(val)

def fitting_finished(self, out, x, y, strmode, mode, zeros_in_data, pars, raw_x,raw_y):
comps = out.eval_components(x=x)
# fit results to be checked
for key in out.params:
print(key, "=", out.params[key].value)

# fit results print

results = strmode + ' done: ' + out.method + ', # data: ' + str(out.ndata) + ', # func evals: ' + str(
out.nfev) + ', # varys: ' + str(out.nvarys) + ', r chi-sqr: ' + str(
format(out.redchi, self.floating)) + ', Akaike info crit: ' + str(
format(out.aic, self.floating)) + ', Last run finished: ' + QTime.currentTime().toString()
if self.get_attr(out,'aic') == 'unknown' or self.get_attr(out,'bic') == 'unknown' or self.get_attr(out,'redchi') == 'unknown' or self.get_attr(out,'chisqr') == 'unknown':
results = 'Fitting interrupted: ' + out.method + ', # data: ' + str(out.ndata) + ', # func evals: ' + str(
out.nfev) + ', # varys: ' + str(out.nvarys) + ', r chi-sqr: ' + self.get_attr(out, 'redchi') + ', Akaike info crit: ' + self.get_attr(out,'aic') + ', Last run finished: ' + QTime.currentTime().toString()
else:
results = strmode + ' done: ' + out.method + ', # data: ' + str(out.ndata) + ', # func evals: ' + str(
out.nfev) + ', # varys: ' + str(out.nvarys) + ', r chi-sqr: ' + self.get_attr(out, 'redchi') + ', Akaike info crit: ' + self.get_attr(out,'aic') + ', Last run finished: ' + QTime.currentTime().toString()
self.statusBar().showMessage(results)

# component results into table
Expand Down Expand Up @@ -3217,16 +3245,17 @@ def ana(self, mode):
self.stats_tab.setItem(4, 0, item)
item = QtWidgets.QTableWidgetItem(str(out.nfree))
self.stats_tab.setItem(5, 0, item)
item = QtWidgets.QTableWidgetItem(str(format(out.chisqr, self.floating)))
item = QtWidgets.QTableWidgetItem(self.get_attr(out,'chisqr'))
self.stats_tab.setItem(6, 0, item)
if zeros_in_data:
item = QtWidgets.QTableWidgetItem(str(format(out.redchi, self.floating))+' not weigthed by sqrt(data)')
item = QtWidgets.QTableWidgetItem(
str(format(out.redchi, self.floating)) + ' not weigthed by sqrt(data)')
else:
item = QtWidgets.QTableWidgetItem(str(format(out.redchi, self.floating)))
item = QtWidgets.QTableWidgetItem(self.get_attr(out,'redchi'))
self.stats_tab.setItem(7, 0, item)
item = QtWidgets.QTableWidgetItem(str(format(out.aic, self.floating)))
item = QtWidgets.QTableWidgetItem(self.get_attr(out,'aic'))
self.stats_tab.setItem(8, 0, item)
item = QtWidgets.QTableWidgetItem(str(format(out.bic, self.floating)))
item = QtWidgets.QTableWidgetItem(self.get_attr(out,'bic'))
self.stats_tab.setItem(9, 0, item)
self.stats_tab.resizeColumnsToContents()
self.stats_tab.resizeRowsToContents()
Expand All @@ -3239,7 +3268,7 @@ def ana(self, mode):
if mode == "sim":
self.ar.set_title(r"Simulation mode", fontsize=11)
if mode == 'eva':
plottitle=self.plottitle.text()
plottitle = self.plottitle.text()
if len(plottitle) == 0:
plottitle = self.comboBox_file.currentText().split('/')[-1]
if plottitle != '':
Expand All @@ -3250,7 +3279,8 @@ def ana(self, mode):
strind = self.fitp1.cellWidget(0, 2 * index_pk + 1).currentText()
strind = strind.split(":", 1)[0]
self.ax.fill_between(x, comps[strind + str(index_pk + 1) + '_'] + sum_background + self.static_bg,
sum_background + self.static_bg, label=self.fitp1.horizontalHeaderItem(2*index_pk+1).text())
sum_background + self.static_bg,
label=self.fitp1.horizontalHeaderItem(2 * index_pk + 1).text())
self.ax.plot(x, comps[strind + str(index_pk + 1) + '_'] + sum_background + self.static_bg)
if index_pk == len_idx_pk - 1:
self.ax.plot(x, + sum_background + self.static_bg, label='BG')
Expand All @@ -3273,8 +3303,9 @@ def ana(self, mode):
for index_pk in range(len_idx_pk):
strind = self.fitp1.cellWidget(0, 2 * index_pk + 1).currentText()
strind = strind.split(":", 1)[0]
self.ax.fill_between(x, comps[strind + str(index_pk + 1) + '_'] + self.static_bg + sum_background,
self.static_bg + sum_background, label=self.fitp1.horizontalHeaderItem(2*index_pk+1).text())
self.ax.fill_between(x, comps[strind + str(index_pk + 1) + '_'] + self.static_bg + sum_background,
self.static_bg + sum_background,
label=self.fitp1.horizontalHeaderItem(2 * index_pk + 1).text())
self.ax.plot(x, comps[strind + str(index_pk + 1) + '_'] + self.static_bg + sum_background)
if index_pk == len_idx_pk - 1:
self.ax.plot(x, + self.static_bg + sum_background, label="BG")
Expand Down Expand Up @@ -3304,16 +3335,17 @@ def ana(self, mode):
df_sum = pd.DataFrame(out.best_fit, columns=['sum_fit'])
df_b = pd.DataFrame(sum_background + self.static_bg, columns=['bg'])
if isinstance(self.static_bg, int):
df_b_static = pd.DataFrame([0]*len(sum_background), columns=['bg_static (not used)'])
df_b_static = pd.DataFrame([0] * len(sum_background), columns=['bg_static (not used)'])
else:
df_b_static = pd.DataFrame(self.static_bg, columns=['bg_static'])
self.result = pd.concat([df_raw_x, df_raw_y,df_corrected_x, df_y, df_pks, df_b, df_b_static, df_sum], axis=1)
self.result = pd.concat([df_raw_x, df_raw_y, df_corrected_x, df_y, df_pks, df_b, df_b_static, df_sum], axis=1)
df_bg_comps = pd.DataFrame.from_dict(self.bg_comps, orient='columns')
self.result = pd.concat([self.result, df_bg_comps], axis=1)
for index_pk in range(int(self.fitp1.columnCount() / 2)):
strind = self.fitp1.cellWidget(0, 2 * index_pk + 1).currentText()
strind = strind.split(":", 1)[0]
df_c = pd.DataFrame(comps[strind + str(index_pk + 1) + '_'], columns=[self.fitp1.horizontalHeaderItem(2*index_pk+1).text()])
df_c = pd.DataFrame(comps[strind + str(index_pk + 1) + '_'],
columns=[self.fitp1.horizontalHeaderItem(2 * index_pk + 1).text()])
self.result = pd.concat([self.result, df_c], axis=1)
print(out.fit_report())
lim_reached = False
Expand Down Expand Up @@ -3341,6 +3373,7 @@ def center(self):
self.move(qr.topLeft())

def closeEvent(self, event):
self.interrupt_fit()
event.accept()
sys.exit(0)

Expand Down

0 comments on commit a552ec5

Please sign in to comment.