module DDC.Core.Flow.Transform.Rates.Combinators
( Fun(..), Bind(..), ABind(..), SBind(..), Scalar(..)
, Program(..)
, CName(..)
, lookupA, lookupS, lookupB
, envOfBind
, freeOfBind, cnameOfBind
, outputsOfCluster, inputsOfCluster
, seriesInputsOfCluster
)
where
import DDC.Base.Pretty
import DDC.Core.Flow.Exp (ExpF)
import Data.Maybe (catMaybes)
import Data.List (nub)
import Prelude hiding ((<$>))
data Fun s a
= Fun ExpF [s]
deriving Show
data Bind s a
= ABind a (ABind s a)
| SBind s (SBind s a)
| Ext
{ _beOut :: CName s a
, _beExp :: ExpF
, _beIns :: ([s], [a])
}
deriving Show
data ABind s a
= MapN (Fun s a) [a]
| Filter (Fun s a) a
| Generate (Scalar s a) (Fun s a)
| Gather a a
| Cross a a
deriving Show
data Scalar s a
= Scalar ExpF (Maybe s)
deriving Show
data SBind s a
= Fold (Fun s a) (Scalar s a) a
deriving Show
data Program s a
= Program
{ _ins :: ([s], [a])
, _binds :: [Bind s a]
, _outs :: ([s], [a])
}
deriving Show
data CName s a
= NameScalar s
| NameArray a
deriving (Eq, Ord, Show)
lookupA :: Eq a => Program s a -> a -> Maybe (ABind s a)
lookupA p a
= go $ _binds p
where
go [] = Nothing
go (ABind a' b : _)
| a == a'
= Just b
go (_ : bs)
= go bs
lookupS :: Eq s => Program s a -> s -> Maybe (SBind s a)
lookupS p s
= go $ _binds p
where
go [] = Nothing
go (SBind s' b : _)
| s == s'
= Just b
go (_ : bs)
= go bs
lookupB :: (Eq s, Eq a) => Program s a -> CName s a -> Maybe (Bind s a)
lookupB p nm = go (_binds p)
where
go [] = Nothing
go (b@(ABind a _) : _)
| NameArray a' <- nm
, a == a'
= Just b
go (b@(SBind s _) : _)
| NameScalar s' <- nm
, s == s'
= Just b
go (b@(Ext nm' _ _) : _)
| nm == nm'
= Just $ b
go (_ : bs)
= go bs
envOfBind :: Bind s a -> ([s], [a])
envOfBind (SBind s _) = ([s], [])
envOfBind (ABind a _) = ([], [a])
envOfBind (Ext (NameScalar s) _ _) = ([s], [])
envOfBind (Ext (NameArray a) _ _) = ([], [a])
cnameOfBind :: Bind s a -> CName s a
cnameOfBind (SBind s _) = NameScalar s
cnameOfBind (ABind a _) = NameArray a
cnameOfBind (Ext n _ _) = n
freeOfBind :: Bind s a -> [CName s a]
freeOfBind b
= case b of
SBind _ (Fold fun i a)
-> ffun fun ++ fscalar i ++ [fa a]
ABind _ (MapN fun as)
-> ffun fun ++ map fa as
ABind _ (Filter fun a)
-> ffun fun ++ [fa a]
ABind _ (Generate s fun)
-> ffun fun ++ fscalar s
ABind _ (Gather x y)
-> [fa x, fa y]
ABind _ (Cross x y)
-> [fa x, fa y]
Ext _ _ (inS,inA)
-> map fs inS ++ map fa inA
where
ffun (Fun _ f) = map fs f
fscalar (Scalar _ Nothing) = []
fscalar (Scalar _ (Just s)) = [NameScalar s]
fs = NameScalar
fa = NameArray
seriesInputOfBind :: Bind s a -> [a]
seriesInputOfBind b
= case b of
SBind _ (Fold _fun _i a)
-> [a]
ABind _ (MapN _fun as)
-> as
ABind _ (Filter _fun a)
-> [a]
ABind _ (Generate _s _fun)
-> []
ABind _ (Gather v ix)
-> [v, ix]
ABind _ (Cross x y)
-> [x, y]
Ext _ _ (_inS,_inA)
-> []
outputsOfCluster :: (Eq s, Eq a) => Program s a -> [CName s a] -> [CName s a]
outputsOfCluster prog cluster
= let notin = filter (not . flip elem cluster . cnameOfBind) (_binds prog)
frees = concatMap freeOfBind notin
(ss,as) = _outs prog
pouts = map NameScalar ss ++ map NameArray as
alls = frees ++ pouts
found = filter (flip elem cluster) alls
in nub $ found
inputsOfCluster :: (Eq s, Eq a) => Program s a -> [CName s a] -> [CName s a]
inputsOfCluster prog cluster
= let binds = catMaybes
$ map (lookupB prog) cluster
frees = concatMap freeOfBind binds
found = filter (not . flip elem cluster) frees
in nub $ found
seriesInputsOfCluster :: (Eq s, Eq a) => Program s a -> [CName s a] -> [a]
seriesInputsOfCluster prog cluster
= let binds = catMaybes
$ map (lookupB prog) cluster
frees = concatMap seriesInputOfBind binds
found = filter (not . flip elem cluster . NameArray) frees
in nub $ found
instance (Pretty s, Pretty a) => Pretty (Fun s a) where
ppr (Fun _ ss)
= encloseSep lbrace rbrace space
$ map ppr ss
instance (Pretty s, Pretty a) => Pretty (Scalar s a) where
ppr (Scalar _ Nothing)
= text "-"
ppr (Scalar _ (Just s))
= ppr s
instance (Pretty s, Pretty a) => Pretty (Bind s a) where
ppr (SBind n (Fold f i a))
= bind (ppr n) "reduce" (ppr f <+> ppr i <+> ppr a)
ppr (ABind n (MapN f as))
= bind (ppr n) "mapN" (ppr f <+> hsep (map ppr as))
ppr (ABind n (Filter f a))
= bind (ppr n) "filter" (ppr f <+> ppr a)
ppr (ABind n (Gather a b))
= bind (ppr n) "gather" (ppr a <+> ppr b)
ppr (ABind n (Generate sz f))
= bind (ppr n) "generate" (ppr sz <+> ppr f)
ppr (ABind n (Cross a b))
= bind (ppr n) "cross" (ppr a <+> ppr b)
ppr (Ext out _ ins)
= bind (ppr out) "external" (binds ins)
where
binds (ss,as)
= encloseSep lbrace rbrace space (map ppr ss) <+> hcat (map ppr as)
bind :: Doc -> String -> Doc -> Doc
bind nm com args
= nm <+> nest 4 (equals <+> text com <+> args)
instance (Pretty s, Pretty a) => Pretty (Program s a) where
ppr (Program ins binds outs)
= params <$> vcat (map ppr binds) <$> returns
where
params
= vcat (map (\i -> text "param scalar" <+> ppr i) (fst ins))
<$> vcat (map (\i -> text "param array" <+> ppr i) (snd ins))
returns
= vcat (map (\i -> text "return" <+> ppr i) (fst outs))
<$> vcat (map (\i -> text "return" <+> ppr i) (snd outs))
instance (Pretty s, Pretty a) => Pretty (CName s a) where
ppr (NameScalar s) = text "{" <> ppr s <> text "}"
ppr (NameArray a) = ppr a