• 欢迎光临~

MMClassification 实践笔记

开发技术 开发技术 2022-12-27 次浏览

1. 配置环境

参考文档:https://mmclassification.readthedocs.io/zh_CN/dev-1.x/get_started.html

git clone -b 1.x https://github.com/open-mmlab/mmclassification.git
## 容易失败,参考 https://blog.csdn.net/good_good_xiu/article/details/118567249
cd mmclassification
conda install pytorch torchvision -c pytorch
pip install -U openmim && mim install -e .  -i https://pypi.tuna.tsinghua.edu.cn/simple

2. 验证

验证 GPU 是否可用?

python

import torch
torch.cuda.is_available() # True

第 1 步 我们需要下载配置文件和模型权重文件

mim download mmcls --config resnet50_8xb32_in1k --dest .

第 2 步 验证示例的推理流程

python demo/image_demo.py demo/demo.JPEG resnet50_8xb32_in1k.py resnet50_8xb32_in1k_20210831-ea4938fc.pth --device cpu

3. 训练自定义数据集

数据组织:切记一定要这种格式。

.
└── xx_data
	├── train
	│   ├── class1(xx张图像)
	│   ├── class2(xx张图像)
	└── val
	│   ├── class1(xx张图像)
	│   ├── class2(xx张图像)

配置文件:

注意:num_classes=7, data_prefix='/data/xxx/datasets/age_classify/age_data/train',

model = dict(
	type='ImageClassifier',
	backbone=dict(type='MobileNetV2', widen_factor=1.0),
	neck=dict(type='GlobalAveragePooling'),
	head=dict(
		type='LinearClsHead',
		num_classes=7,
		in_channels=1280,
		loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
		topk=(1, )))

load_from = 'mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth'

# dataset_type = 'ImageNet'
# data_preprocessor = dict(
#     num_classes=1000,
#     mean=[123.675, 116.28, 103.53],
#     std=[58.395, 57.12, 57.375],
#     to_rgb=True)
# train_pipeline = [
#     dict(type='LoadImageFromFile'),
#     dict(type='RandomResizedCrop', scale=224, backend='pillow'),
#     dict(type='RandomFlip', prob=0.5, direction='horizontal'),
#     dict(type='PackClsInputs')
# ]
# test_pipeline = [
#     dict(type='LoadImageFromFile'),
#     dict(type='ResizeEdge', scale=256, edge='short', backend='pillow'),
#     dict(type='CenterCrop', crop_size=224),
#     dict(type='PackClsInputs')
# ]

train_dataloader = dict(
	batch_size=64,
	num_workers=5,
	dataset=dict(
		type='CustomDataset',
		# data_root='data/imagenet',
		# ann_file='meta/train.txt',
		data_prefix='/data/xxx/datasets/age_classify/age_data/train',
		pipeline=[
			dict(type='LoadImageFromFile'),
			dict(type='RandomResizedCrop', scale=168, backend='pillow'),
			dict(type='RandomFlip', prob=0.5, direction='horizontal'),
			dict(type='PackClsInputs')
		]),
	sampler=dict(type='DefaultSampler', shuffle=True))
val_dataloader = dict(
	batch_size=64,
	num_workers=5,
	dataset=dict(
		type='CustomDataset',
		# data_root='data/imagenet',
		# ann_file='meta/val.txt',
		data_prefix='/data/xxx/datasets/age_classify/age_data/val',
		pipeline=[
			dict(type='LoadImageFromFile'),
			dict(type='ResizeEdge', scale=224, edge='short', backend='pillow'),
			dict(type='CenterCrop', crop_size=168),
			dict(type='PackClsInputs')
		]),
	sampler=dict(type='DefaultSampler', shuffle=False))
val_evaluator = dict(type='Accuracy', topk=(1, ))
test_dataloader = dict(
	batch_size=32,
	num_workers=5,
	dataset=dict(
		type='CustomDataset',
		# data_root='data/imagenet',
		# ann_file='meta/val.txt',
		data_prefix='/data/xxx/datasets/age_classify/age_data/val',
		pipeline=[
			dict(type='LoadImageFromFile'),
			dict(type='ResizeEdge', scale=224, edge='short', backend='pillow'),
			dict(type='CenterCrop', crop_size=168),
			dict(type='PackClsInputs')
		]),
	sampler=dict(type='DefaultSampler', shuffle=False))
test_evaluator = dict(type='Accuracy', topk=(1,))
optim_wrapper = dict(
	optimizer=dict(type='SGD', lr=0.005, momentum=0.9, weight_decay=4e-05))
param_scheduler = dict(type='StepLR', by_epoch=True, step_size=1, gamma=0.98)
train_cfg = dict(by_epoch=True, max_epochs=100, val_interval=5)
val_cfg = dict()
test_cfg = dict()
auto_scale_lr = dict(base_batch_size=256)
default_scope = 'mmcls'
default_hooks = dict(
	timer=dict(type='IterTimerHook'),
	logger=dict(type='LoggerHook', interval=10),
	param_scheduler=dict(type='ParamSchedulerHook'),
	checkpoint=dict(type='CheckpointHook', interval=5),
	sampler_seed=dict(type='DistSamplerSeedHook'),
	visualization=dict(type='VisualizationHook', enable=False))
env_cfg = dict(
	cudnn_benchmark=False,
	mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
	dist_cfg=dict(backend='nccl'))
vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
	type='ClsVisualizer', vis_backends=[dict(type='LocalVisBackend')])
log_level = 'INFO'

resume = False
randomness = dict(seed=None, deterministic=False)

训练命令:

python tool/tools/train.py --config 上面的配置文件的路径即可

参考:https://github.com/wangruohui/sjtu-openmmlab-tutorial/blob/main/cls-2-train.ipynb

程序员灯塔
转载请注明原文链接:MMClassification 实践笔记
喜欢 (0)