-- | Converting DDC expressions to and from Combinator Normal Form. module DDC.Core.Flow.Transform.Rates.Combinators ( Fun(..), Bind(..), ABind(..), SBind(..), Scalar(..) , Program(..) , CName(..) , lookupA, lookupS, lookupB , envOfBind , freeOfBind, cnameOfBind , outputsOfCluster, inputsOfCluster , seriesInputsOfCluster ) where import DDC.Base.Pretty import DDC.Core.Flow.Exp (ExpF) import Data.Maybe (catMaybes) import Data.List (nub) import Prelude hiding ((<$>)) ----------------------------------- -- = Combinator normal form. -- | Worker function. May only reference scalars in the environment, not arrays. -- Takes the expression of the function, and a list of the free scalars that are referenced inside it. -- The expression must be a function from scalar to scalar. data Fun s a = Fun ExpF [s] deriving Show -- | Array, scalar and external bindings. -- Array bindings are those whose value is an array, such as map, filter. -- Scalar bindings have scalar values, currently only fold. -- External expressions are those that cannot be converted to primitive combinators. -- The they take a single expression that computes all outputs, with the list of free scalar and array inputs. data Bind s a = ABind a (ABind s a) | SBind s (SBind s a) | Ext { _beOut :: CName s a , _beExp :: ExpF , _beIns :: ([s], [a]) } deriving Show -- | An array-valued binding. data ABind s a -- | map_n :: (a_1 ... a_n -> b) -> Array a_1 ... Array a_n -> Array b = MapN (Fun s a) [a] -- | filter :: (a -> Bool) -> Array a -> Array a | Filter (Fun s a) a -- | generate :: Nat -> (Nat -> a) -> Array a | Generate (Scalar s a) (Fun s a) -- | gather :: Array a -> Array Nat -> Array a | Gather a a -- | cross :: Array a -> Array b -> Array (a, b) | Cross a a deriving Show -- | Scalars can either be a literal such as "0", or a named scalar reference. -- If it's not a named scalar reference, we need to keep the expression so we can reconstruct it later. -- (We do not have array literals, so this is only necessary for scalars) data Scalar s a = Scalar ExpF (Maybe s) deriving Show -- | A scalar-valued binding data SBind s a -- | fold :: (a -> a -> a) -> a -> Array a -> a = Fold (Fun s a) (Scalar s a) a deriving Show -- | An entire program/function to find a fusion clustering for data Program s a = Program { _ins :: ([s], [a]) , _binds :: [Bind s a] , _outs :: ([s], [a]) } deriving Show -- | Name of a combinator. -- This will also be the name of the corresponding node of the graph. data CName s a = NameScalar s | NameArray a deriving (Eq, Ord, Show) lookupA :: Eq a => Program s a -> a -> Maybe (ABind s a) lookupA p a = go $ _binds p where go [] = Nothing go (ABind a' b : _) | a == a' = Just b go (_ : bs) = go bs lookupS :: Eq s => Program s a -> s -> Maybe (SBind s a) lookupS p s = go $ _binds p where go [] = Nothing go (SBind s' b : _) | s == s' = Just b go (_ : bs) = go bs lookupB :: (Eq s, Eq a) => Program s a -> CName s a -> Maybe (Bind s a) lookupB p nm = go (_binds p) where go [] = Nothing go (b@(ABind a _) : _) | NameArray a' <- nm , a == a' = Just b go (b@(SBind s _) : _) | NameScalar s' <- nm , s == s' = Just b go (b@(Ext nm' _ _) : _) | nm == nm' = Just $ b go (_ : bs) = go bs envOfBind :: Bind s a -> ([s], [a]) envOfBind (SBind s _) = ([s], []) envOfBind (ABind a _) = ([], [a]) envOfBind (Ext (NameScalar s) _ _) = ([s], []) envOfBind (Ext (NameArray a) _ _) = ([], [a]) cnameOfBind :: Bind s a -> CName s a cnameOfBind (SBind s _) = NameScalar s cnameOfBind (ABind a _) = NameArray a cnameOfBind (Ext n _ _) = n freeOfBind :: Bind s a -> [CName s a] freeOfBind b = case b of SBind _ (Fold fun i a) -> ffun fun ++ fscalar i ++ [fa a] ABind _ (MapN fun as) -> ffun fun ++ map fa as ABind _ (Filter fun a) -> ffun fun ++ [fa a] ABind _ (Generate s fun) -> ffun fun ++ fscalar s ABind _ (Gather x y) -> [fa x, fa y] ABind _ (Cross x y) -> [fa x, fa y] Ext _ _ (inS,inA) -> map fs inS ++ map fa inA where ffun (Fun _ f) = map fs f fscalar (Scalar _ Nothing) = [] fscalar (Scalar _ (Just s)) = [NameScalar s] fs = NameScalar fa = NameArray -- | Get inputs that must be converted to series or rate vectors seriesInputOfBind :: Bind s a -> [a] seriesInputOfBind b = case b of SBind _ (Fold _fun _i a) -> [a] ABind _ (MapN _fun as) -> as ABind _ (Filter _fun a) -> [a] ABind _ (Generate _s _fun) -> [] ABind _ (Gather v ix) -- Only the indices array is consumed series-wise. -- The vector is random access. -> [v, ix] -- Cross product's first is consumed in series, but second is consumed multiple times ABind _ (Cross x y) -> [x, y] -- Externals do not require series inputs. Ext _ _ (_inS,_inA) -> [] -- | For a given program and list of nodes that will be clustered together, -- find a list of the nodes that are used afterwards. -- Only these nodes must be made manifest. -- The output nodes is a subset of the input cluster nodes. outputsOfCluster :: (Eq s, Eq a) => Program s a -> [CName s a] -> [CName s a] outputsOfCluster prog cluster -- Get all bindings in the program that aren't in this cluster = let notin = filter (not . flip elem cluster . cnameOfBind) (_binds prog) -- And find their free variables frees = concatMap freeOfBind notin -- Convert the returns of the program to CNames (ss,as) = _outs prog pouts = map NameScalar ss ++ map NameArray as -- We want to look in both returns and free variables of bindings alls = frees ++ pouts -- Now search through and find those in the cluster found = filter (flip elem cluster) alls in nub $ found -- | For a given program and list of nodes that will be clustered together, -- find a list of the nodes that are used as inputs. -- The input nodes will not mention any of the cluster nodes. inputsOfCluster :: (Eq s, Eq a) => Program s a -> [CName s a] -> [CName s a] inputsOfCluster prog cluster -- Get bindings of clusters = let binds = catMaybes $ map (lookupB prog) cluster -- And find the free variables frees = concatMap freeOfBind binds -- Ignore the ones in the cluster found = filter (not . flip elem cluster) frees in nub $ found -- | For a given program and list of nodes that will be clustered together, -- find a list of the inputs that need to be converted to series. -- If the cluster is correct, these should all be the same size. seriesInputsOfCluster :: (Eq s, Eq a) => Program s a -> [CName s a] -> [a] seriesInputsOfCluster prog cluster -- Get bindings of clusters = let binds = catMaybes $ map (lookupB prog) cluster -- And find the free variables frees = concatMap seriesInputOfBind binds -- Ignore the ones in the cluster found = filter (not . flip elem cluster . NameArray) frees in nub $ found ----------------------------------- -- == Pretty printing -- This is just the notation I used in the prototype. instance (Pretty s, Pretty a) => Pretty (Fun s a) where ppr (Fun _ ss) = encloseSep lbrace rbrace space $ map ppr ss instance (Pretty s, Pretty a) => Pretty (Scalar s a) where ppr (Scalar _ Nothing) = text "-" ppr (Scalar _ (Just s)) = ppr s instance (Pretty s, Pretty a) => Pretty (Bind s a) where ppr (SBind n (Fold f i a)) = bind (ppr n) "reduce" (ppr f <+> ppr i <+> ppr a) ppr (ABind n (MapN f as)) = bind (ppr n) "mapN" (ppr f <+> hsep (map ppr as)) ppr (ABind n (Filter f a)) = bind (ppr n) "filter" (ppr f <+> ppr a) ppr (ABind n (Gather a b)) = bind (ppr n) "gather" (ppr a <+> ppr b) ppr (ABind n (Generate sz f)) = bind (ppr n) "generate" (ppr sz <+> ppr f) ppr (ABind n (Cross a b)) = bind (ppr n) "cross" (ppr a <+> ppr b) ppr (Ext out _ ins) = bind (ppr out) "external" (binds ins) where binds (ss,as) = encloseSep lbrace rbrace space (map ppr ss) <+> hcat (map ppr as) bind :: Doc -> String -> Doc -> Doc bind nm com args = nm <+> nest 4 (equals <+> text com <+> args) instance (Pretty s, Pretty a) => Pretty (Program s a) where ppr (Program ins binds outs) = params <$> vcat (map ppr binds) <$> returns where params = vcat (map (\i -> text "param scalar" <+> ppr i) (fst ins)) <$> vcat (map (\i -> text "param array" <+> ppr i) (snd ins)) returns = vcat (map (\i -> text "return" <+> ppr i) (fst outs)) <$> vcat (map (\i -> text "return" <+> ppr i) (snd outs)) instance (Pretty s, Pretty a) => Pretty (CName s a) where ppr (NameScalar s) = text "{" <> ppr s <> text "}" ppr (NameArray a) = ppr a