问题描述
我有一个对象检测 TFLite 模型保存为 model.tflite
文件.我可以运行它
I have an object detection TFLite model saved as model.tflite
file. I can run it as
interpreter = tf.lite.Interpreter("model.tflite")
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]['index'], input_image)
interpreter.invoke()
然后得到输出
detection_boxes = interpreter.get_tensor(output_details[0]['index'])
detection_classes = interpreter.get_tensor(output_details[1]['index'])
detection_scores = interpreter.get_tensor(output_details[2]['index'])
num_boxes = interpreter.get_tensor(output_details[3]['index'])
我想用图片中的给定类绘制检测到的框.最简单的解决方案似乎是使用工具 viz_utils.visualize_boxes_and_labels_on_image_array
as.
I would like to plot the detected boxes with given classes in a picture. The simplest solution to do this seems to be using the tool viz_utils.visualize_boxes_and_labels_on_image_array
as.
viz_utils.visualize_boxes_and_labels_on_image_array(
image_np_with_detections,
detection_boxes,
detection_classes,
detection_scores,
category_index,
use_normalized_coordinates=True,
max_boxes_to_draw=20,
min_score_thresh=.1,
agnostic_mode=False
然而,为此需要有 category_index
(将类索引转换为人类可读的标签).通常,您可以从包含标签的文件中加载它,在 .tflite 模型的情况下,如果我没记错的话,应该包含/打包在 .tflite 文件中.
However, for that one needs to have the category_index
(to convert the classes indices to human readable labels). Typically, you can load it from a file containing the labels, which, in case of .tflite model, should be included/packed in the .tflite file, if I am not wrong.
但是,我不知道该怎么做,或者我应该使用哪些函数(我还查看了 tflite_support
库,但不知道如何从关联文件).
However, I can't figure out how to do it, or which functions should I use (I looked also at tflite_support
library, but can't figure out how to extract the categories from the associated file).
使用 .tflite 文件将检测到的带有标签的框可视化的正确方法是什么?它不必使用 viz_utils
.任何帮助表示赞赏.谢谢.
What is the proper way to visualize the detected boxes with labels using a .tflite file? It doesn't have to be using viz_utils
. Any help is appreciated. Thanks.
推荐答案
# labels variable contains the list of the names of the category and
# it generates by reading the labels.txt
with open("labels.txt", "r") as f:
txt = f.read()
labels = txt.splitlines()
for idx, box in enumerate(detection_boxes[0]):
if detection_scores[0][idx] > threshold:
class_name = labels[int(detection_classes[0][idx])]
我根据 https://github.com 创建了此代码片段/tensorflow/models/issues/7458#issuecomment-523904465.
这篇关于如何从 TFLite 模型中可视化检测到的框(如何从 TFLite 模型中获取类别索引?)的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!