{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
-- | Segmented scan.
module Futhark.Pass.ExtractKernels.Segmented
       ( regularSegmentedScan
       )
       where

import Control.Monad
import qualified Data.Map.Strict as M

import Futhark.Transform.Rename
import Futhark.Representation.Kernels
import Futhark.Representation.SOACS.SOAC (nilFn)
import Futhark.MonadFreshNames
import Futhark.Tools hiding (true, false)
import Futhark.Pass.ExtractKernels.BlockedKernel

addFlagToLambda :: (MonadBinder m, Lore m ~ Kernels) =>
                   [SubExp] -> Lambda InKernel -> m (Lambda InKernel)
addFlagToLambda nes lam = do
  let num_accs = length nes
  x_flag <- newVName "x_flag"
  y_flag <- newVName "y_flag"
  let x_flag_param = Param x_flag $ Prim Bool
      y_flag_param = Param y_flag $ Prim Bool
      (x_params, y_params) = splitAt num_accs $ lambdaParams lam
      params = [x_flag_param] ++ x_params ++ [y_flag_param] ++ y_params

  body <- runBodyBinder $ localScope (scopeOfLParams params) $ do
    new_flag <- letSubExp "new_flag" $
                BasicOp $ BinOp LogOr (Var x_flag) (Var y_flag)
    lhs <- fmap (map Var) $ letTupExp "seg_lhs" $ If (Var y_flag)
           (resultBody nes)
           (resultBody $ map (Var . paramName) x_params) $
           ifCommon $ map paramType x_params
    let rhs = map (Var . paramName) y_params

    lam' <- renameLambda lam -- avoid shadowing
    res <- eLambda lam' $ map eSubExp $ lhs ++ rhs

    return $ resultBody $ new_flag : res

  return Lambda { lambdaParams = params
                , lambdaBody = body
                , lambdaReturnType = Prim Bool : lambdaReturnType lam
                }

regularSegmentedScan :: (MonadBinder m, Lore m ~ Kernels) =>
                        SubExp
                     -> Pattern Kernels
                     -> SubExp
                     -> Lambda InKernel
                     -> Lambda InKernel
                     -> [(VName, SubExp)] -> [KernelInput]
                     -> [SubExp] -> [VName]
                     -> m ()
regularSegmentedScan segment_size pat w lam map_lam ispace inps nes arrs = do
  flags_i <- newVName "flags_i"

  unused_flag_array <- newVName "unused_flag_array"
  flags_body <-
    runBodyBinder $ localScope (M.singleton flags_i $ IndexInfo Int32) $ do
      segment_index <- letSubExp "segment_index" $
                       BasicOp $ BinOp (SRem Int32) (Var flags_i) segment_size
      start_of_segment <- letSubExp "start_of_segment" $
                          BasicOp $ CmpOp (CmpEq int32) segment_index zero
      let flag = start_of_segment
      return $ resultBody [flag]
  (mapk_bnds, mapk) <- mapKernelFromBody w (FlatThreadSpace [(flags_i, w)]) [] [Prim Bool] flags_body
  addStms mapk_bnds
  flags <- letExp "flags" $ Op mapk

  lam' <- addFlagToLambda nes lam

  flag_p <- newParam "flag" $ Prim Bool
  let map_lam' = map_lam { lambdaParams = flag_p : lambdaParams map_lam
                         , lambdaBody = (lambdaBody map_lam)
                           { bodyResult = Var (paramName flag_p) : bodyResult (lambdaBody map_lam) }
                         , lambdaReturnType = Prim Bool : lambdaReturnType map_lam
                         }

  let pat' = pat { patternValueElements = PatElem unused_flag_array
                                          (arrayOf (Prim Bool) (Shape [w]) NoUniqueness) :
                                          patternValueElements pat
                 }
  void $ blockedScan pat' w (lam', false:nes) (Commutative, nilFn, mempty) map_lam' segment_size ispace inps (flags:arrs)
  where zero = constant (0 :: Int32)
        false = constant False