我正在制作一个线性回归模型(3个输入参数,类型为float),可以使其在基于用户输入进行预测的Android应用中在设备上运行。
为此,我使用了TensorFlow估计器tf.estimator.LinearRegressor。我还使用以下代码制作了一个SavedModel:
serving_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(tf.feature_column.make_parse_example_spec([crim, indus, tax]))
export_path = model_est.export_saved_model("saved_model", serving_input_fn)
特征列在代码中之前的定义如下:
tax = tf.feature_column.numeric_column('tax')
indus = tf.feature_column.numeric_column('indus')
crim = tf.feature_column.numeric_column('crim')
整个模型构建代码如下:
import tensorflow as tf
tf.compat.v1.disable_v2_behavior()
TRAIN_CSV_PATH = './data/BostonHousing_subset.csv'
TEST_CSV_PATH = './data/boston_test_subset.csv'
PREDICT_CSV_PATH = './data/boston_predict_subset.csv'
# target variable to predict:
LABEL_PR = "medv"
def get_batch(file_path, batch_size, num_epochs=None, **args):
with open(file_path) as file:
num_rows = len(file.readlines())
dataset = tf.data.experimental.make_csv_dataset(
file_path, batch_size, label_name=LABEL_PR, num_epochs=num_epochs, header=True, **args)
# repeat and shuffle and batch separately instead of the previous line
# for clarity purposes
# dataset = dataset.repeat(num_epochs)
# dataset = dataset.batch(batch_size)
iterator = dataset.make_one_shot_iterator()
elem = iterator.get_next()
return elem
# Now to define the feature columns
tax = tf.feature_column.numeric_column('tax')
indus = tf.feature_column.numeric_column('indus')
crim = tf.feature_column.numeric_column('crim')
# Building the model
model_est = tf.estimator.LinearRegressor(feature_columns=[crim, indus, tax], model_dir='model_dir')
# Train it now
model_est.train(steps=2300, input_fn=lambda: get_batch(TRAIN_CSV_PATH, batch_size=256))
results = model_est.evaluate(steps=1000, input_fn=lambda: get_batch(TEST_CSV_PATH, batch_size=128))
for key in results:
print(" {}, was: {}".format(key, results[key]))
to_pred = {
'crim': [0.03359, 5.09017, 0.12650, 0.05515, 8.15174, 0.24522],
'indus': [2.95, 18.10, 5.13, 2.18, 18.10, 9.90],
'tax': [252, 666, 284, 222, 666, 304],
}
def test_get_inp():
dataset = tf.data.Dataset.from_tensors(to_pred)
return dataset
# Predict
for pred_results in model_est.predict(input_fn=test_get_inp):
print(pred_results['predictions'][0])
# Now to export as SavedModel
print(tf.feature_column.make_parse_example_spec([crim, indus, tax]))
serving_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(tf.feature_column.make_parse_example_spec([crim, indus, tax]))
export_path = model_est.export_saved_model("saved_model", serving_input_fn)
我用来将此SavedModel转换为tflite格式的代码是:
import tensorflow as tf
converter = tf.lite.TFLiteConverter.from_saved_model('saved_model/1576168761')
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)
输出一个.tflite文件。
但是,当我尝试使用以下代码加载此tflite文件时:
import tensorflow as tf
interpreter = tf.lite.Interpreter(model_path="converted_model.tflite")
interpreter.allocate_tensors()
我收到此错误:
Traceback (most recent call last):
File "D:/Documents/Projects/boston_house_pricing/get_model_details.py", line 5, in <module>
interpreter.allocate_tensors()
File "D:\Anaconda3\envs\boston_housing\lib\site-packages\tensorflow_core\lite\python\interpreter.py", line 244, in allocate_tensors
return self._interpreter.AllocateTensors()
File "D:\Anaconda3\envs\boston_housing\lib\site-packages\tensorflow_core\lite\python\interpreter_wrapper\tensorflow_wrap_interpreter_wrapper.py", line 106, in AllocateTensors
return _tensorflow_wrap_interpreter_wrapper.InterpreterWrapper_AllocateTensors(self)
RuntimeError: Regular TensorFlow ops are not supported by this interpreter. Make sure you invoke the Flex delegate before inference.Node number 0 (FlexParseExample) failed to prepare.
我无法理解如何解决此错误。另外,当我尝试使用tflite在Java中使用此文件(Android)初始化解释器时,会引发相同消息的错误。
帮助将不胜感激与此有关。
最佳答案
似乎该错误很好地说明了问题,当转换为tflite时,您已指定tf.lite.OpsSet.SELECT_TF_OPS标志,这将导致转换器包括tflite本机不支持的操作,并且期望您将使用使用tflite中的flex模块编译并在解释器中包含这些操作。
有关flex的更多信息:https://www.tensorflow.org/lite/guide/ops_select
无论如何,您都有两个主要选项,要么使用flex并编译所需的操作,要么仅使用tflite本身支持的操作并忽略tf.lite.OpsSet.SELECT_TF_OPS标志。
有关本机支持的tensorflow操作的信息,请参见:https://www.tensorflow.org/lite/guide/ops_compatibility
关于python - 在tflite解释器中分配张量时出错,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/59315059/