Skip to content

Latest commit

 

History

History
74 lines (64 loc) · 4.23 KB

训练自己的物体检测模型.md

File metadata and controls

74 lines (64 loc) · 4.23 KB

TensorFlow使用object detection训练自己的模型用于物体识别

简介

使用tensorflow的object_detection模块进行物体检测训练,最终实现识别自己的数据集

机器配置与工作环境:
  • 硬件:
    • cpu:i56300hq
    • gpu:gtx960m*2g
  • 软件
    • ubuntu18.04
    • python3.6
    • tensorflow*1.14.0
    • 其他需要的扩展包按照官方文档安装完成即可 enter description here
训练数据准备

object_detection推荐的训练数据格式是TFRecord格式的文件,我们需要通过对图片进行标注生成xml文件、将xml文件生成csv文件,再将csv文件转换为TFRecord文件以生成训练数据。

  1. 使用labelImg工具标注图片生成xml文件: 在这里下载并安装labelImg,官方教程很详细。安装之后使用此工具可以很容易对图片进行标注
  2. 使用Python脚本转换xml数据为对应的csv数据
  3. 使用Python脚本根据csv数据信息生成TFRecord数据
  4. 在进行上述操作时候,我们还需要构建我们的标签图,后缀名为pbtxt,如:
    item{
       id:1
       name: 'dog'
       }
    item{
       id:2
       name: 'cat'
       }
    item{
       id:3
       name: 'else'
       }  
准备预训练模型

这里下载一个预训练模型,解压之后得到一个模型文件目录,注意,由于后续代码文件更新,所以需要将模型目录里面的.config文件由object_detection目录下的samples目录中对应的模型配置文件替换,在替换之后,还需要根据自己的文件目录配置修改.config文件中的PATH_TO_BE_CONFIGURED字段:

  • num_classes:你自己训练的模型分类类别数量
  • fine_tune_checkpoint:解压后的模型文件中的model.ckpt文件路径
  • train_input_reader、eval_input_reader下的input_path和label_map_path:对应的训练数据集、验证数据集的.TFRecord输入输入和.pbtxt标签图文件
  • 注意:我在使用ssd模型的时候发生了错误,未找到正确的解决方法,建议使用faster_rcnn系列模型
进行训练

注意,由于近期的代码更新,位于object_detection文件train.py移动到了legacy文件夹下,我们需要使用这个文件进行模型训练。 按照此文件中的用法示例使用即可: enter description here 如:

python train.py --logtostderr --pipeline_config_path=./models/faster_rcnn_resnet50_coco_2018_01_28/pipeline.config --train_dir=./models/main/pre/
导出模型

使用object_detection目录下的export_inference_graph.py进行模型导出即可 此文件中也有其用法描述: enter description here 如:

python export_inference_graph.py --input_type image_tensor --pipeline_config_path=./models/faster_rcnn_resnet50_coco_2018_01_28/pipeline.config --trained_checkpoint_prefix=./models/main/pre/model.ckpt-70 --output_directory=./models/main/out/

注意:trained_checkpoint_prefix指的是模型文件前缀

使用训练好的文件进行图像识别
  1. 使用jupyter notebook打开object_detection目录下的object_detection_tutorial.ipynb文件
  2. 修改代码中的PATH_TO_CKPT、PATH_TO_LABELS、NUM_CLASSES为自己训练好并导出到文件夹下的.pb文件路径、自己创建的.pbtxt标签图文件、自己训练的模型识别的类别数
  3. 修改PATH_TO_TEST_IMAGES_DIR为自己待识别图片的文件夹路径,并修改TEST_IMAGES_PATHS中的图片文件扩展名为自己的图片扩展名,并修改后续range数以匹配自己的图片。注,默认识别是图片文件夹下的1.bmp、2.bmp,自己看看代码改改就行。如: enter description here
收功

end