{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ConstraintKinds #-}
-- | Find the actual variables that need updating when a variable attribute
-- needs updating.  This is different than variable aliasing: Variable aliasing
-- is a theoretical concept, while this module has the practical purpose of
-- finding any extra variables that also need a change when a variable has a
-- change of memory block.
--
-- If and DoLoop statements have special requirements, as do some aliasing
-- expressions.  We don't want to (just) use the obvious statement variable;
-- sometimes updating the memory block of one variable actually means updating
-- the memory block of other variables as well.

module Futhark.Optimise.MemoryBlockMerging.ActualVariables
  ( findActualVariables
  ) where

import qualified Data.Set as S
import qualified Data.Map.Strict as M
import qualified Data.List as L
import Data.Maybe (fromMaybe, mapMaybe, catMaybes)
import Control.Monad
import Control.Monad.RWS

import Futhark.Representation.AST
import Futhark.Representation.ExplicitMemory (
  ExplicitMemorish, ExplicitMemory, InKernel)
import qualified Futhark.Representation.ExplicitMemory as ExpMem
import Futhark.Representation.Kernels.Kernel

import Futhark.Optimise.MemoryBlockMerging.Miscellaneous
import Futhark.Optimise.MemoryBlockMerging.Types
import Futhark.Optimise.MemoryBlockMerging.AllExpVars


data Context = Context
  { ctxVarToMem :: VarMemMappings MemorySrc
  , ctxFirstUses :: FirstUses
  }
  deriving (Show)

newtype FindM lore a = FindM { unFindM :: RWS Context () ActualVariables a }
  deriving (Monad, Functor, Applicative,
            MonadReader Context,
            MonadState ActualVariables)

type LoreConstraints lore = (ExplicitMemorish lore,
                             FullWalk lore,
                             LookInKernelExp lore)

coerce :: FindM flore a -> FindM tlore a
coerce = FindM . unFindM

recordActuals :: VName -> Names -> FindM lore ()
recordActuals stmt_var more_actuals = do
  -- If S.empty has already been recorded, keep it at that.  This is because the
  -- ActualVariables system is currently also used for disabling memory block
  -- optimisations -- if a variables resolves to the empty set, don't touch it.
  -- This keeps some edge cases simple.  FIXME at some point.
  current_actuals <- M.lookup stmt_var <$> get
  case S.null <$> current_actuals of
    Just True -> return ()
    _ -> modify (insertOrUpdateMany stmt_var more_actuals)

-- Find all the actual variables in a function definition.
findActualVariables :: VarMemMappings MemorySrc -> FirstUses ->
                       FunDef ExplicitMemory -> ActualVariables
findActualVariables var_mem_mappings first_uses fundef =
  let context = Context var_mem_mappings first_uses
      m = unFindM $ lookInBody $ funDefBody fundef
      actual_variables = fst $ execRWS m context M.empty
  in actual_variables

lookInFParam :: FParam lore -> FindM lore ()
lookInFParam (Param v _) =
  recordActuals v $ S.singleton v

lookInLParam :: LParam lore -> FindM lore ()
lookInLParam (Param v _) =
  recordActuals v $ S.singleton v

lookInLambda :: LoreConstraints lore => Lambda lore -> FindM lore ()
lookInLambda (Lambda params body _) = do
  forM_ params lookInLParam
  lookInBody body

lookInBody :: LoreConstraints lore =>
              Body lore -> FindM lore ()
lookInBody (Body _ bnds _res) =
  mapM_ lookInStm bnds

lookInKernelBody :: LoreConstraints lore =>
                    KernelBody lore -> FindM lore ()
lookInKernelBody (KernelBody _ bnds _res) =
  mapM_ lookInStm bnds

lookInStm :: LoreConstraints lore =>
             Stm lore -> FindM lore ()
lookInStm stm@(Let (Pattern patctxelems patvalelems) _ e) = do
  case (patvalelems, e) of
    ([PatElem var _], BasicOp (Update orig _ _)) -> do
      let actuals = S.fromList [var, orig]
      -- When coalescing an in-place update statement, also look at the original
      -- array.
      recordActuals var actuals
      -- When reusing a previous memory block, make sure to also update related
      -- in-place updates.
      recordActuals orig actuals
    _ -> return ()

  -- Ignore the existential memory blocks.
  let bodyResult' = drop (length patctxelems) . bodyResult

  -- Special handling of loops, ifs, etc.
  case e of
    DoLoop _mergectxparams mergevalparams loopform body -> do
      let body_vars0 = mapMaybe (subExpVar . snd) mergevalparams
          body_vars1 = map (paramName . fst) mergevalparams
          body_vars2 = S.toList $ findAllExpVars e
          body_vars = body_vars0 ++ body_vars1 ++ body_vars2
      forM_ patvalelems $ \(PatElem var membound) -> do
        case membound of
          ExpMem.MemArray _ _ _ (ExpMem.ArrayIn mem _) -> do
            -- If mem is existential, we need to find the return memory that it
            -- refers to.  We cannot just look at its memory aliases, since it
            -- likely aliases both the initial memory and the final memory.

            let zipped = zip patctxelems (bodyResult body)
                mem_search = case L.find ((== mem) . patElemName . fst) zipped of
                  Just (_, Var res_mem) -> res_mem
                  _ -> mem
            -- Find the ones using the same memory as the result of the loop
            -- expression.
            body_vars' <- filterM (lookupGivesMem mem_search) body_vars
            -- Not only the result variable needs to change its memory block in
            -- case of a future memory merging with it; also the variables
            -- extracted above.
            let actuals = var : body_vars'
            forM_ actuals $ \a -> recordActuals a (S.fromList actuals)
            -- Some of these can be changed later on to have an actual variable
            -- set of S.empty, e.g. if one of the variables using the memory is
            -- a rearrange operation.  This is fine, and will occur in the walk
            -- later on.

            -- If you extend this loop handling, make sure not to target existential
            -- memory blocks.  We want those to stay.
          _ -> return ()

        -- It seems wrong to change the memory of merge variables, so we disable
        -- it.  If we were to accept it, we would need to record what other
        -- variables to change as well.  Seems hard.
        recordActuals var S.empty

      case loopform of
        ForLoop _ _ _ loop_vars ->
          -- Link 'array' to 'lvar' in 'for lvar in array' loop expressions.
          forM_ loop_vars $ \(Param lvar _, array) ->
            aliasOpHandleVar array lvar
        WhileLoop _ -> return ()

    If _se body_then body_else _types ->
      -- We don't want to coalesce the existiential memory block of the if.
      -- However, if a branch result has a memory block that is firstly used
      -- inside the branch, it is okay to coalesce that in a future statement.
      forM_ (zip3 patvalelems (bodyResult' body_then) (bodyResult' body_else))
        $ \(PatElem var membound, res_then, res_else) -> do
        let body_vars = S.toList $ findAllExpVars e
        case membound of
          ExpMem.MemArray _ _ _ (ExpMem.ArrayIn mem _) ->
            if mem `L.elem` map patElemName patctxelems
              then
              -- If the memory block is existential, we say that the If result
              -- refers to all results in the If.
              recordActuals var
              $ S.fromList (var : catMaybes [subExpVar res_then, subExpVar res_else])

              else do
              -- If the memory block is not existential, we need to find all the
              -- variables in any sub-bodies using the same memory block (like
              -- with loops).
              body_vars' <- filterM (lookupGivesMem mem) body_vars

              first_uses <- asks ctxFirstUses
              case filter ((mem `S.member`) . (`lookupEmptyable` first_uses)) body_vars' of
                [] ->
                  -- Not just the result variable needs to change its memory
                  -- block in case of a future memory block merging with it;
                  -- also the variables extracted above.
                  recordActuals var $ S.fromList (var : body_vars')
                _ ->
                  -- If we come across a non-existential If which can be said to
                  -- create a new array *and* which has one or more bodies which
                  -- can also be said to create a new array *in the same memory*
                  -- (i.e. has first memory uses), then we disable it.  This is
                  -- not at all an impossible case to handle, but such an If is
                  -- weird, since it would make more sense if it had existential
                  -- memory, so maybe something needs to be done somewhere else
                  -- in the compiler?  If this is naively enabled, we can get an
                  -- error because the sub-body results are first uses while the
                  -- main result is not.  This can be "fixed" by stating that
                  -- the If as a whole is also a first use of the memory, but
                  -- this seems too conservative.  FIXME.
                  forM_ (var : body_vars') $ \v -> recordActuals v S.empty

          _ -> return ()

    BasicOp (Index orig _) -> do
      let ielem = head patvalelems -- Should be okay.
          var = patElemName ielem
      case patElemAttr ielem of
        ExpMem.MemArray{} ->
          -- Disable merging for index expressions that return arrays.  Maybe
          -- too restrictive.  Make sure the source also updates the memory of
          -- the index when updated.  The array might be an aliasing operation,
          -- in which case we try to find the original array.
          aliasOpHandleVar orig var
        _ -> return ()

    -- Support reusing the memory of reshape operations by recording the origin
    -- array that is being reshaped.  Only partial support for reshape
    -- operations: If the shape is more than one-dimensional, mark the statement
    -- as disabled for memory merging operations.
    BasicOp (Reshape shapechange_var orig) ->
      forM_ (map patElemName patvalelems) $ \var -> do
        orig' <- aliasOpRoot' orig
        mem_orig <- M.lookup orig' <$> asks ctxVarToMem
        case (shapechange_var, mem_orig) of
          ([_], Just (MemorySrc _ _ (Shape [_]))) ->
            recordActuals var $ S.fromList [var, orig]
            -- Works, but only in limited cases where the reshape is not even
            -- that useful to begin with; mostly cases where a reshape was
            -- inserted by the compiler in an assert-like manner.
          _ ->
            recordActuals var S.empty
            -- FIXME: The problem with these more complex cases with more than
            -- one dimension is that a slice is relative to the shape of the
            -- reshaped array, and not the original array.  Disabled for now.
        recordActuals orig' $ S.fromList [orig', var]

    -- For the other aliasing operations, disable their use for now.  If the
    -- source has a change of memory block, make sure to change this as well.
    BasicOp (Rearrange _ orig) ->
      aliasOpHandle orig patvalelems

    BasicOp (Rotate _ orig) ->
      aliasOpHandle orig patvalelems

    BasicOp (Opaque (Var orig)) ->
      aliasOpHandle orig patvalelems

    _ -> forM_ patvalelems $ \(PatElem var membound) -> do
      let body_vars = S.toList $ findAllExpVars e
      case membound of
        ExpMem.MemArray _ _ _ (ExpMem.ArrayIn mem _) -> do
          body_vars' <- filterM (lookupGivesMem mem) body_vars
          recordActuals var $ S.fromList (var : body_vars')
        _ -> return ()

  -- If we are inside a kernel, check for actual variables in the KernelExp of
  -- the statement.
  lookInKernelExp stm

  -- Recurse over any sub-bodies.
  fullWalkExpM walker walker_kernel e
  where walker = identityWalker
          { walkOnBody = lookInBody
          , walkOnFParam = lookInFParam
          , walkOnLParam = lookInLParam
          }
        walker_kernel = identityKernelWalker
          { walkOnKernelBody = coerce . lookInBody
          , walkOnKernelKernelBody = coerce . lookInKernelBody
          , walkOnKernelLambda = coerce . lookInLambda
          , walkOnKernelLParam = lookInLParam
          }

-- If we have a rotate or similar, we want to find the original array and
-- associate *that* with this aliasing array, so that changes to the original
-- array will affect this one as well.
aliasOpHandle :: VName -> [PatElem lore] -> FindM lore ()
aliasOpHandle orig patvalelems =
  forM_ (map patElemName patvalelems) $ aliasOpHandleVar orig

aliasOpHandleVar :: VName -> VName -> FindM lore ()
aliasOpHandleVar orig var = do
  recordActuals var S.empty
  orig' <- aliasOpRoot' orig
  recordActuals orig' $ S.fromList [orig', var]

aliasOpRoot :: VName -> FindM lore (Maybe VName)
aliasOpRoot orig = do
  current_actuals <- get
  return $ case S.null <$> M.lookup orig current_actuals of
    -- If the original array is itself an aliasing operation, find the *actual*
    -- original array.  There can be more than one reference.  We just pick the
    -- first one -- any one should do, since there is a transitive closure
    -- calculation later on.
    Just True -> case M.keys (M.filter (orig `S.member`) current_actuals) of
      orig' : _ -> Just orig'
      _ -> Nothing
    -- Else, just return orig.
    _ -> Just orig

aliasOpRoot' :: VName -> FindM lore VName
aliasOpRoot' orig =
  fromJust ("at some point there will have been a proper statement: "
            ++ pretty orig) <$> aliasOpRoot orig

-- Is the memory block of 'v' the same as 'mem'?
lookupGivesMem :: MName -> VName -> FindM lore Bool
lookupGivesMem mem v = do
  m <- M.lookup v <$> asks ctxVarToMem
  return (Just mem == (memSrcName <$> m))

class LookInKernelExp lore where
  -- Find actual vars in 'KernelExp's.
  lookInKernelExp :: Stm lore -> FindM lore ()

instance LookInKernelExp ExplicitMemory where
  lookInKernelExp (Let (Pattern _ patvalelems) _ e) = case e of
    Op (ExpMem.Inner (Kernel _ _ _ (KernelBody _ _ ress))) ->
      zipWithM_ (\(PatElem var _) res -> case res of
                    WriteReturn _ arr _ ->
                      recordActuals arr $ S.singleton var
                    _ -> return ()
                ) patvalelems ress
    _ -> return ()

instance LookInKernelExp InKernel where
  lookInKernelExp (Let _ _ e) = case e of
    Op (ExpMem.Inner ke) -> case ke of
      ExpMem.GroupReduce _ _ input -> do
        let arrs = map snd input
        extendActualVarsInKernel e arrs
      ExpMem.GroupScan _ _ input -> do
        let arrs = map snd input
        extendActualVarsInKernel e arrs
      ExpMem.GroupStream _ _ _ _ arrs ->
        extendActualVarsInKernel e arrs
      _ -> return ()
    _ -> return ()

-- Record actual variables for input arrays to 'KernelExp's.
extendActualVarsInKernel :: Exp InKernel -> [VName] -> FindM InKernel ()
extendActualVarsInKernel e arrs = forM_ arrs $ \var -> do
  -- The array might be an aliasing operation, in which case we try to find the
  -- original array.
  var' <- fromMaybe var <$> aliasOpRoot var
  varmem <- M.lookup var <$> asks ctxVarToMem
  case varmem of
    Just mem -> do
      let body_vars = findAllExpVars e
      body_vars' <- filterSetM (lookupGivesMem $ memSrcName mem) body_vars
      let actuals = S.insert var' body_vars'
      recordActuals var' actuals
    Nothing -> return ()