You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The filter code is written in poor style in a variety of ways
it relies on a lot of internal mutation, making it hard to follow
it keeps dictionaries of filter values and variances for different flavors of filters that vary only by what they are orthogonal
because of 2 the code is hard to follow, and features a lot of repetition
Here is an example (incomplete) refactoring that avoid the use of mutation, limits code repetition, and generalizes over the differences between filters (noconst, baseline, etc). Moving the code in this direction is desireable. Now all the core math lives in just a few lines in compute_filter and compute_fourier_filter.
fromdataclassesimportdataclass, fieldimportnumpyasnpfromscipy.linalgimporttoeplitz@dataclass(frozen=True)classFilter:
"""A computed filter with metadata about its creation."""values: np.ndarrayvariance: floatfactory: "FilterFactory"modification_fn: callabledef__repr__(self):
return (f"Filter(values=array(shape={self.values.shape}), variance={self.variance}, "f"factory=FilterFactory(avg_signal=array(shape={self.factory.avg_signal.shape}), "f"modification_fn={self.modification_fn.__name__}))")
@propertydefvdv(self):
"""Compute the signal-to-noise ratio (SNR) of the filter."""# Signal: Use the norm of the filter values as the "signal"signal=np.linalg.norm(self.values)
# Noise: The variance of the filter is used as the "noise"noise=self.variance# Signal-to-noise ratio (SNR)returnsignal/np.sqrt(noise) ifnoise!=0elsenp.inf@dataclass(frozen=True)classFilterFactory:
"""Factory to compute filters."""avg_signal: np.ndarrayn_pretrigger: intsample_time_s: floatnoise_autocorr: np.ndarray=None# Optionalnoise_psd: np.ndarray=None# Required for Fourier-based filterwhitener: callable=Nonedefcompute_filter(self, modification_fn):
"""Compute a filter using the provided modification function."""n=len(self.avg_signal)
noise_corr=self.noise_autocorr[:n] ifself.noise_autocorrisnotNoneelsenp.ones(n) # Default to onesTS=ToeplitzSolver(noise_corr, symmetric=True)
Rinv_sig=TS(self.normalized_signal(self.avg_signal))
Rinv_1=TS(np.ones(n))
# Use the modification function to create the filter valuesfilter_values=modification_fn(Rinv_sig, Rinv_1)
variance=np.dot(filter_values, Rinv_sig)
returnFilter(
values=filter_values,
variance=variance,
factory=self,
modification_fn=modification_fn,
)
defcompute_fourier_filter(self):
"""Compute a filter using Fourier methods with noise_psd."""ifself.noise_psdisNone:
raiseValueError("noise_psd must be provided for compute_fourier_filter.")
n=len(self.avg_signal)
normalized_signal=self.normalized_signal(self.avg_signal)
# Compute the Fourier transform of the signal and the noise power spectral densitysignal_fft=np.fft.fft(normalized_signal)
noise_psd_fft=np.fft.fft(self.noise_psd[:n])
# Compute the filter in the frequency domainfourier_filter=np.conj(signal_fft) / (noise_psd_fft+1e-6) # Avoid division by zero# Transform back to the time domainfilter_values=np.fft.ifft(fourier_filter).real# Compute the variance of the filtervariance=np.dot(filter_values, normalized_signal)
returnFilter(
values=filter_values,
variance=variance,
factory=self,
modification_fn=self.compute_fourier_filter,
)
defshorten(self, front_samples: int, back_samples: int):
"""Return a new FilterFactory with truncated signal and compatible noise."""# Truncate the signal from the front and the backshortened_signal=self.avg_signal[front_samples:-back_samples]
# Renormalize the shortened signaltruncated_factory=FilterFactory(
avg_signal=shortened_signal,
n_pretrigger=self.n_pretrigger,
sample_time_s=self.sample_time_s,
noise_autocorr=self.truncate_noise(self.noise_autocorr, front_samples, back_samples),
noise_psd=self.truncate_noise(self.noise_psd, front_samples, back_samples),
whitener=self.whitener,
)
returntruncated_factorydeftruncate_noise(self, noise, front_samples, back_samples):
"""Helper function to truncate noise_autocorr and noise_psd in a compatible way."""ifnoiseisnotNone:
returnnoise[front_samples:-back_samples]
returnNone@staticmethoddefnormalized_signal(avg_signal, n_pretrigger):
"""Class method to compute the normalized signal."""pre_avg=avg_signal[:n_pretrigger-1].mean()
a=avg_signal.min()
b=avg_signal.max()
is_negative=pre_avg-a>b-pre_avgpeak_signal=a-pre_avgifis_negativeelseb-pre_avgnormalized_signal= (avg_signal-pre_avg) /peak_signalnormalized_signal[:n_pretrigger] =0.0returnnormalized_signal@staticmethoddefbaseline_modification(Rinv_sig, Rinv_1):
"""Modification function for baseline subtraction."""returnRinv_sig- (Rinv_sig.sum() /Rinv_1.sum()) *Rinv_1@staticmethoddefnoconst_modification(Rinv_sig, Rinv_1):
"""Modification function for no baseline subtraction."""returnRinv_1.sum() *Rinv_sig-Rinv_sig.sum() *Rinv_1classToeplitzSolver:
"""Helper to solve Toeplitz systems."""def__init__(self, autocorr, symmetric=True):
self.autocorr=autocorrself.symmetric=symmetricself.matrix=toeplitz(autocorr)
def__call__(self, vector):
"""Solve the system with the given vector."""returnnp.linalg.solve(self.matrix, vector)
I am not suggesting that I'll do this. I just thought I'd try having chatGPT write up a refactoring as a quick test. I spent 20-30 mins having it revise. Then once I'd done that, it was clear it wasn't good enough to drop right in, but good enough to make a good issue.
The filter code is written in poor style in a variety of ways
Here is an example (incomplete) refactoring that avoid the use of mutation, limits code repetition, and generalizes over the differences between filters (noconst, baseline, etc). Moving the code in this direction is desireable. Now all the core math lives in just a few lines in
compute_filter
andcompute_fourier_filter
.with example usage
The text was updated successfully, but these errors were encountered: