背景
在路由器上,需要做一个域名分流功能:系统配置有两个DNS server,通过配置一个白名单,让白名单内的域名走 DNS server 1, 其余的的dns解析走DNS server 2。最初想到的方案是,通过dnsmasq配置来完成。dnsmaq实际上是支持一种配置方式就指定某个域名用指定的dns来解析。但,要一条一条的配置,稍显繁琐;而且dnsmaq这种配置方式暗示,很大可能未对大量的这样的配置做优化,很可能是逐条匹配,那么性能估计难达要求(笔者赖,没有去验证这个事实),而且小设备上cpu资源很紧张,所以不太想在dnsmasq上去纠缠。
另一个方案,就是实现一个类似ipset 的 iptables模块,通过内核匹配直接指定目标server:
点击(此处)折叠或打开
- iptables -t nat -p udp -m udp --dport 53 -m domainset --set-name "srv2" -j DNAT --to-destionation [srvip:53]
至于domainset这个模块,netfilter框架有很好的支持,照猫画虎即可。问题在于这个模块在内核的数据路径上(虽然只包括dns的请求路径),对于性能要求比较高,而且需要兼顾cpu和内存资源。于是并有了本主题的验证:
小设备上(cpu 800Mhz,单核或双核,内存64M 或128M)如何快速高效的匹配10000数量级的域名白名单?
通过简单验证,排除了hash, hash 链表等算法,选择了trie。
原因:
1, 该项目中,白名单是事先准备好的,设备运行过程中变动很小,而且基本上只加不删;
2,由于数量不是很大,trie对内存的消耗基本可以接受;
3,trie查询相比hash,对cpu消耗极低,而且查询效率比hash高。
实现和验证
首先定义数据结构:由于域名中包含的字母 0-9,a-z,A-Z, '.', '-', 由于域名是不区分大小写的,那么把‘a-z’与‘A-Z’合并;即有效字母共26+10+2=38个。
于是不难定义节点结构体为:
点击(此处)折叠或打开
- struct trie_node {
- char key;
- struct trie_node* psub[38];
- }
在上诉结构体中,每个字母的下一级包含了所有可能(这样做是为了高速的查询);但实际上可以预期,这部分有巨大的浪费。但如果psub不使用静态数组,而使用链表的话,需要遍历,那么查询速度会大打折扣。
一个折中的办法是,psub采用动态数组,记录当前节点实际有效的子节点信息;额为维护一个索引,用于快速索引当前节点。
点击(此处)折叠或打开
- struct trie_node{
- char key;
- unsigned char count;
- unsigned char index[39]; // start from 1
- struct trie_node **psub;
- }
由于index记录了字母信息,所以成员key可以省去。
另外,由于需求要求域名可以支持顶层域名泛型配置(比如,配置了abc.com为白名单,那么YYY.XXX.abc.com都与其匹配)。为此,我们在存储时,采用倒置存储,moc.cba, 与 moc.cba.*匹配。
但有种情况是,配置了moc.cba, 但搜索moc.cbaXYZ,这二者存在交叠,但又不是域名级数包含关系,实际上二者不能匹配。于是,我们需要在搜索路径上,标明moc.cbaXYZ这个路径上,‘a’到达一次完结点,‘Z’到达一次完结点。因此需要增加一个成员来记录该信息。
点击(此处)折叠或打开
- struct trie_node {
- unsigned char count;
- unsigned char index[39];
- unsigned char flag;
- struct trie_node **psub;
- }
最终
点击(此处)折叠或打开
- struct URL_TRIE_NODE {
- /*
- * udata[0], for counter
- * udata[1-10], for index ??0-9??in @pnodes
- * udata[11-36], for index 'a-z' in @pnodes
- * udata[37], for index '.' in @pnodes
- * udata[38], for index '-' in @pnodes
- * udata[39], flags
- */
- u8 udata[40];
- struct URL_TRIE_NODE **pnodes;
- };
- typedef struct URL_TRIE_NODE node_t;
附完整demo代码:
点击(此处)折叠或打开
- #include <stdio.h>
- #include <stdlib.h>
- #include <string.h>
- #include <sys/time.h>
- #include <unistd.h>
- #define DEBUG 1
- #undef u8
- typedef unsigned char u8;
- #define IDX_VAL_DOT 37
- #define IDX_VAL_HYPLINE 38
- #define IDX_VAL_FLAGS 39
- #define DOMAIN_TAIL_MASK 0x01
- struct URL_TRIE_NODE {
- /*
- * udata[0], for counter
- * udata[1-10], for index ??0-9??in @pnodes
- * udata[11-36], for index 'a-z' in @pnodes
- * udata[37], for index '.' in @pnodes
- * udata[38], for index '-' in @pnodes
- * udata[39], flags
- */
- u8 udata[40];
- struct URL_TRIE_NODE **pnodes;
- };
- typedef struct URL_TRIE_NODE node_t;
- #define IS_DOMAIN_TAIL(flag) ((flag) & DOMAIN_TAIL_MASK)
- struct URL_TRIE {
- unsigned int ndnum;
- unsigned int urlnum;
- unsigned int size;
- node_t root;
- };
- typedef struct URL_TRIE trie_t;
- static int trie_urlidx_find(trie_t *trie, const u8 *urlidx, int len)
- {
- node_t *pnd = &trie->root;
- int cursor = 0;
- while (pnd && (cursor < len))
- {
- int ndidx = pnd->udata[urlidx[cursor]];
-
- if (IS_DOMAIN_TAIL(pnd->udata[IDX_VAL_FLAGS]))
- {
- /*
- * domain top level overlap
- * recorder: abc.com match rul<XXX.abc.com>
- */
- if (urlidx[cursor] == IDX_VAL_DOT)
- {
- return 1;
- }
- }
- if (ndidx)
- {
- pnd = pnd->pnodes[ndidx - 1];
- // common mathed
- if (cursor == len - 1 && IS_DOMAIN_TAIL(pnd->udata[IDX_VAL_FLAGS]))
- {
- return 1;
- }
- cursor ++;
- }
- else
- {
- break;
- }
-
- }
- return 0;
- }
- static int trie_url_add(trie_t *trie, const u8 *urlidx, int len)
- {
- node_t* pnd = &trie->root;
- int cursor = 0;
- while (cursor < len)
- {
- int ndidx = pnd->udata[urlidx[cursor]]; // start from 1
- if (ndidx)
- {
- pnd = pnd->pnodes[ndidx - 1];
- cursor++;
- }
- else
- {
- node_t **old = pnd->pnodes;
- node_t **pnnd;
- node_t *nd = NULL;
-
- pnnd = (node_t**)malloc(sizeof(node_t *) * (pnd->udata[0] + 1));
- if (!pnnd)
- {
- printf("malloc for new nodes for(%p) error\n", pnd);
- return -1;
- }
- nd = (node_t*)malloc(sizeof(node_t));
- if (!nd)
- {
- printf("malloc new node fail\n");
- free(pnnd);
- return -1;
- }
- memset(nd, 0, sizeof(node_t));
- if (pnd->udata[0])
- {
- memcpy(pnnd, old, sizeof(node_t*) * pnd->udata[0]);
- free(old);
- }
- pnd->pnodes= pnnd;
- pnd->pnodes[pnd->udata[0]++] = nd;
- pnd->udata[urlidx[cursor]] = pnd->udata[0];
- trie->ndnum++;
- trie->size += sizeof(node_t) + sizeof(node_t*);
- pnd = nd;
- cursor++;
- }
- }
- pnd->udata[IDX_VAL_FLAGS] |= DOMAIN_TAIL_MASK; // add end flag
- trie->urlnum++;
- return 0;
- }
- static int url2idx(const u8 *url, u8* index)
- {
- int len = strlen(url);
- int cursor = 0;
-
- while( cursor < len)
- {
- int idx = 0;
- if (url[cursor] == '.')
- {
- idx = IDX_VAL_DOT;
- }
- else if (url[cursor] == '-')
- {
- idx = IDX_VAL_HYPLINE;
- }
- else if (url[cursor] >= '0' && url[cursor] <= '9')
- {
- idx = 1 + (url[cursor] - '0');
- }
- else if (url[cursor] >= 'a' & url[cursor] <= 'z')
- {
- idx = 11 + (url[cursor] - 'a');
- }
- else if (url[cursor] >= 'A' & url[cursor] <= 'Z')
- {
- idx = 11 + (url[cursor] - 'A');
- }
- else
- {
- printf("bad url:%s, bad key:%c\n", url, url[cursor]);
- return 0;
- }
- index[len -1 -cursor] = idx;
- cursor++;
- }
- return len;
- }
- static int idx2url(u8 *index, int len, u8 *url)
- {
- int cursor = 0;
- while (cursor < len)
- {
- u8 ch = 0;
- if (index[cursor] == IDX_VAL_DOT )
- ch = '.';
- else if (index[cursor] == IDX_VAL_HYPLINE )
- ch = '-';
- else if (index[cursor] >= 1 && index[cursor] <= 10)
- ch = '0' + (index[cursor] -1);
- else if (index[cursor] >= 11 && index[cursor] <= 36)
- ch = 'a' + (index[cursor]- 11);
- else
- {
- printf("fail to map %d\n", index[cursor]);
- return -1;
- }
- url[len - 1 - cursor] = ch;
- cursor++;
- }
- url[len] = 0;
- return 0;
- }
- int URLTRIE_loadfile(trie_t *trie, char *path)
- {
- char buf[130] = {0};
- FILE *pFile;
- char data[100] = {0};
-
- pFile =fopen(path, "r");
- if (NULL == pFile )
- {
- printf("can't open %s\n", path);
- return -1;
- }
- while (( fgets(data, 128, pFile) != NULL))
- {
- int len = strlen(data);
- u8 idx[128] = {0};
- if (data[len -1] == '\n')
- {
- data[len - 1] = 0;
- }
- len = url2idx(data, idx);
- trie_url_add(trie, (const u8*)idx,len);
- }
- }
- trie_t* URLTRIE_new()
- {
- trie_t *trie;
- trie = (trie_t *)malloc(sizeof(trie_t));
- if (trie)
- {
- memset(trie, 0, sizeof(trie_t));
- return trie;
- }
- return NULL;
- }
- int URLTRIE_find(trie_t *trie, u8 *url)
- {
- u8 idx[128]= {0};
- int len = 0;
- if (NULL == trie || NULL == url)
- return 0;
- len = url2idx(url, idx);
- if (len <= 0)
- return 0;
- return trie_urlidx_find(trie, idx, len);
- }
- int URLTRIE_add(trie_t *trie, u8 *url)
- {
- u8 idx[128]= {0};
- int len = 0;
- if (NULL == trie || NULL == url)
- return 0;
- len = url2idx(url, idx);
- if (len <= 0)
- return -1;
- return trie_url_add(trie, idx, len);
- }
- #if DEBUG
- static void debug_print_trie_url(u8* index, int len)
- {
- u8 curl[128] = {0};
- idx2url(index, len, curl);
- printf("%s\n", curl);
-
- }
- static void debug_print_trie_path(node_t *pnd, u8* stack, int *head)
- {
- int i;
- if (pnd->udata[39] & 0x01)
- {
- debug_print_trie_url(stack, *head);
- }
- for (i = 1; i <= 38; i++)
- {
- if (pnd->udata[i])
- {
- stack[(*head)++] = i;
- debug_print_trie_path(pnd->pnodes[pnd->udata[i] -1], stack, head);
- stack[(*head)--] = 0;
- }
- }
- }
- static void debug_print_trie(trie_t *trie)
- {
- u8 stack[128] = {0};
- int head= 0;
- node_t *pnd = &trie->root;
- printf("trie:%p, rul number: %u, node number:%u, size:%u\n", trie, trie->urlnum, trie->ndnum, trie->size);
- //debug_print_trie_path(pnd, stack, &head);
- }
- static void debug_search_file(trie_t *trie, u8 *path)
- {
- char buf[130] = {0};
- FILE *pFile;
- char data[100] = {0};
-
- unsigned int total = 0;
- unsigned int matched = 0;
-
- pFile =fopen(path, "r");
- if (NULL == pFile )
- {
- printf("can't open %s\n", path);
- return ;
- }
- while (( fgets(data, 128, pFile) != NULL))
- {
- int len = strlen(data);
- u8 idx[128] = {0};
- if (data[len -1] == '\n')
- {
- data[len - 1] = 0;
- }
- total++;
- if (URLTRIE_find(trie, data))
- matched++;
- }
- printf("search %s, total :%u, matched:%u\n", path, total, matched);
- }
- static void debug_test()
- {
- trie_t *trie = URLTRIE_new();
- struct timeval tv1;
- struct timeval tv2;
-
- if (NULL == trie)
- {
- printf("new trie fail\n");
- return;
- }
- gettimeofday(&tv1, NULL);
- URLTRIE_loadfile(trie, "domain.dat");
- gettimeofday(&tv2, NULL);
- printf("load time: %ld(us)\n", tv2.tv_usec - tv1.tv_usec);
- debug_print_trie(trie);
- gettimeofday(&tv1, NULL);
- debug_search_file(trie, "domain.dat");
- gettimeofday(&tv2, NULL);
- printf("search time: %ld(us)\n", tv2.tv_usec - tv1.tv_usec);
- }
- #endif
- int main(int argc, char **argv)
- {
- #if DEBUG
- debug_test();
- #endif
- return 0;
- }
最终测试结果:
load time: 26534(us)
trie:0x2101010, rul number: 5660, node number:38664, size:1855872
search domain.dat, total :5660, matched:5660
search time: 7816(us)
测试域名资源文件
domain.txt
后续优化:
1, 域名中www开头的很多,如果把www压缩后可以省不少空间(domain.txt中数据可以节省%30左右);
2, 上面的代码可以避免完全重复的域名添加,但是没有检出域名级数包含。当然简化处理是外部先处理原数据,剔除重复包含关系的数据。