diff --git "a/\346\234\200\344\274\230\345\214\226\347\256\227\346\263\225/IRLS.md" "b/\346\234\200\344\274\230\345\214\226\347\256\227\346\263\225/IRLS.md" index 4b33232..a983a1a 100644 --- "a/\346\234\200\344\274\230\345\214\226\347\256\227\346\263\225/IRLS.md" +++ "b/\346\234\200\344\274\230\345\214\226\347\256\227\346\263\225/IRLS.md" @@ -51,6 +51,27 @@ $$W_{i}^{(t)} = |y_{i} - X_{i}\beta^{(t)}|^{p-2}$$   有一点需要说明的是,这段代码中标签和权重的更新并没有参照上面的原理或者说我理解有误。 +### 2.2 训练新的模型 + +```scala + // 使用更新过的样本训练新的模型 + model = new WeightedLeastSquares(fitIntercept, regParam, elasticNetParam = 0.0, + standardizeFeatures = false, standardizeLabel = false).fit(newInstances) + + // 检查是否收敛 + val oldCoefficients = oldModel.coefficients + val coefficients = model.coefficients + BLAS.axpy(-1.0, coefficients, oldCoefficients) + val maxTolOfCoefficients = oldCoefficients.toArray.reduce { (x, y) => + math.max(math.abs(x), math.abs(y)) + } + val maxTol = math.max(maxTolOfCoefficients, math.abs(oldModel.intercept - model.intercept)) + if (maxTol < tol) { + converged = true + } +``` +  训练完新的模型后,重复2.1步,直到参数收敛或者到达迭代的最大次数。 + ## 3 参考文献 【1】[Iteratively reweighted least squares](https://en.wikipedia.org/wiki/Iteratively_reweighted_least_squares)