Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor filter code #312

Open
ggggggggg opened this issue Nov 15, 2024 · 2 comments
Open

refactor filter code #312

ggggggggg opened this issue Nov 15, 2024 · 2 comments

Comments

@ggggggggg
Copy link
Collaborator

ggggggggg commented Nov 15, 2024

The filter code is written in poor style in a variety of ways

  1. it relies on a lot of internal mutation, making it hard to follow
  2. it keeps dictionaries of filter values and variances for different flavors of filters that vary only by what they are orthogonal
  3. because of 2 the code is hard to follow, and features a lot of repetition

Here is an example (incomplete) refactoring that avoid the use of mutation, limits code repetition, and generalizes over the differences between filters (noconst, baseline, etc). Moving the code in this direction is desireable. Now all the core math lives in just a few lines in compute_filter and compute_fourier_filter.

from dataclasses import dataclass, field
import numpy as np
from scipy.linalg import toeplitz


@dataclass(frozen=True)
class Filter:
    """A computed filter with metadata about its creation."""
    values: np.ndarray
    variance: float
    factory: "FilterFactory"
    modification_fn: callable

    def __repr__(self):
        return (f"Filter(values=array(shape={self.values.shape}), variance={self.variance}, "
                f"factory=FilterFactory(avg_signal=array(shape={self.factory.avg_signal.shape}), "
                f"modification_fn={self.modification_fn.__name__}))")

    @property
    def vdv(self):
        """Compute the signal-to-noise ratio (SNR) of the filter."""
        # Signal: Use the norm of the filter values as the "signal"
        signal = np.linalg.norm(self.values)

        # Noise: The variance of the filter is used as the "noise"
        noise = self.variance

        # Signal-to-noise ratio (SNR)
        return signal / np.sqrt(noise) if noise != 0 else np.inf


@dataclass(frozen=True)
class FilterFactory:
    """Factory to compute filters."""
    avg_signal: np.ndarray
    n_pretrigger: int
    sample_time_s: float
    noise_autocorr: np.ndarray = None  # Optional
    noise_psd: np.ndarray = None  # Required for Fourier-based filter
    whitener: callable = None

    def compute_filter(self, modification_fn):
        """Compute a filter using the provided modification function."""
        n = len(self.avg_signal)
        noise_corr = self.noise_autocorr[:n] if self.noise_autocorr is not None else np.ones(n)  # Default to ones
        TS = ToeplitzSolver(noise_corr, symmetric=True)
        Rinv_sig = TS(self.normalized_signal(self.avg_signal))
        Rinv_1 = TS(np.ones(n))

        # Use the modification function to create the filter values
        filter_values = modification_fn(Rinv_sig, Rinv_1)
        variance = np.dot(filter_values, Rinv_sig)

        return Filter(
            values=filter_values,
            variance=variance,
            factory=self,
            modification_fn=modification_fn,
        )

    def compute_fourier_filter(self):
        """Compute a filter using Fourier methods with noise_psd."""
        if self.noise_psd is None:
            raise ValueError("noise_psd must be provided for compute_fourier_filter.")

        n = len(self.avg_signal)
        normalized_signal = self.normalized_signal(self.avg_signal)

        # Compute the Fourier transform of the signal and the noise power spectral density
        signal_fft = np.fft.fft(normalized_signal)
        noise_psd_fft = np.fft.fft(self.noise_psd[:n])

        # Compute the filter in the frequency domain
        fourier_filter = np.conj(signal_fft) / (noise_psd_fft + 1e-6)  # Avoid division by zero

        # Transform back to the time domain
        filter_values = np.fft.ifft(fourier_filter).real

        # Compute the variance of the filter
        variance = np.dot(filter_values, normalized_signal)

        return Filter(
            values=filter_values,
            variance=variance,
            factory=self,
            modification_fn=self.compute_fourier_filter,
        )

    def shorten(self, front_samples: int, back_samples: int):
        """Return a new FilterFactory with truncated signal and compatible noise."""
        # Truncate the signal from the front and the back
        shortened_signal = self.avg_signal[front_samples:-back_samples]
        
        # Renormalize the shortened signal
        truncated_factory = FilterFactory(
            avg_signal=shortened_signal,
            n_pretrigger=self.n_pretrigger,
            sample_time_s=self.sample_time_s,
            noise_autocorr=self.truncate_noise(self.noise_autocorr, front_samples, back_samples),
            noise_psd=self.truncate_noise(self.noise_psd, front_samples, back_samples),
            whitener=self.whitener,
        )
        return truncated_factory

    def truncate_noise(self, noise, front_samples, back_samples):
        """Helper function to truncate noise_autocorr and noise_psd in a compatible way."""
        if noise is not None:
            return noise[front_samples:-back_samples]
        return None

    @staticmethod
    def normalized_signal(avg_signal, n_pretrigger):
        """Class method to compute the normalized signal."""
        pre_avg = avg_signal[:n_pretrigger - 1].mean()
        a = avg_signal.min()
        b = avg_signal.max()
        is_negative = pre_avg - a > b - pre_avg
        peak_signal = a - pre_avg if is_negative else b - pre_avg
        normalized_signal = (avg_signal - pre_avg) / peak_signal
        normalized_signal[:n_pretrigger] = 0.0
        return normalized_signal

    @staticmethod
    def baseline_modification(Rinv_sig, Rinv_1):
        """Modification function for baseline subtraction."""
        return Rinv_sig - (Rinv_sig.sum() / Rinv_1.sum()) * Rinv_1

    @staticmethod
    def noconst_modification(Rinv_sig, Rinv_1):
        """Modification function for no baseline subtraction."""
        return Rinv_1.sum() * Rinv_sig - Rinv_sig.sum() * Rinv_1


class ToeplitzSolver:
    """Helper to solve Toeplitz systems."""
    def __init__(self, autocorr, symmetric=True):
        self.autocorr = autocorr
        self.symmetric = symmetric
        self.matrix = toeplitz(autocorr)

    def __call__(self, vector):
        """Solve the system with the given vector."""
        return np.linalg.solve(self.matrix, vector)

with example usage

# Example data
avg_signal = np.array([0.1, 0.2, 0.5, 1.0, 0.8, 0.3])
n_pretrigger = 2
sample_time_s = 0.01
noise_autocorr = np.array([1.0, 0.5, 0.2, 0.1])  # Optional
noise_psd = np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6])  # Required for Fourier filter

# Create a factory
factory = FilterFactory(avg_signal, n_pretrigger, sample_time_s, noise_autocorr=noise_autocorr, noise_psd=noise_psd)

# Compute filters
baseline_filter = factory.compute_filter(FilterFactory.baseline_modification)
noconst_filter = factory.compute_filter(FilterFactory.noconst_modification)

# Print signal-to-noise ratio (vdv)
print(f"Baseline filter vdv: {baseline_filter.vdv}")
print(f"No constant filter vdv: {noconst_filter.vdv}")
@joefowler
Copy link
Member

Good idea. Are you suggesting that you'll do this? I am happy to, and I was the one just digging into the filter designs this week.

@ggggggggg
Copy link
Collaborator Author

I am not suggesting that I'll do this. I just thought I'd try having chatGPT write up a refactoring as a quick test. I spent 20-30 mins having it revise. Then once I'd done that, it was clear it wasn't good enough to drop right in, but good enough to make a good issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants