1 | package bigfft
|
---|
2 |
|
---|
3 | import (
|
---|
4 | "math/big"
|
---|
5 | )
|
---|
6 |
|
---|
7 | // Arithmetic modulo 2^n+1.
|
---|
8 |
|
---|
9 | // A fermat of length w+1 represents a number modulo 2^(w*_W) + 1. The last
|
---|
10 | // word is zero or one. A number has at most two representatives satisfying the
|
---|
11 | // 0-1 last word constraint.
|
---|
12 | type fermat nat
|
---|
13 |
|
---|
14 | func (n fermat) String() string { return nat(n).String() }
|
---|
15 |
|
---|
16 | func (z fermat) norm() {
|
---|
17 | n := len(z) - 1
|
---|
18 | c := z[n]
|
---|
19 | if c == 0 {
|
---|
20 | return
|
---|
21 | }
|
---|
22 | if z[0] >= c {
|
---|
23 | z[n] = 0
|
---|
24 | z[0] -= c
|
---|
25 | return
|
---|
26 | }
|
---|
27 | // z[0] < z[n].
|
---|
28 | subVW(z, z, c) // Substract c
|
---|
29 | if c > 1 {
|
---|
30 | z[n] -= c - 1
|
---|
31 | c = 1
|
---|
32 | }
|
---|
33 | // Add back c.
|
---|
34 | if z[n] == 1 {
|
---|
35 | z[n] = 0
|
---|
36 | return
|
---|
37 | } else {
|
---|
38 | addVW(z, z, 1)
|
---|
39 | }
|
---|
40 | }
|
---|
41 |
|
---|
42 | // Shift computes (x << k) mod (2^n+1).
|
---|
43 | func (z fermat) Shift(x fermat, k int) {
|
---|
44 | if len(z) != len(x) {
|
---|
45 | panic("len(z) != len(x) in Shift")
|
---|
46 | }
|
---|
47 | n := len(x) - 1
|
---|
48 | // Shift by n*_W is taking the opposite.
|
---|
49 | k %= 2 * n * _W
|
---|
50 | if k < 0 {
|
---|
51 | k += 2 * n * _W
|
---|
52 | }
|
---|
53 | neg := false
|
---|
54 | if k >= n*_W {
|
---|
55 | k -= n * _W
|
---|
56 | neg = true
|
---|
57 | }
|
---|
58 |
|
---|
59 | kw, kb := k/_W, k%_W
|
---|
60 |
|
---|
61 | z[n] = 1 // Add (-1)
|
---|
62 | if !neg {
|
---|
63 | for i := 0; i < kw; i++ {
|
---|
64 | z[i] = 0
|
---|
65 | }
|
---|
66 | // Shift left by kw words.
|
---|
67 | // x = a·2^(n-k) + b
|
---|
68 | // x<<k = (b<<k) - a
|
---|
69 | copy(z[kw:], x[:n-kw])
|
---|
70 | b := subVV(z[:kw+1], z[:kw+1], x[n-kw:])
|
---|
71 | if z[kw+1] > 0 {
|
---|
72 | z[kw+1] -= b
|
---|
73 | } else {
|
---|
74 | subVW(z[kw+1:], z[kw+1:], b)
|
---|
75 | }
|
---|
76 | } else {
|
---|
77 | for i := kw + 1; i < n; i++ {
|
---|
78 | z[i] = 0
|
---|
79 | }
|
---|
80 | // Shift left and negate, by kw words.
|
---|
81 | copy(z[:kw+1], x[n-kw:n+1]) // z_low = x_high
|
---|
82 | b := subVV(z[kw:n], z[kw:n], x[:n-kw]) // z_high -= x_low
|
---|
83 | z[n] -= b
|
---|
84 | }
|
---|
85 | // Add back 1.
|
---|
86 | if z[n] > 0 {
|
---|
87 | z[n]--
|
---|
88 | } else if z[0] < ^big.Word(0) {
|
---|
89 | z[0]++
|
---|
90 | } else {
|
---|
91 | addVW(z, z, 1)
|
---|
92 | }
|
---|
93 | // Shift left by kb bits
|
---|
94 | shlVU(z, z, uint(kb))
|
---|
95 | z.norm()
|
---|
96 | }
|
---|
97 |
|
---|
98 | // ShiftHalf shifts x by k/2 bits the left. Shifting by 1/2 bit
|
---|
99 | // is multiplication by sqrt(2) mod 2^n+1 which is 2^(3n/4) - 2^(n/4).
|
---|
100 | // A temporary buffer must be provided in tmp.
|
---|
101 | func (z fermat) ShiftHalf(x fermat, k int, tmp fermat) {
|
---|
102 | n := len(z) - 1
|
---|
103 | if k%2 == 0 {
|
---|
104 | z.Shift(x, k/2)
|
---|
105 | return
|
---|
106 | }
|
---|
107 | u := (k - 1) / 2
|
---|
108 | a := u + (3*_W/4)*n
|
---|
109 | b := u + (_W/4)*n
|
---|
110 | z.Shift(x, a)
|
---|
111 | tmp.Shift(x, b)
|
---|
112 | z.Sub(z, tmp)
|
---|
113 | }
|
---|
114 |
|
---|
115 | // Add computes addition mod 2^n+1.
|
---|
116 | func (z fermat) Add(x, y fermat) fermat {
|
---|
117 | if len(z) != len(x) {
|
---|
118 | panic("Add: len(z) != len(x)")
|
---|
119 | }
|
---|
120 | addVV(z, x, y) // there cannot be a carry here.
|
---|
121 | z.norm()
|
---|
122 | return z
|
---|
123 | }
|
---|
124 |
|
---|
125 | // Sub computes substraction mod 2^n+1.
|
---|
126 | func (z fermat) Sub(x, y fermat) fermat {
|
---|
127 | if len(z) != len(x) {
|
---|
128 | panic("Add: len(z) != len(x)")
|
---|
129 | }
|
---|
130 | n := len(y) - 1
|
---|
131 | b := subVV(z[:n], x[:n], y[:n])
|
---|
132 | b += y[n]
|
---|
133 | // If b > 0, we need to subtract b<<n, which is the same as adding b.
|
---|
134 | z[n] = x[n]
|
---|
135 | if z[0] <= ^big.Word(0)-b {
|
---|
136 | z[0] += b
|
---|
137 | } else {
|
---|
138 | addVW(z, z, b)
|
---|
139 | }
|
---|
140 | z.norm()
|
---|
141 | return z
|
---|
142 | }
|
---|
143 |
|
---|
144 | func (z fermat) Mul(x, y fermat) fermat {
|
---|
145 | if len(x) != len(y) {
|
---|
146 | panic("Mul: len(x) != len(y)")
|
---|
147 | }
|
---|
148 | n := len(x) - 1
|
---|
149 | if n < 30 {
|
---|
150 | z = z[:2*n+2]
|
---|
151 | basicMul(z, x, y)
|
---|
152 | z = z[:2*n+1]
|
---|
153 | } else {
|
---|
154 | var xi, yi, zi big.Int
|
---|
155 | xi.SetBits(x)
|
---|
156 | yi.SetBits(y)
|
---|
157 | zi.SetBits(z)
|
---|
158 | zb := zi.Mul(&xi, &yi).Bits()
|
---|
159 | if len(zb) <= n {
|
---|
160 | // Short product.
|
---|
161 | copy(z, zb)
|
---|
162 | for i := len(zb); i < len(z); i++ {
|
---|
163 | z[i] = 0
|
---|
164 | }
|
---|
165 | return z
|
---|
166 | }
|
---|
167 | z = zb
|
---|
168 | }
|
---|
169 | // len(z) is at most 2n+1.
|
---|
170 | if len(z) > 2*n+1 {
|
---|
171 | panic("len(z) > 2n+1")
|
---|
172 | }
|
---|
173 | // We now have
|
---|
174 | // z = z[:n] + 1<<(n*W) * z[n:2n+1]
|
---|
175 | // which normalizes to:
|
---|
176 | // z = z[:n] - z[n:2n] + z[2n]
|
---|
177 | c1 := big.Word(0)
|
---|
178 | if len(z) > 2*n {
|
---|
179 | c1 = addVW(z[:n], z[:n], z[2*n])
|
---|
180 | }
|
---|
181 | c2 := big.Word(0)
|
---|
182 | if len(z) >= 2*n {
|
---|
183 | c2 = subVV(z[:n], z[:n], z[n:2*n])
|
---|
184 | } else {
|
---|
185 | m := len(z) - n
|
---|
186 | c2 = subVV(z[:m], z[:m], z[n:])
|
---|
187 | c2 = subVW(z[m:n], z[m:n], c2)
|
---|
188 | }
|
---|
189 | // Restore carries.
|
---|
190 | // Substracting z[n] -= c2 is the same
|
---|
191 | // as z[0] += c2
|
---|
192 | z = z[:n+1]
|
---|
193 | z[n] = c1
|
---|
194 | c := addVW(z, z, c2)
|
---|
195 | if c != 0 {
|
---|
196 | panic("impossible")
|
---|
197 | }
|
---|
198 | z.norm()
|
---|
199 | return z
|
---|
200 | }
|
---|
201 |
|
---|
202 | // copied from math/big
|
---|
203 | //
|
---|
204 | // basicMul multiplies x and y and leaves the result in z.
|
---|
205 | // The (non-normalized) result is placed in z[0 : len(x) + len(y)].
|
---|
206 | func basicMul(z, x, y fermat) {
|
---|
207 | // initialize z
|
---|
208 | for i := 0; i < len(z); i++ {
|
---|
209 | z[i] = 0
|
---|
210 | }
|
---|
211 | for i, d := range y {
|
---|
212 | if d != 0 {
|
---|
213 | z[len(x)+i] = addMulVVW(z[i:i+len(x)], x, d)
|
---|
214 | }
|
---|
215 | }
|
---|
216 | }
|
---|