Skip to content

Commit

Permalink
Merge pull request #3146 from matsim-org/feature/subpopulationsInAnne…
Browse files Browse the repository at this point in the history
…aler

Feature/subpopulations in annealer
  • Loading branch information
jfbischoff authored Mar 18, 2024
2 parents 73d5d76 + 4c6d7be commit fdac3c1
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ public class ReplanningAnnealer implements IterationStartsListener, StartupListe
private final ReplanningAnnealerConfigGroup saConfig;
private final int innovationStop;
private final String sep;
private final EnumMap<AnnealParameterOption, Double> currentValues;
private final EnumMap<AnnealParameterOption,Map<String, Double>> currentValuesPerSubpopulation;
private int currentIter;
private List<String> header;
@Inject
Expand All @@ -67,7 +67,7 @@ public class ReplanningAnnealer implements IterationStartsListener, StartupListe
public ReplanningAnnealer(Config config) {
this.config = config;
this.saConfig = ConfigUtils.addOrGetModule(config, ReplanningAnnealerConfigGroup.class);
this.currentValues = new EnumMap<>(AnnealParameterOption.class);
this.currentValuesPerSubpopulation = new EnumMap<>(AnnealParameterOption.class);
this.innovationStop = getInnovationStop(config);
this.sep = config.global().getDefaultDelimiter();
}
Expand All @@ -83,17 +83,19 @@ private static boolean isInnovationStrategy(String strategyName) {
@Override
public void notifyStartup(StartupEvent event) {
header = new ArrayList<>();
for (AnnealingVariable av : this.saConfig.getAnnealingVariables().values()) {
for (AnnealingVariable av : this.saConfig.getAllAnnealingVariables()) {
if (!av.getAnnealType().equals(AnnealOption.disabled)) {
// check and fix initial value if needed
checkAndFixStartValue(av, event);

this.currentValues.put(av.getAnnealParameter(), av.getStartValue());
header.add(av.getAnnealParameter().name());
var mapPerSubpopulation = this.currentValuesPerSubpopulation.computeIfAbsent(av.getAnnealParameter(),a-> new HashMap<>());
mapPerSubpopulation.put(av.getSubpopulation(),av.getStartValue());
String subpopulationString = av.getSubpopulation()!=null? "_"+av.getSubpopulation() :"";
header.add(av.getAnnealParameter().name()+subpopulationString);
if (av.getAnnealParameter().equals(AnnealParameterOption.globalInnovationRate)) {
header.addAll(this.config.replanning().getStrategySettings().stream()
.filter(s -> Objects.equals(av.getDefaultSubpopulation(), s.getSubpopulation()))
.map(ReplanningConfigGroup.StrategySettings::getStrategyName)
.filter(s -> Objects.equals(av.getSubpopulation(), s.getSubpopulation()))
.map(strategySettings -> strategySettings.getStrategyName()+subpopulationString)
.collect(Collectors.toList()));
}
} else { // if disabled, better remove it
Expand All @@ -114,34 +116,35 @@ public void notifyStartup(StartupEvent event) {
public void notifyIterationStarts(IterationStartsEvent event) {
this.currentIter = event.getIteration() - this.config.controller().getFirstIteration();
Map<String, String> annealStats = new HashMap<>();
for (AnnealingVariable av : this.saConfig.getAnnealingVariables().values()) {
List<AnnealingVariable> allVariables = this.saConfig.getAllAnnealingVariables();
for (AnnealingVariable av : allVariables) {
if (this.currentIter > 0) {
switch (av.getAnnealType()) {
case geometric:
this.currentValues.compute(av.getAnnealParameter(), (k, v) ->
this.currentValuesPerSubpopulation.get(av.getAnnealParameter()).compute(av.getSubpopulation(), (k, v) ->
v * av.getShapeFactor());
break;
case exponential:
int halfLifeIter = av.getHalfLife() <= 1.0 ?
(int) (av.getHalfLife() * this.innovationStop) : (int) av.getHalfLife();
this.currentValues.compute(av.getAnnealParameter(), (k, v) ->
this.currentValuesPerSubpopulation.get(av.getAnnealParameter()).compute(av.getSubpopulation(), (k, v) ->
av.getStartValue() / Math.exp((double) this.currentIter / halfLifeIter));
break;
case msa:
this.currentValues.compute(av.getAnnealParameter(), (k, v) ->
this.currentValuesPerSubpopulation.get(av.getAnnealParameter()).compute(av.getSubpopulation(), (k, v) ->
av.getStartValue() / Math.pow(this.currentIter, av.getShapeFactor()));
break;
case sigmoid:
halfLifeIter = av.getHalfLife() <= 1.0 ?
(int) (av.getHalfLife() * this.innovationStop) : (int) av.getHalfLife();
this.currentValues.compute(av.getAnnealParameter(), (k, v) ->
this.currentValuesPerSubpopulation.get(av.getAnnealParameter()).compute(av.getSubpopulation(), (k, v) ->
av.getEndValue() + (av.getStartValue() - av.getEndValue()) /
(1 + Math.exp(av.getShapeFactor() * (this.currentIter - halfLifeIter))));
break;
case linear:
double slope = (av.getStartValue() - av.getEndValue())
/ (this.config.controller().getFirstIteration() - this.innovationStop);
this.currentValues.compute(av.getAnnealParameter(), (k, v) ->
this.currentValuesPerSubpopulation.get(av.getAnnealParameter()).compute(av.getSubpopulation(), (k, v) ->
this.currentIter * slope + av.getStartValue());
break;
case disabled:
Expand All @@ -150,14 +153,16 @@ public void notifyIterationStarts(IterationStartsEvent event) {
throw new IllegalArgumentException();
}

log.info("Annealling will be performed on parameter " + av.getAnnealParameter() +
". Value: " + this.currentValues.get(av.getAnnealParameter()));
log.info("Annealling will be performed on parameter " + av.getAnnealParameter() +". Subpopulation: "+av.getSubpopulation()+
". Value: " +this.currentValuesPerSubpopulation.get(av.getAnnealParameter()).get(av.getSubpopulation()));

this.currentValues.compute(av.getAnnealParameter(), (k, v) ->
this.currentValuesPerSubpopulation.get(av.getAnnealParameter()).compute(av.getSubpopulation(), (k, v) ->
Math.max(v, av.getEndValue()));
}
double annealValue = this.currentValues.get(av.getAnnealParameter());
annealStats.put(av.getAnnealParameter().name(), String.format(Locale.US, "%.4f", annealValue));
double annealValue = this.currentValuesPerSubpopulation.get(av.getAnnealParameter()).get(av.getSubpopulation());
String subpopulationString = av.getSubpopulation()!=null? "_"+av.getSubpopulation() :"";

annealStats.put(av.getAnnealParameter().name()+subpopulationString, String.format(Locale.US, "%.4f", annealValue));
anneal(event, av, annealValue, annealStats);
}

Expand All @@ -178,6 +183,8 @@ private void writeIterationstats(int currentIter, Map<String, String> annealStat
}

private void anneal(IterationStartsEvent event, AnnealingVariable av, double annealValue, Map<String, String> annealStats) {
String subpopulationString = av.getSubpopulation()!=null? "_"+av.getSubpopulation() :"";

switch (av.getAnnealParameter()) {
case BrainExpBeta:
this.config.scoring().setBrainExpBeta(annealValue);
Expand All @@ -193,16 +200,17 @@ private void anneal(IterationStartsEvent event, AnnealingVariable av, double ann
annealValue = 0.0;
}
List<Double> annealValues = annealReplanning(annealValue,
event.getServices().getStrategyManager(), av.getDefaultSubpopulation());
event.getServices().getStrategyManager(), av.getSubpopulation());
int i = 0;
for (ReplanningConfigGroup.StrategySettings ss : this.config.replanning().getStrategySettings()) {
if (Objects.equals(ss.getSubpopulation(), av.getDefaultSubpopulation())) {
annealStats.put(ss.getStrategyName(), String.format(Locale.US, "%.4f", annealValues.get(i)));
if (Objects.equals(ss.getSubpopulation(), av.getSubpopulation())) {
annealStats.put(ss.getStrategyName()+subpopulationString, String.format(Locale.US, "%.4f", annealValues.get(i)));
i++;
}
}
annealStats.put(av.getAnnealParameter().name(), String.format(Locale.US, "%.4f", // update value in case of switchoff
getStrategyWeights(event.getServices().getStrategyManager(), av.getDefaultSubpopulation(), StratType.allInnovation)));

annealStats.put(av.getAnnealParameter().name()+subpopulationString, String.format(Locale.US, "%.4f", // update value in case of switchoff
getStrategyWeights(event.getServices().getStrategyManager(), av.getSubpopulation(), StratType.allInnovation)));
break;
default:
throw new IllegalArgumentException();
Expand Down Expand Up @@ -328,14 +336,14 @@ private void checkAndFixStartValue(ReplanningAnnealerConfigGroup.AnnealingVariab
configValue = this.config.scoring().getLearningRate();
break;
case globalInnovationRate:
double innovationWeights = getStrategyWeights(this.config, av.getDefaultSubpopulation(), StratType.allInnovation);
double selectorWeights = getStrategyWeights(this.config, av.getDefaultSubpopulation(), StratType.allSelectors);
double innovationWeights = getStrategyWeights(this.config, av.getSubpopulation(), StratType.allInnovation);
double selectorWeights = getStrategyWeights(this.config, av.getSubpopulation(), StratType.allSelectors);
if (innovationWeights + selectorWeights != 1.0) {
log.warn("Initial sum of strategy weights different from 1.0. Rescaling.");
double innovationStartValue = av.getStartValue() == null ? innovationWeights : av.getStartValue();
rescaleStartupWeights(innovationStartValue, this.config, event.getServices().getStrategyManager(), av.getDefaultSubpopulation());
rescaleStartupWeights(innovationStartValue, this.config, event.getServices().getStrategyManager(), av.getSubpopulation());
}
configValue = getStrategyWeights(this.config, av.getDefaultSubpopulation(), StratType.allInnovation);
configValue = getStrategyWeights(this.config, av.getSubpopulation(), StratType.allInnovation);
break;
default:
throw new IllegalArgumentException();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
package org.matsim.core.replanning.annealing;

import java.util.EnumMap;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import org.matsim.core.config.ConfigGroup;
import org.matsim.core.config.ReflectiveConfigGroup;

Expand Down Expand Up @@ -81,27 +85,35 @@ public void addParameterSet(final ConfigGroup set) {
addAnnealingVariable((AnnealingVariable) set);
}

public Map<AnnealParameterOption, AnnealingVariable> getAnnealingVariables() {
final EnumMap<AnnealParameterOption, AnnealingVariable> map =
new EnumMap<>(AnnealParameterOption.class);
for (ConfigGroup pars : getParameterSets(AnnealingVariable.GROUP_NAME)) {
final AnnealParameterOption name = ((AnnealingVariable) pars).getAnnealParameter();
final AnnealingVariable old = map.put(name, (AnnealingVariable) pars);
if (old != null) {
throw new IllegalStateException("several parameter sets for variable " + name);
}
}
return map;
}
public List<AnnealingVariable> getAllAnnealingVariables(){
return getAnnealingVariablesPerSubpopulation().values().stream().flatMap(a->a.values().stream()).collect(Collectors.toList());
}
public Map<AnnealParameterOption, Map<String,AnnealingVariable>> getAnnealingVariablesPerSubpopulation() {
final EnumMap<AnnealParameterOption, Map<String,AnnealingVariable>> map =
new EnumMap<>(AnnealParameterOption.class);
for (ConfigGroup pars : getParameterSets(AnnealingVariable.GROUP_NAME)) {
AnnealParameterOption name = ((AnnealingVariable) pars).getAnnealParameter();
String subpopulation = ((AnnealingVariable) pars).getSubpopulation();
var paramsPerSubpopulation = map.computeIfAbsent(name,a->new HashMap<>());
final AnnealingVariable old = paramsPerSubpopulation.put(subpopulation, (AnnealingVariable) pars);
if (old != null) {
throw new IllegalStateException("several parameter sets for variable " + name + " and subpopulation "+subpopulation);
}
}
return map;
}

public void addAnnealingVariable(final AnnealingVariable params) {
final AnnealingVariable previous = this.getAnnealingVariables().get(params.getAnnealParameter());
var previousMap = this.getAnnealingVariablesPerSubpopulation().get(params.getAnnealParameter());
if (previousMap!=null){
AnnealingVariable previous = previousMap.get(params.getSubpopulation());
if (previous != null) {
final boolean removed = removeParameterSet(previous);
if (!removed) {
throw new RuntimeException("problem replacing annealing variable");
}
}
}
super.addParameterSet(params);
}

Expand All @@ -117,11 +129,11 @@ public static class AnnealingVariable extends ReflectiveConfigGroup {
private static final String START_VALUE = "startValue";
private static final String END_VALUE = "endValue";
private static final String ANNEAL_TYPE = "annealType";
private static final String DEFAULT_SUBPOP = "defaultSubpopulation";
private static final String SUBPOPULATION = "subpopulation";
private static final String ANNEAL_PARAM = "annealParameter";
private static final String HALFLIFE = "halfLife";
private static final String SHAPE_FACTOR = "shapeFactor";
private String defaultSubpop = null;
private String subpopulation = null;
private Double startValue = null;
private double endValue = 0.0001;
private double shapeFactor = 0.9;
Expand Down Expand Up @@ -167,14 +179,14 @@ public void setAnnealType(AnnealOption annealType) {
this.annealType = annealType;
}

@StringGetter(DEFAULT_SUBPOP)
public String getDefaultSubpopulation() {
return this.defaultSubpop;
@StringGetter(SUBPOPULATION)
public String getSubpopulation() {
return this.subpopulation;
}

@StringSetter(DEFAULT_SUBPOP)
@StringSetter(SUBPOPULATION)
public void setDefaultSubpopulation(String defaultSubpop) {
this.defaultSubpop = defaultSubpop;
this.subpopulation = defaultSubpop;
}

@StringGetter(ANNEAL_PARAM)
Expand Down Expand Up @@ -220,7 +232,7 @@ public Map<String, String> getComments() {
map.put(ANNEAL_TYPE, "options: linear, exponential, geometric, msa, sigmoid and disabled (no annealing).");
map.put(ANNEAL_PARAM,
"list of config parameters that shall be annealed. Currently supported: globalInnovationRate, BrainExpBeta, PathSizeLogitBeta, learningRate. Default is globalInnovationRate");
map.put(DEFAULT_SUBPOP, "subpopulation to have the global innovation rate adjusted. Not applicable when annealing with other parameters.");
map.put(SUBPOPULATION, "subpopulation to have the global innovation rate adjusted. Not applicable when annealing with other parameters.");
map.put(START_VALUE, "start value for annealing.");
map.put(END_VALUE, "final annealing value. When the annealing function reaches this value, further results remain constant.");
return map;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,19 @@ public class ReplanningAnnealerTest {
"8;0.1000;0.0500;0.0500;0.9000\n" +
"9;0.0500;0.0250;0.0250;0.9500\n" +
"10;0.0000;0.0000;0.0000;1.0000\n";

private String expectedLinearAnnealMultipleSubpopulations = "it;globalInnovationRate_otherAnnealer;ChangeExpBeta_otherAnnealer;TimeAllocationMutator_otherAnnealer;globalInnovationRate_subpop;ReRoute_subpop;SubtourModeChoice_subpop;ChangeExpBeta_subpop\n" +
"0;0.8000;0.2000;0.8000;0.5000;0.2500;0.2500;0.5000\n" +
"1;0.7200;0.2800;0.7200;0.4500;0.2250;0.2250;0.5500\n" +
"2;0.6400;0.3600;0.6400;0.4000;0.2000;0.2000;0.6000\n" +
"3;0.5600;0.4400;0.5600;0.3500;0.1750;0.1750;0.6500\n" +
"4;0.4800;0.5200;0.4800;0.3000;0.1500;0.1500;0.7000\n" +
"5;0.4000;0.6000;0.4000;0.2500;0.1250;0.1250;0.7500\n" +
"6;0.3200;0.6800;0.3200;0.2000;0.1000;0.1000;0.8000\n" +
"7;0.2400;0.7600;0.2400;0.1500;0.0750;0.0750;0.8500\n" +
"8;0.1600;0.8400;0.1600;0.1000;0.0500;0.0500;0.9000\n" +
"9;0.0800;0.9200;0.0800;0.0500;0.0250;0.0250;0.9500\n" +
"10;0.0000;1.0000;0.0000;0.0000;0.0000;0.0000;1.0000\n";
private String expectedMsaAnneal =
"it;globalInnovationRate;ReRoute;SubtourModeChoice;ChangeExpBeta\n" +
"0;0.5000;0.2500;0.2500;0.5000\n" +
Expand Down Expand Up @@ -360,21 +373,65 @@ void testSubpopulationAnneal() throws IOException {
this.saConfigVar.setStartValue(0.5);
this.saConfigVar.setDefaultSubpopulation(targetSubpop);
this.config.replanning().getStrategySettings().forEach(s -> s.setSubpopulation(targetSubpop));
ReplanningConfigGroup.StrategySettings s = new ReplanningConfigGroup.StrategySettings();
s.setStrategyName("TimeAllocationMutator");
s.setWeight(0.25);
s.setSubpopulation("noAnneal");
this.config.replanning().addStrategySettings(s);

String otherAnnealerSubpopulation = "otherAnnealer";
String othertargetSubpop = otherAnnealerSubpopulation;
ReplanningAnnealerConfigGroup.AnnealingVariable saConfigVar2 = new ReplanningAnnealerConfigGroup.AnnealingVariable();
saConfigVar2.setAnnealType("linear");
saConfigVar2.setEndValue(0.0);
saConfigVar2.setStartValue(0.8);
saConfigVar2.setDefaultSubpopulation(othertargetSubpop);
this.config.replanningAnnealer().addParameterSet(saConfigVar2);

ReplanningConfigGroup.StrategySettings s = new ReplanningConfigGroup.StrategySettings();
s.setStrategyName("TimeAllocationMutator");
s.setWeight(0.25);
s.setSubpopulation(otherAnnealerSubpopulation);
ReplanningConfigGroup.StrategySettings s2 = new ReplanningConfigGroup.StrategySettings();
s2.setStrategyName("ChangeExpBeta"); // shouldn't be affected
s2.setWeight(0.5);
s2.setSubpopulation(otherAnnealerSubpopulation);
this.config.replanning().addStrategySettings(s2);
this.config.replanning().addStrategySettings(s);


ReplanningConfigGroup.StrategySettings noAnnealSettings = new ReplanningConfigGroup.StrategySettings();
noAnnealSettings.setStrategyName("TimeAllocationMutator");
noAnnealSettings.setWeight(0.25);
noAnnealSettings.setSubpopulation("noAnneal");
this.config.replanning().addStrategySettings(noAnnealSettings);

Controler controler = new Controler(this.scenario);
controler.run();

Assertions.assertEquals(expectedLinearAnneal, readResult(controler.getControlerIO().getOutputFilename(FILENAME_ANNEAL)));
Assertions.assertEquals(expectedLinearAnnealMultipleSubpopulations, readResult(controler.getControlerIO().getOutputFilename(FILENAME_ANNEAL)));

StrategyManager sm = controler.getInjector().getInstance(StrategyManager.class);
List<Double> weights = sm.getWeights(targetSubpop);
List<Double> weights2 = sm.getWeights(otherAnnealerSubpopulation);

Assertions.assertEquals(1.0, weights.stream().mapToDouble(Double::doubleValue).sum(), 1e-4);
Assertions.assertEquals(1.0, weights2.stream().mapToDouble(Double::doubleValue).sum(), 1e-4);
}

@Test
void testNullSubpopulationAnneal() throws IOException {
String targetSubpop = null;
this.saConfigVar.setAnnealType("linear");
this.saConfigVar.setEndValue(0.0);
this.saConfigVar.setStartValue(0.5);
this.saConfigVar.setDefaultSubpopulation(targetSubpop);
this.config.replanning().getStrategySettings().forEach(s -> s.setSubpopulation(targetSubpop));

Controler controler = new Controler(this.scenario);
controler.run();

Assertions.assertEquals(expectedLinearAnneal, readResult(controler.getControlerIO().getOutputFilename(FILENAME_ANNEAL)));

StrategyManager sm = controler.getInjector().getInstance(StrategyManager.class);
List<Double> weights = sm.getWeights(targetSubpop);

Assertions.assertEquals(1.0, weights.stream().mapToDouble(Double::doubleValue).sum(), 1e-4);
}

}

0 comments on commit fdac3c1

Please sign in to comment.