{-# LANGUAGE CPP #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
module SimplStg ( stg2stg ) where
#include "HsVersions.h"
import GhcPrelude
import StgSyn
import StgLint          ( lintStgTopBindings )
import StgStats         ( showStgStats )
import UnariseStg       ( unarise )
import StgCse           ( stgCse )
import StgLiftLams      ( stgLiftLams )
import Module           ( Module )
import DynFlags
import ErrUtils
import UniqSupply
import Outputable
import Control.Monad
import Control.Monad.IO.Class
import Control.Monad.Trans.State.Strict
newtype StgM a = StgM { _unStgM :: StateT UniqSupply IO a }
  deriving (Functor, Applicative, Monad, MonadIO)
instance MonadUnique StgM where
  getUniqueSupplyM = StgM (state splitUniqSupply)
  getUniqueM = StgM (state takeUniqFromSupply)
runStgM :: UniqSupply -> StgM a -> IO a
runStgM us (StgM m) = evalStateT m us
stg2stg :: DynFlags                  
        -> Module                    
        -> [StgTopBinding]           
        -> IO [StgTopBinding]        
stg2stg dflags this_mod binds
  = do  { showPass dflags "Stg2Stg"
        ; us <- mkSplitUniqSupply 'g'
        
        ; binds' <- runStgM us $
            foldM do_stg_pass binds (getStgToDo dflags)
        ; dump_when Opt_D_dump_stg "STG syntax:" binds'
        ; return binds'
   }
  where
    stg_linter what
      | gopt Opt_DoStgLinting dflags
      = lintStgTopBindings dflags this_mod what
      | otherwise
      = \ _whodunnit _binds -> return ()
    
    do_stg_pass :: [StgTopBinding] -> StgToDo -> StgM [StgTopBinding]
    do_stg_pass binds to_do
      = case to_do of
          StgDoNothing ->
            return binds
          StgStats ->
            trace (showStgStats binds) (return binds)
          StgCSE -> do
            let binds' = {-# SCC "StgCse" #-} stgCse binds
            end_pass "StgCse" binds'
          StgLiftLams -> do
            us <- getUniqueSupplyM
            let binds' = {-# SCC "StgLiftLams" #-} stgLiftLams dflags us binds
            end_pass "StgLiftLams" binds'
          StgUnarise -> do
            liftIO (dump_when Opt_D_dump_stg "Pre unarise:" binds)
            us <- getUniqueSupplyM
            liftIO (stg_linter False "Pre-unarise" binds)
            let binds' = unarise us binds
            liftIO (stg_linter True "Unarise" binds')
            return binds'
    dump_when flag header binds
      = dumpIfSet_dyn dflags flag header (pprStgTopBindings binds)
    end_pass what binds2
      = liftIO $ do 
          dumpIfSet_dyn dflags Opt_D_verbose_stg2stg what
            (vcat (map ppr binds2))
          stg_linter False what binds2
          return binds2
data StgToDo
  = StgCSE
  
  | StgLiftLams
  
  
  | StgStats
  | StgUnarise
  
  | StgDoNothing
  
  deriving Eq
getStgToDo :: DynFlags -> [StgToDo]
getStgToDo dflags =
  filter (/= StgDoNothing)
    [ mandatory StgUnarise
    
    
    , optional Opt_StgCSE StgCSE
    , optional Opt_StgLiftLams StgLiftLams
    , optional Opt_StgStats StgStats
    ] where
      optional opt = runWhen (gopt opt dflags)
      mandatory = id
runWhen :: Bool -> StgToDo -> StgToDo
runWhen True todo = todo
runWhen _    _    = StgDoNothing