{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE TupleSections #-}
-- | Find array creations that can be set to use existing memory blocks instead
-- of new allocations.
module Futhark.Optimise.MemoryBlockMerging.Reuse.Core
  ( coreReuseFunDef
  ) where

import qualified Data.Set as S
import qualified Data.Map.Strict as M
import qualified Data.List as L
import Data.Maybe (catMaybes, fromMaybe, isJust)
import Control.Monad
import Control.Monad.RWS
import Control.Monad.State
import Control.Monad.Identity

import Futhark.MonadFreshNames
import Futhark.Binder
import Futhark.Construct
import Futhark.Representation.AST
import Futhark.Analysis.PrimExp
import Futhark.Analysis.PrimExp.Convert
import Futhark.Representation.ExplicitMemory
       (ExplicitMemory, ExplicitMemorish)
import Futhark.Pass.ExplicitAllocations()
import qualified Futhark.Representation.ExplicitMemory as ExpMem
import qualified Futhark.Representation.ExplicitMemory.IndexFunction as IxFun
import Futhark.Representation.Kernels.Kernel

import Futhark.Optimise.MemoryBlockMerging.PrimExps (findPrimExpsFunDef)
import Futhark.Optimise.MemoryBlockMerging.Miscellaneous
import Futhark.Optimise.MemoryBlockMerging.Types
import Futhark.Optimise.MemoryBlockMerging.MemoryUpdater

import Futhark.Optimise.MemoryBlockMerging.Reuse.AllocationSizes
import Futhark.Optimise.MemoryBlockMerging.Reuse.AllocationSizeUses


data Context = Context { ctxFirstUses :: FirstUses
                         -- ^ From the module Liveness.FirstUses
                       , ctxInterferences :: Interferences
                       , ctxPotentialKernelInterferences
                         :: PotentialKernelDataRaceInterferences
                         -- ^ From the module Liveness.Interferences
                       , ctxSizes :: Sizes
                         -- ^ maps a memory block to its size and space
                       , ctxVarToMem :: VarMemMappings MemorySrc
                         -- ^ From the module VariableMemory
                       , ctxActualVars :: M.Map VName Names
                         -- ^ From the module ActualVariables
                       , ctxExistentials :: Names
                         -- ^ From the module Existentials
                       , ctxVarPrimExps :: M.Map VName (PrimExp VName)
                         -- ^ From the module PrimExps
                       , ctxSizeVarsUsesBefore :: M.Map VName Names
                         -- ^ maps a memory name to the size variables available
                         -- at that memory block allocation point
                       }
  deriving (Show)

data Current = Current { curUses :: M.Map MName MNames
                         -- ^ maps a memory block to the memory blocks that
                         -- have been merged into it so far
                       , curEqAsserts :: M.Map VName Names
                         -- ^ maps a variable name to other semantically equal
                         -- variable names

                       , curVarToMemRes :: VarMemMappings MemoryLoc
                         -- ^ The result of the core analysis: maps an array
                         -- name to its memory block.

                       , curVarToMaxExpRes :: M.Map MName Names
                         -- ^ Changes in variable uses where allocation sizes
                         -- are maxed from its elements.  Keyed by statement
                         -- memory name (alloc stmt).  Maps an alloc stmt to the
                         -- sizes that need to be taken max for.

                       , curKernelMaxSizedRes :: M.Map MName (VName,
                                                              ((VName, VName),
                                                               (VName, VName)))
                         -- ^ Maps an alloc stmt to
                         -- (size0,
                         --  ((array0, size_var0, ixfun0),
                         --   (array1, size_var1, ixfun1))).
                         --
                         -- Needed for array creations in kernel
                         -- bodies that can only reuse memory if index functions
                         -- are changed, and the allocation size is maxed.
                         --
                         -- size_var0 is *not* the size of the entire allocation
                         -- of the key memory, but *part of* the allocation
                         -- size.  This part will be replaced by the maximum of
                         -- the two sizes.
                       }
  deriving (Show)

emptyCurrent :: Current
emptyCurrent = Current { curUses = M.empty
                       , curEqAsserts = M.empty
                       , curVarToMemRes = M.empty
                       , curVarToMaxExpRes = M.empty
                       , curKernelMaxSizedRes = M.empty
                       }

newtype FindM lore a = FindM { unFindM :: RWS Context () Current a }
  deriving (Monad, Functor, Applicative,
            MonadReader Context,
            MonadState Current)

type LoreConstraints lore = (ExplicitMemorish lore,
                             FullWalk lore)

coerce :: FindM flore a -> FindM tlore a
coerce = FindM . unFindM

-- Lookup the memory block statically associated with a variable.
lookupVarMem :: MonadReader Context m =>
                VName -> m MemorySrc
lookupVarMem var =
  -- This should always be called from a place where it is certain that 'var'
  -- refers to a statement with an array expression.
  fromJust ("lookup memory block from " ++ pretty var) . M.lookup var
  <$> asks ctxVarToMem

lookupActualVars' :: ActualVariables -> VName -> Names
lookupActualVars' actual_vars var =
  -- Do this recursively.
  let actual_vars' = expandWithAliases actual_vars actual_vars
  in fromMaybe (S.singleton var) $ M.lookup var actual_vars'

lookupActualVars :: MonadReader Context m =>
                    VName -> m Names
lookupActualVars var = asks $ flip lookupActualVars' var . ctxActualVars

lookupSize :: MonadReader Context m =>
              VName -> m SubExp
lookupSize var =
  fst . fromJust ("lookup size from " ++ pretty var) . M.lookup var
  <$> asks ctxSizes

lookupSpace :: MonadReader Context m =>
               MName -> m Space
lookupSpace mem =
  snd . fromJust ("lookup space from " ++ pretty mem) . M.lookup mem
  <$> asks ctxSizes

-- Record that the existing old_mem now also "is the same as" new_mem.
insertUse :: VName -> VName -> FindM lore ()
insertUse old_mem new_mem =
  modify $ \cur -> cur { curUses = insertOrUpdate old_mem new_mem $ curUses cur }

recordMemMapping :: VName -> MemoryLoc -> FindM lore ()
recordMemMapping x mem =
  modify $ \cur -> cur { curVarToMemRes = M.insert x mem $ curVarToMemRes cur }

recordMaxMapping :: MName -> VName -> FindM lore ()
recordMaxMapping mem y =
  modify $ \cur -> cur { curVarToMaxExpRes = insertOrUpdate mem y
                                             $ curVarToMaxExpRes cur }

recordKernelMaxMapping :: MName -> (VName, ((VName, VName), (VName, VName)))
                       -> FindM lore ()
recordKernelMaxMapping mem info =
  modify $ \cur -> cur { curKernelMaxSizedRes =
                           M.insert mem info $ curKernelMaxSizedRes cur
                       }

modifyCurEqAsserts :: (M.Map VName Names -> M.Map VName Names) -> FindM lore ()
modifyCurEqAsserts f = modify $ \c -> c { curEqAsserts = f $ curEqAsserts c }

-- Run a monad with a local copy of the uses.  We don't want any new uses in
-- nested bodies to be available for merging into when we are back in the main
-- body, but we do want updates to existing uses to be propagated.
withLocalUses :: FindM lore a -> FindM lore a
withLocalUses m = do
  uses_before <- gets curUses
  res <- m
  uses_after <- gets curUses
  -- Only take the results whose memory block keys were also present prior to
  -- traversing the sub-body.
  let uses_before_updated = M.filterWithKey
                            (\mem _ -> mem `S.member` M.keysSet uses_before)
                            uses_after
  modify $ \cur -> cur { curUses = uses_before_updated }
  return res

coreReuseFunDef :: MonadFreshNames m =>
                   FunDef ExplicitMemory -> FirstUses ->
                   Interferences -> PotentialKernelDataRaceInterferences ->
                   VarMemMappings MemorySrc -> ActualVariables -> Names ->
                   m (FunDef ExplicitMemory)
coreReuseFunDef fundef first_uses interferences potential_kernel_interferences var_to_mem actual_vars existentials = do
  let sizes = memBlockSizesFunDef fundef
      size_uses = findSizeUsesFunDef fundef
      var_to_pe = findPrimExpsFunDef fundef
      context = Context
        { ctxFirstUses = first_uses
        , ctxInterferences = interferences
        , ctxPotentialKernelInterferences = potential_kernel_interferences
        , ctxSizes = sizes
        , ctxVarToMem = var_to_mem
        , ctxActualVars = actual_vars
        , ctxExistentials = existentials
        , ctxVarPrimExps = var_to_pe
        , ctxSizeVarsUsesBefore = size_uses
        }
      m = unFindM $ do
        forM_ (funDefParams fundef) lookInFParam
        lookInBody $ funDefBody fundef
      (res, ()) = execRWS m context emptyCurrent
      var_to_mem_res = curVarToMemRes res
  fundef' <- transformFromVarMemMappings var_to_mem_res (M.map memSrcName var_to_mem) (M.map fst sizes) (M.map fst sizes) False fundef
  let sizes' = memBlockSizesFunDef fundef'
  fundef'' <- transformFromVarMaxExpMappings (curVarToMaxExpRes res) fundef'
  transformFromKernelMaxSizedMappings var_to_pe var_to_mem (M.map memLocName var_to_mem_res) sizes' actual_vars (curKernelMaxSizedRes res) fundef''

lookInFParam :: LoreConstraints lore =>
                FParam lore -> FindM lore ()
lookInFParam (Param _ membound) =
  -- Unique array function parameters also count as "allocations" in which
  -- memory can be reused.
  case membound of
    ExpMem.MemArray _ _ Unique (ExpMem.ArrayIn mem _) ->
      insertUse mem mem
    _ -> return ()

lookInBody :: LoreConstraints lore =>
              Body lore -> FindM lore ()
lookInBody (Body _ bnds _res) =
  mapM_ lookInStm bnds

lookInKernelBody :: LoreConstraints lore =>
                    KernelBody lore -> FindM lore ()
lookInKernelBody (KernelBody _ bnds _res) =
  mapM_ lookInStm bnds

lookInStm :: LoreConstraints lore =>
             Stm lore -> FindM lore ()
lookInStm (Let (Pattern _patctxelems patvalelems) _ e) = do
  var_to_pe <- asks ctxVarPrimExps
  let eqs | BasicOp (Assert (Var v) _ _) <- e
          , Just (CmpOpExp (CmpEq _) (LeafExp v0 _) (LeafExp v1 _)) <- M.lookup v var_to_pe = do
              modifyCurEqAsserts $ insertOrUpdate v0 v1
              modifyCurEqAsserts $ insertOrUpdate v1 v0
          | otherwise = return ()
  eqs

  forM_ patvalelems $ \(PatElem var membound) -> do
    -- For every declaration with a first memory use, check (through
    -- handleNewArray) if it can reuse some earlier memory block.
    first_uses_var <- lookupEmptyable var <$> asks ctxFirstUses
    actual_vars_var <- lookupActualVars var
    existentials <- asks ctxExistentials
    case membound of
      ExpMem.MemArray _ _ _ (ExpMem.ArrayIn mem _) ->
        when (-- We require that it must be a first use, i.e. an array creation.
              mem `S.member` first_uses_var
              -- If the array is existential or "aliases" something that is
              -- existential, we do not try to make it reuse any memory.
              && not (var `S.member` existentials)
              && not (any (`S.member` existentials) actual_vars_var))
        $ handleNewArray var mem
      _ -> return ()

  fullWalkExpM walker walker_kernel e

  where walker = identityWalker
          { walkOnBody = withLocalUses . lookInBody }
        walker_kernel = identityKernelWalker
          { walkOnKernelBody = coerce . withLocalUses . lookInBody
          , walkOnKernelKernelBody = coerce . withLocalUses . lookInKernelBody
          , walkOnKernelLambda = coerce . withLocalUses . lookInBody . lambdaBody
          }

-- Check if a new array declaration x with a first use of the memory xmem can be
-- set to use a previously encountered memory block.
handleNewArray :: VName -> MName -> FindM lore ()
handleNewArray x xmem = do
  interferences <- asks ctxInterferences
  actual_vars <- lookupActualVars x

  let notTheSame :: Monad m => MName -> MNames -> m Bool
      notTheSame kmem _used_mems = return (kmem /= xmem)

  let noneInterfere :: Monad m => MName -> MNames -> m Bool
      noneInterfere _kmem used_mems =
        -- A memory block can have already been reused.  We also check for
        -- interference with any previously merged blocks.
        return $ all (\used_mem -> not $ S.member xmem
                                   $ lookupEmptyable used_mem interferences)
        $ S.toList used_mems

  let noneInterfereKernelArray :: MonadReader Context m => MNames -> m Bool
      noneInterfereKernelArray used_mems =
        not <$> anyM (interferesInKernel xmem) (S.toList used_mems)

  let sameSpace :: MonadReader Context m =>
                   MName -> MNames -> m Bool
      sameSpace kmem _used_mems = do
        kspace <- lookupSpace kmem
        xspace <- lookupSpace xmem
        return (kspace == xspace)

  -- Is the size of the new memory block (xmem) equal to any of the memory
  -- blocks (used_mems) using an already used memory block?
  let sizesMatch :: MNames -> FindM lore Bool
      sizesMatch used_mems = do
        ok_sizes <- mapM lookupSize $ S.toList used_mems
        new_size <- lookupSize xmem
        -- Check for size equality by checking for variable name equality.
        let eq_simple = new_size `L.elem` ok_sizes

        -- Check for size equality by constructing 'PrimExp's and comparing
        -- those.  Use the custom VarWithLooseEquality type to compare inner
        -- sizes: If an equality assert statement was found earlier, consider
        -- its two operands to be the same.
        var_to_pe <- asks ctxVarPrimExps
        eq_asserts <- gets curEqAsserts
        let sePrimExp se = do
              v <- subExpVar se
              pe <- M.lookup v var_to_pe
              let pe_expanded = expandPrimExp var_to_pe pe
              traverse (\v_inner -> -- Has custom Eq instance.
                                       pure $ VarWithLooseEquality v_inner
                                       $ lookupEmptyable v_inner eq_asserts
                       ) pe_expanded
        let ok_sizes_pe = map sePrimExp ok_sizes
        let new_size_pe = sePrimExp new_size

        -- If new_size_pe actually denotes a PrimExp, check if it is among the
        -- constructed 'PrimExp's of the sizes of the memory blocks that have
        -- already been set to use the target memory block.
        let eq_advanced = isJust new_size_pe && new_size_pe `L.elem` ok_sizes_pe

        return (eq_simple || eq_advanced)

  -- In case sizes do not match: Is it possible to change the size of the target
  -- memory block to be a maximum of itself and the new memory block?
  let sizesCanBeMaxed :: MName -> FindM lore Bool
      sizesCanBeMaxed kmem = do
        ksize <- lookupSize kmem
        xsize <- lookupSize xmem
        uses_before <- asks ctxSizeVarsUsesBefore
        let ok = fromMaybe False $ do
              ksize' <- subExpVar ksize
              xsize' <- subExpVar xsize
              return (xsize' `S.member` fromJust ("is recorded for all size variables "
                                                  ++ pretty ksize')
                      (M.lookup ksize' uses_before))
        return ok

  let sizesCanBeMaxedKernelArray :: MName -> MNames ->
                                    FindM lore (Maybe (VName, ((VName, VName),
                                                               (VName, VName))))
      sizesCanBeMaxedKernelArray kmem used_mems = do
        -- Let a kernel body have two indexed array creations result_0 and
        -- result_1 with the index functions
        --
        --   result_0: ixfun_start_0[indices_start_0, 0i64:+res_0*1i64]
        --   result_1: ixfun_start_1[indices_start_1, 0i64:+res_1*1i64]
        --
        -- with the additional requirements that
        --
        --   + ixfun_start_0 is equal to ixfun_start_1 except for mentions of
        --     res_0 and res_1.
        --
        --   + indices_start_0 is equal to indices_start_1.
        --
        -- Example:
        --
        --   result_0: Direct(num_groups, res_0, group_size)[0, 2, 1][group_id, local_tid, 0i64:+res_0*1i64]
        --   result_1: Direct(num_groups, res_1, group_size)[0, 2, 1][group_id, local_tid, 0i64:+res_1*1i64]
        --
        -- By default result_0 and result_1 will be set to interfere because
        -- each thread can access parts of the memory of another thread if they
        -- are merged.  We can fix this my making both index functions describe
        -- the same access pattern except for the final dimension.  We want this
        -- to happen for the example above:
        --
        --   result_0': Direct(num_groups, res_max, group_size)[0, 2, 1][group_id, local_tid, 0i64:+res_0*1i64]
        --   result_1': Direct(num_groups, res_max, group_size)[0, 2, 1][group_id, local_tid, 0i64:+res_1*1i64]
        --
        -- Where res_max = max(res_0, res_1).  Now they cover the same area in
        -- space.  The final index slices are kept as they were, since the shape
        -- of the created array should stay the same.  This means that the
        -- smallest array will not be writing to all of its available space.
        --
        -- We need to check:
        --
        --   + Is res_1 in scope at the allocation?  Allocation size hoisting
        --     has probably been helpful here.
        --
        --   + Does res_0 and res_1 have the same base type size?
        --
        -- If true, modify the program as such:
        --
        --   + Insert a res_max statement before the allocation.
        --
        --   + Change the allocation size to use res_max instead of res_0.
        --
        --   + Modify both index functions to use res_max instead of res_0 and
        --     res_1, respectively, except for at the final index slice.
        --
        -- Extension: If an array reuses an already reused array, remember to
        -- update *all* index functions.  Currently we avoid these cases for
        -- simplicity of implementation.

        potentials <- asks ctxPotentialKernelInterferences
        uses_before <- asks ctxSizeVarsUsesBefore

        let first_usess = filter (\p ->
                                    let pot_mems = map (\(m, _, _, _) -> m) p
                                    in kmem `elem` pot_mems && xmem `elem` pot_mems)
                          potentials
        kmem_size <- fromJust "should be a var" . subExpVar <$> lookupSize kmem

        return $ case (S.toList used_mems, first_usess) of
          -- We only support the basic case for now.  FIXME (or, at the very
          -- least, manage to create a program where this will have an effect).
          --
          -- A used_mems list of size > 1 means that kmem has already been
          -- reused.  This is okay, but a bit harder to keep track of.
          --
          -- A first_usess list of size > 1 means that xmem and kmem
          -- data-race-interfere in multiple kernels.  This will never happen in
          -- the current implementation, but could *potentially* happen in the
          -- future.
          ([_], [first_uses]) -> do
            (_, kmem_array, kmem_pt, kmem_ixfun) <-
              L.find (\(mname, _, _, _) -> mname == kmem) first_uses
            (_, xmem_array, xmem_pt, xmem_ixfun) <-
              L.find (\(mname, _, _, _) -> mname == xmem) first_uses

            if (kmem, kmem_ixfun) `ixFunsCompatible` (xmem, xmem_ixfun)
              then Nothing -- These are not special, and need not special handling.
              else do
              (kmem_ixfun_start, kmem_indices_start, kmem_final_dim) <-
                IxFun.getInfoMaxUnification kmem_ixfun
              (xmem_ixfun_start, xmem_indices_start, xmem_final_dim) <-
                IxFun.getInfoMaxUnification xmem_ixfun

              let xmem_final_dim_before_kmem_final_dim =
                    maybe False (xmem_final_dim `S.member`) $
                    M.lookup kmem_final_dim uses_before
                  kmem_ixfun_start' = getIxFun' kmem_ixfun_start
                                      (M.singleton kmem_final_dim xmem_final_dim)
                  xmem_ixfun_start' = getIxFun' xmem_ixfun_start
                                      (M.singleton xmem_final_dim kmem_final_dim)

                  res = if kmem_indices_start == xmem_indices_start &&
                           (kmem, kmem_ixfun_start') `ixFunsCompatible`
                           (xmem, xmem_ixfun_start') &&
                           (primByteSize kmem_pt :: Int) == primByteSize xmem_pt &&
                           xmem_final_dim_before_kmem_final_dim
                        then return (kmem_size,
                                     ((kmem_array, kmem_final_dim),
                                      (xmem_array, xmem_final_dim)))
                        else Nothing

                in res
          _ -> Nothing

        where getIxFun' :: ExpMem.IxFun -> M.Map VName VName ->
                           IxFun.IxFun (PrimExp VarWithLooseEquality)
              getIxFun' ixfun others =
                let loose_eq_map name_inner =
                      -- Has custom Eq instance.
                      pure $ VarWithLooseEquality name_inner
                      $ maybe S.empty S.singleton $ M.lookup name_inner others
                in runIdentity $ traverse (traverse loose_eq_map) ixfun

  let sizesCanBeMaxedKernelArray' :: MName -> MNames -> FindM lore Bool
      sizesCanBeMaxedKernelArray' kmem used_mems =
        isJust <$> sizesCanBeMaxedKernelArray kmem used_mems

  let noOtherUsesOfMemory :: MName -> MNames -> FindM lore Bool
      noOtherUsesOfMemory _kmem _used_mems =
        -- If the array in question 'x' is not the only array that uses the
        -- memory (ignoring aliasing), then do not perform memory reuse.  We
        -- only want to reuse memory if it means we can remove an allocation.
        -- FIXME: If we can check that all arrays using the memory in question
        -- 'xmem' can be set to reuse some other memory, so that 'xmem' does not
        -- have to be allocated, then this restriction can go away.  It also
        -- might be the case that the ActualVariables module does not find all
        -- array connections, i.e. it concludes that two arrays are distinct
        -- when they are actually not; this can happen with streams.
        and . M.elems . M.mapWithKey (
          \v m -> (memSrcName m /= xmem)
                  || (v `L.elem` actual_vars)
          ) <$> asks ctxVarToMem

  let notCurrentlyDisabled :: FindM lore Bool
      notCurrentlyDisabled =
        -- FIXME: We currently disable reusing memory of constant size.  This is
        -- a problem in the misc/heston/heston32.fut benchmark (but not the
        -- heston64.fut one).  It would be nice to not have to disable this
        -- feature, as it works well for the most part.  Why is this a problem?
        -- Or is it maybe something else that causes heston32 to segfault?
        isJust . subExpVar <$> lookupSize xmem

  let sizesWorkOut :: MName -> MNames -> FindM lore Bool
      sizesWorkOut kmem used_mems =
        -- The size of an allocation is okay to reuse if it is the same as the
        -- current memory size, or if it can be changed to be the maximum size
        -- of the two sizes.
        (notCurrentlyDisabled <&&> noneInterfereKernelArray used_mems <&&>
         (sizesMatch used_mems <||> sizesCanBeMaxed kmem))
        <||> sizesCanBeMaxedKernelArray' kmem used_mems

  let canBeUsed t = and <$> mapM (($ t) . uncurry)
                    [notTheSame, noneInterfere, sameSpace, noOtherUsesOfMemory,
                     sizesWorkOut]
  cur_uses <- gets curUses
  found_use <- catMaybes <$> mapM (maybeFromBoolM canBeUsed) (M.assocs cur_uses)

  case found_use of
    (kmem, used_mems) : _ -> do
      -- There is a previous memory block that we can use.  Record the mapping.
      insertUse kmem xmem
      forM_ actual_vars $ \var -> do
        ixfun <- memSrcIxFun <$> lookupVarMem var
        recordMemMapping var $ MemoryLoc kmem ixfun -- Only change the memory block.

      -- Record any size-maximum change in case of sizesCanBeMaxed returning
      -- True.
      whenM (sizesCanBeMaxed kmem) $ do
        ksize <- lookupSize kmem
        xsize <- lookupSize xmem
        fromMaybe (return ()) $ do
          ksize' <- subExpVar ksize
          xsize' <- subExpVar xsize
          return $ do
            recordMaxMapping kmem ksize'
            recordMaxMapping kmem xsize'

      -- If we are inside a kernel body, and the current array can use the
      -- memory block of another array if its size gets maximised, record this
      -- change.  The actual program transformation will happen later.
      kernel_maxing <- sizesCanBeMaxedKernelArray kmem used_mems
      forM_ kernel_maxing $ \info ->
        recordKernelMaxMapping kmem info

    _ ->
      -- There is no previous memory block available for use.  Record that this
      -- memory block is available.
      insertUse xmem xmem

data VarWithLooseEquality = VarWithLooseEquality VName Names
  deriving (Show)

instance Eq VarWithLooseEquality where
  VarWithLooseEquality v0 vs0 == VarWithLooseEquality v1 vs1 =
    not $ S.null $ S.intersection (S.insert v0 vs0) (S.insert v1 vs1)

interferesInKernel :: MonadReader Context m => MName -> MName -> m Bool
interferesInKernel mem0 mem1 = do
  potentials <- asks ctxPotentialKernelInterferences

  let interferesInGroup :: PotentialKernelDataRaceInterferenceGroup -> Bool
      interferesInGroup first_uses = fromMaybe False $ do
        (_, _, pt0, ixfun0) <- L.find (\(mname, _, _, _) -> mname == mem0) first_uses
        (_, _, pt1, ixfun1) <- L.find (\(mname, _, _, _) -> mname == mem1) first_uses
        return $ interferes (pt0, ixfun0) (pt1, ixfun1)

      interferes :: (PrimType, ExpMem.IxFun) -> (PrimType, ExpMem.IxFun) -> Bool
      interferes (pt0, ixfun0) (pt1, ixfun1) =
          -- Must be different.
          mem0 /= mem1 &&
          (
            -- Do the index functions range over different memory areas?
            ((ixFunHasIndex ixfun0 || ixFunHasIndex ixfun1) &&
             not (ixFunsCompatible (mem0, ixfun0) (mem1, ixfun1)))
            ||
            -- Do the arrays have different base type size?  If so, they take
            -- up different amounts of space, and will not be compatible.
            ((primByteSize pt0 :: Int) /= primByteSize pt1)
          )

  return $ any interferesInGroup potentials

-- Does an index function contain an Index expression?
--
-- If the index function of the memory annotation uses an index, it means that
-- the array creation does not refer to the entire array.  It is an array
-- creation, but only partially: It creates part of the array, and another part
-- is created in another loop iteration or kernel thread.  The danger in
-- declaring this memory a first use lies in how it can then be reused later in
-- the iteration/thread by some memory with a *different* index in its memory
-- annotation index function, which can affect reads in other threads.
ixFunHasIndex :: IxFun.IxFun num -> Bool
ixFunHasIndex = IxFun.ixFunHasIndex

-- Do the two index functions describe the same range?  In other words, does one
-- array take up precisely the same location (offset) and size as another array
-- relative to the beginning of their respective memory blocks?  FIXME: This can
-- be less conservative, for example by handling that different reshapes of the
-- same array can describe the same offset and space, but do we have any tests
-- or benchmarks where that occurs?
ixFunsCompatible :: Eq v =>
                    (MName, IxFun.IxFun (PrimExp v)) -> (MName, IxFun.IxFun (PrimExp v)) ->
                    Bool
ixFunsCompatible (_mem0, ixfun0) (_mem1, ixfun1) =
  IxFun.ixFunsCompatibleRaw ixfun0 ixfun1

-- Replace certain allocation sizes in a program with new variables describing
-- the maximum of two or more allocation sizes.
transformFromVarMaxExpMappings :: MonadFreshNames m =>
                                  M.Map VName Names
                               -> FunDef ExplicitMemory -> m (FunDef ExplicitMemory)
transformFromVarMaxExpMappings var_to_max fundef = do
  var_to_new_var <-
    M.fromList <$> mapM (\(k, v) -> (k,) <$> maxsToReplacement (S.toList v))
    (M.assocs var_to_max)
  return $ insertAndReplace var_to_new_var fundef

-- A replacement is a new size variable and any new subexpressions that the new
-- variable depends on.
data Replacement = Replacement
  { replName :: VName -- The new variable
  , replStms :: [Stm ExplicitMemory] -- The new expressions
  }
  deriving (Show)

-- Take a list of size variables.  Return a replacement consisting of a size
-- variable denoting the maximum of the input sizes.
maxsToReplacement :: MonadFreshNames m =>
                     [VName] -> m Replacement
maxsToReplacement [] = error "maxsToReplacements: Cannot take max of zero variables"
maxsToReplacement [v] = return $ Replacement v []
maxsToReplacement vs = do
  -- Should be O(lg N) number of new expressions.
  let (vs0, vs1) = splitAt (length vs `div` 2) vs
  Replacement m0 es0 <- maxsToReplacement vs0
  Replacement m1 es1 <- maxsToReplacement vs1
  vmax <- newVName "max"
  let emax = BasicOp $ BinOp (SMax Int64) (Var m0) (Var m1)
      new_stm = Let (Pattern [] [PatElem vmax
                                 (ExpMem.MemPrim (IntType Int64))]) (defAux ()) emax
      prev_stms = es0 ++ es1 ++ [new_stm]
  return $ Replacement vmax prev_stms

-- Modify a function to use the new replacements.
insertAndReplace :: M.Map MName Replacement -> FunDef ExplicitMemory ->
                    FunDef ExplicitMemory
insertAndReplace replaces0 fundef =
  let body' = evalState (transformBody $ funDefBody fundef) replaces0
  in fundef { funDefBody = body' }

  where transformBody :: Body ExplicitMemory ->
                         State (M.Map VName Replacement) (Body ExplicitMemory)
        transformBody body = do
          stms' <- concat <$> mapM transformStm (stmsToList $ bodyStms body)
          return $ body { bodyStms = stmsFromList stms' }

        transformStm :: Stm ExplicitMemory ->
                        State (M.Map VName Replacement) [Stm ExplicitMemory]
        transformStm stm@(Let (Pattern [] [PatElem mem_name
                                           (ExpMem.MemMem _ pat_space)]) _
                          (Op (ExpMem.Alloc _ space))) = do
          replaces <- get
          case M.lookup mem_name replaces of
            Just repl -> do
              let prev = replStms repl
                  new = Let (Pattern [] [PatElem mem_name
                                         (ExpMem.MemMem (Var (replName repl))
                                          pat_space)]) (defAux ())
                        (Op (ExpMem.Alloc (Var (replName repl)) space))
              -- We should only generate the new statements once.
              modify $ M.adjust (\repl0 -> repl0 { replStms = [] }) mem_name
              return (prev ++ [new])
            Nothing -> return [stm]
        transformStm (Let pat attr e) = do
          let mapper = identityMapper { mapOnBody = const transformBody }
          e' <- mapExpM mapper e
          return [Let pat attr e']


-- Change certain allocation sizes in a program.
transformFromKernelMaxSizedMappings :: MonadFreshNames m =>
  M.Map VName (PrimExp VName) -> VarMemMappings MemorySrc -> VarMemMappings MName ->
  Sizes -> ActualVariables -> M.Map MName (VName, ((VName, VName),
                                                   (VName, VName))) ->
  FunDef ExplicitMemory -> m (FunDef ExplicitMemory)
transformFromKernelMaxSizedMappings
  var_to_pe var_to_mem var_to_mem_res sizes_orig actual_vars mem_to_info fundef = do
  (mem_to_size_var, arr_to_mem_ixfun) <-
    unzip <$> mapM (uncurry withNewMaxVar) (M.assocs mem_to_info)
  let mem_to_size_var' = M.fromList mem_to_size_var
      arr_to_memloc = M.fromList $ map (\(arr, destmem, ixfun) ->
                                          (arr, MemoryLoc destmem ixfun))
                      $ concat arr_to_mem_ixfun

      fundef' = insertAndReplace mem_to_size_var' fundef
      sizes = memBlockSizesFunDef fundef'
  transformFromVarMemMappings arr_to_memloc (M.union var_to_mem_res (M.map memSrcName var_to_mem)) (M.map fst sizes) (M.map fst sizes_orig) True fundef'

  where withNewMaxVar :: MonadFreshNames m =>
                         MName -> (VName,
                                   ((VName, VName),
                                    (VName, VName))) ->
                         m ((MName, Replacement),
                            [(VName, MName, ExpMem.IxFun)])
        withNewMaxVar mem (kmem_size,
                           ((kmem_array, kmem_final_dim),
                            (xmem_array, xmem_final_dim))) = do
          final_dim_max_v <- newVName "max_final_dim"
          let final_dim_max_e =
                BasicOp (BinOp (SMax Int32)
                         (Var kmem_final_dim) (Var xmem_final_dim))

              var_to_pe_extension =
                M.singleton kmem_final_dim (LeafExp final_dim_max_v (IntType Int32))
              var_to_pe' = M.union var_to_pe_extension var_to_pe
              full_size_pe = fromJust "should exist" $ M.lookup kmem_size var_to_pe
              full_size_pe_expanded = expandPrimExp var_to_pe' full_size_pe
              new_full_size_m =
                letExp "max" =<< primExpToExp (return . BasicOp . SubExp . Var)
                full_size_pe_expanded
          (alloc_size_var, alloc_size_stms) <-
            modifyNameSource $ runState $ runBinderT new_full_size_m mempty
          let alloc_size_fd_stm =
                Let (Pattern [] [PatElem final_dim_max_v
                                 (ExpMem.MemPrim (IntType Int32))]) (defAux ()) final_dim_max_e
              alloc_size_stms' = oneStm alloc_size_fd_stm <> alloc_size_stms

              vars_kmem =
                S.insert kmem_array $ lookupActualVars' actual_vars kmem_array
              vars_xmem =
                S.insert xmem_array $ lookupActualVars' actual_vars xmem_array

              arrayToMapping final_dim v =
                let ixfun = memSrcIxFun $ fromJust "should exist"
                            $ M.lookup v var_to_mem
                    ixfun_new = IxFun.subsInIndexIxFun ixfun final_dim final_dim_max_v --newIxFun ixfun final_dim
                in (v, mem, ixfun_new)
              arr_to_mem_ixfun_kmem = map (arrayToMapping kmem_final_dim)
                                      $ S.toList vars_kmem
              arr_to_mem_ixfun_xmem = map (arrayToMapping xmem_final_dim)
                                      $ S.toList vars_xmem
              arr_to_mem_ixfun = arr_to_mem_ixfun_kmem ++ arr_to_mem_ixfun_xmem

          return ((mem, Replacement alloc_size_var $ stmsToList alloc_size_stms'),
                  arr_to_mem_ixfun)