こんにちは、DXCEL WAVEの運営者(@dxcelwave)です!
- Pythonで作成した機械学習の学習済みモデルをローカルPC内に保存したい。
- 学習済みモデルを自由にロードし、再利用できるようにしたい。
【Python】Pickleとは
Pickleとは、Pythonオブジェクトのシリアライズ/デシリアライズを通じ、オブジェクトの保存や復元を行うためのライブラリです。
シリアライズ(直列化)とは、プログラミング言語により作成されたオブジェクトをバイト列などに変換すること、変更後の状態を維持することを指します。
PickleをもとにPythonオブジェクトをシリアライズし、ファイル形式でお手元のPC内に保存します。これにより、別の環境からもPythonオブジェクトが自由に読み込み(ロード)できるようになります。
Pythonオブジェクトをロードする際は、バイト列を元々のオブジェクトに復元する処理を実行します。これをデシリアライズ(非直列化)と言います。
本記事では、Pickleを用いて以下を実現する方法について解説します。
- 機械学習モデルをシリアライズし、お手元のPC内にファイルを保存する
- 機械学習モデルをでデシリアライズし、Python環境に機械学習モデルをロードする
【Python実践】Pickleで機械学習モデルを保存&ロード
機械学習モデルを作成し、学習済みモデルを保存、さらに学習済みモデルをロードし推論に至る実践的な流れを以下に示します。
- データセットの準備
- データの読込
- モデル学習
- 学習済みモデルの保存(Pickle)
- 学習済みモデルのロード(Pickle)
- 学習済みモデルを用いた推論
データセットの説明
データセットには、機械学習のサンプルデータとして有名なIris(アヤメ)データセットを活用します。3種類のアヤメ(Iris Setosa, Iris Versicolor, Iris Virginica)があり、それぞれ50サンプルずつ(合計150サンプル)用意されているデータです。このアヤメの名前を目的変数として利用します。また、説明変数にはアヤメの計測値である萼片(sepals)と花びら(petals)の長さと幅の4つを利用します。
データの読込
まず、前述したアヤメのデータセットを準備します。上記目的変数と説明変数をPandas形式で取り扱うために、下記のコードを実行してみましょう。
import numpy as np
import pandas as pd
from sklearn import datasets
# データロード
iris = datasets.load_iris()
# 説明変数
X = iris.data
X = pd.DataFrame(X, columns=["Sepal Length", "Sepal Width", "Petal Length", "Petal Width"])
# 目的変数
Y = iris.target
Y = iris_target = pd.DataFrame(Y, columns = ["Species"])
モデル学習
前述で読み込んだIrisデータを活用し、簡単な決定木モデルをscikit-learn
で作成します。
from sklearn.tree import DecisionTreeClassifier
# 決定木インスタンス
model = DecisionTreeClassifier()
# モデル学習
model.fit(X,Y)
学習済みモデルの保存
前述で作成した決定木モデルをローカルに保存します。
pickle.dump()
メソッドを活用し、次のように実行しましょう。
import pickle
# モデルを保存
filename = 'ml_model.sav'
pickle.dump(model, open(filename, 'wb'))
学習済みモデルのロード
前述のコードを実行すると、ml_model.sav
という名称で学習済みモデルのファイルがカレントディレクトリに保存されているでしょう。次はそのファイルを再びPython環境で読み込みます。
pickle.load()
メソッドを活用し、次のように実行します。この時ml_model.savはバイナリモードで保存されています。バイナリモードで保存されたpickleファイルをロードする際は、バイナリモード'rb'
を指定し、復元する必要があることを考慮しましょう。
# 保存した学習済みモデルを読込
filename = 'ml_model.sav'
model_load = pickle.load(open(filename, 'rb'))
学習済みモデルを用いた推論
最後に、読み込んだ学習済みモデルを用いて推論を行います。
# サンプルデータ
X_test = pd.DataFrame({"Sepal Length": [5.0],
"Sepal Width" : [3.0],
"Petal Length": [1.3],
"Petal Width" : [0.1]})
# 予測結果
predict = model_load.predict(X_test)
print("予測結果: {}".format(predict))
# 出力イメージ
# 予測結果: [0]
【参考】AI・機械学習における配信情報まとめ
当サイトではAI・機械学習における「基礎」から「最新のプログラミング手法」に至るまで幅広く解説しております。また「おすすめの勉強方法」をはじめ、副業・転職・フリーランスとして始める「AI・機械学習案件の探し方」についても詳しく言及しています。
【仕事探し】副業・転職・フリーランス
【教育】おすすめ勉強法
【参考】記事一覧
最後に
お問い合わせフォーム
上記課題に向けてご気軽にご相談下さい。
お問い合わせはこちら