现在我们已经成功训练好了tensorflow object detection 模型,我们的数据集为5类缺陷目标,网络模型选用的是ssd inception V2。具体训练步骤见TensorRT成功测试自己的数据集SSD模型二,现在讲述在tensorrt中具体如何使用。
导出Trained Inference Graph
TensorFlow/models/research/object_detection/export_inference_graph.py复制到training_demo文件夹
找到training_demo/training下的model.ckpt-*的 第一个index , 比如model.ckpt-0
cd training_demo
导出pb
1 | python export_inference_graph.py --input_type image_tensor --pipeline_config_path training/ssd_inception_v2_coco.config --trained_checkpoint_prefix training/model.ckpt-0 --output_directory trained-inference-graphs/output_inference_graph_v1.pb |
uff_ssd
首先,定位到TensorRT-7.0.0.11/samples/python/uff_ssd
打开 detect_object.py
1 | COCO_CLASSES_LIST = ['unlabeled', 'dirty', 'oil', 'pit', 'scratch', 'wire_drawing'] |
定义你自己的CLASSES_LIST,切记'unlabeled'不能少
1 | # Model used for inference |
接着注释掉model下的raise error和download_model
1 | if model_name != "ssd_inception_v2_coco_2017_11_17": |
1 | NMS = gs.create_plugin_node( |
切记这里的numClasses为classes+1