10分钟快速上手使用¶
本文档在一个小数据集上展示了如何通过PaddleX进行训练。本示例同步在AIStudio上,可直接在线体验模型训练。
本示例代码源于Github tutorials/train/classification/mobilenetv3_small_ssld.py,用户可自行下载至本地运行。
PaddleX中的所有模型训练跟随以下3个步骤,即可快速完成训练代码开发!
步骤 | 说明 | |
---|---|---|
第1步 | 定义transforms | 用于定义模型训练、验证、预测过程中, 输入图像的预处理和数据增强操作 |
第2步 | 定义datasets | 用于定义模型要加载的训练、验证数据集 |
第3步 | 定义模型开始训练 | 选择需要的模型,进行训练 |
注意:不同模型的transforms、datasets和训练参数都有较大差异,更多模型训练,可直接根据文档教程获取更多模型的训练代码。模型训练教程
PaddleX的其它用法
安装相关过程和问题可以参考PaddleX的安装文档。
pip install paddlex -i https://mirror.baidu.com/pypi/simple
wget https://bj.bcebos.com/paddlex/datasets/vegetables_cls.tar.gz
tar xzvf vegetables_cls.tar.gz
因为训练时加入了数据增强操作,因此在训练和验证过程中,模型的数据处理流程需要分别进行定义。如下所示,代码在train_transforms
中加入了RandomCrop和RandomHorizontalFlip两种数据增强方式, 更多方法可以参考数据增强文档。
from paddlex.cls import transforms
train_transforms = transforms.Compose([
transforms.RandomCrop(crop_size=224),
transforms.RandomHorizontalFlip(),
transforms.Normalize()
])
eval_transforms = transforms.Compose([
transforms.ResizeByShort(short_size=256),
transforms.CenterCrop(crop_size=224),
transforms.Normalize()
])
定义数据集,pdx.datasets.ImageNet
表示读取ImageNet格式的分类数据集
train_dataset = pdx.datasets.ImageNet(
data_dir='vegetables_cls',
file_list='vegetables_cls/train_list.txt',
label_list='vegetables_cls/labels.txt',
transforms=train_transforms,
shuffle=True)
eval_dataset = pdx.datasets.ImageNet(
data_dir='vegetables_cls',
file_list='vegetables_cls/val_list.txt',
label_list='vegetables_cls/labels.txt',
transforms=eval_transforms)
5. 使用MobileNetV3_small_ssld模型开始训练
本文档中使用百度基于蒸馏方法得到的MobileNetV3预训练模型,模型结构与MobileNetV3一致,但精度更高。PaddleX内置了20多种分类模型,查阅PaddleX模型库了解更多分类模型。
num_classes = len(train_dataset.labels)
model = pdx.cls.MobileNetV3_small_ssld(num_classes=num_classes)
model.train(num_epochs=20,
train_dataset=train_dataset,
train_batch_size=32,
eval_dataset=eval_dataset,
lr_decay_epochs=[4, 6, 8],
save_dir='output/mobilenetv3_small_ssld',
use_vdl=True)
训练过程中,模型在训练集和验证集上的指标均会以标准输出流形式输出到命令终端。当用户设定use_vdl=True
时,也会使用VisualDL格式将指标打点到save_dir
目录下的vdl_log
文件夹,在终端运行如下命令启动visualdl并查看可视化的指标变化情况。
visualdl --logdir output/mobilenetv3_small_ssld --port 8001
服务启动后,通过浏览器打开https://0.0.0.0:8001或https://localhost:8001即可。
如果您使用的是AIStudio平台进行训练,不能通过此方式启动visualdl,请参考AIStudio VisualDL启动教程使用
模型在训练过程中,会每间隔一定轮数保存一次模型,在验证集上评估效果最好的一轮会保存在save_dir
目录下的best_model
文件夹。通过如下方式可加载模型,进行预测。
import paddlex as pdx
model = pdx.load_model('output/mobilenetv3_small_ssld/best_model')
result = model.predict('vegetables_cls/bocai/100.jpg')
print("Predict Result: ", result)
预测结果输出如下,
Predict Result: Predict Result: [{'score': 0.9999393, 'category': 'bocai', 'category_id': 0}]