module Dvda.FunGraph ( FunGraph
, ToFunGraph
, NumT
, (:*)(..)
, MVS(..)
, toFunGraph
, countNodes
, fgInputs
, fgOutputs
, fgLookupGExpr
, fgReified
, topSort
, nodelistToFunGraph
, exprsToFunGraph
) where
import Control.Applicative
import Data.Foldable ( Foldable )
import qualified Data.Foldable as F
import qualified Data.Graph as Graph
import Data.Hashable ( Hashable )
import qualified Data.HashSet as HS
import Data.Traversable ( Traversable )
import qualified Data.Traversable as T
import Dvda.Expr
import Dvda.Reify ( ReifyGraph(..), reifyGraphs )
data FunGraph a = FunGraph { fgGraph :: Graph.Graph
, fgInputs :: [MVS (GExpr a Int)]
, fgOutputs :: [MVS Int]
, fgReified :: [(Int, GExpr a Int)]
, fgLookupGExpr :: Int -> Maybe (GExpr a Int)
, fgVertexFromKey :: Int -> Maybe Int
, fgNodeFromVertex :: Int -> (GExpr a Int, Int, [Int])
}
instance Show a => Show (FunGraph a) where
show fg = "FunGraph\ninputs:\n" ++ show (fgInputs fg) ++ "\noutputs:\n" ++ show (fgOutputs fg) ++ "\ngraph:\n" ++ show (fgGraph fg)
data MVS a = Mat [[a]] | Vec [a] | Sca a deriving Show
instance Functor MVS where
fmap f (Sca x) = Sca (f x)
fmap f (Vec xs) = Vec (map f xs)
fmap f (Mat xs) = Mat (map (map f) xs)
instance Foldable MVS where
foldr f x0 (Sca x) = foldr f x0 [x]
foldr f x0 (Vec xs) = foldr f x0 xs
foldr f x0 (Mat xs) = foldr f x0 (concat xs)
instance Traversable MVS where
traverse f (Sca x) = Sca <$> f x
traverse f (Vec xs) = Vec <$> T.traverse f xs
traverse f (Mat xs) = Mat <$> T.traverse (T.traverse f) xs
class ToFunGraph a where
type NumT a
toMVSList :: a -> [MVS (Expr (NumT a))]
instance ToFunGraph (Expr a) where
type NumT (Expr a) = a
toMVSList x = [Sca x]
instance ToFunGraph [Expr a] where
type NumT [Expr a] = NumT (Expr a)
toMVSList x = [Vec x]
instance ToFunGraph [[Expr a]] where
type NumT [[Expr a]] = NumT [Expr a]
toMVSList x = [Mat x]
data a :* b = a :* b deriving Show
infixr 6 :*
instance (ToFunGraph a, ToFunGraph b, NumT a ~ NumT b) => ToFunGraph (a :* b) where
type NumT (a :* b) = NumT a
toMVSList (x :* y) = toMVSList x ++ toMVSList y
detectMissingInputs :: (Eq a, Hashable a, Show a) => [MVS (Expr a)] -> [(Int,GExpr a Int)] -> [GExpr a Int]
detectMissingInputs exprs gr = HS.toList $ HS.difference allGraphInputs allUserInputs
where
allUserInputs = let f (ESym name) acc = (GSym name):acc
f _ e = error $ "detectMissingInputs given non-ESym input \"" ++ show e ++ "\""
in HS.fromList $ foldr f [] (concatMap F.toList exprs)
allGraphInputs = let f (_,(GSym name)) acc = (GSym name):acc
f _ acc = acc
in HS.fromList $ foldr f [] gr
findConflictingInputs :: (Eq a, Hashable a, Show a) => [MVS (Expr a)] -> [Expr a]
findConflictingInputs exprs = HS.toList redundant
where
redundant = snd $ foldl f (HS.empty, HS.empty) (concatMap F.toList exprs)
where
f (knownExprs, redundantExprs) expr@(ESym _)
| HS.member expr knownExprs = (knownExprs, HS.insert expr redundantExprs)
| otherwise = (HS.insert expr knownExprs, redundantExprs)
f _ e = error $ "findConflictingInputs saw non-ESym input \"" ++ show e ++ "\""
toFunGraph :: (Eq a, Hashable a, Show a, ToFunGraph b, ToFunGraph c, NumT b ~ a, NumT c ~ a)
=> b -> c -> IO (FunGraph a)
toFunGraph inputs outputs = mvsToFunGraph (toMVSList inputs) (toMVSList outputs)
mvsToFunGraph :: (Eq a, Hashable a, Show a) => [MVS (Expr a)] -> [MVS (Expr a)] -> IO (FunGraph a)
mvsToFunGraph inputMVSExprs outputMVSExprs = do
(ReifyGraph rgr, outputMVSIndices) <- reifyGraphs outputMVSExprs
let fg = nodelistToFunGraph rgr inputMVSGExprs outputMVSIndices
inputMVSGExprs = map (fmap f) inputMVSExprs
where
f (ESym name) = (GSym name)
f x = error $ "ERROR: mvsToFunGraph given non-ESym input \"" ++ show x ++ "\""
return $ case (detectMissingInputs inputMVSExprs rgr, findConflictingInputs inputMVSExprs) of
([],[]) -> fg
(xs,[]) -> error $ "mvsToFunGraph found inputs that were not provided by the user: " ++ show xs
( _,xs) -> error $ "mvsToFunGraph found idential inputs set more than once: " ++ show xs
nodelistToFunGraph :: [(Int,GExpr a Int)] -> [MVS (GExpr a Int)] -> [MVS Int] -> FunGraph a
nodelistToFunGraph rgr inputMVSIndices outputMVSIndices =
FunGraph { fgGraph = gr
, fgInputs = inputMVSIndices
, fgOutputs = outputMVSIndices
, fgLookupGExpr = lookupG
, fgReified = rgr
, fgVertexFromKey = lookupKey
, fgNodeFromVertex = lookupVertex
}
where
(gr, lookupVertex, lookupKey) = Graph.graphFromEdges $ map (\(k,gexpr) -> (gexpr, k, getParents gexpr)) rgr
lookupG k = (\(g,_,_) -> g) <$> lookupVertex <$> lookupKey k
countNodes :: FunGraph a -> Int
countNodes = length . Graph.vertices . fgGraph
topSort :: FunGraph a -> [Int]
topSort fg = map ((\(_,k,_) -> k) . (fgNodeFromVertex fg)) $ Graph.topSort (fgGraph fg)
exprsToFunGraph :: (Eq a, Show a, Hashable a) => [Expr a] -> IO (FunGraph a)
exprsToFunGraph outputs = do
let getSyms :: [Expr a] -> [Sym]
getSyms exprs = HS.toList $ foldr (\acc expr -> foldExpr f expr acc) HS.empty exprs
where
f (ESym s) hs = HS.insert s hs
f _ hs = hs
inputs = map ESym $ getSyms outputs
toFunGraph inputs outputs