From 5f1a98c9f5145e8c2244a23d585581a1ebdca8d1 Mon Sep 17 00:00:00 2001 From: Josh Ellithorpe Date: Thu, 25 Oct 2018 00:43:07 -0700 Subject: [PATCH] Optimize field inversion --- bchec/field.go | 208 +++++++++++++++++++++++++------------------------ 1 file changed, 107 insertions(+), 101 deletions(-) diff --git a/bchec/field.go b/bchec/field.go index 6242d16e9..e51f154e8 100644 --- a/bchec/field.go +++ b/bchec/field.go @@ -1119,105 +1119,111 @@ func (f *fieldVal) SquareVal(val *fieldVal) *fieldVal { // The field value is returned to support chaining. This enables syntax like: // f.Inverse().Mul(f2) so that f = f^-1 * f2. func (f *fieldVal) Inverse() *fieldVal { - // Fermat's little theorem states that for a nonzero number a and prime - // prime p, a^(p-1) = 1 (mod p). Since the multipliciative inverse is - // a*b = 1 (mod p), it follows that b = a*a^(p-2) = a^(p-1) = 1 (mod p). - // Thus, a^(p-2) is the multiplicative inverse. - // - // In order to efficiently compute a^(p-2), p-2 needs to be split into - // a sequence of squares and multipications that minimizes the number of - // multiplications needed (since they are more costly than squarings). - // Intermediate results are saved and reused as well. - // - // The secp256k1 prime - 2 is 2^256 - 4294968275. - // - // This has a cost of 258 field squarings and 33 field multiplications. - var a2, a3, a4, a10, a11, a21, a42, a45, a63, a1019, a1023 fieldVal - a2.SquareVal(f) - a3.Mul2(&a2, f) - a4.SquareVal(&a2) - a10.SquareVal(&a4).Mul(&a2) - a11.Mul2(&a10, f) - a21.Mul2(&a10, &a11) - a42.SquareVal(&a21) - a45.Mul2(&a42, &a3) - a63.Mul2(&a42, &a21) - a1019.SquareVal(&a63).Square().Square().Square().Mul(&a11) - a1023.Mul2(&a1019, &a4) - f.Set(&a63) // f = a^(2^6 - 1) - f.Square().Square().Square().Square().Square() // f = a^(2^11 - 32) - f.Square().Square().Square().Square().Square() // f = a^(2^16 - 1024) - f.Mul(&a1023) // f = a^(2^16 - 1) - f.Square().Square().Square().Square().Square() // f = a^(2^21 - 32) - f.Square().Square().Square().Square().Square() // f = a^(2^26 - 1024) - f.Mul(&a1023) // f = a^(2^26 - 1) - f.Square().Square().Square().Square().Square() // f = a^(2^31 - 32) - f.Square().Square().Square().Square().Square() // f = a^(2^36 - 1024) - f.Mul(&a1023) // f = a^(2^36 - 1) - f.Square().Square().Square().Square().Square() // f = a^(2^41 - 32) - f.Square().Square().Square().Square().Square() // f = a^(2^46 - 1024) - f.Mul(&a1023) // f = a^(2^46 - 1) - f.Square().Square().Square().Square().Square() // f = a^(2^51 - 32) - f.Square().Square().Square().Square().Square() // f = a^(2^56 - 1024) - f.Mul(&a1023) // f = a^(2^56 - 1) - f.Square().Square().Square().Square().Square() // f = a^(2^61 - 32) - f.Square().Square().Square().Square().Square() // f = a^(2^66 - 1024) - f.Mul(&a1023) // f = a^(2^66 - 1) - f.Square().Square().Square().Square().Square() // f = a^(2^71 - 32) - f.Square().Square().Square().Square().Square() // f = a^(2^76 - 1024) - f.Mul(&a1023) // f = a^(2^76 - 1) - f.Square().Square().Square().Square().Square() // f = a^(2^81 - 32) - f.Square().Square().Square().Square().Square() // f = a^(2^86 - 1024) - f.Mul(&a1023) // f = a^(2^86 - 1) - f.Square().Square().Square().Square().Square() // f = a^(2^91 - 32) - f.Square().Square().Square().Square().Square() // f = a^(2^96 - 1024) - f.Mul(&a1023) // f = a^(2^96 - 1) - f.Square().Square().Square().Square().Square() // f = a^(2^101 - 32) - f.Square().Square().Square().Square().Square() // f = a^(2^106 - 1024) - f.Mul(&a1023) // f = a^(2^106 - 1) - f.Square().Square().Square().Square().Square() // f = a^(2^111 - 32) - f.Square().Square().Square().Square().Square() // f = a^(2^116 - 1024) - f.Mul(&a1023) // f = a^(2^116 - 1) - f.Square().Square().Square().Square().Square() // f = a^(2^121 - 32) - f.Square().Square().Square().Square().Square() // f = a^(2^126 - 1024) - f.Mul(&a1023) // f = a^(2^126 - 1) - f.Square().Square().Square().Square().Square() // f = a^(2^131 - 32) - f.Square().Square().Square().Square().Square() // f = a^(2^136 - 1024) - f.Mul(&a1023) // f = a^(2^136 - 1) - f.Square().Square().Square().Square().Square() // f = a^(2^141 - 32) - f.Square().Square().Square().Square().Square() // f = a^(2^146 - 1024) - f.Mul(&a1023) // f = a^(2^146 - 1) - f.Square().Square().Square().Square().Square() // f = a^(2^151 - 32) - f.Square().Square().Square().Square().Square() // f = a^(2^156 - 1024) - f.Mul(&a1023) // f = a^(2^156 - 1) - f.Square().Square().Square().Square().Square() // f = a^(2^161 - 32) - f.Square().Square().Square().Square().Square() // f = a^(2^166 - 1024) - f.Mul(&a1023) // f = a^(2^166 - 1) - f.Square().Square().Square().Square().Square() // f = a^(2^171 - 32) - f.Square().Square().Square().Square().Square() // f = a^(2^176 - 1024) - f.Mul(&a1023) // f = a^(2^176 - 1) - f.Square().Square().Square().Square().Square() // f = a^(2^181 - 32) - f.Square().Square().Square().Square().Square() // f = a^(2^186 - 1024) - f.Mul(&a1023) // f = a^(2^186 - 1) - f.Square().Square().Square().Square().Square() // f = a^(2^191 - 32) - f.Square().Square().Square().Square().Square() // f = a^(2^196 - 1024) - f.Mul(&a1023) // f = a^(2^196 - 1) - f.Square().Square().Square().Square().Square() // f = a^(2^201 - 32) - f.Square().Square().Square().Square().Square() // f = a^(2^206 - 1024) - f.Mul(&a1023) // f = a^(2^206 - 1) - f.Square().Square().Square().Square().Square() // f = a^(2^211 - 32) - f.Square().Square().Square().Square().Square() // f = a^(2^216 - 1024) - f.Mul(&a1023) // f = a^(2^216 - 1) - f.Square().Square().Square().Square().Square() // f = a^(2^221 - 32) - f.Square().Square().Square().Square().Square() // f = a^(2^226 - 1024) - f.Mul(&a1019) // f = a^(2^226 - 5) - f.Square().Square().Square().Square().Square() // f = a^(2^231 - 160) - f.Square().Square().Square().Square().Square() // f = a^(2^236 - 5120) - f.Mul(&a1023) // f = a^(2^236 - 4097) - f.Square().Square().Square().Square().Square() // f = a^(2^241 - 131104) - f.Square().Square().Square().Square().Square() // f = a^(2^246 - 4195328) - f.Mul(&a1023) // f = a^(2^246 - 4194305) - f.Square().Square().Square().Square().Square() // f = a^(2^251 - 134217760) - f.Square().Square().Square().Square().Square() // f = a^(2^256 - 4294968320) - return f.Mul(&a45) // f = a^(2^256 - 4294968275) = a^(p-2) + /* + Fermat's little theorem states that for a nonzero number a and prime + prime p, a^(p-1) = 1 (mod p). Since the multipliciative inverse is + a*b = 1 (mod p), it follows that b = a*a^(p-2) = a^(p-1) = 1 (mod p). + Thus, a^(p-2) is the multiplicative inverse. + In order to efficiently compute a^(p-2), p-2 needs to be split into + a sequence of squares and multipications that minimizes the number of + multiplications needed (since they are more costly than squarings). + Intermediate results are saved and reused as well. + This algorithm came from Brian Smith's site. It follows the analogy of + recreating the binary representation of p-2 where we are counting the + number of factors of f included and squaring effects doubling the number + of factors of f (bitshift by 1) and multiplication adds a number of + factors of f. + https://briansmith.org/ecc-inversion-addition-chains-01 + The secp256k1 prime - 2 is 2^256 - 4294968275. + Binary representation of secp256k1 prime − 2: + 11111111111111111111111111111111111111111111111111111111111111111111111111 + 11111111111111111111111111111111111111111111111111111111111111111111111111 + 11111111111111111111111111111111111111111111111111111111111111111111111111 + 1011111111111111111111110000101101 + 223 ones; 1 zero; 22 ones; 4 zeros; 1 one; 1 zero; 2 ones; 1 zero; 1 one + The algorithm is easier to write in Haskell + secp256k1FieldInverseSquaredExponent double add one = + let + x1 = one + x2 = (nth double 1 `andThen` add x1 ) x1 + x3 = (nth double 1 `andThen` add x1 ) x2 + x11 = (nth double 3 `andThen` add x3 `andThen` + nth double 3 `andThen` add x3 `andThen` + nth double 2 `andThen` add x2 ) x3 + x22 = (nth double 11 `andThen` add x11 ) x11 + x44 = (nth double 22 `andThen` add x22 ) x22 + x88 = (nth double 44 `andThen` add x44 ) x44 + in (nth double 88 `andThen` add x88 `andThen` + nth double 44 `andThen` add x44 `andThen` + nth double 3 `andThen` add x3 `andThen` + nth double (1 + 22) `andThen` add x22 `andThen` + nth double (4 + 1) `andThen` add x1 `andThen` + nth double (1 + 2) `andThen` add x2 `andThen` + nth double 2 `andThen` add x1 ) x88 + --------------------------------------------------------- + -- Total length: 269 = 255 doubles + 15 adds + */ + + // This has a cost of 255 field squarings and 15 field multiplications. + var x1, x2, x3, x11, x22, x44, x88 fieldVal + x1 = *f + x2.SquareVal(&x1).Mul(&x1) + x3.SquareVal(&x2).Mul(&x1) + x11.SquareVal(&x3).Square().Square().Mul(&x3). + Square().Square().Square().Mul(&x3). + Square().Square().Mul(&x2) + x22.SquareVal(&x11).Square().Square().Square().Square(). + Square().Square().Square().Square().Square(). + Square().Mul(&x11) + x44.SquareVal(&x22).Square().Square().Square().Square(). + Square().Square().Square().Square().Square(). + Square().Square().Square().Square().Square(). + Square().Square().Square().Square().Square(). + Square().Square().Mul(&x22) + x88.SquareVal(&x44).Square().Square().Square().Square(). + Square().Square().Square().Square().Square(). + Square().Square().Square().Square().Square(). + Square().Square().Square().Square().Square(). + Square().Square().Square().Square().Square(). + Square().Square().Square().Square().Square(). + Square().Square().Square().Square().Square(). + Square().Square().Square().Square().Square(). + Square().Square().Square().Square().Mul(&x44) + return f.SquareVal(&x88).Square().Square().Square().Square(). + Square().Square().Square().Square().Square(). + Square().Square().Square().Square().Square(). + Square().Square().Square().Square().Square(). + Square().Square().Square().Square().Square(). + Square().Square().Square().Square().Square(). + Square().Square().Square().Square().Square(). + Square().Square().Square().Square().Square(). + Square().Square().Square().Square().Square(). + Square().Square().Square().Square().Square(). + Square().Square().Square().Square().Square(). + Square().Square().Square().Square().Square(). + Square().Square().Square().Square().Square(). + Square().Square().Square().Square().Square(). + Square().Square().Square().Square().Square(). + Square().Square().Square().Square().Square(). + Square().Square().Square().Square().Square(). + Square().Square().Square().Mul(&x88). + Square().Square().Square().Square().Square(). + Square().Square().Square().Square().Square(). + Square().Square().Square().Square().Square(). + Square().Square().Square().Square().Square(). + Square().Square().Square().Square().Square(). + Square().Square().Square().Square().Square(). + Square().Square().Square().Square().Square(). + Square().Square().Square().Square().Square(). + Square().Square().Square().Square().Mul(&x44). + Square().Square().Square().Mul(&x3). + Square().Square().Square().Square().Square(). + Square().Square().Square().Square().Square(). + Square().Square().Square().Square().Square(). + Square().Square().Square().Square().Square(). + Square().Square().Square().Mul(&x22). + Square().Square().Square().Square().Square(). + Mul(&x1). + Square().Square().Square().Mul(&x2). + Square().Square().Mul(&x1) }