-
Notifications
You must be signed in to change notification settings - Fork 108
Multi task Learning on Shifu
There are too many models in the production environment. Multi-task Learning can save multiple model building efforts in production. By implicit data augmentation and representation bias, Multi-task Learning may improve the performance of models.
Running MTL in SHIFU is very easy. The change of Init,Stats,Norm and VarSel steps from Shifu NN is ModelConfig.json: "dataset"->"targetColumnName","posTags" and "negTags". You can split tasks with "|||", such as below:
"dataSet" : {
"source" : "HDFS",
"dataPath" : "/user/website/cancer-judgement/DataSet1",
"dataDelimiter" : "|",
"headerPath" : "/user/website/cancer-judgement/DataSet1/.pig_header",
"headerDelimiter" : "|",
"filterExpressions" : "",
"weightColumnName" : "",
"targetColumnName" : "diagnosis|||diagnosis|||diagnosis",
"posTags" : [ "M|||M|||M" ],
"negTags" : [ "B|||B|||B" ],
"metaColumnNameFile" : "columns/meta.column.names",
"categoricalColumnNameFile" : "columns/categorical.column.names",
"validationDataPath" : null,
"validationFilterExpressions" : "",
"missingOrInvalidValues" : [ "", "*", "#", "?", "null", "~" ]
}
The change of training from Shifu NN is ModelConfig.json: "train"->"algorithm" to "MTL", such as below:
"train" : {
"baggingNum" : 1,
"baggingWithReplacement" : false,
"baggingSampleRate" : 1.0,
"validSetRate" : 0.2,
"numTrainEpochs" : 200,
"isContinuous" : false,
"workerThreadCount" : 4,
"algorithm" : "MTL",
"params" : {
"Propagation" : "R",
"LearningRate" : 0.1,
"NumHiddenNodes" : [ 50 ],
"NumHiddenLayers" : 1,
"RegularizedConstant" : 0.0,
"ActivationFunc" : [ "tanh" ]
},
"customPaths" : { }
}
You should use normalized data and the modelConfig accordingly. You should also prepare the ColumnConfigs for different tasks respectively. All ColumnConfigs are under the catalogue "mtl" which is at the same level as "ModelConfig.json". You may see the structure obviously in structure of data and configs for unit tests of MTL. Please don't forget to modify the method "loadConfigs" in MTLMaster and MTLWorker.
@Test
public void testMultiTaskNN() {
Properties props = new Properties();
props.setProperty(GuaguaConstants.MASTER_COMPUTABLE_CLASS, MTLMaster.class.getName());
props.setProperty(GuaguaConstants.WORKER_COMPUTABLE_CLASS, MTLWorker.class.getName());
props.setProperty(GuaguaConstants.GUAGUA_ITERATION_COUNT, "10");
props.setProperty(GuaguaConstants.GUAGUA_MASTER_RESULT_CLASS, MTLParams.class.getName());
props.setProperty(GuaguaConstants.GUAGUA_WORKER_RESULT_CLASS, MTLParams.class.getName());
props.setProperty(CommonConstants.MODELSET_SOURCE_TYPE, "LOCAL");
String modelConfigJson = getClass().getResource("/model/MultiTaskNN/ModelConfig.json").toString();
props.setProperty(CommonConstants.SHIFU_MODEL_CONFIG, modelConfigJson
);
props.setProperty(GuaguaConstants.GUAGUA_INPUT_DIR,
getClass().getResource("/data/part-m-00000-mtl-afterNormalized").toString());
GuaguaUnitDriver<MTLParams, MTLParams> driver = new GuaguaMRUnitDriver<MTLParams, MTLParams>(props);
driver.run();
}