Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 39 additions & 30 deletions brat/Brat/Checker.hs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import Brat.Naming
import Brat.QualName
-- import Brat.Search
import Brat.Syntax.Abstractor (NormalisedAbstractor(..), normaliseAbstractor)
import Brat.Syntax.CircuitProperties (CircuitProperties(..), Properties)
import Brat.Syntax.Common
import Brat.Syntax.Core
import Brat.Syntax.FuncDecl (FunBody(..))
Expand Down Expand Up @@ -174,6 +175,7 @@ check :: (CheckConstraints m k
,EvMode m
,TensorOutputs (Outputs m d)
,?my :: Modey m
,?props :: Properties m
, DIRY d)
=> WC (Term d k)
-> ChkConnectors m d k
Expand All @@ -186,6 +188,7 @@ check' :: forall m d k
,EvMode m
,TensorOutputs (Outputs m d)
,?my :: Modey m
,?props :: Properties m
, DIRY d)
=> Term d k
-> ChkConnectors m d k
Expand Down Expand Up @@ -218,7 +221,7 @@ check' (Lambda c@(WC abstFC abst, body) cs) (overs, unders) = do
-- We'll check the first variant against a Hypo node (omitted from compilation)
-- to work out how many overs/unders it needs, and then check it again (in Chk)
-- with the other clauses, as part of the body.
(ins :->> outs) <- mkSig usedOvers unders
(FunTy _ ins outs) <- mkSig usedOvers unders
(allFakeUnders, rightFakeUnders, tgtMap) <- suppressHoles $ suppressGraph $ do
(_, [], fakeOvers, fakeAcc) <- anext "lambda_fake_source" Hypo (S0, Some (Zy :* S0)) R0 ins
-- Hypo `check` calls need an environment, even just to compute leftovers;
Expand Down Expand Up @@ -264,18 +267,19 @@ check' (Lambda c@(WC abstFC abst, body) cs) (overs, unders) = do
portNamesToBoundNames :: [(String, (Src, BinderType m))] -> [(String, (Src, BinderType m))]
portNamesToBoundNames = fmap (\(n, (src, ty)) -> (n, (NamedPort (end src) n, ty)))

mkSig :: ToEnd t => [(Src, BinderType m)] -> [(NamedPort t, BinderType m)] -> Checking (CTy m Z)
mkSig :: ToEnd t => [(Src, BinderType m)] -> [(NamedPort t, BinderType m)] -> Checking (FunTy m Z)
mkSig overs unders = rowToRo ?my (retuple <$> overs) S0 >>=
\(Some (inRo :* endz)) -> rowToRo ?my (retuple <$> unders) endz >>=
\(Some (outRo :* _)) -> pure (inRo :->> outRo)
\(Some (outRo :* _)) -> pure (FunTy ?props inRo outRo)


retuple (NamedPort e p, ty) = (p, e, ty)

mkWires overs unders = case zipSameLength overs unders of
Nothing -> err $ InternalError "Trying to wire up different sized lists of wires"
Just conns -> traverse (\((src, ty), (tgt, _)) -> wire (src, binderToValue ?my ty, tgt)) conns

checkClauses cty@(ins :->> outs) overs all_cs = do
checkClauses cty@(FunTy _ ins outs) overs all_cs = do
let clauses = NE.zip (NE.fromList [0..]) all_cs <&>
\(i, (abs, tm)) -> Clause i (normaliseAbstractor <$> abs) tm
clauses <- traverse (checkClause ?my "lambda" cty) clauses
Expand All @@ -289,7 +293,7 @@ check' (Pull ports t) (overs, unders) = do
unders <- pullPortsRow ports unders
check t (overs, unders)
check' (t ::: outs) (overs, ()) | Braty <- ?my = do
(ins :->> outs) :: CTy Brat Z <- kindCheckAnnotation Braty ":::" outs
(ins :->> outs) :: FunTy Brat Z <- kindCheckAnnotation Braty () ":::" outs
(_, hungries, danglies, _) <- next "id" Id (S0,Some (Zy :* S0)) ins outs
((), leftOvers) <- noUnders $ check t (overs, hungries)
pure (((), danglies), (leftOvers, ()))
Expand All @@ -312,13 +316,13 @@ check' (Th tm) ((), u@(hungry, ty):unders) = case (?my, ty) of
checkThunk :: forall m. (CheckConstraints m UVerb, EvMode m)
=> Modey m
-> String
-> CTy m Z
-> FunTy m Z
-> WC (Term Chk UVerb)
-> Checking Src
checkThunk m name cty tm = do
checkThunk m name cty@(FunTy props _ _) tm = do
((dangling, _), ()) <- let ?my = m in makeBox name cty $
\(thOvers, thUnders) -> do
(((), ()), leftovers) <- check tm (thOvers, thUnders)
(((), ()), leftovers) <- let ?props = props in check tm (thOvers, thUnders)
case leftovers of
([], []) -> pure ()
([], unders) -> err (ThunkLeftUnders (showRow unders))
Expand All @@ -332,34 +336,38 @@ check' (TypedTh t) ((), ()) = case ?my of
Braty -> do
-- but the computation in it could be either Brat or Kern
brat <- catchErr $ check t ((), ())
kern <- catchErr $ let ?my = Kerny in check t ((), ())
-- TODO: What if the programmer intends a less strict kernel type?
let props = PControllable
kern <- catchErr $ let ?my = Kerny in let ?props = props in check t ((), ())
case (brat, kern) of
(Left e, Left _) -> req $ Throw e -- pick an error arbitrarily
-- I don't believe that there is any syntax that could synthesize
-- both a classical type and a kernel type, but just in case:
-- (pushing down Emb(TypedTh(v)) to Thunk(Emb+Forget(v)) would help in Checkable cases)
(Right _, Right _) -> typeErr "TypedTh could be either Brat or Kernel"
(Left _, Right (conns, ((), ()))) -> let ?my = Kerny in createThunk conns
(Right (conns, ((), ())), Left _) -> createThunk conns
(Left _, Right (conns, ((), ()))) -> let ?my = Kerny in createThunk props conns
(Right (conns, ((), ())), Left _) -> createThunk () conns
where
createThunk :: (CheckConstraints m2 Noun, ?my :: Modey m2, EvMode m2)
=> SynConnectors m2 Syn KVerb
=> Properties m2
-> SynConnectors m2 Syn KVerb
-> Checking (SynConnectors Brat Syn Noun
,ChkConnectors Brat Syn Noun)
createThunk (ins, outs) = do
createThunk ps (ins, outs) = do
Some (ez :* inR) <- mkArgRo ?my S0 (first (fmap toEnd) <$> ins)
Some (_ :* outR) <- mkArgRo ?my ez (first (fmap toEnd) <$> outs)
(thunkOut, ()) <- makeBox "thunk" (inR :->> outR) $
(thunkOut, ()) <- makeBox "thunk" (FunTy ps inR outR) $
\(thOvers, thUnders) -> do
-- if these ensureEmpty's fail then its a bug!
checkInputs t thOvers ins >>= ensureEmpty "TypedTh inputs"
checkOutputs t thUnders outs >>= ensureEmpty "TypedTh outputs"
pure (((), [thunkOut]), ((), ()))
check' (Force th) ((), ()) = do
(((), outs), ((), ())) <- let ?my = Braty in check th ((), ())
(((), outs), ((), ())) <- let ?my = Braty in let ?props = () in check th ((), ())
-- pull a bunch of thunks (only!) out of here
(_, thInputs, thOutputs) <- getThunks ?my outs
(_, thInputs, thOutputs) <- getThunks ?my ?props outs
pure ((thInputs, thOutputs), ((), ()))

check' (Forget kv) (overs, unders) = do
((ins, outs), ((), rightUnders)) <- check kv ((), unders)
leftOvers <- checkInputs kv overs ins
Expand All @@ -384,9 +392,9 @@ check' (Arith op l r) ((), u@(hungry, ty):unders) = case (?my, ty) of
let inRo = RPr ("left", ty) $ RPr ("right", ty) R0
let outRo = RPr ("out", ty) R0
(_, [lunders, runders], [(dangling, _)], _) <- next (show op) (ArithNode op) (S0, Some $ Zy :* S0) inRo outRo
(((), ()), ((), leftUnders)) <- check l ((), [lunders])
(((), ()), ((), leftUnders)) <- let ?props = () in check l ((), [lunders])
ensureEmpty "arith unders" leftUnders
(((), ()), ((), leftUnders)) <- check r ((), [runders])
(((), ()), ((), leftUnders)) <- let ?props = () in check r ((), [runders])
ensureEmpty "arith unders" leftUnders
wire (dangling, ty, hungry)
pure (((), ()), ((), unders))
Expand Down Expand Up @@ -684,13 +692,13 @@ data Clause = Clause
-- refined overs)
checkClause :: forall m. (CheckConstraints m UVerb, EvMode m) => Modey m
-> String
-> CTy m Z
-> FunTy m Z
-> Clause
-> Checking
( TestMatchData m -- TestMatch data (LHS)
, Name -- Function node (RHS)
)
checkClause my fnName cty clause = modily my $ do
checkClause my fnName cty@(FunTy props _ _) clause = modily my $ do
let clauseName = fnName ++ "." ++ show (index clause)

-- First, we check the patterns on the LHS. This requires some overs,
Expand All @@ -708,14 +716,14 @@ checkClause my fnName cty clause = modily my $ do
Some (_ :* outRo) <- mkArgRo my patEz (first (fmap toEnd) <$> unders)
let match = TestMatchData my $ MatchSequence overs tests (snd <$> sol)
let vars = fst <$> sol
pure (vars, match, patRo :->> outRo)
pure (vars, match, FunTy props patRo outRo)

-- Now actually make a box for the RHS and check it
((boxPort, _ty), _) <- let ?my = my in makeBox (clauseName ++ "_rhs") rhsCty $ \(rhsOvers, rhsUnders) -> do
let abstractor = foldr ((:||:) . APat . Bind) AEmpty vars
let ?my = my in do
env <- abstractAll rhsOvers abstractor
localEnv env $ check @m (rhs clause) ((), rhsUnders)
localEnv env $ let ?props = props in check @m (rhs clause) ((), rhsUnders)
let NamedPort {end=Ex rhsNode _} = boxPort
pure (match, rhsNode)

Expand All @@ -724,17 +732,17 @@ checkClause my fnName cty clause = modily my $ do
checkBody :: (CheckConstraints m UVerb, EvMode m, ?my :: Modey m)
=> String -- The function name
-> FunBody Term UVerb
-> CTy m Z -- Function type
-> FunTy m Z -- Function type
-> Checking Src
checkBody fnName body cty = do
checkBody fnName body cty@(FunTy props _ _) = do
(tm, (absFC, tmFC)) <- case body of
NoLhs tm -> pure (tm, (fcOf tm, fcOf tm))
Clauses (c@(abs, tm) :| cs) -> do
fc <- req AskFC
pure (WC fc (Lambda c cs), (fcOf abs, fcOf tm))
Undefined -> err (InternalError "Checking undefined clause")
((src, _), _) <- makeBox (fnName ++ ".box") cty $ \conns@(_, unders) -> do
(((), ()), leftovers) <- check tm conns
(((), ()), leftovers) <- let ?props = props in check tm conns
case leftovers of
([], []) -> pure ()
([], rightUnders) -> localFC tmFC $
Expand Down Expand Up @@ -825,14 +833,14 @@ kindCheck ((hungry, Star []):unders) (C (ss :-> ts)) = do
let val = VFun Braty (inRo :->> outRo)
defineTgt hungry val
pure ([val], unders)
kindCheck ((hungry, Star []):unders) (K (ss :-> ts)) = do
kindCheck ((hungry, Star []):unders) (K ps (ss :-> ts)) = do
-- N.B. Kernels can't bind so we don't need to pass around a stack of ends
ss <- kindCheckRow Kerny "" ss
ts <- kindCheckRow Kerny "" ts
case (ss, ts) of
(Some ss, Some ts) -> case kernelNoBind ss of
Refl -> do
let val = VFun Kerny (ss :->> ts)
let val = VFun Kerny (FunTy ps ss ts)
defineTgt hungry val
pure ([val], unders)

Expand Down Expand Up @@ -930,10 +938,11 @@ kindCheckRow my name r = do
-- Checks that an annotation is a valid row, returning the
-- evaluation of the type of an Id node passing through such values
kindCheckAnnotation :: Modey m
-> Properties m
-> String -- for node name
-> [(PortName, ThunkRowType m)]
-> Checking (CTy m Z)
kindCheckAnnotation my name outs = do
-> Checking (FunTy m Z)
kindCheckAnnotation my ps name outs = do
trackM "kca"
name <- req (Fresh $ "__kca_" ++ name)
kindCheckRow' my (Zy :* S0) M.empty (name, 0) outs >>= \case
Expand All @@ -944,7 +953,7 @@ kindCheckAnnotation my name outs = do
-- but persuades the Haskell typechecker it's ok to use the copy
-- as return types (that happen not to mention the argument types).
case varChangerThroughRo (ParToInx (AddZ n) s) ins of
Some (_ :* outs) -> pure (ins :->> outs)
Some (_ :* outs) -> pure (FunTy ps ins outs)

kindCheckRow' :: forall m n
. Modey m
Expand Down
72 changes: 41 additions & 31 deletions brat/Brat/Checker/Helpers.hs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ module Brat.Checker.Helpers {-(pullPortsRow, pullPortsSig
,ensureEmpty, noUnders
,rowToSig
,showMode, getVec
,mkThunkTy
,wire
,next, knext, anext
,kindType, getThunks
Expand All @@ -27,7 +26,8 @@ import Brat.Eval (eval, EvMode(..), kindType)
import Brat.FC (FC)
import Brat.Graph (Node(..), NodeType(..))
import Brat.Naming (Name, FreshMonad(..))
import Brat.Syntax.Common
import Brat.Syntax.CircuitProperties (CircuitProperties(..), Properties)
import Brat.Syntax.Common hiding (pattern PNone)
import Brat.Syntax.Core (Term(..))
import Brat.Syntax.Simple
import Brat.Syntax.Port (ToEnd(..))
Expand All @@ -36,6 +36,7 @@ import Bwd
import Hasochism
import Util (log2)

import Control.Monad (when)
import Control.Monad.State.Lazy (StateT(..), runStateT)
import Control.Monad.Freer (req)
import Data.Bifunctor
Expand Down Expand Up @@ -149,15 +150,6 @@ type family ThunkRowType (m :: Mode) where
ThunkRowType Brat = KindOr (Term Chk Noun)
ThunkRowType Kernel = Term Chk Noun

mkThunkTy :: Modey m
-> ThunkFCType m
-> [(PortName, ThunkRowType m)]
-> [(PortName, ThunkRowType m)]
-> Term Chk Noun
-- mkThunkTy Braty fc ss ts = C (WC fc (ss :-> ts))
mkThunkTy Braty _ ss ts = C (ss :-> ts)
mkThunkTy Kerny () ss ts = K (ss :-> ts)

anext :: forall m i j k
. EvMode m
=> String
Expand Down Expand Up @@ -238,38 +230,56 @@ wire (src, ty, tgt) = do
-- This is the dual notion to the overs and unders used for typechecking against
-- Hence, we return them here in the opposite order to `check`'s connectors
getThunks :: Modey m
-> Properties m
-> [(Src, BinderType Brat)] -- A row of 0 or more function types in the same mode
-> Checking ([Name]
,Unders m Chk
,Overs m UVerb
)
getThunks _ [] = pure ([], [], [])
getThunks Braty ((src, Right ty):rest) = do
getThunks _ _ [] = pure ([], [], [])
getThunks Braty () ((src, Right ty):rest) = do
ty <- eval S0 ty
(src, ss :->> ts) <- vectorise Braty (src, ty)
(node, unders, overs, _) <- let ?my = Braty in
anext "Eval" (Eval (end src)) (S0, Some (Zy :* S0)) ss ts
(nodes, unders', overs') <- getThunks Braty rest
(nodes, unders', overs') <- getThunks Braty () rest
pure (node:nodes, unders <> unders', overs <> overs')
getThunks Kerny ((src, Right ty):rest) = do
getThunks Kerny expectedProps ((src, Right ty):rest) = do
ty <- eval S0 ty
(src, ss :->> ts) <- vectorise Kerny (src,ty)
(src, sig@(FunTy props ss ts)) <- vectorise Kerny (src,ty)
when (props < expectedProps)
(err (TypeErr (unwords ["Expected all kernel types to be"
,expected expectedProps
,"but kernel with signature"
,show sig
,"was"
,actual props
])))
(node, unders, overs, _) <- let ?my = Kerny in anext "Splice" (Splice (end src)) (S0, Some (Zy :* S0)) ss ts
(nodes, unders', overs') <- getThunks Kerny rest
(nodes, unders', overs') <- getThunks Kerny props rest
pure (node:nodes, unders <> unders', overs <> overs')
getThunks Braty ((src, Left (Star args)):rest) = do
where
expected PControllable = "controllable and reversible"
expected PReversible = "reversible"
expected _ = undefined

actual PNone = "neither"
actual PReversible = "just reversible, not controllable"
actual _ = undefined

getThunks Braty () ((src, Left (Star args)):rest) = do
(node, unders, overs) <- case bwdStack (B0 <>< args) of
Some (_ :* stk) -> do
let (ri,ro) = kindArgRows stk
(node, unders, overs, _) <- next "Eval" (Eval (end src)) (S0, Some (Zy :* S0)) ri ro
pure (node, unders, overs)
(nodes, unders', overs') <- getThunks Braty rest
(nodes, unders', overs') <- getThunks Braty () rest
pure (node:nodes, unders <> unders', overs <> overs')
getThunks m ro = err $ ExpectedThunk (showMode m) (showRow ro)
getThunks m _ ro = err $ ExpectedThunk (showMode m) (showRow ro)

-- The type given here should be normalised
vecLayers :: Modey m -> Val Z -> Checking ([(Src, NumVal (VVar Z))] -- The sizes of the vector layers
,CTy m Z -- The function type at the end
,FunTy m Z -- The function type at the end
)
vecLayers my (TVec ty (VNum n)) = do
src <- mkStaticNum n
Expand Down Expand Up @@ -324,14 +334,14 @@ mkStaticNum n@(NumValue c gro) = do
wire (oneSrc, TNat, rhs)
pure src

vectorise :: forall m. Modey m -> (Src, Val Z) -> Checking (Src, CTy m Z)
vectorise :: forall m. Modey m -> (Src, Val Z) -> Checking (Src, FunTy m Z)
vectorise my (src, ty) = do
(layers, cty) <- vecLayers my ty
modily my $ foldrM mkMapFun (src, cty) layers
where
mkMapFun :: (Src, NumVal (VVar Z)) -- Layer to apply
-> (Src, CTy m Z) -- The input to this level of mapfun
-> Checking (Src, CTy m Z)
-> (Src, FunTy m Z) -- The input to this level of mapfun
-> Checking (Src, FunTy m Z)
mkMapFun (lenSrc, len) (valSrc, cty) = do
let weak1 = changeVar (Thinning (ThDrop ThNull))
vecFun <- vectorisedFun len my cty
Expand All @@ -342,17 +352,17 @@ vectorise my (src, ty) = do
defineTgt lenTgt (VNum len)
wire (lenSrc, kindType Nat, lenTgt)
wire (valSrc, ty, valTgt)
let vecCTy = case (my,my',cty) of
let vecFunTy = case (my,my',cty) of
(Braty,Braty,cty) -> cty
(Kerny,Kerny,cty) -> cty
_ -> error "next returned wrong mode of computation type to that passed in"
pure (vectorSrc, vecCTy)
pure (vectorSrc, vecFunTy)

vectorisedFun :: NumVal (VVar Z) -> Modey m -> CTy m Z -> Checking (Val Z)
vectorisedFun nv my (ss :->> ts) = do
vectorisedFun :: NumVal (VVar Z) -> Modey m -> FunTy m Z -> Checking (Val Z)
vectorisedFun nv my (FunTy ps ss ts) = do
(ss', ny) <- vectoriseRo True nv Zy ss
(ts', _) <- vectoriseRo False nv ny ts
pure $ modily my $ VFun my (ss' :->> ts')
pure $ modily my $ VFun my (FunTy ps ss' ts')

-- We don't allow existentials in vectorised functions, so the boolean says
-- whether we are in the input row and can allow binding
Expand Down Expand Up @@ -394,10 +404,10 @@ declareTgt tgt my ty = req (Declare (InEnd (end tgt)) my ty)
-- Build a box corresponding to the inside of a thunk
makeBox :: (?my :: Modey m, EvMode m)
=> String -- A label for the nodes we create
-> CTy m Z
-> FunTy m Z
-> ((Overs m UVerb, Unders m Chk) -> Checking a) -- checks + builds the body using srcs/tgts from the box
-> Checking ((Src, BinderType Brat), a)
makeBox name cty@(ss :->> ts) body = do
makeBox name cty@(FunTy _ ss ts) body = do
(src, _, overs, ctx) <- anext (name ++ "/in") Source (S0, Some (Zy :* S0)) R0 ss
(tgt, unders, _, _) <- anext (name ++ "/out") Target ctx ts R0
case (?my, body) of
Expand Down
Loading
Loading