Skip to content

Commit 1fd7206

Browse files
authored
feat: match-expressions with congruence equation theorems (leanprover#8506)
This PR implements `match`-expressions in `grind` using `match` congruence equations. The goal is to minimize the number of `cast` operations that need to be inserted, and avoid `cast` over functions. The new approach support `match`-expressions of the form `match h : ... with ...`.
1 parent a6e76b4 commit 1fd7206

File tree

13 files changed

+349
-71
lines changed

13 files changed

+349
-71
lines changed

src/Init/Grind/Util.lean

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,55 @@ def offset (a b : Nat) : Nat := a + b
3232
/-- Gadget for representing `a = b` in patterns for backward propagation. -/
3333
def eqBwdPattern (a b : α) : Prop := a = b
3434

35+
/--
36+
Gadget for representing generalization steps `h : x = val` in patterns
37+
This gadget is used to represent patterns in theorems that have been generalized to reduce the
38+
number of casts introduced during E-matching based instantiation.
39+
40+
For example, consider the theorem
41+
```
42+
Option.pbind_some {α1 : Type u_1} {a : α1} {α2 : Type u_2}
43+
{f : (a_1 : α1) → some a = some a_1 → Option α2}
44+
: (some a).pbind f = f a rfl
45+
```
46+
Now, suppose we have a goal containing the term `c.pbind g` and the equivalence class
47+
`{c, some b}`. The E-matching module generates the instance
48+
```
49+
(some b).pbind (cast ⋯ g)
50+
```
51+
The `cast` is necessary because `g`'s type contains `c` instead of `some b.
52+
This `cast` problematic because we don't have a systematic way of pushing casts over functions
53+
to its arguments. Moreover, heterogeneous equality is not effective because the following theorem
54+
is not provable in DTT:
55+
```
56+
theorem hcongr (h₁ : f ≍ g) (h₂ : a ≍ b) : f a ≍ g b := ...
57+
```
58+
The standard solution is to generalize the theorem above and write it as
59+
```
60+
theorem Option.pbind_some'
61+
{α1 : Type u_1} {a : α1} {α2 : Type u_2}
62+
{x : Option α1}
63+
{f : (a_1 : α1) → x = some a_1 → Option α2}
64+
(h : x = some a)
65+
: x.pbind f = f a h := by
66+
subst h
67+
apply Option.pbind_some
68+
```
69+
Internally, we use this gadget to mark the E-matching pattern as
70+
```
71+
(genPattern h x (some a)).pbind f
72+
```
73+
This pattern is matched in the same way we match `(some a).pbind f`, but it saves the proof
74+
for the actual term to the `some`-application in `f`, and the actual term in `x`.
75+
76+
In the example above, `c.pbind g` also matches the pattern `(genPattern h x (some a)).pbind f`,
77+
and stores `c` in `x`, `b` in `a`, and the proof that `c = some b` in `h`.
78+
-/
79+
def genPattern {α : Sort u} (_h : Prop) (x : α) (_val : α) : α := x
80+
81+
/-- Similar to `genPattern` but for the heterogenous case -/
82+
def genHEqPattern {α β : Sort u} (_h : Prop) (x : α) (_val : β) : α := x
83+
3584
/--
3685
Gadget for annotating the equalities in `match`-equations conclusions.
3786
`_origin` is the term used to instantiate the `match`-equation using E-matching.

src/Lean/Elab/Tactic/Grind.lean

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -214,27 +214,26 @@ def mkGrindOnly
214214
let mut foundFns : NameSet := {}
215215
for { origin, kind } in trace.thms.toList do
216216
if let .decl declName := origin then
217-
unless Match.isMatchEqnTheorem (← getEnv) declName do
218-
if let some declName ← isEqnThm? declName then
219-
unless foundFns.contains declName do
220-
foundFns := foundFns.insert declName
221-
let decl : Ident := mkIdent (← unresolveNameGlobalAvoidingLocals declName)
222-
let param ← `(Parser.Tactic.grindParam| $decl:ident)
223-
params := params.push param
224-
else
217+
if let some declName ← isEqnThm? declName then
218+
unless foundFns.contains declName do
219+
foundFns := foundFns.insert declName
225220
let decl : Ident := mkIdent (← unresolveNameGlobalAvoidingLocals declName)
226-
let param ← match kind with
227-
| .eqLhs => `(Parser.Tactic.grindParam| = $decl)
228-
| .eqRhs => `(Parser.Tactic.grindParam| =_ $decl)
229-
| .eqBoth => `(Parser.Tactic.grindParam| _=_ $decl)
230-
| .eqBwd => `(Parser.Tactic.grindParam| ←= $decl)
231-
| .bwd => `(Parser.Tactic.grindParam| ← $decl)
232-
| .fwd => `(Parser.Tactic.grindParam| → $decl)
233-
| .leftRight => `(Parser.Tactic.grindParam| => $decl)
234-
| .rightLeft => `(Parser.Tactic.grindParam| <= $decl)
235-
| .user => `(Parser.Tactic.grindParam| usr $decl)
236-
| .default => `(Parser.Tactic.grindParam| $decl:ident)
221+
let param ← `(Parser.Tactic.grindParam| $decl:ident)
237222
params := params.push param
223+
else
224+
let decl : Ident := mkIdent (← unresolveNameGlobalAvoidingLocals declName)
225+
let param ← match kind with
226+
| .eqLhs => `(Parser.Tactic.grindParam| = $decl)
227+
| .eqRhs => `(Parser.Tactic.grindParam| =_ $decl)
228+
| .eqBoth => `(Parser.Tactic.grindParam| _=_ $decl)
229+
| .eqBwd => `(Parser.Tactic.grindParam| ←= $decl)
230+
| .bwd => `(Parser.Tactic.grindParam| ← $decl)
231+
| .fwd => `(Parser.Tactic.grindParam| → $decl)
232+
| .leftRight => `(Parser.Tactic.grindParam| => $decl)
233+
| .rightLeft => `(Parser.Tactic.grindParam| <= $decl)
234+
| .user => `(Parser.Tactic.grindParam| usr $decl)
235+
| .default => `(Parser.Tactic.grindParam| $decl:ident)
236+
params := params.push param
238237
for declName in trace.eagerCases.toList do
239238
unless Grind.isBuiltinEagerCases declName do
240239
let decl : Ident := mkIdent (← unresolveNameGlobalAvoidingLocals declName)

src/Lean/Meta/Match/MatchEqs.lean

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -833,7 +833,7 @@ def congrEqn1ThmSuffix := congrEqnThmSuffixBasePrefix ++ "1"
833833
example : congrEqn1ThmSuffix = "congr_eq_1" := rfl
834834

835835
/-- Returns `true` if `s` is of the form `congr_eq_<idx>` -/
836-
def iscongrEqnReservedNameSuffix (s : String) : Bool :=
836+
def isCongrEqnReservedNameSuffix (s : String) : Bool :=
837837
congrEqnThmSuffixBasePrefix.isPrefixOf s && (s.drop congrEqnThmSuffixBasePrefix.length).isNat
838838

839839
/- We generate the equations and splitter on demand, and do not save them on .olean files. -/
@@ -928,10 +928,10 @@ builtin_initialize registerTraceClass `Meta.Match.matchEqs
928928

929929
private def isMatchEqName? (env : Environment) (n : Name) : Option (Name × Bool) := do
930930
let .str p s := n | failure
931-
guard <| isEqnReservedNameSuffix s || s == "splitter" || iscongrEqnReservedNameSuffix s
931+
guard <| isEqnReservedNameSuffix s || s == "splitter" || isCongrEqnReservedNameSuffix s
932932
let p ← privateToUserName? p
933933
guard <| isMatcherCore env p
934-
return (p, iscongrEqnReservedNameSuffix s)
934+
return (p, isCongrEqnReservedNameSuffix s)
935935

936936
builtin_initialize registerReservedNamePredicate (isMatchEqName? · · |>.isSome)
937937

src/Lean/Meta/Tactic/Grind/EMatch.lean

Lines changed: 113 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ namespace EMatch
1717
/-- We represent an `E-matching` problem as a list of constraints. -/
1818
inductive Cnstr where
1919
| /-- Matches pattern `pat` with term `e` -/
20-
«match» (pat : Expr) (e : Expr)
20+
«match» (gen? : Option GenPatternInfo) (pat : Expr) (e : Expr)
2121
| /-- Matches offset pattern `pat+k` with term `e` -/
22-
offset (pat : Expr) (k : Nat) (e : Expr)
22+
offset (gen? : Option GenPatternInfo) (pat : Expr) (k : Nat) (e : Expr)
2323
| /-- This constraint is used to encode multi-patterns. -/
2424
«continue» (pat : Expr)
2525
deriving Inhabited
@@ -30,6 +30,12 @@ This is a small hack to avoid one extra level of indirection by using `Option Ex
3030
-/
3131
private def unassigned : Expr := mkConst (Name.mkSimple "[grind_unassigned]")
3232

33+
/--
34+
Internal "marker" for representing equality proofs for generalized patterns.
35+
They must be synthesized after we have the partial assignment.
36+
-/
37+
private def delayedEqProof : Expr := mkConst (Name.mkSimple "[grind_delayed_eq_proof]")
38+
3339
private def assignmentToMessageData (assignment : Array Expr) : Array MessageData :=
3440
assignment.reverse.map fun e =>
3541
if isSameExpr e unassigned then m!"_" else m!"{e}"
@@ -89,6 +95,22 @@ private def assign? (c : Choice) (bidx : Nat) (e : Expr) : OptionT GoalM Choice
8995
-- `Choice` was not properly initialized
9096
unreachable!
9197

98+
/--
99+
Assigns `bidx` with the marker for a delayed equality proof for generalized patterns.
100+
The proof is assigned after we have the complete assignment.
101+
-/
102+
private def assignDelayedEqProof? (c : Choice) (bidx : Nat) : OptionT GoalM Choice := do
103+
if h : bidx < c.assignment.size then
104+
let v := c.assignment[bidx]
105+
if isSameExpr v unassigned then
106+
return { c with assignment := c.assignment.set bidx delayedEqProof }
107+
else
108+
return c
109+
else
110+
-- `Choice` was not properly initialized
111+
unreachable!
112+
113+
92114
private def unassign (c : Choice) (bidx : Nat) : Choice :=
93115
{ c with assignment := c.assignment.set! bidx unassigned }
94116

@@ -100,6 +122,11 @@ private def eqvFunctions (pFn eFn : Expr) : Bool :=
100122
(pFn.isFVar && pFn == eFn)
101123
|| (pFn.isConst && eFn.isConstOf pFn.constName!)
102124

125+
protected def _root_.Lean.Meta.Grind.GenPatternInfo.assign? (genInfo : GenPatternInfo) (c : Choice) (x : Expr) : OptionT GoalM Choice := do
126+
let c ← assign? c genInfo.xIdx x
127+
let c ← assignDelayedEqProof? c genInfo.hIdx
128+
return c
129+
103130
/-- Matches a pattern argument. See `matchArgs?`. -/
104131
private def matchArg? (c : Choice) (pArg : Expr) (eArg : Expr) : OptionT GoalM Choice := do
105132
if isPatternDontCare pArg then
@@ -128,9 +155,20 @@ private def matchArg? (c : Choice) (pArg : Expr) (eArg : Expr) : OptionT GoalM C
128155
else if let some (pArg, k) := isOffsetPattern? pArg then
129156
assert! Option.isNone <| isOffsetPattern? pArg
130157
assert! !isPatternDontCare pArg
131-
return { c with cnstrs := .offset pArg k eArg :: c.cnstrs }
158+
return { c with cnstrs := .offset none pArg k eArg :: c.cnstrs }
159+
else if let some (genInfo, pArg) := isGenPattern? pArg then
160+
if pArg.isBVar then
161+
let c ← assign? c pArg.bvarIdx! eArg
162+
genInfo.assign? c eArg
163+
else if let some pArg := groundPattern? pArg then
164+
guard (← isEqv pArg eArg <||> withReducibleAndInstances (isDefEq pArg eArg))
165+
genInfo.assign? c eArg
166+
else if let some (pArg, k) := isOffsetPattern? pArg then
167+
return { c with cnstrs := .offset (some genInfo) pArg k eArg :: c.cnstrs }
168+
else
169+
return { c with cnstrs := .match (some genInfo) pArg eArg :: c.cnstrs }
132170
else
133-
return { c with cnstrs := .match pArg eArg :: c.cnstrs }
171+
return { c with cnstrs := .match none pArg eArg :: c.cnstrs }
134172

135173
private def Choice.updateGen (c : Choice) (gen : Nat) : Choice :=
136174
{ c with gen := Nat.max gen c.gen }
@@ -160,25 +198,35 @@ private partial def matchArgsPrefix? (c : Choice) (p : Expr) (e : Expr) : Option
160198
else
161199
matchArgs? c p (e.getAppPrefix pn)
162200

201+
-- Helper function for `processMatch` and `processGenMatch`
202+
@[inline] private def isCandidate (n : ENode) (pFn : Expr) (pNumArgs : Nat) (maxGeneration : Nat) : Bool :=
203+
-- Remark: we use `<` because the instance generation is the maximum term generation + 1
204+
n.generation < maxGeneration
205+
-- uses heterogeneous equality or is the root of its congruence class
206+
&& (n.heqProofs || n.isCongrRoot)
207+
&& eqvFunctions pFn n.self.getAppFn
208+
&& n.self.getAppNumArgs == pNumArgs
209+
210+
private def assignGenInfo? (genInfo? : Option GenPatternInfo) (c : Choice) (x : Expr) : OptionT GoalM Choice := do
211+
let some genInfo := genInfo? | return c
212+
genInfo.assign? c x
213+
163214
/--
164215
Matches pattern `p` with term `e` with respect to choice `c`.
165216
We traverse the equivalence class of `e` looking for applications compatible with `p`.
166217
For each candidate application, we match the arguments and may update `c`s assignments and constraints.
167218
We add the updated choices to the choice stack.
168219
-/
169-
private partial def processMatch (c : Choice) (p : Expr) (e : Expr) : M Unit := do
220+
private partial def processMatch (c : Choice) (genInfo? : Option GenPatternInfo) (p : Expr) (e : Expr) : M Unit := do
170221
let maxGeneration ← getMaxGeneration
171222
let pFn := p.getAppFn
172223
let numArgs := p.getAppNumArgs
173224
let mut curr := e
174225
repeat
175226
let n ← getENode curr
176227
-- Remark: we use `<` because the instance generation is the maximum term generation + 1
177-
if n.generation < maxGeneration
178-
-- uses heterogeneous equality or is the root of its congruence class
179-
&& (n.heqProofs || n.isCongrRoot)
180-
&& eqvFunctions pFn curr.getAppFn
181-
&& curr.getAppNumArgs == numArgs then
228+
if isCandidate n pFn numArgs maxGeneration then
229+
if let some c ← assignGenInfo? genInfo? c e |>.run then
182230
if let some c ← matchArgs? c p curr |>.run then
183231
pushChoice (c.updateGen n.generation)
184232
curr ← getNext curr
@@ -187,32 +235,34 @@ private partial def processMatch (c : Choice) (p : Expr) (e : Expr) : M Unit :=
187235
/--
188236
Matches offset pattern `pArg+k` with term `e` with respect to choice `c`.
189237
-/
190-
private partial def processOffset (c : Choice) (pArg : Expr) (k : Nat) (e : Expr) : M Unit := do
238+
private partial def processOffset (c : Choice) (genInfo? : Option GenPatternInfo) (pArg : Expr) (k : Nat) (e : Expr) : M Unit := do
191239
let maxGeneration ← getMaxGeneration
192240
let mut curr := e
193241
repeat
194242
let n ← getENode curr
195243
if n.generation < maxGeneration then
196244
if let some (eArg, k') ← isOffset? curr |>.run then
197-
if k' < k then
198-
let c := c.updateGen n.generation
199-
pushChoice { c with cnstrs := .offset pArg (k - k') eArg :: c.cnstrs }
200-
else if k' == k then
201-
if let some c ← matchArg? c pArg eArg |>.run then
202-
pushChoice (c.updateGen n.generation)
203-
else if k' > k then
204-
let eArg' := mkNatAdd eArg (mkNatLit (k' - k))
205-
let eArg' ← shareCommon (← canon eArg')
206-
internalize eArg' (n.generation+1)
207-
if let some c ← matchArg? c pArg eArg' |>.run then
208-
pushChoice (c.updateGen n.generation)
245+
if let some c ← assignGenInfo? genInfo? c e |>.run then
246+
if k' < k then
247+
let c := c.updateGen n.generation
248+
pushChoice { c with cnstrs := .offset none pArg (k - k') eArg :: c.cnstrs }
249+
else if k' == k then
250+
if let some c ← matchArg? c pArg eArg |>.run then
251+
pushChoice (c.updateGen n.generation)
252+
else if k' > k then
253+
let eArg' := mkNatAdd eArg (mkNatLit (k' - k))
254+
let eArg' ← shareCommon (← canon eArg')
255+
internalize eArg' (n.generation+1)
256+
if let some c ← matchArg? c pArg eArg' |>.run then
257+
pushChoice (c.updateGen n.generation)
209258
else if let some k' ← evalNat curr |>.run then
210-
if k' >= k then
211-
let eArg' := mkNatLit (k' - k)
212-
let eArg' ← shareCommon (← canon eArg')
213-
internalize eArg' (n.generation+1)
214-
if let some c ← matchArg? c pArg eArg' |>.run then
215-
pushChoice (c.updateGen n.generation)
259+
if let some c ← assignGenInfo? genInfo? c e |>.run then
260+
if k' >= k then
261+
let eArg' := mkNatLit (k' - k)
262+
let eArg' ← shareCommon (← canon eArg')
263+
internalize eArg' (n.generation+1)
264+
if let some c ← matchArg? c pArg eArg' |>.run then
265+
pushChoice (c.updateGen n.generation)
216266
curr ← getNext curr
217267
if isSameExpr curr e then break
218268

@@ -280,7 +330,7 @@ private def addNewInstance (thm : EMatchTheorem) (proof : Expr) (generation : Na
280330
check proof
281331
let mut prop ← inferType proof
282332
let mut proof := proof
283-
if Match.isMatchEqnTheorem (← getEnv) thm.origin.key then
333+
if (← isMatchEqLikeDeclName thm.origin.key) then
284334
prop ← annotateMatchEqnType prop (← read).initApp
285335
-- We must add a hint here because `annotateMatchEqnType` introduces `simpMatchDiscrsOnly` and
286336
-- `Grind.PreMatchCond` which are not reducible.
@@ -302,13 +352,43 @@ private def synthesizeInsts (mvars : Array Expr) (bis : Array BinderInfo) : Opti
302352
reportIssue! "failed to synthesize instance when instantiating {← thm.origin.pp}{indentExpr type}"
303353
failure
304354

355+
private def preprocessGeneralizedPatternRHS (lhs : Expr) (rhs : Expr) (origin : Origin) (expectedType : Expr) : OptionT (StateT Choice M) Expr := do
356+
assert! (← alreadyInternalized lhs)
357+
-- We use `dsimp` here to ensure terms such as `Nat.succ x` are normalized as `x+1`.
358+
let rhs ← preprocessLight (← dsimpCore rhs)
359+
internalize rhs (← getGeneration lhs)
360+
processNewFacts
361+
if (← isEqv lhs rhs) then
362+
return rhs
363+
else
364+
reportIssue! "invalid generalized pattern at `{← origin.pp}`\nwhen processing argument with type{indentExpr expectedType}\nfailed to prove{indentExpr lhs}\nis equal to{indentExpr rhs}"
365+
failure
366+
367+
private def assignGeneralizedPatternProof (mvarId : MVarId) (eqProof : Expr) (origin : Origin) : OptionT (StateT Choice M) Unit := do
368+
unless (← mvarId.checkedAssign eqProof) do
369+
reportIssue! "invalid generalized pattern at `{← origin.pp}`\nfailed to assign {mkMVar mvarId}\nwith{indentExpr eqProof}"
370+
failure
371+
305372
private def applyAssignment (mvars : Array Expr) : OptionT (StateT Choice M) Unit := do
306373
let thm := (← read).thm
307374
let numParams := thm.numParams
308375
for h : i in [:mvars.size] do
309376
let bidx := numParams - i - 1
310377
let mut v := (← get).assignment[bidx]!
311-
unless isSameExpr v unassigned do
378+
if isSameExpr v delayedEqProof then
379+
let mvarId := mvars[i].mvarId!
380+
let mvarIdType ← instantiateMVars (← mvarId.getType)
381+
match_expr mvarIdType with
382+
| Eq α lhs rhs =>
383+
let rhs ← preprocessGeneralizedPatternRHS lhs rhs thm.origin mvarIdType
384+
assignGeneralizedPatternProof mvarId (← mkEqProof lhs rhs) thm.origin
385+
| HEq α lhs β rhs =>
386+
let rhs ← preprocessGeneralizedPatternRHS lhs rhs thm.origin mvarIdType
387+
assignGeneralizedPatternProof mvarId (← mkHEqProof lhs rhs) thm.origin
388+
| _ =>
389+
reportIssue! "invalid generalized pattern at `{← thm.origin.pp}`\nequality type expected{indentExpr mvarIdType}"
390+
failure
391+
else if !isSameExpr v unassigned then
312392
let mvarId := mvars[i].mvarId!
313393
let mvarIdType ← mvarId.getType
314394
let vType ← inferType v
@@ -388,8 +468,8 @@ private def processChoices : M Unit := do
388468
if c.gen < maxGeneration then
389469
match c.cnstrs with
390470
| [] => instantiateTheorem c
391-
| .match p e :: cnstrs => processMatch { c with cnstrs } p e
392-
| .offset p k e :: cnstrs => processOffset { c with cnstrs } p k e
471+
| .match genInfo? p e :: cnstrs => processMatch { c with cnstrs } genInfo? p e
472+
| .offset genInfo? p k e :: cnstrs => processOffset { c with cnstrs } genInfo? p k e
393473
| .continue p :: cnstrs => processContinue { c with cnstrs } p
394474

395475
private def main (p : Expr) (cnstrs : List Cnstr) : M Unit := do

0 commit comments

Comments
 (0)