菜单 搜索

机器学习实战:线性回归模型预测电视,报纸和广播哪个广告效果好

媒体资源网 http://www.allchina.cn 2023/6/16

序言

我是实用主义者,本专栏宗旨亦如此。很多教材、教程、视频把理论讲的已经足够多了,相信大家也都看过不少。但是看过不等于学会,更不等于会用。 本专栏现在开始模型篇,我的重点是实战,大家重点看我的实战代码部分即可。我相信,大家遇到类似案例或场景时,结合我的代码可以立即上手。

专栏限时优惠,欢迎订阅

简单线性回归

做为机器学习入门的经典模型,线性回归是绝对值得大家深入的推导实践的,而在众多的模型中,也是相对的容易。线性回归模型主要是用于线性建模,假设样本数为m个,特征为n个。线性回归的任务就是建立这样一个模型: h(x)=w1x1 w2x2 ... wn*xn b,其中wi是权重,xi是特征。

模型损失函数,它表示的是样本值y和预测值h(x)的距离。

我们的任务是找到参数使J(θ)最小,在机器学习中使用的是学习方法,通过优化的方式得到最优解,下面我们是用梯度下降来进行模型的求解。

得到了梯度之后,我们可以通过梯度下降的方式不断更新参数得到最优解。

梯度下降寻找损失函数最小值的过程,也是拟合效果越来越好的过程。

多元线性回归模型

简单线性回归模型是多元模型的特例,下图的多元模型原理及上面的截图都是引用Jim Liang的《Getting Started with Machine Learning》,特此说明。

下面我们的线性回归实战案例就是多元模型,这里就不多说了。

模型的评价

训练好了模型,我们需要衡量对模型泛化性能进行评估,回归模型常用的有以下几种:

1、均方误差(Mean Squared Error, MSE),MSE的值越小越好。

2、决定系数(coefficient of determination)(R^2),R方的值越大,拟合的效果越好,最优值是1。

3、校正R平方(Adjusted R-squared),多变量的情况时使用,值越大,拟合的效果越好。

Sklearn API详解

实用主义者直接调用sklearn,调用方法和参数详解。

clf.coef_:获取训练会的线性函数X参数的权值clf.intercept_:训练后模型截距clf.predict:根据输出值进行预测clf._decision_function:根据输入进行预测的第二种方法clf.score:对预测结果进行评估打分clf.get_params:获取本次训练模型的参数值clf.set_params:修改模型的参数值线性回归实战案例

我们使用《An Introduction to Statistical Learning》一书中的数据集,即电视广告(TV),报纸广告(Newspaper),和广播广告(Radio)对于产品销量的影响。具体来说,一个公司同时通过这三种广告媒介进行宣传,在不同的广告预算下,产品销量也不同。我们希望通过数据分析了解不同的广告渠道对销量有什么影响,并最大化广告对于销量的增益。

导入库

import numpy as npimport pandas as pd import seaborn as snsimport matplotlib.pyplot as plt from sklearn.linear_model import LinearRegressionfrom sklearn.model_selection import train_test_splitfrom sklearn.metrics import r2_score

导入数据

df=pd.read_csv('..Advertising.csv')

X, y = df.iloc[:, :-1], df.iloc[:, -1]

切分训练集和测试集

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)

看一下数据集整体情况

df.describe()

特征和数据量都不大,所以我们可以通过散点图直接观察一下

sns.set(style="darkgrid")g_TV = sns.jointplot("TV", "sales", data=df, kind="reg", truncate=False, color="m", height=7)g_radio = sns.jointplot("radio", "sales", data=df, kind="reg", truncate=False, color="m", height=7)g_newspaper = sns.jointplot("newspaper", "sales", data=df, kind="reg", truncate=False, color="m", height=7)

其实更简单点,一行代码即可

g6 = sns.pairplot(df, kind="reg")

{!-- PGC_COLUMN --}

数据可视化大家可以订阅我的另一个专栏《Python数据可视化》

观察上图,直接上多元线性回归模型,训练模型:

lr = LinearRegression()model = lr.fit(X_train, y_train)

看一下模型参数

print('sita参数 =', lr.coef_)print('截距项 = ', lr.intercept_)

sita参数 = [ 0.04630652 0.1825716 -0.00352643]

截距项 = 3.170281182647294

模型评估:

y_hat = lr.predict(X_test)mse = np.average((y_hat - np.array(y_test)) ** 2) rmse = np.sqrt(mse) est_error = r2_score(y_test, y_hat)

MSE = 2.286656053378682

RMSE = 1.5121693203403785

r2_score = 0.9123113461020417

拟和效果不错,再来可视化看一下:

t = np.arange(len(X_test))plt.plot(t, y_test, 'r-', linewidth=2, label='real data')plt.plot(t, y_hat, 'g-', linewidth=2, label='predict data')plt.legend(loc='upper right')

代码、数据、高清导图,请私信:线性回归