{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE TypeFamilies #-}
-- | Extract limited nested parallelism for execution inside
-- individual kernel workgroups.
module Futhark.Pass.ExtractKernels.Intragroup
  (intraGroupParallelise)
where

import Control.Monad.Identity
import Control.Monad.RWS
import Control.Monad.Trans.Maybe
import qualified Data.Map.Strict as M
import qualified Data.Set as S

import Prelude hiding (log)

import Futhark.Analysis.PrimExp.Convert
import Futhark.Representation.SOACS
import qualified Futhark.Representation.Kernels as Out
import Futhark.Representation.Kernels.Kernel
import Futhark.MonadFreshNames
import Futhark.Tools
import Futhark.Pass.ExtractKernels.DistributeNests
import Futhark.Pass.ExtractKernels.Distribution
import Futhark.Pass.ExtractKernels.BlockedKernel
import Futhark.Util (chunks)
import Futhark.Util.Log

-- | Convert the statements inside a map nest to kernel statements,
-- attempting to parallelise any remaining (top-level) parallel
-- statements.  Anything that is not a map, scan or reduction will
-- simply be sequentialised.  This includes sequential loops that
-- contain maps, scans or reduction.  In the future, we could probably
-- do something more clever.  Make sure that the amount of parallelism
-- to be exploited does not exceed the group size.  Further, as a hack
-- we also consider the size of all intermediate arrays as
-- "parallelism to be exploited" to avoid exploding local memory.
--
-- We distinguish between "minimum group size" and "maximum
-- exploitable parallelism".
intraGroupParallelise :: (MonadFreshNames m, LocalScope Out.Kernels m) =>
                         KernelNest -> Lambda
                      -> m (Maybe ((SubExp, SubExp), SubExp, Log,
                                   Out.Stms Out.Kernels, Out.Stms Out.Kernels))
intraGroupParallelise knest lam = runMaybeT $ do
  (ispace, inps) <- lift $ flatKernel knest

  (num_groups, w_stms) <- lift $ runBinder $
    letSubExp "intra_num_groups" =<<
    foldBinOp (Mul Int32) (intConst Int32 1) (map snd ispace)

  let body = lambdaBody lam

  group_size <- newVName "computed_group_size"
  let intra_lvl = SegThread (Count num_groups) (Count $ Var group_size) SegNoVirt

  (wss_min, wss_avail, log, kbody) <-
    lift $ localScope (scopeOfLParams $ lambdaParams lam) $
    intraGroupParalleliseBody intra_lvl body

  known_outside <- lift $ M.keys <$> askScope
  unless (all (`elem` known_outside) $ namesToList $ freeIn $
          wss_min ++ wss_avail) $
    fail "Irregular parallelism"

  ((intra_avail_par, kspace, read_input_stms), prelude_stms) <- lift $ runBinder $ do
    let foldBinOp' _    []    = eSubExp $ intConst Int32 0
        foldBinOp' bop (x:xs) = foldBinOp bop x xs
    ws_min <- mapM (letSubExp "one_intra_par_min" <=< foldBinOp' (Mul Int32)) $
              filter (not . null) wss_min
    ws_avail <- mapM (letSubExp "one_intra_par_avail" <=< foldBinOp' (Mul Int32)) $
                filter (not . null) wss_avail

    -- The amount of parallelism available *in the worst case* is
    -- equal to the smallest parallel loop.
    intra_avail_par <- letSubExp "intra_avail_par" =<< foldBinOp' (SMin Int32) ws_avail

    -- The group size is either the maximum of the minimum parallelism
    -- exploited, or the desired parallelism (bounded by the max group
    -- size) in case there is no minimum.
    letBindNames_ [group_size] =<<
      if null ws_min
      then eBinOp (SMin Int32)
           (eSubExp =<< letSubExp "max_group_size" (Op $ Out.GetSizeMax Out.SizeGroup))
           (eSubExp intra_avail_par)
      else foldBinOp' (SMax Int32) ws_min

    let inputIsUsed input = kernelInputName input `nameIn` freeIn body
        used_inps = filter inputIsUsed inps

    addStms w_stms
    read_input_stms <- runBinder_ $ mapM readKernelInput used_inps
    space <- mkSegSpace ispace
    return (intra_avail_par, space, read_input_stms)

  let kbody' = kbody { kernelBodyStms = read_input_stms <> kernelBodyStms kbody }

  let nested_pat = loopNestingPattern first_nest
      rts = map (length ispace `stripArray`) $ patternTypes nested_pat
      lvl = SegGroup (Count num_groups) (Count $ Var group_size) SegNoVirt
      kstm = Let nested_pat (StmAux cs ()) $ Op $ SegOp $ SegMap lvl kspace rts kbody'

  let intra_min_par = intra_avail_par
  return ((intra_min_par, intra_avail_par), Var group_size, log,
           prelude_stms, oneStm kstm)
  where first_nest = fst knest
        cs = loopNestingCertificates first_nest

data Acc = Acc { accMinPar :: S.Set [SubExp]
               , accAvailPar :: S.Set [SubExp]
               , accLog :: Log
               }

instance Semigroup Acc where
  Acc min_x avail_x log_x <> Acc min_y avail_y log_y =
    Acc (min_x <> min_y) (avail_x <> avail_y) (log_x <> log_y)

instance Monoid Acc where
  mempty = Acc mempty mempty mempty

type IntraGroupM =
  BinderT Out.Kernels (RWS () Acc VNameSource)

instance MonadLogger IntraGroupM where
  addLog log = tell mempty { accLog = log }

runIntraGroupM :: (MonadFreshNames m, HasScope Out.Kernels m) =>
                  IntraGroupM () -> m (Acc, Out.Stms Out.Kernels)
runIntraGroupM m = do
  scope <- castScope <$> askScope
  modifyNameSource $ \src ->
    let (((), kstms), src', acc) = runRWS (runBinderT m scope) () src
    in ((acc, kstms), src')

parallelMin :: [SubExp] -> IntraGroupM ()
parallelMin ws = tell mempty { accMinPar = S.singleton ws
                             , accAvailPar = S.singleton ws
                             }

intraGroupBody :: SegLevel -> Body -> IntraGroupM (Out.Body Out.Kernels)
intraGroupBody lvl body = do
  stms <- collectStms_ $ mapM_ (intraGroupStm lvl) $ bodyStms body
  return $ mkBody stms $ bodyResult body

intraGroupStm :: SegLevel -> Stm -> IntraGroupM ()
intraGroupStm lvl stm@(Let pat aux e) = do
  scope <- askScope
  let lvl' = SegThread (segNumGroups lvl) (segGroupSize lvl) SegNoVirt

  case e of
    DoLoop ctx val form loopbody ->
      localScope (scopeOf form') $
      localScope (scopeOfFParams $ map fst $ ctx ++ val) $ do
      loopbody' <- intraGroupBody lvl loopbody
      certifying (stmAuxCerts aux) $
        letBind_ pat $ DoLoop ctx val form' loopbody'
          where form' = case form of
                          ForLoop i it bound inps -> ForLoop i it bound inps
                          WhileLoop cond          -> WhileLoop cond

    If cond tbody fbody ifattr -> do
      tbody' <- intraGroupBody lvl tbody
      fbody' <- intraGroupBody lvl fbody
      certifying (stmAuxCerts aux) $
        letBind_ pat $ If cond tbody' fbody' ifattr

    Op (Screma w form arrs)
      | Just lam <- isMapSOAC form -> do
      let loopnest = MapNesting pat (stmAuxCerts aux) w $ zip (lambdaParams lam) arrs
          env = DistEnv { distNest =
                            singleNesting $ Nesting mempty loopnest
                        , distScope =
                            scopeOfPattern pat <>
                            scopeForKernels (scopeOf lam) <> scope
                        , distOnInnerMap =
                            distributeMap
                        , distOnTopLevelStms =
                            lift . collectStms_ . intraGroupStms lvl
                        , distSegLevel = \minw _ _ -> do
                            lift $ parallelMin minw
                            return lvl
                        }
          acc = DistAcc { distTargets = singleTarget (pat, bodyResult $ lambdaBody lam)
                        , distStms = mempty
                        }

      addStms =<<
        runDistNestT env (distributeMapBodyStms acc (bodyStms $ lambdaBody lam))

    Op (Screma w form arrs)
      | Just (scanfun, nes, mapfun) <- isScanomapSOAC form -> do
      let scanfun' = soacsLambdaToKernels scanfun
          mapfun' = soacsLambdaToKernels mapfun
      certifying (stmAuxCerts aux) $
        addStms =<< segScan lvl' pat w scanfun' mapfun' nes arrs [] []
      parallelMin [w]

    Op (Screma w form arrs)
      | Just (reds, map_lam) <- isRedomapSOAC form,
        Reduce comm red_lam nes <- singleReduce reds -> do
      let red_lam' = soacsLambdaToKernels red_lam
          map_lam' = soacsLambdaToKernels map_lam
      certifying (stmAuxCerts aux) $
        addStms =<< segRed lvl' pat w [SegRedOp comm red_lam' nes mempty] map_lam' arrs [] []
      parallelMin [w]

    Op (Stream w (Sequential accs) lam arrs)
      | chunk_size_param : _ <- lambdaParams lam -> do
      types <- asksScope castScope
      ((), stream_bnds) <-
        runBinderT (sequentialStreamWholeArray pat w accs lam arrs) types
      let replace (Var v) | v == paramName chunk_size_param = w
          replace se = se
          replaceSets (Acc x y log) =
            Acc (S.map (map replace) x) (S.map (map replace) y) log
      censor replaceSets $ mapM_ (intraGroupStm lvl) stream_bnds

    Op (Scatter w lam ivs dests) -> do
      write_i <- newVName "write_i"
      space <- mkSegSpace [(write_i, w)]

      let lam' = soacsLambdaToKernels lam
          (dests_ws, dests_ns, dests_vs) = unzip3 dests
          (i_res, v_res) = splitAt (sum dests_ns) $ bodyResult $ lambdaBody lam'
          krets = do (a_w, a, is_vs) <- zip3 dests_ws dests_vs $ chunks dests_ns $ zip i_res v_res
                     return $ WriteReturns [a_w] a [ ([i],v) | (i,v) <- is_vs ]
          inputs = do (p, p_a) <- zip (lambdaParams lam') ivs
                      return $ KernelInput (paramName p) (paramType p) p_a [Var write_i]

      kstms <- runBinder_ $ localScope (scopeOfSegSpace space) $ do
        mapM_ readKernelInput inputs
        addStms $ bodyStms $ lambdaBody lam'

      certifying (stmAuxCerts aux) $ do
        let ts = map rowType $ patternTypes pat
            body = KernelBody () kstms krets
        letBind_ pat $ Op $ SegOp $ SegMap lvl' space ts body

      parallelMin [w]

    _ ->
      addStm $ soacsStmToKernels stm

intraGroupStms :: SegLevel -> Stms SOACS -> IntraGroupM ()
intraGroupStms lvl = mapM_ (intraGroupStm lvl)

intraGroupParalleliseBody :: (MonadFreshNames m, HasScope Out.Kernels m) =>
                             SegLevel -> Body
                          -> m ([[SubExp]], [[SubExp]], Log, Out.KernelBody Out.Kernels)
intraGroupParalleliseBody lvl body = do
  (Acc min_ws avail_ws log, kstms) <-
    runIntraGroupM $ intraGroupStms lvl $ bodyStms body
  return (S.toList min_ws, S.toList avail_ws, log,
          KernelBody () kstms $ map Returns $ bodyResult body)