通过P/Invoke加速C#程序

任何语言都会提供FFI机制(Foreign Function Interface, 叫法不太一样), 大多数的FFI机制是和C API. C#提供了P/Invoke来和操作系统, 第三方扩展进行交互.

FFI通常用来和老的代码交互, 例如有大量的遗留代码, 重写成本太高, 可以导出C接口, 然后新系统和老系统交互; 还有一种用处就是优化, 将某一部分功能挪到C/C++(或者其他Native语言)里面, 通过特殊的优化, 对系统进行加速.

所有的FFI均存在额外的开销, 除了C++和C这种语言交互. 托管语言和非托管语言, 托管语言和托管语言交互的成本都不小. C#和C的交互, 主要的成本有两块:

  • 参数传递的成本

    C#里面的字符串是UTF-16编码的, 但是在C里面一般使用ASCII或者兼容的编码, 所以调用之前需要先做一次转换.

    内存布局不一样的参数, 会有额外的开销.

  • 调用的额外开销

    P/Invoke 的开销介于每个呼叫 10 到 30 x86 指令之间。 除了此固定成本外,封送还会产生额外的开销。 在托管代码和非托管代码中具有相同的表示形式的可声明类型之间没有封送成本。 例如,int 和 Int32 之间没有翻译费用。

    可以理解为10-30个时钟周期, 比虚函数调用成本要高一些.

以上是P/Invoke优化的基础知识. 只要调用的函数执行的时间较长, 参数的转换足够少, 那么进行P/Invoke优化就是有意义的.

某游戏服务器使用了AES-ECB加密算法进行通讯协议的加密. 算法一直没改, 实现修改了好几次, 因为整个编码过程中, 会产生多个临时byte[]对象, 所以一直想要优化掉.

下面这个版本是C# Slice的版本, 希望把加密后的内容放到我准备好的Slice里面(IByteBuffer). 但是其中有一个MemoryStream还是无法处理, 这个对象内部还是会产生byte[].

public static int AesEncrypt(byte[] src, int offset, int count, byte[] dest, int destOffset, byte[] Key0)
{
    using Rijndael rm = Rijndael.Create();
    rm.Key = Key0;
    rm.Mode = CipherMode.ECB;
    rm.Padding = PaddingMode.PKCS7;

    using ICryptoTransform cTransform = rm.CreateEncryptor();
    using var memoryStream = new MemoryStream(dest, destOffset, count + 32);
    using var writer = new CryptoStream(memoryStream, cTransform, CryptoStreamMode.Write);
    writer.Write(src, offset, count);
    writer.FlushFinalBlock();

    return (int)memoryStream.Position;
}

花了好长时间去研究.NET内部的实现, 没找到解决办法.

所以这时候就把眼睛转向了P/Invoke和C++. 好在可以先通过C#的版本生成一个输入输出样本, 然后C++尝试着去跑通整个输入输出.

下面是C++的版本:

aes_ech.h

#pragma once
#include <openssl/aes.h>
#include <assert.h>
#include <string.h>

#ifdef WIN32
#define __DLLIMPORT __declspec(dllimport)
#define __DLLEXPORT __declspec(dllexport)
#else
#define __DLLIMPORT
#define __DLLEXPORT 
#endif

extern "C"
{
__DLLEXPORT int AesEcbEncrypt(unsigned char* key, int key_size,
		unsigned char* source, int source_length,
		unsigned char* dest);

__DLLEXPORT int AesEcbDecrypt(unsigned char* key, int key_size,
		unsigned char* source, int source_length,
		unsigned char* dest);
}


static inline int pkcs7padding(unsigned char* data, int length) {
	int padding = AES_BLOCK_SIZE - length % AES_BLOCK_SIZE;
	int destSize = length + padding;
	for (int index = length; index < destSize; ++index) {
		data[index] = padding;
	}
	return destSize;
}

static inline int Encrypt(unsigned char* key, int keyLength,
			unsigned char* src, int srcLength,
			unsigned char* dest) {
	int paddingLength = pkcs7padding(src, srcLength);

	AES_KEY aes_key;
	AES_set_encrypt_key(reinterpret_cast<const unsigned char*>(&key[0]),
		keyLength * 8, &aes_key);

	unsigned char* encrypted = dest;

	for (int block = 0; block < paddingLength; block += AES_BLOCK_SIZE) {
		AES_ecb_encrypt(reinterpret_cast<const unsigned char*>(&src[block]),
			reinterpret_cast<unsigned char*>(&encrypted[block]),
			&aes_key, AES_ENCRYPT);
	}

	return paddingLength;
}

static inline int pkcs7unpadding(unsigned char* data, int dataLength) {
	int padding = data[dataLength - 1];
	return dataLength - padding;
}

static inline int Decrypt(unsigned char *key, int keyLength,
			unsigned char* encrypted, int encryptedLength,
			unsigned char* decrypted) {
	AES_KEY aes_key;
	AES_set_decrypt_key(reinterpret_cast<const unsigned char*>(&key[0]),
		keyLength * 8, &aes_key);

	int decrypted_length = encryptedLength;

	for (int block = 0; block < encryptedLength;
		block += AES_BLOCK_SIZE) {
		AES_ecb_encrypt(reinterpret_cast<const unsigned char*>(&encrypted[block]),
			reinterpret_cast<unsigned char*>(&decrypted[block]),
			&aes_key, AES_DECRYPT);
	}

	return pkcs7unpadding(decrypted, encryptedLength);
}

aes_ecb.cpp

#include "aes_ecb.h"

extern "C" 
{
__DLLEXPORT int AesEcbEncrypt(unsigned char* key, int key_size,
    unsigned char* source, int source_length,
    unsigned char* dest) {
    return ::Encrypt(key, key_size, source, source_length, dest);
}

__DLLEXPORT int AesEcbDecrypt(unsigned char* key, int key_size,
    unsigned char* source, int source_length,
    unsigned char* dest) {
    return ::Decrypt(key, key_size, source, source_length, dest);
}
}

C#的P/Invoke封装, 以及测试代码:

using System;
using System.Runtime.InteropServices;
using System.Text;

namespace AesPInvoke
{
    static class AesWin
    {
        [DllImport("AESECB.dll", CallingConvention = CallingConvention.Cdecl)]
        public static unsafe extern int AesEcbEncrypt(byte* key, int key_size, byte* source, int source_length, byte* dest);

        [DllImport("AESECB.dll", CallingConvention = CallingConvention.Cdecl)]
        public static unsafe extern int AesEcbDecrypt(byte* key, int key_size, byte* source, int source_length, byte* dest);
    }
    static class AesLinux 
    {
        [DllImport("AESECB.so", CallingConvention = CallingConvention.Cdecl)]
        public static unsafe extern int AesEcbEncrypt(byte* key, int key_size, byte* source, int source_length, byte* dest);

        [DllImport("AESECB.so", CallingConvention = CallingConvention.Cdecl)]
        public static unsafe extern int AesEcbDecrypt(byte* key, int key_size, byte* source, int source_length, byte* dest);
    }

    static class Aes 
    {
        public unsafe delegate int AesFunc(byte* key, int key_size, byte* source, int source_length, byte* dest);
        static AesFunc encrypt;
        static AesFunc decrypt;
        public static AesFunc AesEncrpt => encrypt;
        public static AesFunc AesDecrypt => decrypt;
        static unsafe Aes() 
        {
            if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) 
            {
                encrypt = AesLinux.AesEcbEncrypt;
                decrypt = AesLinux.AesEcbDecrypt;
            }
            else 
            {
                encrypt = AesWin.AesEcbEncrypt;
                decrypt = AesWin.AesEcbDecrypt;
            }
        }
    }

    class Program
    {

        private static bool Compare(ArraySegment<byte> a, ArraySegment<byte> b) 
        {
            if (a.Count != b.Count)
            {
                return false;
            }
            for (int i = 0; i < a.Count; ++i) 
            {
                if (a[i] != b[i]) return false;
            }
            return true;
        }

        static  unsafe void Main(string[] args)
        {
            byte[] origin = new byte[] {
                                    0x06, 0x04, 0x34, 0x35, 0x32, 0x56, 0x0a, 0x10, 0x08, 0xf9, 0xeb, 0x06,
                                    0x10, 0x93, 0x12, 0x18, 0x85, 0x1a, 0x20, 0x89, 0xdf, 0xf6, 0xd3, 0x01
            };
            byte[] dest = new byte[] {0x0f, 0xd9, 0x52, 0x10, 0x11, 0x4b, 0xcc, 0xe5,
                              0x48, 0x9d, 0x47, 0x2a, 0x69, 0xa4, 0x19, 0xcc,
                              0x08, 0x6b, 0x7d, 0xe9, 0x65, 0x26, 0x53, 0x10,
                              0x5c, 0xc9, 0x2f, 0xa8, 0x02, 0x43, 0x32, 0x8f};

            var originSegment = new ArraySegment<byte>(origin);
            var destSegment = new ArraySegment<byte>(dest);

            byte[] key = Encoding.UTF8.GetBytes("12345678876543211234567887654abc");

            byte[] input = new byte[origin.Length + 32];
            Array.Copy(origin, input, origin.Length);

            byte[] output = new byte[origin.Length + 32];

            fixed(byte* keyPointer = key) 
            fixed(byte* inputPointer = input)
            fixed(byte* outputPointer = output)
            {
                var length = Aes.AesEncrpt(keyPointer, key.Length, inputPointer, origin.Length, outputPointer);
                var data = new ArraySegment<byte>(output, 0, length);
                Console.WriteLine("{0}", Compare(destSegment, data));
            }

            input = new byte[dest.Length];
            Array.Copy(dest, input, dest.Length);
            output = new byte[dest.Length];

            fixed(byte* keyPointer = key) 
            fixed(byte* inputPointer = input)
            fixed(byte* outputPointer = output)
            {
                var length = Aes.AesDecrypt(keyPointer, key.Length, inputPointer, dest.Length, outputPointer);
                var data = new ArraySegment<byte>(output, 0, length);
                Console.WriteLine("{0}", Compare(originSegment, data));
            }

            Console.WriteLine("Hello World!");
        }
    }
}

跑通测试之后, 就可以集成到系统里面去, 把托管实现给替换掉. 一次可以把多余的AllocArray, 和加速同时完成.

C++版本的AES ECB加密使用了OpenSSL库, 好处是工业级实现, 而且还有可能会有AES-NI加速, Windows上面只需要通过vcpkg就可以方便的移植过来, Linux上面本身就有这个库.

大部分C#代码都可以跑得非常快, 一般情况下是不需要进行这种极端优化. 但是某游戏服务器是一个比较特殊的服务器, 其服务器只有一个进程, 一个进程内需要跑IO密集, 计算密集(加解密,物理,战斗等), 还要承担GC的负担, 所以才采用了这种优化方式.

参考:

  1. P/Invoke
  2. P/Invoke开销
  3. OpenSSL AES
  4. AES-NI Performance
  5. vcpkg

通过P/Invoke加速C#程序

09-15 20:48