module DDC.Core.Transform.Forward
( ForwardInfo (..)
, FloatControl (..)
, Config(..)
, forwardModule
, forwardX)
where
import DDC.Base.Pretty
import DDC.Core.Analysis.Usage
import DDC.Core.Exp
import DDC.Core.Module
import DDC.Core.Simplifier.Base
import DDC.Core.Transform.Reannotate
import DDC.Core.Fragment
import DDC.Core.Predicates
import DDC.Core.Compounds
import Data.Map (Map)
import Control.Monad
import Control.Monad.Writer (Writer, runWriter, tell)
import Data.Monoid (Monoid, mempty, mappend)
import Data.Typeable
import qualified Data.Map as Map
import qualified DDC.Core.Transform.SubstituteXX as S
data ForwardInfo
= ForwardInfo
{
infoInspected :: !Int
, infoSubsts :: !Int
, infoBindings :: !Int }
deriving Typeable
instance Pretty ForwardInfo where
ppr (ForwardInfo inspected substs bindings)
= text "Forward:"
<$> indent 4 (vcat
[ text "Total bindings inspected: " <> int inspected
, text " Trivial substitutions made: " <> int substs
, text " Bindings moved forward: " <> int bindings ])
instance Monoid ForwardInfo where
mempty = ForwardInfo 0 0 0
mappend (ForwardInfo i1 s1 b1)(ForwardInfo i2 s2 b2)
= ForwardInfo (i1 + i2) (s1 + s2) (b1 + b2)
data FloatControl
= FloatAllow
| FloatDeny
| FloatForce
deriving (Eq, Show)
data Config a n
= Config
{ configFloatControl :: Lets a n -> FloatControl
, configFloatLetBody :: Bool }
forwardModule
:: Ord n
=> Profile n
-> Config a n
-> Module a n
-> TransformResult (Module a n)
forwardModule profile config mm
= let (mm', info)
= runWriter
$ forwardWith profile config Map.empty
$ usageModule mm
progress (ForwardInfo _ s f)
= s + f > 0
in TransformResult
{ result = mm'
, resultProgress = progress info
, resultAgain = False
, resultInfo = TransformInfo info }
forwardX :: Ord n
=> Profile n
-> Config a n
-> Exp a n
-> TransformResult (Exp a n)
forwardX profile config xx
= let (x',info) = runWriter
$ forwardWith profile config Map.empty
$ usageX xx
progress (ForwardInfo _ s f)
= s + f > 0
in TransformResult
{ result = x'
, resultProgress = progress info
, resultAgain = False
, resultInfo = TransformInfo info }
class Forward (c :: * -> * -> *) where
forwardWith
:: Ord n
=> Profile n
-> Config a n
-> Map n (Exp a n)
-> c (UsedMap n, a) n
-> Writer ForwardInfo (c a n)
instance Forward Module where
forwardWith profile config bindings
(ModuleCore
{ moduleName = name
, moduleExportTypes = exportTypes
, moduleExportValues = exportValues
, moduleImportTypes = importTypes
, moduleImportValues = importValues
, moduleDataDefsLocal = dataDefsLocal
, moduleBody = body })
= do body' <- forwardWith profile config bindings body
return ModuleCore
{ moduleName = name
, moduleExportTypes = exportTypes
, moduleExportValues = exportValues
, moduleImportTypes = importTypes
, moduleImportValues = importValues
, moduleDataDefsLocal = dataDefsLocal
, moduleBody = body' }
instance Forward Exp where
forwardWith profile config bindings xx
=
let down = forwardWith profile config bindings
in case xx of
XVar a u@(UName n)
-> case Map.lookup n bindings of
Just xx' -> do
tell mempty { infoSubsts = 1 }
return xx'
Nothing ->
return $ XVar (snd a) u
XVar a u -> return $ XVar (snd a) u
XCon a u -> return $ XCon (snd a) u
XLAM a b x -> liftM (XLAM (snd a) b) (down x)
XLam a b x -> liftM (XLam (snd a) b) (down x)
XApp a x1 x2 -> liftM2 (XApp (snd a)) (down x1) (down x2)
XLet _ (LLet b x1) (XVar _ u)
| boundMatchesBind u b
, configFloatLetBody config
-> down x1
XLet _ (LLet b x1) x2
| isAtomX x1
-> do
tell mempty { infoInspected = 1
, infoBindings = 1 }
down $ S.substituteXX b x1 x2
XLet (UsedMap um, a') lts@(LLet (BName n _) x1) x2
-> do
let control = configFloatControl config
$ reannotate snd lts
let isFun = isXLam x1 || isXLAM x1
let isApplied
| Just usage <- Map.lookup n um
, [UsedFunction] <- filterUsedInCasts usage
= True
| otherwise = False
let shouldFloat
= case control of
FloatDeny -> False
FloatForce -> True
FloatAllow -> isFun && isApplied
if shouldFloat
then do
tell mempty { infoInspected = 1
, infoBindings = 1 }
x1' <- down x1
let bindings' = Map.insert n x1' bindings
forwardWith profile config bindings' x2
else do
tell mempty { infoInspected = 1}
liftM2 (XLet a') (down lts) (down x2)
XLet (_, a') lts x
-> liftM2 (XLet a') (down lts) (down x)
XCase a x alts -> liftM2 (XCase (snd a)) (down x) (mapM down alts)
XCast a c x -> liftM2 (XCast (snd a)) (down c) (down x)
XType a t -> return (XType (snd a) t)
XWitness a w -> return (XWitness (snd a) (reannotate snd w))
filterUsedInCasts :: [Used] -> [Used]
filterUsedInCasts = filter notCast
where notCast UsedInCast = False
notCast _ = True
instance Forward Cast where
forwardWith profile config bindings xx
= let down = forwardWith profile config bindings
in case xx of
CastWeakenEffect eff -> return $ CastWeakenEffect eff
CastWeakenClosure xs -> liftM CastWeakenClosure (mapM down xs)
CastPurify w -> return $ CastPurify (reannotate snd w)
CastForget w -> return $ CastForget (reannotate snd w)
CastBox -> return $ CastBox
CastRun -> return $ CastRun
instance Forward Lets where
forwardWith profile config bindings lts
= let down = forwardWith profile config bindings
in case lts of
LLet b x
-> liftM (LLet b) (down x)
LRec bxs
-> liftM LRec
$ mapM (\(b,x)
-> do x' <- down x
return (b, x'))
bxs
LPrivate b mt bs -> return $ LPrivate b mt bs
LWithRegion b -> return $ LWithRegion b
instance Forward Alt where
forwardWith profile config bindings (AAlt p x)
= liftM (AAlt p) (forwardWith profile config bindings x)