Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 26 additions & 26 deletions YatimaStdLib/AddChain.lean
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ import YatimaStdLib.Nat
/-!
# AddChain

This module implements an efficient representation of numbers in terms of the double-and-add
This module implements an efficient representation of numbers in terms of the double-and-add
algorithm.

The finding an `AddChain` with `minChain` can be used to pre-calculate the function which
implements a double-and-add or square-and-multiply represention of a natural number.

The types are left polymorphic to allow for more efficient implementations of `double` in the
The types are left polymorphic to allow for more efficient implementations of `double` in the
case of `Chainable`, or `square` for `Square`.

## References
Expand All @@ -22,13 +22,13 @@ case of `Chainable`, or `square` for `Square`.

/-- The typeclass which implements efficient `add`, `mul`, `double`, and `doubleAdd` methods -/
class Chainable (α : Type _) extends BEq α, OfNat α (nat_lit 1) where
add : α → α → α
mul : α → α → α
add : α → α → α
mul : α → α → α
double : α → α := fun x => add x x
doubleAdd : α → α → α := fun x y => add (double x) y
doubleAdd : α → α → α := fun x y => add (double x) y

/-
In this section we include the basic instances for `Chainable` that exist in the core
In this section we include the basic instances for `Chainable` that exist in the core
numerical libraries of Lean
-/
section instances
Expand All @@ -53,7 +53,7 @@ inductive ChainStep | add (idx₁ idx₂ : Nat) | double (idx : Nat)
deriving Repr

/--
The chain of operations that can be used to represent a `Chainable` n in terms of the
The chain of operations that can be used to represent a `Chainable` n in terms of the
double-and-add algorithm
-/
def AddChain (α : Type _) [Chainable α] := Array α
Expand All @@ -62,23 +62,23 @@ instance [Chainable α] : Inhabited (AddChain α) where
default := #[1]

instance [Chainable α] : HAdd (AddChain α) α (AddChain α) where
hAdd ch n := ch.push (n + ch.back)
hAdd ch n := ch.push (n + ch.back!)

instance [Chainable α] : Mul (AddChain α) where
mul ch₁ ch₂ :=
let last := ch₁.back
let ch₂' := ch₂.last.map (fun x => x * last)
mul ch₁ ch₂ :=
let last := ch₁.back!
let ch₂' := ch₂.last.map (fun x => x * last)
ch₁.append ch₂'

/-
In this section we implement an efficient algorithm to calculate the minimal AddChain for a natural
/-
In this section we implement an efficient algorithm to calculate the minimal AddChain for a natural
number
-/
namespace Nat

mutual
mutual

private partial def addChain (n k : Nat) : AddChain Nat :=
private partial def addChain (n k : Nat) : AddChain Nat :=
let (q, r) := n.quotRem k
if r == 0 || r == 1 then
minChain k * minChain q + r else
Expand All @@ -88,7 +88,7 @@ private partial def addChain (n k : Nat) : AddChain Nat :=
partial def minChain (n : Nat) : AddChain Nat :=
let logN := n.log2
if n == (1 <<< logN) then
Array.iota logN |>.map fun k => 2^k
Array.iota logN |>.map fun k => 2^k
else if n == 3 then
#[1, 2, 3] else
let k := n / (1 <<< (logN/2))
Expand Down Expand Up @@ -124,8 +124,8 @@ def buildSteps [Chainable α] (ch : AddChain α) : Array ChainStep := Id.run do

end AddChain

/--
The function which returns the `AddChain` and the `Array ChainStep` to represent the natural number
/--
The function which returns the `AddChain` and the `Array ChainStep` to represent the natural number
`n` in terms of the double-and-add representation
-/
def Nat.buildAddChain (n : Nat) : AddChain Nat × Array ChainStep :=
Expand All @@ -137,7 +137,7 @@ class Square (α : Type _) extends OfNat α (nat_lit 1) where
mul : α → α → α
square : α → α

namespace Square
namespace Square

instance [Square α] : Inhabited α where
default := 1
Expand All @@ -155,27 +155,27 @@ instance [Square α] : Mul α where

for step in steps.toList do
match step with
| .add left right =>
| .add left right =>
answer := answer.push (answer[left]! * answer[right]!)
| .double idx =>
| .double idx =>
answer := answer.push (square answer[idx]!)
answer.back

answer.back!

end Square

namespace Exp

open Square in
/-- Returns the function `n ↦ n ^ exp` by pre-calculating the `AddChain` for `exp`. -/
@[specialize] def fastExpFunc [Square α] (exp : Nat) : α → α :=
@[specialize] def fastExpFunc [Square α] (exp : Nat) : α → α :=
let (_ , steps) := exp.buildAddChain
chainExp steps

/-- A fast implementation of exponentation based off an `AddChain` representation of `exp`. -/
@[specialize] def fastExp [Square α] (n : α) (exp : Nat) : α := fastExpFunc exp n

instance (priority := low) [Square α] : HPow α Nat α where
hPow n pow := fastExp n pow
hPow n pow := fastExp n pow

end Exp
end Exp
10 changes: 5 additions & 5 deletions YatimaStdLib/Arithmetic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ open Nat (powMod)

/--
An implementation of the probabilistic Miller-Rabin primality test. Returns `false` if it can be
verified that `n` is not prime, and returns `true` if `n` is probably prime after `k` loops, which
verified that `n` is not prime, and returns `true` if `n` is probably prime after `k` loops, which
may return a false positive with probability `1 / 4^k` assuming the ψRNG `gen` doesn't conspire
against us in some unexpected way
-/
Expand All @@ -14,7 +14,7 @@ def millerRabinTest (n k : Nat) : Bool :=
-- let exp : Nat → Nat := Exp.fastExpFunc d -- TODO: Use AddChains once we have an efficient Zmod
Id.run do
let mut a := 0
let mut gen := mkStdGen (n + k) -- Using Lean's built in ψRNG
let mut gen := mkStdGen (n + k) -- Using Lean's built in ψRNG
for _ in [:k] do
(a, gen) := randNat gen 2 (n - 2)
let mut x := powMod n a d
Expand All @@ -29,9 +29,9 @@ def millerRabinTest (n k : Nat) : Bool :=
return true

open Std in
/--
/--
Calculates the discrete logarithm of using the Babystep-Giantstep algorithm, should have `O(√n)`
runtime
runtime
-/
def dLog (base result mod : Nat) : Option Nat := do
let mut basePowers : HashMap Nat Nat := .empty
Expand All @@ -47,7 +47,7 @@ def dLog (base result mod : Nat) : Option Nat := do
let mut target := result

for quot in [:lim] do
match basePowers.find? target with
match basePowers.get? target with
| some rem => return quot * lim + rem
| none => target := target * basePowInv % mod

Expand Down
10 changes: 5 additions & 5 deletions YatimaStdLib/Array.lean
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import Std.Data.Array.Basic
import Batteries.Data.Array.Basic
import YatimaStdLib.List

namespace Array
Expand All @@ -10,7 +10,7 @@ def iota (n : Nat) : Array Nat :=
instance : Monad Array where
map := Array.map
pure x := #[x]
bind l f := Array.join $ Array.map f l
bind l f := Array.flatten $ Array.map f l

def shuffle (ar : Array α) (seed : Option Nat := none) [Inhabited α] :
IO $ Array α := do
Expand All @@ -29,17 +29,17 @@ def pad (ar : Array α) (a : α) (n : Nat) : Array α :=
ar ++ (.mkArray diff a)

instance [Ord α] : Ord (Array α) where
compare x y := compare x.data y.data
compare x y := compare x.toList y.toList

def last (ar : Array α) : Array α := ar.toSubarray.popFront.toArray

theorem append_size (arr₁ arr₂ : Array α) (h1 : arr₁.size = n) (h2 : arr₂.size = m)
theorem append_size (arr₁ arr₂ : Array α) (h1 : arr₁.size = n) (h2 : arr₂.size = m)
: (arr₁ ++ arr₂).size = n + m := by
unfold Array.size at *
simp [h1, h2]

def stdSizes (maxSize : Nat) := Array.iota maxSize |>.map (2 ^ ·)

def average (arr : Array Nat) : Nat :=
def average (arr : Array Nat) : Nat :=
let sum := arr.foldl (init := 0) fun acc a => acc + a
sum / arr.size
Loading
Loading