在C语言中,宽度最大的无符号整数类型是unsigned long long, 占8个字节。那么,如果整数超过8个字节,如何进行大数乘法呢? 例如:

$ python
Python 2.7.6 (default, Oct 26 2016, 20:32:47)
...<snip>....
>>> a = 0x123456781234567812345678
>>> b = 0x876543211234567887654321
>>> print "a * b = 0x%x" % (a * b)
a * b = 0x9a0cd057ba4c159a33a669f0a522711984e32bd70b88d78

用C语言实现大数乘法,跟十进制的多位数乘法类似,基本思路是采用分而治之的策略难点就是进位处理相对比较复杂。本文尝试给出C代码实现(基于小端),并使用Python脚本验证计算结果。

1. foo.c

 #include <stdio.h>
#include <stdlib.h>
#include <string.h> typedef unsigned char byte; /* 1 byte */
typedef unsigned short word; /* 2 bytes */
typedef unsigned int dword; /* 4 bytes */
typedef unsigned long long qword; /* 8 bytes */ typedef struct big_number_s {
dword *data;
dword size;
} big_number_t; static void
dump(char *tag, big_number_t *p)
{
if (p == NULL)
return; printf("%s : data=%p : size=%d:\t", tag, p, p->size);
for (dword i = ; i < p->size; i++)
printf("0x%08x ", (p->data)[i]);
printf("\n");
} /*
* Add 64-bit number (8 bytes) to a[] whose element is 32-bit int (4 bytes)
*
* e.g.
* a[] = {0x12345678,0x87654321,0x0}; n = 3;
* n64 = 0xffffffff12345678
*
* The whole process of add64() looks like:
*
* 0x12345678 0x87654321 0x00000000
* + 0x12345678 0xffffffff
* -----------------------------------
* = 0x2468acf0 0x87654321 0x00000000
* + 0xffffffff
* -----------------------------------
* = 0x2468acf0 0x87654320 0x00000001
*
* Finally,
* a[] = {0x2468acf0,0x87654320,0x00000001}
*/
static void
add64(dword a[], dword n, qword n64)
{
dword carry = ; carry = n64 & 0xFFFFFFFF; /* low 32 bits of n64 */
for (dword i = ; i < n; i++) {
if (carry == 0x0)
break; qword t = (qword)a[i] + (qword)carry;
a[i] = t & 0xFFFFFFFF;
carry = (dword)(t >> ); /* next carry */
} carry = (dword)(n64 >> ); /* high 32 bits of n64 */
for (dword i = ; i < n; i++) {
if (carry == 0x0)
break; qword t = (qword)a[i] + (qword)carry;
a[i] = t & 0xFFFFFFFF;
carry = (dword)(t >> ); /* next carry */
}
} static big_number_t *
big_number_mul(big_number_t *a, big_number_t *b)
{
big_number_t *c = (big_number_t *)malloc(sizeof(big_number_t));
if (c == NULL) /* malloc error */
return NULL; c->size = a->size + b->size;
c->data = (dword *)malloc(sizeof(dword) * c->size);
if (c->data == NULL) /* malloc error */
return NULL; memset(c->data, , sizeof(dword) * c->size); dword *adp = a->data;
dword *bdp = b->data;
dword *cdp = c->data;
for (dword i = ; i < a->size; i++) {
if (adp[i] == 0x0)
continue; for (dword j = ; j < b->size; j++) {
if (bdp[j] == 0x0)
continue; qword n64 = (qword)adp[i] * (qword)bdp[j];
dword *dst = cdp + i + j;
add64(dst, c->size - (i + j), n64);
}
} return c;
} static void
free_big_number(big_number_t *p)
{
if (p == NULL)
return; if (p->data != NULL)
free(p->data); free(p);
} int
main(int argc, char *argv[])
{
dword a_data[] = {0x12345678, 0x9abcdef0, 0xffffffff, 0x9abcdefa, 0x0};
dword b_data[] = {0xfedcba98, 0x76543210, 0x76543210, 0xfedcba98, 0x0}; big_number_t a;
a.data = (dword *)a_data;
a.size = sizeof(a_data) / sizeof(dword); big_number_t b;
b.data = (dword *)b_data;
b.size = sizeof(b_data) / sizeof(dword); dump("BigNumber A", &a);
dump("BigNumber B", &b);
big_number_t *c = big_number_mul(&a, &b);
dump(" C = A * B", c);
free_big_number(c); return ;
}

2. bar.py

 #!/usr/bin/python

 import sys

 def str2hex(s):
l = s.split(' ') i = len(l)
out = ""
while i > 0:
i -= 1
e = l[i]
if e.startswith("0x"):
e = e[2:]
out += e out = "0x%s" % out
n = eval("%s * %d" % (out, 0x1))
return n def hex2str(n):
s_hex = "%x" % n
if s_hex.startswith("0x"):
s_hex = s_hex[2:] n = len(s_hex)
m = n % 8
if m != 0:
s_hex = '' * (8 - m) + s_hex
n += (8 - m)
i = n
l = []
while i >= 8:
l.append('0x' + s_hex[i-8:i])
i -= 8
return "%s" % ' '.join(l) def main(argc, argv):
if argc != 4:
sys.stderr.write("Usage: %s <a> <b> <c>\n" % argv[0])
return 1 a = argv[1]
b = argv[2]
c = argv[3]
ax = str2hex(a)
bx = str2hex(b)
cx = str2hex(c) axbx = ax * bx
if axbx != cx:
print "0x%x * 0x%x = " % (ax, bx)
print "got: 0x%x" % axbx
print "exp: 0x%x" % cx
print "res: FAIL"
return 1 print "got: %s" % hex2str(axbx)
print "exp: %s" % c
print "res: PASS"
return 0 if __name__ == '__main__':
argv = sys.argv
argc = len(argv)
sys.exit(main(argc, argv))

3. Makefile

CC        = gcc
CFLAGS = -g -Wall -m32 -std=c99 TARGETS = foo bar all: $(TARGETS) foo: foo.c
$(CC) $(CFLAGS) -o $@ $< bar: bar.py
cp $< $@ && chmod +x $@ clean:
rm -f *.o
clobber: clean
rm -f $(TARGETS)
cl: clobber

4. 编译并测试

$ make
gcc -g -Wall -m32 -std=c99 -o foo foo.c
cp bar.py bar && chmod +x bar
$ ./foo
BigNumber A : data=0xbfc2a7c8 : size=: 0x12345678 0x9abcdef0 0xffffffff 0x9abcdefa 0x00000000
BigNumber B : data=0xbfc2a7d0 : size=: 0xfedcba98 0x76543210 0x76543210 0xfedcba98 0x00000000
C = A * B : data=0x8967008 : size=: 0x35068740 0xee07360a 0x053bd8c9 0x2895f6cd 0xb973e57e 0x4e6cfe66 0x0b60b60b 0x9a0cd056 0x00000000 0x00000000
$ A="0x12345678 0x9abcdef0 0xffffffff 0x9abcdefa 0x00000000"
$ B="0xfedcba98 0x76543210 0x76543210 0xfedcba98 0x00000000"
$ C="0x35068740 0xee07360a 0x053bd8c9 0x2895f6cd 0xb973e57e 0x4e6cfe66 0x0b60b60b 0x9a0cd056 0x00000000 0x00000000"
$
$ ./bar "$A" "$B" "$C"
got: 0x35068740 0xee07360a 0x053bd8c9 0x2895f6cd 0xb973e57e 0x4e6cfe66 0x0b60b60b 0x9a0cd056
exp: 0x35068740 0xee07360a 0x053bd8c9 0x2895f6cd 0xb973e57e 0x4e6cfe66 0x0b60b60b 0x9a0cd056 0x00000000 0x00000000
res: PASS
$

结束语:

本文给出的是串行化的大数乘法实现方法。 A * B = C设定如下:

  • 大数A对应的数组长度为M, A = {a0, a1, ..., aM};
  • 大数B对应的数组长度为N, B = {b0, b1, ..., bN};
  • A*B的结果C对应的数组长度为(M+N)。
A = { a0, a1, ..., aM };
B = { b0, b1, ..., bN }; C = A * B
= a0 * b0 + a0 * b1 + ... + a0 * bN
+ a1 * b0 + a1 * b1 + ... + a1 * bN
+ ...
+ aM * b0 + aM * b1 + ... + aM * bN a[i] * b[j] will be save to memory @ c[i+j]
i = 0, 1, ..., M;
j = 0, 1, ..., N
a[i] is unsigned int (4 bytes)
b[j] is unsigned int (4 bytes)

算法的时间复杂度为O(M*N), 空间复杂度为O(1)。 为了缩短运行时间,我们也可以采用并行化的实现方法。

  • 启动M个线程同时计算, T0 = a0 * B, T1 = a1 * B, ..., TM = aM * B;
  • 接下来,只要M个线程都把活干完了,主线程就可以对T0, T1, ..., TM进行合并。

不过,在并行化的实现方法中,对每一个线程来说,时间复杂度为O(N), 空间复杂度为O(N) (至少N+3个辅助存储空间)。因为有M个线程并行计算,于是总的空间复杂度为O(M*N)。

05-11 17:04