module Dvda.Algorithm.Construct
( Algorithm(..)
, AlgOp(..)
, Node(..)
, InputIdx(..)
, OutputIdx(..)
, constructAlgorithm
, squashWorkVector
) where
import qualified Data.Foldable as F
import Data.Maybe ( fromMaybe )
import qualified Data.Traversable as T
import qualified Data.IntMap as IM
import qualified Data.Vector as V
import qualified Data.HashMap.Lazy as HM
import Dvda.Expr
import Dvda.Algorithm.FunGraph ( FunGraph(..), Node(..), toFunGraph )
newtype InputIdx = InputIdx Int deriving Show
newtype OutputIdx = OutputIdx Int deriving Show
data AlgOp a = InputOp Node InputIdx
| OutputOp Node OutputIdx
| NormalOp Node (GExpr a Node)
deriving Show
data Algorithm a = Algorithm { algInDims :: Int
, algOutDims :: Int
, algOps :: [AlgOp a]
, algWorkSize :: Int
}
newtype LiveNode = LiveNode Int deriving Show
newtype NodeMap a = NodeMap (IM.IntMap a) deriving Show
nmEmpty :: NodeMap a
nmEmpty = NodeMap IM.empty
nmInsertWith :: (a -> a -> a) -> Node -> a -> NodeMap a -> NodeMap a
nmInsertWith f (Node k) v (NodeMap im) = NodeMap (IM.insertWith f k v im)
nmLookup :: Node -> NodeMap a -> Maybe a
nmLookup (Node k) (NodeMap im) = IM.lookup k im
nmInsert :: Node -> a -> NodeMap a -> NodeMap a
nmInsert (Node k) val (NodeMap im) = NodeMap $ IM.insert k val im
squashWorkVec' :: NodeMap Int -> NodeMap LiveNode -> [LiveNode] -> [AlgOp a] -> [AlgOp a]
squashWorkVec' accessMap liveMap0 (LiveNode pool0:pools) (InputOp k inIdx:xs) =
InputOp (Node pool0) inIdx : squashWorkVec' accessMap liveMap pools xs
where
liveMap = nmInsertWith err k (LiveNode pool0) liveMap0
err = error "SSA node written to more than once"
squashWorkVec' accessMap0 liveMap0 pool0 (OutputOp k outIdx:xs) =
OutputOp (Node lk) outIdx : squashWorkVec' accessMap liveMap0 pool xs
where
(LiveNode lk) = fromMaybe noLiveErr (nmLookup k liveMap0)
where noLiveErr = error "OutputOp couldn't find node in live map"
(accessMap, pool) = case nmLookup k accessMap0 of
Just 0 -> error "squashWorkVec': accessed something with 0 references"
Just 1 -> (nmInsert k 0 accessMap0, LiveNode lk:pool0)
Just n -> (nmInsert k (n1) accessMap0, pool0)
Nothing -> error "squashWorkVec': node not in access map"
squashWorkVec' accessMap0 liveMap0 pool0 (NormalOp k gexpr0:xs) =
NormalOp (Node retLiveK) gexpr : squashWorkVec' accessMap liveMap pool xs
where
decrement (am0, p0) depk = case nmLookup depk am0 of
Just 0 -> error "squashWorkVec': accessed something with 0 references"
Just 1 -> ((nmInsert depk 0 am0, LiveNode lk:p0), Node lk)
Just n -> ((nmInsert depk (n1) am0, p0), Node lk)
Nothing -> error "squashWorkVec': node not in access map"
where
LiveNode lk = fromMaybe (error "depsLiveKs missing") (nmLookup depk liveMap0)
((accessMap, LiveNode retLiveK:pool), gexpr) =
T.mapAccumL decrement (accessMap0, pool0) gexpr0
liveMap = nmInsert k (LiveNode retLiveK) liveMap0
squashWorkVec' _ _ _ [] = []
squashWorkVec' _ _ [] _ = error "squashWorkVec': empty pool"
squashWorkVector :: Algorithm a -> Algorithm a
squashWorkVector alg =
Algorithm { algOps = newAlgOps
, algInDims = algInDims alg
, algOutDims = algOutDims alg
, algWorkSize = workVectorSize newAlgOps
}
where
addOne k = nmInsertWith (+) k (1::Int)
countAccesses accMap (InputOp _ _:xs) = countAccesses accMap xs
countAccesses accMap (OutputOp k _:xs) = countAccesses (addOne k accMap) xs
countAccesses accMap0 (NormalOp _ gexpr:xs) = countAccesses accMap xs
where
accMap = F.foldr addOne accMap0 gexpr
countAccesses accMap [] = accMap
accesses = countAccesses nmEmpty (algOps alg)
newAlgOps = squashWorkVec' accesses nmEmpty (map LiveNode [0..]) (algOps alg)
graphToAlg :: [(Node,GExpr a Node)] -> V.Vector (Sym,InputIdx) -> V.Vector (Node,OutputIdx)
-> [AlgOp a]
graphToAlg rgr0 inSyms outIdxs = f rgr0
where
inSymMap = HM.fromList (F.toList inSyms :: [(Sym,InputIdx)])
outIdxMap = IM.fromList (map (\(Node k, x) -> (k, x)) (F.toList outIdxs :: [(Node,OutputIdx)]))
f ((k@(Node k'),GSym s):xs) = case HM.lookup s inSymMap of
Nothing -> error "toAlg: symbolic is not in inputs"
Just inIdx -> case IM.lookup k' outIdxMap of
Nothing -> InputOp k inIdx : f xs
Just outIdx -> InputOp k inIdx : OutputOp k outIdx : f xs
f ((k@(Node k'),x):xs) = case IM.lookup k' outIdxMap of
Nothing -> NormalOp k x : f xs
Just outIdx -> NormalOp k x : OutputOp k outIdx : f xs
f [] = []
workVectorSize :: [AlgOp a] -> Int
workVectorSize = workVectorSize' (1)
where
workVectorSize' n (NormalOp (Node m) _:xs) = workVectorSize' (max n m) xs
workVectorSize' n (InputOp (Node m) _:xs) = workVectorSize' (max n m) xs
workVectorSize' n (OutputOp (Node m) _:xs) = workVectorSize' (max n m) xs
workVectorSize' n [] = n+1
constructAlgorithm :: V.Vector (Expr a) -> V.Vector (Expr a) -> IO (Algorithm a)
constructAlgorithm inputVecs outputVecs = do
fg <- toFunGraph inputVecs outputVecs
let inputIdxs = V.map (\(k,x) -> (x, InputIdx k)) (V.indexed ( fgInputs fg))
outputIdxs = V.map (\(k,x) -> (x, OutputIdx k)) (V.indexed (fgOutputs fg))
ops = graphToAlg (fgReified fg) inputIdxs outputIdxs
return Algorithm { algInDims = V.length inputIdxs
, algOutDims = V.length outputIdxs
, algOps = ops
, algWorkSize = workVectorSize ops
}