diff --git a/brat/Data/Hugr.hs b/brat/Data/Hugr.hs index 9791134d..6ea8b202 100644 --- a/brat/Data/Hugr.hs +++ b/brat/Data/Hugr.hs @@ -455,6 +455,11 @@ holeOp :: Int -> FunctionType -> CustomOp holeOp idx sig = CustomOp "BRAT" "Hole" sig [TANat idx, TAType (HTFunc (PolyFuncType [] sig))] +isHole :: HugrOp -> Maybe (Int, FunctionType) +isHole (OpCustom (CustomOp "BRAT" "Hole" sig args)) = + let [TANat idx, _] = args in Just (idx, sig) -- crash rather than return false for bad args +isHole _ = Nothing + -- TYPE ARGS: -- * A length-2 sequence comprising: -- * A sequence of types (the inputs of outerSig) diff --git a/brat/Data/HugrGraph.hs b/brat/Data/HugrGraph.hs index 8f914435..a1a4e4e9 100644 --- a/brat/Data/HugrGraph.hs +++ b/brat/Data/HugrGraph.hs @@ -5,16 +5,18 @@ module Data.HugrGraph(NodeId, freshNode, setFirstChildren, setOp, getParent, getOp, - addEdge, addOrderEdge, - edgeList, serialize + addEdge, addOrderEdge, edgeList, + splice, inlineDFG, + serialize ) where -import Brat.Naming (Namespace, Name, fresh, split) +import Brat.Naming (Namespace, Name(..), fresh, split) import Bwd import Data.Hugr hiding (const) -import Control.Monad.State (State, execState, state) +import Control.Monad.State (State, execState, state, get, put, modify) import Data.Foldable (for_) +import Data.Functor ((<&>)) import Data.Bifunctor (first) import Data.Maybe (fromMaybe) import qualified Data.Map as M @@ -99,6 +101,99 @@ getParent HugrGraph {parents} n = parents M.! n getOp :: HugrGraph -> NodeId -> HugrOp getOp HugrGraph {nodes} n = nodes M.! n +--- Replaces the specified node of the first Hugr, with the second Hugr. +splice :: HugrGraph -> NodeId -> HugrGraph -> HugrGraph +splice host hole add = case (M.lookup hole (nodes host) >>= isHole) of + Just (_, sig) -> case M.lookup (root add) (nodes add) of + Just (OpDFG (DFG sig' _)) | sig == sig' -> {-inlineDFG hole-} host { + -- prefer host entry for parent of (`hole` == root of `add`) + parents = union (parents host) (M.mapKeys k $ M.map k $ parents add), + -- override host `nodes` for `hole` with new (DFG) + nodes = M.union (M.mapKeys k (nodes add)) (nodes host), + edges_in = union (edges_in host) $ M.fromList [(k tgt, [(Port (k srcNode) srcPort, tgtPort) + | (Port srcNode srcPort, tgtPort) <- in_edges ]) + | (tgt, in_edges ) <- M.assocs (edges_in add)], + edges_out = union (edges_out host) $ M.fromList [(k src, [(srcPort, Port (k tgtNode) tgtPort) + | (srcPort, Port tgtNode tgtPort) <- out_edges]) + | (src, out_edges) <- M.assocs (edges_out add)], + first_children = union (first_children host) (M.mapKeys k $ M.map (k <$>) $ first_children add) + } + other -> error $ "Expected DFG with sig " ++ show sig ++ "\nBut found: " ++ show other + other -> error $ "Expected a hole, found " ++ show other + where + prefixRoot :: NodeId -> NodeId + prefixRoot (NodeId (MkName ids)) = let NodeId (MkName rs) = hole in NodeId $ MkName (rs ++ ids) + + keyMap :: M.Map NodeId NodeId -- translate `add` keys into `host` by prefixing with `hole`. + -- parent is definitive list of non-root nodes + keyMap = M.fromList $ (root add, hole):[(k, prefixRoot k) | k <- M.keys (parents add)] + + union = M.unionWith (\_ _ -> error "keys not disjoint") + k = (keyMap M.!) + +inlineDFG :: NodeId -> State HugrGraph () +inlineDFG dfg = get >>= \h -> case M.lookup dfg (nodes h) of + (Just (OpDFG _)) -> do + let newp = (parents h) M.! dfg + let [inp, out] = (first_children h) M.! dfg + -- rewire edges + dfg_in_map <- takeInEdgeMap dfg + input_out_map <- takeOutEdges inp + for_ input_out_map $ \(outp, dest) -> addEdge (dfg_in_map M.! outp, dest) + dfg_out_map <- takeOutEdges dfg + output_in_map <- takeInEdgeMap out + for_ dfg_out_map $ \(outp, dest) -> addEdge (output_in_map M.! outp, dest) + -- remove dfg, inp, out; reparent children of dfg + let to_remove = [dfg, inp, out] + modify $ \h -> h { + first_children = M.delete dfg (first_children h), -- inp/out shouldn't have any children + nodes = foldl (flip M.delete) (nodes h) to_remove, + -- TODO this is O(size of hugr) reparenting. Either add a child map, + -- or combine with splicing so we only iterate through the inserted + -- hugr (which we do anyway) rather than the host. + parents = M.fromList [(n, if p==dfg then newp else p) + | (n,p) <- M.assocs (parents h), not (elem n to_remove)] + } + other -> error $ "Expected DFG, found " ++ show other + where + takeInEdgeMap n = takeInEdges n <&> \es -> M.fromList [(p, src) | (src, p) <- es] + +takeInEdges :: NodeId -> State HugrGraph [(PortId NodeId, Int)] +takeInEdges tgt = do + h <- get + let (removed_edges, edges_in') = first (fromMaybe []) $ M.updateLookupWithKey + (\_ _ -> Nothing) tgt (edges_in h) + let edges_out' = foldl removeFromOutMap (edges_out h) removed_edges + put h {edges_in=edges_in', edges_out=edges_out'} + pure removed_edges + where + removeFromOutMap :: M.Map NodeId [(Int, PortId NodeId)] -> (PortId NodeId, Int) -> M.Map NodeId [(Int, PortId NodeId)] + removeFromOutMap eos (Port src outport, inport) = M.alter (\(Just es) -> Just $ removeFromOutList es (outport, Port tgt inport)) src eos + + removeFromOutList :: [(Int, PortId NodeId)] -> (Int, PortId NodeId) -> [(Int, PortId NodeId)] + removeFromOutList [] _ = error "Out-edge not found" + removeFromOutList (e:es) e' | e == e' = es + removeFromOutList ((outport, _):_) (outport', _) | outport == outport' = error "Wrong out-edge" + removeFromOutList (e:es) r = e:(removeFromOutList es r) + +takeOutEdges :: NodeId -> State HugrGraph [(Int, PortId NodeId)] +takeOutEdges src = do + h <- get + let (removed_edges, edges_out') = first (fromMaybe []) $ M.updateLookupWithKey + (\_ _ -> Nothing) src (edges_out h) + let edges_in' = foldl removeFromInMap (edges_in h) removed_edges + put h {edges_in=edges_in', edges_out=edges_out'} + pure removed_edges + where + removeFromInMap :: M.Map NodeId [(PortId NodeId, Int)] -> (Int, PortId NodeId) -> M.Map NodeId [(PortId NodeId, Int)] + removeFromInMap eis (outport, Port tgt inport) = M.alter (\(Just es) -> Just $ removeFromInList es (Port src outport, inport)) tgt eis + + removeFromInList:: [(PortId NodeId, Int)] -> (PortId NodeId, Int) -> [(PortId NodeId, Int)] + removeFromInList [] _ = error "In-edge not found" + removeFromInList (e:es) e' | e==e' = es + removeFromInList ((_, inport):_) (_,inport') | inport == inport' = error "Wrong in-edge" + removeFromInList (e:es) r = e:(removeFromInList es r) + serialize :: HugrGraph -> Hugr Int serialize hugr = renameAndSort (execState (for_ orderEdges addOrderEdge) hugr) where @@ -141,7 +236,7 @@ renameAndSort hugr@(HugrGraph {root, first_children=fc, nodes, parents}) = Hugr nodeStackAndIndices = let just_root = (B0 :< (root, nodes M.! root), M.singleton root 0) init = foldl addNode just_root (first_children root) in foldl addNode init (M.keys parents) - + addNode :: StackAndIndices -> NodeId -> StackAndIndices addNode ins n = case M.lookup n (snd ins) of (Just _) -> ins diff --git a/brat/brat.cabal b/brat/brat.cabal index 81e57265..51ad35c3 100644 --- a/brat/brat.cabal +++ b/brat/brat.cabal @@ -166,6 +166,7 @@ test-suite tests Test.Elaboration, Test.Failure, Test.Graph, + Test.HugrGraph, Test.Libs, Test.Parsing, Test.Naming, @@ -175,7 +176,8 @@ test-suite tests Test.TypeArith, Test.Util - build-depends: base <5, + build-depends: aeson, + base <5, brat, tasty, tasty-hunit, diff --git a/brat/test/Main.hs b/brat/test/Main.hs index 2c67e0cb..f91033f6 100644 --- a/brat/test/Main.hs +++ b/brat/test/Main.hs @@ -7,6 +7,7 @@ import Test.Graph import Test.Compile.Hugr import Test.Elaboration import Test.Failure +import Test.HugrGraph import Test.Libs import Test.Naming import Test.Parsing @@ -64,6 +65,7 @@ main = do parsingTests <- getParsingTests compilationTests <- setupCompilationTests graphTests <- getGraphTests + spliceTests <- getSpliceTests let coroTests = testGroup "coroutine" [testCase "coroT1" $ assertChecking coroT1 ,testCase "coroT2" $ assertCheckingFail "Typechecking blocked on" coroT2 @@ -82,4 +84,5 @@ main = do ,compilationTests ,typeArithTests ,coroTests + ,spliceTests ] diff --git a/brat/test/Test/HugrGraph.hs b/brat/test/Test/HugrGraph.hs new file mode 100644 index 00000000..b9c10580 --- /dev/null +++ b/brat/test/Test/HugrGraph.hs @@ -0,0 +1,75 @@ +module Test.HugrGraph(getSpliceTests) where + +import Brat.Naming as N +import Data.HugrGraph as H +import Data.Hugr + +import Control.Monad.State (State, execState, get, runState) +import Data.Aeson (encode) +import Data.Functor ((<&>)) +import Data.Maybe (isJust, isNothing) +import Data.List (find) +import qualified Data.ByteString.Lazy as BS +import System.Directory (createDirectoryIfMissing) +import System.FilePath +import Test.Tasty +import Test.Tasty.HUnit + +prefix = "test/hugr" +outputDir = prefix "output" + +addNode :: String -> NodeId -> HugrOp -> State HugrGraph NodeId +addNode nam parent op = do + name <- H.freshNode parent nam + H.setOp name op + pure name + +getSpliceTests :: IO TestTree +getSpliceTests = do + createDirectoryIfMissing True outputDir + pure $ testGroup "splice" [testSplice False, testSplice True] + +testSplice :: Bool -> TestTree +testSplice inline = testCaseInfo name $ do + let (h, holeId) = host + let outPrefix = outputDir name + BS.writeFile (outPrefix ++ "_host.json") (encode $ H.serialize h) + BS.writeFile (outPrefix ++ "_insertee.json") (encode $ H.serialize dfgHugr) + let spliced = H.splice h holeId dfgHugr + let resHugr@(Hugr (ns, _)) = H.serialize $ if inline + then execState (inlineDFG holeId) spliced else spliced + let outFile = outPrefix ++ "_result.json" + BS.writeFile outFile $ encode resHugr + assertBool "Should be no holes now" $ isNothing $ find (isJust . isHole) $ snd <$> ns + -- if inline, we should assert there's no DFG either + pure $ "Written to " ++ outFile ++ " pending validation" + where + name = if inline then "inline" else "noinline" + host :: (HugrGraph, NodeId) + host = swap $ flip runState (H.new N.root "root" rootDefn) $ do + root <- get <&> H.root + input <- addNode "inp" root (OpIn (InputNode tys [])) + output <- addNode "out" root (OpOut (OutputNode tys [])) + setFirstChildren root [input, output] + hole <- addNode "hole" root (OpCustom $ holeOp 0 tq_ty) + H.addEdge (Port input 0, Port hole 0) + H.addEdge (Port input 1, Port hole 1) + H.addEdge (Port hole 0, Port output 0) + H.addEdge (Port hole 1, Port output 1) + pure hole + dfgHugr :: HugrGraph + dfgHugr = flip execState (H.new N.root "root" rootDfg) $ do + root <- get <&> H.root + input <- addNode "inp" root (OpIn (InputNode tys [])) + output <- addNode "out" root (OpOut (OutputNode tys [])) + setFirstChildren root [input, output] + gate <- addNode "gate" root (OpCustom $ CustomOp "tket" "CX" tq_ty []) + H.addEdge (Port input 0, Port gate 0) + H.addEdge (Port input 1, Port gate 1) + H.addEdge (Port gate 0, Port output 0) + H.addEdge (Port gate 1, Port output 1) + swap (x,y) = (y,x) + tys = [HTQubit, HTQubit] + tq_ty = FunctionType tys tys bratExts + rootDefn = OpDefn $ FuncDefn "main" (PolyFuncType [] tq_ty) [] + rootDfg = OpDFG $ DFG tq_ty [] diff --git a/brat/tools/validate.sh b/brat/tools/validate.sh index 2a4342a0..e8f234d5 100755 --- a/brat/tools/validate.sh +++ b/brat/tools/validate.sh @@ -12,20 +12,22 @@ declare -a FAILED_TEST_MSGS UNEXPECTED_PASSES= NUM_FAILURES=0 -for json in $(find test/compilation/output -maxdepth 1 -name "*.json"); do - echo Validating "$json" - RESULT=$(cat "$json" | hugr_validator 2>&1) - if [ $? -ne 0 ]; then - FAILED_TEST_NAMES[NUM_FAILURES]=$json - FAILED_TEST_MSGS[NUM_FAILURES]=$RESULT - NUM_FAILURES=$((NUM_FAILURES + 1)) - fi -done +for dir in test/compilation/output test/hugr/output; do + for json in $(find $dir -maxdepth 1 -name "*.json"); do + echo Validating "$json" + RESULT=$(cat "$json" | hugr_validator 2>&1) + if [ $? -ne 0 ]; then + FAILED_TEST_NAMES[NUM_FAILURES]=$json + FAILED_TEST_MSGS[NUM_FAILURES]=$RESULT + NUM_FAILURES=$((NUM_FAILURES + 1)) + fi + done -for invalid_json in $(find test/compilation/output -maxdepth 1 -name "*.json.invalid"); do - if (hugr_validator < $invalid_json 2>/dev/null > /dev/null); then - UNEXPECTED_PASSES="$UNEXPECTED_PASSES $invalid_json" - fi + for invalid_json in $(find test/compilation/output -maxdepth 1 -name "*.json.invalid"); do + if (hugr_validator < $invalid_json 2>/dev/null > /dev/null); then + UNEXPECTED_PASSES="$UNEXPECTED_PASSES $invalid_json" + fi + done done RED='\033[0;31m'