Skip to content

Multi task Learning on Shifu

RoastEgg edited this page Sep 4, 2019 · 2 revisions

Why Multi-task Learning?

There are too many models in the production environment. Multi-task Learning can save multiple model building efforts in production. By implicit data augmentation and representation bias, Multi-task Learning may improve the performance of models.

How to use Multi-task Learning model in Shifu?

Running MTL in SHIFU is very easy. The change of Init,Stats,Norm and VarSel steps from Shifu NN is ModelConfig.json: "dataset"->"targetColumnName","posTags" and "negTags". You can split tasks with "|||", such as below:

  "dataSet" : {
    "source" : "HDFS",
    "dataPath" : "/user/website/cancer-judgement/DataSet1",
    "dataDelimiter" : "|",
    "headerPath" : "/user/website/cancer-judgement/DataSet1/.pig_header",
    "headerDelimiter" : "|",
    "filterExpressions" : "",
    "weightColumnName" : "",
    "targetColumnName" : "diagnosis|||diagnosis|||diagnosis",
    "posTags" : [ "M|||M|||M" ],
    "negTags" : [ "B|||B|||B" ],
    "metaColumnNameFile" : "columns/meta.column.names",
    "categoricalColumnNameFile" : "columns/categorical.column.names",
    "validationDataPath" : null,
    "validationFilterExpressions" : "",
    "missingOrInvalidValues" : [ "", "*", "#", "?", "null", "~" ]
  }

The change of training from Shifu NN is ModelConfig.json: "train"->"algorithm" to "MTL", such as below:

  "train" : {
    "baggingNum" : 1,
    "baggingWithReplacement" : false,
    "baggingSampleRate" : 1.0,
    "validSetRate" : 0.2,
    "numTrainEpochs" : 200,
    "isContinuous" : false,
    "workerThreadCount" : 4,
    "algorithm" : "MTL",
    "params" : {
      "Propagation" : "R",
      "LearningRate" : 0.1,
      "NumHiddenNodes" : [ 50 ],
      "NumHiddenLayers" : 1,
      "RegularizedConstant" : 0.0,
      "ActivationFunc" : [ "tanh" ]
    },
    "customPaths" : { }
  }

How To Do Unit Tests of Multi-task Learning?

You should use normalized data and the modelConfig accordingly. You should also prepare the ColumnConfigs for different tasks respectively. All ColumnConfigs are under the catalogue "mtl" which is at the same level as "ModelConfig.json". You may see the structure obviously in structure of data and configs for unit tests of MTL. Please don't forget to modify the method "loadConfigs" in MTLMaster and MTLWorker.

    @Test
    public void testMultiTaskNN() {
        Properties props = new Properties();
        props.setProperty(GuaguaConstants.MASTER_COMPUTABLE_CLASS, MTLMaster.class.getName());
        props.setProperty(GuaguaConstants.WORKER_COMPUTABLE_CLASS, MTLWorker.class.getName());
        props.setProperty(GuaguaConstants.GUAGUA_ITERATION_COUNT, "10");
        props.setProperty(GuaguaConstants.GUAGUA_MASTER_RESULT_CLASS, MTLParams.class.getName());
        props.setProperty(GuaguaConstants.GUAGUA_WORKER_RESULT_CLASS, MTLParams.class.getName());
        props.setProperty(CommonConstants.MODELSET_SOURCE_TYPE, "LOCAL");
        String modelConfigJson = getClass().getResource("/model/MultiTaskNN/ModelConfig.json").toString();
        props.setProperty(CommonConstants.SHIFU_MODEL_CONFIG, modelConfigJson
                );
        props.setProperty(GuaguaConstants.GUAGUA_INPUT_DIR,
                getClass().getResource("/data/part-m-00000-mtl-afterNormalized").toString());
        GuaguaUnitDriver<MTLParams, MTLParams> driver = new GuaguaMRUnitDriver<MTLParams, MTLParams>(props);
        driver.run();
    }
Clone this wiki locally