{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
module Futhark.Pass.ExtractKernels.Distribution
       (
         Target
       , Targets
       , ppTargets
       , singleTarget
       , outerTarget
       , innerTarget
       , pushInnerTarget
       , popInnerTarget
       , targetsScope

       , LoopNesting (..)
       , ppLoopNesting

       , Nesting (..)
       , Nestings
       , ppNestings
       , letBindInInnerNesting
       , singleNesting
       , pushInnerNesting

       , KernelNest
       , ppKernelNest
       , newKernel
       , pushKernelNesting
       , pushInnerKernelNesting
       , removeArraysFromNest
       , kernelNestLoops
       , kernelNestWidths
       , boundInKernelNest
       , boundInKernelNests
       , flatKernel
       , constructKernel

       , tryDistribute
       , tryDistributeStm
       )
       where

import Control.Monad.RWS.Strict
import Control.Monad.Trans.Maybe
import qualified Data.Map.Strict as M
import qualified Data.Set as S
import Data.Foldable
import Data.Maybe
import Data.List

import Futhark.Representation.Kernels
import Futhark.MonadFreshNames
import Futhark.Tools
import Futhark.Util
import Futhark.Transform.Rename
import Futhark.Util.Log
import Futhark.Pass.ExtractKernels.BlockedKernel (mapKernel, KernelInput(..))

type Target = (Pattern Kernels, Result)

-- | First pair element is the very innermost ("current") target.  In
-- the list, the outermost target comes first.  Invariant: Every
-- element of a pattern must be present as the result of the
-- immediately enclosing target.  This is ensured by 'pushInnerTarget'
-- by removing unused pattern elements.
data Targets = Targets { _innerTarget :: Target
                       , _outerTargets :: [Target]
                       }

ppTargets :: Targets -> String
ppTargets (Targets target targets) =
  unlines $ map ppTarget $ targets ++ [target]
  where ppTarget (pat, res) =
          pretty pat ++ " <- " ++ pretty res

singleTarget :: Target -> Targets
singleTarget = flip Targets []

outerTarget :: Targets -> Target
outerTarget (Targets inner_target []) = inner_target
outerTarget (Targets _ (outer_target : _)) = outer_target

innerTarget :: Targets -> Target
innerTarget (Targets inner_target _) = inner_target

pushOuterTarget :: Target -> Targets -> Targets
pushOuterTarget target (Targets inner_target targets) =
  Targets inner_target (target : targets)

pushInnerTarget :: Target -> Targets -> Targets
pushInnerTarget (pat, res) (Targets inner_target targets) =
  Targets (pat', res') (targets ++ [inner_target])
  where (pes', res') = unzip $ filter (used . fst) $ zip (patternElements pat) res
        pat' = Pattern [] pes'
        inner_used = freeIn $ snd inner_target
        used pe = patElemName pe `S.member` inner_used

popInnerTarget :: Targets -> Maybe (Target, Targets)
popInnerTarget (Targets t ts) =
  case reverse ts of
    x:xs -> Just (t, Targets x $ reverse xs)
    []   -> Nothing

targetScope :: Target -> Scope Kernels
targetScope = scopeOfPattern . fst

targetsScope :: Targets -> Scope Kernels
targetsScope (Targets t ts) = mconcat $ map targetScope $ t : ts

data LoopNesting = MapNesting { loopNestingPattern :: Pattern Kernels
                              , loopNestingCertificates :: Certificates
                              , loopNestingWidth :: SubExp
                              , loopNestingParamsAndArrs :: [(Param Type, VName)]
                              }
                 deriving (Show)

instance Scoped Kernels LoopNesting where
  scopeOf = scopeOfLParams . map fst . loopNestingParamsAndArrs

ppLoopNesting :: LoopNesting -> String
ppLoopNesting (MapNesting _ _ _ params_and_arrs) =
  pretty (map fst params_and_arrs) ++
  " <- " ++
  pretty (map snd params_and_arrs)

loopNestingParams :: LoopNesting -> [LParam Kernels]
loopNestingParams  = map fst . loopNestingParamsAndArrs

instance FreeIn LoopNesting where
  freeIn (MapNesting pat cs w params_and_arrs) =
    freeIn pat <>
    freeIn cs <>
    freeIn w <>
    freeIn params_and_arrs

data Nesting = Nesting { nestingLetBound :: Names
                       , nestingLoop :: LoopNesting
                       }
             deriving (Show)

letBindInNesting :: Names -> Nesting -> Nesting
letBindInNesting newnames (Nesting oldnames loop) =
  Nesting (oldnames <> newnames) loop

-- ^ First pair element is the very innermost ("current") nest.  In
-- the list, the outermost nest comes first.
type Nestings = (Nesting, [Nesting])

ppNestings :: Nestings -> String
ppNestings (nesting, nestings) =
  unlines $ map ppNesting $ nestings ++ [nesting]
  where ppNesting (Nesting _ loop) =
          ppLoopNesting loop

singleNesting :: Nesting -> Nestings
singleNesting = (,[])

pushInnerNesting :: Nesting -> Nestings -> Nestings
pushInnerNesting nesting (inner_nesting, nestings) =
  (nesting, nestings ++ [inner_nesting])

-- | Both parameters and let-bound.
boundInNesting :: Nesting -> Names
boundInNesting nesting =
  S.fromList (map paramName (loopNestingParams loop)) <>
  nestingLetBound nesting
  where loop = nestingLoop nesting

letBindInInnerNesting :: Names -> Nestings -> Nestings
letBindInInnerNesting names (nest, nestings) =
  (letBindInNesting names nest, nestings)


-- | Note: first element is *outermost* nesting.  This is different
-- from the similar types elsewhere!
type KernelNest = (LoopNesting, [LoopNesting])

ppKernelNest :: KernelNest -> String
ppKernelNest (nesting, nestings) =
  unlines $ map ppLoopNesting $ nesting : nestings

-- | Add new outermost nesting, pushing the current outermost to the
-- list, also taking care to swap patterns if necessary.
pushKernelNesting :: Target -> LoopNesting -> KernelNest -> KernelNest
pushKernelNesting target newnest (nest, nests) =
  (fixNestingPatternOrder newnest target (loopNestingPattern nest),
   nest : nests)

-- | Add new innermost nesting, pushing the current outermost to the
-- list.  It is important that the 'Target' has the right order
-- (non-permuted compared to what is expected by the outer nests).
pushInnerKernelNesting :: Target -> LoopNesting -> KernelNest -> KernelNest
pushInnerKernelNesting target newnest (nest, nests) =
  (nest, nests ++ [fixNestingPatternOrder newnest target (loopNestingPattern innermost)])
  where innermost = case reverse nests of
          []  -> nest
          n:_ -> n

fixNestingPatternOrder :: LoopNesting -> Target -> Pattern Kernels -> LoopNesting
fixNestingPatternOrder nest (_,res) inner_pat =
  nest { loopNestingPattern = basicPattern [] pat' }
  where pat = loopNestingPattern nest
        pat' = map fst fixed_target
        fixed_target = sortOn posInInnerPat $ zip (patternValueIdents pat) res
        posInInnerPat (_, Var v) = fromMaybe 0 $ elemIndex v $ patternNames inner_pat
        posInInnerPat _          = 0

-- | Remove these arrays from the outermost nesting, and all
-- uses of corresponding parameters from innermost nesting.
removeArraysFromNest :: [VName] -> KernelNest -> KernelNest
removeArraysFromNest orig_arrs (outer, inners) =
  let (arrs, outer') = remove (S.fromList orig_arrs) outer
      (_, inners') = mapAccumL remove arrs inners
  in (outer', inners')
  where remove arrs nest =
          let (discard, keep) = partition ((`S.member` arrs) . snd) $ loopNestingParamsAndArrs nest
          in (S.fromList (map (paramName . fst) discard) <> arrs,
              nest { loopNestingParamsAndArrs = keep })

newKernel :: LoopNesting -> KernelNest
newKernel nest = (nest, [])

kernelNestLoops :: KernelNest -> [LoopNesting]
kernelNestLoops (loop, loops) = loop : loops

boundInKernelNest :: KernelNest -> Names
boundInKernelNest = mconcat . boundInKernelNests

boundInKernelNests :: KernelNest -> [Names]
boundInKernelNests = map (S.fromList .
                          map (paramName . fst) .
                          loopNestingParamsAndArrs) .
                     kernelNestLoops

kernelNestWidths :: KernelNest -> [SubExp]
kernelNestWidths = map loopNestingWidth . kernelNestLoops

constructKernel :: (MonadFreshNames m, LocalScope Kernels m) =>
                   KernelNest -> KernelBody InKernel
                -> m (Stms Kernels, SubExp, Stm Kernels)
constructKernel kernel_nest inner_body = do
  (w_bnds, w, ispace, inps, rts) <- flatKernel kernel_nest
  let used_inps = filter inputIsUsed inps
      cs = loopNestingCertificates first_nest

  (ksize_bnds, k) <- inScopeOf w_bnds $
    mapKernel w (FlatThreadSpace ispace) used_inps rts inner_body

  let kbnds = w_bnds <> ksize_bnds
  return (kbnds,
          w,
          Let (loopNestingPattern first_nest) (StmAux cs ()) $ Op k)
  where
    first_nest = fst kernel_nest
    inputIsUsed input = kernelInputName input `S.member`
                        freeIn inner_body

-- | Flatten a kernel nesting to:
--
--  (0) Ancillary prologue bindings.
--
--  (1) The total number of threads, equal to the product of all
--  nesting widths, and equal to the product of the index space.
--
--  (2) The index space.
--
--  (3) The kernel inputs - not that some of these may be unused.
--
--  (4) The per-thread return type.
flatKernel :: MonadFreshNames m =>
              KernelNest
           -> m (Stms Kernels,
                 SubExp,
                 [(VName, SubExp)],
                 [KernelInput],
                 [Type])
flatKernel (MapNesting pat _ nesting_w params_and_arrs, []) = do
  i <- newVName "gtid"
  let inps = [ KernelInput pname ptype arr [Var i] |
               (Param pname ptype, arr) <- params_and_arrs ]
  return (mempty, nesting_w, [(i,nesting_w)], inps,
          map rowType $ patternTypes pat)

flatKernel (MapNesting _ _ nesting_w params_and_arrs, nest : nests) = do
  i <- newVName "gtid"
  (w_bnds, w, ispace, inps, returns) <- flatKernel (nest, nests)

  w' <- newVName "nesting_size"
  let w_bnd = mkLet [] [Ident w' $ Prim int32] $
              BasicOp $ BinOp (Mul Int32) w nesting_w

  let inps' = map fixupInput inps
      isParam inp =
        snd <$> find ((==kernelInputArray inp) . paramName . fst) params_and_arrs
      fixupInput inp
        | Just arr <- isParam inp =
            inp { kernelInputArray = arr
                , kernelInputIndices = Var i : kernelInputIndices inp }
        | otherwise =
            inp

  return (w_bnds <> oneStm w_bnd, Var w', (i, nesting_w) : ispace,
          extra_inps i <> inps', returns)
  where extra_inps i =
          [ KernelInput pname ptype arr [Var i] |
            (Param pname ptype, arr) <- params_and_arrs ]

-- | Description of distribution to do.
data DistributionBody = DistributionBody {
    distributionTarget :: Targets
  , distributionFreeInBody :: Names
  , distributionIdentityMap :: M.Map VName Ident
  , distributionExpandTarget :: Target -> Target
    -- ^ Also related to avoiding identity mapping.
  }

distributionInnerPattern :: DistributionBody -> Pattern Kernels
distributionInnerPattern = fst . innerTarget . distributionTarget

distributionBodyFromStms :: Attributes lore =>
                            Targets -> Stms lore -> (DistributionBody, Result)
distributionBodyFromStms (Targets (inner_pat, inner_res) targets) stms =
  let bound_by_stms = S.fromList $ M.keys $ scopeOf stms
      (inner_pat', inner_res', inner_identity_map, inner_expand_target) =
        removeIdentityMappingGeneral bound_by_stms inner_pat inner_res
  in (DistributionBody
      { distributionTarget = Targets (inner_pat', inner_res') targets
      , distributionFreeInBody = fold (fmap freeInStm stms) `S.difference` bound_by_stms
      , distributionIdentityMap = inner_identity_map
      , distributionExpandTarget = inner_expand_target
      },
      inner_res')

distributionBodyFromStm :: Attributes lore =>
                           Targets -> Stm lore -> (DistributionBody, Result)
distributionBodyFromStm targets bnd =
  distributionBodyFromStms targets $ oneStm bnd

createKernelNest :: (MonadFreshNames m, HasScope t m) =>
                    Nestings
                 -> DistributionBody
                 -> m (Maybe (Targets, KernelNest))
createKernelNest (inner_nest, nests) distrib_body = do
  let Targets target targets = distributionTarget distrib_body
  unless (length nests == length targets) $
    fail $ "Nests and targets do not match!\n" ++
    "nests: " ++ ppNestings (inner_nest, nests) ++
    "\ntargets:" ++ ppTargets (Targets target targets)
  runMaybeT $ fmap prepare $ recurse $ zip nests targets

  where prepare (x, _, z) = (z, x)
        bound_in_nest =
          mconcat $ map boundInNesting $ inner_nest : nests
        -- | Can something of this type be taken outside the nest?
        -- I.e. are none of its dimensions bound inside the nest.
        distributableType =
          S.null . S.intersection bound_in_nest . freeIn . arrayDims

        distributeAtNesting :: (HasScope t m, MonadFreshNames m) =>
                               Nesting
                            -> Pattern Kernels
                            -> (LoopNesting -> KernelNest, Names)
                            -> M.Map VName Ident
                            -> [Ident]
                            -> (Target -> Targets)
                            -> MaybeT m (KernelNest, Names, Targets)
        distributeAtNesting
          (Nesting nest_let_bound nest)
          pat
          (add_to_kernel, free_in_kernel)
          identity_map
          inner_returned_arrs
          addTarget = do
          let nest'@(MapNesting _ cs w params_and_arrs) =
                removeUnusedNestingParts free_in_kernel nest
              (params,arrs) = unzip params_and_arrs
              param_names = S.fromList $ map paramName params
              free_in_kernel' =
                (freeIn nest' <> free_in_kernel) `S.difference` param_names
              required_from_nest =
                free_in_kernel' `S.intersection` nest_let_bound

          required_from_nest_idents <-
            forM (S.toList required_from_nest) $ \name -> do
              t <- lift $ lookupType name
              return $ Ident name t

          (free_params, free_arrs, bind_in_target) <-
            fmap unzip3 $
            forM (inner_returned_arrs++required_from_nest_idents) $
            \(Ident pname ptype) ->
              case M.lookup pname identity_map of
                Nothing -> do
                  arr <- newIdent (baseString pname ++ "_r") $
                         arrayOfRow ptype w
                  return (Param pname ptype,
                          arr,
                          True)
                Just arr ->
                  return (Param pname ptype,
                          arr,
                          False)

          let free_arrs_pat =
                basicPattern [] $ map snd $
                filter fst $ zip bind_in_target free_arrs
              free_params_pat =
                map snd $ filter fst $ zip bind_in_target free_params

              (actual_params, actual_arrs) =
                (params++free_params,
                 arrs++map identName free_arrs)
              actual_param_names =
                S.fromList $ map paramName actual_params

              nest'' =
                removeUnusedNestingParts free_in_kernel $
                MapNesting pat cs w $ zip actual_params actual_arrs

              free_in_kernel'' =
                (freeIn nest'' <> free_in_kernel) `S.difference` actual_param_names

          unless (all (distributableType . paramType) $
                  loopNestingParams nest'') $
            fail "Would induce irregular array"
          return (add_to_kernel nest'',

                  free_in_kernel'',

                  addTarget (free_arrs_pat, map (Var . paramName) free_params_pat))

        recurse :: (HasScope t m, MonadFreshNames m) =>
                   [(Nesting,Target)]
                -> MaybeT m (KernelNest, Names, Targets)
        recurse [] =
          distributeAtNesting
          inner_nest
          (distributionInnerPattern distrib_body)
          (newKernel,
           distributionFreeInBody distrib_body `S.intersection` bound_in_nest)
          (distributionIdentityMap distrib_body)
          [] $
          singleTarget . distributionExpandTarget distrib_body

        recurse ((nest, (pat,res)) : nests') = do
          (kernel@(outer, _), kernel_free, kernel_targets) <- recurse nests'

          let (pat', res', identity_map, expand_target) =
                removeIdentityMappingFromNesting
                (S.fromList $ patternNames $ loopNestingPattern outer) pat res

          distributeAtNesting
            nest
            pat'
            (\k -> pushKernelNesting (pat',res') k kernel,
             kernel_free)
            identity_map
            (patternIdents $ fst $ outerTarget kernel_targets)
            ((`pushOuterTarget` kernel_targets) . expand_target)

removeUnusedNestingParts :: Names -> LoopNesting -> LoopNesting
removeUnusedNestingParts used (MapNesting pat cs w params_and_arrs) =
  MapNesting pat cs w $ zip used_params used_arrs
  where (params,arrs) = unzip params_and_arrs
        (used_params, used_arrs) =
          unzip $
          filter ((`S.member` used) . paramName . fst) $
          zip params arrs

removeIdentityMappingGeneral :: Names -> Pattern Kernels -> Result
                             -> (Pattern Kernels,
                                 Result,
                                 M.Map VName Ident,
                                 Target -> Target)
removeIdentityMappingGeneral bound pat res =
  let (identities, not_identities) =
        mapEither isIdentity $ zip (patternElements pat) res
      (not_identity_patElems, not_identity_res) = unzip not_identities
      (identity_patElems, identity_res) = unzip identities
      expandTarget (tpat, tres) =
        (Pattern [] $ patternElements tpat ++ identity_patElems,
         tres ++ map Var identity_res)
      identity_map = M.fromList $ zip identity_res $
                      map patElemIdent identity_patElems
  in (Pattern [] not_identity_patElems,
      not_identity_res,
      identity_map,
      expandTarget)
  where isIdentity (patElem, Var v)
          | not (v `S.member` bound) = Left (patElem, v)
        isIdentity x                  = Right x

removeIdentityMappingFromNesting :: Names -> Pattern Kernels -> Result
                                 -> (Pattern Kernels,
                                     Result,
                                     M.Map VName Ident,
                                     Target -> Target)
removeIdentityMappingFromNesting bound_in_nesting pat res =
  let (pat', res', identity_map, expand_target) =
        removeIdentityMappingGeneral bound_in_nesting pat res
  in (pat', res', identity_map, expand_target)

tryDistribute :: (MonadFreshNames m, LocalScope Kernels m, MonadLogger m) =>
                 Nestings -> Targets -> Stms InKernel
              -> m (Maybe (Targets, Stms Kernels))
tryDistribute _ targets stms | null stms =
  -- No point in distributing an empty kernel.
  return $ Just (targets, mempty)
tryDistribute nest targets stms =
  createKernelNest nest dist_body >>=
  \case
    Just (targets', distributed) -> do
      (w_bnds, _, kernel_bnd) <- localScope (targetsScope targets') $
        constructKernel distributed inner_body
      distributed' <- renameStm kernel_bnd
      logMsg $ "distributing\n" ++
        unlines (map pretty $ stmsToList stms) ++
        pretty (snd $ innerTarget targets) ++
        "\nas\n" ++ pretty distributed' ++
        "\ndue to targets\n" ++ ppTargets targets ++
        "\nand with new targets\n" ++ ppTargets targets'
      return $ Just (targets', w_bnds <> oneStm distributed')
    Nothing ->
      return Nothing
  where (dist_body, inner_body_res) = distributionBodyFromStms targets stms
        inner_body = KernelBody () stms $
                     map (ThreadsReturn ThreadsInSpace) inner_body_res

tryDistributeStm :: (MonadFreshNames m, HasScope t m, Attributes lore) =>
                    Nestings -> Targets -> Stm lore
                 -> m (Maybe (Result, Targets, KernelNest))
tryDistributeStm nest targets bnd =
  fmap addRes <$> createKernelNest nest dist_body
  where (dist_body, res) = distributionBodyFromStm targets bnd
        addRes (targets', kernel_nest) = (res, targets', kernel_nest)