本文共 1290 字,大约阅读时间需要 4 分钟。
下载工程
git clone https://github.com/SpikeKing/DL-Project-Template
创建和激活虚拟环境
source venv/bin/activate
安装Python依赖库
pip install -r requirements.txt
开发流程
● 定义自己的数据加载类,继承DataLoaderBase; ● 定义自己的网络结构类,继承ModelBase; ● 定义自己的模型训练类,继承TrainerBase; ● 定义自己的样本预测类,继承InferBase; ● 定义自己的配置文件,写入实验的相关参数;执行训练模型和预测样本操作。
示例工程
识别MNIST库中手写数字,工程simple_mnist
训练:
python main_train.py -c configs/simple_mnist_config.json
预测:
nist.weights.10-0.24.hdf5
TensorBoard
操作步骤:
● 创建自己的加载数据类,继承DataLoaderBase基类; ● 覆写get_train_data()
和 get_test_data()
,返回训练和测试数据; 操作步骤:
● 创建自己的网络结构类,继承ModelBase基类; ● 覆写build_model()
,创建网络结构; ● 在构造器中,调用 build_model()
; 注意:plot_model()
支持绘制网络结构;
Trainer
操作步骤:
● 创建自己的训练类,继承TrainerBase基类; ● 参数:网络结构model、训练数据data; ● 覆写train()
,fit数据,训练网络结构; 注意:支持在训练中调用callbacks,额外添加模型存储、TensorBoard、FPR度量等。
操作步骤:
● 创建自己的预测类,继承InferBase基类; ● 覆写load_model()
,提供模型加载功能; ● 覆写predict()
,提供样本预测功能;
Config
定义在模型训练过程中所需的参数,JSON格式,支持:学习率、Epoch、Batch等参数。
Main
训练:
● 创建配置文件config; ● 创建数据加载类dataloader; ● 创建网络结构类model; ● 创建训练类trainer,参数是训练和测试数据、模型; ● 执行训练类trainer的train();预测:
● 创建配置文件config; ● 处理预测样本test; ● 创建预测类infer;● 执行预测类infer的predict();
原文发布时间为:2018-10-24
本文来自云栖社区合作伙伴“大数据挖掘DT机器学习”,了解相关信息可以关注“”。
转载地址:http://jball.baihongyu.com/