-- | Float let-bindings with a single use forward into their use-sites.
module DDC.Core.Transform.Forward
        ( ForwardInfo   (..)
        , FloatControl  (..)
        , Config(..)
        , forwardModule
        , forwardX)
where
import DDC.Base.Pretty
import DDC.Core.Analysis.Usage
import DDC.Core.Exp
import DDC.Core.Module
import DDC.Core.Simplifier.Base
import DDC.Core.Transform.Reannotate
import DDC.Core.Fragment
import DDC.Core.Predicates
import DDC.Core.Compounds
import Data.Map                 (Map)
import Control.Monad
import Control.Monad.Writer     (Writer, runWriter, tell)
import Data.Monoid              (Monoid, mempty, mappend)
import Data.Typeable
import qualified Data.Map                               as Map
import qualified DDC.Core.Transform.SubstituteXX        as S

-------------------------------------------------------------------------------
-- | Summary of number of bindings floated.
data ForwardInfo
        = ForwardInfo
        { -- | Number of bindings inspected.
          infoInspected :: !Int

          -- | Number of trivial @v1 = v2@ bindings inlined.
        , infoSubsts    :: !Int

          -- | Number of bindings floated forwards.
        , infoBindings  :: !Int }
        deriving Typeable


instance Pretty ForwardInfo where
 ppr (ForwardInfo inspected substs bindings)
  =  text "Forward:"
  <$> indent 4 (vcat
      [ text "Total bindings inspected:      " <> int inspected
      , text "  Trivial substitutions made:  " <> int substs
      , text "  Bindings moved forward:      " <> int bindings ])


instance Monoid ForwardInfo where
 mempty = ForwardInfo 0 0 0
 mappend (ForwardInfo i1 s1 b1)(ForwardInfo i2 s2 b2)
        = ForwardInfo (i1 + i2) (s1 + s2) (b1 + b2)


-------------------------------------------------------------------------------
-- | Fine control over what should be floated.
data FloatControl
        = FloatAllow    -- ^ Allow binding to be floated, but don't require it.
        | FloatDeny     -- ^ Prevent a binding being floated, at all times.
        | FloatForce    -- ^ Force   a binding to be floated, at all times.
        deriving (Eq, Show)

data Config a n
        = Config
        { configFloatControl    :: Lets a n -> FloatControl
        , configFloatLetBody    :: Bool }

-------------------------------------------------------------------------------
-- | Float let-bindings in a module with a single use forward into
--   their use sites.
forwardModule 
        :: Ord n
        => Profile n    -- ^ Language profile
        -> Config a n
        -> Module a n 
        -> TransformResult (Module a n)

forwardModule profile config mm
 = let  (mm', info)
         = runWriter
                $ forwardWith profile config Map.empty 
                $ usageModule mm

        progress (ForwardInfo _ s f)
                = s + f > 0

   in   TransformResult
         { result         = mm'
         , resultProgress = progress info
         , resultAgain    = False
         , resultInfo     = TransformInfo info }


-- | Float let-bindings in an expression with a single use forward into
--   their use-sites.
forwardX :: Ord n
         => Profile n   -- ^ Language profile.
         -> Config a n 
         -> Exp a n                      
         -> TransformResult (Exp a n)

forwardX profile config xx
 = let  (x',info) = runWriter
                  $ forwardWith profile config Map.empty
                  $ usageX xx

        progress (ForwardInfo _ s f) 
                = s + f > 0

   in  TransformResult
        { result         = x'
        , resultProgress = progress info
        , resultAgain    = False
        , resultInfo     = TransformInfo info }


-------------------------------------------------------------------------------
class Forward (c :: * -> * -> *) where
 -- | Carry bindings forward and downward into their use-sites.
 forwardWith 
        :: Ord n
        => Profile n            -- ^ Language profile.
        -> Config a n
        -> Map n (Exp a n)      -- ^ Bindings currently being carried forward.
        -> c (UsedMap n, a) n
        -> Writer ForwardInfo (c a n)

instance Forward Module where
 forwardWith profile config bindings 
        (ModuleCore
                { moduleName            = name
                , moduleExportTypes     = exportTypes
                , moduleExportValues    = exportValues
                , moduleImportTypes     = importTypes
                , moduleImportValues    = importValues
                , moduleDataDefsLocal   = dataDefsLocal
                , moduleBody            = body })

  = do  body' <- forwardWith profile config bindings body
        return ModuleCore
                { moduleName            = name
                , moduleExportTypes     = exportTypes
                , moduleExportValues    = exportValues
                , moduleImportTypes     = importTypes
                , moduleImportValues    = importValues
                , moduleDataDefsLocal   = dataDefsLocal
                , moduleBody            = body' }


instance Forward Exp where
 forwardWith profile config bindings xx
  = {-# SCC forwardWith #-}
    let down    = forwardWith profile config bindings 
    in case xx of
        XVar a u@(UName n)
         -> case Map.lookup n bindings of
                Just xx'        -> do
                    tell mempty { infoSubsts = 1 }
                    return xx'
                Nothing         ->
                    return $ XVar (snd a) u

        XVar a u        -> return $ XVar (snd a) u
        XCon a u        -> return $ XCon (snd a) u
        XLAM a b x      -> liftM    (XLAM (snd a) b) (down x)
        XLam a b x      -> liftM    (XLam (snd a) b) (down x)
        XApp a x1 x2    -> liftM2   (XApp (snd a))   (down x1) (down x2)

        -- Always float last let-binding into its use.
        --   let x = exp in x => exp
        XLet _ (LLet b x1) (XVar _ u)
         |  boundMatchesBind u b
         ,  configFloatLetBody config
         -> down x1

        -- Always float atomic bindings (variables, constructors)
        XLet _ (LLet b x1) x2
         | isAtomX x1
         -> do 
                -- Record that we've moved this binding.
                tell mempty { infoInspected = 1
                            , infoBindings  = 1 }

                -- Slow, but handles anonymous binders and shadowing
                down $ S.substituteXX b x1 x2

        XLet (UsedMap um, a') lts@(LLet (BName n _) x1) x2
         -> do  
                let control    = configFloatControl config 
                               $ reannotate snd lts

                let isFun      = isXLam x1 || isXLAM x1

                let isApplied
                     | Just usage       <- Map.lookup n um
                     , [UsedFunction]   <- filterUsedInCasts usage
                                        = True
                     | otherwise        = False

                let shouldFloat
                     = case control of
                        FloatDeny       -> False
                        FloatForce      -> True
                        FloatAllow      -> isFun && isApplied

                if shouldFloat 
                 then do
                        -- Record that we've moved this binding.
                        tell mempty { infoInspected = 1
                                    , infoBindings  = 1 }

                        x1'             <- down x1
                        let bindings'   = Map.insert n x1' bindings
                        forwardWith profile config bindings' x2

                 else do        
                        tell mempty { infoInspected = 1}
                        liftM2 (XLet a') (down lts) (down x2)

        XLet (_, a') lts x     
         ->     liftM2 (XLet a') (down lts) (down x)

        XCase a x alts  -> liftM2 (XCase    (snd a)) (down x) (mapM down alts)
        XCast a c x     -> liftM2 (XCast    (snd a)) (down c) (down x)
        XType a t       -> return (XType    (snd a) t)
        XWitness a w    -> return (XWitness (snd a) (reannotate snd w))


filterUsedInCasts :: [Used] -> [Used]
filterUsedInCasts = filter notCast
 where  notCast UsedInCast      = False
        notCast _               = True


instance Forward Cast where
 forwardWith profile config bindings xx
  = let down    = forwardWith profile config bindings
    in case xx of
        CastWeakenEffect eff    -> return $ CastWeakenEffect eff
        CastWeakenClosure xs    -> liftM    CastWeakenClosure (mapM down xs)
        CastPurify w            -> return $ CastPurify (reannotate snd w)
        CastForget w            -> return $ CastForget (reannotate snd w)
        CastBox                 -> return $ CastBox
        CastRun                 -> return $ CastRun


instance Forward Lets where
 forwardWith profile config bindings lts
  = let down    = forwardWith profile config bindings
    in case lts of
        LLet b x   
         -> liftM (LLet b) (down x)

        LRec bxs        
         -> liftM LRec
         $  mapM (\(b,x) 
                    -> do x' <- down x
                          return (b, x')) 
            bxs

        LPrivate b mt bs -> return $ LPrivate b mt bs
        LWithRegion b    -> return $ LWithRegion b


instance Forward Alt where
 forwardWith profile config bindings (AAlt p x)
  = liftM (AAlt p) (forwardWith profile config bindings x)