先看报错:
/home/orangepi/MNN-master/source/backend/cpu/arm/arm64/bf16/ARMV86_MNNPackedMatMulRemain_BF16.S:158: Fatal error: macros nested too deeply
再看代码:
PostTreatLH8:
FMAX v9, v15, v16, v17, v18
FMAX v9, v19, v20, v21, v22
FMAX v9, v23, v24, v25, v26
FMAX v9, v27, v28, v29, v30
FMIN v10, v15, v16, v17, v18
FMIN v10, v19, v20, v21, v22
FMIN v10, v23, v24, v25, v26
FMIN v10, v27, v28, v29, v30
这段代码中使用了多个 FMAX 和 FMIN 函数,它们分别计算一组向量 v15 到 v30 中的最大值和最小值,并将结果存储在两个不同的向量(v9 和 v10)中。具体来说,每个函数接受 5 个向量作为参数,例如:
(1)第一个 FMAX 函数 FMAX v9, v15, v16, v17, v18 计算向量 v15, v16, v17, v18 中的最大值,并将结果存储在 v9 向量中。
(2) 第二个 FMAX 函数 FMAX v9, v19, v20, v21, v22 计算向量 v19, v20, v21, v22 中的最大值,并将结果存储在 v9 向量中。
依此类推,第三个和第四个 FMAX 函数分别计算向量 v23 到 v26 和 v27 到 v30 中的最大值,并将结果存储在 v9 向量中。
类似地,第一个 FMIN 函数 FMIN v10, v15, v16, v17, v18 计算向量 v15, v16, v17, v18 中的最小值,并将结果存储在 v10 向量中。其余的 FMIN 函数依此类推,分别计算各自参数向量中的最小值,并将结果存储在 v10 向量中。
总之,这段代码的作用是计算给定向量组中的最大值和最小值,并将结果存储在两个新的向量中。这些计算通常用于图像处理和机器学习等领域中,以处理和提取特征向量等任务。
启动bf16进行构建时,源码的汇编指令嵌套过深,会导致编译时定义的宏无法展开,这个时候我们需要将指令集中所有关于FMAX和FMIN两个变量的嵌套调用展开,按照以下这种形式修改:
PostTreatLH8:
fmax v15.4s, v15.4s, v9.4s
fmax v16.4s, v16.4s, v9.4s
fmax v17.4s, v17.4s, v9.4s
fmax v18.4s, v18.4s, v9.4s
fmax v19.4s, v19.4s, v9.4s
fmax v20.4s, v20.4s, v9.4s
fmax v21.4s, v21.4s, v9.4s
fmax v22.4s, v22.4s, v9.4s
fmax v23.4s, v23.4s, v9.4s
fmax v24.4s, v24.4s, v9.4s
fmax v25.4s, v25.4s, v9.4s
fmax v26.4s, v26.4s, v9.4s
fmax v27.4s, v27.4s, v9.4s
fmax v28.4s, v28.4s, v9.4s
fmax v29.4s, v29.4s, v9.4s
fmax v30.4s, v30.4s, v9.4s
fmin v15.4s, v15.4s, v10.4s
fmin v16.4s, v16.4s, v10.4s
fmin v17.4s, v17.4s, v10.4s
fmin v18.4s, v18.4s, v10.4s
fmin v19.4s, v19.4s, v10.4s
fmin v20.4s, v20.4s, v10.4s
fmin v21.4s, v21.4s, v10.4s
fmin v22.4s, v22.4s, v10.4s
fmin v23.4s, v23.4s, v10.4s
fmin v24.4s, v24.4s, v10.4s
fmin v25.4s, v25.4s, v10.4s
fmin v26.4s, v26.4s, v10.4s
fmin v27.4s, v27.4s, v10.4s
fmin v28.4s, v28.4s, v10.4s
fmin v29.4s, v29.4s, v10.4s
fmin v30.4s, v30.4s, v10.4s
接下来看报错:
/home/orangepi/MNN-master/source/backend/cpu/arm/arm64/bf16/ARMV86_MNNPackedMatMul_BF16.S:174: Fatal error: macros nested too deeply
再看代码:
PostTreatLH8:
dup v5.4s, w17
dup v6.4s, w18
FMAX v5, v7, v8, v9, v10
FMAX v5, v11, v12, v13, v14
FMAX v5, v15, v16, v17, v18
FMAX v5, v19, v20, v21, v22
FMAX v5, v23, v24, v25, v26
FMAX v5, v27, v28, v29, v30
FMIN v6, v7, v8, v9, v10
FMIN v6, v11, v12, v13, v14
FMIN v6, v15, v16, v17, v18
FMIN v6, v19, v20, v21, v22
FMIN v6, v23, v24, v25, v26
按照以下这种形式修改:
PostTreatLH8:
dup v5.4s, w17
dup v6.4s, w18
fmax v7.4s, v7.4s, v5.4s
fmax v8.4s, v8.4s, v5.4s
fmax v9.4s, v9.4s, v5.4s
fmax v10.4s, v10.4s, v5.4s
fmax v11.4s, v11.4s, v5.4s
fmax v12.4s, v12.4s, v5.4s
fmax v13.4s, v13.4s, v5.4s
fmax v14.4s, v14.4s, v5.4s
fmax v15.4s, v15.4s, v5.4s
fmax v16.4s, v16.4s, v5.4s
fmax v17.4s, v17.4s, v5.4s
fmax v18.4s, v18.4s, v5.4s
fmax v19.4s, v19.4s, v5.4s
fmax v20.4s, v20.4s, v5.4s
fmax v21.4s, v21.4s, v5.4s
fmax v22.4s, v22.4s, v5.4s
fmax v23.4s, v23.4s, v5.4s
fmax v24.4s, v24.4s, v5.4s
fmax v25.4s, v25.4s, v5.4s
fmax v26.4s, v26.4s, v5.4s
fmax v27.4s, v27.4s, v5.4s
fmax v28.4s, v28.4s, v5.4s
fmax v29.4s, v29.4s, v5.4s
fmax v30.4s, v30.4s, v5.4s
fmin v7.4s, v7.4s, v6.4s
fmin v8.4s, v8.4s, v6.4s
fmin v9.4s, v9.4s, v6.4s
fmin v10.4s, v10.4s, v6.4s
fmin v11.4s, v11.4s, v6.4s
fmin v12.4s, v12.4s, v6.4s
fmin v13.4s, v13.4s, v6.4s
fmin v14.4s, v14.4s, v6.4s
fmin v15.4s, v15.4s, v6.4s
fmin v16.4s, v16.4s, v6.4s
fmin v17.4s, v17.4s, v6.4s
fmin v18.4s, v18.4s, v6.4s
fmin v19.4s, v19.4s, v6.4s
fmin v20.4s, v20.4s, v6.4s
fmin v21.4s, v21.4s, v6.4s
fmin v22.4s, v22.4s, v6.4s
fmin v23.4s, v23.4s, v6.4s
fmin v24.4s, v24.4s, v6.4s
fmin v25.4s, v25.4s, v6.4s
fmin v26.4s, v26.4s, v6.4s
这段代码使用了 ARM 架构中的向量寄存器(例如 v7.4s 和 v5.4s),它们可以同时存储多个浮点数或整数,并且通常用于图像处理和机器学习等领域。其中,FMAX 和 FMIN 函数分别用来计算一组向量中的最大值和最小值,并将结果存储在新的向量中。
具体来说,该代码首先使用 DUP 函数将标量 w17 和 w18 复制到向量 v5.4s 和 v6.4s 中。然后,它调用了 6 次 FMAX 函数和 5 次 FMIN 函数,每次传递 4 个向量参数。这意味着该代码实际上是将 24 个向量中的最大值和最小值分别存储在 v5.4s 和 v6.4s 中。
总之,该代码使用向量计算来高效地计算一组向量中的最大值和最小值,并且通过重复调用 FMAX 和 FMIN 函数来处理多个向量。这种方法可以在某些情况下比传统的循环方式更快,并且可以充分利用现代 CPU 的并行计算能力。
以下放出两个文件:
ARMV86_MNNPackedMatMul_BF16.S
//
// ARMV86_MNNPackedMatMul_BF16.S
// MNN
//
// Created by MNN on 2022/10/09.
// Copyright © 2018-2021 Alibaba Group Holding Limited
//
#ifdef __aarch64__
#include "MNNAsmGlobal.h"
.text
.align 5
.macro SET_ZERO d0, d1, d2, d3
movi \d0\().4s, #0
movi \d1\().4s, #0
movi \d2\().4s, #0
movi \d3\().4s, #0
.endm
.macro Float32ToBf16 d0, d1, d2, d3
shrn \d0\().4h, \d0\().4s, #16
shrn \d1\().4h, \d1\().4s, #16
shrn \d2\().4h, \d2\().4s, #16
shrn \d3\().4h, \d3\().4s, #16
.endm
.macro SET_BIAS s, d0, d1, d2, d3
mov \d0\().16b, \s\().16b
mov \d1\().16b, \s\().16b
mov \d2\().16b, \s\().16b
mov \d3\().16b, \s\().16b
.endm
// 12 * 8 * 4 MatMul
asm_function ARMV86_MNNPackedMatMul_BF16
//void ARMV86_MNNPackedMatMul_BF16(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias);
// x0: C, x1:A, x2:B, x3:parameter, x4: postParameters, x5:bias
stp d14, d15, [sp, #-64]!
stp d12, d13, [sp, #16]
stp d10, d11, [sp, #32]
stp d8, d9, [sp, #48]
//ldr x8, [x3, #0] // deprecated
ldr x9, [x3, #8] // l
ldr x10, [x3, #16] // h
mov x11, #64 // B_stride = LP * HP = 4 * 8 * sizeof(int16_t)
ldr x13, [x3, #24] // cStride
ldr x7, [x3, #40] // bExtraStride
add x10, x10, #3
lsr x10, x10, #2
add x9, x9, #3
lsr x9, x9, #2
cbz x4, Start
ld1 {v5.4s}, [x4]
mov w17, v5.s[2] // min value
mov w18, v5.s[3] // max value
Start:
cmp x10, #2
blt LH4
LH8:
sub x14, x13, #96 // cStride - 96
LoopH:
mov x15, x1
mov x12, x9
cbz x5, NoBiasH8
ld1 {v0.4h, v1.4h}, [x5], #16 // 8 * sizeof(int16_t)
shll v0.4s, v0.4h, #16
shll v1.4s, v1.4h, #16
mov v2.16b, v0.16b
mov v3.16b, v1.16b
uzp1 v18.2d, v0.2d, v2.2d // bias_0, bias_1, bias_0, bias_1
uzp2 v19.2d, v0.2d, v2.2d // bias_2, bias_3, bias_2, bias_3
uzp1 v30.2d, v1.2d, v3.2d // bias_0, bias_1, bias_0, bias_1
uzp2 v31.2d, v1.2d, v3.2d // bias_2, bias_3, bias_2, bias_3
SET_BIAS v18, v8, v10, v12, v14
mov v16.16b, v18.16b
SET_BIAS v19, v9, v11, v13, v15
mov v17.16b, v19.16b
SET_BIAS v30, v20, v22, v24, v26
mov v28.16b, v30.16b
SET_BIAS v31, v21, v23, v25, v27
mov v29.16b, v31.16b
b LoopL
NoBiasH8:
SET_ZERO v8, v9, v10, v11
SET_ZERO v12, v13, v14, v15
SET_ZERO v16, v17, v18, v19
SET_ZERO v20, v21, v22, v23
SET_ZERO v24, v25, v26, v27
SET_ZERO v28, v29, v30, v31
LoopL:
// A [12, 4, bf16] : rn = 6 : v2 - v7
// B [ 8, 4, bf16] : rn = 2 : v0 - v1
// C [12, 8, fp32] : rn = 24 : v8 - v31
ld1 {v2.8h, v3.8h, v4.8h, v5.8h}, [x15], #64 // A: 8 * 4 * sizeof(int16_t)
ld1 {v6.8h, v7.8h}, [x15], #32 // A: 4 * 4 * sizeof(int16_t)
ld1 {v0.8h, v1.8h}, [x2], #32 // B: 4 * 4 * sizeof(int16_t)
.inst 0x6e40ec48 // bfmmla v8.4s, v2.8h, v0.8h
.inst 0x6e41ec49 // bfmmla v9.4s, v2.8h, v1.8h
.inst 0x6e40ec6a // bfmmla v10.4s, v3.8h, v0.8h
.inst 0x6e41ec6b // bfmmla v11.4s, v3.8h, v1.8h
.inst 0x6e40ec8c // bfmmla v12.4s, v4.8h, v0.8h
.inst 0x6e41ec8d // bfmmla v13.4s, v4.8h, v1.8h
.inst 0x6e40ecae // bfmmla v14.4s, v5.8h, v0.8h
.inst 0x6e41ecaf // bfmmla v15.4s, v5.8h, v1.8h
.inst 0x6e40ecd0 // bfmmla v16.4s, v6.8h, v0.8h
.inst 0x6e41ecd1 // bfmmla v17.4s, v6.8h, v1.8h
.inst 0x6e40ecf2 // bfmmla v18.4s, v7.8h, v0.8h
.inst 0x6e41ecf3 // bfmmla v19.4s, v7.8h, v1.8h
ld1 {v0.8h, v1.8h}, [x2], #32 // B: 4 * 4 * sizeof(int16_t)
.inst 0x6e40ec54 // bfmmla v20.4s, v2.8h, v0.8h
.inst 0x6e41ec55 // bfmmla v21.4s, v2.8h, v1.8h
.inst 0x6e40ec76 // bfmmla v22.4s, v3.8h, v0.8h
.inst 0x6e41ec77 // bfmmla v23.4s, v3.8h, v1.8h
.inst 0x6e40ec98 // bfmmla v24.4s, v4.8h, v0.8h
.inst 0x6e41ec99 // bfmmla v25.4s, v4.8h, v1.8h
.inst 0x6e40ecba // bfmmla v26.4s, v5.8h, v0.8h
.inst 0x6e41ecbb // bfmmla v27.4s, v5.8h, v1.8h
.inst 0x6e40ecdc // bfmmla v28.4s, v6.8h, v0.8h
.inst 0x6e41ecdd // bfmmla v29.4s, v6.8h, v1.8h
.inst 0x6e40ecfe // bfmmla v30.4s, v7.8h, v0.8h
.inst 0x6e41ecff // bfmmla v31.4s, v7.8h, v1.8h
subs x12, x12, #1
bgt LoopL
LoopLEnd:
uzp1 v7.2d, v8.2d, v9.2d
uzp2 v8.2d, v8.2d, v9.2d
uzp1 v9.2d, v10.2d, v11.2d
uzp2 v10.2d, v10.2d, v11.2d
uzp1 v11.2d, v12.2d, v13.2d
uzp2 v12.2d, v12.2d, v13.2d
uzp1 v13.2d, v14.2d, v15.2d
uzp2 v14.2d, v14.2d, v15.2d
uzp1 v15.2d, v16.2d, v17.2d
uzp2 v16.2d, v16.2d, v17.2d
uzp1 v17.2d, v18.2d, v19.2d
uzp2 v18.2d, v18.2d, v19.2d
uzp1 v19.2d, v20.2d, v21.2d
uzp2 v20.2d, v20.2d, v21.2d
uzp1 v21.2d, v22.2d, v23.2d
uzp2 v22.2d, v22.2d, v23.2d
uzp1 v23.2d, v24.2d, v25.2d
uzp2 v24.2d, v24.2d, v25.2d
uzp1 v25.2d, v26.2d, v27.2d
uzp2 v26.2d, v26.2d, v27.2d
uzp1 v27.2d, v28.2d, v29.2d
uzp2 v28.2d, v28.2d, v29.2d
uzp1 v29.2d, v30.2d, v31.2d
uzp2 v30.2d, v30.2d, v31.2d
cbz x4, StoreLH8
PostTreatLH8:
dup v5.4s, w17
dup v6.4s, w18
fmax v7.4s, v7.4s, v5.4s
fmax v8.4s, v8.4s, v5.4s
fmax v9.4s, v9.4s, v5.4s
fmax v10.4s, v10.4s, v5.4s
fmax v11.4s, v11.4s, v5.4s
fmax v12.4s, v12.4s, v5.4s
fmax v13.4s, v13.4s, v5.4s
fmax v14.4s, v14.4s, v5.4s
fmax v15.4s, v15.4s, v5.4s
fmax v16.4s, v16.4s, v5.4s
fmax v17.4s, v17.4s, v5.4s
fmax v18.4s, v18.4s, v5.4s
fmax v19.4s, v19.4s, v5.4s
fmax v20.4s, v20.4s, v5.4s
fmax v21.4s, v21.4s, v5.4s
fmax v22.4s, v22.4s, v5.4s
fmax v23.4s, v23.4s, v5.4s
fmax v24.4s, v24.4s, v5.4s
fmax v25.4s, v25.4s, v5.4s
fmax v26.4s, v26.4s, v5.4s
fmax v27.4s, v27.4s, v5.4s
fmax v28.4s, v28.4s, v5.4s
fmax v29.4s, v29.4s, v5.4s
fmax v30.4s, v30.4s, v5.4s
fmin v7.4s, v7.4s, v6.4s
fmin v8.4s, v8.4s, v6.4s
fmin v9.4s, v9.4s, v6.4s
fmin v10.4s, v10.4s, v6.4s
fmin v11.4s, v11.4s, v6.4s
fmin v12.4s, v12.4s, v6.4s
fmin v13.4s, v13.4s, v6.4s
fmin v14.4s, v14.4s, v6.4s
fmin v15.4s, v15.4s, v6.4s
fmin v16.4s, v16.4s, v6.4s
fmin v17.4s, v17.4s, v6.4s
fmin v18.4s, v18.4s, v6.4s
fmin v19.4s, v19.4s, v6.4s
fmin v20.4s, v20.4s, v6.4s
fmin v21.4s, v21.4s, v6.4s
fmin v22.4s, v22.4s, v6.4s
fmin v23.4s, v23.4s, v6.4s
fmin v24.4s, v24.4s, v6.4s
fmin v25.4s, v25.4s, v6.4s
fmin v26.4s, v26.4s, v6.4s
StoreLH8:
Float32ToBf16 v7, v8, v9, v10
Float32ToBf16 v11, v12, v13, v14
Float32ToBf16 v15, v16, v17, v18
Float32ToBf16 v19, v20, v21, v22
Float32ToBf16 v23, v24, v25, v26
Float32ToBf16 v27, v28, v29, v30
st1 {v7.4h, v8.4h, v9.4h, v10.4h}, [x0], #32 // 16 * sizeof(int16_t)
st1 {v11.4h, v12.4h, v13.4h, v14.4h}, [x0], #32 // 16 * sizeof(int16_t)
st1 {v15.4h, v16.4h, v17.4h, v18.4h}, [x0], #32 // 16 * sizeof(int16_t)
add x0, x0, x14
st1 {v19.4h, v20.4h, v21.4h, v22.4h}, [x0], #32 // 16 * sizeof(int16_t)
st1 {v23.4h, v24.4h, v25.4h, v26.4h}, [x0], #32 // 16 * sizeof(int16_t)
st1 {v27.4h, v28.4h, v29.4h, v30.4h}, [x0], #32 // 16 * sizeof(int16_t)
add x0, x0, x14
add x2, x2, x7 // weight stride
sub x10, x10, #2
cmp x10, #2
bge LoopH
LH4:
cbz x10, End
LoopHR:
mov x15, x1
mov x12, x9
cbz x5, NoBiasH4
ld1 {v0.4h}, [x5], #8 // 8 * sizeof(int16_t)
shll v0.4s, v0.4h, #16
mov v2.16b, v0.16b
uzp1 v18.2d, v0.2d, v2.2d // bias_0, bias_1, bias_0, bias_1
uzp2 v19.2d, v0.2d, v2.2d // bias_2, bias_3, bias_2, bias_3
SET_BIAS v18, v8, v10, v12, v14
mov v16.16b, v18.16b
SET_BIAS v19, v9, v11, v13, v15
mov v17.16b, v19.16b
b LoopLR
NoBiasH4:
SET_ZERO v8, v9, v10, v11
SET_ZERO v12, v13, v14, v15
SET_ZERO v16, v17, v18, v19
LoopLR:
// A [12, 4, bf16] : rn = 6 : v2 - v7
// B [ 4, 4, bf16] : rn = 2 : v0 - v1
// C [12, 4, fp32] : rn = 12 : v8 - v19
ld1 {v2.8h, v3.8h, v4.8h, v5.8h}, [x15], #64 // A: 8 * 4 * sizeof(int16_t)
ld1 {v6.8h, v7.8h}, [x15], #32 // A: 4 * 4 * sizeof(int16_t)
ld1 {v0.8h, v1.8h}, [x2], x11 // B: 4 * 4 * sizeof(int16_t)
.inst 0x6e40ec48 // bfmmla v8.4s, v2.8h, v0.8h
.inst 0x6e41ec49 // bfmmla v9.4s, v2.8h, v1.8h
.inst 0x6e40ec6a // bfmmla v10.4s, v3.8h, v0.8h
.inst 0x6e41ec6b // bfmmla v11.4s, v3.8h, v1.8h
.inst 0x6e40ec8c // bfmmla v12.4s, v4.8h, v0.8h
.inst 0x6e41ec8d // bfmmla v13.4s, v4.8h, v1.8h
.inst 0x6e40ecae // bfmmla v14.4s, v5.8h, v0.8h
.inst 0x6e41ecaf // bfmmla v15.4s, v5.8h, v1.8h
.inst 0x6e40ecd0 // bfmmla v16.4s, v6.8h, v0.8h
.inst 0x6e41ecd1 // bfmmla v17.4s, v6.8h, v1.8h
.inst 0x6e40ecf2 // bfmmla v18.4s, v7.8h, v0.8h
.inst 0x6e41ecf3 // bfmmla v19.4s, v7.8h, v1.8h
subs x12, x12, #1
bgt LoopLR
LoopLREnd:
add x2, x2, x7 // weight stride
uzp1 v7.2d, v8.2d, v9.2d
uzp2 v8.2d, v8.2d, v9.2d
uzp1 v9.2d, v10.2d, v11.2d
uzp2 v10.2d, v10.2d, v11.2d
uzp1 v11.2d, v12.2d, v13.2d
uzp2 v12.2d, v12.2d, v13.2d
uzp1 v13.2d, v14.2d, v15.2d
uzp2 v14.2d, v14.2d, v15.2d
uzp1 v15.2d, v16.2d, v17.2d
uzp2 v16.2d, v16.2d, v17.2d
uzp1 v17.2d, v18.2d, v19.2d
uzp2 v18.2d, v18.2d, v19.2d
cbz x4, StoreLH4
PostTreatLH4:
dup v5.4s, w17
dup v6.4s, w18
fmax v7.4s, v7.4s, v5.4s
fmax v8.4s, v8.4s, v5.4s
fmax v9.4s, v9.4s, v5.4s
fmax v10.4s, v10.4s, v5.4s
fmax v11.4s, v11.4s, v5.4s
fmax v12.4s, v12.4s, v5.4s
fmax v13.4s, v13.4s, v5.4s
fmax v14.4s, v14.4s, v5.4s
fmax v15.4s, v15.4s, v5.4s
fmax v16.4s, v16.4s, v5.4s
fmax v17.4s, v17.4s, v5.4s
fmax v18.4s, v18.4s, v5.4s
fmin v7.4s, v7.4s, v6.4s
fmin v8.4s, v8.4s, v6.4s
fmin v9.4s, v9.4s, v6.4s
fmin v10.4s, v10.4s, v6.4s
fmin v11.4s, v11.4s, v6.4s
fmin v12.4s, v12.4s, v6.4s
fmin v13.4s, v13.4s, v6.4s
fmin v14.4s, v14.4s, v6.4s
fmin v15.4s, v15.4s, v6.4s
fmin v16.4s, v16.4s, v6.4s
fmin v17.4s, v17.4s, v6.4s
fmin v18.4s, v18.4s, v6.4s
StoreLH4:
Float32ToBf16 v7, v8, v9, v10
Float32ToBf16 v11, v12, v13, v14
Float32ToBf16 v15, v16, v17, v18
st1 {v7.4h, v8.4h, v9.4h, v10.4h}, [x0], #32 // 16 * sizeof(int16_t)
st1 {v11.4h, v12.4h, v13.4h, v14.4h}, [x0], #32 // 16 * sizeof(int16_t)
st1 {v15.4h, v16.4h, v17.4h, v18.4h}, [x0], #32 // 16 * sizeof(int16_t)
End:
ldp d8, d9, [sp, #48]
ldp d10, d11, [sp, #32]
ldp d12, d13, [sp, #16]
ldp d14, d15, [sp], #64
ret
#endif
ARMV86_MNNPackedMatMulRemain_BF16.S
//
// ARMV86_MNNPackedMatMulRemain_BF16.S
// MNN
//
// Created by MNN on 2022/10/09.
// Copyright © 2018-2021 Alibaba Group Holding Limited
//
#ifdef __aarch64__
#include "MNNAsmGlobal.h"
.text
.align 5
.macro SET_ZERO d0, d1, d2, d3
movi \d0\().4s, #0
movi \d1\().4s, #0
movi \d2\().4s, #0
movi \d3\().4s, #0
.endm
.macro Float32ToBf16 d0, d1, d2, d3
shrn \d0\().4h, \d0\().4s, #16
shrn \d1\().4h, \d1\().4s, #16
shrn \d2\().4h, \d2\().4s, #16
shrn \d3\().4h, \d3\().4s, #16
.endm
.macro SET_BIAS s, d0, d1, d2
mov \d0\().16b, \s\().16b
mov \d1\().16b, \s\().16b
mov \d2\().16b, \s\().16b
.endm
// 12 * 8 * 4 MatMul
asm_function ARMV86_MNNPackedMatMulRemain_BF16
//void ARMV86_MNNPackedMatMulRemain_BF16(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias);
//Auto x0: C, x1:A, x2:B, x3:eSize, x4:parameter, x5:postParameters, x6:bias
sub sp, sp, #32
str x19, [sp, #0]
str x20, [sp, #8]
str x21, [sp, #16]
ldr x11, [x4, #0] // aStride
ldr x9, [x4, #8] // l
ldr x10, [x4, #16] // h
lsl x11, x11, #2 // aStride * 4
mov x16, #64 // B_stride = LP * HP = 4 * 8 * sizeof(int16_t)
ldr x7, [x4, #24] // cStride
ldr x19, [x4, #40] // bExtraStride
add x10, x10, #3
lsr x10, x10, #2
add x9, x9, #3
lsr x9, x9, #2
cbz x5, Start
ld1 {v5.4s}, [x5]
dup v9.4s, v5.s[2] // Min Value
dup v10.4s, v5.s[3] // Max Value
Start:
E8:
cmp x3, #8
blt E4
LoopE8: // e, TILE_BLOCK size is 8
mov x20, x6 // bias
mov x8, x10 // updiv(h, 4)
mov x21, x0 // dest, C
mov x13, x2 // weight, B
LH8:
cmp x8, #2 // h/4 > 2
blt LH4
sub x14, x7, #64 // cStride - 64
LoopH8x8:
mov x15, x1 // src, A
mov x12, x9 // l
cbz x5, NoBiasLH8
ld1 {v0.4h, v1.4h}, [x20], #16 // 8 * sizeof(int16_t)
shll v0.4s, v0.4h, #16
shll v1.4s, v1.4h, #16
mov v2.16b, v0.16b
mov v3.16b, v1.16b
uzp1 v16.2d, v0.2d, v2.2d // bias_0, bias_1, bias_0, bias_1
uzp2 v17.2d, v0.2d, v2.2d // bias_2, bias_3, bias_2, bias_3
uzp1 v24.2d, v1.2d, v3.2d // bias_0, bias_1, bias_0, bias_1
uzp2 v25.2d, v1.2d, v3.2d // bias_2, bias_3, bias_2, bias_3
SET_BIAS v16, v18, v20, v22
SET_BIAS v17, v19, v21, v23
SET_BIAS v24, v26, v28, v30
SET_BIAS v25, v27, v29, v31
b LoopL
NoBiasLH8:
SET_ZERO v16, v17, v18, v19
SET_ZERO v20, v21, v22, v23
SET_ZERO v24, v25, v26, v27
SET_ZERO v28, v29, v30, v31
LoopL:
// A [8, 4, bf16] : rn = 4 : v4 - v7
// B [8, 4, bf16] : rn = 4 : v0 - v3
// C [8, 8, fp32] : rn = 16 : v16 - v31
ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x15], x11 // A: 8 * 4 * sizeof(int16_t)
ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x13], x16 // B: 8 * 4 * sizeof(int16_t)
.inst 0x6e40ec90 // bfmmla v16.4s, v4.8h, v0.8h
.inst 0x6e41ec91 // bfmmla v17.4s, v4.8h, v1.8h
.inst 0x6e40ecb2 // bfmmla v18.4s, v5.8h, v0.8h
.inst 0x6e41ecb3 // bfmmla v19.4s, v5.8h, v1.8h
.inst 0x6e40ecd4 // bfmmla v20.4s, v6.8h, v0.8h
.inst 0x6e41ecd5 // bfmmla v21.4s, v6.8h, v1.8h
.inst 0x6e40ecf6 // bfmmla v22.4s, v7.8h, v0.8h
.inst 0x6e41ecf7 // bfmmla v23.4s, v7.8h, v1.8h
.inst 0x6e42ec98 // bfmmla v24.4s, v4.8h, v2.8h
.inst 0x6e43ec99 // bfmmla v25.4s, v4.8h, v3.8h
.inst 0x6e42ecba // bfmmla v26.4s, v5.8h, v2.8h
.inst 0x6e43ecbb // bfmmla v27.4s, v5.8h, v3.8h
.inst 0x6e42ecdc // bfmmla v28.4s, v6.8h, v2.8h
.inst 0x6e43ecdd // bfmmla v29.4s, v6.8h, v3.8h
.inst 0x6e42ecfe // bfmmla v30.4s, v7.8h, v2.8h
.inst 0x6e43ecff // bfmmla v31.4s, v7.8h, v3.8h
subs x12, x12, #1
bgt LoopL
LoopLEnd:
uzp1 v15.2d, v16.2d, v17.2d
uzp2 v16.2d, v16.2d, v17.2d
uzp1 v17.2d, v18.2d, v19.2d
uzp2 v18.2d, v18.2d, v19.2d
uzp1 v19.2d, v20.2d, v21.2d
uzp2 v20.2d, v20.2d, v21.2d
uzp1 v21.2d, v22.2d, v23.2d
uzp2 v22.2d, v22.2d, v23.2d
uzp1 v23.2d, v24.2d, v25.2d
uzp2 v24.2d, v24.2d, v25.2d
uzp1 v25.2d, v26.2d, v27.2d
uzp2 v26.2d, v26.2d, v27.2d
uzp1 v27.2d, v28.2d, v29.2d
uzp2 v28.2d, v28.2d, v29.2d
uzp1 v29.2d, v30.2d, v31.2d
uzp2 v30.2d, v30.2d, v31.2d
cbz x5, StoreLH8
PostTreatLH8:
fmax v15.4s, v15.4s, v9.4s
fmax v16.4s, v16.4s, v9.4s
fmax v17.4s, v17.4s, v9.4s
fmax v18.4s, v18.4s, v9.4s
fmax v19.4s, v19.4s, v9.4s
fmax v20.4s, v20.4s, v9.4s
fmax v21.4s, v21.4s, v9.4s
fmax v22.4s, v22.4s, v9.4s
fmax v23.4s, v23.4s, v9.4s
fmax v24.4s, v24.4s, v9.4s
fmax v25.4s, v25.4s, v9.4s
fmax v26.4s, v26.4s, v9.4s
fmax v27.4s, v27.4s, v9.4s
fmax v28.4s, v28.4s, v9.4s
fmax v29.4s, v29.4s, v9.4s
fmax v30.4s, v30.4s, v9.4s
fmin v15.4s, v15.4s, v10.4s
fmin v16.4s, v16.4s, v10.4s
fmin v17.4s, v17.4s, v10.4s
fmin v18.4s, v18.4s, v10.4s
fmin v19.4s, v19.4s, v10.4s
fmin v20.4s, v20.4s, v10.4s
fmin v21.4s, v21.4s, v10.4s
fmin v22.4s, v22.4s, v10.4s
fmin v23.4s, v23.4s, v10.4s
fmin v24.4s, v24.4s, v10.4s
fmin v25.4s, v25.4s, v10.4s
fmin v26.4s, v26.4s, v10.4s
fmin v27.4s, v27.4s, v10.4s
fmin v28.4s, v28.4s, v10.4s
fmin v29.4s, v29.4s, v10.4s
fmin v30.4s, v30.4s, v10.4s
StoreLH8:
Float32ToBf16 v15, v16, v17, v18
Float32ToBf16 v19, v20, v21, v22
Float32ToBf16 v23, v24, v25, v26
Float32ToBf16 v27, v28, v29, v30
st1 {v15.4h, v16.4h, v17.4h, v18.4h}, [x0], #32 // 16 * sizeof(int16_t)
st1 {v19.4h, v20.4h, v21.4h, v22.4h}, [x0], #32 // 16 * sizeof(int16_t)
add x0, x0, x14
st1 {v23.4h, v24.4h, v25.4h, v26.4h}, [x0], #32 // 16 * sizeof(int16_t)
st1 {v27.4h, v28.4h, v29.4h, v30.4h}, [x0], #32 // 16 * sizeof(int16_t)
add x0, x0, x14
add x13, x13, x19 // weight stride
sub x8, x8, #2
cmp x8, #2
bge LoopH8x8
LH4:
cbz x8, E8End
LoopHRemain:
mov x15, x1
mov x12, x9
cbz x5, NoBiasHRemain
ld1 {v0.4h}, [x20]
shll v0.4s, v0.4h, #16
mov v2.16b, v0.16b
uzp1 v16.2d, v0.2d, v2.2d // bias_0, bias_1, bias_0, bias_1
uzp2 v17.2d, v0.2d, v2.2d // bias_2, bias_3, bias_2, bias_3
SET_BIAS v16, v18, v20, v22
SET_BIAS v17, v19, v21, v23
b LoopLR
NoBiasHRemain:
SET_ZERO v16, v17, v18, v19
SET_ZERO v20, v21, v22, v23
LoopLR:
// A [8, 4, bf16] : rn = 4 : v4 - v7
// B [4, 4, bf16] : rn = 2 : v0 - v1
// C [8, 4, fp32] : rn = 8 : v16 - v23
ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x15], x11 // A: 8 * 4 * sizeof(int16_t)
ld1 {v0.8h, v1.8h}, [x13], x16 // B: 4 * 4 * sizeof(int16_t)
.inst 0x6e40ec90 // bfmmla v16.4s, v4.8h, v0.8h
.inst 0x6e41ec91 // bfmmla v17.4s, v4.8h, v1.8h
.inst 0x6e40ecb2 // bfmmla v18.4s, v5.8h, v0.8h
.inst 0x6e41ecb3 // bfmmla v19.4s, v5.8h, v1.8h
.inst 0x6e40ecd4 // bfmmla v20.4s, v6.8h, v0.8h
.inst 0x6e41ecd5 // bfmmla v21.4s, v6.8h, v1.8h
.inst 0x6e40ecf6 // bfmmla v22.4s, v7.8h, v0.8h
.inst 0x6e41ecf7 // bfmmla v23.4s, v7.8h, v1.8h
subs x12, x12, #1
bne LoopLR
LoopLREnd:
uzp1 v15.2d, v16.2d, v17.2d
uzp2 v16.2d, v16.2d, v17.2d
uzp1 v17.2d, v18.2d, v19.2d
uzp2 v18.2d, v18.2d, v19.2d
uzp1 v19.2d, v20.2d, v21.2d
uzp2 v20.2d, v20.2d, v21.2d
uzp1 v21.2d, v22.2d, v23.2d
uzp2 v22.2d, v22.2d, v23.2d
cbz x5, StoreLH8x4
PostTreatLH8x4:
fmax v15.4s, v15.4s, v9.4s
fmax v16.4s, v16.4s, v9.4s
fmax v17.4s, v17.4s, v9.4s
fmax v18.4s, v18.4s, v9.4s
fmax v19.4s, v19.4s, v9.4s
fmax v20.4s, v20.4s, v9.4s
fmax v21.4s, v21.4s, v9.4s
fmax v22.4s, v22.4s, v9.4s
fmin v15.4s, v15.4s, v10.4s
fmin v16.4s, v16.4s, v10.4s
fmin v17.4s, v17.4s, v10.4s
fmin v18.4s, v18.4s, v10.4s
fmin v19.4s, v19.4s, v10.4s
fmin v20.4s, v20.4s, v10.4s
fmin v21.4s, v21.4s, v10.4s
fmin v22.4s, v22.4s, v10.4s
StoreLH8x4:
Float32ToBf16 v15, v16, v17, v18
Float32ToBf16 v19, v20, v21, v22
st1 {v15.4h, v16.4h, v17.4h, v18.4h}, [x0], #32 // 16 * sizeof(int16_t)
st1 {v19.4h, v20.4h, v21.4h, v22.4h}, [x0], #32 // 16 * sizeof(int16_t)
E8End:
sub x3, x3, #8
cmp x3, #8
add x0, x21, #64 // move dest address of 8 * 4 * sizeof(int16_t)
add x1, x1, #64 // move A matrix address of 8 * 4 * sizeof(int16_t)
bge LoopE8
E4:
cmp x3, #4
mov x20, x6
blt E2
mov x8, x10
mov x21, x0
mov x13, x2
cmp x8, #2
blt E4LH4
E4LH8:
E4LoopH8:
mov x15, x1
mov x12, x9
cbz x5, NoBiasE4
ld1 {v0.4h, v1.4h}, [x20], #16 // 8 * sizeof(int16_t)
shll v0.4s, v0.4h, #16
shll v1.4s, v1.4h, #16
mov v2.16b, v0.16b
mov v3.16b, v1.16b
uzp1 v16.2d, v0.2d, v2.2d // bias_0, bias_1, bias_0, bias_1
uzp2 v17.2d, v0.2d, v2.2d // bias_2, bias_3, bias_2, bias_3
uzp1 v20.2d, v1.2d, v3.2d // bias_0, bias_1, bias_0, bias_1
uzp2 v21.2d, v1.2d, v3.2d // bias_2, bias_3, bias_2, bias_3
mov v18.16b, v16.16b
mov v19.16b, v17.16b
mov v22.16b, v20.16b
mov v23.16b, v21.16b
b E4LoopL
NoBiasE4:
SET_ZERO v16, v17, v18, v19
SET_ZERO v20, v21, v22, v23
E4LoopL:
// A [4, 4, bf16] : rn = 4 : v4 - v5
// B [8, 4, bf16] : rn = 4 : v0 - v3
// C [4, 8, fp32] : rn = 8 : v16 - v23
ld1 {v4.8h, v5.8h}, [x15], x11 // A: 4 * 4 * sizeof(int16_t)
ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x13], x16 // B: 8 * 4 * sizeof(int16_t)
.inst 0x6e40ec90 // bfmmla v16.4s, v4.8h, v0.8h
.inst 0x6e41ec91 // bfmmla v17.4s, v4.8h, v1.8h
.inst 0x6e40ecb2 // bfmmla v18.4s, v5.8h, v0.8h
.inst 0x6e41ecb3 // bfmmla v19.4s, v5.8h, v1.8h
.inst 0x6e42ec94 // bfmmla v20.4s, v4.8h, v2.8h
.inst 0x6e43ec95 // bfmmla v21.4s, v4.8h, v3.8h
.inst 0x6e42ecb6 // bfmmla v22.4s, v5.8h, v2.8h
.inst 0x6e43ecb7 // bfmmla v23.4s, v5.8h, v3.8h
subs x12, x12, #1
bgt E4LoopL
E4LoopLEnd:
uzp1 v15.2d, v16.2d, v17.2d
uzp2 v16.2d, v16.2d, v17.2d
uzp1 v17.2d, v18.2d, v19.2d
uzp2 v18.2d, v18.2d, v19.2d
uzp1 v19.2d, v20.2d, v21.2d
uzp2 v20.2d, v20.2d, v21.2d
uzp1 v21.2d, v22.2d, v23.2d
uzp2 v22.2d, v22.2d, v23.2d
cbz x5, StoreLH4x8
PostTreatLH4x8:
fmax v15.4s, v15.4s, v9.4s
fmax v16.4s, v16.4s, v9.4s
fmax v17.4s, v17.4s, v9.4s
fmax v18.4s, v18.4s, v9.4s
fmax v19.4s, v19.4s, v9.4s
fmax v20.4s, v20.4s, v9.4s
fmax v21.4s, v21.4s, v9.4s
fmax v22.4s, v22.4s, v9.4s
fmin v15.4s, v15.4s, v10.4s
fmin v16.4s, v16.4s, v10.4s
fmin v17.4s, v17.4s, v10.4s
fmin v18.4s, v18.4s, v10.4s
fmin v19.4s, v19.4s, v10.4s
fmin v20.4s, v20.4s, v10.4s
fmin v21.4s, v21.4s, v10.4s
fmin v22.4s, v22.4s, v10.4s
StoreLH4x8:
Float32ToBf16 v15, v16, v17, v18
Float32ToBf16 v19, v20, v21, v22
st1 {v15.4h, v16.4h, v17.4h, v18.4h}, [x0], x7 // 16 * sizeof(int16_t)
st1 {v19.4h, v20.4h, v21.4h, v22.4h}, [x0], x7 // 16 * sizeof(int16_t)
add x13, x13, x19 // weight stride
sub x8, x8, #2
cmp x8, #2
bge E4LoopH8
E4LH4:
cbz x8, E4End
mov x15, x1
mov x12, x9
cbz x5, NoBiasE4R
ld1 {v0.4h}, [x20]
shll v0.4s, v0.4h, #16
mov v2.16b, v0.16b
uzp1 v16.2d, v0.2d, v2.2d // bias_0, bias_1, bias_0, bias_1
uzp2 v17.2d, v0.2d, v2.2d // bias_2, bias_3, bias_2, bias_3
mov v18.16b, v16.16b
mov v19.16b, v17.16b
b E4LoopLR
NoBiasE4R:
SET_ZERO v16, v17, v18, v19
E4LoopLR:
// A [4, 4, bf16] : rn = 4 : v4 - v5
// B [4, 4, bf16] : rn = 4 : v0 - v1
// C [4, 4, fp32] : rn = 4 : v16 - v19
ld1 {v4.8h, v5.8h}, [x15], x11 // A: 4 * 4 * sizeof(int16_t)
ld1 {v0.8h, v1.8h}, [x13], x16 // B: 4 * 4 * sizeof(int16_t)
.inst 0x6e40ec90 // bfmmla v16.4s, v4.8h, v0.8h
.inst 0x6e41ec91 // bfmmla v17.4s, v4.8h, v1.8h
.inst 0x6e40ecb2 // bfmmla v18.4s, v5.8h, v0.8h
.inst 0x6e41ecb3 // bfmmla v19.4s, v5.8h, v1.8h
subs x12, x12, #1
bgt E4LoopLR
E4LoopLREnd:
uzp1 v15.2d, v16.2d, v17.2d
uzp2 v16.2d, v16.2d, v17.2d
uzp1 v17.2d, v18.2d, v19.2d
uzp2 v18.2d, v18.2d, v19.2d
cbz x5, StoreLH4x4
PostTreatLH4x4:
fmax v15.4s, v15.4s, v9.4s
fmax v16.4s, v16.4s, v9.4s
fmax v17.4s, v17.4s, v9.4s
fmax v18.4s, v18.4s, v9.4s
fmin v19.4s, v19.4s, v10.4s
fmin v20.4s, v20.4s, v10.4s
fmin v21.4s, v21.4s, v10.4s
fmin v22.4s, v22.4s, v10.4s
StoreLH4x4:
Float32ToBf16 v15, v16, v17, v18
st1 {v15.4h, v16.4h, v17.4h, v18.4h}, [x0] // 16 * sizeof(int16_t)
E4End:
sub x3, x3, #4
add x0, x21, #32 // move dest address of 4 * 4 * sizeof(int16_t)
add x1, x1, #32 // move dest address of 4 * 4 * sizeof(int16_t)
E2:
cmp x3, #2
mov x20, x6
blt E1
mov x8, x10
mov x21, x0
mov x13, x2
cmp x8, #2
blt E2LH4
E2LH8:
E2LoopH8:
mov x15, x1
mov x12, x9
cbz x5, NoBiasE2
ld1 {v0.4h, v1.4h}, [x20], #16
shll v0.4s, v0.4h, #16
shll v1.4s, v1.4h, #16
mov v2.16b, v0.16b
mov v3.16b, v1.16b
uzp1 v16.2d, v0.2d, v2.2d // bias_0, bias_1, bias_0, bias_1
uzp2 v17.2d, v0.2d, v2.2d // bias_2, bias_3, bias_2, bias_3
uzp1 v18.2d, v1.2d, v3.2d // bias_0, bias_1, bias_0, bias_1
uzp2 v19.2d, v1.2d, v3.2d // bias_2, bias_3, bias_2, bias_3
b E2LoopL
NoBiasE2:
SET_ZERO v16, v17, v18, v19
E2LoopL:
// A [2, 4, bf16] : rn = 1 : v4
// B [8, 4, bf16] : rn = 2 : v0 - v3
// C [2, 8, fp32] : rn = 4 : v16 - v19
ld1 {v4.8h}, [x15], x11 // A: 2 * 4 * sizeof(int16_t)
ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x13], x16 // B: 8 * 4 * sizeof(int16_t)
.inst 0x6e40ec90 // bfmmla v16.4s, v4.8h, v0.8h
.inst 0x6e41ec91 // bfmmla v17.4s, v4.8h, v1.8h
.inst 0x6e42ec92 // bfmmla v18.4s, v4.8h, v2.8h
.inst 0x6e43ec93 // bfmmla v19.4s, v4.8h, v3.8h
subs x12, x12, #1
bgt E2LoopL
E2LoopLEnd:
uzp1 v15.2d, v16.2d, v17.2d
uzp2 v16.2d, v16.2d, v17.2d
uzp1 v17.2d, v18.2d, v19.2d
uzp2 v18.2d, v18.2d, v19.2d
cbz x5, StoreLH2x8
PostTreatLH2x8:
fmax v15.4s, v15.4s, v9.4s
fmax v16.4s, v16.4s, v9.4s
fmax v17.4s, v17.4s, v9.4s
fmax v18.4s, v18.4s, v9.4s
fmin v15.4s, v15.4s, v10.4s
fmin v16.4s, v16.4s, v10.4s
fmin v17.4s, v17.4s, v10.4s
fmin v18.4s, v18.4s, v10.4s
StoreLH2x8:
Float32ToBf16 v15, v16, v17, v18
st1 {v15.4h, v16.4h}, [x0], x7 // 8 * sizeof(int16_t)
st1 {v17.4h, v18.4h}, [x0], x7 // 8 * sizeof(int16_t)
add x13, x13, x19 // weight stride
sub x8, x8, #2
cmp x8, #2
bge E2LoopH8
E2LH4:
cbz x8, E2End
mov x15, x1
mov x12, x9
cbz x5, NoBiasE2R
ld1 {v0.4h}, [x20]
shll v0.4s, v0.4h, #16
mov v2.16b, v0.16b
uzp1 v16.2d, v0.2d, v2.2d // bias_0, bias_1, bias_0, bias_1
uzp2 v17.2d, v0.2d, v2.2d // bias_2, bias_3, bias_2, bias_3
b E2LoopLR
NoBiasE2R:
movi v16.4s, #0
movi v17.4s, #0
E2LoopLR:
// A [2, 4, bf16] : rn = 1 : v4
// B [4, 4, bf16] : rn = 2 : v0 - v1
// C [2, 4, fp32] : rn = 2 : v16 - v17
ld1 {v4.8h}, [x15], x11 // A: 2 * 4 * sizeof(int16_t)
ld1 {v0.8h, v1.8h}, [x13], x16 // B: 4 * 4 * sizeof(int16_t)
.inst 0x6e40ec90 // bfmmla v16.4s, v4.8h, v0.8h
.inst 0x6e41ec91 // bfmmla v17.4s, v4.8h, v1.8h
subs x12, x12, #1
bgt E2LoopLR
E2LoopLREnd:
uzp1 v15.2d, v16.2d, v17.2d
uzp2 v16.2d, v16.2d, v17.2d
cbz x5, StoreLH2x4
PostTreatLH2x4:
fmax v15.4s, v15.4s, v9.4s
fmax v16.4s, v16.4s, v9.4s
fmin v15.4s, v15.4s, v10.4s
fmin v16.4s, v16.4s, v10.4s
StoreLH2x4:
shrn v15.4h, v15.4s, #16
shrn v16.4h, v16.4s, #16
st1 {v15.4h, v16.4h}, [x0] // 8 * sizeof(int16_t)
E2End:
sub x3, x3, #2
add x0, x21, #16 // move dest address of 2 * 4 * sizeof(int16_t)
add x1, x1, #16 // move dest address of 2 * 4 * sizeof(int16_t)
E1:
cmp x3, #0
beq End
LoopE1:
mov x20, x6
mov x8, x10
mov x21, x0
mov x13, x2
cmp x8, #2
blt E1LH4
E1LH8:
E1LoopH8:
mov x15, x1
mov x12, x9
cbz x5, NoBiasE1
ld1 {v0.4h, v1.4h}, [x20], #16
shll v0.4s, v0.4h, #16
shll v1.4s, v1.4h, #16
mov v2.16b, v0.16b
mov v3.16b, v1.16b
uzp1 v16.2d, v0.2d, v2.2d // bias_0, bias_1, bias_0, bias_1
uzp2 v17.2d, v0.2d, v2.2d // bias_2, bias_3, bias_2, bias_3
uzp1 v18.2d, v1.2d, v3.2d // bias_0, bias_1, bias_0, bias_1
uzp2 v19.2d, v1.2d, v3.2d // bias_2, bias_3, bias_2, bias_3
b E1LoopL
NoBiasE1:
SET_ZERO v16, v17, v18, v19
E1LoopL:
// A [1, 4, bf16] : rn = 1 : v4
// B [8, 4, bf16] : rn = 4 : v0 - v3
// C [1, 8, fp32] : rn = 4 : v16 - v19
ld1 {v4.4h}, [x15], x11 // A: 1 * 4 * sizeof(int16_t)
ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x13], x16 // B: 8 * 4 * sizeof(int16_t)
.inst 0x6e40ec90 // bfmmla v16.4s, v4.8h, v0.8h
.inst 0x6e41ec91 // bfmmla v17.4s, v4.8h, v1.8h
.inst 0x6e42ec92 // bfmmla v18.4s, v4.8h, v2.8h
.inst 0x6e43ec93 // bfmmla v19.4s, v4.8h, v3.8h
subs x12, x12, #1
bgt E1LoopL
E1LoopLEnd:
// v16-v19: [r0, r1, 0, 0]
uzp1 v15.2d, v16.2d, v17.2d
uzp1 v16.2d, v18.2d, v19.2d
cbz x5, StoreLH1x8
PostTreatLH1x8:
fmax v15.4s, v15.4s, v9.4s
fmax v16.4s, v16.4s, v9.4s
fmin v15.4s, v15.4s, v10.4s
fmin v16.4s, v16.4s, v10.4s
StoreLH1x8:
shrn v15.4h, v15.4s, #16
shrn v16.4h, v16.4s, #16
st1 {v15.4h}, [x0], x7
st1 {v16.4h}, [x0], x7
add x13, x13, x19
sub x8, x8, #2
cmp x8, #2
bge E1LoopH8
E1LH4:
cbz x8, E1End
mov x15, x1
mov x12, x9
cbz x5, NoBiasE1R
ld1 {v0.4h}, [x20]
shll v0.4s, v0.4h, #16
mov v2.16b, v0.16b
uzp1 v16.2d, v0.2d, v2.2d // bias_0, bias_1, bias_0, bias_1
uzp2 v17.2d, v0.2d, v2.2d // bias_2, bias_3, bias_2, bias_3
b E1LoopLR
NoBiasE1R:
movi v16.4s, #0
movi v17.4s, #0
E1LoopLR:
// A [1, 4, bf16] : rn = 1 : v4
// B [4, 4, bf16] : rn = 2 : v0 - v1
// C [1, 8, fp32] : rn = 4 : v16 - v17
ld1 {v4.4h}, [x15], x11 // A: 1 * 4 * sizeof(int16_t)
ld1 {v0.8h, v1.8h}, [x13], x16 // B: 4 * 4 * sizeof(int16_t)
.inst 0x6e40ec90 // bfmmla v16.4s, v4.8h, v0.8h
.inst 0x6e41ec91 // bfmmla v17.4s, v4.8h, v1.8h
subs x12, x12, #1
bgt E1LoopLR
E1LoopLREnd:
uzp1 v15.2d, v16.2d, v17.2d
cbz x5, StoreLH1x4
PostTreatLH1x4:
fmax v15.4s, v15.4s, v9.4s
fmin v15.4s, v15.4s, v10.4s
StoreLH1x4:
shrn v15.4h, v15.4s, #16
st1 {v15.4h}, [x0]
E1End:
subs x3, x3, #1
add x0, x21, #8
add x1, x1, #8
bne LoopE1
End:
ldr x19, [sp, #0]
ldr x20, [sp, #8]
ldr x21, [sp, #16]
add sp, sp, #32
ret
#endif