08 Calculating the equation of a regression line
08 Calculating the equation of a regression line#
%%html
<iframe width="700" height="400" src="https://www.youtube.com/embed/FGesqq22TCM/" frameborder="0" allowfullscreen></iframe>
import numpy as np
import pandas as pd
from pandas import Series, DataFrame
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
from sklearn.linear_model import LinearRegression
https://en.wikipedia.org/wiki/Simple_linear_regression
\[ \hat{y} = mx+b \]
\[ m = r \frac{S_{y}}{S_{x}} \]
\[ (\bar{x}, \bar{y}) \]
x = np.array([1, 2, 2, 3])
y = np.array([1, 2, 3, 6])
def reg_line(x, y):
x_bar = x.mean()
y_bar = y.mean()
s_x = x.std(ddof=1)
s_y = y.std(ddof=1)
r, _ = stats.pearsonr(x, y)
m = r * s_y / s_x
# from y_hat = mx + b. then b will be equal to
b = y_bar - m * x_bar
print(f'yhat = {m}x + {b}')
return m, b
m, b = reg_line(x, y)
yhat = 2.5x + -2.0
slope, intercept, r_value, p_value, std_err = stats.linregress(x, y)
line = intercept + slope * x
x_test, y_test = x.reshape(-1, 1), y.reshape(-1, 1)
reg = LinearRegression()
reg.fit(x_test, y_test)
y_pred = reg.predict(x_test)
print('slope', reg.coef_, slope, m)
print('intercept', reg.intercept_, intercept, b)
#print('R squard', reg.score(x_, y_), r_squard)
print('line', y_pred, line)
slope [[2.5]] 2.5 2.5
intercept [-2.] -2.0 -2.0
line [[0.5]
[3. ]
[3. ]
[5.5]] [0.5 3. 3. 5.5]
plt.scatter(x, y, label='original data')
plt.plot(x, intercept + slope * x, color='r', label='fitted line')
plt.legend()
plt.show()
plt.scatter(x, y, label='original data')
plt.plot(x, b + m * x, color='r', label='fitted line')
plt.legend()
plt.show()
plt.scatter(x_test, y_test, label='original data')
plt.plot(x_test, y_pred, color='r', label='fitted line')
plt.legend()
plt.show()
sns.scatterplot(x, y, label='original data')
sns.lineplot(x, b + m*x, color='r', label='fitted line')
/opt/hostedtoolcache/Python/3.9.13/x64/lib/python3.9/site-packages/seaborn/_decorators.py:36: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.
warnings.warn(
/opt/hostedtoolcache/Python/3.9.13/x64/lib/python3.9/site-packages/seaborn/_decorators.py:36: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.
warnings.warn(
<AxesSubplot:>
sns.scatterplot(x_test.reshape(-1), y_test.reshape(-1), label='original data')
sns.lineplot(x_test.reshape(-1), y_pred.reshape(-1), color='r', label='fitted line')
/opt/hostedtoolcache/Python/3.9.13/x64/lib/python3.9/site-packages/seaborn/_decorators.py:36: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.
warnings.warn(
/opt/hostedtoolcache/Python/3.9.13/x64/lib/python3.9/site-packages/seaborn/_decorators.py:36: FutureWarning: Pass the following variables as keyword args: x, y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.
warnings.warn(
<AxesSubplot:>