文章目录
本节总览
这一小节需要了解Kuiperinfer项目使用的模型格式PNNX,以及其使用的计算图相关的结构定义。
然后对计算图进行封装。
PNNX计算图
KuiperInfer使用的模型格式为PNNX
,是Pytorch Neural Network Exchange
的缩写。这个格式是NCNN
框架的一个pr中引入的,该格式希望能够将Pytorch模型导出为高效、简洁的计算图。详细可参考https://zhuanlan.zhihu.com/p/427620428
回顾一下绪论中的内容,概念上的计算图主要包含以下部分:
Operator
:计算图中的计算节点Graph
:多个Operator串联得到的有向无环图,规定了计算节点的执行顺序Layer
:Operator中具体执行运算的部分,首先读入输入Tensor中的数据,然后计算,将结果存放在输出Tensor中Tensor
:存放多维数据
而PNNX计算图包括Graph, Operator
和Operand
三种结构。下面将逐个分析每个结构的类定义。
测试模型
Kuiperinfer提供了一对小模型文件linear.param
和linear.bin
,分别为网络的结构和权重文件。网络的pytorch定义如下:
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear = nn.Linear(32, 128)
def forward(self, x):
x = self.linear(x)
x = F.sigmoid(x)
return x
可以使用Netron打开linear.param
文件观察网络的结构
Linear层的结构,包含两个操作数#0, #1
和两个权重属性@weight, @bias
。
Graph
Graph
的功能是管理计算图中的Operator
和Operand
;其结构定义如下
class Graph
{
// 算子
Operator* new_operator(const std::string& type, const std::string& name);
Operator* new_operator_before(const std::string& type, const std::string& name, const Operator* cur);
Operator* new_operator_after(const std::string& type, const std::string& name, const Operator* cur);
// 操作数
Operand* new_operand(const torch::jit::Value* v);
Operand* new_operand(const std::string& name);
Operand* get_operand(const std::string& name);
std::vector<Operator*> ops;
std::vector<Operand*> operands;
// 加载/保存
int load(const std::string& parampath, const std::string& binpath);
int save(const std::string& parampath, const std::string& binpath);
};
单元测试 - 输出算子
TEST(test_ir, pnnx_graph_ops) {
using namespace kuiper_infer;
/**
* 如果这里加载失败,请首先考虑相对路径的正确性问题
*/
// LOG(INFO) << std::filesystem::current_path(); // 需要# include<filesystem>, 加这一行检查路径
std::string bin_path("../../../src/model_file/test_linear.pnnx.bin"); // 这里要填远程主机上的模型相对路径
std::string param_path("../../../src/model_file/test_linear.pnnx.param");
std::unique_ptr<pnnx::Graph> graph = std::make_unique<pnnx::Graph>();
int load_result = graph->load(param_path, bin_path);
// 如果这里加载失败,请首先考虑相对路径(bin_path和param_path)的正确性问题
ASSERT_EQ(load_result, 0);
const auto &ops = graph->ops;
for (int i = 0; i < ops.size(); ++i) {
LOG(INFO) << ops.at(i)->name;
}
}
I20240301 08:55:21.320151 1057 test_ir.cpp:37] pnnx_input_0
I20240301 08:55:21.320165 1057 test_ir.cpp:37] linear
I20240301 08:55:21.320171 1057 test_ir.cpp:37] F.sigmoid_0
I20240301 08:55:21.320178 1057 test_ir.cpp:37] pnnx_output_0
可以看到,Graph
中存在四个算子:input_0, linear, sigmoid_0, output_0
。
单元测试 - 输出操作数
TEST(test_ir, pnnx_graph_operands) {
using namespace kuiper_infer;
/**
* 如果这里加载失败,请首先考虑相对路径的正确性问题
*/
std::string bin_path("../../../src/model_file/test_linear.pnnx.bin");
std::string param_path("../../../src/model_file/test_linear.pnnx.param");
std::unique_ptr<pnnx::Graph> graph = std::make_unique<pnnx::Graph>();
int load_result = graph->load(param_path, bin_path);
// 如果这里加载失败,请首先考虑相对路径(bin_path和param_path)的正确性问题
ASSERT_EQ(load_result, 0);
const auto &ops = graph->ops;
for (int i = 0; i < ops.size(); ++i) {
const auto &op = ops.at(i); // 取算子
LOG(INFO) << "OP Name: " << op->name;
LOG(INFO) << "OP Inputs";
for (int j = 0; j < op->inputs.size(); ++j) {
LOG(INFO) << "Input name: " << op->inputs.at(j)->name
<< " shape: " << ShapeStr(op->inputs.at(j)->shape);
}
LOG(INFO) << "OP Output";
for (int j = 0; j < op->outputs.size(); ++j) {
LOG(INFO) << "Output name: " << op->outputs.at(j)->name
<< " shape: " << ShapeStr(op->outputs.at(j)->shape);
}
LOG(INFO) << "---------------------------------------------";
}
}
I20240301 08:55:21.320255 1057 test_ir.cpp:56] OP Name: pnnx_input_0
I20240301 08:55:21.320262 1057 test_ir.cpp:57] OP Inputs
I20240301 08:55:21.320267 1057 test_ir.cpp:63] OP Output
I20240301 08:55:21.320274 1057 test_ir.cpp:65] Output name: 0 shape: 1 x 32
I20240301 08:55:21.320281 1057 test_ir.cpp:68] ---------------------------------------------
I20240301 08:55:21.320287 1057 test_ir.cpp:56] OP Name: linear
I20240301 08:55:21.320292 1057 test_ir.cpp:57] OP Inputs
I20240301 08:55:21.320298 1057 test_ir.cpp:59] Input name: 0 shape: 1 x 32
I20240301 08:55:21.320305 1057 test_ir.cpp:63] OP Output
I20240301 08:55:21.320362 1057 test_ir.cpp:65] Output name: 1 shape: 1 x 128
I20240301 08:55:21.320374 1057 test_ir.cpp:68] ---------------------------------------------
I20240301 08:55:21.320380 1057 test_ir.cpp:56] OP Name: F.sigmoid_0
I20240301 08:55:21.320386 1057 test_ir.cpp:57] OP Inputs
I20240301 08:55:21.320392 1057 test_ir.cpp:59] Input name: 1 shape: 1 x 128
I20240301 08:55:21.320405 1057 test_ir.cpp:63] OP Output
I20240301 08:55:21.320420 1057 test_ir.cpp:65] Output name: 2 shape: 1 x 128
I20240301 08:55:21.320432 1057 test_ir.cpp:68] ---------------------------------------------
I20240301 08:55:21.320443 1057 test_ir.cpp:56] OP Name: pnnx_output_0
I20240301 08:55:21.320453 1057 test_ir.cpp:57] OP Inputs
I20240301 08:55:21.320461 1057 test_ir.cpp:59] Input name: 2 shape: 1 x 128
I20240301 08:55:21.320467 1057 test_ir.cpp:63] OP Output
I20240301 08:55:21.320478 1057 test_ir.cpp:68] ---------------------------------------------
Operator
Operator
的结构定义如下:
class Operator
{
public:
// 输入输出操作数
std::vector<Operand*> inputs;
std::vector<Operand*> outputs;
// Operator的类型和名称
std::string type;
std::string name;
std::vector<std::string> inputnames;
std::map<std::string, Parameter> params;
std::map<std::string, Attribute> attrs;
}
单元测试
尝试输出Linear层的信息,并与可视化结果进行对照
TEST(test_ir, pnnx_graph_operands_and_params) {
using namespace kuiper_infer;
/**
* 如果这里加载失败,请首先考虑相对路径的正确性问题
*/
std::string bin_path("../../../src/model_file/test_linear.pnnx.bin");
std::string param_path("../../../src/model_file/test_linear.pnnx.param");
std::unique_ptr<pnnx::Graph> graph = std::make_unique<pnnx::Graph>();
int load_result = graph->load(param_path, bin_path);
// 如果这里加载失败,请首先考虑相对路径(bin_path和param_path)的正确性问题
ASSERT_EQ(load_result, 0);
const auto &ops = graph->ops;
for (int i = 0; i < ops.size(); ++i) {
const auto &op = ops.at(i);
if (op->name != "linear") { // 只打印Linear
continue;
}
LOG(INFO) << "OP Name: " << op->name;
LOG(INFO) << "OP Inputs";
for (int j = 0; j < op->inputs.size(); ++j) {
LOG(INFO) << "Input name: " << op->inputs.at(j)->name
<< " shape: " << ShapeStr(op->inputs.at(j)->shape);
}
LOG(INFO) << "OP Output";
for (int j = 0; j < op->outputs.size(); ++j) {
LOG(INFO) << "Output name: " << op->outputs.at(j)->name
<< " shape: " << ShapeStr(op->outputs.at(j)->shape);
}
LOG(INFO) << "Params";
for (const auto &attr : op->params) {
LOG(INFO) << attr.first << " type " << attr.second.type;
}
LOG(INFO) << "Weight: ";
for (const auto &weight : op->attrs) {
LOG(INFO) << weight.first << " : " << ShapeStr(weight.second.shape)
<< " type " << weight.second.type;
}
LOG(INFO) << "---------------------------------------------";
}
}
I20240302 12:06:46.516649 2026 test_ir.cpp:90] OP Name: linear
I20240302 12:06:46.516660 2026 test_ir.cpp:91] OP Inputs
I20240302 12:06:46.516680 2026 test_ir.cpp:93] Input name: 0 shape: 1 x 32
I20240302 12:06:46.516691 2026 test_ir.cpp:97] OP Output
I20240302 12:06:46.516698 2026 test_ir.cpp:99] Output name: 1 shape: 1 x 128
I20240302 12:06:46.516705 2026 test_ir.cpp:103] Params
I20240302 12:06:46.516714 2026 test_ir.cpp:105] bias type 1
I20240302 12:06:46.516724 2026 test_ir.cpp:105] in_features type 2
I20240302 12:06:46.516733 2026 test_ir.cpp:105] out_features type 2
I20240302 12:06:46.516741 2026 test_ir.cpp:108] Weight:
I20240302 12:06:46.516749 2026 test_ir.cpp:110] bias : 128 type 1
I20240302 12:06:46.516758 2026 test_ir.cpp:110] weight : 128 x 32 type 1
I20240302 12:06:46.516765 2026 test_ir.cpp:113] ---------------------------------------------
Parameter 和 Attribute
权重数据结构和参数数据结构的定义如下,可参考Operator中单元测试的结果进行对照;
class Parameter
{
public:
Parameter()
: type(0)
{
}
...
#if BUILD_PNNX
Parameter(const torch::jit::Node* value_node);
Parameter(const torch::jit::Value* value);
#endif // BUILD_PNNX
static Parameter parse_from_string(const std::string& value);
// 0=null 1=b 2=i 3=f 4=s 5=ai 6=af 7=as 8=others
int type;
// value
bool b;
int i;
float f;
std::vector<int> ai;
std::vector<float> af;
// keep std::string typed member the last for cross cxxabi compatibility
std::string s;
std::vector<std::string> as;
};
class Attribute
{
public:
Attribute()
: type(0)
{
}
#if BUILD_PNNX
Attribute(const at::Tensor& t);
#endif // BUILD_PNNX
Attribute(const std::initializer_list<int>& shape, const std::vector<float>& t);
// 0=null 1=f32 2=f64 3=f16 4=i32 5=i64 6=i16 7=i8 8=u8 9=bool
int type;
std::vector<int> shape;
std::vector<char> data;
};
Operand
Operand
的定义如下:
class Operand
{
public:
void remove_consumer(const Operator* c);
Operator* producer; // 产生该操作数的算子,只能有一个
std::vector<Operator*> consumers; // 使用该操作数的算子,可以有多个
// 0=null 1=f32 2=f64 3=f16 4=i32 5=i64 6=i16 7=i8 8=u8 9=bool 10=cp64 11=cp128 12=cp32
int type;
std::vector<int> shape;
// keep std::string typed member the last for cross cxxabi compatibility
std::string name;
std::map<std::string, Parameter> params;
};
单元测试
TEST(test_ir, pnnx_graph_operands_customer_producer) {
using namespace kuiper_infer;
/**
* 如果这里加载失败,请首先考虑相对路径的正确性问题
*/
std::string bin_path("../../../src/model_file/test_linear.pnnx.bin");
std::string param_path("../../../src/model_file/test_linear.pnnx.param");
std::unique_ptr<pnnx::Graph> graph = std::make_unique<pnnx::Graph>();
int load_result = graph->load(param_path, bin_path);
// 如果这里加载失败,请首先考虑相对路径(bin_path和param_path)的正确性问题
ASSERT_EQ(load_result, 0);
const auto &operands = graph->operands;
for (int i = 0; i < operands.size(); ++i) {
const auto &operand = operands.at(i);
LOG(INFO) << "Operand Name: #" << operand->name;
LOG(INFO) << "Customers: ";
for (const auto &customer : operand->consumers) {
LOG(INFO) << customer->name;
}
LOG(INFO) << "Producer: " << operand->producer->name;
}
}
I20240302 12:06:46.516847 2026 test_ir.cpp:131] Operand Name: #0
I20240302 12:06:46.516860 2026 test_ir.cpp:132] Customers:
I20240302 12:06:46.516870 2026 test_ir.cpp:134] linear
I20240302 12:06:46.516880 2026 test_ir.cpp:137] Producer: pnnx_input_0
I20240302 12:06:46.516891 2026 test_ir.cpp:131] Operand Name: #1
I20240302 12:06:46.516901 2026 test_ir.cpp:132] Customers:
I20240302 12:06:46.516911 2026 test_ir.cpp:134] F.sigmoid_0
I20240302 12:06:46.516922 2026 test_ir.cpp:137] Producer: linear
I20240302 12:06:46.516932 2026 test_ir.cpp:131] Operand Name: #2
I20240302 12:06:46.516942 2026 test_ir.cpp:132] Customers:
I20240302 12:06:46.516953 2026 test_ir.cpp:134] pnnx_output_0
I20240302 12:06:46.516964 2026 test_ir.cpp:137] Producer: F.sigmoid_0
可以结合模型的可视化进行观察:操作数#0
由pnnx_input_0
产生,在linear
层使用;操作数#1
由linear
产生,在F.sigmoid_0
层使用,操作数#2
同理。
Kuiperinfer封装
首先给出UML结构图
可知Kuiperinfer的计算图核心结构是RuntimeOperator
,该结构是对PNNX::Operator
的封装。
定义如下:
struct RuntimeOperator {
virtual ~RuntimeOperator();
bool has_forward = false;
std::string name; /// 计算节点的名称
std::string type; /// 计算节点的类型
std::shared_ptr<Layer> layer; /// 节点对应的计算Layer
std::vector<std::string> output_names; /// 节点的输出节点名称
std::shared_ptr<RuntimeOperand> output_operands; /// 节点的输出操作数
std::map<std::string, std::shared_ptr<RuntimeOperand>>
input_operands; /// 节点的输入操作数
std::vector<std::shared_ptr<RuntimeOperand>>
input_operands_seq; /// 节点的输入操作数,顺序排列
std::map<std::string, std::shared_ptr<RuntimeOperator>>
output_operators; /// 输出节点的名字和节点对应
std::map<std::string, RuntimeParameter*> params; /// 算子的参数信息
std::map<std::string, std::shared_ptr<RuntimeAttribute>>
attribute; /// 算子的属性信息,内含权重信息
};
-
name
:节点的名称,用来区分唯一节点,例如Conv_1, Conv_2
-
type
:节点类型,例如Convolution, Relu
-
layer
:完成具体计算的组件,例如在卷积算子中的卷积计算 -
input_operands, output_operands
:运算符的输入输出操作数,若输入为4, 3, 224, 224
,则input_operands
的datas
数组长度为4,数组中每个元素张量的大小为3, 224, 224
struct RuntimeOperand { // 假设输入为BCHW = (4, 3, 224, 224) std::string name; /// 操作数的名称 std::vector<int32_t> shapes; /// 操作数的形状 std::vector<std::shared_ptr<Tensor<float>>> datas; /// 存储操作数,长度为4,每个Tensor形状为(3, 224, 224) RuntimeDataType type = RuntimeDataType::kTypeUnknown; /// 操作数的类型,一般是float };
-
params
:算子的参数信息,例如卷积层的核大小、步长 -
attribute
:算子的weights,bias
从PNNX::Operator到Kuiper::RuntimeOperator
从PNNX::Operator
中提取信息,填入Kuiperinfer
对应的数据结构。
实现如下,首先给出计算图的类定义:
class RuntimeGraph {
public:
/**
* 初始化计算图
* @param param_path 计算图的结构文件
* @param bin_path 计算图中的权重文件
*/
RuntimeGraph(std::string param_path, std::string bin_path);
/**
* 设置权重文件
* @param bin_path 权重文件路径
*/
void set_bin_path(const std::string &bin_path);
/**
* 设置结构文件
* @param param_path 结构文件路径
*/
void set_param_path(const std::string ¶m_path);
/**
* 返回结构文件
* @return 返回结构文件
*/
const std::string ¶m_path() const;
/**
* 返回权重文件
* @return 返回权重文件
*/
const std::string &bin_path() const;
/**
* 计算图的初始化
* @return 是否初始化成功
*/
bool Init();
const std::vector<std::shared_ptr<RuntimeOperator>> &operators() const;
private:
/**
* 初始化kuiper infer计算图节点中的输入操作数
* @param inputs pnnx中的输入操作数
* @param runtime_operator 计算图节点
*/
static void InitGraphOperatorsInput(
const std::vector<pnnx::Operand *> &inputs,
const std::shared_ptr<RuntimeOperator> &runtime_operator);
/**
* 初始化kuiper infer计算图节点中的输出操作数
* @param outputs pnnx中的输出操作数
* @param runtime_operator 计算图节点
*/
static void InitGraphOperatorsOutput(
const std::vector<pnnx::Operand *> &outputs,
const std::shared_ptr<RuntimeOperator> &runtime_operator);
/**
* 初始化kuiper infer计算图中的节点属性
* @param attrs pnnx中的节点属性
* @param runtime_operator 计算图节点
*/
static void
InitGraphAttrs(const std::map<std::string, pnnx::Attribute> &attrs,
const std::shared_ptr<RuntimeOperator> &runtime_operator);
/**
* 初始化kuiper infer计算图中的节点参数
* @param params pnnx中的参数属性
* @param runtime_operator 计算图节点
*/
static void
InitGraphParams(const std::map<std::string, pnnx::Parameter> ¶ms,
const std::shared_ptr<RuntimeOperator> &runtime_operator);
public:
private:
std::string input_name_; /// 计算图输入节点的名称
std::string output_name_; /// 计算图输出节点的名称
std::string param_path_; /// 计算图的结构文件
std::string bin_path_; /// 计算图的权重文件
std::vector<std::shared_ptr<RuntimeOperator>> operators_;
std::map<std::string, std::shared_ptr<RuntimeOperator>> operators_maps_;
std::unique_ptr<pnnx::Graph> graph_; /// pnnx的graph
};
计算图的初始化
bool RuntimeGraph::Init() {
if (this->bin_path_.empty() || this->param_path_.empty()) {
LOG(ERROR) << "The bin path or param path is empty";
return false;
} // 检查权重文件
this->graph_ = std::make_unique<pnnx::Graph>(); // make函数,用来将参数转发给动态分配对象,然后返回对象的智能指针
int load_result = this->graph_->load(param_path_, bin_path_); // 调用PNNX的load读取参数
if (load_result != 0) {
LOG(ERROR) << "Can not find the param path or bin path: " << param_path_
<< " " << bin_path_;
return false;
} // 检查读取结果
std::vector<pnnx::Operator *> operators = this->graph_->ops; // 获取PNNX的算子列表
if (operators.empty()) {
LOG(ERROR) << "Can not read the layers' define";
return false;
}
this->operators_.clear();
this->operators_maps_.clear(); // 主动释放vector内存
for (const pnnx::Operator *op: operators) { // 处理每个算子
if (!op) {
LOG(ERROR) << "Meet the empty node";
continue;
} else {
std::shared_ptr<RuntimeOperator> runtime_operator =
std::make_shared<RuntimeOperator>();
// 初始化算子的名称
runtime_operator->name = op->name;
runtime_operator->type = op->type;
// 初始化算子中的input
const std::vector<pnnx::Operand *> &inputs = op->inputs;
if (!inputs.empty()) {
InitGraphOperatorsInput(inputs, runtime_operator);
}
// 记录输出operand中的名称
const std::vector<pnnx::Operand *> &outputs = op->outputs;
if (!outputs.empty()) {
InitGraphOperatorsOutput(outputs, runtime_operator);
}
// 初始化算子中的attribute(权重)
const std::map<std::string, pnnx::Attribute> &attrs = op->attrs;
if (!attrs.empty()) {
InitGraphAttrs(attrs, runtime_operator);
}
// 初始化算子中的parameter
const std::map<std::string, pnnx::Parameter> ¶ms = op->params;
if (!params.empty()) {
InitGraphParams(params, runtime_operator);
}
this->operators_.push_back(runtime_operator); // 完成所有初始化后,插入到operators数组中
this->operators_maps_.insert({runtime_operator->name, runtime_operator}); // 插入到map中
}
}
return true;
}
实现中Operand,Attribute和Param的初始化方法见下文。
初始化RuntimeOperator中的RuntimeOperand
提取PNNX中的操作数到RuntimeOperand中,对应实现为InitGraphOperatorsInput
和InitGraphOperatorsOutput
for (const pnnx::Operator *op : operators){
inputs = op->inputs;
InitGraphOperatorsInput(inputs, runtime_operator);
...
}
void RuntimeGraph::InitGraphOperatorsInput(
const std::vector<pnnx::Operand *> &inputs,
const std::shared_ptr<RuntimeOperator> &runtime_operator) {
for (const pnnx::Operand *input: inputs) { // 根据pnnx的每个操作数,初始化RuntimeOperator中的每个操作数
if (!input) {
continue;
}
const pnnx::Operator *producer = input->producer;
std::shared_ptr<RuntimeOperand> runtime_operand =
std::make_shared<RuntimeOperand>();
runtime_operand->name = producer->name;
runtime_operand->shapes = input->shape; // 设置name,shape
switch (input->type) { // 设置type
case 1: {
runtime_operand->type = RuntimeDataType::kTypeFloat32;
break;
}
case 0: {
runtime_operand->type = RuntimeDataType::kTypeUnknown;
break;
}
default: {
LOG(FATAL) << "Unknown input operand type: " << input->type;
}
}
runtime_operator->input_operands.insert({producer->name, runtime_operand}); // 记录producer,即产生这个操作数的运算符
runtime_operator->input_operands_seq.push_back(runtime_operand); // 顺序排列
}
}
const std::vector<pnnx::Operand*>& outputs = op->outputs;
InitGraphOperatorsOutput(outputs, runtime_operator);
void RuntimeGraph::InitGraphOperatorsOutput(
const std::vector<pnnx::Operand *> &outputs,
const std::shared_ptr<RuntimeOperator> &runtime_operator) {
for (const pnnx::Operand *output: outputs) {
if (!output) {
continue;
}
const auto &consumers = output->consumers; // 获取cusumer,即输出的操作数
for (const auto &c: consumers) {
runtime_operator->output_names.push_back(c->name); // 记录cusumer的名字,输出操作数在这里不进行构建
}
}
}
初始化RuntimeAttribute
const std::map<std::string, pnnx::Attribute>& attrs = op->attrs;
InitGraphAttrs(attrs, runtime_operator);
void RuntimeGraph::InitGraphAttrs(
const std::map<std::string, pnnx::Attribute>& attrs,
const std::shared_ptr<RuntimeOperator>& runtime_operator) {
for (const auto& [name, attr] : attrs) {
switch (attr.type) {
case 1: {
std::shared_ptr<RuntimeAttribute> runtime_attribute =
std::make_shared<RuntimeAttribute>();
runtime_attribute->type = RuntimeDataType::kTypeFloat32;
runtime_attribute->weight_data = attr.data;
runtime_attribute->shape = attr.shape; // 记录信息
runtime_operator->attribute.insert({name, runtime_attribute}); // 插入名字和信息,在Linear中,name就是weight或bias
break;
}
default: {
LOG(FATAL) << "Unknown attribute type: " << attr.type;
}
}
}
}
初始化RuntimeParam
void RuntimeGraph::InitGraphParams(
const std::map<std::string, pnnx::Parameter> ¶ms,
const std::shared_ptr<RuntimeOperator> &runtime_operator) {
for (const auto &[name, parameter]: params) { // 分别处理每个param
const int type = parameter.type;
switch (type) { // 按不同类别处理
// 在这里写出实现
// 0=null 1=b 2=i 3=f 4=s 5=ai 6=af 7=as 8=others,对应分case即可
case int(RuntimeParameterType::kParameterUnknown) : {
RuntimeParameter* runtime_parameter = new RuntimeParameter;
runtime_operator->params.insert({ name, runtime_parameter });
break;
}
case int(RuntimeParameterType::kParameterBool) : { // bool,需要赋值
RuntimeParameterBool* runtime_parameter = new RuntimeParameterBool;
runtime_parameter->value = parameter.b;
runtime_operator->params.insert({ name, runtime_parameter });
break;
}
case int(RuntimeParameterType::kParameterInt) : {
RuntimeParameterInt* runtime_parameter = new RuntimeParameterInt;
runtime_parameter->value = parameter.i;
runtime_operator->params.insert({ name, runtime_parameter });
break;
}
case int(RuntimeParameterType::kParameterFloat) : {
RuntimeParameterFloat* runtime_parameter = new RuntimeParameterFloat;
runtime_parameter->value = parameter.f;
runtime_operator->params.insert({ name, runtime_parameter });
break;
}
case int(RuntimeParameterType::kParameterString) : {
RuntimeParameterString* runtime_parameter = new RuntimeParameterString;
runtime_parameter->value = parameter.s;
runtime_operator->params.insert({ name, runtime_parameter });
break;
}
case int(RuntimeParameterType::kParameterIntArray) : {
RuntimeParameterIntArray* runtime_parameter = new RuntimeParameterIntArray;
runtime_parameter->value = parameter.ai;
runtime_operator->params.insert({ name, runtime_parameter });
break;
}
case int(RuntimeParameterType::kParameterFloatArray) : {
RuntimeParameterFloatArray* runtime_parameter = new RuntimeParameterFloatArray;
runtime_parameter->value = parameter.af;
runtime_operator->params.insert({ name, runtime_parameter });
break;
}
case int(RuntimeParameterType::kParameterStringArray) : {
RuntimeParameterStringArray* runtime_parameter = new RuntimeParameterStringArray;
runtime_parameter->value = parameter.as;
runtime_operator->params.insert({ name, runtime_parameter });
break;
}
default: {
LOG(FATAL) << "Unknown parameter type: " << type;
}
}
}
}
单元测试
TEST(test_ir, pnnx_graph_all_homework) {
using namespace kuiper_infer;
/**
* 如果这里加载失败,请首先考虑相对路径的正确性问题
*/
std::string bin_path("../../../src/model_file/test_linear.pnnx.bin");
std::string param_path("../../../src/model_file/test_linear.pnnx.param");
RuntimeGraph graph(param_path, bin_path);
const bool init_success = graph.Init();
ASSERT_EQ(init_success, true);
const auto &operators = graph.operators();
for (const auto &operator_ : operators) { // 检查operator
if (operator_->name == "linear") {
const auto ¶ms = operator_->params;
ASSERT_EQ(params.size(), 3); // 三个参数,分别为bias,in_features,out_features
/
ASSERT_EQ(params.count("bias"), 1);
RuntimeParameter *parameter_bool = params.at("bias");
ASSERT_NE(parameter_bool, nullptr);
ASSERT_EQ((dynamic_cast<RuntimeParameterBool *>(parameter_bool)->value),
true);
/
ASSERT_EQ(params.count("in_features"), 1);
RuntimeParameter *parameter_in_features = params.at("in_features");
ASSERT_NE(parameter_in_features, nullptr);
ASSERT_EQ(
(dynamic_cast<RuntimeParameterInt *>(parameter_in_features)->value),
32);
/
ASSERT_EQ(params.count("out_features"), 1);
RuntimeParameter *parameter_out_features = params.at("out_features");
ASSERT_NE(parameter_out_features, nullptr);
ASSERT_EQ(
(dynamic_cast<RuntimeParameterInt *>(parameter_out_features)->value),
128);
}
}
}
参考
- 【Kuiperinfer】:https://github.com/zjhellofss/kuiperdatawhale
- 作者B站主页:https://space.bilibili.com/1822828582?spm_id_from=333.337.search-card.all.click
- 【Armadillo Docs】:https://arma.sourceforge.net/docs.html
- 【PNNX】:https://github.com/pnnx/pnnx