Skip to content

Commit

Permalink
add fitting algorithm as a part of the pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
warisa-r committed Sep 16, 2024
1 parent 765941e commit 97be19a
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 16 deletions.
5 changes: 3 additions & 2 deletions examples/obstacle.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ ENV["JULIA_NUM_THREADS"] = "4"
flow_data = FlowData("examples/data/obstacle/funky_square.celltape")
point_detect_algo = ImageProcessingShockDetectionAlgo(0.2, :prewitt)
dbscan_algo = DBSCANAlgo(0.25, 50, 40)
fitting_algo = FittingAlgo(0.1, false)

# Square case
# What seems to be happening in the first frame is that
Expand All @@ -17,7 +18,7 @@ dbscan_algo = DBSCANAlgo(0.25, 50, 40)
# Decreasing the gradient threshold will increase the shock detects at the edges of the obstacle, making it harder to seperate
# both moving shocks and a line will still be fitted, despite the algorithm having detected more "tail" of the shock

detection = detect(flow_data, point_detect_algo, dbscan_algo)
detection = detect(flow_data, point_detect_algo, dbscan_algo, fitting_algo)

#plot_shock_fits_over_time(flow_data, detection, false)
#create_heatmap_evo_with_shock(flow_data, detection, :density_field, true, true)
create_heatmap_evo_with_shock(flow_data, detection, :density_field, true, true)
2 changes: 1 addition & 1 deletion src/ShockwaveDetection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ export GradientShockDetectionAlgo
export ImageProcessingShockDetectionAlgo
export DBSCANAlgo
export cluster_shock_points
export fit_shock_clusters_over_time, calculate_normal_vector
export FittingAlgo, fit_shock_clusters_over_time, calculate_normal_vector
export create_wave_animation, create_wave_animation_with_shock, create_heatmap_evo, create_heatmap_evo_with_shock, plot_shock_fits_over_time
export ShockDetectionResult2D, detect
export ShockDetectionResult1D, detect
Expand Down
45 changes: 36 additions & 9 deletions src/fitting.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,31 @@
using LsqFit
using LinearAlgebra

"""
struct FittingAlgo{T}
A structure that represents the fitting algorithm configuration.
# Fields
- `angle_tolerance::T`: The tolerance value for checking the closeness of angles.
- `use_initial_guess::Bool`: A flag indicating whether to use an initial guess for the fitting algorithm.
# Constructors
- `FittingAlgo(angle_tolerance::T = convert(T, 0.1), use_initial_guess::Bool = false)`:
Creates a new `FittingAlgo` instance with the specified parameters. The default value for `angle_tolerance` is 0.1, and the default value for `use_initial_guess` is `false`.
# Example
```julia
algo = FittingAlgo(0.05, true)
"""
struct FittingAlgo{T}
angle_tolerance::T
use_initial_guess::Bool
function FittingAlgo(angle_tolerance::T = convert(T, 0.1), use_initial_guess::Bool = false) where T
new{T}(angle_tolerance, use_initial_guess)
end
end

struct Fitting{T, U}
model::Function
parameters::Array{T}
Expand Down Expand Up @@ -48,7 +73,7 @@ function vline_model(xy, p)
return x .- c
end

function fit_shock_cluster(cluster)
function fit_shock_cluster(cluster, FittingAlgo)

# Helper function to convert a cluster to a matrix of data points in a form that LsqFit can use
function cluster_to_data_points(shock_cluster)
Expand All @@ -64,8 +89,11 @@ function fit_shock_cluster(cluster)

xy = cluster_to_data_points(cluster)
models = [vline_model, hline_model, line_model, circle_model, parabola_model] # Use only these three firsts
#TODO: make user choose between interpolation guess or random guess
p0s = [rand(1), rand(1), rand(2), rand(3), rand(3)] # Initial parameters for each model
if FittingAlgo.use_initial_guess
p0s = [0, 0, 0, 0, 0] # TODO: call interpolation
else
p0s = [rand(1), rand(1), rand(2), rand(3), rand(3)] # Initial parameters for each model
end

best_fit = nothing
least_error = Inf
Expand Down Expand Up @@ -93,8 +121,7 @@ function fit_shock_cluster(cluster)
range = (minimum(angles), maximum(angles))

# Define a tolerance for angle closeness
#TODO: Make this a parameter? Or fine tune this?
tolerance = 0.1 # Adjust as needed for your application
tolerance = FittingAlgo.angle_tolerance

# Check if the angles span a full circle within the tolerance
if abs(range[1]- (-pi)) < tolerance && abs(range[2] - pi) < tolerance
Expand All @@ -110,11 +137,11 @@ function fit_shock_cluster(cluster)
return best_fit
end

function fit_shock_clusters(shock_clusters)
function fit_shock_clusters(shock_clusters, FittingAlgo)
shock_fits = []
if !isempty(shock_clusters)
@threads for shock_cluster in shock_clusters # Sequence of fit doesn't matter
best_fit = fit_shock_cluster(shock_cluster)
best_fit = fit_shock_cluster(shock_cluster, FittingAlgo)
push!(shock_fits, best_fit)
end
end
Expand All @@ -133,13 +160,13 @@ The fitted shock clusters are stored in the array `shock_fits_over_time`.
# Returns
- `shock_fits_over_time`: An array of fitted shock clusters over time.
"""
function fit_shock_clusters_over_time(shock_clusters_over_time)
function fit_shock_clusters_over_time(shock_clusters_over_time, FittingAlgo)
nsteps = length(shock_clusters_over_time)
shock_fits_over_time = Vector{Any}(undef, nsteps)
@threads for t in 1:nsteps
shock_fits = []
if !isempty(shock_clusters_over_time[t])
shock_fits = fit_shock_clusters(shock_clusters_over_time[t])
shock_fits = fit_shock_clusters(shock_clusters_over_time[t], FittingAlgo)
end
shock_fits_over_time[t] = shock_fits
end
Expand Down
7 changes: 4 additions & 3 deletions src/pipeline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,15 @@ struct ShockDetectionResult1D <: AbstractShockDetectionResult
end

"""
detect(flow_data::FlowData, shock_point_algo::Abstract2DShockDetectionAlgo, cluster_algo::DBSCANAlgo)
detect(flow_data::FlowData, shock_point_algo::Abstract2DShockDetectionAlgo, cluster_algo::DBSCANAlgo, fitting_algo::FittingAlgo)
Detects shocks in 2D flow data, clusters the shockpoints and applies fitting to the cluster. It also shows the runtime and memory allocations required in each subprocess
# Arguments
- `flow_data::FlowData`: A `FlowData` object containing the 2D flow field data.
- `shock_point_algo::Abstract2DShockDetectionAlgo`: An algorithm for detecting shock points in 2D flow data.
- `cluster_algo::DBSCANAlgo`: A clustering algorithm (e.g., DBSCAN) to group detected shock points into clusters.
- `fitting_algo::FittingAlgo`: An algorithm for fitting shock clusters to create a smooth representation of the shock over time.
# Returns
- `ShockDetectionResult2D`: An object containing:
Expand All @@ -44,7 +45,7 @@ Detects shocks in 2D flow data, clusters the shockpoints and applies fitting to
# Description
This function detects shock points in 2D flow data using a specified shock detection algorithm. Detected shock points are clustered using the provided `DBSCANAlgo`, and then the clusters are fitted to create a smooth representation of the shock over time.
"""
function detect(flow_data::FlowData, shock_point_algo::Abstract2DShockDetectionAlgo, cluster_algo::DBSCANAlgo)
function detect(flow_data::FlowData, shock_point_algo::Abstract2DShockDetectionAlgo, cluster_algo::DBSCANAlgo, fitting_algo::FittingAlgo)
to = TimerOutput()

@timeit to "Detect Shock Points(2D)" begin
Expand All @@ -56,7 +57,7 @@ function detect(flow_data::FlowData, shock_point_algo::Abstract2DShockDetectionA
end

@timeit to "Fit Shock Clusters" begin
shock_fits_over_time = fit_shock_clusters_over_time(shock_clusters_over_time)
shock_fits_over_time = fit_shock_clusters_over_time(shock_clusters_over_time, fitting_algo)
end

show(to, sortby = :firstexec)
Expand Down
3 changes: 2 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ end
flow_data = FlowData(joinpath(DATA_DIR, "sod_shock_right_2d.tape"), false)
point_detect_algo = ImageProcessingShockDetectionAlgo(0.5, :prewitt)
dbscan_algo = DBSCANAlgo(0.25, 3, 10)
detection = detect(flow_data, point_detect_algo, dbscan_algo)
fitting_algo = FittingAlgo(0.1, false)
detection = detect(flow_data, point_detect_algo, dbscan_algo, fitting_algo)
plot_shock_fits_over_time(flow_data, detection, true)
end

Expand Down

0 comments on commit 97be19a

Please sign in to comment.