From 7db41327729d6dd4261f9cf67381ffbb88bf1e8c Mon Sep 17 00:00:00 2001 From: endy Date: Wed, 25 Jan 2017 17:15:51 +0800 Subject: [PATCH] IRLS --- .../glr.md" | 2 +- .../IRLS.md" | 43 ++++++++++++++++++- 2 files changed, 43 insertions(+), 2 deletions(-) diff --git "a/\345\210\206\347\261\273\345\222\214\345\233\236\345\275\222/\347\272\277\346\200\247\346\250\241\345\236\213/\345\271\277\344\271\211\347\272\277\346\200\247\345\233\236\345\275\222/glr.md" "b/\345\210\206\347\261\273\345\222\214\345\233\236\345\275\222/\347\272\277\346\200\247\346\250\241\345\236\213/\345\271\277\344\271\211\347\272\277\346\200\247\345\233\236\345\275\222/glr.md" index 898d1f9..df8edec 100644 --- "a/\345\210\206\347\261\273\345\222\214\345\233\236\345\275\222/\347\272\277\346\200\247\346\250\241\345\236\213/\345\271\277\344\271\211\347\272\277\346\200\247\345\233\236\345\275\222/glr.md" +++ "b/\345\210\206\347\261\273\345\222\214\345\233\236\345\275\222/\347\272\277\346\200\247\346\250\241\345\236\213/\345\271\277\344\271\211\347\272\277\346\200\247\345\233\236\345\275\222/glr.md" @@ -96,7 +96,7 @@ println(s"Intercept: ${model.intercept}") irlsModel.diagInvAtWA.toArray, irlsModel.numIterations, getSolver) model.setSummary(Some(trainingSummary)) ``` -  迭代再加权最小二乘的分析见最优化章节:[迭代再加权最小二乘](../分类和回归/线性模型/广义线性回归/IRLS.md)。 +  迭代再加权最小二乘的分析见最优化章节:[迭代再加权最小二乘](../../../最优化算法/IRLS.md)。 ### 3.3 链接函数 diff --git "a/\346\234\200\344\274\230\345\214\226\347\256\227\346\263\225/IRLS.md" "b/\346\234\200\344\274\230\345\214\226\347\256\227\346\263\225/IRLS.md" index 3752e92..4b33232 100644 --- "a/\346\234\200\344\274\230\345\214\226\347\256\227\346\263\225/IRLS.md" +++ "b/\346\234\200\344\274\230\345\214\226\347\256\227\346\263\225/IRLS.md" @@ -1,6 +1,6 @@ # 迭代再加权最小二乘 -## 原理 +## 1 原理   迭代再加权最小二乘(`IRLS`)用于解决特定的最优化问题,这个最优化问题的目标函数如下所示: @@ -13,3 +13,44 @@ $$\beta ^{t+1} = argmin_{\beta} \sum_{i=1}^{n} w_{i}(\beta^{(t)}))|y_{i} - f_{i}   在这个公式中,$W^{(t)}$是权重对角矩阵,它的所有元素都初始化为1。每次迭代中,通过下面的公式更新。 $$W_{i}^{(t)} = |y_{i} - X_{i}\beta^{(t)}|^{p-2}$$ + +## 2 源码分析 + +  在`spark ml`中,迭代再加权最小二乘主要解决广义线性回归问题。下面看看实现代码。 + +### 2.1 更新权重 + +```scala + // Update offsets and weights using reweightFunc + val newInstances = instances.map { instance => + val (newOffset, newWeight) = reweightFunc(instance, oldModel) + Instance(newOffset, newWeight, instance.features) + } +``` +  这里使用`reweightFunc`方法更新权重。具体的实现在广义线性回归的实现中。 + +```scala + /** + * The reweight function used to update offsets and weights + * at each iteration of [[IterativelyReweightedLeastSquares]]. + */ + val reweightFunc: (Instance, WeightedLeastSquaresModel) => (Double, Double) = { + (instance: Instance, model: WeightedLeastSquaresModel) => { + val eta = model.predict(instance.features) + val mu = fitted(eta) + val offset = eta + (instance.label - mu) * link.deriv(mu) + val weight = instance.weight / (math.pow(this.link.deriv(mu), 2.0) * family.variance(mu)) + (offset, weight) + } + } + + def fitted(eta: Double): Double = family.project(link.unlink(eta)) +``` +  这里的`model.predict`利用带权最小二乘模型预测样本的取值,然后调用`fitted`方法计算均值函数$\mu$。`offset`表示 +更新后的标签值,`weight`表示更新后的权重。关于链接函数的相关计算可以参考[广义线性回归](../分类和回归/线性模型/广义线性回归/glr.md)的分析。 + +  有一点需要说明的是,这段代码中标签和权重的更新并没有参照上面的原理或者说我理解有误。 + +## 3 参考文献 + +【1】[Iteratively reweighted least squares](https://en.wikipedia.org/wiki/Iteratively_reweighted_least_squares)