module DDC.Core.Flow.Transform.Rates.SeriesOfVector
(seriesOfVectorModule
,seriesOfVectorFunction)
where
import DDC.Core.Collect
import DDC.Core.Flow.Compounds
import DDC.Core.Flow.Prim
import DDC.Core.Flow.Exp as DDC
import DDC.Core.Flow.Transform.Rates.Combinators as Com
import DDC.Core.Flow.Transform.Rates.CnfFromExp
import DDC.Core.Flow.Transform.Rates.Fail
import DDC.Core.Flow.Transform.Rates.Graph
import qualified DDC.Core.Flow.Transform.Rates.SizeInference as SI
import DDC.Core.Flow.Transform.Rates.Clusters
import DDC.Core.Module
import DDC.Core.Transform.Annotate
import DDC.Core.Transform.Deannotate
import qualified DDC.Type.Env as Env
import qualified Data.Map as Map
import Data.Map (Map)
import qualified Data.Set as Set
import Data.List (partition)
seriesOfVectorModule :: ModuleF -> (ModuleF, [(Name,Fail)])
seriesOfVectorModule mm
= let body = deannotate (const Nothing)
$ moduleBody mm
(lets, xx) = splitXLets body
letsErrs = map seriesOfVectorLets lets
lets' = concatMap fst letsErrs
errs = concatMap snd letsErrs
body' = annotate ()
$ xLets lets' xx
in (mm { moduleBody = body' }, errs)
seriesOfVectorLets :: LetsF -> ([LetsF], [(Name,Fail)])
seriesOfVectorLets ll
| LLet b@(BName n _) x <- ll
, (x',ls',errs) <- seriesOfVectorFunction x
= ( map (uncurry LLet) ls' ++ [LLet b x']
, map (\f -> (n,f)) errs)
| LRec bxs <- ll
, (bs,xs) <- unzip bxs
, (xs',ls', _errs) <- unzip3 $ map seriesOfVectorFunction xs
= ( [LRec (concat ls' ++ (bs `zip` xs'))]
, [])
| otherwise
= ([ll], [])
seriesOfVectorFunction :: ExpF -> (ExpF, [(BindF,ExpF)], [Fail])
seriesOfVectorFunction fun
= case cnfOfExp fun of
Left err
-> (fun, [], [FailCannotConvert err])
Right prog
-> case SI.generate prog of
Nothing
-> (fun, [], [])
Just (env,_s)
-> let g = graphOfBinds prog env
tmap a b = SI.parents prog env a b
clusters = cluster g tmap
(re, ls) = reconstruct fun prog env clusters
in (re, ls, [])
reconstruct
:: ExpF
-> Program Name Name
-> SI.Env Name
-> [[CName Name Name]]
-> (ExpF, [(BindF, ExpF)])
reconstruct fun prog env clusters
= (makeXLamFlags lams
$ xLets lets' xx
, procs)
where
(lams, body) = takeXLamFlags_safe fun
(olds, xx) = splitXLets body
types = takeTypes (concatMap valwitBindsOfLets olds ++ map snd lams)
lets = concatMap convert clusters
(lets', procs) = extractProcs lets lams
convert c
= let outputs = outputsOfCluster prog c
arrIns = seriesInputsOfCluster prog c
binds = filter (flip elem c . cnameOfBind) (_binds prog)
binds' = map (\a -> (a, cnameOfBind a `elem` outputs)) binds
in mkLets types env arrIns binds'
extractProcs :: [LetsF] -> [(Bool, DDC.Bind Name)] -> ([LetsF], [(BindF,ExpF)])
extractProcs lets env
= go lets $ env
where
go [] _
= ([], [])
go (l:ls) e
| LLet b x <- l
, BName nm _ <- b
= let this = go1 b nm x e
rest = go ls ((False,b) : e)
in this `mappend` rest
| otherwise
= ([l],[]) `mappend` go ls e
go1 b nm x e
| Just (op, args) <- takeXApps x
, XVar (UPrim (NameOpSeries (OpSeriesRateVecsOfVectors n)) _) <- op
, (xs, [lam]) <- splitAt (length args 1) args
, (lams,body) <- takeXLamFlags_safe lam
, ([LLet n' x'], binds) <- go1 b nm body (lams ++ e)
= ([LLet n'
(xApps (xVarOpSeries (OpSeriesRateVecsOfVectors n))
(xs ++ [makeXLamFlags lams x']))]
, binds)
| Just (op, args) <- takeXApps x
, XVar (UPrim (NameOpSeries OpSeriesRunProcess) _) <- op
, (xs, [lam]) <- splitAt (length args 1) args
= let fsX = freeX Env.empty lam
fsT = freeT Env.empty lam
isMentioned (ty,bo)
| Just bo' <- takeSubstBoundOfBind bo
= if ty
then Set.member bo' fsT
else Set.member bo' fsX
| otherwise
= False
os = filter isMentioned e
ss = takeSubstBoundsOfBinds . map snd
osT = filter (fst) os
osX = filter (not.fst) os
nm' = NameVarMod nm "process"
x' = xApps op (xs ++
[xApps (XVar $ UName nm')
( map (XType . TVar) (ss osT)
++ map XVar (ss osX))])
p' = makeXLamFlags (osT ++ osX) lam
in ([LLet b x'], [(BName nm' (tBot kData), p')])
| otherwise
= ([LLet b x], [])
mkLets :: Map Name TypeF -> SI.Env Name -> [Name] -> [(Com.Bind Name Name, Bool)] -> [LetsF]
mkLets types env arrIns bs
| any isExt (map fst bs)
= case bs of
[(Ext (NameArray b) xx _, _)] -> [LLet (BName b (types Map.! b)) xx]
[(Ext (NameScalar b) xx _, _)] -> [LLet (BName b (types Map.! b)) xx]
_ -> error ("ddc-core-flow:DDC.Core.Flow.Transform.Rates.SeriesOfVector impossible\n" ++
"an external node has been clustered with another node.\n" ++
"this means there must be a bug in the clustering algorithm.\n" ++
show bs)
| [] <- bs
= error ("ddc-core-flow:DDC.Core.Flow.Transform.Rates.SeriesOfVector impossible\n" ++
"a cluster was created with no bindings.\n" ++
"this means there must be a bug in the clustering algorithm.\n" ++
show bs)
| otherwise
= process types env arrIns
$ map toEither bs
where
isExt (Ext{}) = True
isExt _ = False
toEither (SBind s b, out) = ((s, Left b), out)
toEither (ABind a b, out) = ((a, Right b), out)
toEither (Ext{}, _) = error "ddc-core-flow.mkLets: impossible!"
process :: Map Name TypeF
-> SI.Env Name
-> [Name]
-> [((Name, Either (SBind Name Name) (ABind Name Name)), Bool)]
-> [LetsF]
process types env arrIns bs
= let pres = concatMap getPre bs
mid = getRateVecs arrIns
$ getGenerates bs
$ runProcs
$ xLets (map getInSeries arrIns)
$ mkProcs bs
posts = concatMap getPost bs
in pres ++ [LLet (BName (NameVarMod outname "runproc") tUnit) mid] ++ posts
where
getPre b
| ((s, Left (Fold _ (Scalar z _) _)), _) <- b
= [LLet (BName (NameVarMod s "ref") $ tRef $ tyOf s) (xNew (tyOf s) z)]
| ((v, Right _), True) <- b
= [LLet (BName v $ tVector $ sctyOf v) (xNewVector (sctyOf v) allocSize)]
| _ <- b
= []
getGenerates [] innerX
= innerX
getGenerates (b:rest) innerX
| ((n, Right (Generate (Scalar sz _) _)), _) <- b
, rest' <- getGenerates rest innerX
= xApps (xVarOpSeries (OpSeriesRateVecsOfVectors 0))
[ XType tUnit
, sz
, XLAM (BName (klokV n) kRate) rest' ]
| otherwise
= getGenerates rest innerX
getRateVecs [] innerX
= innerX
getRateVecs (a:as) innerX
| Just t <- SI.lookupV env a
, (same,others) <- partition ((==Just t) . SI.lookupV env) as
, rest' <- getRateVecs others innerX
, these <- (a : same)
, nums <- length these
, op <- OpSeriesRateVecsOfVectors nums
, flags <- map (\n -> (False, BName (NameVarMod n "rv") (tRateVec (klokT a) (sctyOf n)))) these
= xApps (xVarOpSeries op)
( map xsctyOf these ++ [XType tUnit]
++ map var these
++[ makeXLamFlags ((True, BName (klokV a) kRate) : flags)
rest' ]
)
| otherwise
= getRateVecs as innerX
getInSeries n
= LLet (BName (NameVarMod n "s") (tSeries procT (klokT n) (sctyOf n)))
(xApps (xVarOpSeries OpSeriesSeriesOfRateVec)
[ procX, klokX n, xsctyOf n, var $ NameVarMod n "rv" ] )
getPost b
| ((s, Left (Fold _ (Scalar _ _) _)), _) <- b
= [ LLet (BName s $ sctyOf s) (xRead (tyOf s) (var $ NameVarMod s "ref")) ]
| _ <- b
= []
runProcs body
= let flags = [ (True, BName procName kProc)
, (False, BNone tUnit) ]
in xApps (xVarOpSeries OpSeriesRunProcess)
([XType processRate, makeXLamFlags flags body])
mkProcs (b:rs)
| ((s, Left (Fold (Fun xf _) (Scalar xs _) ain)), _) <- b
= let rest = mkProcs rs
in XLet (LLet (BName (NameVarMod s "proc") $ tProcess procT $ klokT ain)
$ xApps (xVarOpSeries OpSeriesReduce)
[ procX, klokX ain, xtyOf s, var (NameVarMod s "ref")
, xf, xs, var (NameVarMod ain "s")])
rest
| ((n, Right abind), out) <- b
= let rest = mkProcs rs
n'proc = NameVarMod n "proc"
n's = NameVarMod n "s"
n'flag = NameVarMod n "flags"
n'sel = NameVarMod n "sel"
llet nm t x1
= XLet (LLet (BName nm t) x1)
go | out
= llet n'proc (tProcess procT $ klokT n)
( xApps (xVarOpSeries OpSeriesFill) [procX, klokX n, xsctyOf n, var $ n, var $ n's] )
rest
| otherwise
= rest
in case abind of
MapN (Fun xf _) ains
-> llet n's (tSeries procT (klokT n) $ sctyOf n)
( xApps (xVarOpSeries (OpSeriesMap (length ains)))
([procX, klokX n] ++ (map xsctyOf ains) ++ [xsctyOf n, xf]
++ map (var . flip NameVarMod "s") ains) )
go
Filter (Fun xf _) ain
-> llet n'flag (tSeries procT (klokT ain) tBool)
( xApps (xVarOpSeries (OpSeriesMap 1))
[ procX, klokX ain, xsctyOf n, XType tBool, xf, var $ NameVarMod ain "s"] )
$ xApps (xVarOpSeries $ OpSeriesMkSel 1)
[ procX, klokX ain, XType processRate, var n'flag
, XLAM (BName (klokV n) kRate)
$ XLam (BName n'sel (tSel1 procT (klokT ain) (klokT n)))
$ llet n's (tSeries procT (klokT n) $ sctyOf n)
( xApps (xVarOpSeries OpSeriesPack)
[ procX, klokX ain, klokX n, xsctyOf n
, var n'sel, var $ NameVarMod ain "s"] )
go ]
Generate _sz (Fun xf _)
-> llet n's (tSeries procT (klokT n) $ sctyOf n)
( xApps (xVarOpSeries OpSeriesGenerate)
[ procX, klokX n, xsctyOf n, xf ])
go
Gather v ix
-> llet n's (tSeries procT (klokT n) $ sctyOf n)
( xApps (xVarOpSeries OpSeriesGather)
([ procX, klokX v, klokX ix, xsctyOf v
, var $ NameVarMod v "rv", var $ NameVarMod ix "s"]) )
go
Cross _a _b
-> error "ddc-core-flow.process: Cross combinator not implemented yet"
| otherwise
= error "ddc-core-flow.process: impossible!"
mkProcs []
= let procs = concatMap getProc bs
in case procs of
(_:_) -> foldl1 mkJoin $ concatMap getProc bs
[] -> error "ddc-core-flow.process: cluster with no outputs?"
mkJoin p q
= xApps (xVarOpSeries OpSeriesJoin) [p, q]
getProc b@((s, Left _), _)
= [resizeProc b $ var $ NameVarMod s "proc"]
getProc b@((a, _), True)
= [resizeProc b $ var $ NameVarMod a "proc"]
getProc _
= []
resizeProc b v
= goResize (fst $ fst b) v (reverse bs)
goResize _ v []
= v
goResize n v (((n',b),_):rest)
| n == n'
= case b of
Left (Fold _ _ ain)
-> goResize ain v rest
Right (MapN _ (i:_))
-> goResize i v rest
Right (MapN _ [])
-> error "ddc-core-flow.process: Map with no inputs."
Right (Filter _ ain)
-> goResize ain
( xApps (xVarOpSeries OpSeriesResizeProc)
[ procX, klokX n, klokX ain
, xApps (xVarOpSeries OpSeriesResizeSel1)
[ procX, klokX n, klokX ain, klokX n
, var $ NameVarMod n "sel"
, xApps (xVarOpSeries OpSeriesResizeId)
[ procX, klokX n ]
]
, v ]) rest
Right (Generate _ _)
-> v
Right (Gather _ ain)
-> goResize ain v rest
Right (Cross a _b)
-> goResize a v rest
| otherwise
= goResize n v rest
allocSize
| (i:_) <- arrIns
= xApps (xVarOpVector OpVectorLength) [xsctyOf i, var i]
| (Scalar sz _:_) <- concatMap findGenerateSize bs
= sz
| otherwise
= error ("ddc-core-flow: allocSize, but no size known" ++ show arrIns ++ "\n" ++ show bs)
processRate
= bindRate (head bs)
bindRate b
= let k = klokT (fst $ fst b)
in case snd $ fst b of
Left (Fold _ _ ain)
-> klokT ain
Right (MapN _ _)
-> k
Right (Filter _ ain)
-> klokT ain
Right (Generate _ _)
-> k
Right (Gather _ _)
-> k
Right (Cross ain _)
-> klokT ain
procName = NameVarMod outname "PROC"
procT = TVar $ UName $ procName
procX = XType $ procT
klokV = getKlok env
klokT = TVar . UName . klokV
klokX = XType . klokT
findGenerateSize ((_, Right (Generate sz _)), _)
= [sz]
findGenerateSize _
= []
outname
= fst $ fst $ head bs
tyOf n = types Map.! n
sctyOf = getScalarType . tyOf
xtyOf = XType . tyOf
xsctyOf = XType . sctyOf
var n = XVar $ UName n
xVarOpSeries n = XVar (UPrim (NameOpSeries n) (typeOpSeries n))
xVarOpVector n = XVar (UPrim (NameOpVector n) (typeOpVector n))
getScalarType :: TypeF -> TypeF
getScalarType tt
= case takePrimTyConApps tt of
Just (NameTyConFlow TyConFlowVector, [sc]) -> sc
_ -> tt
takeTypes :: [DDC.Bind Name] -> Map Name TypeF
takeTypes binds
= Map.fromList $ concatMap get binds
where
get (BName n t) = [(n,t)]
get _ = []
getKlok :: SI.Env Name -> Name -> Name
getKlok e n
| Just t <- SI.lookupV e n
= ty $ goT t
| otherwise
= ty n
where
ty = flip NameVarMod "k"
goT (SI.TVar kv)
= goKV kv
goT (SI.TCross ta tb)
= NameVarMod (goT ta) (show tb)
goKV (SI.KV v)
= v
goKV (SI.K' kv)
= NameVarMod (goKV kv) "'"