在上图通过手机的相机拍摄到的物体识别出具体的名称,这个需要通过TensorFlow 训练的模型引用到项目中;以下就是详细地集成 TensorFlow步骤,请按照以下步骤进行操作:
在项目的根目录下的
build.gradle
文件中添加 TensorFlow 的 Maven 仓库。在repositories
部分添加以下行:
allprojects {
repositories {
// 其他仓库...
maven {
url 'https://google.bintray.com/tensorflow'
}
}
}
在应用的
build.gradle
文件中添加 TensorFlow Lite 的依赖。在dependencies
部分添加以下行:
implementation 'org.tensorflow:tensorflow-lite:2.5.0'
将 TensorFlow Lite 模型文件添加到你的 Android 项目中。将模型文件(
.tflite
)复制到app/src/main/assets
目录下。如果assets
目录不存在,可以手动创建。创建一个
TFLiteObjectDetectionAPIModel类,用于加载和运行 TensorFlow Lite 模型。以下是一个示例代码:
import org.tensorflow.lite.Interpreter;
public class TFLiteObjectDetectionAPIModel implements Classifier {
private static final Logger LOGGER = new Logger();
// Only return this many results.
private static final int NUM_DETECTIONS = 10;
// Float model
private static final float IMAGE_MEAN = 128.0f;
private static final float IMAGE_STD = 128.0f;
// Number of threads in the java app
private static final int NUM_THREADS = 4;
private boolean isModelQuantized;
// Config values.
private int inputSize;
// Pre-allocated buffers.
private Vector<String> labels = new Vector<String>();
private int[] intValues;
// outputLocations: array of shape [Batchsize, NUM_DETECTIONS,4]
// contains the location of detected boxes
private float[][][] outputLocations;
// outputClasses: array of shape [Batchsize, NUM_DETECTIONS]
// contains the classes of detected boxes
private float[][] outputClasses;
// outputScores: array of shape [Batchsize, NUM_DETECTIONS]
// contains the scores of detected boxes
private float[][] outputScores;
// numDetections: array of shape [Batchsize]
// contains the number of detected boxes
private float[] numDetections;
private ByteBuffer imgData;
private Interpreter tfLite;
private TFLiteObjectDetectionAPIModel() {}
/** Memory-map the model file in Assets. */
private static MappedByteBuffer loadModelFile(AssetManager assets, String modelFilename)
throws IOException {
AssetFileDescriptor fileDescriptor = assets.openFd(modelFilename);
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
FileChannel fileChannel = inputStream.getChannel();
long startOffset = fileDescriptor.getStartOffset();
long declaredLength = fileDescriptor.getDeclaredLength();
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}
/**
* Initializes a native TensorFlow session for classifying images.
*
* @param assetManager The asset manager to be used to load assets.
* @param modelFilename The filepath of the model GraphDef protocol buffer.
* @param labelFilename The filepath of label file for classes.
* @param inputSize The size of image input
* @param isQuantized Boolean representing model is quantized or not
*/
public static Classifier create(
final AssetManager assetManager,
final String modelFilename,
final String labelFilename,
final int inputSize,
final boolean isQuantized)
throws IOException {
final TFLiteObjectDetectionAPIModel d = new TFLiteObjectDetectionAPIModel();
InputStream labelsInput = null;
String actualFilename = labelFilename.split("file:///android_asset/")[1];
labelsInput = assetManager.open(actualFilename);
BufferedReader br = null;
br = new BufferedReader(new InputStreamReader(labelsInput));
String line;
while ((line = br.readLine()) != null) {
LOGGER.w(line);
d.labels.add(line);
}
br.close();
d.inputSize = inputSize;
try {
d.tfLite = new Interpreter(loadModelFile(assetManager, modelFilename));
} catch (Exception e) {
throw new RuntimeException(e);
}
d.isModelQuantized = isQuantized;
// Pre-allocate buffers.
int numBytesPerChannel;
if (isQuantized) {
numBytesPerChannel = 1; // Quantized
} else {
numBytesPerChannel = 4; // Floating point
}
d.imgData = ByteBuffer.allocateDirect(1 * d.inputSize * d.inputSize * 3 * numBytesPerChannel);
d.imgData.order(ByteOrder.nativeOrder());
d.intValues = new int[d.inputSize * d.inputSize];
d.tfLite.setNumThreads(NUM_THREADS);
d.outputLocations = new float[1][NUM_DETECTIONS][4];
d.outputClasses = new float[1][NUM_DETECTIONS];
d.outputScores = new float[1][NUM_DETECTIONS];
d.numDetections = new float[1];
return d;
}
@Override
public List<Recognition> recognizeImage(final Bitmap bitmap) {
// Log this method so that it can be analyzed with systrace.
Trace.beginSection("recognizeImage");
Trace.beginSection("preprocessBitmap");
// Preprocess the image data from 0-255 int to normalized float based
// on the provided parameters.
bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
imgData.rewind();
for (int i = 0; i < inputSize; ++i) {
for (int j = 0; j < inputSize; ++j) {
int pixelValue = intValues[i * inputSize + j];
if (isModelQuantized) {
// Quantized model
imgData.put((byte) ((pixelValue >> 16) & 0xFF));
imgData.put((byte) ((pixelValue >> 8) & 0xFF));
imgData.put((byte) (pixelValue & 0xFF));
} else { // Float model
imgData.putFloat((((pixelValue >> 16) & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
imgData.putFloat((((pixelValue >> 8) & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
imgData.putFloat(((pixelValue & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
}
}
}
Trace.endSection(); // preprocessBitmap
// Copy the input data into TensorFlow.
Trace.beginSection("feed");
outputLocations = new float[1][NUM_DETECTIONS][4];
outputClasses = new float[1][NUM_DETECTIONS];
outputScores = new float[1][NUM_DETECTIONS];
numDetections = new float[1];
Object[] inputArray = {imgData};
Map<Integer, Object> outputMap = new HashMap<>();
outputMap.put(0, outputLocations);
outputMap.put(1, outputClasses);
outputMap.put(2, outputScores);
outputMap.put(3, numDetections);
Trace.endSection();
// Run the inference call.
Trace.beginSection("run");
tfLite.runForMultipleInputsOutputs(inputArray, outputMap);
Trace.endSection();
// Show the best detections.
// after scaling them back to the input size.
final ArrayList<Recognition> recognitions = new ArrayList<>(NUM_DETECTIONS);
for (int i = 0; i < NUM_DETECTIONS; ++i) {
final RectF detection =
new RectF(
outputLocations[0][i][1] * inputSize,
outputLocations[0][i][0] * inputSize,
outputLocations[0][i][3] * inputSize,
outputLocations[0][i][2] * inputSize);
// SSD Mobilenet V1 Model assumes class 0 is background class
// in label file and class labels start from 1 to number_of_classes+1,
// while outputClasses correspond to class index from 0 to number_of_classes
int labelOffset = 1;
recognitions.add(
new Recognition(
"" + i,
labels.get((int) outputClasses[0][i] + labelOffset),
outputScores[0][i],
detection));
}
Trace.endSection(); // "recognizeImage"
return recognitions;
}
@Override
public void enableStatLogging(final boolean logStats) {}
@Override
public String getStatString() {
return "";
}
@Override
public void close() {}
public void setNumThreads(int num_threads) {
if (tfLite != null) tfLite.setNumThreads(num_threads);
}
@Override
public void setUseNNAPI(boolean isChecked) {
if (tfLite != null) tfLite.setUseNNAPI(isChecked);
}
}
确保替换 modelPath
参数为你的模型文件在 assets
目录中的路径。
在你的应用程序中使用 TFLiteObjectDetectionAPIModel 类进行推理。以下是一个简单的示例:
@Override
public void onPreviewSizeChosen(final Size size, final int rotation) {
final float textSizePx =
TypedValue.applyDimension(
TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, getResources().getDisplayMetrics());
borderedText = new BorderedText(textSizePx);
borderedText.setTypeface(Typeface.MONOSPACE);
tracker = new MultiBoxTracker(this);
int cropSize = TF_OD_API_INPUT_SIZE;
try {
detector =
TFLiteObjectDetectionAPIModel.create(
getAssets(),
TF_OD_API_MODEL_FILE,
TF_OD_API_LABELS_FILE,
TF_OD_API_INPUT_SIZE,
TF_OD_API_IS_QUANTIZED);
cropSize = TF_OD_API_INPUT_SIZE;
} catch (final IOException e) {
e.printStackTrace();
LOGGER.e(e, "Exception initializing classifier!");
Toast toast =
Toast.makeText(
getApplicationContext(), "Classifier could not be initialized", Toast.LENGTH_SHORT);
toast.show();
finish();
}
// 解析输出数据
// ...
根据你的模型和任务,你可能需要根据模型的规范和文档来解析输出数据。
输出解析文本数据
需要项目源码私聊