module E.WorkerWrapper(performWorkWrap,workWrapProgram) where
import Control.Monad.Writer.Strict hiding(Product(..))
import Data.Maybe
import DataConstructors
import E.CPR
import E.E
import E.Program
import E.Traverse
import E.TypeCheck
import E.Values
import GenUtil
import Info.Types
import Name.Id
import Name.Name
import Name.Names
import Stats hiding(null)
import Support.CanType
import Util.SetLike
import qualified E.Demand as Demand
import qualified Info.Info as Info
data Arg =
Absent
| Cons Constructor [(Arg,TVr)]
| Plain
isPlain Plain = True
isPlain _ = False
fsubs Demand.None = repeat Demand.lazy
fsubs (Demand.Product xs) = xs ++ repeat Demand.lazy
wrappable :: Monad m =>
DataTable
-> TVr
-> E
-> m (Maybe Name,E,[(Arg,TVr)])
wrappable dataTable mtvr e@ELam {} = ans where
cpr = maybe Top id (Info.lookup (tvrInfo mtvr))
Demand.DemandSignature _ (_ Demand.:=> sa) = maybe Demand.absSig id (Info.lookup (tvrInfo mtvr))
ans = f e ( sa ++ repeat Demand.lazy) cpr []
g t@TVr { tvrIdent = eid } _ | isEmptyId eid = (Absent,t)
g t (Demand.S subs)
| Just con <- getProduct dataTable tt = (Cons con (as con),t)
where
as con = [ g TVr { tvrIdent = n, tvrType = st, tvrInfo = mempty } demand | st <- slotTypes dataTable (conName con) tt | n <- tmpNames Val (tvrIdent t) | demand <- fsubs subs]
tt = getType t
g t Demand.Absent | isLifted (EVar t) = (Absent,t)
g t _ = (Plain,t)
f (ELam t e) (demand:ss) (Fun x) ts = f e ss x (g t demand:ts)
f e _ (Tup n _) ts | isCPR n = return (Just n,e,reverse ts)
f e _ (Tag [n]) ts | isCPR n = return (Just n,e,reverse ts)
f e _ _ ts | any (not . isPlain . fst) ts = return (Nothing ,e,reverse ts)
f _ _ _ _ = fail "not workwrapable"
isCPR n | isBoxed n, onlyChild dataTable n = True
| otherwise = False
isBoxed n = isJust $ do
Constructor { conInhabits = c } <- getConstructor n dataTable
if c == s_Star then return () else do
Constructor { conInhabits = c } <- getConstructor c dataTable
if c == s_Star then return () else Nothing
wrappable _ _ _ = fail "Only lambdas are wrappable"
workerName x = case fromId x of
Just y -> toId (toName Val ("W@",'f':show y))
Nothing -> toId (toName Val ("W@",'f':show x))
tmpNames ns x = case fromId x of
Just y -> [toId (toName ns ("X@",'f':show y ++ "@" ++ show i)) | i <- [(1::Int)..] ]
Nothing -> [toId (toName ns ("X@",'f':show x ++ "@" ++ show i)) | i <- [(1::Int)..] ]
workWrap' :: MonadStats m => DataTable -> TVr -> E -> m ((TVr,E),(TVr,E))
workWrap' _dataTable tvr _e
| badProps `intersects` getProperties tvr = fail "Don't workwrap this"
where badProps = fromList [prop_WRAPPER,prop_INLINE,prop_SUPERINLINE,prop_PLACEHOLDER,prop_NOINLINE]
workWrap' dataTable tvr e | isJust res = ans where
res@(~(Just (cname,body,sargs))) = wrappable dataTable tvr e
args = snds sargs
args' = concatMap f sargs where
f (Absent,_) = []
f (Plain,t) = [t]
f (Cons c ts,_) = concatMap f ts
lets = concatMap f (zip sargs (map show naturals)) where
f ((Absent,t),n) = [(t,EError ("WorkWrap.Absent." ++ tvrShowName tvr ++ "." ++ n) (getType t))]
f ((Plain,_),_) = []
f ((Cons c ts,t),n) = [(t,ELit (updateLit dataTable litCons { litName = conName c, litArgs = map EVar (snds ts), litType = getType t }))] ++ concatMap f (zip ts [ n ++ "." ++ show i | i <- naturals])
cases e = f sargs where
f [] = e
f ((Absent,_):rs) = f rs
f ((Plain,_):rs) = f rs
f ((Cons c ts,t):rs) = eCase (EVar t) [Alt (updateLit dataTable litCons { litName = conName c, litArgs = snds ts, litType = getType t }) (f (ts ++ rs))] Unknown
nprops = insert prop_WORKER $ getProperties tvr `intersection` fromList [prop_JOINPOINT, prop_ONESHOT]
ans = doTicks >> return ((setProperty prop_WRAPPER tvr,wrapper),(tvr',worker))
tvr' = putProperties nprops $ TVr { tvrIdent = workerName (tvrIdent tvr), tvrInfo = mempty, tvrType = wt }
worker = foldr ELam body' (args' ++ navar) where
body' = eLetRec lets $ case cname of
Just cname -> eCase body [cb] Unknown where
cb = Alt (updateLit dataTable litCons { litName = cname, litArgs = vars, litType = bodyTyp }) (if isSingleton then EVar sv else (ELit $ unboxedTuple (map EVar vars)))
Nothing -> body
wrapper = foldr ELam ne args where
workerCall = (foldl EAp (EVar tvr') (map EVar args' ++ navalue))
ne | Just cname <- cname, isSingleton = cases $ eStrictLet sv workerCall (ELit $ updateLit dataTable litCons { litName = cname, litArgs = [EVar sv], litType = bodyTyp })
| Just cname <- cname = let ca = Alt (unboxedTuple vars) (ELit $ updateLit dataTable litCons { litName = cname, litArgs = (map EVar vars), litType = bodyTyp }) in cases $ eCase workerCall [ca] Unknown
| otherwise = cases $ workerCall
getName (Just x) = x
getName Nothing = error ("workWrap': cname = Nothing: tvr = "++show tvr)
vars@(~[sv]) = [ tVr i t | t <- slotTypes dataTable (getName cname) bodyTyp | i <- newIds dontUseThese ]
dontUseThese = freeIds (getType tvr) `mappend` freeIds bodyTyp
isSingleton = case vars of
[v] -> getType (getType v) == eHash
_ -> False
Just wt = typecheck dataTable worker
Just bodyTyp = typecheck dataTable body
needsArg = null args'
(navar,navalue) = if needsArg then ([tvr { tvrType = ltTuple' []}],[eTuple' []]) else ([],[])
doTicks = do
case cname of
Just n -> mtick ("E.Workwrap.CPR.{" ++ show n ++ "}")
_ -> return ()
let argw cn (Absent,_) = mtick $ cn ++ ".absent"
argw cn (Cons n ts,_) = mtick nname >> mapM_ (argw nname) ts where
nname = cn ++ ".{" ++ show (conName n) ++ "}"
argw _ _ = return ()
mapM_ (argw "E.Workwrap.arg") sargs
workWrap' _dataTable tvr e = fail "not workWrapable"
workWrapProgram :: Program -> Program
workWrapProgram prog = ans where
(nds,stats) = performWorkWrap (progDataTable prog) (programDs prog)
ans = programSetDs' nds prog { progStats = progStats prog `mappend` stats }
performWorkWrap :: DataTable -> [(TVr,E)] -> ([(TVr,E)],Stats.Stat)
performWorkWrap dataTable ds = runWriter (wwDs ds) where
wwDs ds = liftM concat $ mapM wwDef ds
wwDef (tvr,e) = case runStatT (workWrap' dataTable tvr e) of
Just (((tx,x),(ty,y)),st) -> do
tell st
y' <- wwE y
return ([ (tx,x), (ty,y') ] :: [(TVr,E)])
Nothing -> do
e' <- wwE e
return ([(tvr,e')]:: [(TVr,E)])
wwE ELetRec { eDefs = ds, eBody = e } = do
ds' <- wwDs ds
e' <- wwE e
return (ELetRec ds' e')
wwE e = emapE' wwE e