{-# OPTIONS_GHC -fno-warn-redundant-constraints #-}

-- | This module contains a representation for the index function based on
-- linear-memory accessor descriptors; see Zhu, Hoeflinger and David work.
module Futhark.IR.Mem.IxFun
  ( IxFun (..),
    Shape,
    LMAD (..),
    LMADDim (..),
    index,
    mkExistential,
    iota,
    permute,
    reshape,
    coerce,
    slice,
    flatSlice,
    expand,
    shape,
    rank,
    isDirect,
    substituteInIxFun,
    substituteInLMAD,
    existentialize,
    closeEnough,
    disjoint,
    disjoint2,
    disjoint3,
  )
where

import Control.Category
import Control.Monad
import Control.Monad.State
import Data.Map.Strict qualified as M
import Data.Traversable
import Futhark.Analysis.PrimExp
import Futhark.Analysis.PrimExp.Convert
import Futhark.IR.Mem.LMAD hiding
  ( equivalent,
    flatSlice,
    index,
    iota,
    isDirect,
    mkExistential,
    permute,
    rank,
    reshape,
    shape,
    slice,
  )
import Futhark.IR.Mem.LMAD qualified as LMAD
import Futhark.IR.Prop
import Futhark.IR.Syntax
  ( FlatSlice (..),
    Slice (..),
    unitSlice,
  )
import Futhark.IR.Syntax.Core (Ext (..))
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import Futhark.Util.IntegralExp
import Futhark.Util.Pretty
import Prelude hiding (gcd, id, mod, (.))

-- | An index function is a mapping from a multidimensional array
-- index space (the domain) to a one-dimensional memory index space.
-- Essentially, it explains where the element at position @[i,j,p]@ of
-- some array is stored inside the flat one-dimensional array that
-- constitutes its memory.  For example, we can use this to
-- distinguish row-major and column-major representations.
--
-- An index function is represented as an LMAD.
data IxFun num = IxFun
  { forall num. IxFun num -> LMAD num
ixfunLMAD :: LMAD num,
    -- | the shape of the support array, i.e., the original array
    --   that birthed (is the start point) of this index function.
    forall num. IxFun num -> Shape num
base :: Shape num
  }
  deriving (Int -> IxFun num -> ShowS
forall num. Show num => Int -> IxFun num -> ShowS
forall num. Show num => [IxFun num] -> ShowS
forall num. Show num => IxFun num -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [IxFun num] -> ShowS
$cshowList :: forall num. Show num => [IxFun num] -> ShowS
show :: IxFun num -> String
$cshow :: forall num. Show num => IxFun num -> String
showsPrec :: Int -> IxFun num -> ShowS
$cshowsPrec :: forall num. Show num => Int -> IxFun num -> ShowS
Show, IxFun num -> IxFun num -> Bool
forall num. Eq num => IxFun num -> IxFun num -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: IxFun num -> IxFun num -> Bool
$c/= :: forall num. Eq num => IxFun num -> IxFun num -> Bool
== :: IxFun num -> IxFun num -> Bool
$c== :: forall num. Eq num => IxFun num -> IxFun num -> Bool
Eq)

instance (Pretty num) => Pretty (IxFun num) where
  pretty :: forall ann. IxFun num -> Doc ann
pretty (IxFun LMAD num
lmad Shape num
oshp) =
    forall ann. Doc ann -> Doc ann
braces forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. [Doc a] -> Doc a
semistack forall a b. (a -> b) -> a -> b
$
      [ Doc ann
"base:" forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall ann. Doc ann -> Doc ann
brackets (forall a. [Doc a] -> Doc a
commasep forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty Shape num
oshp),
        Doc ann
"LMAD:" forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall a ann. Pretty a => a -> Doc ann
pretty LMAD num
lmad
      ]

instance (Substitute num) => Substitute (IxFun num) where
  substituteNames :: Map VName VName -> IxFun num -> IxFun num
substituteNames Map VName VName
substs = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a -> b) -> a -> b
$ forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs

instance (Substitute num) => Rename (IxFun num) where
  rename :: IxFun num -> RenameM (IxFun num)
rename = forall a. Substitute a => a -> RenameM a
substituteRename

instance (FreeIn num) => FreeIn (IxFun num) where
  freeIn' :: IxFun num -> FV
freeIn' = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap forall a. FreeIn a => a -> FV
freeIn'

instance Functor IxFun where
  fmap :: forall a b. (a -> b) -> IxFun a -> IxFun b
fmap = forall (t :: * -> *) a b. Traversable t => (a -> b) -> t a -> t b
fmapDefault

instance Foldable IxFun where
  foldMap :: forall m a. Monoid m => (a -> m) -> IxFun a -> m
foldMap = forall (t :: * -> *) m a.
(Traversable t, Monoid m) =>
(a -> m) -> t a -> m
foldMapDefault

-- It is important that the traversal order here is the same as in
-- mkExistential.
instance Traversable IxFun where
  traverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> IxFun a -> f (IxFun b)
traverse a -> f b
f (IxFun LMAD a
lmad Shape a
oshp) =
    forall num. LMAD num -> Shape num -> IxFun num
IxFun forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse a -> f b
f LMAD a
lmad forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse a -> f b
f Shape a
oshp

-- | Substitute a name with a PrimExp in an index function.
substituteInIxFun ::
  (Ord a) =>
  M.Map a (TPrimExp t a) ->
  IxFun (TPrimExp t a) ->
  IxFun (TPrimExp t a)
substituteInIxFun :: forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a)
-> IxFun (TPrimExp t a) -> IxFun (TPrimExp t a)
substituteInIxFun Map a (TPrimExp t a)
tab (IxFun LMAD (TPrimExp t a)
lmad Shape (TPrimExp t a)
oshp) =
  forall num. LMAD num -> Shape num -> IxFun num
IxFun
    (forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a) -> LMAD (TPrimExp t a) -> LMAD (TPrimExp t a)
substituteInLMAD Map a (TPrimExp t a)
tab LMAD (TPrimExp t a)
lmad)
    (forall a b. (a -> b) -> [a] -> [b]
map (forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall v. Ord v => Map v (PrimExp v) -> PrimExp v -> PrimExp v
substituteInPrimExp Map a (PrimExp a)
tab' forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped) Shape (TPrimExp t a)
oshp)
  where
    tab' :: Map a (PrimExp a)
tab' = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped Map a (TPrimExp t a)
tab

-- | Is this is a row-major array?
isDirect :: (Eq num, IntegralExp num) => IxFun num -> Bool
isDirect :: forall num. (Eq num, IntegralExp num) => IxFun num -> Bool
isDirect (IxFun (LMAD num
offset [LMADDim num]
dims) Shape num
oshp) =
  let strides_expected :: Shape num
strides_expected = forall a. [a] -> [a]
reverse forall a b. (a -> b) -> a -> b
$ forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl forall a. Num a => a -> a -> a
(*) num
1 (forall a. [a] -> [a]
reverse (forall a. [a] -> [a]
tail Shape num
oshp))
   in forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape num
oshp forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims
        Bool -> Bool -> Bool
&& num
offset forall a. Eq a => a -> a -> Bool
== num
0
        Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all
          (\(LMADDim num
s num
n, num
d, num
se) -> num
s forall a. Eq a => a -> a -> Bool
== num
se Bool -> Bool -> Bool
&& num
n forall a. Eq a => a -> a -> Bool
== num
d)
          (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [LMADDim num]
dims Shape num
oshp Shape num
strides_expected)

-- | The index space of the index function.  This is the same as the
-- shape of arrays that the index function supports.
shape :: (Eq num, IntegralExp num) => IxFun num -> Shape num
shape :: forall num. (Eq num, IntegralExp num) => IxFun num -> Shape num
shape = forall num. LMAD num -> Shape num
LMAD.shape forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall num. IxFun num -> LMAD num
ixfunLMAD

-- | Compute the flat memory index for a complete set @inds@ of array indices
-- and a certain element size @elem_size@.
index ::
  (IntegralExp num, Eq num) =>
  IxFun num ->
  Indices num ->
  num
index :: forall num.
(IntegralExp num, Eq num) =>
IxFun num -> Indices num -> num
index = forall num.
(IntegralExp num, Eq num) =>
LMAD num -> Indices num -> num
LMAD.index forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall num. IxFun num -> LMAD num
ixfunLMAD

-- | iota with offset.
iotaOffset :: (IntegralExp num) => num -> Shape num -> IxFun num
iotaOffset :: forall num. IntegralExp num => num -> Shape num -> IxFun num
iotaOffset num
o Shape num
ns = forall num. LMAD num -> Shape num -> IxFun num
IxFun (forall num. IntegralExp num => num -> [num] -> LMAD num
LMAD.iota num
o Shape num
ns) Shape num
ns

-- | iota.
iota :: (IntegralExp num) => Shape num -> IxFun num
iota :: forall num. IntegralExp num => Shape num -> IxFun num
iota = forall num. IntegralExp num => num -> Shape num -> IxFun num
iotaOffset num
0

-- | Create a single-LMAD index function that is
-- existential in everything, with the provided permutation.
mkExistential :: Int -> Int -> Int -> IxFun (Ext a)
mkExistential :: forall a. Int -> Int -> Int -> IxFun (Ext a)
mkExistential Int
basis_rank Int
lmad_rank Int
start =
  forall num. LMAD num -> Shape num -> IxFun num
IxFun (forall a. Int -> Int -> LMAD (Ext a)
LMAD.mkExistential Int
lmad_rank Int
start) forall {a}. [Ext a]
basis
  where
    basis :: [Ext a]
basis = forall a. Int -> [a] -> [a]
take Int
basis_rank forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a. Int -> Ext a
Ext [Int
start forall a. Num a => a -> a -> a
+ Int
1 forall a. Num a => a -> a -> a
+ Int
lmad_rank forall a. Num a => a -> a -> a
* Int
2 ..]

-- | Permute dimensions.
permute ::
  (IntegralExp num) =>
  IxFun num ->
  Permutation ->
  IxFun num
permute :: forall num.
IntegralExp num =>
IxFun num -> Permutation -> IxFun num
permute (IxFun LMAD num
lmad Shape num
oshp) Permutation
perm_new =
  forall num. LMAD num -> Shape num -> IxFun num
IxFun (forall num. LMAD num -> Permutation -> LMAD num
LMAD.permute LMAD num
lmad Permutation
perm_new) Shape num
oshp

-- | Slice an index function.
slice ::
  (Eq num, IntegralExp num) =>
  IxFun num ->
  Slice num ->
  IxFun num
slice :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
slice ixfun :: IxFun num
ixfun@(IxFun lmad :: LMAD num
lmad@(LMAD num
_ [LMADDim num]
_) Shape num
oshp) (Slice [DimIndex num]
is)
  -- Avoid identity slicing.
  | [DimIndex num]
is forall a. Eq a => a -> a -> Bool
== forall a b. (a -> b) -> [a] -> [b]
map (forall d. Num d => d -> d -> DimIndex d
unitSlice num
0) (forall num. (Eq num, IntegralExp num) => IxFun num -> Shape num
shape IxFun num
ixfun) = IxFun num
ixfun
  | Bool
otherwise =
      forall num. LMAD num -> Shape num -> IxFun num
IxFun (forall num.
(Eq num, IntegralExp num) =>
LMAD num -> Slice num -> LMAD num
LMAD.slice LMAD num
lmad (forall d. [DimIndex d] -> Slice d
Slice [DimIndex num]
is)) Shape num
oshp

-- | Flat-slice an index function.
flatSlice ::
  (Eq num, IntegralExp num) =>
  IxFun num ->
  FlatSlice num ->
  IxFun num
flatSlice :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> FlatSlice num -> IxFun num
flatSlice (IxFun LMAD num
lmad Shape num
oshp) FlatSlice num
s = forall num. LMAD num -> Shape num -> IxFun num
IxFun (forall num.
IntegralExp num =>
LMAD num -> FlatSlice num -> LMAD num
LMAD.flatSlice LMAD num
lmad FlatSlice num
s) Shape num
oshp

-- | Reshape an index function.
--
-- There are four conditions that all must hold for the result of a reshape
-- operation to remain in the one-LMAD domain:
--
--   (1) the permutation of the underlying LMAD must leave unchanged
--       the LMAD dimensions that were *not* reshape coercions.
--   (2) the repetition of dimensions of the underlying LMAD must
--       refer only to the coerced-dimensions of the reshape operation.
--
-- If any of these conditions do not hold, then the reshape operation
-- will conservatively add a new LMAD to the list, leading to a
-- representation that provides less opportunities for further
-- analysis
reshape ::
  (Eq num, IntegralExp num) =>
  IxFun num ->
  Shape num ->
  Maybe (IxFun num)
reshape :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Shape num -> Maybe (IxFun num)
reshape (IxFun LMAD num
lmad Shape num
_) Shape num
new_shape =
  forall num. LMAD num -> Shape num -> IxFun num
IxFun forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall num.
(Eq num, IntegralExp num) =>
LMAD num -> Shape num -> Maybe (LMAD num)
LMAD.reshape LMAD num
lmad Shape num
new_shape forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Shape num
new_shape

-- | Coerce an index function to look like it has a new shape.
-- Dynamically the shape must be the same.
coerce ::
  (Eq num, IntegralExp num) =>
  IxFun num ->
  Shape num ->
  IxFun num
coerce :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Shape num -> IxFun num
coerce (IxFun LMAD num
lmad Shape num
_) Shape num
new_shape =
  forall num. LMAD num -> Shape num -> IxFun num
IxFun (LMAD num -> LMAD num
onLMAD LMAD num
lmad) Shape num
new_shape
  where
    onLMAD :: LMAD num -> LMAD num
onLMAD (LMAD num
offset [LMADDim num]
dims) = forall num. num -> [LMADDim num] -> LMAD num
LMAD num
offset forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall {num}. LMADDim num -> num -> LMADDim num
onDim [LMADDim num]
dims Shape num
new_shape
    onDim :: LMADDim num -> num -> LMADDim num
onDim LMADDim num
ld num
d = LMADDim num
ld {ldShape :: num
ldShape = num
d}

-- | The number of dimensions in the domain of the input function.
rank :: (IntegralExp num) => IxFun num -> Int
rank :: forall num. IntegralExp num => IxFun num -> Int
rank (IxFun (LMAD num
_ [LMADDim num]
sss) Shape num
_) = forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
sss

-- | Conceptually expand index function to be a particular slice of
-- another by adjusting the offset and strides.  Used for memory
-- expansion.
expand ::
  (Eq num, IntegralExp num) => num -> num -> IxFun num -> Maybe (IxFun num)
expand :: forall num.
(Eq num, IntegralExp num) =>
num -> num -> IxFun num -> Maybe (IxFun num)
expand num
o num
p (IxFun LMAD num
lmad Shape num
base) =
  let onDim :: LMADDim num -> LMADDim num
onDim LMADDim num
ld = LMADDim num
ld {ldStride :: num
LMAD.ldStride = num
p forall a. Num a => a -> a -> a
* forall num. LMADDim num -> num
LMAD.ldStride LMADDim num
ld}
      lmad' :: LMAD num
lmad' =
        forall num. num -> [LMADDim num] -> LMAD num
LMAD
          (num
o forall a. Num a => a -> a -> a
+ num
p forall a. Num a => a -> a -> a
* forall num. LMAD num -> num
LMAD.offset LMAD num
lmad)
          (forall a b. (a -> b) -> [a] -> [b]
map LMADDim num -> LMADDim num
onDim (forall num. LMAD num -> [LMADDim num]
LMAD.dims LMAD num
lmad))
   in forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall num. LMAD num -> Shape num -> IxFun num
IxFun LMAD num
lmad' Shape num
base

-- | Turn all the leaves of the index function into 'Ext's.  We
--  require that there's only one LMAD, that the index function is
--  contiguous, and the base shape has only one dimension.
existentialize ::
  IxFun (TPrimExp Int64 a) ->
  IxFun (TPrimExp Int64 (Ext b))
existentialize :: forall a b.
IxFun (TPrimExp Int64 a) -> IxFun (TPrimExp Int64 (Ext b))
existentialize IxFun (TPrimExp Int64 a)
ixfun = forall s a. State s a -> s -> a
evalState (forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall a b. a -> b -> a
const forall {k} {t :: k} {a}. StateT Int Identity (TPrimExp t (Ext a))
mkExt) IxFun (TPrimExp Int64 a)
ixfun) Int
0
  where
    mkExt :: StateT Int Identity (TPrimExp t (Ext a))
mkExt = do
      Int
i <- forall s (m :: * -> *). MonadState s m => m s
get
      forall s (m :: * -> *). MonadState s m => s -> m ()
put forall a b. (a -> b) -> a -> b
$ Int
i forall a. Num a => a -> a -> a
+ Int
1
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp forall a b. (a -> b) -> a -> b
$ forall v. v -> PrimType -> PrimExp v
LeafExp (forall a. Int -> Ext a
Ext Int
i) PrimType
int64

-- | When comparing index functions as part of the type check in KernelsMem,
-- we may run into problems caused by the simplifier. As index functions can be
-- generalized over if-then-else expressions, the simplifier might hoist some of
-- the code from inside the if-then-else (computing the offset of an array, for
-- instance), but now the type checker cannot verify that the generalized index
-- function is valid, because some of the existentials are computed somewhere
-- else. To Work around this, we've had to relax the KernelsMem type-checker
-- a bit, specifically, we've introduced this function to verify whether two
-- index functions are "close enough" that we can assume that they match. We use
-- this instead of `ixfun1 == ixfun2` and hope that it's good enough.
closeEnough :: IxFun num -> IxFun num -> Bool
closeEnough :: forall num. IxFun num -> IxFun num -> Bool
closeEnough IxFun num
ixf1 IxFun num
ixf2 =
  (forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall num. IxFun num -> Shape num
base IxFun num
ixf1) forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall num. IxFun num -> Shape num
base IxFun num
ixf2))
    Bool -> Bool -> Bool
&& forall {num} {num}. LMAD num -> LMAD num -> Bool
closeEnoughLMADs (forall num. IxFun num -> LMAD num
ixfunLMAD IxFun num
ixf1) (forall num. IxFun num -> LMAD num
ixfunLMAD IxFun num
ixf2)
  where
    closeEnoughLMADs :: LMAD num -> LMAD num -> Bool
closeEnoughLMADs LMAD num
lmad1 LMAD num
lmad2 =
      forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall num. LMAD num -> [LMADDim num]
LMAD.dims LMAD num
lmad1) forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall num. LMAD num -> [LMADDim num]
LMAD.dims LMAD num
lmad2)