Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions brat/Data/Hugr.hs
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,11 @@ holeOp :: Int -> FunctionType -> CustomOp
holeOp idx sig = CustomOp "BRAT" "Hole" sig
[TANat idx, TAType (HTFunc (PolyFuncType [] sig))]

isHole :: HugrOp -> Maybe (Int, FunctionType)
isHole (OpCustom (CustomOp "BRAT" "Hole" sig args)) =
let [TANat idx, _] = args in Just (idx, sig) -- crash rather than return false for bad args
isHole _ = Nothing

-- TYPE ARGS:
-- * A length-2 sequence comprising:
-- * A sequence of types (the inputs of outerSig)
Expand Down
105 changes: 100 additions & 5 deletions brat/Data/HugrGraph.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,18 @@ module Data.HugrGraph(NodeId,
freshNode,
setFirstChildren,
setOp, getParent, getOp,
addEdge, addOrderEdge,
edgeList, serialize
addEdge, addOrderEdge, edgeList,
splice, inlineDFG,
serialize
) where

import Brat.Naming (Namespace, Name, fresh, split)
import Brat.Naming (Namespace, Name(..), fresh, split)
import Bwd
import Data.Hugr hiding (const)

import Control.Monad.State (State, execState, state)
import Control.Monad.State (State, execState, state, get, put, modify)
import Data.Foldable (for_)
import Data.Functor ((<&>))
import Data.Bifunctor (first)
import Data.Maybe (fromMaybe)
import qualified Data.Map as M
Expand Down Expand Up @@ -99,6 +101,99 @@ getParent HugrGraph {parents} n = parents M.! n
getOp :: HugrGraph -> NodeId -> HugrOp
getOp HugrGraph {nodes} n = nodes M.! n

--- Replaces the specified node of the first Hugr, with the second Hugr.
splice :: HugrGraph -> NodeId -> HugrGraph -> HugrGraph
splice host hole add = case (M.lookup hole (nodes host) >>= isHole) of
Just (_, sig) -> case M.lookup (root add) (nodes add) of
Just (OpDFG (DFG sig' _)) | sig == sig' -> {-inlineDFG hole-} host {
-- prefer host entry for parent of (`hole` == root of `add`)
parents = union (parents host) (M.mapKeys k $ M.map k $ parents add),
-- override host `nodes` for `hole` with new (DFG)
nodes = M.union (M.mapKeys k (nodes add)) (nodes host),
edges_in = union (edges_in host) $ M.fromList [(k tgt, [(Port (k srcNode) srcPort, tgtPort)
| (Port srcNode srcPort, tgtPort) <- in_edges ])
| (tgt, in_edges ) <- M.assocs (edges_in add)],
edges_out = union (edges_out host) $ M.fromList [(k src, [(srcPort, Port (k tgtNode) tgtPort)
| (srcPort, Port tgtNode tgtPort) <- out_edges])
| (src, out_edges) <- M.assocs (edges_out add)],
first_children = union (first_children host) (M.mapKeys k $ M.map (k <$>) $ first_children add)
}
other -> error $ "Expected DFG with sig " ++ show sig ++ "\nBut found: " ++ show other
other -> error $ "Expected a hole, found " ++ show other
where
prefixRoot :: NodeId -> NodeId
prefixRoot (NodeId (MkName ids)) = let NodeId (MkName rs) = hole in NodeId $ MkName (rs ++ ids)

keyMap :: M.Map NodeId NodeId -- translate `add` keys into `host` by prefixing with `hole`.
-- parent is definitive list of non-root nodes
keyMap = M.fromList $ (root add, hole):[(k, prefixRoot k) | k <- M.keys (parents add)]

union = M.unionWith (\_ _ -> error "keys not disjoint")
k = (keyMap M.!)

inlineDFG :: NodeId -> State HugrGraph ()
inlineDFG dfg = get >>= \h -> case M.lookup dfg (nodes h) of
(Just (OpDFG _)) -> do
let newp = (parents h) M.! dfg
let [inp, out] = (first_children h) M.! dfg
-- rewire edges
dfg_in_map <- takeInEdgeMap dfg
input_out_map <- takeOutEdges inp
for_ input_out_map $ \(outp, dest) -> addEdge (dfg_in_map M.! outp, dest)
dfg_out_map <- takeOutEdges dfg
output_in_map <- takeInEdgeMap out
for_ dfg_out_map $ \(outp, dest) -> addEdge (output_in_map M.! outp, dest)
-- remove dfg, inp, out; reparent children of dfg
let to_remove = [dfg, inp, out]
modify $ \h -> h {
first_children = M.delete dfg (first_children h), -- inp/out shouldn't have any children
nodes = foldl (flip M.delete) (nodes h) to_remove,
-- TODO this is O(size of hugr) reparenting. Either add a child map,
-- or combine with splicing so we only iterate through the inserted
-- hugr (which we do anyway) rather than the host.
parents = M.fromList [(n, if p==dfg then newp else p)
| (n,p) <- M.assocs (parents h), not (elem n to_remove)]
}
other -> error $ "Expected DFG, found " ++ show other
where
takeInEdgeMap n = takeInEdges n <&> \es -> M.fromList [(p, src) | (src, p) <- es]

takeInEdges :: NodeId -> State HugrGraph [(PortId NodeId, Int)]
takeInEdges tgt = do
h <- get
let (removed_edges, edges_in') = first (fromMaybe []) $ M.updateLookupWithKey
(\_ _ -> Nothing) tgt (edges_in h)
let edges_out' = foldl removeFromOutMap (edges_out h) removed_edges
put h {edges_in=edges_in', edges_out=edges_out'}
pure removed_edges
where
removeFromOutMap :: M.Map NodeId [(Int, PortId NodeId)] -> (PortId NodeId, Int) -> M.Map NodeId [(Int, PortId NodeId)]
removeFromOutMap eos (Port src outport, inport) = M.alter (\(Just es) -> Just $ removeFromOutList es (outport, Port tgt inport)) src eos

removeFromOutList :: [(Int, PortId NodeId)] -> (Int, PortId NodeId) -> [(Int, PortId NodeId)]
removeFromOutList [] _ = error "Out-edge not found"
removeFromOutList (e:es) e' | e == e' = es
removeFromOutList ((outport, _):_) (outport', _) | outport == outport' = error "Wrong out-edge"
removeFromOutList (e:es) r = e:(removeFromOutList es r)

takeOutEdges :: NodeId -> State HugrGraph [(Int, PortId NodeId)]
takeOutEdges src = do
h <- get
let (removed_edges, edges_out') = first (fromMaybe []) $ M.updateLookupWithKey
(\_ _ -> Nothing) src (edges_out h)
let edges_in' = foldl removeFromInMap (edges_in h) removed_edges
put h {edges_in=edges_in', edges_out=edges_out'}
pure removed_edges
where
removeFromInMap :: M.Map NodeId [(PortId NodeId, Int)] -> (Int, PortId NodeId) -> M.Map NodeId [(PortId NodeId, Int)]
removeFromInMap eis (outport, Port tgt inport) = M.alter (\(Just es) -> Just $ removeFromInList es (Port src outport, inport)) tgt eis

removeFromInList:: [(PortId NodeId, Int)] -> (PortId NodeId, Int) -> [(PortId NodeId, Int)]
removeFromInList [] _ = error "In-edge not found"
removeFromInList (e:es) e' | e==e' = es
removeFromInList ((_, inport):_) (_,inport') | inport == inport' = error "Wrong in-edge"
removeFromInList (e:es) r = e:(removeFromInList es r)

serialize :: HugrGraph -> Hugr Int
serialize hugr = renameAndSort (execState (for_ orderEdges addOrderEdge) hugr)
where
Expand Down Expand Up @@ -141,7 +236,7 @@ renameAndSort hugr@(HugrGraph {root, first_children=fc, nodes, parents}) = Hugr
nodeStackAndIndices = let just_root = (B0 :< (root, nodes M.! root), M.singleton root 0)
init = foldl addNode just_root (first_children root)
in foldl addNode init (M.keys parents)

addNode :: StackAndIndices -> NodeId -> StackAndIndices
addNode ins n = case M.lookup n (snd ins) of
(Just _) -> ins
Expand Down
4 changes: 3 additions & 1 deletion brat/brat.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ test-suite tests
Test.Elaboration,
Test.Failure,
Test.Graph,
Test.HugrGraph,
Test.Libs,
Test.Parsing,
Test.Naming,
Expand All @@ -175,7 +176,8 @@ test-suite tests
Test.TypeArith,
Test.Util

build-depends: base <5,
build-depends: aeson,
base <5,
brat,
tasty,
tasty-hunit,
Expand Down
3 changes: 3 additions & 0 deletions brat/test/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import Test.Graph
import Test.Compile.Hugr
import Test.Elaboration
import Test.Failure
import Test.HugrGraph
import Test.Libs
import Test.Naming
import Test.Parsing
Expand Down Expand Up @@ -64,6 +65,7 @@ main = do
parsingTests <- getParsingTests
compilationTests <- setupCompilationTests
graphTests <- getGraphTests
spliceTests <- getSpliceTests
let coroTests = testGroup "coroutine"
[testCase "coroT1" $ assertChecking coroT1
,testCase "coroT2" $ assertCheckingFail "Typechecking blocked on" coroT2
Expand All @@ -82,4 +84,5 @@ main = do
,compilationTests
,typeArithTests
,coroTests
,spliceTests
]
75 changes: 75 additions & 0 deletions brat/test/Test/HugrGraph.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
module Test.HugrGraph(getSpliceTests) where

import Brat.Naming as N
import Data.HugrGraph as H
import Data.Hugr

import Control.Monad.State (State, execState, get, runState)
import Data.Aeson (encode)
import Data.Functor ((<&>))
import Data.Maybe (isJust, isNothing)
import Data.List (find)
import qualified Data.ByteString.Lazy as BS
import System.Directory (createDirectoryIfMissing)
import System.FilePath
import Test.Tasty
import Test.Tasty.HUnit

prefix = "test/hugr"
outputDir = prefix </> "output"

addNode :: String -> NodeId -> HugrOp -> State HugrGraph NodeId
addNode nam parent op = do
name <- H.freshNode parent nam
H.setOp name op
pure name

getSpliceTests :: IO TestTree
getSpliceTests = do
createDirectoryIfMissing True outputDir
pure $ testGroup "splice" [testSplice False, testSplice True]

testSplice :: Bool -> TestTree
testSplice inline = testCaseInfo name $ do
let (h, holeId) = host
let outPrefix = outputDir </> name
BS.writeFile (outPrefix ++ "_host.json") (encode $ H.serialize h)
BS.writeFile (outPrefix ++ "_insertee.json") (encode $ H.serialize dfgHugr)
let spliced = H.splice h holeId dfgHugr
let resHugr@(Hugr (ns, _)) = H.serialize $ if inline
then execState (inlineDFG holeId) spliced else spliced
let outFile = outPrefix ++ "_result.json"
BS.writeFile outFile $ encode resHugr
assertBool "Should be no holes now" $ isNothing $ find (isJust . isHole) $ snd <$> ns
-- if inline, we should assert there's no DFG either
pure $ "Written to " ++ outFile ++ " pending validation"
where
name = if inline then "inline" else "noinline"
host :: (HugrGraph, NodeId)
host = swap $ flip runState (H.new N.root "root" rootDefn) $ do
root <- get <&> H.root
input <- addNode "inp" root (OpIn (InputNode tys []))
output <- addNode "out" root (OpOut (OutputNode tys []))
setFirstChildren root [input, output]
hole <- addNode "hole" root (OpCustom $ holeOp 0 tq_ty)
H.addEdge (Port input 0, Port hole 0)
H.addEdge (Port input 1, Port hole 1)
H.addEdge (Port hole 0, Port output 0)
H.addEdge (Port hole 1, Port output 1)
pure hole
dfgHugr :: HugrGraph
dfgHugr = flip execState (H.new N.root "root" rootDfg) $ do
root <- get <&> H.root
input <- addNode "inp" root (OpIn (InputNode tys []))
output <- addNode "out" root (OpOut (OutputNode tys []))
setFirstChildren root [input, output]
gate <- addNode "gate" root (OpCustom $ CustomOp "tket" "CX" tq_ty [])
H.addEdge (Port input 0, Port gate 0)
H.addEdge (Port input 1, Port gate 1)
H.addEdge (Port gate 0, Port output 0)
H.addEdge (Port gate 1, Port output 1)
swap (x,y) = (y,x)
tys = [HTQubit, HTQubit]
tq_ty = FunctionType tys tys bratExts
rootDefn = OpDefn $ FuncDefn "main" (PolyFuncType [] tq_ty) []
rootDfg = OpDFG $ DFG tq_ty []
28 changes: 15 additions & 13 deletions brat/tools/validate.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,22 @@ declare -a FAILED_TEST_MSGS
UNEXPECTED_PASSES=
NUM_FAILURES=0

for json in $(find test/compilation/output -maxdepth 1 -name "*.json"); do
echo Validating "$json"
RESULT=$(cat "$json" | hugr_validator 2>&1)
if [ $? -ne 0 ]; then
FAILED_TEST_NAMES[NUM_FAILURES]=$json
FAILED_TEST_MSGS[NUM_FAILURES]=$RESULT
NUM_FAILURES=$((NUM_FAILURES + 1))
fi
done
for dir in test/compilation/output test/hugr/output; do
for json in $(find $dir -maxdepth 1 -name "*.json"); do
echo Validating "$json"
RESULT=$(cat "$json" | hugr_validator 2>&1)
if [ $? -ne 0 ]; then
FAILED_TEST_NAMES[NUM_FAILURES]=$json
FAILED_TEST_MSGS[NUM_FAILURES]=$RESULT
NUM_FAILURES=$((NUM_FAILURES + 1))
fi
done

for invalid_json in $(find test/compilation/output -maxdepth 1 -name "*.json.invalid"); do
if (hugr_validator < $invalid_json 2>/dev/null > /dev/null); then
UNEXPECTED_PASSES="$UNEXPECTED_PASSES $invalid_json"
fi
for invalid_json in $(find test/compilation/output -maxdepth 1 -name "*.json.invalid"); do
if (hugr_validator < $invalid_json 2>/dev/null > /dev/null); then
UNEXPECTED_PASSES="$UNEXPECTED_PASSES $invalid_json"
fi
done
done

RED='\033[0;31m'
Expand Down