RのWeb制作

Webサービス制作のための技術情報を。データ分析(Python、機械学習コンペ他)や自作野球ゲームMeisyoのこと中心。

Python データサイエンス

[Python]グリッドサーチを軽量化し、チューニングしたパラメータも反映する機構を作る

投稿日:

パラメータチューニング方法であるグリッドサーチ、
確かに自動で実行してくれて、すごく便利なのですが問題点があります。

めっちゃ時間がかかる
もし、下記のパラメータ設定のモノを全てグリッドサーチしようとすれば、ゲーミングPCでも余裕で24時間を超えます
・・・特に時間制限のあるColaboratoryでは安易に使えない。

ただ、パラメータ1つ1つグリッドサーチすると
それはそれで「そのパラメータでしかモデルを捉えられない」という問題をはらみます。

そこで、今回は下記の方法の実装を行いました。

グリッドサーチしてパラメータ1を「A」に決める

パラメータ1は「A」として、
グリッドサーチしてパラメータ2を「B」に決める

パラメータ1は「A」、パラメータ2は「B」として、
グリッドサーチしてパラメータ3を「C」に決める

なぜこんなことをしたいのかというと、
適当にパラメータを投入すれば勝手に上手くやってくれるし、時間もかからない機能が欲しかっただけです。

モジュールのインポート

モジュールを必要最低限呼び出します。
scalerはお好きなものをご使用ください。

from sklearn.pipeline import Pipeline
from sklearn.model_selection import GridSearchCV
from sklearn.preprocessing import MinMaxScaler, StandardScaler, RobustScaler, Normalizer
scaler = MinMaxScaler()
#scaler = StandardScaler()
#scaler = RobustScaler()
#scaler = Normalizer(copy=True, norm='l2')

学習器のインポート

今回は回帰の問題だったのでSVRや回帰型XGBを使用しましたが、分類も対応可能です。

from sklearn.svm import SVR # サポートベクター回帰
import xgboost as xgb # XGB

データのインポート

pandasで読み込まれており、前処理は終わっているX、yを仮定します。
ただし、正規化(または標準化等)はまだ実施していないものとします。

X = ...
y = ...

パラメータ設定

設定できる項目はまだまだありますが、一つの例として記載します。

cv_list = {
    'XGB':
    {
        'discriminator': xgb.XGBRegressor(objective='reg:linear'),
        'random_state': [0],
        'booster': ['gbtree', 'gblinear'],
        'n_estimators': [1, 3, 5, 10, 20, 30, 50, 100, 200],
        'max_depth': [3, 4, 5, 6, 7, 8, 9, 10],
        'subsample': 1.1 - 10 ** np.linspace(-1, -0.8, 20), # ~1
        'learning_rate': 10 ** np.linspace(-1.2, -0.9, 20),
        'gamma': np.linspace(0, 1.0, 20), # 0~
        'reg_lambda': 1.1 - 10 ** np.linspace(-1, -0.3, 20), # ~1
        'reg_alpha': -0.1 + 10 ** np.linspace(-1, -0.3, 20) # 0~
    },
    'SVR':
    {
        'discriminator': SVR(),
        'kernel': ['poly', 'rbf', 'sigmoid', 'linear'],
        'degree': np.arange(1, 10, 1),
        'C': 10 ** np.linspace(-5, 3, 5),
        'gamma': 10 ** np.linspace(-5, 0, 5),
        'epsilon': 10 ** np.linspace(-5, 0, 5)
    }
}

あるパラメータが必要だと思えば追加すれば良いと思います。

グリッドサーチ実行

複数個のモデルのグリッドサーチに対応しています。
パイプラインを使ってデータ分割→scaler実施→model.fit→判定を自動で行います。
パイプラインを使用すると、param_gridをモデル名__パラメータ名にしなければならないことにも対応しています。
最後に一番良いパラメータを表示します。

for model in cv_list:
    # model作成
    print("#---------------------------------------------------------#")
    print("#", model, "loaded")
    print("#---------------------------------------------------------#")
    model_test = cv_list[model]['discriminator']
    model_name = "model"
    pipe_model = Pipeline([("scaler", scaler), (model_name, model_test)])
    model_param = cv_list[model].copy()
    del model_param['discriminator']
    
    # paramごとにCV
    best_param = {}
    for param in model_param:
        # make
        param_one = model_name + "__" + param
        param_grid = {}
        param_grid[param_one] = model_param[param]
        param_grid.update(best_param) # 以前のチューニング結果を反映
        print(">> Tuning '%s' is..." % (param))
        model_cv = GridSearchCV(pipe_model, param_grid=param_grid, cv=3,
                                return_train_score=False, n_jobs=-1, verbose=0)
        model_cv.fit(X, y)
        # Best
        best_param[param_one] = [getattr(model_cv.best_estimator_.steps[1][1], param)]
        print('Best Params:', best_param[param_one])
    
    # last
    print(">> All Best Params is...")
    print(best_param)

GridSearchCVに自作評価関数を入れることも可能です。

scoring=make_scorer(func_scoring, greater_is_better=True)

最後に

10分もかからずにグリッドサーチが終わると思います。
すべてのパラメータを総当りで行うのはコスト(時間)がかかりすぎるのでおすすめはしません。

お役に立てたのであれば嬉しい限りです。

-Python, データサイエンス

執筆者:


comment

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

関連記事

【教材紹介】データ解析の実務プロセス入門

「データ分析を会社で初めて行いたい」「データ分析を任されたがどうすればいいかよく分からない」というときはこちらの書籍がおすすめ。良いデータ分析を構成する分析プロセスからデータの収集方法、探索的データ解 …

(VPSでつくる) Python(Flask)でMariaDB(MySQL)へ接続できるアプリをもっと読みやすく改良してみよう

連載第十二回目です。 前回の記事で、Python3.6.8+FlaskでMariaDBに接続・データベースを編集するアプリを動作させる設定を行い、動作確認しました。 今回は、機能は前回と全く同じアプリ …

手書き数字診断士(機械学習)ver 0.0

手書き数字診断士、まずは動くようにしました。 ただ、初っ端から間違えています・・・! 動画 http://webmaking.rei-farms.jp/wp-content/uploads/2018/ …

[Python] tensorflow_datasetsで詰まったとき

「図解速習 DEEPLEARNING」で自己環境(Windows)で学習していました。 tensorflow_datasetsって何だ・・・? import tensorflow_datasets a …

[Meisyo] 打撃・守備のバランス調整(v0.40)

変更概要 守備力を上方修正します。 詳細に言うと、OPSに対する影響度を、守備力=ミートまたは反応の有利な能力値にしました。 これまではOPSに対する影響は、守備力<ミートまたは反応の有利な能力値(2 …