{-# OPTIONS_GHC -Wall #-} module Dvda.Algorithm.Construct ( Algorithm(..) , AlgOp(..) , Node(..) , InputIdx(..) , OutputIdx(..) , constructAlgorithm , squashWorkVector ) where import qualified Data.Foldable as F import Data.Maybe ( fromMaybe ) import qualified Data.Traversable as T import qualified Data.IntMap as IM import qualified Data.Vector as V import qualified Data.HashMap.Lazy as HM import Dvda.Expr import Dvda.Algorithm.FunGraph ( FunGraph(..), Node(..), toFunGraph ) newtype InputIdx = InputIdx Int deriving Show newtype OutputIdx = OutputIdx Int deriving Show data AlgOp a = InputOp Node InputIdx | OutputOp Node OutputIdx | NormalOp Node (GExpr a Node) deriving Show data Algorithm a = Algorithm { algInDims :: Int , algOutDims :: Int , algOps :: [AlgOp a] , algWorkSize :: Int } newtype LiveNode = LiveNode Int deriving Show newtype NodeMap a = NodeMap (IM.IntMap a) deriving Show nmEmpty :: NodeMap a nmEmpty = NodeMap IM.empty nmInsertWith :: (a -> a -> a) -> Node -> a -> NodeMap a -> NodeMap a nmInsertWith f (Node k) v (NodeMap im) = NodeMap (IM.insertWith f k v im) nmLookup :: Node -> NodeMap a -> Maybe a nmLookup (Node k) (NodeMap im) = IM.lookup k im nmInsert :: Node -> a -> NodeMap a -> NodeMap a nmInsert (Node k) val (NodeMap im) = NodeMap $ IM.insert k val im squashWorkVec' :: NodeMap Int -> NodeMap LiveNode -> [LiveNode] -> [AlgOp a] -> [AlgOp a] squashWorkVec' accessMap liveMap0 (LiveNode pool0:pools) (InputOp k inIdx:xs) = InputOp (Node pool0) inIdx : squashWorkVec' accessMap liveMap pools xs where -- input to the the first element of the live pool -- update the liveMap to reflect this -- update the pool to reflect this liveMap = nmInsertWith err k (LiveNode pool0) liveMap0 err = error "SSA node written to more than once" squashWorkVec' accessMap0 liveMap0 pool0 (OutputOp k outIdx:xs) = OutputOp (Node lk) outIdx : squashWorkVec' accessMap liveMap0 pool xs where -- output from node looked up from live variables (LiveNode lk) = fromMaybe noLiveErr (nmLookup k liveMap0) where noLiveErr = error "OutputOp couldn't find node in live map" -- decrement access map, if references are now zero, add live node back to pool (accessMap, pool) = case nmLookup k accessMap0 of Just 0 -> error "squashWorkVec': accessed something with 0 references" Just 1 -> (nmInsert k 0 accessMap0, LiveNode lk:pool0) Just n -> (nmInsert k (n-1) accessMap0, pool0) Nothing -> error "squashWorkVec': node not in access map" squashWorkVec' accessMap0 liveMap0 pool0 (NormalOp k gexpr0:xs) = NormalOp (Node retLiveK) gexpr : squashWorkVec' accessMap liveMap pool xs where decrement (am0, p0) depk = case nmLookup depk am0 of Just 0 -> error "squashWorkVec': accessed something with 0 references" Just 1 -> ((nmInsert depk 0 am0, LiveNode lk:p0), Node lk) Just n -> ((nmInsert depk (n-1) am0, p0), Node lk) Nothing -> error "squashWorkVec': node not in access map" where LiveNode lk = fromMaybe (error "depsLiveKs missing") (nmLookup depk liveMap0) ((accessMap, LiveNode retLiveK:pool), gexpr) = T.mapAccumL decrement (accessMap0, pool0) gexpr0 liveMap = nmInsert k (LiveNode retLiveK) liveMap0 squashWorkVec' _ _ _ [] = [] squashWorkVec' _ _ [] _ = error "squashWorkVec': empty pool" -- | Converts SSA to live variables. -- This reduces the size of the work vector by re-using dead registers. -- Does this break if it's called more than once? -- Maybe these should have different types squashWorkVector :: Algorithm a -> Algorithm a squashWorkVector alg = Algorithm { algOps = newAlgOps , algInDims = algInDims alg , algOutDims = algOutDims alg , algWorkSize = workVectorSize newAlgOps } where addOne k = nmInsertWith (+) k (1::Int) countAccesses accMap (InputOp _ _:xs) = countAccesses accMap xs countAccesses accMap (OutputOp k _:xs) = countAccesses (addOne k accMap) xs countAccesses accMap0 (NormalOp _ gexpr:xs) = countAccesses accMap xs where accMap = F.foldr addOne accMap0 gexpr countAccesses accMap [] = accMap accesses = countAccesses nmEmpty (algOps alg) newAlgOps = squashWorkVec' accesses nmEmpty (map LiveNode [0..]) (algOps alg) graphToAlg :: [(Node,GExpr a Node)] -> V.Vector (Sym,InputIdx) -> V.Vector (Node,OutputIdx) -> [AlgOp a] graphToAlg rgr0 inSyms outIdxs = f rgr0 where inSymMap = HM.fromList (F.toList inSyms :: [(Sym,InputIdx)]) outIdxMap = IM.fromList (map (\(Node k, x) -> (k, x)) (F.toList outIdxs :: [(Node,OutputIdx)])) f ((k@(Node k'),GSym s):xs) = case HM.lookup s inSymMap of Nothing -> error "toAlg: symbolic is not in inputs" Just inIdx -> case IM.lookup k' outIdxMap of -- sym is an input Nothing -> InputOp k inIdx : f xs -- sym is an input and an output Just outIdx -> InputOp k inIdx : OutputOp k outIdx : f xs f ((k@(Node k'),x):xs) = case IM.lookup k' outIdxMap of -- no input or output Nothing -> NormalOp k x : f xs -- output only Just outIdx -> NormalOp k x : OutputOp k outIdx : f xs f [] = [] workVectorSize :: [AlgOp a] -> Int workVectorSize = workVectorSize' (-1) where workVectorSize' n (NormalOp (Node m) _:xs) = workVectorSize' (max n m) xs workVectorSize' n (InputOp (Node m) _:xs) = workVectorSize' (max n m) xs workVectorSize' n (OutputOp (Node m) _:xs) = workVectorSize' (max n m) xs workVectorSize' n [] = n+1 -- | create a SSA algorithm from a vector of symbolic inputs and outputs constructAlgorithm :: V.Vector (Expr a) -> V.Vector (Expr a) -> IO (Algorithm a) constructAlgorithm inputVecs outputVecs = do fg <- toFunGraph inputVecs outputVecs let inputIdxs = V.map (\(k,x) -> (x, InputIdx k)) (V.indexed ( fgInputs fg)) outputIdxs = V.map (\(k,x) -> (x, OutputIdx k)) (V.indexed (fgOutputs fg)) ops = graphToAlg (fgReified fg) inputIdxs outputIdxs return Algorithm { algInDims = V.length inputIdxs , algOutDims = V.length outputIdxs , algOps = ops , algWorkSize = workVectorSize ops }