Skip to content

Commit

Permalink
refact: move selection_rate & mutation_rate out from GAParams t…
Browse files Browse the repository at this point in the history
…o operators (#478)

## Description

## Linked issues <!-- Please use "Resolves #<issue_no> syntax in case
this PR should be linked to an issue -->

Closes #477
Closes kkafar/master-monorepo#246

## Important implementation details <!-- if any, optional section -->
  • Loading branch information
kkafar authored Apr 29, 2024
1 parent 4a63ed7 commit 992feeb
Show file tree
Hide file tree
Showing 8 changed files with 57 additions and 151 deletions.
2 changes: 1 addition & 1 deletion coco/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ fn ecrs_ga_search(problem: &mut Problem, _max_budget: usize, _random_generator:
))
.set_selection_operator(selection::Tournament::new(0.2))
.set_crossover_operator(crossover::Uniform::new())
.set_mutation_operator(mutation::Reversing::new())
.set_mutation_operator(mutation::Reversing::new(0.05))
.set_replacement_operator(replacement::WeakParent::new())
.build();

Expand Down
12 changes: 5 additions & 7 deletions src/ga.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,8 @@ use self::{
};

pub struct GAParams {
pub selection_rate: f64,
pub mutation_rate: f64,
// pub selection_rate: f64,
// pub mutation_rate: f64,
pub population_size: usize,
pub generation_limit: usize,
pub max_duration: std::time::Duration,
Expand Down Expand Up @@ -322,11 +322,9 @@ where
self.metadata.crossover_dur = Some(self.timer.elapsed());

self.timer.start();
children.iter_mut().for_each(|child| {
self.config
.mutation_operator
.apply(&self.metadata, child, self.config.params.mutation_rate)
});
children
.iter_mut()
.for_each(|child| self.config.mutation_operator.apply(&self.metadata, child));
self.metadata.mutation_dur = Some(self.timer.elapsed());

if self.config.replacement_operator.requires_children_fitness() {
Expand Down
35 changes: 4 additions & 31 deletions src/ga/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,6 @@ impl Error for ConfigError {}
// don't have to write it by hand...
#[derive(Debug, Clone)]
pub struct GAParamsOpt {
pub selection_rate: Option<f64>,
pub mutation_rate: Option<f64>,
pub population_size: Option<usize>,
pub generation_limit: Option<usize>,
pub max_duration: Option<std::time::Duration>,
Expand All @@ -75,8 +73,6 @@ impl GAParamsOpt {
/// Returns new instance of [GAParamsOpt] struct. All fields are `None` initially.
pub fn new() -> Self {
Self {
selection_rate: None,
mutation_rate: None,
population_size: None,
generation_limit: None,
max_duration: None,
Expand All @@ -85,8 +81,6 @@ impl GAParamsOpt {

/// Sets all `None` values to values form `other`
pub fn fill_from(&mut self, other: &GAParams) {
self.selection_rate.get_or_insert(other.selection_rate);
self.mutation_rate.get_or_insert(other.mutation_rate);
self.population_size.get_or_insert(other.population_size);
self.generation_limit.get_or_insert(other.generation_limit);
self.max_duration.get_or_insert(other.max_duration);
Expand All @@ -97,14 +91,6 @@ impl TryFrom<GAParamsOpt> for GAParams {
type Error = ConfigError;

fn try_from(params_opt: GAParamsOpt) -> Result<Self, Self::Error> {
let Some(selection_rate) = params_opt.selection_rate else {
return Err(ConfigError::MissingParam("Unspecified selection rate".to_owned()));
};

let Some(mutation_rate) = params_opt.mutation_rate else {
return Err(ConfigError::MissingParam("Unspecified mutation rate".to_owned()));
};

let Some(population_size) = params_opt.population_size else {
return Err(ConfigError::MissingParam(
"Unspecified population size".to_owned(),
Expand All @@ -122,8 +108,6 @@ impl TryFrom<GAParamsOpt> for GAParams {
};

Ok(GAParams {
selection_rate,
mutation_rate,
population_size,
generation_limit,
max_duration,
Expand Down Expand Up @@ -313,8 +297,8 @@ impl Builder {
/// some of the values.
pub(crate) trait DefaultParams {
const DEFAULT_PARAMS: GAParams = GAParams {
selection_rate: 1.0,
mutation_rate: 0.05,
// selection_rate: 1.0,
// mutation_rate: 0.05,
population_size: 100,
generation_limit: usize::MAX,
max_duration: std::time::Duration::MAX,
Expand All @@ -340,8 +324,8 @@ mod test {
#[test]
fn new_param_opt_is_empty() {
let params = GAParamsOpt::new();
assert!(params.selection_rate.is_none());
assert!(params.mutation_rate.is_none());
// assert!(params.selection_rate.is_none());
// assert!(params.mutation_rate.is_none());
assert!(params.population_size.is_none());
assert!(params.generation_limit.is_none());
assert!(params.max_duration.is_none());
Expand All @@ -350,21 +334,16 @@ mod test {
#[test]
fn param_opt_fills_correctly() {
let mut params_opt = GAParamsOpt::new();
params_opt.selection_rate = Some(0.5);
params_opt.generation_limit = Some(100);

let params = GAParams {
selection_rate: 1.0,
mutation_rate: 1.0,
population_size: 100,
generation_limit: 200,
max_duration: std::time::Duration::from_secs(1),
};

params_opt.fill_from(&params);

assert!(params_opt.selection_rate.is_some() && params_opt.selection_rate.unwrap() == 0.5);
assert!(params_opt.mutation_rate.is_some() && params_opt.mutation_rate.unwrap() == 1.0);
assert!(params_opt.population_size.is_some() && params_opt.population_size.unwrap() == 100);
assert!(params_opt.generation_limit.is_some() && params_opt.generation_limit.unwrap() == 100);
assert!(
Expand All @@ -377,12 +356,6 @@ mod test {
fn conversion_works_as_expected() {
let mut params_opt = GAParamsOpt::new();

params_opt.selection_rate = Some(1.0);
assert!(convert_gaparamsopt_to_ga_params(params_opt.clone()).is_err());

params_opt.mutation_rate = Some(0.0);
assert!(convert_gaparamsopt_to_ga_params(params_opt.clone()).is_err());

params_opt.population_size = Some(200);
assert!(convert_gaparamsopt_to_ga_params(params_opt.clone()).is_err());

Expand Down
26 changes: 3 additions & 23 deletions src/ga/builder/bitstring.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,28 +57,6 @@ impl<F: Fitness<BitStringIndividual>> BitStringBuilder<F> {
}
}

/// Sets selection rate
///
/// ## Arguments
///
/// * `selection_rate` - Selection rate; must be in [0, 1] interval
pub fn set_selection_rate(mut self, selection_rate: f64) -> Self {
assert!((0f64..=1f64).contains(&selection_rate));
self.config.params.selection_rate = Some(selection_rate);
self
}

/// Sets mutation rate
///
/// ## Arguments
///
/// * `mutation_rate` - Mutation rate; must be in [0, 1] interval
pub fn set_mutation_rate(mut self, mutation_rate: f64) -> Self {
assert!((0.0..=1.0).contains(&mutation_rate));
self.config.params.mutation_rate = Some(mutation_rate);
self
}

/// Sets max duration. If exceeded, the algorithm halts.
///
/// ## Arguments
Expand Down Expand Up @@ -167,7 +145,9 @@ impl<F: Fitness<BitStringIndividual>> BitStringBuilder<F> {
self.config
.crossover_operator
.get_or_insert_with(SinglePoint::new);
self.config.mutation_operator.get_or_insert_with(FlipBit::new);
self.config
.mutation_operator
.get_or_insert_with(|| FlipBit::new(0.05));
self.config
.selection_operator
.get_or_insert_with(|| Tournament::new(0.2));
Expand Down
30 changes: 0 additions & 30 deletions src/ga/builder/generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,36 +75,6 @@ where
}
}

/// Sets selection rate
///
/// ## Arguments
///
/// * `selection_rate` - Selection rate; must be in [0, 1] interval
///
/// ## Panics
///
/// If the `selection_rate` param has invalid value.
pub fn set_selection_rate(mut self, selection_rate: f64) -> Self {
assert!((0f64..=1f64).contains(&selection_rate));
self.config.params.selection_rate = Some(selection_rate);
self
}

/// Sets mutation rate
///
/// ## Arguments
///
/// * `mutation_rate` - Mutation rate; must be in [0, 1] interval
///
/// ## Panics
///
/// If the parameter has invalid value.
pub fn set_mutation_rate(mut self, mutation_rate: f64) -> Self {
assert!((0.0..=1.0).contains(&mutation_rate));
self.config.params.mutation_rate = Some(mutation_rate);
self
}

/// Sets max duration. If exceeded, the algorithm halts.
///
/// ## Arguments
Expand Down
26 changes: 3 additions & 23 deletions src/ga/builder/realvalued.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,28 +56,6 @@ impl<F: Fitness<RealValueIndividual>> RealValuedBuilder<F> {
}
}

/// Sets selection rate
///
/// ## Arguments
///
/// * `selection_rate` - Selection rate; must be in [0, 1] interval
pub fn set_selection_rate(mut self, selection_rate: f64) -> Self {
debug_assert!((0f64..=1f64).contains(&selection_rate));
self.config.params.selection_rate = Some(selection_rate);
self
}

/// Sets mutation rate
///
/// ## Arguments
///
/// * `mutation_rate` - Mutation rate; must be in [0, 1] interval
pub fn set_mutation_rate(mut self, mutation_rate: f64) -> Self {
assert!((0.0..=1.0).contains(&mutation_rate));
self.config.params.mutation_rate = Some(mutation_rate);
self
}

/// Sets max duration. If exceeded, the algorithm halts.
///
/// ## Arguments
Expand Down Expand Up @@ -166,7 +144,9 @@ impl<F: Fitness<RealValueIndividual>> RealValuedBuilder<F> {
self.config
.crossover_operator
.get_or_insert_with(SinglePoint::new);
self.config.mutation_operator.get_or_insert_with(Interchange::new);
self.config
.mutation_operator
.get_or_insert_with(|| Interchange::new(0.05));
self.config
.selection_operator
.get_or_insert_with(|| Tournament::new(0.2));
Expand Down
2 changes: 1 addition & 1 deletion src/ga/operators/mutation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@ pub trait MutationOperator<IndividualT: IndividualTrait> {
///
/// * `individual` - mutable reference to to-be-mutated individual
/// * `mutation_rate` - probability of gene mutation
fn apply(&mut self, metadata: &GAMetadata, individual: &mut IndividualT, mutation_rate: f64);
fn apply(&mut self, metadata: &GAMetadata, individual: &mut IndividualT);
}
Loading

0 comments on commit 992feeb

Please sign in to comment.