Skip to content

Commit

Permalink
quality code with sonarqub, update notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
luckyjim committed Feb 21, 2024
1 parent ab4f131 commit c77c32a
Show file tree
Hide file tree
Showing 5 changed files with 217 additions and 125 deletions.
233 changes: 169 additions & 64 deletions examples/basis/class_Handling3dTraces.ipynb

Large diffs are not rendered by default.

20 changes: 12 additions & 8 deletions grand/basis/du_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@


def closest_node(node, nodes): # pragma: no cover
"""Simple computation of distance mouse DU
:param node:
:param nodes:
"""
nodes = np.asarray(nodes)
dist_2 = np.sum((nodes - node) ** 2, axis=1)
return np.argmin(dist_2)
Expand Down Expand Up @@ -54,7 +59,7 @@ def init_pos_id(self, du_pos, du_id=None):
self.area_km2 = -1
self.idx2idt = du_id
assert isinstance(self.du_pos, np.ndarray)
assert isinstance(self.idx2idt, list) or isinstance(self.idx2idt, np.ndarray)
assert isinstance(self.idx2idt, (list, np.ndarray))
assert du_pos.shape[0] == len(du_id)
assert du_pos.shape[1] == 3

Expand Down Expand Up @@ -230,10 +235,9 @@ def plot_footprint_4d(
def subplot(plt_axis, a_values, cpnt="", scale="log"):
ax1 = plt_axis
size_circle = 80
cur_idx_plot = -1

ax1.set_title(cpnt)
if type(scale) is str:
if isinstance(scale, str):
my_cmaps = "Blues"
vmin = np.nanmin(a_values)
vmax = np.nanmax(a_values)
Expand Down Expand Up @@ -262,10 +266,10 @@ def subplot(plt_axis, a_values, cpnt="", scale="log"):
# plt.xlabel("[m]")
return scm

fig, ax = plt.subplots(2, 2)
fig, ax2 = plt.subplots(2, 2)

t_max, _ = o_tr.get_tmax_vmax()
ret_scat = subplot(ax[0, 0], t_max, cpnt="Time of max value", scale="lin")
ret_scat = subplot(ax2[0, 0], t_max, cpnt="Time of max value", scale="lin")
fig.colorbar(ret_scat, label="ns")
# same scale for
if same_scale:
Expand All @@ -274,11 +278,11 @@ def subplot(plt_axis, a_values, cpnt="", scale="log"):
norm_user = colors.Normalize(vmin=vmin, vmax=vmax)
else:
norm_user = "lin"
ret_scat = subplot(ax[1, 0], v_plot[:, 0], f"{title} {o_tr.axis_name[0]}", norm_user)
ret_scat = subplot(ax2[1, 0], v_plot[:, 0], f"{title} {o_tr.axis_name[0]}", norm_user)
fig.colorbar(ret_scat, label=unit)
ret_scat = subplot(ax[0, 1], v_plot[:, 1], f"{title} {o_tr.axis_name[1]}", norm_user)
ret_scat = subplot(ax2[0, 1], v_plot[:, 1], f"{title} {o_tr.axis_name[1]}", norm_user)
fig.colorbar(ret_scat, label=unit)
ret_scat = subplot(ax[1, 1], v_plot[:, 2], f"{title} {o_tr.axis_name[2]}", norm_user)
ret_scat = subplot(ax2[1, 1], v_plot[:, 2], f"{title} {o_tr.axis_name[2]}", norm_user)
fig.colorbar(ret_scat, label=unit)

def plot_footprint_time(self, a_time, a3_values, title=""): # pragma: no cover
Expand Down
33 changes: 16 additions & 17 deletions grand/basis/signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,23 +108,22 @@ def find_max_with_parabola_interp(x_trace, y_trace, idx_max, factor_hill=0.8):
logger.debug(f"{x_trace[b_idx]}\t{x_trace[e_idx]}")
if (e_idx - b_idx) <= 2:
return find_max_with_parabola_interp_3pt(x_trace, y_trace, idx_max)
else:
logger.debug(f"Parabola interp: mode hill")
# mode hill
y_hill = y_trace[b_idx : e_idx + 1] - y_trace[b_idx]
x_hill = x_trace[b_idx : e_idx + 1] - x_trace[b_idx]
mat = np.empty((x_hill.shape[0], 3), dtype=np.float32)
mat[:, 2] = 1
mat[:, 1] = x_hill
mat[:, 0] = x_hill * x_hill
sol = np.linalg.lstsq(mat, y_hill, rcond=None)[0]
if -1e-5 < sol[0] and sol[0] < 1e-5:
# very flat case
return x_trace[idx_max], y_trace[idx_max]
x_m = -sol[1] / (2 * sol[0])
x_max = x_trace[b_idx] + x_m
y_max = y_trace[b_idx] + x_m * sol[1] / 2 + sol[2]
return x_max, y_max
logger.debug(f"Parabola interp: mode hill")
# mode hill
y_hill = y_trace[b_idx : e_idx + 1] - y_trace[b_idx]
x_hill = x_trace[b_idx : e_idx + 1] - x_trace[b_idx]
mat = np.empty((x_hill.shape[0], 3), dtype=np.float32)
mat[:, 2] = 1
mat[:, 1] = x_hill
mat[:, 0] = x_hill * x_hill
sol = np.linalg.lstsq(mat, y_hill, rcond=None)[0]
if -1e-5 < sol[0] and sol[0] < 1e-5:
# very flat case
return x_trace[idx_max], y_trace[idx_max]
x_m = -sol[1] / (2 * sol[0])
x_max = x_trace[b_idx] + x_m
y_max = y_trace[b_idx] + x_m * sol[1] / 2 + sol[2]
return x_max, y_max


def get_filter(time, trace, fr_min, fr_max):
Expand Down
50 changes: 17 additions & 33 deletions grand/basis/traces_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
"""
from logging import getLogger
import copy
from functools import lru_cache

import numpy as np
import scipy.signal as ssig
Expand All @@ -19,7 +18,7 @@


def get_psd(trace, f_samp_mhz, nperseg=0):
"""Estimate power spectrum density by Welch method
"""Reference estimation of power spectrum density by Welch's method
:param trace: floatX(nb_sample,)
:param f_samp_mhz: frequency sampling
Expand Down Expand Up @@ -93,6 +92,9 @@ def __init__(self, name="NotDefined"):
self._color = ["k", "y", "b"]
self.axis_name = self._d_axis_val["idx"]
self.network = None
# computing by user and store in object
self.t_max = None
self.v_max = None

### INTERNAL

Expand Down Expand Up @@ -121,7 +123,7 @@ def init_traces(self, traces, du_id=None, t_start_ns=None, f_samp_mhz=2000):
self.idx2idt = du_id
self.idt2idx = {idt: idx for idx, idt in enumerate(self.idx2idt)}
self.t_start_ns = t_start_ns
if isinstance(f_samp_mhz, int) or isinstance(f_samp_mhz, float):
if isinstance(f_samp_mhz, (int, float)):
self.f_samp_mhz = np.ones(len(du_id)) * f_samp_mhz
else:
self.f_samp_mhz = f_samp_mhz
Expand Down Expand Up @@ -182,11 +184,8 @@ def apply_bandpass(self, fr_min, fr_max, causal=True):
else:
filtered = ssig.filtfilt(coeff_b, coeff_a, self.traces)
self.traces = filtered.real
try:
delattr(self, "t_max")
delattr(self, "v_max")
except:
pass
self.t_max = None
self.v_max = None

def _define_t_samples(self):
"""
Expand Down Expand Up @@ -233,11 +232,8 @@ def keep_only_trace_with_index(self, l_idx):
if self.network:
self.network = copy.deepcopy(self.network)
self.network.keep_only_du_with_index(l_idx)
try:
delattr(self, "t_max")
delattr(self, "v_max")
except:
pass
self.t_max = None
self.v_max = None

def reduce_nb_trace(self, new_nb_du):
"""reduces the number of traces to the first <new_nb_du>
Expand Down Expand Up @@ -309,11 +305,8 @@ def copy(self, new_traces=None, deepcopy=True):
elif new_traces == 0:
new_traces = np.zeros_like(self.traces)
my_copy.traces = new_traces
try:
delattr(self, "t_max")
delattr(self, "v_max")
except:
pass
self.t_max = None
self.v_max = None
return my_copy

def get_delta_t_ns(self):
Expand Down Expand Up @@ -374,7 +367,7 @@ def get_tmax_vmax(self, hilbert=True, interpol="auto"):
self.t_max = tmax
self.v_max = vmax
return tmax, vmax
if not interpol in ["parab", "auto"]:
if interpol not in ["parab", "auto"]:
raise
t_max = np.empty_like(tmax)
v_max = np.empty_like(tmax)
Expand Down Expand Up @@ -479,7 +472,7 @@ def plot_trace_idx(self, idx, to_draw="012"): # pragma: no cover
self._color[idx_axis],
label=axis + r", $\sigma_{noise}\approx$" + f"{m_sig:.1e}",
)
if hasattr(self, "t_max"):
if self.t_max is not None:
snr = self.v_max[idx] / a_sigma.max()
plt.plot(
self.t_max[idx],
Expand Down Expand Up @@ -516,17 +509,10 @@ def plot_psd_trace_idx(self, idx, to_draw="012"): # pragma: no cover
plt.figure()
for idx_axis, axis in enumerate(self.axis_name):
if str(idx_axis) in to_draw:
if True:
freq, pxx_den = ssig.welch(
self.traces[idx, idx_axis],
self.f_samp_mhz[idx] * 1e6,
nperseg=self.nperseg,
window="bartlett",
scaling="density",
)
else:
freq, pxx_den = get_psd(self.traces[idx, idx_axis], self.f_samp_mhz)
plt.semilogy(freq[2:] * 1e-6, pxx_den[2:], self._color[idx_axis], label=axis)
freq, pxx_den = get_psd(
self.traces[idx, idx_axis], self.f_samp_mhz[idx], self.nperseg
)
plt.semilogy(freq[2:], pxx_den[2:], self._color[idx_axis], label=axis)
# plt.plot(freq[2:] * 1e-6, pxx_den[2:], self._color[idx_axis], label=axis)
m_title = f"Power spectrum density of {self.type_trace}, DU {self.idx2idt[idx]} (idx={idx})"
m_title += f"\nPeriodogram has {self.nperseg} samples, delta freq {freq[1]*1e-6:.2f}MHz"
Expand All @@ -536,8 +522,6 @@ def plot_psd_trace_idx(self, idx, to_draw="012"): # pragma: no cover
plt.xlim([0, 400])
plt.grid()
plt.legend()
self.welch_freq = freq
self.welch_pxx_den = pxx_den

def plot_psd_trace_du(self, du_id, to_draw="012"): # pragma: no cover
"""Draw power spectrum for 3 traces associated to DU idx2idt
Expand Down
6 changes: 3 additions & 3 deletions grand/dataio/root_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import os.path
from logging import getLogger
from typing import Optional
import glob

import numpy as np
import ROOT
Expand Down Expand Up @@ -71,7 +70,7 @@ def __init__(self, tt_event, f_name):
self.t_bin_size: Optional[float] = 0.5
self.du_id: Optional[None, np.ndarray, list] = None
self.du_count: Optional[None, int] = None
self.du_xyz: Optional[None, np.ndarray] = None
self.du_xyz: Optional[None, np.ndarray] = None
self.tag: Optional[str] = ""
#
self.f_name = f_name
Expand Down Expand Up @@ -204,7 +203,8 @@ def get_file_event(f_name):
if "tvoltage" in trees_list: # File with voltage info as input
return FileVoltage(f_name)
logger.error(
f"File {f_name} doesn't content TTree teventefield or teventvoltage. It contains {trees_list}."
f"File {f_name} doesn't content TTree teventefield or teventvoltage."
" It contains {trees_list}."
)
raise AssertionError

Expand Down

0 comments on commit c77c32a

Please sign in to comment.