@@ -17,9 +17,9 @@ namespace EMatch
1717/-- We represent an `E-matching` problem as a list of constraints. -/
1818inductive Cnstr where
1919 | /-- Matches pattern `pat` with term `e` -/
20- «match » (pat : Expr) (e : Expr)
20+ «match » (gen? : Option GenPatternInfo) ( pat : Expr) (e : Expr)
2121 | /-- Matches offset pattern `pat+k` with term `e` -/
22- offset (pat : Expr) (k : Nat) (e : Expr)
22+ offset (gen? : Option GenPatternInfo) ( pat : Expr) (k : Nat) (e : Expr)
2323 | /-- This constraint is used to encode multi-patterns. -/
2424 «continue » (pat : Expr)
2525 deriving Inhabited
@@ -30,6 +30,12 @@ This is a small hack to avoid one extra level of indirection by using `Option Ex
3030-/
3131private def unassigned : Expr := mkConst (Name.mkSimple "[grind_unassigned]" )
3232
33+ /--
34+ Internal "marker" for representing equality proofs for generalized patterns.
35+ They must be synthesized after we have the partial assignment.
36+ -/
37+ private def delayedEqProof : Expr := mkConst (Name.mkSimple "[grind_delayed_eq_proof]" )
38+
3339private def assignmentToMessageData (assignment : Array Expr) : Array MessageData :=
3440 assignment.reverse.map fun e =>
3541 if isSameExpr e unassigned then m!"_" else m!"{e}"
@@ -89,6 +95,22 @@ private def assign? (c : Choice) (bidx : Nat) (e : Expr) : OptionT GoalM Choice
8995 -- `Choice` was not properly initialized
9096 unreachable!
9197
98+ /--
99+ Assigns `bidx` with the marker for a delayed equality proof for generalized patterns.
100+ The proof is assigned after we have the complete assignment.
101+ -/
102+ private def assignDelayedEqProof? (c : Choice) (bidx : Nat) : OptionT GoalM Choice := do
103+ if h : bidx < c.assignment.size then
104+ let v := c.assignment[bidx]
105+ if isSameExpr v unassigned then
106+ return { c with assignment := c.assignment.set bidx delayedEqProof }
107+ else
108+ return c
109+ else
110+ -- `Choice` was not properly initialized
111+ unreachable!
112+
113+
92114private def unassign (c : Choice) (bidx : Nat) : Choice :=
93115 { c with assignment := c.assignment.set! bidx unassigned }
94116
@@ -100,6 +122,11 @@ private def eqvFunctions (pFn eFn : Expr) : Bool :=
100122 (pFn.isFVar && pFn == eFn)
101123 || (pFn.isConst && eFn.isConstOf pFn.constName!)
102124
125+ protected def _root_.Lean.Meta.Grind.GenPatternInfo.assign? (genInfo : GenPatternInfo) (c : Choice) (x : Expr) : OptionT GoalM Choice := do
126+ let c ← assign? c genInfo.xIdx x
127+ let c ← assignDelayedEqProof? c genInfo.hIdx
128+ return c
129+
103130/-- Matches a pattern argument. See `matchArgs?`. -/
104131private def matchArg? (c : Choice) (pArg : Expr) (eArg : Expr) : OptionT GoalM Choice := do
105132 if isPatternDontCare pArg then
@@ -128,9 +155,20 @@ private def matchArg? (c : Choice) (pArg : Expr) (eArg : Expr) : OptionT GoalM C
128155 else if let some (pArg, k) := isOffsetPattern? pArg then
129156 assert! Option.isNone <| isOffsetPattern? pArg
130157 assert! !isPatternDontCare pArg
131- return { c with cnstrs := .offset pArg k eArg :: c.cnstrs }
158+ return { c with cnstrs := .offset none pArg k eArg :: c.cnstrs }
159+ else if let some (genInfo, pArg) := isGenPattern? pArg then
160+ if pArg.isBVar then
161+ let c ← assign? c pArg.bvarIdx! eArg
162+ genInfo.assign? c eArg
163+ else if let some pArg := groundPattern? pArg then
164+ guard (← isEqv pArg eArg <||> withReducibleAndInstances (isDefEq pArg eArg))
165+ genInfo.assign? c eArg
166+ else if let some (pArg, k) := isOffsetPattern? pArg then
167+ return { c with cnstrs := .offset (some genInfo) pArg k eArg :: c.cnstrs }
168+ else
169+ return { c with cnstrs := .match (some genInfo) pArg eArg :: c.cnstrs }
132170 else
133- return { c with cnstrs := .match pArg eArg :: c.cnstrs }
171+ return { c with cnstrs := .match none pArg eArg :: c.cnstrs }
134172
135173private def Choice.updateGen (c : Choice) (gen : Nat) : Choice :=
136174 { c with gen := Nat.max gen c.gen }
@@ -160,25 +198,35 @@ private partial def matchArgsPrefix? (c : Choice) (p : Expr) (e : Expr) : Option
160198 else
161199 matchArgs? c p (e.getAppPrefix pn)
162200
201+ -- Helper function for `processMatch` and `processGenMatch`
202+ @[inline] private def isCandidate (n : ENode) (pFn : Expr) (pNumArgs : Nat) (maxGeneration : Nat) : Bool :=
203+ -- Remark: we use `<` because the instance generation is the maximum term generation + 1
204+ n.generation < maxGeneration
205+ -- uses heterogeneous equality or is the root of its congruence class
206+ && (n.heqProofs || n.isCongrRoot)
207+ && eqvFunctions pFn n.self.getAppFn
208+ && n.self.getAppNumArgs == pNumArgs
209+
210+ private def assignGenInfo? (genInfo? : Option GenPatternInfo) (c : Choice) (x : Expr) : OptionT GoalM Choice := do
211+ let some genInfo := genInfo? | return c
212+ genInfo.assign? c x
213+
163214/--
164215Matches pattern `p` with term `e` with respect to choice `c`.
165216We traverse the equivalence class of `e` looking for applications compatible with `p`.
166217For each candidate application, we match the arguments and may update `c`s assignments and constraints.
167218We add the updated choices to the choice stack.
168219-/
169- private partial def processMatch (c : Choice) (p : Expr) (e : Expr) : M Unit := do
220+ private partial def processMatch (c : Choice) (genInfo? : Option GenPatternInfo) ( p : Expr) (e : Expr) : M Unit := do
170221 let maxGeneration ← getMaxGeneration
171222 let pFn := p.getAppFn
172223 let numArgs := p.getAppNumArgs
173224 let mut curr := e
174225 repeat
175226 let n ← getENode curr
176227 -- Remark: we use `<` because the instance generation is the maximum term generation + 1
177- if n.generation < maxGeneration
178- -- uses heterogeneous equality or is the root of its congruence class
179- && (n.heqProofs || n.isCongrRoot)
180- && eqvFunctions pFn curr.getAppFn
181- && curr.getAppNumArgs == numArgs then
228+ if isCandidate n pFn numArgs maxGeneration then
229+ if let some c ← assignGenInfo? genInfo? c e |>.run then
182230 if let some c ← matchArgs? c p curr |>.run then
183231 pushChoice (c.updateGen n.generation)
184232 curr ← getNext curr
@@ -187,32 +235,34 @@ private partial def processMatch (c : Choice) (p : Expr) (e : Expr) : M Unit :=
187235/--
188236Matches offset pattern `pArg+k` with term `e` with respect to choice `c`.
189237-/
190- private partial def processOffset (c : Choice) (pArg : Expr) (k : Nat) (e : Expr) : M Unit := do
238+ private partial def processOffset (c : Choice) (genInfo? : Option GenPatternInfo) ( pArg : Expr) (k : Nat) (e : Expr) : M Unit := do
191239 let maxGeneration ← getMaxGeneration
192240 let mut curr := e
193241 repeat
194242 let n ← getENode curr
195243 if n.generation < maxGeneration then
196244 if let some (eArg, k') ← isOffset? curr |>.run then
197- if k' < k then
198- let c := c.updateGen n.generation
199- pushChoice { c with cnstrs := .offset pArg (k - k') eArg :: c.cnstrs }
200- else if k' == k then
201- if let some c ← matchArg? c pArg eArg |>.run then
202- pushChoice (c.updateGen n.generation)
203- else if k' > k then
204- let eArg' := mkNatAdd eArg (mkNatLit (k' - k))
205- let eArg' ← shareCommon (← canon eArg')
206- internalize eArg' (n.generation+1 )
207- if let some c ← matchArg? c pArg eArg' |>.run then
208- pushChoice (c.updateGen n.generation)
245+ if let some c ← assignGenInfo? genInfo? c e |>.run then
246+ if k' < k then
247+ let c := c.updateGen n.generation
248+ pushChoice { c with cnstrs := .offset none pArg (k - k') eArg :: c.cnstrs }
249+ else if k' == k then
250+ if let some c ← matchArg? c pArg eArg |>.run then
251+ pushChoice (c.updateGen n.generation)
252+ else if k' > k then
253+ let eArg' := mkNatAdd eArg (mkNatLit (k' - k))
254+ let eArg' ← shareCommon (← canon eArg')
255+ internalize eArg' (n.generation+1 )
256+ if let some c ← matchArg? c pArg eArg' |>.run then
257+ pushChoice (c.updateGen n.generation)
209258 else if let some k' ← evalNat curr |>.run then
210- if k' >= k then
211- let eArg' := mkNatLit (k' - k)
212- let eArg' ← shareCommon (← canon eArg')
213- internalize eArg' (n.generation+1 )
214- if let some c ← matchArg? c pArg eArg' |>.run then
215- pushChoice (c.updateGen n.generation)
259+ if let some c ← assignGenInfo? genInfo? c e |>.run then
260+ if k' >= k then
261+ let eArg' := mkNatLit (k' - k)
262+ let eArg' ← shareCommon (← canon eArg')
263+ internalize eArg' (n.generation+1 )
264+ if let some c ← matchArg? c pArg eArg' |>.run then
265+ pushChoice (c.updateGen n.generation)
216266 curr ← getNext curr
217267 if isSameExpr curr e then break
218268
@@ -280,7 +330,7 @@ private def addNewInstance (thm : EMatchTheorem) (proof : Expr) (generation : Na
280330 check proof
281331 let mut prop ← inferType proof
282332 let mut proof := proof
283- if Match.isMatchEqnTheorem (← getEnv) thm.origin.key then
333+ if (← isMatchEqLikeDeclName thm.origin.key) then
284334 prop ← annotateMatchEqnType prop (← read).initApp
285335 -- We must add a hint here because `annotateMatchEqnType` introduces `simpMatchDiscrsOnly` and
286336 -- `Grind.PreMatchCond` which are not reducible.
@@ -302,13 +352,43 @@ private def synthesizeInsts (mvars : Array Expr) (bis : Array BinderInfo) : Opti
302352 reportIssue! "failed to synthesize instance when instantiating {← thm.origin.pp}{indentExpr type}"
303353 failure
304354
355+ private def preprocessGeneralizedPatternRHS (lhs : Expr) (rhs : Expr) (origin : Origin) (expectedType : Expr) : OptionT (StateT Choice M) Expr := do
356+ assert! (← alreadyInternalized lhs)
357+ -- We use `dsimp` here to ensure terms such as `Nat.succ x` are normalized as `x+1`.
358+ let rhs ← preprocessLight (← dsimpCore rhs)
359+ internalize rhs (← getGeneration lhs)
360+ processNewFacts
361+ if (← isEqv lhs rhs) then
362+ return rhs
363+ else
364+ reportIssue! "invalid generalized pattern at `{← origin.pp}`\n when processing argument with type{indentExpr expectedType}\n failed to prove{indentExpr lhs}\n is equal to{indentExpr rhs}"
365+ failure
366+
367+ private def assignGeneralizedPatternProof (mvarId : MVarId) (eqProof : Expr) (origin : Origin) : OptionT (StateT Choice M) Unit := do
368+ unless (← mvarId.checkedAssign eqProof) do
369+ reportIssue! "invalid generalized pattern at `{← origin.pp}`\n failed to assign {mkMVar mvarId}\n with{indentExpr eqProof}"
370+ failure
371+
305372private def applyAssignment (mvars : Array Expr) : OptionT (StateT Choice M) Unit := do
306373 let thm := (← read).thm
307374 let numParams := thm.numParams
308375 for h : i in [:mvars.size] do
309376 let bidx := numParams - i - 1
310377 let mut v := (← get).assignment[bidx]!
311- unless isSameExpr v unassigned do
378+ if isSameExpr v delayedEqProof then
379+ let mvarId := mvars[i].mvarId!
380+ let mvarIdType ← instantiateMVars (← mvarId.getType)
381+ match_expr mvarIdType with
382+ | Eq α lhs rhs =>
383+ let rhs ← preprocessGeneralizedPatternRHS lhs rhs thm.origin mvarIdType
384+ assignGeneralizedPatternProof mvarId (← mkEqProof lhs rhs) thm.origin
385+ | HEq α lhs β rhs =>
386+ let rhs ← preprocessGeneralizedPatternRHS lhs rhs thm.origin mvarIdType
387+ assignGeneralizedPatternProof mvarId (← mkHEqProof lhs rhs) thm.origin
388+ | _ =>
389+ reportIssue! "invalid generalized pattern at `{← thm.origin.pp}`\n equality type expected{indentExpr mvarIdType}"
390+ failure
391+ else if !isSameExpr v unassigned then
312392 let mvarId := mvars[i].mvarId!
313393 let mvarIdType ← mvarId.getType
314394 let vType ← inferType v
@@ -388,8 +468,8 @@ private def processChoices : M Unit := do
388468 if c.gen < maxGeneration then
389469 match c.cnstrs with
390470 | [] => instantiateTheorem c
391- | .match p e :: cnstrs => processMatch { c with cnstrs } p e
392- | .offset p k e :: cnstrs => processOffset { c with cnstrs } p k e
471+ | .match genInfo? p e :: cnstrs => processMatch { c with cnstrs } genInfo? p e
472+ | .offset genInfo? p k e :: cnstrs => processOffset { c with cnstrs } genInfo? p k e
393473 | .continue p :: cnstrs => processContinue { c with cnstrs } p
394474
395475private def main (p : Expr) (cnstrs : List Cnstr) : M Unit := do
0 commit comments