[822] | 1 | package bigfft
|
---|
| 2 |
|
---|
| 3 | import (
|
---|
| 4 | "math/big"
|
---|
| 5 | )
|
---|
| 6 |
|
---|
| 7 | // Arithmetic modulo 2^n+1.
|
---|
| 8 |
|
---|
| 9 | // A fermat of length w+1 represents a number modulo 2^(w*_W) + 1. The last
|
---|
| 10 | // word is zero or one. A number has at most two representatives satisfying the
|
---|
| 11 | // 0-1 last word constraint.
|
---|
| 12 | type fermat nat
|
---|
| 13 |
|
---|
| 14 | func (n fermat) String() string { return nat(n).String() }
|
---|
| 15 |
|
---|
| 16 | func (z fermat) norm() {
|
---|
| 17 | n := len(z) - 1
|
---|
| 18 | c := z[n]
|
---|
| 19 | if c == 0 {
|
---|
| 20 | return
|
---|
| 21 | }
|
---|
| 22 | if z[0] >= c {
|
---|
| 23 | z[n] = 0
|
---|
| 24 | z[0] -= c
|
---|
| 25 | return
|
---|
| 26 | }
|
---|
| 27 | // z[0] < z[n].
|
---|
| 28 | subVW(z, z, c) // Substract c
|
---|
| 29 | if c > 1 {
|
---|
| 30 | z[n] -= c - 1
|
---|
| 31 | c = 1
|
---|
| 32 | }
|
---|
| 33 | // Add back c.
|
---|
| 34 | if z[n] == 1 {
|
---|
| 35 | z[n] = 0
|
---|
| 36 | return
|
---|
| 37 | } else {
|
---|
| 38 | addVW(z, z, 1)
|
---|
| 39 | }
|
---|
| 40 | }
|
---|
| 41 |
|
---|
| 42 | // Shift computes (x << k) mod (2^n+1).
|
---|
| 43 | func (z fermat) Shift(x fermat, k int) {
|
---|
| 44 | if len(z) != len(x) {
|
---|
| 45 | panic("len(z) != len(x) in Shift")
|
---|
| 46 | }
|
---|
| 47 | n := len(x) - 1
|
---|
| 48 | // Shift by n*_W is taking the opposite.
|
---|
| 49 | k %= 2 * n * _W
|
---|
| 50 | if k < 0 {
|
---|
| 51 | k += 2 * n * _W
|
---|
| 52 | }
|
---|
| 53 | neg := false
|
---|
| 54 | if k >= n*_W {
|
---|
| 55 | k -= n * _W
|
---|
| 56 | neg = true
|
---|
| 57 | }
|
---|
| 58 |
|
---|
| 59 | kw, kb := k/_W, k%_W
|
---|
| 60 |
|
---|
| 61 | z[n] = 1 // Add (-1)
|
---|
| 62 | if !neg {
|
---|
| 63 | for i := 0; i < kw; i++ {
|
---|
| 64 | z[i] = 0
|
---|
| 65 | }
|
---|
| 66 | // Shift left by kw words.
|
---|
| 67 | // x = a·2^(n-k) + b
|
---|
| 68 | // x<<k = (b<<k) - a
|
---|
| 69 | copy(z[kw:], x[:n-kw])
|
---|
| 70 | b := subVV(z[:kw+1], z[:kw+1], x[n-kw:])
|
---|
| 71 | if z[kw+1] > 0 {
|
---|
| 72 | z[kw+1] -= b
|
---|
| 73 | } else {
|
---|
| 74 | subVW(z[kw+1:], z[kw+1:], b)
|
---|
| 75 | }
|
---|
| 76 | } else {
|
---|
| 77 | for i := kw + 1; i < n; i++ {
|
---|
| 78 | z[i] = 0
|
---|
| 79 | }
|
---|
| 80 | // Shift left and negate, by kw words.
|
---|
| 81 | copy(z[:kw+1], x[n-kw:n+1]) // z_low = x_high
|
---|
| 82 | b := subVV(z[kw:n], z[kw:n], x[:n-kw]) // z_high -= x_low
|
---|
| 83 | z[n] -= b
|
---|
| 84 | }
|
---|
| 85 | // Add back 1.
|
---|
| 86 | if z[n] > 0 {
|
---|
| 87 | z[n]--
|
---|
| 88 | } else if z[0] < ^big.Word(0) {
|
---|
| 89 | z[0]++
|
---|
| 90 | } else {
|
---|
| 91 | addVW(z, z, 1)
|
---|
| 92 | }
|
---|
| 93 | // Shift left by kb bits
|
---|
| 94 | shlVU(z, z, uint(kb))
|
---|
| 95 | z.norm()
|
---|
| 96 | }
|
---|
| 97 |
|
---|
| 98 | // ShiftHalf shifts x by k/2 bits the left. Shifting by 1/2 bit
|
---|
| 99 | // is multiplication by sqrt(2) mod 2^n+1 which is 2^(3n/4) - 2^(n/4).
|
---|
| 100 | // A temporary buffer must be provided in tmp.
|
---|
| 101 | func (z fermat) ShiftHalf(x fermat, k int, tmp fermat) {
|
---|
| 102 | n := len(z) - 1
|
---|
| 103 | if k%2 == 0 {
|
---|
| 104 | z.Shift(x, k/2)
|
---|
| 105 | return
|
---|
| 106 | }
|
---|
| 107 | u := (k - 1) / 2
|
---|
| 108 | a := u + (3*_W/4)*n
|
---|
| 109 | b := u + (_W/4)*n
|
---|
| 110 | z.Shift(x, a)
|
---|
| 111 | tmp.Shift(x, b)
|
---|
| 112 | z.Sub(z, tmp)
|
---|
| 113 | }
|
---|
| 114 |
|
---|
| 115 | // Add computes addition mod 2^n+1.
|
---|
| 116 | func (z fermat) Add(x, y fermat) fermat {
|
---|
| 117 | if len(z) != len(x) {
|
---|
| 118 | panic("Add: len(z) != len(x)")
|
---|
| 119 | }
|
---|
| 120 | addVV(z, x, y) // there cannot be a carry here.
|
---|
| 121 | z.norm()
|
---|
| 122 | return z
|
---|
| 123 | }
|
---|
| 124 |
|
---|
| 125 | // Sub computes substraction mod 2^n+1.
|
---|
| 126 | func (z fermat) Sub(x, y fermat) fermat {
|
---|
| 127 | if len(z) != len(x) {
|
---|
| 128 | panic("Add: len(z) != len(x)")
|
---|
| 129 | }
|
---|
| 130 | n := len(y) - 1
|
---|
| 131 | b := subVV(z[:n], x[:n], y[:n])
|
---|
| 132 | b += y[n]
|
---|
| 133 | // If b > 0, we need to subtract b<<n, which is the same as adding b.
|
---|
| 134 | z[n] = x[n]
|
---|
| 135 | if z[0] <= ^big.Word(0)-b {
|
---|
| 136 | z[0] += b
|
---|
| 137 | } else {
|
---|
| 138 | addVW(z, z, b)
|
---|
| 139 | }
|
---|
| 140 | z.norm()
|
---|
| 141 | return z
|
---|
| 142 | }
|
---|
| 143 |
|
---|
| 144 | func (z fermat) Mul(x, y fermat) fermat {
|
---|
| 145 | if len(x) != len(y) {
|
---|
| 146 | panic("Mul: len(x) != len(y)")
|
---|
| 147 | }
|
---|
| 148 | n := len(x) - 1
|
---|
| 149 | if n < 30 {
|
---|
| 150 | z = z[:2*n+2]
|
---|
| 151 | basicMul(z, x, y)
|
---|
| 152 | z = z[:2*n+1]
|
---|
| 153 | } else {
|
---|
| 154 | var xi, yi, zi big.Int
|
---|
| 155 | xi.SetBits(x)
|
---|
| 156 | yi.SetBits(y)
|
---|
| 157 | zi.SetBits(z)
|
---|
| 158 | zb := zi.Mul(&xi, &yi).Bits()
|
---|
| 159 | if len(zb) <= n {
|
---|
| 160 | // Short product.
|
---|
| 161 | copy(z, zb)
|
---|
| 162 | for i := len(zb); i < len(z); i++ {
|
---|
| 163 | z[i] = 0
|
---|
| 164 | }
|
---|
| 165 | return z
|
---|
| 166 | }
|
---|
| 167 | z = zb
|
---|
| 168 | }
|
---|
| 169 | // len(z) is at most 2n+1.
|
---|
| 170 | if len(z) > 2*n+1 {
|
---|
| 171 | panic("len(z) > 2n+1")
|
---|
| 172 | }
|
---|
| 173 | // We now have
|
---|
| 174 | // z = z[:n] + 1<<(n*W) * z[n:2n+1]
|
---|
| 175 | // which normalizes to:
|
---|
| 176 | // z = z[:n] - z[n:2n] + z[2n]
|
---|
| 177 | c1 := big.Word(0)
|
---|
| 178 | if len(z) > 2*n {
|
---|
| 179 | c1 = addVW(z[:n], z[:n], z[2*n])
|
---|
| 180 | }
|
---|
| 181 | c2 := big.Word(0)
|
---|
| 182 | if len(z) >= 2*n {
|
---|
| 183 | c2 = subVV(z[:n], z[:n], z[n:2*n])
|
---|
| 184 | } else {
|
---|
| 185 | m := len(z) - n
|
---|
| 186 | c2 = subVV(z[:m], z[:m], z[n:])
|
---|
| 187 | c2 = subVW(z[m:n], z[m:n], c2)
|
---|
| 188 | }
|
---|
| 189 | // Restore carries.
|
---|
| 190 | // Substracting z[n] -= c2 is the same
|
---|
| 191 | // as z[0] += c2
|
---|
| 192 | z = z[:n+1]
|
---|
| 193 | z[n] = c1
|
---|
| 194 | c := addVW(z, z, c2)
|
---|
| 195 | if c != 0 {
|
---|
| 196 | panic("impossible")
|
---|
| 197 | }
|
---|
| 198 | z.norm()
|
---|
| 199 | return z
|
---|
| 200 | }
|
---|
| 201 |
|
---|
| 202 | // copied from math/big
|
---|
| 203 | //
|
---|
| 204 | // basicMul multiplies x and y and leaves the result in z.
|
---|
| 205 | // The (non-normalized) result is placed in z[0 : len(x) + len(y)].
|
---|
| 206 | func basicMul(z, x, y fermat) {
|
---|
| 207 | // initialize z
|
---|
| 208 | for i := 0; i < len(z); i++ {
|
---|
| 209 | z[i] = 0
|
---|
| 210 | }
|
---|
| 211 | for i, d := range y {
|
---|
| 212 | if d != 0 {
|
---|
| 213 | z[len(x)+i] = addMulVVW(z[i:i+len(x)], x, d)
|
---|
| 214 | }
|
---|
| 215 | }
|
---|
| 216 | }
|
---|