-
Notifications
You must be signed in to change notification settings - Fork 0
/
线性回归
70 lines (58 loc) · 2.46 KB
/
线性回归
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import numpy as np
from numpy import *
from matplotlib import pyplot as plt
# 加载数据集,
def loadDataSet(fileName): #general function to parse tab -delimited floats
numFeat = len(open(fileName).readline().split('\t')) - 1 #get number of fields
xArr = [] # x数据集
yArr = [] # y数据集
fr = open(fileName)
for line in fr.readlines():
lineArr =[]
curLine = line.strip().split('\t')
for i in range(numFeat):
lineArr.append(float(curLine[i]))
xArr.append(lineArr)
# 最后一列是y的值
yArr.append(float(curLine[-1]))
return xArr, yArr
# 计算回归系数w
def standRegres(xArr,yArr):
'''
计算回归系数
:param xArr: x数据集
:param yArr: y数据集
:return: 回归系数
'''
xMat = mat(xArr)
yMat = mat(yArr).T # 由于yArr是一个列表, 而yMat需要的是一个列向量, 所以需要转置
xTx = xMat.T*xMat
# 前提条件, xTx不可逆
if linalg.det(xTx) == 0.0:
print("This matrix is singular, cannot do inverse")
return
ws = xTx.I * (xMat.T*yMat)
return ws
def plotRegression(xArr, yArr, ws):
"""
函数说明:绘制回归曲线和数据点
"""
xMat = np.mat(xArr) #创建xMat矩阵
yMat = np.mat(yArr) #创建yMat矩阵
xCopy = xMat.copy() #深拷贝xMat矩阵
xCopy.sort(0) #排序 如果直线的数据点次序混乱,绘图的时候会出现问题。所以先将点按照升序排列
yHat = xCopy * ws #计算对应的y值
fig = plt.figure()
ax = fig.add_subplot(111) #添加subplot
ax.plot(xCopy[:, 1], yHat, c = 'red') #绘制回归曲线
ax.scatter(xMat[:,1].flatten().A[0], yMat.flatten().A[0], s = 20, c = 'blue',alpha = .5) #绘制样本点
plt.title('DataSet') #绘制title
plt.xlabel('X')
plt.show()
if __name__ == '__main__':
# 加载数据集
xArr, yArr = loadDataSet('E:\\Python\\data.txt')
# 计算回归系数
ws = standRegres(xArr, yArr)
# 绘制回归曲线
plotRegression(xArr, yArr, ws)