RのWeb制作

Webサービス制作のための技術情報を。データ分析(Python、機械学習コンペ他)や自作野球ゲームMeisyoのこと中心。

Python データサイエンス

[Python]グリッドサーチを軽量化し、チューニングしたパラメータも反映する機構を作る

投稿日:

パラメータチューニング方法であるグリッドサーチ、
確かに自動で実行してくれて、すごく便利なのですが問題点があります。

めっちゃ時間がかかる
もし、下記のパラメータ設定のモノを全てグリッドサーチしようとすれば、ゲーミングPCでも余裕で24時間を超えます
・・・特に時間制限のあるColaboratoryでは安易に使えない。

ただ、パラメータ1つ1つグリッドサーチすると
それはそれで「そのパラメータでしかモデルを捉えられない」という問題をはらみます。

そこで、今回は下記の方法の実装を行いました。

グリッドサーチしてパラメータ1を「A」に決める

パラメータ1は「A」として、
グリッドサーチしてパラメータ2を「B」に決める

パラメータ1は「A」、パラメータ2は「B」として、
グリッドサーチしてパラメータ3を「C」に決める

なぜこんなことをしたいのかというと、
適当にパラメータを投入すれば勝手に上手くやってくれるし、時間もかからない機能が欲しかっただけです。

モジュールのインポート

モジュールを必要最低限呼び出します。
scalerはお好きなものをご使用ください。

from sklearn.pipeline import Pipeline
from sklearn.model_selection import GridSearchCV
from sklearn.preprocessing import MinMaxScaler, StandardScaler, RobustScaler, Normalizer
scaler = MinMaxScaler()
#scaler = StandardScaler()
#scaler = RobustScaler()
#scaler = Normalizer(copy=True, norm='l2')

学習器のインポート

今回は回帰の問題だったのでSVRや回帰型XGBを使用しましたが、分類も対応可能です。

from sklearn.svm import SVR # サポートベクター回帰
import xgboost as xgb # XGB

データのインポート

pandasで読み込まれており、前処理は終わっているX、yを仮定します。
ただし、正規化(または標準化等)はまだ実施していないものとします。

X = ...
y = ...

パラメータ設定

設定できる項目はまだまだありますが、一つの例として記載します。

cv_list = {
    'XGB':
    {
        'discriminator': xgb.XGBRegressor(objective='reg:linear'),
        'random_state': [0],
        'booster': ['gbtree', 'gblinear'],
        'n_estimators': [1, 3, 5, 10, 20, 30, 50, 100, 200],
        'max_depth': [3, 4, 5, 6, 7, 8, 9, 10],
        'subsample': 1.1 - 10 ** np.linspace(-1, -0.8, 20), # ~1
        'learning_rate': 10 ** np.linspace(-1.2, -0.9, 20),
        'gamma': np.linspace(0, 1.0, 20), # 0~
        'reg_lambda': 1.1 - 10 ** np.linspace(-1, -0.3, 20), # ~1
        'reg_alpha': -0.1 + 10 ** np.linspace(-1, -0.3, 20) # 0~
    },
    'SVR':
    {
        'discriminator': SVR(),
        'kernel': ['poly', 'rbf', 'sigmoid', 'linear'],
        'degree': np.arange(1, 10, 1),
        'C': 10 ** np.linspace(-5, 3, 5),
        'gamma': 10 ** np.linspace(-5, 0, 5),
        'epsilon': 10 ** np.linspace(-5, 0, 5)
    }
}

あるパラメータが必要だと思えば追加すれば良いと思います。

グリッドサーチ実行

複数個のモデルのグリッドサーチに対応しています。
パイプラインを使ってデータ分割→scaler実施→model.fit→判定を自動で行います。
パイプラインを使用すると、param_gridをモデル名__パラメータ名にしなければならないことにも対応しています。
最後に一番良いパラメータを表示します。

for model in cv_list:
    # model作成
    print("#---------------------------------------------------------#")
    print("#", model, "loaded")
    print("#---------------------------------------------------------#")
    model_test = cv_list[model]['discriminator']
    model_name = "model"
    pipe_model = Pipeline([("scaler", scaler), (model_name, model_test)])
    model_param = cv_list[model].copy()
    del model_param['discriminator']
    
    # paramごとにCV
    best_param = {}
    for param in model_param:
        # make
        param_one = model_name + "__" + param
        param_grid = {}
        param_grid[param_one] = model_param[param]
        param_grid.update(best_param) # 以前のチューニング結果を反映
        print(">> Tuning '%s' is..." % (param))
        model_cv = GridSearchCV(pipe_model, param_grid=param_grid, cv=3,
                                return_train_score=False, n_jobs=-1, verbose=0)
        model_cv.fit(X, y)
        # Best
        best_param[param_one] = [getattr(model_cv.best_estimator_.steps[1][1], param)]
        print('Best Params:', best_param[param_one])
    
    # last
    print(">> All Best Params is...")
    print(best_param)

GridSearchCVに自作評価関数を入れることも可能です。

scoring=make_scorer(func_scoring, greater_is_better=True)

最後に

10分もかからずにグリッドサーチが終わると思います。
すべてのパラメータを総当りで行うのはコスト(時間)がかかりすぎるのでおすすめはしません。

お役に立てたのであれば嬉しい限りです。

-Python, データサイエンス

執筆者:


comment

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

関連記事

TensorFlow RNNで詰まるの巻

DeeplearningのフレームワークTensorFlowの学習まで漕ぎ着けました。 CNN(画像認識用と言っても過言ではない)はゼロから始めるディープラーニングでだいたいOK。 何度か読み返してわ …

[Python:Predict Gollira]2枚の画像でどちらがゴリラっぽいかを人間が予想する。

「あーこの人ゴリラっぽい。」と思うことはありませんか? ゴリラっぽさってどこから来るんだろうかと悩んでいました。 前回、[Python] ディープラーニングのモデル「VGG16」を使って画像認識をし、 …

[Anaconda]Anacondaが動かない!TypeError: expected str, bytes or os.PathLike object, not NoneType

Anaconda Navigatorが起動できません。 昨日まで動いていたのに・・・。 エラー文はこちら TypeError: expected str, bytes or os.PathLike o …

ヒストグラムの階級数を決める方法論

データ分析業務ははっきり言って泥臭い。 分析の設計を行い、可視化を行ってから使えるデータかどうか判断できる。 そもそもそれはデータ分析前の話なのだが。 今回は、可視化の中でもデータの傾向を把握するのに …

(VPSでつくる) Pythonのバージョンを2.7.5から3.6.8にする

連載第四回目です。 CentOS7にインストールされているPythonのバージョンが2.7なので、バージョンアップをします。 そうしないとPythonのアプリが動きません。 なぜなら、Python2と …