module Language.SequentCore.Simpl.ExprSize (
ExprSize(..), valueSize, contSize, commandSize
) where
import Language.SequentCore.Syntax
import Bag
import DataCon
import DynFlags
import Id
import IdInfo
import Literal
import Outputable
import PrelNames ( buildIdKey, augmentIdKey )
import PrimOp
import Type
import TysPrim ( realWorldStatePrimTy )
import Unique
import Util ( count )
import qualified Data.ByteString as BS
data ExprSize
= ExprSize { esBase :: !Int
, esArgDiscs :: ![Int]
, esResultDisc :: !Int
}
data BodySize
= TooBig
| BodySize { bsBase :: !Int
, bsArgDiscs :: Bag (Id, Int)
, bsResultDisc :: !Int
}
instance Outputable BodySize where
ppr TooBig = text "TooBig"
ppr (BodySize b _ r) = brackets (int b <+> int r)
instance Outputable ExprSize where
ppr (ExprSize b as r) = brackets (int b <+> brackets (hsep $ map int as) <+> int r)
body2ExprSize :: [Id] -> BodySize -> Maybe ExprSize
body2ExprSize _ TooBig = Nothing
body2ExprSize topArgs (BodySize b as r)
= Just (ExprSize b (map (discount as) topArgs) r)
where
discount :: Bag (Id, Int) -> Id -> Int
discount cbs bndr = foldlBag combine 0 cbs
where
combine acc (bndr', disc)
| bndr == bndr' = acc `plus_disc` disc
| otherwise = acc
plus_disc :: Int -> Int -> Int
plus_disc | isFunTy (idType bndr) = max
| otherwise = (+)
sizeZero :: BodySize
sizeZero = BodySize { bsBase = 0, bsArgDiscs = emptyBag, bsResultDisc = 0 }
sizeN :: Int -> BodySize
sizeN n = BodySize { bsBase = n, bsArgDiscs = emptyBag, bsResultDisc = 0 }
maxSize :: BodySize -> BodySize -> BodySize
TooBig `maxSize` _ = TooBig
_ `maxSize` TooBig = TooBig
s1@(BodySize b1 _ _) `maxSize` s2@(BodySize b2 _ _) = if b1 > b2 then s1 else s2
mkBodySize :: Int -> Int -> Bag (Id, Int) -> Int -> BodySize
mkBodySize cap b as r
| cap < b r = TooBig
| otherwise = BodySize { bsBase = b, bsArgDiscs = as, bsResultDisc = r }
valueSize :: DynFlags -> Int -> SeqCoreValue -> Maybe ExprSize
contSize :: DynFlags -> Int -> SeqCoreCont -> Maybe ExprSize
commandSize :: DynFlags -> Int -> SeqCoreCommand -> Maybe ExprSize
data Expr = V SeqCoreValue | C SeqCoreCommand | K SeqCoreCont
instance Outputable Expr where
ppr (V val) = ppr val
ppr (C comm) = ppr comm
ppr (K cont) = ppr cont
bodySize :: DynFlags -> Int -> [Id] -> Expr -> BodySize
valueSize dflags cap val
= let (xs, body) = collectLambdas val
cap = ufCreationThreshold dflags
valBinders = filter isId xs
in body2ExprSize valBinders $ bodySize dflags cap valBinders (C body)
commandSize dflags cap comm
| Just v <- asValueCommand comm = valueSize dflags cap v
| otherwise = body2ExprSize [] $
bodySize dflags cap [] (C comm)
contSize dflags cap cont = body2ExprSize [] $ bodySize dflags cap [] (K cont)
bodySize dflags cap topArgs expr
= cap `seq` size expr
where
size (V (Type _)) = sizeZero
size (V (Coercion _)) = sizeZero
size (V (Var _)) = sizeZero
size (V (Compute comm)) = size (C comm)
size (V (Cont cont)) = size (K cont)
size (V (Lit lit)) = sizeN (litSize lit)
size (V (Cons dc args)) = sizeArgs args `addSizeNSD`
sizeCall (dataConWorkId dc) args voids
where voids = count isRealWorldValue args
size (V (Lam x comm)) | erased = size (C comm)
| otherwise = lamScrutDiscount dflags (size (C comm))
where erased = isId x && not (isRealWorldId x)
size (K Return) = sizeZero
size (K (Jump _)) = sizeZero
size (K (Cast _ cont)) = size (K cont)
size (K (Tick _ cont)) = size (K cont)
size (K (App arg cont)) = sizeArg arg `addSizeNSD` size (K cont)
size (K (Case _ _ alts cont))
= sizeAlts alts `addSizeOfCont` cont
size (C comm) = sizeLets (cmdLet comm) `addSizeNSD`
sizeCut (cmdValue comm) (cmdCont comm)
sizeCut :: SeqCoreValue -> SeqCoreCont -> BodySize
sizeCut (Var f) cont@(App {})
= let (args, cont') = collectArgs cont
realArgs = filter (not . isErasedValue) args
voids = count isRealWorldValue realArgs
in sizeArgs realArgs `addSizeNSD` sizeCall f realArgs voids
`addSizeOfCont` cont'
sizeCut (Var x) (Case _b _ty alts cont')
| x `elem` topArgs
= combineSizes total max `addSizeOfCont` cont'
where
altSizes = map sizeAlt alts
total = foldr addAltSize sizeZero altSizes
max = foldr maxSize sizeZero altSizes
combineSizes (BodySize totBase totArgDiscs totResDisc)
(BodySize maxBase _ _)
= BodySize totBase
(unitBag (x, 20 + totBase maxBase)
`unionBags` totArgDiscs)
totResDisc
combineSizes tot _ = tot
sizeCut val cont
= size (V val) `addSizeOfCont` cont
sizeArg :: SeqCoreValue -> BodySize
sizeArg arg = size (V arg)
sizeArgs :: [SeqCoreValue] -> BodySize
sizeArgs args = foldr addSizeNSD sizeZero (map sizeArg args)
sizeCall :: Id -> [SeqCoreValue] -> Int -> BodySize
sizeCall fun valArgs voids
= case idDetails fun of
FCallId _ -> sizeN (10 * (1 + length valArgs))
DataConWorkId dc -> conSize dc (length valArgs)
PrimOpId op -> primOpSize op (length valArgs)
ClassOpId _ -> classOpSize dflags topArgs valArgs
_ -> funSize dflags topArgs fun (length valArgs) voids
sizeAlt :: SeqCoreAlt -> BodySize
sizeAlt (Alt _ _ rhs) = size (C rhs) `addSizeN` 10
sizeAlts :: [SeqCoreAlt] -> BodySize
sizeAlts alts = foldr addAltSize sizeZero (map sizeAlt alts)
sizeBind :: SeqCoreBind -> BodySize
sizeBind (NonRec x rhs)
= size (V rhs) `addSizeN` allocSize
where
allocSize
| isUnLiftedType (idType x) = 0
| otherwise = 10
sizeBind (Rec pairs)
= foldr (addSizeNSD . pairSize) (sizeN allocSize) pairs
where
allocSize = 10 * length pairs
pairSize (_x, rhs) = size (V rhs)
sizeLets :: [SeqCoreBind] -> BodySize
sizeLets binds = foldr (addSizeNSD . sizeBind) sizeZero binds
addSizeN :: BodySize -> Int -> BodySize
addSizeN TooBig _ = TooBig
addSizeN (BodySize b as r) d = mkBodySize cap (b + d) as r
addAltSize :: BodySize -> BodySize -> BodySize
addAltSize (BodySize b1 as1 r1) (BodySize b2 as2 r2)
= mkBodySize cap (b1 + b2) (as1 `unionBags` as2) (r1 + r2)
addAltSize _ _ = TooBig
addSizeNSD :: BodySize -> BodySize -> BodySize
addSizeNSD TooBig _ = TooBig
addSizeNSD _ TooBig = TooBig
addSizeNSD (BodySize b1 as1 _) (BodySize b2 as2 r2)
= mkBodySize cap (b1 + b2) (as1 `unionBags` as2) r2
addSizeOfCont :: BodySize -> SeqCoreCont -> BodySize
addSizeOfCont size1 cont
| isPassThroughCont cont = size1
| otherwise = size1 `addSizeNSD` size (K cont)
isPassThroughCont :: Cont b -> Bool
isPassThroughCont Return = True
isPassThroughCont (Tick _ cont) = isPassThroughCont cont
isPassThroughCont (Cast _ cont) = isPassThroughCont cont
isPassThroughCont (App arg cont) = isErasedValue arg
&& isPassThroughCont cont
isPassThroughCont _ = False
infixl 6 `addSizeN`, `addSizeNSD`, `addSizeOfCont`
litSize :: Literal -> Int
litSize (LitInteger {}) = 100
litSize (MachStr str) = 10 + 10 * ((BS.length str + 3) `div` 4)
litSize _other = 0
classOpSize :: DynFlags -> [Id] -> [SeqCoreValue] -> BodySize
classOpSize _ _ []
= sizeZero
classOpSize dflags top_args (arg1 : other_args)
= BodySize size arg_discount 0
where
size = 20 + (10 * length other_args)
arg_discount = case arg1 of
Var dict | dict `elem` top_args
-> unitBag (dict, ufDictDiscount dflags)
_other -> emptyBag
funSize :: DynFlags -> [Id] -> Id -> Int -> Int -> BodySize
funSize dflags top_args fun n_val_args voids
| fun `hasKey` buildIdKey = buildSize
| fun `hasKey` augmentIdKey = augmentSize
| otherwise = BodySize size arg_discount res_discount
where
some_val_args = n_val_args > 0
size | some_val_args = 10 * (1 + n_val_args voids)
| otherwise = 0
arg_discount | some_val_args && fun `elem` top_args
= unitBag (fun, ufFunAppDiscount dflags)
| otherwise = emptyBag
res_discount | idArity fun > n_val_args = ufFunAppDiscount dflags
| otherwise = 0
primOpSize :: PrimOp -> Int -> BodySize
primOpSize op n_val_args
= if primOpOutOfLine op
then sizeN (op_size + n_val_args)
else sizeN op_size
where
op_size = primOpCodeSize op
buildSize :: BodySize
buildSize = BodySize 0 emptyBag 40
augmentSize :: BodySize
augmentSize = BodySize 0 emptyBag 40
conSize :: DataCon -> Int -> BodySize
conSize dc n_val_args
| n_val_args == 0 = BodySize 0 emptyBag 10
| isUnboxedTupleCon dc = BodySize 0 emptyBag (10 * (1 + n_val_args))
| otherwise = BodySize 10 emptyBag (10 * (1 + n_val_args))
lamScrutDiscount :: DynFlags -> BodySize -> BodySize
lamScrutDiscount _ TooBig = TooBig
lamScrutDiscount dflags (BodySize b as _)
= BodySize b as (ufFunAppDiscount dflags)
isRealWorldId :: Id -> Bool
isRealWorldId id = idType id `eqType` realWorldStatePrimTy
isRealWorldValue :: Value b -> Bool
isRealWorldValue (Var id) = isRealWorldId id
isRealWorldValue _ = False