{-# OPTIONS_GHC -Wall #-} module Dvda.Algorithm.FunGraph ( FunGraph(..) , Node(..) , toFunGraph ) where import Control.Applicative ( (<$>) ) import Data.Foldable ( Foldable ) import qualified Data.Foldable as F import qualified Data.Graph as Graph import qualified Data.HashSet as HS import Data.Traversable ( Traversable ) import Dvda.Expr import Dvda.Algorithm.Reify ( ReifyGraph(..), Node(..), reifyGraph ) data FunGraph f g a = FunGraph { fgInputs :: f Sym , fgOutputs :: g Node , fgReified :: [(Node, GExpr a Node)] , fgTopSort :: [(Node, GExpr a Node)] } -- | find any symbols which are parents of outputs, but are not supplied by the user detectMissingInputs :: Foldable f => f (Expr a) -> [(Node, GExpr a Node)] -> [Sym] detectMissingInputs exprs gr = HS.toList $ HS.difference allGraphInputs allUserInputs where allUserInputs = let f (ESym name) acc = name : acc f _ _ = error $ "detectMissingInputs given non-ESym input" -- \"" ++ show e ++ "\"" in HS.fromList $ F.foldr f [] exprs allGraphInputs = let f (_, GSym name) acc = name : acc f _ acc = acc in HS.fromList $ foldr f [] gr -- | if the same input symbol (like ESym "x") is given at two different places throw an exception findConflictingInputs :: Foldable f => f Sym -> [Sym] findConflictingInputs syms = HS.toList redundant where redundant = snd $ F.foldl f (HS.empty, HS.empty) syms where f (knownExprs, redundantExprs) s | HS.member s knownExprs = (knownExprs, HS.insert s redundantExprs) | otherwise = (HS.insert s knownExprs, redundantExprs) -- | Take inputs and outputs and traverse the outputs reifying all expressions -- and creating a hashmap of StableNames. Once the hashmap is created, -- lookup the provided inputs and return a FunGraph which contains an -- expression graph, input/output indices, and other useful functions. -- StableNames may be non-deterministic so this function may return graphs -- with greater or fewer CSE's eliminated. -- If CSE is then performed on the graph, the result is deterministic. toFunGraph :: (Functor f, Foldable f, Traversable g) => f (Expr a) -> g (Expr a) -> IO (FunGraph f g a) toFunGraph inputExprs outputExprs = do -- reify the outputs (ReifyGraph rgr, outputIndices) <- reifyGraph outputExprs let userInputSyms = fmap f inputExprs where f (ESym s) = s f _ = error $ "ERROR: toFunGraph given non-ESym input" -- \"" ++ show x ++ "\"" fg = FunGraph { fgInputs = userInputSyms , fgOutputs = outputIndices , fgReified = reverse rgr , fgTopSort = topSort } -- make sure all the inputs are symbolic, and find their indices in the Expr graph (gr, lookupVertex, lookupKey) = Graph.graphFromEdges $ map (\(k,gexpr) -> (gexpr, k, F.toList gexpr)) rgr lookupG k = (\(g,_,_) -> g) <$> lookupVertex <$> lookupKey k topSort = map lookup' $ reverse $ map ((\(_,k,_) -> k) . lookupVertex) $ Graph.topSort gr lookup' k = case lookupG k of Nothing -> error "DVDA internal error" Just g -> (k,g) return $ case (detectMissingInputs inputExprs rgr, findConflictingInputs userInputSyms) of ([],[]) -> fg (xs,[]) -> error $ "toFunGraph found inputs that were not provided by the user: " ++ show xs ( _,xs) -> error $ "toFunGraph found conflicting inputs: " ++ show xs