0%

TensorRT成功测试自己的数据集SSD模型三

现在我们已经成功训练好了tensorflow object detection 模型,我们的数据集为5类缺陷目标,网络模型选用的是ssd inception V2。具体训练步骤见TensorRT成功测试自己的数据集SSD模型二,现在讲述在tensorrt中具体如何使用。

TensorRT成功测试自己的数据集SSD模型一

TensorRT成功测试自己的数据集SSD模型二

TensorRT成功测试自己的数据集SSD模型三

TensorRT成功测试自己的数据集SSD模型四

导出Trained Inference Graph

  • TensorFlow/models/research/object_detection/export_inference_graph.py复制到training_demo文件夹

  • 找到training_demo/training下的model.ckpt-*的 第一个index , 比如model.ckpt-0

  • cd training_demo

  • 导出pb

1
python export_inference_graph.py --input_type image_tensor --pipeline_config_path training/ssd_inception_v2_coco.config --trained_checkpoint_prefix training/model.ckpt-0 --output_directory trained-inference-graphs/output_inference_graph_v1.pb

uff_ssd

首先,定位到TensorRT-7.0.0.11/samples/python/uff_ssd

打开 detect_object.py 

  • COCO label list

utils/coco.py 148行
1
COCO_CLASSES_LIST = ['unlabeled', 'dirty', 'oil', 'pit', 'scratch', 'wire_drawing']

定义你自己的CLASSES_LIST,切记'unlabeled'不能少

  • MODEL_NAME

detect_objects.py
1
2
# Model used for inference
MODEL_NAME = 'my_model'

接着注释掉model下的raise error和download_model

uff_ssd/utils/model.py
1
2
3
4
5
6
7
if model_name != "ssd_inception_v2_coco_2017_11_17":
# raise NotImplementedError(
# "Model {} is not supported yet".format(model_name))
#####test#####
print(model_name)
######test#####
#download_model(model_name, silent)
  • numClasses

utils/model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
NMS = gs.create_plugin_node(
name="NMS",
op="NMS_TRT",
shareLocation=1,
varianceEncodedInTarget=0,
backgroundLabelId=0,
confidenceThreshold=1e-8,
nmsThreshold=0.6,
topK=100,
keepTopK=100,
numClasses=6,
inputOrder=[0, 2, 1],
confSigmoid=1,
isNormalized=1
)

切记这里的numClasses为classes+1