データが主食

データエンジニアの備忘録。分析だったり、読んだ本のメモだったり。

MLflowに入門してみた

2019年6月にDatabricks社よりMLflowのv.1.0がリリースされました。

mlflowの背景と解決する課題

機械学習プロジェクトでは、様々な職種(データエンジニア、データサイエンティスト、アプリケーションエンジニアなど)の人間がタスクを実行します。例えば、複数のデータサイエンティストがモデルの精度を競ってアルゴリズムを改善したり、アプリケーションエンジニアがそのモデルをアプリケーションに組み込んだり、データエンジニアが分析基盤の参考にしたり。やはり、複数人が関わると様々な問題が生じます。

f:id:ktr89:20190616225501p:plain

引用: Accelerating Production Machine Learning with MLflow - Databricks

  • 複数人で同じ環境(依存ライブラリのバージョンなど)を構築するのは難しい。
  • 各モデルのスコアやパラメータ情報を複数人で共有するのが煩雑。
  • データサイエンティストが実験した環境をそのままリリースできない。

MLflowがやってくれること

MLflowはSparkの開発で有名なDatabricks社で開発が進められています。

  • MLflowはMLプロジェクト管理をしてくれます
  • MLflowは特定のライブラリ(scikit-learnなど)に依存しません といったところが機能のポイントです。MLflowは3つのモジュールから構成されます。

MLflow tracking

  • ログ記録のためのAPIとその表示機能
with mlflow.start_run():
    for epoch in range(0, 3):
        mlflow.log_metric(key="quality", value=2*epoch, step=epoch)

こんな感じで記述するとローカルマシンでトラッキングサーバーを立ててウェブUIから結果を閲覧できる。

  • databricks社がホスティングしているtrakingサーバーを使うこともできる。

MLflow project

  • パッケージフォーマットを記載することで、環境構築やスクリプトの実行を自動化できる。

MLflow models

  • 学習済みモデルのモジュールの標準的な入出力を定義し、それに従うことでデプロイなどを自動化できる。

とりあえず動かしてみる

mlflow.org

にしたがって、とりあえず動かしてみようと思います。 MacOSだと環境に縛りがあるようなので、Docker(on Mac)を利用して実行してみます。

docker pull conda/c3i-linux-64 
docker run -it conda/c3i-linux-64
cd 
pip install mlflow
git clone https://github.com/mlflow/mlflow
cd mlflow/examples/sklearn_logistic_regression

これで環境構築は完了です。

mlflow run .

出力

2019/06/16 13:41:05 INFO mlflow.projects: === Creating conda environment mlflow-7fe9f24d100ade6523081ed2190877c5976a985c ===
Warning: you have pip-installed dependencies in your environment file, but you do not list pip itself as one of your conda dependencies.  Conda may not use the correct pip to install your packages, and they may end up in the wrong place.  Please add an explicit pip dependency.  I'm adding one for you, but still nagging you.
Collecting package metadata: done
Solving environment: done

Downloading and Extracting Packages
mkl_random-1.0.1     | 373 KB    | ############################################################### | 100%
python-3.6.0         | 16.3 MB   | ############################################################### | 100%
wheel-0.33.4         | 40 KB     | ############################################################### | 100%
setuptools-41.0.1    | 656 KB    | ############################################################### | 100%
libgfortran-ng-7.3.0 | 1.3 MB    | ############################################################### | 100%
libstdcxx-ng-9.1.0   | 4.0 MB    | ############################################################### | 100%
readline-6.2         | 606 KB    | ############################################################### | 100%
openssl-1.0.2s       | 3.1 MB    | ############################################################### | 100%
numpy-base-1.15.4    | 4.2 MB    | ############################################################### | 100%
certifi-2019.3.9     | 155 KB    | ############################################################### | 100%
mkl_fft-1.0.6        | 150 KB    | ############################################################### | 100%
mkl-2018.0.3         | 198.7 MB  | ############################################################### | 100%
intel-openmp-2019.4  | 876 KB    | ############################################################### | 100%
scipy-1.1.0          | 18.0 MB   | ############################################################### | 100%
libgcc-ng-9.1.0      | 8.1 MB    | ############################################################### | 100%
sqlite-3.13.0        | 4.0 MB    | ############################################################### | 100%
numpy-1.15.4         | 35 KB     | ############################################################### | 100%
scikit-learn-0.19.1  | 5.2 MB    | ############################################################### | 100%
pip-19.1.1           | 1.9 MB    | ############################################################### | 100%
blas-1.0             | 6 KB      | ############################################################### | 100%
tk-8.5.18            | 1.9 MB    | ############################################################### | 100%
Preparing transaction: done
Verifying transaction: done
Executing transaction: done
Ran pip subprocess with arguments:
['/opt/conda/envs/mlflow-7fe9f24d100ade6523081ed2190877c5976a985c/bin/python', '-m', 'pip', 'install', '-U', '-r', '/root/mlflow/examples/sklearn_logistic_regression/condaenv._wh4m8j4.requirements.txt']
Pip subprocess output:
Collecting mlflow>=1.0 (from -r /root/mlflow/examples/sklearn_logistic_regression/condaenv._wh4m8j4.requirements.txt (line 1))
  Using cached https://files.pythonhosted.org/packages/01/ec/8c9448968d4662e8354b9c3a62e635f8929ed507a45af3d9fdb84be51270/mlflow-1.0.0-py3-none-any.whl
Collecting requests>=2.17.3 (from mlflow>=1.0->-r /root/mlflow/examples/sklearn_logistic_regression/condaenv._wh4m8j4.requirements.txt (line 1))
  Downloading https://files.pythonhosted.org/packages/51/bd/23c926cd341ea6b7dd0b2a00aba99ae0f828be89d72b2190f27c11d4b7fb/requests-2.22.0-py2.py3-none-any.whl (57kB)
Collecting click>=7.0 (from mlflow>=1.0->-r /root/mlflow/examples/sklearn_logistic_regression/condaenv._wh4m8j4.requirements.txt (line 1))
  Using cached https://files.pythonhosted.org/packages/fa/37/45185cb5abbc30d7257104c434fe0b07e5a195a6847506c074527aa599ec/Click-7.0-py2.py3-none-any.whl
Collecting databricks-cli>=0.8.0 (from mlflow>=1.0->-r /root/mlflow/examples/sklearn_logistic_regression/condaenv._wh4m8j4.requirements.txt (line 1))
  Using cached https://files.pythonhosted.org/packages/5f/38/f83bc71c5e7351a03e8d44aaf04647d076bbf8f097e3f93b921704b7a74c/databricks_cli-0.8.7-py3-none-any.whl
Collecting six>=1.10.0 (from mlflow>=1.0->-r /root/mlflow/examples/sklearn_logistic_regression/condaenv._wh4m8j4.requirements.txt (line 1))
  Downloading https://files.pythonhosted.org/packages/73/fb/00a976f728d0d1fecfe898238ce23f502a721c0ac0ecfedb80e0d88c64e9/six-1.12.0-py2.py3-none-any.whl
Collecting gunicorn (from mlflow>=1.0->-r /root/mlflow/examples/sklearn_logistic_regression/condaenv._wh4m8j4.requirements.txt (line 1))
  Using cached https://files.pythonhosted.org/packages/8c/da/b8dd8deb741bff556db53902d4706774c8e1e67265f69528c14c003644e6/gunicorn-19.9.0-py2.py3-none-any.whl
Collecting python-dateutil (from mlflow>=1.0->-r /root/mlflow/examples/sklearn_logistic_regression/condaenv._wh4m8j4.requirements.txt (line 1))
  Downloading https://files.pythonhosted.org/packages/41/17/c62faccbfbd163c7f57f3844689e3a78bae1f403648a6afb1d0866d87fbb/python_dateutil-2.8.0-py2.py3-none-any.whl (226kB)
Collecting simplejson (from mlflow>=1.0->-r /root/mlflow/examples/sklearn_logistic_regression/condaenv._wh4m8j4.requirements.txt (line 1))
  Using cached https://files.pythonhosted.org/packages/e3/24/c35fb1c1c315fc0fffe61ea00d3f88e85469004713dab488dee4f35b0aff/simplejson-3.16.0.tar.gz
Collecting sqlparse (from mlflow>=1.0->-r /root/mlflow/examples/sklearn_logistic_regression/condaenv._wh4m8j4.requirements.txt (line 1))
  Using cached https://files.pythonhosted.org/packages/ef/53/900f7d2a54557c6a37886585a91336520e5539e3ae2423ff1102daf4f3a7/sqlparse-0.3.0-py2.py3-none-any.whl
Collecting alembic (from mlflow>=1.0->-r /root/mlflow/examples/sklearn_logistic_regression/condaenv._wh4m8j4.requirements.txt (line 1))
Requirement already satisfied, skipping upgrade: numpy in /opt/conda/envs/mlflow-7fe9f24d100ade6523081ed2190877c5976a985c/lib/python3.6/site-packages (from mlflow>=1.0->-r /root/mlflow/examples/sklearn_logistic_regression/condaenv._wh4m8j4.requirements.txt (line 1)) (1.15.4)
Collecting Flask (from mlflow>=1.0->-r /root/mlflow/examples/sklearn_logistic_regression/condaenv._wh4m8j4.requirements.txt (line 1))
  Using cached https://files.pythonhosted.org/packages/9a/74/670ae9737d14114753b8c8fdf2e8bd212a05d3b361ab15b44937dfd40985/Flask-1.0.3-py2.py3-none-any.whl
Collecting protobuf>=3.6.0 (from mlflow>=1.0->-r /root/mlflow/examples/sklearn_logistic_regression/condaenv._wh4m8j4.requirements.txt (line 1))
  Downloading https://files.pythonhosted.org/packages/d2/fb/29de8d08967f0cce1bb10b39846d836b0f3bf6776ddc36aed7c73498ca7e/protobuf-3.8.0-cp36-cp36m-manylinux1_x86_64.whl (1.2MB)
Collecting gitpython>=2.1.0 (from mlflow>=1.0->-r /root/mlflow/examples/sklearn_logistic_regression/condaenv._wh4m8j4.requirements.txt (line 1))
  Using cached https://files.pythonhosted.org/packages/fe/e5/fafe827507644c32d6dc553a1c435cdf882e0c28918a5bab29f7fbebfb70/GitPython-2.1.11-py2.py3-none-any.whl
Collecting cloudpickle (from mlflow>=1.0->-r /root/mlflow/examples/sklearn_logistic_regression/condaenv._wh4m8j4.requirements.txt (line 1))
  Using cached https://files.pythonhosted.org/packages/09/f4/4a080c349c1680a2086196fcf0286a65931708156f39568ed7051e42ff6a/cloudpickle-1.2.1-py2.py3-none-any.whl
Collecting pyyaml (from mlflow>=1.0->-r /root/mlflow/examples/sklearn_logistic_regression/condaenv._wh4m8j4.requirements.txt (line 1))
  Downloading https://files.pythonhosted.org/packages/a3/65/837fefac7475963d1eccf4aa684c23b95aa6c1d033a2c5965ccb11e22623/PyYAML-5.1.1.tar.gz (274kB)
Collecting docker>=3.6.0 (from mlflow>=1.0->-r /root/mlflow/examples/sklearn_logistic_regression/condaenv._wh4m8j4.requirements.txt (line 1))
  Using cached https://files.pythonhosted.org/packages/91/93/310fe092039f6b0759a1f8524e9e2c56f8012804fa2a8da4e4289bb74d7c/docker-4.0.1-py2.py3-none-any.whl
Collecting pandas (from mlflow>=1.0->-r /root/mlflow/examples/sklearn_logistic_regression/condaenv._wh4m8j4.requirements.txt (line 1))
  Downloading https://files.pythonhosted.org/packages/19/74/e50234bc82c553fecdbd566d8650801e3fe2d6d8c8d940638e3d8a7c5522/pandas-0.24.2-cp36-cp36m-manylinux1_x86_64.whl (10.1MB)
Collecting entrypoints (from mlflow>=1.0->-r /root/mlflow/examples/sklearn_logistic_regression/condaenv._wh4m8j4.requirements.txt (line 1))
  Using cached https://files.pythonhosted.org/packages/ac/c6/44694103f8c221443ee6b0041f69e2740d89a25641e62fb4f2ee568f2f9c/entrypoints-0.3-py2.py3-none-any.whl
Collecting sqlalchemy (from mlflow>=1.0->-r /root/mlflow/examples/sklearn_logistic_regression/condaenv._wh4m8j4.requirements.txt (line 1))
  Using cached https://files.pythonhosted.org/packages/ba/37/094ecf4b218f20572986dc90fe8c6aed32e2a711bcd02ce8ef251fde2011/SQLAlchemy-1.3.4.tar.gz
Collecting querystring-parser (from mlflow>=1.0->-r /root/mlflow/examples/sklearn_logistic_regression/condaenv._wh4m8j4.requirements.txt (line 1))
  Using cached https://files.pythonhosted.org/packages/57/64/3086a9a991ff3aca7b769f5b0b51ff8445a06337ae2c58f215bcee48f527/querystring_parser-1.2.3.tar.gz
Collecting idna<2.9,>=2.5 (from requests>=2.17.3->mlflow>=1.0->-r /root/mlflow/examples/sklearn_logistic_regression/condaenv._wh4m8j4.requirements.txt (line 1))
  Downloading https://files.pythonhosted.org/packages/14/2c/cd551d81dbe15200be1cf41cd03869a46fe7226e7450af7a6545bfc474c9/idna-2.8-py2.py3-none-any.whl (58kB)
Requirement already satisfied, skipping upgrade: certifi>=2017.4.17 in /opt/conda/envs/mlflow-7fe9f24d100ade6523081ed2190877c5976a985c/lib/python3.6/site-packages (from requests>=2.17.3->mlflow>=1.0->-r /root/mlflow/examples/sklearn_logistic_regression/condaenv._wh4m8j4.requirements.txt (line 1)) (2019.3.9)
Collecting urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 (from requests>=2.17.3->mlflow>=1.0->-r /root/mlflow/examples/sklearn_logistic_regression/condaenv._wh4m8j4.requirements.txt (line 1))
  Downloading https://files.pythonhosted.org/packages/e6/60/247f23a7121ae632d62811ba7f273d0e58972d75e58a94d329d51550a47d/urllib3-1.25.3-py2.py3-none-any.whl (150kB)
Collecting chardet<3.1.0,>=3.0.2 (from requests>=2.17.3->mlflow>=1.0->-r /root/mlflow/examples/sklearn_logistic_regression/condaenv._wh4m8j4.requirements.txt (line 1))
  Downloading https://files.pythonhosted.org/packages/bc/a9/01ffebfb562e4274b6487b4bb1ddec7ca55ec7510b22e4c51f14098443b8/chardet-3.0.4-py2.py3-none-any.whl (133kB)
Collecting configparser>=0.3.5 (from databricks-cli>=0.8.0->mlflow>=1.0->-r /root/mlflow/examples/sklearn_logistic_regression/condaenv._wh4m8j4.requirements.txt (line 1))
  Using cached https://files.pythonhosted.org/packages/ba/05/6c96328e92e625fc31445d24d75a2c92ef9ba34fc5b037fe69693c362a0d/configparser-3.7.4-py2.py3-none-any.whl
Collecting tabulate>=0.7.7 (from databricks-cli>=0.8.0->mlflow>=1.0->-r /root/mlflow/examples/sklearn_logistic_regression/condaenv._wh4m8j4.requirements.txt (line 1))
  Using cached https://files.pythonhosted.org/packages/c2/fd/202954b3f0eb896c53b7b6f07390851b1fd2ca84aa95880d7ae4f434c4ac/tabulate-0.8.3.tar.gz
Collecting python-editor>=0.3 (from alembic->mlflow>=1.0->-r /root/mlflow/examples/sklearn_logistic_regression/condaenv._wh4m8j4.requirements.txt (line 1))
  Using cached https://files.pythonhosted.org/packages/c6/d3/201fc3abe391bbae6606e6f1d598c15d367033332bd54352b12f35513717/python_editor-1.0.4-py3-none-any.whl
Collecting Mako (from alembic->mlflow>=1.0->-r /root/mlflow/examples/sklearn_logistic_regression/condaenv._wh4m8j4.requirements.txt (line 1))
  Using cached https://files.pythonhosted.org/packages/0a/af/a6d8aa7b8909a36074f517b15222e3a2fbd5ef3452c0a686e3d43043dd3b/Mako-1.0.12.tar.gz
Collecting Jinja2>=2.10 (from Flask->mlflow>=1.0->-r /root/mlflow/examples/sklearn_logistic_regression/condaenv._wh4m8j4.requirements.txt (line 1))
  Downloading https://files.pythonhosted.org/packages/1d/e7/fd8b501e7a6dfe492a433deb7b9d833d39ca74916fa8bc63dd1a4947a671/Jinja2-2.10.1-py2.py3-none-any.whl (124kB)
Collecting itsdangerous>=0.24 (from Flask->mlflow>=1.0->-r /root/mlflow/examples/sklearn_logistic_regression/condaenv._wh4m8j4.requirements.txt (line 1))
  Using cached https://files.pythonhosted.org/packages/76/ae/44b03b253d6fade317f32c24d100b3b35c2239807046a4c953c7b89fa49e/itsdangerous-1.1.0-py2.py3-none-any.whl
Collecting Werkzeug>=0.14 (from Flask->mlflow>=1.0->-r /root/mlflow/examples/sklearn_logistic_regression/condaenv._wh4m8j4.requirements.txt (line 1))
  Using cached https://files.pythonhosted.org/packages/9f/57/92a497e38161ce40606c27a86759c6b92dd34fcdb33f64171ec559257c02/Werkzeug-0.15.4-py2.py3-none-any.whl
Requirement already satisfied, skipping upgrade: setuptools in /opt/conda/envs/mlflow-7fe9f24d100ade6523081ed2190877c5976a985c/lib/python3.6/site-packages (from protobuf>=3.6.0->mlflow>=1.0->-r /root/mlflow/examples/sklearn_logistic_regression/condaenv._wh4m8j4.requirements.txt (line 1)) (41.0.1)
Collecting gitdb2>=2.0.0 (from gitpython>=2.1.0->mlflow>=1.0->-r /root/mlflow/examples/sklearn_logistic_regression/condaenv._wh4m8j4.requirements.txt (line 1))
  Using cached https://files.pythonhosted.org/packages/da/30/a407568aa8d8f25db817cf50121a958722f3fc5f87e3a6fba1f40c0633e3/gitdb2-2.0.5-py2.py3-none-any.whl
Collecting websocket-client>=0.32.0 (from docker>=3.6.0->mlflow>=1.0->-r /root/mlflow/examples/sklearn_logistic_regression/condaenv._wh4m8j4.requirements.txt (line 1))
  Using cached https://files.pythonhosted.org/packages/29/19/44753eab1fdb50770ac69605527e8859468f3c0fd7dc5a76dd9c4dbd7906/websocket_client-0.56.0-py2.py3-none-any.whl
Collecting pytz>=2011k (from pandas->mlflow>=1.0->-r /root/mlflow/examples/sklearn_logistic_regression/condaenv._wh4m8j4.requirements.txt (line 1))
  Downloading https://files.pythonhosted.org/packages/3d/73/fe30c2daaaa0713420d0382b16fbb761409f532c56bdcc514bf7b6262bb6/pytz-2019.1-py2.py3-none-any.whl (510kB)
Collecting MarkupSafe>=0.9.2 (from Mako->alembic->mlflow>=1.0->-r /root/mlflow/examples/sklearn_logistic_regression/condaenv._wh4m8j4.requirements.txt (line 1))
  Downloading https://files.pythonhosted.org/packages/b2/5f/23e0023be6bb885d00ffbefad2942bc51a620328ee910f64abe5a8d18dd1/MarkupSafe-1.1.1-cp36-cp36m-manylinux1_x86_64.whl
Collecting smmap2>=2.0.0 (from gitdb2>=2.0.0->gitpython>=2.1.0->mlflow>=1.0->-r /root/mlflow/examples/sklearn_logistic_regression/condaenv._wh4m8j4.requirements.txt (line 1))
  Using cached https://files.pythonhosted.org/packages/55/d2/866d45e3a121ee15a1dc013824d58072fd5c7799c9c34d01378eb262ca8f/smmap2-2.0.5-py2.py3-none-any.whl
Building wheels for collected packages: simplejson, pyyaml, sqlalchemy, querystring-parser, tabulate, Mako
  Building wheel for simplejson (setup.py): started
  Building wheel for simplejson (setup.py): finished with status 'done'
  Stored in directory: /root/.cache/pip/wheels/5d/1a/1e/0350bb3df3e74215cd91325344cc86c2c691f5306eb4d22c77
  Building wheel for pyyaml (setup.py): started
  Building wheel for pyyaml (setup.py): finished with status 'done'
  Stored in directory: /root/.cache/pip/wheels/16/27/a1/775c62ddea7bfa62324fd1f65847ed31c55dadb6051481ba3f
  Building wheel for sqlalchemy (setup.py): started
  Building wheel for sqlalchemy (setup.py): finished with status 'done'
  Stored in directory: /root/.cache/pip/wheels/cc/b2/b8/54b71f2c27738fc6f9d1b68b6cf653c28a5fa0a9846d02be32
  Building wheel for querystring-parser (setup.py): started
  Building wheel for querystring-parser (setup.py): finished with status 'done'
  Stored in directory: /root/.cache/pip/wheels/ee/09/99/bf937e4f02788fa8b33dc5240842ba3977ba5c3c4ad4a115d7
  Building wheel for tabulate (setup.py): started
  Building wheel for tabulate (setup.py): finished with status 'done'
  Stored in directory: /root/.cache/pip/wheels/2b/67/89/414471314a2d15de625d184d8be6d38a03ae1e983dbda91e84
  Building wheel for Mako (setup.py): started
  Building wheel for Mako (setup.py): finished with status 'done'
  Stored in directory: /root/.cache/pip/wheels/b3/7b/ae/5addd138cd8175503b9782737bada30d0c88310d08c106f9bf
Successfully built simplejson pyyaml sqlalchemy querystring-parser tabulate Mako
Installing collected packages: idna, urllib3, chardet, requests, click, configparser, six, tabulate, databricks-cli, gunicorn, python-dateutil, simplejson, sqlparse, sqlalchemy, python-editor, MarkupSafe, Mako, alembic, Jinja2, itsdangerous, Werkzeug, Flask, protobuf, smmap2, gitdb2, gitpython, cloudpickle, pyyaml, websocket-client, docker, pytz, pandas, entrypoints, querystring-parser, mlflow
Successfully installed Flask-1.0.3 Jinja2-2.10.1 Mako-1.0.12 MarkupSafe-1.1.1 Werkzeug-0.15.4 alembic-1.0.10 chardet-3.0.4 click-7.0 cloudpickle-1.2.1 configparser-3.7.4 databricks-cli-0.8.7 docker-4.0.1 entrypoints-0.3 gitdb2-2.0.5 gitpython-2.1.11 gunicorn-19.9.0 idna-2.8 itsdangerous-1.1.0 mlflow-1.0.0 pandas-0.24.2 protobuf-3.8.0 python-dateutil-2.8.0 python-editor-1.0.4 pytz-2019.1 pyyaml-5.1.1 querystring-parser-1.2.3 requests-2.22.0 simplejson-3.16.0 six-1.12.0 smmap2-2.0.5 sqlalchemy-1.3.4 sqlparse-0.3.0 tabulate-0.8.3 urllib3-1.25.3 websocket-client-0.56.0

#
# To activate this environment, use:
# > conda activate mlflow-7fe9f24d100ade6523081ed2190877c5976a985c
#
# To deactivate an active environment, use:
# > conda deactivate
#

2019/06/16 13:45:49 INFO mlflow.projects: === Created directory /tmp/tmpmwa09cy0 for downloading remote URIs passed to arguments of type 'path' ===
2019/06/16 13:45:49 INFO mlflow.projects: === Running command 'source activate mlflow-7fe9f24d100ade6523081ed2190877c5976a985c && python train.py' in run with ID '6aa1bbdc5f484d309d77d1967bb2e1e4' ===
Score: 0.6666666666666666
Model saved in run 6aa1bbdc5f484d309d77d1967bb2e1e4
2019/06/16 13:45:52 INFO mlflow.projects: === Run (ID '6aa1bbdc5f484d309d77d1967bb2e1e4') succeeded ===

これで実行が開始されました。 さて、sklearn_logistic_regressionのディレクトリを構成するファイルの中身を紹介します。 依存関係の記述やスコアの保存など基礎的な機能がよくわかります。

MLproject

name: sklearn_logistic_example

conda_env: conda.yaml

entry_points:
  main:
    command: "python train.py"

conda.yaml

name: sklearn-example
channels:
  - defaults
  - anaconda
dependencies:
  - python==3.6
  - scikit-learn=0.19.1
  - pip:
    - mlflow>=1.0

train.py

from __future__ import print_function

import numpy as np
from sklearn.linear_model import LogisticRegression

import mlflow
import mlflow.sklearn

if __name__ == "__main__":
    X = np.array([-2, -1, 0, 1, 2, 1]).reshape(-1, 1)
    y = np.array([0, 0, 1, 1, 1, 0])
    lr = LogisticRegression()
    lr.fit(X, y)
    score = lr.score(X, y)
    print("Score: %s" % score)
    mlflow.log_metric("score", score)
    mlflow.sklearn.log_model(lr, "model")
    print("Model saved in run %s" % mlflow.active_run().info.run_uuid)

所感

  • データサイエンティストに自由に研究してもらうよりは、このフォーマットにしたがってもらう方が幸せになれそう。
  • データの管理ができないので、完全な再現はできないっぽい。SparkのDelta Lakeの流れを汲んでいきそう。
  • Amazon SageMakerのコンセプトと近い。SageMakerだと学習の並列化ができるが、MLflowだとできないっぽい。
  • Pipfile.lockのような仕組みがないので、依存バージョンを固定しないと環境が統一できない気がする。

参考