module DDC.Core.Flow.Lower
        ( Config        (..)
        , defaultConfigScalar
        , defaultConfigKernel
        , defaultConfigVector
        , Method        (..)
        , Lifting       (..)
        , lowerModule)
where
import DDC.Core.Flow.Transform.Slurp
import DDC.Core.Flow.Transform.Schedule
import DDC.Core.Flow.Transform.Schedule.Base
import DDC.Core.Flow.Transform.Extract
import DDC.Core.Flow.Process
import DDC.Core.Flow.Procedure
import DDC.Core.Flow.Compounds
import DDC.Core.Flow.Profile
import DDC.Core.Flow.Prim
import DDC.Core.Flow.Exp
import DDC.Core.Module

import DDC.Core.Transform.TransformUpX
import DDC.Core.Transform.Annotate
import DDC.Core.Transform.Deannotate

import qualified DDC.Core.Simplifier                    as C
import qualified DDC.Core.Simplifier.Recipe             as C
import qualified DDC.Core.Transform.Namify              as C
import qualified DDC.Core.Transform.Snip                as Snip
import qualified DDC.Type.Env                           as Env
import qualified Control.Monad.State.Strict             as S
import qualified Data.Monoid                            as M
import Control.Monad


-- | Configuration for the lower transform.
data Config
        = Config
        { configMethod          :: Method }
        deriving (Eq, Show)


-- | What lowering method to use.
data Method
        -- | Produce sequential scalar code with nested loops.
        = MethodScalar

        -- | Produce vector kernel code that only processes an even multiple
        --   of the vector width.
        | MethodKernel
        { methodLifting         :: Lifting }

        -- | Try to produce sequential vector code,
        --   falling back to scalar code if this is not possible.
        | MethodVector
        { methodLifting         :: Lifting }
        deriving (Eq, Show)


-- | Config for producing code with just scalar operations.
defaultConfigScalar :: Config
defaultConfigScalar
        = Config
        { configMethod  = MethodScalar }


-- | Config for producing code with vector operations, 
--   where the loops just handle a size of data which is an even multiple
--   of the vector width.
defaultConfigKernel :: Config
defaultConfigKernel
        = Config
        { configMethod  = MethodKernel (Lifting 8)}


-- | Config for producing code with vector operations, 
--   where the loops handle arbitrary data sizes, of any number of elements.
defaultConfigVector :: Config
defaultConfigVector
        = Config
        { configMethod  = MethodVector (Lifting 8)}


-- Lower ----------------------------------------------------------------------
-- | Take a module that contains some well formed series processes defined
--   at top-level, and lower them into procedures. 
lowerModule :: Config -> ModuleF -> Either Error ModuleF
lowerModule config mm
 = case slurpProcesses mm of
    -- Can't slurp a process definition from one of the top level series
    -- processes. 
    Left  err   
     -> Left (ErrorSlurpError err)

    -- We've got a process definition for all of then.
    Right procs
     -> do      
        -- Find names of all process bindings
        let procname (Left  p) = [processName p]
            procname (Right _) = []

            procnames   = concatMap procname procs

        -- Schedule the processeses into procedures.
        lets            <- mapM (lowerEither config procnames) procs

        -- Wrap all the procedures into a new module.
        let mm_lowered  = mm
                        { moduleBody    = annotate ()
                                        $ XLet (LRec lets) xUnit }

        -- Clean up extracted code
        let mm_clean        = cleanModule mm_lowered
        return mm_clean


-- | Look at slurped result, and if it's a process lower it, otherwise remove any runProcess# inside expressions
lowerEither  :: Config -> [Name] -> (Either Process (Bind Name, Exp () Name)) -> Either Error (BindF, ExpF)
lowerEither  config _ (Left process)
 = lowerProcess config process

lowerEither _config _procnames (Right (b,xx))
 = let xx' = deannotate (const Nothing)
           $ transformSimpleUpX' replaceRunProc
           $ annotate () xx
   in  return (b, xx')
 where

  -- Replace all calls to runProcess# with runProcessUnit#
  replaceRunProc (XVar (UPrim (NameOpSeries OpSeriesRunProcess) _))
   = Just
   $ XVar
   $ UPrim (NameOpSeries OpSeriesRunProcessUnit)
           (typeOpSeries OpSeriesRunProcessUnit)
  -- Also replace any Process# types with Units
  replaceRunProc (XType t)
   = Just
   $ XType (replaceProcTy t)
  replaceRunProc (XLet (LLet bind x) e)
   = Just
   $ XLet (LLet (replaceProcTyB bind) x) e
  replaceRunProc (XLet (LRec bxs) e)
   | (bs,xs) <- unzip bxs
   , bs'     <- map replaceProcTyB bs
   = Just
   $ XLet (LRec (zip bs' xs)) e

  replaceRunProc _
   = Nothing

  replaceProcTyB (BName n t) = BName n $ replaceProcTy t
  replaceProcTyB (BAnon   t) = BAnon   $ replaceProcTy t
  replaceProcTyB (BNone   t) = BNone   $ replaceProcTy t

  -- Replace Process# a b with Unit
  replaceProcTy tt
   = case tt of
      TVar{} -> tt
      TCon{} -> tt
      TForall bind tt' -> TForall bind (replaceProcTy tt')
      TApp l r
       | Just (NameTyConFlow TyConFlowProcess, [_,_]) <- takePrimTyConApps tt
       -> tUnit
       | otherwise
       -> TApp (replaceProcTy l) (replaceProcTy r)
      TSum ts
       -> TSum ts
 

-- | Lower a single series process into fused code.
lowerProcess :: Config -> Process -> Either Error (BindF, ExpF)
lowerProcess config process
 
 -- Scalar lowering ------------------------------
 | MethodScalar         <- configMethod config
 = do  
        -- Schedule process into scalar code.
        proc                 <- scheduleScalar process

        -- Extract code for the kernel
        let (bProc, xProc)    = extractProcedure proc

        return (bProc, xProc)


 -- Vector lowering -----------------------------
 -- To use the vector method, 
 --  the type of the source function needs to have a quantifier for
 --  the rate variable (k), as well as a (RateNat k) witness.
 --
 | MethodVector lifting <- configMethod config
 , [nRN]  <- [ nRN | (flag, BName nRN tRN) <- processParamFlags process
                   , not flag
                   , isRateNatType tRN ]
 , tK <- processLoopRate process
 = do   let c           = liftingFactor lifting

        -- The RateNat witness
        let xRN         = XVar (UName nRN)

        let tProc       = processProcType process
        let _tLoopRate   = processLoopRate process

        -----------------------------------------
        -- Create the vector version of the kernel.
        --  Vector code processes several elements per loop iteration.
        procVec         <- scheduleKernel lifting process
        let (_, xProcVec) = extractProcedure procVec
        
        
        let bxsDownSeries       
                = [ ( bS
                    , ( BName (NameVarMod n "down")
                              (tSeries tProc (tDown c tK) tE)
                      , xDown c tProc tK tE (XVar (UIx 0)) xS))
                  | (flag, bS@(BName n tS)) <- processParamFlags process
                  , not flag
                  , let Just tE = elemTypeOfSeriesType tS
                  , let Just uS = takeSubstBoundOfBind bS
                  , let xS      = XVar uS
                  , isSeriesType tS ]

        -- Get a value arg to give to the vector procedure.
        let getDownValArg b
                | Just (b', _)  <- lookup b bxsDownSeries
                = liftM XVar $ takeSubstBoundOfBind b'

                | otherwise
                = liftM XVar $ takeSubstBoundOfBind b

        let Just xsVecValArgs    
                = sequence 
                $ map getDownValArg 
                $ map snd
                $ filter (not.fst)
                $ processParamFlags process

        let bRateDown
                = BAnon (tRateNat (tDown c tK))

        let xProcVec'       
                = XLam bRateDown
                $ xLets [LLet b x | (_, (b, x)) <- bxsDownSeries]
                $ xApps xProcVec
                $ [XType tProc, XType tK] ++ xsVecValArgs


        -----------------------------------------
        -- Create tail version.
        --  Scalar code processes the final elements of the loop.
        procTail        <- scheduleScalar process
        let (bProcTail, xProcTail) = extractProcedure procTail

        -- Window the input series to select the tails.
        let bxsTailSeries
                = [ ( bS, ( BName (NameVarMod n "tail") (tSeries tProc (tTail c tK) tE)
                          , xTail c tProc tK tE (XVar (UIx 0)) xS))
                  | (flag, bS@(BName n tS)) <- processParamFlags process
                  , not flag
                  , let Just tE = elemTypeOfSeriesType tS
                  , let Just uS = takeSubstBoundOfBind bS
                  , let xS      = XVar uS
                  , isSeriesType tS ]

        -- Window the output vectors to select the tails.
        let bxsTailVector
                = [ ( bV, ( BName (NameVarMod n "tail") (tVector tE)
                          , xTailVector c tK tE (XVar (UIx 0)) xV))
                  | (flag, bV@(BName n tV)) <- processParamFlags process
                  , not flag
                  , let Just tE = elemTypeOfVectorType tV
                  , let Just uV = takeSubstBoundOfBind bV
                  , let xV      = XVar uV
                  , isVectorType tV ]

        -- Get a value arg to give to the scalar procedure.
        let getTailValArg b
                | Just (b', _)  <- lookup b bxsTailSeries
                = liftM XVar $ takeSubstBoundOfBind b'

                | Just (b', _)  <- lookup b bxsTailVector
                = liftM XVar $ takeSubstBoundOfBind b'

                | otherwise
                = liftM XVar $ takeSubstBoundOfBind b

        let Just xsTailValArgs
                = sequence 
                $ map getTailValArg (map snd $ filter (not.fst) $ procedureParamFlags procTail)

        let bRateTail
                = BAnon (tRateNat (tTail c tK))

        let xProcTail'
                = XLam bRateTail
                $ xLets [LLet b x | (_, (b, x)) <- bxsTailSeries]
                $ xLets [LLet b x | (_, (b, x)) <- bxsTailVector]
                $ xApps xProcTail
                $ [XType tProc, XType (tTail c tK)] ++ xsTailValArgs

        ------------------------------------------
        -- Stich the vector and scalar versions together.
        let xProc
                = makeXLamFlags (processParamFlags process)
                                xBody

            xBody
                = XLet (LLet   (BNone tUnit) 
                               (xSplit c tK xRN xProcVec' xProcTail'))
                       xUnit
                
        -- Reconstruct a binder for the whole procedure / process.
        let bProc
                = BName (processName process)
                        (typeOfBind bProcTail)

        return (bProc, xProc)

 -- Kernel lowering -----------------------------
 | MethodKernel lifting <- configMethod config
 = do
        -- Schedule process into 
        proc            <- scheduleKernel lifting process

        -- Extract code for the kernel
        let (bProc, xProc)  = extractProcedure proc

        return (bProc, xProc)

 | otherwise
 = error $  "ddc-core-flow.lowerProcess: invalid lowering method"
         


-- Clean ----------------------------------------------------------------------
-- | Do some beta-reductions to ensure that arguments to worker functions
--   are inlined, then normalize nested applications. 
--   When snipping, leave lambda abstractions in place so the worker functions
--   applied to our loop combinators aren't moved.
cleanModule :: ModuleF -> ModuleF
cleanModule mm
 = let
        clean           
         =    C.Trans (C.Namify (C.makeNamifier freshT)
                                (C.makeNamifier freshX))
         M.<> C.Trans C.Forward
         M.<> C.beta
         M.<> C.Trans (C.Snip (Snip.configZero { Snip.configPreserveLambdas = True }))
         M.<> C.Trans C.Flatten

        mm_cleaned      
         = C.result $ S.evalState
                (C.applySimplifier profile Env.empty Env.empty
                        (C.Fix 4 clean) mm)
                0
   in   mm_cleaned