こんにちは!ぼりたそです!
今回はLasso回帰とRidge回帰について分かりやすく解説したいと思います。
この二つの回帰は線形回帰の仲間で主に過学習を防ぐことを目的に使用される手法になります。データ分析や機械学習にもよく使用される回帰手法なので、少しでも理解の助けになればと思います。
この記事は以下のポイントでまとめています。
また、学習データとしてCSVファイルを入力するだけでLasso、Ridge回帰を実行するPythonコードも紹介していますので、ご興味ある方は以下の記事を参照ください。
それでは順に解説していきます。
過学習について
まず、過学習についてお話ししていきます。
過学習とはモデリングや機械学習において学習データに対して過度に適合し、汎化性能が低下する現象です。
つまり、モデルが訓練データには非常によく適合する一方で、未知のデータに対してはうまく予測できなくなることを指します。
下に線形回帰時に過学習した例を示します。
緑線が真の関数に対して、赤線が予測した関数になっていますが、学習点に対して異常なほどフィッティングしてしまっていますね。真の関数と比較すると歪になっており、これだと予測の精度が下がってしまうのもわかるかと思います。
では、この過学習を防ぐにはどうしたら良いのでしょうか。
具体的な対策としてはいくつかありますが、その内の一つとして正則化という手法があります。
この正則化がLassoやRidge回帰に使用されている仕組みであり、通常の線形回帰と異なる点になります。
それでは、正則化について詳しく説明していきます。
正則化とLasso & Ridge回帰
では、正則化がどのようなものか説明していきます。
正則化は過学習を防ぐ手法の一つであり、モデルの複雑さを制御することで汎化性能を向上させます。
具体的に説明すると、まず、通常の線形回帰分析では変数の係数を決める際は最小二乗法を使用しているかと思います。
最小二乗法とは関数が$y = a_1x_1 + a_2x_2 + \ldots + a_nx_n$ ($a_i$は編回帰係数)であるときに以下の損失関数が最小となるように変数の係数を決定する手法です。
■最小二乗法
$$\sum_{i=1}^{N} \left( y_i – (a_1 x_{i1} + a_2 x_{i2} + … + a_n x_{in}) \right)^2$$
しかし、この最小二乗法では過学習に陥ることがあり、その場合、編回帰係数 $a_i$ が極端に大きくなってしまいます。
この過学習を防ぐために正則化項を導入します。
正則化は大きく2種類ありL1正則化、L2正則化項が存在します。
それぞれ以下の式で表されます。
L1正則化項は係数の絶対値の和の項になっており、L2は係数の二乗和の項になっています。
先ほど説明した最小二乗法に正則化項を組み込んだ損失関数を使用して係数決定する手法がLasso回帰やRidge回帰になるのです。
詳しくは次の章でご説明します。
Lasso & Ridge回帰
では、いよいよLasso & Ridge回帰について詳細に解説していきます。
先ほど、最小二乗法に正則化項を組み込んだ損失関数を使用して係数決定する手法がLasso & Ridge回帰と説明しましたが、Lasso回帰、Ridge回帰の損失関数とその特徴を以下に示します。
Lasso & Ridge回帰は上記の損失関数を最小化するように回帰モデルの係数を決定してくれます。なので、係数の和である正則化項を組み込むことで係数を小さく抑え、過学習を防ぐ役割を果たします。
つまり、正則化項は係数の大きさに対してペナルティを課しているということになりますね。また、正則化項のハイパーパラメータである $\lambda$ はペナルティの大きさを制御しており、 $\lambda$ が大きいほどペナルティが大きくなる(係数が小さくなる)ということになります。
通常はこの $\lambda$ はクロスバリデーションなどで最適化するのが一般的と言えます。
LassoとRidge回帰の使い分けについてケースバイケースなところもありますが、まずはRidge回帰で重要そうなパラメータを確認した後に、よりモデルの解釈性を上げたい(特徴量を絞りたい)場合はLassoを使用するのがちょうどいいと思っています。
Lasso & Ridge回帰をPythonで比較
それでは実際に通常の重回帰分析とLasso、Ridge回帰でどの程度回帰性能が異なるのかをPythonで実行することで比較してみましょう。
今回は y=sin(x) の関数に従って生成したデータを過学習しやすいように25次元の多項式変換した関数で通常の重回帰、Lasso回帰、Ridge回帰を行った結果を出力するようにコードにしました。
今回のコードだと、 $\hat{y} = w_0 + w_1 X + w_2 X^2 + \ldots + w_d X^{25}$ の式で示される25個の回帰係数を決めてy=sin(x)にフィットするように学習してくださいというお題になります。
わかりやすいように予測した特徴量の係数とテストデータに対するMSE(平均二乗誤差)を出力しました。
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression, Lasso, Ridge
from sklearn.metrics import mean_squared_error
# データ生成
np.random.seed(0)
X_train = np.linspace(-3, 3, 50) # 10個のトレーニングデータ点を生成
y_train = np.sin(X_train) + np.random.normal(0, 0.2, size=X_train.shape) # 正弦関数にノイズを追加
X_test = np.linspace(-3, 3, 100) # より多くのテストデータ点を生成
y_test = np.sin(X_test) # ノイズのない正弦関数を生成
# 多項式特徴量の追加
poly_degree = 25 # 高次の多項式を使用して過学習を起こす
poly_features = PolynomialFeatures(degree=poly_degree, include_bias=False)
X_train_poly = poly_features.fit_transform(X_train[:, np.newaxis])
X_test_poly = poly_features.transform(X_test[:, np.newaxis])
# モデルの学習
model_normal = LinearRegression()
model_normal.fit(X_train_poly, y_train)
model_lasso = Lasso(alpha=0.01) # alphaは正則化パラメータ
model_lasso.fit(X_train_poly, y_train)
model_ridge = Ridge(alpha=0.1) # alphaは正則化パラメータ
model_ridge.fit(X_train_poly, y_train)
# トレーニングデータとテストデータでの予測
y_train_pred_normal = model_normal.predict(X_train_poly)
y_test_pred_normal = model_normal.predict(X_test_poly)
y_train_pred_lasso = model_lasso.predict(X_train_poly)
y_test_pred_lasso = model_lasso.predict(X_test_poly)
y_train_pred_ridge = model_ridge.predict(X_train_poly)
y_test_pred_ridge = model_ridge.predict(X_test_poly)
# MSEの計算
mse_normal = mean_squared_error(y_test, y_test_pred_normal)
mse_lasso = mean_squared_error(y_test, y_test_pred_lasso)
mse_ridge = mean_squared_error(y_test, y_test_pred_ridge)
# 回帰係数の取得
df_normal_coefs = pd.DataFrame({"Normal Coefficients": model_normal.coef_})
df_lasso_coefs = pd.DataFrame({"Lasso Coefficients": model_lasso.coef_})
df_ridge_coefs = pd.DataFrame({"Ridge Coefficients": model_ridge.coef_})
# 係数データフレームの結合
df_coef = pd.concat([df_normal_coefs, df_lasso_coefs, df_ridge_coefs], axis=1)
# 結果のプロット
plt.figure(figsize=(12, 8))
# 正解のプロット
plt.plot(X_test, y_test, label='True function', color='green')
# 予測のプロット
plt.plot(X_test, y_test_pred_normal, label='Normal model prediction', color='blue', linestyle='--')
plt.plot(X_test, y_test_pred_lasso, label='Lasso model prediction', color='red', linestyle='--')
plt.plot(X_test, y_test_pred_ridge, label='Ridge model prediction', color='orange', linestyle='--')
# トレーニングデータのプロット
plt.scatter(X_train, y_train, label='Training data', color='black')
plt.xlabel('X')
plt.ylabel('y')
plt.title('Overfitting Example with Polynomial Regression')
plt.legend()
#plt.show()
plt.savefig('graph.png')
# 回帰係数の出力
print("Regression Coefficients:\n", df_coef)
# MSEの出力
print("\nNormal Model MSE:", mse_normal)
print("Lasso Model MSE:", mse_lasso)
print("Ridge Model MSE:", mse_ridge)
'''
Regression Coefficients:
Normal Coefficients Lasso Coefficients Ridge Coefficients
0 -0.197344 8.549774e-01 1.136630e+00
1 0.574083 0.000000e+00 -1.156902e-01
2 15.426225 -8.480762e-02 -2.529196e-01
3 -3.238403 -1.037514e-02 -1.159110e-01
4 -51.063047 -6.826189e-03 -2.615471e-01
5 6.035358 1.367322e-03 2.825175e-02
6 76.264491 2.464039e-04 7.167233e-02
7 -6.164472 2.370437e-04 1.354192e-01
8 -64.458579 2.572652e-05 1.198943e-01
9 4.071624 -2.993244e-06 4.272102e-02
10 34.206110 1.438850e-06 1.541273e-02
11 -1.810542 -1.849534e-06 -1.447259e-01
12 -12.032257 1.239650e-07 -9.247012e-02
13 0.542904 -2.224174e-07 8.427362e-02
14 2.878871 1.141321e-08 5.305983e-02
15 -0.108857 -1.773508e-08 -2.407649e-02
16 -0.470794 7.889044e-10 -1.463769e-02
17 0.014311 -8.049821e-10 3.924339e-03
18 0.051784 1.335531e-11 2.295029e-03
19 -0.001181 4.546795e-11 -3.723504e-04
20 -0.003662 -7.319172e-12 -2.092187e-04
21 0.000055 1.859530e-11 1.919651e-05
22 0.000150 -1.677709e-12 1.035714e-05
23 -0.000001 3.321028e-12 -4.165391e-07
24 -0.000003 -2.588140e-13 -2.157051e-07
Normal Model MSE: 0.03492194119927194
Lasso Model MSE: 0.009979363887785631
Ridge Model MSE: 0.025455118945035337
'''
出力されたグラフは以下の通りとなっております。緑線が真の関数の曲線になっており、青線が通常の重回帰、黄線がRidge回帰、赤線がLasso回帰の結果となっております。
重回帰とRidge回帰は若干過学習気味ですね。一方でLasso回帰はきちんと曲線を描いており、最も真の関数に近そうですね。
実際にMSEを見ると「Lasso>Ridge>重回帰」の順で予測性能がよく、過学習はRidgeよりLasso回帰の方が防止する効果があるようですね。また、係数に注目しても通常の重回帰と比較して特にLasso回帰は係数が小さく抑えられているのがわかると思います。
終わりに
以上がLasso & Ridge回帰についての説明になります。過学習はデータ予測をする上で強敵になりますので、少しでもリスクヘッジをしたいところですよね。モデルの解釈性を上げる意味でもLasso回帰などは結構有用性があるかと思いますので、ぜひ覚えていきたいですね。