RのWeb制作

Webサービス制作のための技術情報を。データ分析(Python、機械学習コンペ他)や自作野球ゲームMeisyoのこと中心。

Python データサイエンス

[Python]グリッドサーチを軽量化し、チューニングしたパラメータも反映する機構を作る

投稿日:

パラメータチューニング方法であるグリッドサーチ、
確かに自動で実行してくれて、すごく便利なのですが問題点があります。

めっちゃ時間がかかる
もし、下記のパラメータ設定のモノを全てグリッドサーチしようとすれば、ゲーミングPCでも余裕で24時間を超えます
・・・特に時間制限のあるColaboratoryでは安易に使えない。

ただ、パラメータ1つ1つグリッドサーチすると
それはそれで「そのパラメータでしかモデルを捉えられない」という問題をはらみます。

そこで、今回は下記の方法の実装を行いました。

グリッドサーチしてパラメータ1を「A」に決める

パラメータ1は「A」として、
グリッドサーチしてパラメータ2を「B」に決める

パラメータ1は「A」、パラメータ2は「B」として、
グリッドサーチしてパラメータ3を「C」に決める

なぜこんなことをしたいのかというと、
適当にパラメータを投入すれば勝手に上手くやってくれるし、時間もかからない機能が欲しかっただけです。

モジュールのインポート

モジュールを必要最低限呼び出します。
scalerはお好きなものをご使用ください。

from sklearn.pipeline import Pipeline
from sklearn.model_selection import GridSearchCV
from sklearn.preprocessing import MinMaxScaler, StandardScaler, RobustScaler, Normalizer
scaler = MinMaxScaler()
#scaler = StandardScaler()
#scaler = RobustScaler()
#scaler = Normalizer(copy=True, norm='l2')

学習器のインポート

今回は回帰の問題だったのでSVRや回帰型XGBを使用しましたが、分類も対応可能です。

from sklearn.svm import SVR # サポートベクター回帰
import xgboost as xgb # XGB

データのインポート

pandasで読み込まれており、前処理は終わっているX、yを仮定します。
ただし、正規化(または標準化等)はまだ実施していないものとします。

X = ...
y = ...

パラメータ設定

設定できる項目はまだまだありますが、一つの例として記載します。

cv_list = {
    'XGB':
    {
        'discriminator': xgb.XGBRegressor(objective='reg:linear'),
        'random_state': [0],
        'booster': ['gbtree', 'gblinear'],
        'n_estimators': [1, 3, 5, 10, 20, 30, 50, 100, 200],
        'max_depth': [3, 4, 5, 6, 7, 8, 9, 10],
        'subsample': 1.1 - 10 ** np.linspace(-1, -0.8, 20), # ~1
        'learning_rate': 10 ** np.linspace(-1.2, -0.9, 20),
        'gamma': np.linspace(0, 1.0, 20), # 0~
        'reg_lambda': 1.1 - 10 ** np.linspace(-1, -0.3, 20), # ~1
        'reg_alpha': -0.1 + 10 ** np.linspace(-1, -0.3, 20) # 0~
    },
    'SVR':
    {
        'discriminator': SVR(),
        'kernel': ['poly', 'rbf', 'sigmoid', 'linear'],
        'degree': np.arange(1, 10, 1),
        'C': 10 ** np.linspace(-5, 3, 5),
        'gamma': 10 ** np.linspace(-5, 0, 5),
        'epsilon': 10 ** np.linspace(-5, 0, 5)
    }
}

あるパラメータが必要だと思えば追加すれば良いと思います。

グリッドサーチ実行

複数個のモデルのグリッドサーチに対応しています。
パイプラインを使ってデータ分割→scaler実施→model.fit→判定を自動で行います。
パイプラインを使用すると、param_gridをモデル名__パラメータ名にしなければならないことにも対応しています。
最後に一番良いパラメータを表示します。

for model in cv_list:
    # model作成
    print("#---------------------------------------------------------#")
    print("#", model, "loaded")
    print("#---------------------------------------------------------#")
    model_test = cv_list[model]['discriminator']
    model_name = "model"
    pipe_model = Pipeline([("scaler", scaler), (model_name, model_test)])
    model_param = cv_list[model].copy()
    del model_param['discriminator']
    
    # paramごとにCV
    best_param = {}
    for param in model_param:
        # make
        param_one = model_name + "__" + param
        param_grid = {}
        param_grid[param_one] = model_param[param]
        param_grid.update(best_param) # 以前のチューニング結果を反映
        print(">> Tuning '%s' is..." % (param))
        model_cv = GridSearchCV(pipe_model, param_grid=param_grid, cv=3,
                                return_train_score=False, n_jobs=-1, verbose=0)
        model_cv.fit(X, y)
        # Best
        best_param[param_one] = [getattr(model_cv.best_estimator_.steps[1][1], param)]
        print('Best Params:', best_param[param_one])
    
    # last
    print(">> All Best Params is...")
    print(best_param)

GridSearchCVに自作評価関数を入れることも可能です。

scoring=make_scorer(func_scoring, greater_is_better=True)

最後に

10分もかからずにグリッドサーチが終わると思います。
すべてのパラメータを総当りで行うのはコスト(時間)がかかりすぎるのでおすすめはしません。

お役に立てたのであれば嬉しい限りです。

-Python, データサイエンス

執筆者:


comment

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

関連記事

決定木分析(Python CHAID)を解釈する

意思決定のために使用される決定木分析 Scikit-learnでの決定木にはCART(指標:giniまたはentropy)他が採用されています。 CARTは下記の2点を含め、さまざまな理由から使われて …

自然言語処理×教師なし学習での温故知新 PythonでBERT-MaskedLM実装

はじめに 自然言語処理(BERT、GPT-3)および画像認識(ViT)等で以前のState of The Artモデルを超える精度を発揮したTransformer(元論文:Attention Is A …

野球ゲームデータで遊ぶデータサイエンス(正規分布の検定編)

名将と呼ばれた者達のデータを使って、データサイエンスを学んでみましょう! 生きた&整えられたデータは中々公開されていないので、今回の野球ゲームのデータは分析に適していると思われます。もちろん、Kagg …

ログがサービス改善の命

Meisyoでは常にログを取って、「ユーザがどこで困ってそうかな」を探し続けています。 探す方法はいたって簡単。 (何か問題があると考えて)ログを眺める 今回のアップデートでは、アイテムの購入数を選択 …

【教材紹介】Python機械学習プログラミング(第3版)* 文量多め

今回の書籍は内容理解する難易度が高めですが、機械学習の基礎(用語・位置付け・アルゴリズム)が網羅できる、Pythonでの機械学習を学ぶためのおすすめ教材を紹介します。 正式名(ISBNコード) [第3 …