module DDC.Core.Flow.Transform.Schedule.Scalar
        (scheduleScalar)
where
import DDC.Core.Flow.Transform.Schedule.Nest
import DDC.Core.Flow.Transform.Schedule.Error
import DDC.Core.Flow.Transform.Schedule.Base
import DDC.Core.Flow.Procedure
import DDC.Core.Flow.Process
import DDC.Core.Flow.Compounds
import DDC.Core.Flow.Prim
import DDC.Core.Flow.Exp
import Control.Monad


-- | Schedule a process into a procedure, producing scalar code.
scheduleScalar :: Process -> Either Error Procedure
scheduleScalar 
       (Process { processName           = name
                , processParamTypes     = bsParamTypes
                , processParamValues    = bsParamValues
                , processOperators      = operators
                , processContexts       = contexts})
  = do
        -- Check the parameter series all have the same rate.
        tK      <- slurpRateOfParamTypes 
                        $ filter isSeriesType 
                        $ map typeOfBind bsParamValues

        -- Check the primary rate variable matches the rates of the series.
        (case bsParamTypes of
          []            -> Left ErrorNoRateParameters
          BName n k : _ 
           | k == kRate
           , TVar (UName n) == tK -> return ()
          _             -> Left ErrorPrimaryRateMismatch)

        -- Create the initial loop nest of the process rate.
        let bsSeries    = [ b   | b <- bsParamValues
                                , isSeriesType (typeOfBind b) ]

        -- Body expressions that take the next element from each input series.
        let ssBody      
                = [ BodyStmt bElem
                        (xNext tK tElem (XVar (UName nS)) (XVar uIndex))
                        | bS@(BName nS tS)      <- bsSeries
                        , let Just tElem        = elemTypeOfSeriesType tS 
                        , let Just bElem        = elemBindOfSeriesBind bS
                        , let uIndex            = UIx 0 ]

        -- The initial loop nest.
        let nest0       
                = NestLoop 
                { nestRate              = tK 
                , nestStart             = []
                , nestBody              = ssBody
                , nestInner             = NestEmpty
                , nestEnd               = []
                , nestResult            = xUnit }

        -- Create the nested contexts
        let Just nest1  =  foldM insertContext nest0 contexts

        -- Schedule the series operators into the nest.
        nest2           <- foldM scheduleOperator nest1 operators

        return  $ Procedure
                { procedureName         = name
                , procedureParamTypes   = bsParamTypes
                , procedureParamValues  = bsParamValues
                , procedureNest         = nest2 }


-------------------------------------------------------------------------------
-- | Schedule a single series operator into a loop nest.
scheduleOperator 
        :: Nest         -- ^ The current loop nest.
        -> Operator     -- ^ Operator to schedule.
        -> Either Error Nest

scheduleOperator nest0 op

 -- Id -------------------------------------------
 | OpId{}     <- op
 = do   let tK          = opInputRate op

        -- Get binders for the input elements.
        let Just bResult = elemBindOfSeriesBind   (opResultSeries op)
        let Just uInput  = elemBoundOfSeriesBound (opInputSeries  op)

        let Just nest1   
                = insertBody nest0 tK
                $ [ BodyStmt bResult (XVar uInput) ]

        return nest1

 -- Rep -----------------------------------------
 | OpRep{}      <- op
 = do   let tK          = opOutputRate op

        -- Make a binder for the replicated element.
        let BName nResult _ = opResultSeries op
        let nVal        = NameVarMod nResult "val"
        let uVal        = UName nVal
        let bVal        = BName nVal (opElemType op)

        -- Get the binder for the use of it in the replicated context.
        let Just bResult = elemBindOfSeriesBind (opResultSeries op)

        -- Evaluate the expression to be replicated once, 
        -- before the main loop.
        let Just nest1
                = insertStarts nest0 tK
                $ [ StartStmt bVal (opInputExp op) ]

        -- Use the expression for each iteration of the loop.
        let Just nest2
                = insertBody nest1 tK
                $ [ BodyStmt bResult (XVar uVal) ]

        return nest2

 -- Reps ----------------------------------------
 | OpReps{}     <- op
 = do   -- Lookup binder for the input element.
        let Just uInput  = elemBoundOfSeriesBound (opInputSeries op)

        -- Set the result to point to the input element.
        let Just bResult = elemBindOfSeriesBind   (opResultSeries op)

        let Just nest1
                = insertBody nest0 (opOutputRate op)
                $ [ BodyStmt    bResult
                                (XVar uInput)]

        return nest1

 -- Indices --------------------------------------
 | OpIndices{}  <- op
 = do   
        -- In a segment context the variable ^1 is the index into
        -- the current segment.
        let Just bResult = elemBindOfSeriesBind   (opResultSeries op)

        let Just nest1
                = insertBody nest0 (opOutputRate op)
                $ [ BodyStmt    bResult
                                (XVar (UIx 1)) ]

        return nest1

 -- Fill -----------------------------------------
 | OpFill{} <- op
 = do   let tK          = opInputRate op

        -- Get bound of the input element.
        let Just uInput = elemBoundOfSeriesBound (opInputSeries op)

        -- Write the current element to the vector.
        let UName nVec  = opTargetVector op
        let Just nest1      
                = insertBody nest0 tK 
                $ [ BodyVecWrite 
                        nVec                    -- destination vector
                        (opElemType op)         -- series elem type
                        (XVar (UIx 0))          -- index
                        (XVar uInput) ]         -- value

        -- If the length of the vector corresponds to a guarded rate then it
        -- was constructed in a filter context. After the process completes, 
        -- we know how many elements were written so we can truncate the
        -- vector down to its final length.
        let Just nest2
                | nestContainsGuardedRate nest1 tK
                = insertEnds nest1 tK
                $ [ EndVecTrunc 
                        nVec                    -- destination vector
                        (opElemType op)         -- series element type
                        tK ]                    -- rate of source series

                | otherwise
                = Just nest1

        return nest2

 -- Gather ---------------------------------------
 | OpGather{} <- op
 = do   
        let tK          = opInputRate op

        -- Bind for result element.
        let Just bResult = elemBindOfSeriesBind (opResultBind op)

        -- Bound of source index.
        let Just uIndex  = elemBoundOfSeriesBound (opSourceIndices op)

        -- Read from the vector.
        let Just nest1  = insertBody nest0 tK
                        $ [ BodyStmt bResult
                                (xReadVector 
                                        (opElemType op)
                                        (XVar $ opSourceVector op)
                                        (XVar $ uIndex)) ]

        return nest1
 
 -- Scatter --------------------------------------
 | OpScatter{} <- op
 = do   
        let tK          = opInputRate op

        -- Bound of source index.
        let Just uIndex = elemBoundOfSeriesBound (opSourceIndices op)

        -- Bound of source elements.
        let Just uElem  = elemBoundOfSeriesBound (opSourceElems op)

        -- Read from vector.
        let Just nest1  = insertBody nest0 tK
                        $ [ BodyStmt (BNone tUnit)
                                (xWriteVector
                                        (opElemType op)
                                        (XVar $ opTargetVector op)
                                        (XVar $ uIndex) (XVar $ uElem)) ]

        -- Bind final unit value.
        let Just nest2  = insertEnds nest1 tK
                        $ [ EndStmt     (opResultBind op)
                                        xUnit ]

        return nest2

 -- Maps -----------------------------------------
 | OpMap{} <- op
 = do   let tK          = opInputRate op

        -- Bind for the result element.
        let Just bResult = elemBindOfSeriesBind (opResultSeries op)

        -- Binds for all the input elements.
        let Just usInput = sequence
                         $ map elemBoundOfSeriesBound
                         $ opInputSeriess op

        -- Apply input element vars into the worker body.
        let xBody       
                = foldl (\x (b, p) -> XApp (XLam b x) p)
                        (opWorkerBody op)
                        [(b, XVar u)
                                | b <- opWorkerParams op
                                | u <- usInput ]

        let Just nest1  
                = insertBody nest0 tK
                $ [ BodyStmt bResult xBody ]

        return nest1

 -- Pack ----------------------------------------
 | OpPack{}     <- op
 = do   -- Lookup binder for the input element.
        let Just uInput  = elemBoundOfSeriesBound (opInputSeries op)

        -- Set the result to point to the input element
        let Just bResult = elemBindOfSeriesBind  (opResultSeries op)

        let Just nest1
                = insertBody nest0 (opOutputRate op)
                $ [ BodyStmt    bResult
                                (XVar uInput)]

        return nest1

-- Reduce --------------------------------------
 | OpReduce{} <- op
 = do   let tK          = opInputRate op

        -- Initialize the accumulator.
        let UName nResult = opTargetRef op
        let nAcc          = NameVarMod nResult "acc"
        let tAcc          = typeOfBind (opWorkerParamAcc op)

        let nAccInit      = NameVarMod nResult "init"

        let Just nest1
                = insertStarts nest0 tK
                $ [ StartStmt (BName nAccInit tAcc)
                              (xRead tAcc (XVar $ opTargetRef op))
                  , StartAcc   nAcc tAcc (XVar (UName nAccInit)) ]

        -- Lookup binders for the input elements.
        let Just uInput = elemBoundOfSeriesBound (opInputSeries op)
        
        -- Bind for intermediate accumulator value.
        let nAccVal     = NameVarMod nResult "val"
        let uAccVal     = UName nAccVal
        let bAccVal     = BName nAccVal tAcc

        -- Substitute input and accumulator vars into worker body.
        let xBody x1 x2
                = XApp  (XApp   ( XLam (opWorkerParamAcc   op)
                                      $ XLam (opWorkerParamElem  op)
                                             (opWorkerBody op))
                                x1)
                        x2
                       
        -- Update the accumulator in the loop body.
        let Just nest2
                = insertBody nest1 tK
                $ [ BodyAccRead  nAcc tAcc bAccVal
                  , BodyAccWrite nAcc tAcc 
                        (xBody  (XVar uAccVal) 
                                (XVar uInput)) ]
                                
        -- Read back the final value after the loop has finished and
        -- write it to the destination.
        let nAccRes     = NameVarMod nResult "res"
        let Just nest3      
                = insertEnds nest2 tK
                $ [ EndAcc   nAccRes tAcc nAcc 
                  , EndStmt  (BNone tUnit)
                             (xWrite tAcc (XVar $ opTargetRef op)
                                          (XVar $ UName nAccRes)) ]

        return nest3

 -- Unsupported ----------------------------------
 | otherwise
 = Left $ ErrorUnsupported op