背景

在路由器上,需要做一个域名分流功能:系统配置有两个DNS server,通过配置一个白名单,让白名单内的域名走 DNS server 1, 其余的的dns解析走DNS server 2。

最初想到的方案是,通过dnsmasq配置来完成。dnsmaq实际上是支持一种配置方式就指定某个域名用指定的dns来解析。但,要一条一条的配置,稍显繁琐;而且dnsmaq这种配置方式暗示,很大可能未对大量的这样的配置做优化,很可能是逐条匹配,那么性能估计难达要求(笔者赖,没有去验证这个事实),而且小设备上cpu资源很紧张,所以不太想在dnsmasq上去纠缠。

另一个方案,就是实现一个类似ipset 的 iptables模块,通过内核匹配直接指定目标server:

点击(此处)折叠或打开

  1. iptables -t nat -p udp -m udp --dport 53 -m domainset --set-name "srv2" -j DNAT --to-destionation [srvip:53]

至于domainset这个模块,netfilter框架有很好的支持,照猫画虎即可。问题在于这个模块在内核的数据路径上(虽然只包括dns的请求路径),对于性能要求比较高,而且需要兼顾cpu和内存资源。于是并有了本主题的验证:
 
小设备上(cpu 800Mhz,单核或双核,内存64M 或128M)如何快速高效的匹配10000数量级的域名白名单?

通过简单验证,排除了hash, hash 链表等算法,选择了trie。
原因:
1, 该项目中,白名单是事先准备好的,设备运行过程中变动很小,而且基本上只加不删;
2,由于数量不是很大,trie对内存的消耗基本可以接受;
3,trie查询相比hash,对cpu消耗极低,而且查询效率比hash高。

实现和验证

首先定义数据结构:
由于域名中包含的字母 0-9,a-z,A-Z, '.', '-', 由于域名是不区分大小写的,那么把‘a-z’与‘A-Z’合并;即有效字母共26+10+2=38个。
于是不难定义节点结构体为:

点击(此处)折叠或打开

  1. struct trie_node {
  2.         char key;

  3.          struct trie_node* psub[38];
  4. }
按该结构体,我用5660个url测试,得出节点数目为38664, 总内存消耗6031584(约6Mb)。这个内存消耗确实有点超出预期。必须优化了。

在上诉结构体中,每个字母的下一级包含了所有可能(这样做是为了高速的查询);但实际上可以预期,这部分有巨大的浪费。但如果psub不使用静态数组,而使用链表的话,需要遍历,那么查询速度会大打折扣。

一个折中的办法是,psub采用动态数组,记录当前节点实际有效的子节点信息;额为维护一个索引,用于快速索引当前节点。

点击(此处)折叠或打开

  1. struct trie_node{
  2.          char key;
  3.          unsigned char count;
  4.          unsigned char index[39]; // start from 1
  5.          struct trie_node **psub;

  6. }
比如域名中字母‘w’,map对应的索引序号是33,那么在节点记录中index[33]就记录了当前节点的子节点‘w’ 在psub数组中的位置。

由于index记录了字母信息,所以成员key可以省去。
另外,由于需求要求域名可以支持顶层域名泛型配置(比如,配置了abc.com为白名单,那么YYY.XXX.abc.com都与其匹配)。为此,我们在存储时,采用倒置存储,moc.cba, 与 moc.cba.*匹配。
但有种情况是,配置了moc.cba, 但搜索moc.cbaXYZ,这二者存在交叠,但又不是域名级数包含关系,实际上二者不能匹配。于是,我们需要在搜索路径上,标明moc.cbaXYZ这个路径上,‘a’到达一次完结点,‘Z’到达一次完结点。因此需要增加一个成员来记录该信息。

点击(此处)折叠或打开

  1. struct trie_node {
  2.         unsigned char count;
  3.         unsigned char index[39];
  4.         unsigned char flag;
  5.         struct trie_node **psub;
  6. }

最终

点击(此处)折叠或打开

  1. struct URL_TRIE_NODE {
  2.     /*
  3.     * udata[0], for counter
  4.     * udata[1-10], for index ??0-9??in @pnodes
  5.     * udata[11-36], for index 'a-z' in @pnodes
  6.     * udata[37], for index '.' in @pnodes        
  7.     * udata[38], for index '-' in @pnodes
  8.     * udata[39], flags
  9.     */
  10.     u8 udata[40];
  11.     struct URL_TRIE_NODE **pnodes;
  12. };
  13. typedef struct URL_TRIE_NODE node_t;


附完整demo代码:

点击(此处)折叠或打开

  1. #include <stdio.h>
  2. #include <stdlib.h>
  3. #include <string.h>
  4. #include <sys/time.h>
  5. #include <unistd.h>

  6. #define DEBUG 1

  7. #undef u8
  8. typedef unsigned char u8;

  9. #define IDX_VAL_DOT 37
  10. #define IDX_VAL_HYPLINE 38
  11. #define IDX_VAL_FLAGS 39

  12. #define DOMAIN_TAIL_MASK 0x01

  13. struct URL_TRIE_NODE {
  14.     /*
  15.     * udata[0], for counter
  16.     * udata[1-10], for index ??0-9??in @pnodes
  17.     * udata[11-36], for index 'a-z' in @pnodes
  18.     * udata[37], for index '.' in @pnodes        
  19.     * udata[38], for index '-' in @pnodes
  20.     * udata[39], flags
  21.     */
  22.     u8 udata[40];
  23.     struct URL_TRIE_NODE **pnodes;
  24. };
  25. typedef struct URL_TRIE_NODE node_t;

  26. #define IS_DOMAIN_TAIL(flag) ((flag) & DOMAIN_TAIL_MASK)

  27. struct URL_TRIE {
  28.     unsigned int ndnum;
  29.     unsigned int urlnum;
  30.     unsigned int size;
  31.     node_t root;
  32. };
  33. typedef struct URL_TRIE trie_t;

  34. static int trie_urlidx_find(trie_t *trie, const u8 *urlidx, int len)
  35. {
  36.     node_t *pnd = &trie->root;
  37.     int cursor = 0;

  38.     while (pnd && (cursor < len))
  39.     {
  40.         int ndidx = pnd->udata[urlidx[cursor]];
  41.         
  42.         if (IS_DOMAIN_TAIL(pnd->udata[IDX_VAL_FLAGS]))
  43.         {
  44.             /*
  45.             * domain top level overlap
  46.             * recorder: abc.com match rul<XXX.abc.com>
  47.             */
  48.             if (urlidx[cursor] == IDX_VAL_DOT)
  49.             {
  50.                 return 1;
  51.             }
  52.         }

  53.         if (ndidx)
  54.         {
  55.             pnd = pnd->pnodes[ndidx - 1];

  56.             // common mathed
  57.             if (cursor == len - 1 && IS_DOMAIN_TAIL(pnd->udata[IDX_VAL_FLAGS]))
  58.             {
  59.                 return 1;
  60.             }
  61.             cursor ++;
  62.         }
  63.         else
  64.         {
  65.             break;
  66.         }
  67.     
  68.     }
  69.     return 0;
  70. }

  71. static int trie_url_add(trie_t *trie, const u8 *urlidx, int len)
  72. {
  73.     node_t* pnd = &trie->root;
  74.     int cursor = 0;

  75.     while (cursor < len)
  76.     {
  77.         int ndidx = pnd->udata[urlidx[cursor]]; // start from 1
  78.         if (ndidx)
  79.         {
  80.             pnd = pnd->pnodes[ndidx - 1];
  81.             cursor++;
  82.         }
  83.         else
  84.         {
  85.             node_t **old = pnd->pnodes;
  86.             node_t **pnnd;
  87.             node_t *nd = NULL;
  88.             
  89.             pnnd = (node_t**)malloc(sizeof(node_t *) * (pnd->udata[0] + 1));
  90.             if (!pnnd)
  91.             {
  92.                 printf("malloc for new nodes for(%p) error\n", pnd);
  93.                 return -1;
  94.             }
  95.             nd = (node_t*)malloc(sizeof(node_t));
  96.             if (!nd)
  97.             {
  98.                 printf("malloc new node fail\n");
  99.                 free(pnnd);
  100.                 return -1;
  101.             }
  102.             memset(nd, 0, sizeof(node_t));
  103.             if (pnd->udata[0])
  104.             {
  105.                 memcpy(pnnd, old, sizeof(node_t*) * pnd->udata[0]);
  106.                 free(old);
  107.             }

  108.             pnd->pnodes= pnnd;
  109.             pnd->pnodes[pnd->udata[0]++] = nd;
  110.             pnd->udata[urlidx[cursor]] = pnd->udata[0];
  111.             trie->ndnum++;
  112.             trie->size += sizeof(node_t) + sizeof(node_t*);
  113.             pnd = nd;
  114.             cursor++;
  115.         }
  116.     }
  117.     pnd->udata[IDX_VAL_FLAGS] |= DOMAIN_TAIL_MASK; // add end flag
  118.     trie->urlnum++;
  119.     return 0;
  120. }

  121. static int url2idx(const u8 *url, u8* index)
  122. {
  123.     int len = strlen(url);
  124.     int cursor = 0;
  125.     
  126.     while( cursor < len)
  127.     {
  128.         int idx = 0;
  129.         if (url[cursor] == '.')
  130.         {
  131.             idx = IDX_VAL_DOT;
  132.         }
  133.         else if (url[cursor] == '-')
  134.         {
  135.             idx = IDX_VAL_HYPLINE;
  136.         }
  137.         else if (url[cursor] >= '0' && url[cursor] <= '9')
  138.         {
  139.             idx = 1 + (url[cursor] - '0');
  140.         }
  141.         else if (url[cursor] >= 'a' & url[cursor] <= 'z')
  142.         {
  143.             idx = 11 + (url[cursor] - 'a');
  144.         }
  145.         else if (url[cursor] >= 'A' & url[cursor] <= 'Z')
  146.         {
  147.             idx = 11 + (url[cursor] - 'A');
  148.         }
  149.         else
  150.         {
  151.             printf("bad url:%s, bad key:%c\n", url, url[cursor]);
  152.             return 0;
  153.         }
  154.         index[len -1 -cursor] = idx;
  155.         cursor++;
  156.     }
  157.     return len;
  158. }

  159. static int idx2url(u8 *index, int len, u8 *url)
  160. {
  161.     int cursor = 0;
  162.     while (cursor < len)
  163.     {
  164.         u8 ch = 0;
  165.         if (index[cursor] == IDX_VAL_DOT )
  166.             ch = '.';
  167.         else if (index[cursor] == IDX_VAL_HYPLINE )
  168.             ch = '-';
  169.         else if (index[cursor] >= 1 && index[cursor] <= 10)
  170.             ch = '0' + (index[cursor] -1);
  171.         else if (index[cursor] >= 11 && index[cursor] <= 36)
  172.             ch = 'a' + (index[cursor]- 11);
  173.         else
  174.         {
  175.             printf("fail to map %d\n", index[cursor]);
  176.             return -1;
  177.         }

  178.         url[len - 1 - cursor] = ch;
  179.         cursor++;
  180.     }
  181.     url[len] = 0;
  182.     return 0;
  183. }

  184. int URLTRIE_loadfile(trie_t *trie, char *path)
  185. {
  186.     char buf[130] = {0};
  187.     FILE *pFile;
  188.     char data[100] = {0};
  189.     
  190.     pFile =fopen(path, "r");
  191.     if (NULL == pFile )
  192.     {
  193.         printf("can't open %s\n", path);
  194.         return -1;
  195.     }
  196.     while (( fgets(data, 128, pFile) != NULL))
  197.     {
  198.         int len = strlen(data);
  199.         u8 idx[128] = {0};
  200.         if (data[len -1] == '\n')
  201.         {
  202.             data[len - 1] = 0;
  203.         }
  204.         len = url2idx(data, idx);

  205.         trie_url_add(trie, (const u8*)idx,len);
  206.     }    
  207. }

  208. trie_t* URLTRIE_new()
  209. {
  210.     trie_t *trie;

  211.     trie = (trie_t *)malloc(sizeof(trie_t));
  212.     if (trie)
  213.     {
  214.         memset(trie, 0, sizeof(trie_t));
  215.         return trie;
  216.     }
  217.     return NULL;
  218. }

  219. int URLTRIE_find(trie_t *trie, u8 *url)
  220. {
  221.     u8 idx[128]= {0};
  222.     int len = 0;

  223.     if (NULL == trie || NULL == url)
  224.         return 0;

  225.     len = url2idx(url, idx);
  226.     if (len <= 0)
  227.         return 0;
  228.     return trie_urlidx_find(trie, idx, len);
  229. }

  230. int URLTRIE_add(trie_t *trie, u8 *url)
  231. {
  232.     u8 idx[128]= {0};
  233.     int len = 0;

  234.     if (NULL == trie || NULL == url)
  235.         return 0;

  236.     len = url2idx(url, idx);
  237.     if (len <= 0)
  238.         return -1;

  239.     return trie_url_add(trie, idx, len);
  240. }


  241. #if DEBUG
  242. static void debug_print_trie_url(u8* index, int len)
  243. {
  244.     u8 curl[128] = {0};
  245.     idx2url(index, len, curl);
  246.     printf("%s\n", curl);
  247.     
  248. }
  249. static void debug_print_trie_path(node_t *pnd, u8* stack, int *head)
  250. {
  251.     int i;
  252.     if (pnd->udata[39] & 0x01)
  253.     {
  254.         debug_print_trie_url(stack, *head);
  255.     }

  256.     for (i = 1; i <= 38; i++)
  257.     {
  258.         if (pnd->udata[i])
  259.         {
  260.             stack[(*head)++] = i;
  261.             debug_print_trie_path(pnd->pnodes[pnd->udata[i] -1], stack, head);
  262.             stack[(*head)--] = 0;
  263.         }
  264.     }
  265. }
  266. static void debug_print_trie(trie_t *trie)
  267. {
  268.     u8 stack[128] = {0};
  269.     int head= 0;
  270.     node_t *pnd = &trie->root;

  271.     printf("trie:%p, rul number: %u, node number:%u, size:%u\n", trie, trie->urlnum, trie->ndnum, trie->size);
  272.     //debug_print_trie_path(pnd, stack, &head);
  273. }

  274. static void debug_search_file(trie_t *trie, u8 *path)
  275. {
  276.     char buf[130] = {0};
  277.     FILE *pFile;
  278.     char data[100] = {0};
  279.     
  280.     unsigned int total = 0;
  281.     unsigned int matched = 0;
  282.     
  283.     pFile =fopen(path, "r");

  284.     if (NULL == pFile )
  285.     {
  286.         printf("can't open %s\n", path);
  287.         return ;
  288.     }
  289.     while (( fgets(data, 128, pFile) != NULL))
  290.     {
  291.         int len = strlen(data);
  292.         u8 idx[128] = {0};
  293.         if (data[len -1] == '\n')
  294.         {
  295.             data[len - 1] = 0;
  296.         }
  297.         total++;
  298.         if (URLTRIE_find(trie, data))
  299.             matched++;
  300.     }
  301.     printf("search %s, total :%u, matched:%u\n", path, total, matched);
  302. }
  303. static void debug_test()
  304. {
  305.     trie_t *trie = URLTRIE_new();
  306.     struct timeval tv1;
  307.     struct timeval tv2;
  308.     
  309.     if (NULL == trie)
  310.     {
  311.         printf("new trie fail\n");
  312.         return;
  313.     }
  314.     gettimeofday(&tv1, NULL);
  315.     URLTRIE_loadfile(trie, "domain.dat");
  316.     gettimeofday(&tv2, NULL);

  317.     printf("load time: %ld(us)\n", tv2.tv_usec - tv1.tv_usec);
  318.     debug_print_trie(trie);

  319.     gettimeofday(&tv1, NULL);
  320.     debug_search_file(trie, "domain.dat");
  321.     gettimeofday(&tv2, NULL);
  322.     printf("search time: %ld(us)\n", tv2.tv_usec - tv1.tv_usec);
  323. }

  324. #endif


  325. int main(int argc, char **argv)
  326. {
  327.     #if DEBUG
  328.         debug_test();
  329.     #endif
  330.     return 0;
  331. }

最终测试结果:
load time: 26534(us)
trie:0x2101010, rul number: 5660, node number:38664, size:1855872
search domain.dat, total :5660, matched:5660
search time: 7816(us)


测试域名资源文件
tire 做域名白名单-LMLPHPdomain.txt

后续优化:
1, 域名中www开头的很多,如果把www压缩后可以省不少空间(domain.txt中数据可以节省%30左右);
2, 上面的代码可以避免完全重复的域名添加,但是没有检出域名级数包含。当然简化处理是外部先处理原数据,剔除重复包含关系的数据。

10-27 15:43