-- | This module contains a representation of linear-memory accessor
-- descriptors (LMAD); see work by Zhu, Hoeflinger and David.
--
-- This module is designed to be used as a qualified import, as the
-- exported names are quite generic.
module Futhark.IR.Mem.LMAD
  ( Shape,
    Indices,
    LMAD (..),
    LMADDim (..),
    Permutation,
    index,
    slice,
    flatSlice,
    reshape,
    permute,
    shape,
    rank,
    substituteInLMAD,
    disjoint,
    disjoint2,
    disjoint3,
    dynamicEqualsLMAD,
    iota,
    mkExistential,
    equivalent,
    isDirect,
  )
where

import Control.Category
import Control.Monad
import Data.Function (on, (&))
import Data.List (elemIndex, partition, sortBy)
import Data.Map.Strict qualified as M
import Data.Maybe (fromJust, isNothing)
import Data.Traversable
import Futhark.Analysis.AlgSimplify qualified as AlgSimplify
import Futhark.Analysis.PrimExp
import Futhark.Analysis.PrimExp.Convert
import Futhark.IR.Mem.Interval
import Futhark.IR.Prop
import Futhark.IR.Syntax
  ( DimIndex (..),
    Ext (..),
    FlatDimIndex (..),
    FlatSlice (..),
    Slice (..),
    Type,
    unitSlice,
  )
import Futhark.IR.Syntax.Core (VName (..))
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import Futhark.Util
import Futhark.Util.IntegralExp
import Futhark.Util.Pretty
import Prelude hiding (gcd, id, mod, (.))

-- | The shape of an index function.
type Shape num = [num]

-- | Indices passed to an LMAD.  Must always match the rank of the LMAD.
type Indices num = [num]

-- | A complete permutation.
type Permutation = [Int]

-- | A single dimension in an 'LMAD'.
data LMADDim num = LMADDim
  { forall num. LMADDim num -> num
ldStride :: num,
    forall num. LMADDim num -> num
ldShape :: num
  }
  deriving (Int -> LMADDim num -> ShowS
forall num. Show num => Int -> LMADDim num -> ShowS
forall num. Show num => [LMADDim num] -> ShowS
forall num. Show num => LMADDim num -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [LMADDim num] -> ShowS
$cshowList :: forall num. Show num => [LMADDim num] -> ShowS
show :: LMADDim num -> String
$cshow :: forall num. Show num => LMADDim num -> String
showsPrec :: Int -> LMADDim num -> ShowS
$cshowsPrec :: forall num. Show num => Int -> LMADDim num -> ShowS
Show, LMADDim num -> LMADDim num -> Bool
forall num. Eq num => LMADDim num -> LMADDim num -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: LMADDim num -> LMADDim num -> Bool
$c/= :: forall num. Eq num => LMADDim num -> LMADDim num -> Bool
== :: LMADDim num -> LMADDim num -> Bool
$c== :: forall num. Eq num => LMADDim num -> LMADDim num -> Bool
Eq, LMADDim num -> LMADDim num -> Bool
LMADDim num -> LMADDim num -> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {num}. Ord num => Eq (LMADDim num)
forall num. Ord num => LMADDim num -> LMADDim num -> Bool
forall num. Ord num => LMADDim num -> LMADDim num -> Ordering
forall num. Ord num => LMADDim num -> LMADDim num -> LMADDim num
min :: LMADDim num -> LMADDim num -> LMADDim num
$cmin :: forall num. Ord num => LMADDim num -> LMADDim num -> LMADDim num
max :: LMADDim num -> LMADDim num -> LMADDim num
$cmax :: forall num. Ord num => LMADDim num -> LMADDim num -> LMADDim num
>= :: LMADDim num -> LMADDim num -> Bool
$c>= :: forall num. Ord num => LMADDim num -> LMADDim num -> Bool
> :: LMADDim num -> LMADDim num -> Bool
$c> :: forall num. Ord num => LMADDim num -> LMADDim num -> Bool
<= :: LMADDim num -> LMADDim num -> Bool
$c<= :: forall num. Ord num => LMADDim num -> LMADDim num -> Bool
< :: LMADDim num -> LMADDim num -> Bool
$c< :: forall num. Ord num => LMADDim num -> LMADDim num -> Bool
compare :: LMADDim num -> LMADDim num -> Ordering
$ccompare :: forall num. Ord num => LMADDim num -> LMADDim num -> Ordering
Ord)

-- | LMAD's representation consists of a general offset and for each
-- dimension a stride, number of elements (or shape), and
-- permutation. Note that the permutation is not strictly necessary in
-- that the permutation can be performed directly on LMAD dimensions,
-- but then it is difficult to extract the permutation back from an
-- LMAD.
--
-- LMAD algebra is closed under composition w.r.t. operators such as
-- permute, index and slice.  However, other operations, such as
-- reshape, cannot always be represented inside the LMAD algebra.
--
-- It follows that the general representation of an index function is a list of
-- LMADS, in which each following LMAD in the list implicitly corresponds to an
-- irregular reshaping operation.
--
-- However, we expect that the common case is when the index function is one
-- LMAD -- we call this the "nice" representation.
--
-- Finally, the list of LMADs is kept in an @IxFun@ together with the shape of
-- the original array, and a bit to indicate whether the index function is
-- contiguous, i.e., if we instantiate all the points of the current index
-- function, do we get a contiguous memory interval?
--
-- By definition, the LMAD \( \sigma + \{ (n_1, s_1), \ldots, (n_k, s_k) \} \),
-- where \(n\) and \(s\) denote the shape and stride of each dimension, denotes
-- the set of points:
--
-- \[
--    \{ ~ \sigma + i_1 * s_1 + \ldots + i_m * s_m ~ | ~ 0 \leq i_1 < n_1, \ldots, 0 \leq i_m < n_m ~ \}
-- \]
data LMAD num = LMAD
  { forall num. LMAD num -> num
offset :: num,
    forall num. LMAD num -> [LMADDim num]
dims :: [LMADDim num]
  }
  deriving (Int -> LMAD num -> ShowS
forall num. Show num => Int -> LMAD num -> ShowS
forall num. Show num => [LMAD num] -> ShowS
forall num. Show num => LMAD num -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [LMAD num] -> ShowS
$cshowList :: forall num. Show num => [LMAD num] -> ShowS
show :: LMAD num -> String
$cshow :: forall num. Show num => LMAD num -> String
showsPrec :: Int -> LMAD num -> ShowS
$cshowsPrec :: forall num. Show num => Int -> LMAD num -> ShowS
Show, LMAD num -> LMAD num -> Bool
forall num. Eq num => LMAD num -> LMAD num -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: LMAD num -> LMAD num -> Bool
$c/= :: forall num. Eq num => LMAD num -> LMAD num -> Bool
== :: LMAD num -> LMAD num -> Bool
$c== :: forall num. Eq num => LMAD num -> LMAD num -> Bool
Eq, LMAD num -> LMAD num -> Bool
LMAD num -> LMAD num -> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {num}. Ord num => Eq (LMAD num)
forall num. Ord num => LMAD num -> LMAD num -> Bool
forall num. Ord num => LMAD num -> LMAD num -> Ordering
forall num. Ord num => LMAD num -> LMAD num -> LMAD num
min :: LMAD num -> LMAD num -> LMAD num
$cmin :: forall num. Ord num => LMAD num -> LMAD num -> LMAD num
max :: LMAD num -> LMAD num -> LMAD num
$cmax :: forall num. Ord num => LMAD num -> LMAD num -> LMAD num
>= :: LMAD num -> LMAD num -> Bool
$c>= :: forall num. Ord num => LMAD num -> LMAD num -> Bool
> :: LMAD num -> LMAD num -> Bool
$c> :: forall num. Ord num => LMAD num -> LMAD num -> Bool
<= :: LMAD num -> LMAD num -> Bool
$c<= :: forall num. Ord num => LMAD num -> LMAD num -> Bool
< :: LMAD num -> LMAD num -> Bool
$c< :: forall num. Ord num => LMAD num -> LMAD num -> Bool
compare :: LMAD num -> LMAD num -> Ordering
$ccompare :: forall num. Ord num => LMAD num -> LMAD num -> Ordering
Ord)

instance (Pretty num) => Pretty (LMAD num) where
  pretty :: forall ann. LMAD num -> Doc ann
pretty (LMAD num
offset [LMADDim num]
dims) =
    forall ann. Doc ann -> Doc ann
braces forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. [Doc a] -> Doc a
semistack forall a b. (a -> b) -> a -> b
$
      [ Doc ann
"offset:" forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall ann. Doc ann -> Doc ann
group (forall a ann. Pretty a => a -> Doc ann
pretty num
offset),
        Doc ann
"strides:" forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall {b} {ann}. Pretty b => (LMADDim num -> b) -> Doc ann
p forall num. LMADDim num -> num
ldStride,
        Doc ann
"shape:" forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall {b} {ann}. Pretty b => (LMADDim num -> b) -> Doc ann
p forall num. LMADDim num -> num
ldShape
      ]
    where
      p :: (LMADDim num -> b) -> Doc ann
p LMADDim num -> b
f = forall ann. Doc ann -> Doc ann
group forall a b. (a -> b) -> a -> b
$ forall ann. Doc ann -> Doc ann
brackets forall a b. (a -> b) -> a -> b
$ forall ann. Doc ann -> Doc ann
align forall a b. (a -> b) -> a -> b
$ forall a. [Doc a] -> Doc a
commasep forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall a ann. Pretty a => a -> Doc ann
pretty forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. LMADDim num -> b
f) [LMADDim num]
dims

instance (Substitute num) => Substitute (LMAD num) where
  substituteNames :: Map VName VName -> LMAD num -> LMAD num
substituteNames Map VName VName
substs = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a -> b) -> a -> b
$ forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs

instance (Substitute num) => Rename (LMAD num) where
  rename :: LMAD num -> RenameM (LMAD num)
rename = forall a. Substitute a => a -> RenameM a
substituteRename

instance (FreeIn num) => FreeIn (LMAD num) where
  freeIn' :: LMAD num -> FV
freeIn' = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap forall a. FreeIn a => a -> FV
freeIn'

instance (FreeIn num) => FreeIn (LMADDim num) where
  freeIn' :: LMADDim num -> FV
freeIn' (LMADDim num
s num
n) = forall a. FreeIn a => a -> FV
freeIn' num
s forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' num
n

instance Functor LMAD where
  fmap :: forall a b. (a -> b) -> LMAD a -> LMAD b
fmap = forall (t :: * -> *) a b. Traversable t => (a -> b) -> t a -> t b
fmapDefault

instance Foldable LMAD where
  foldMap :: forall m a. Monoid m => (a -> m) -> LMAD a -> m
foldMap = forall (t :: * -> *) m a.
(Traversable t, Monoid m) =>
(a -> m) -> t a -> m
foldMapDefault

instance Traversable LMAD where
  traverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> LMAD a -> f (LMAD b)
traverse a -> f b
f (LMAD a
offset [LMADDim a]
dims) =
    forall num. num -> [LMADDim num] -> LMAD num
LMAD forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> f b
f a
offset forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse LMADDim a -> f (LMADDim b)
f' [LMADDim a]
dims
    where
      f' :: LMADDim a -> f (LMADDim b)
f' (LMADDim a
s a
n) = forall num. num -> num -> LMADDim num
LMADDim forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> f b
f a
s forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> a -> f b
f a
n

flatOneDim ::
  (Eq num, IntegralExp num) =>
  num ->
  num ->
  num
flatOneDim :: forall num. (Eq num, IntegralExp num) => num -> num -> num
flatOneDim num
s num
i
  | num
s forall a. Eq a => a -> a -> Bool
== num
0 = num
0
  | Bool
otherwise = num
i forall a. Num a => a -> a -> a
* num
s

index :: (IntegralExp num, Eq num) => LMAD num -> Indices num -> num
index :: forall num.
(IntegralExp num, Eq num) =>
LMAD num -> Indices num -> num
index (LMAD num
off [LMADDim num]
dims) Indices num
inds =
  num
off forall a. Num a => a -> a -> a
+ forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum Indices num
prods
  where
    prods :: Indices num
prods = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall num. (Eq num, IntegralExp num) => num -> num -> num
flatOneDim (forall a b. (a -> b) -> [a] -> [b]
map forall num. LMADDim num -> num
ldStride [LMADDim num]
dims) Indices num
inds

-- | Handle the case where a slice can stay within a single LMAD.
slice ::
  (Eq num, IntegralExp num) =>
  LMAD num ->
  Slice num ->
  LMAD num
slice :: forall num.
(Eq num, IntegralExp num) =>
LMAD num -> Slice num -> LMAD num
slice lmad :: LMAD num
lmad@(LMAD num
_ [LMADDim num]
ldims) (Slice [DimIndex num]
is) =
  forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl forall num.
(Eq num, IntegralExp num) =>
LMAD num -> (DimIndex num, LMADDim num) -> LMAD num
sliceOne (forall num. num -> [LMADDim num] -> LMAD num
LMAD (forall num. LMAD num -> num
offset LMAD num
lmad) []) forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [DimIndex num]
is [LMADDim num]
ldims
  where
    sliceOne ::
      (Eq num, IntegralExp num) =>
      LMAD num ->
      (DimIndex num, LMADDim num) ->
      LMAD num
    sliceOne :: forall num.
(Eq num, IntegralExp num) =>
LMAD num -> (DimIndex num, LMADDim num) -> LMAD num
sliceOne (LMAD num
off [LMADDim num]
dims) (DimFix num
i, LMADDim num
s num
_x) =
      forall num. num -> [LMADDim num] -> LMAD num
LMAD (num
off forall a. Num a => a -> a -> a
+ forall num. (Eq num, IntegralExp num) => num -> num -> num
flatOneDim num
s num
i) [LMADDim num]
dims
    sliceOne (LMAD num
off [LMADDim num]
dims) (DimSlice num
_ num
ne num
_, LMADDim num
0 num
_) =
      forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off ([LMADDim num]
dims forall a. [a] -> [a] -> [a]
++ [forall num. num -> num -> LMADDim num
LMADDim num
0 num
ne])
    sliceOne (LMAD num
off [LMADDim num]
dims) (DimIndex num
dmind, dim :: LMADDim num
dim@(LMADDim num
_ num
n))
      | DimIndex num
dmind forall a. Eq a => a -> a -> Bool
== forall d. Num d => d -> d -> DimIndex d
unitSlice num
0 num
n = forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off ([LMADDim num]
dims forall a. [a] -> [a] -> [a]
++ [LMADDim num
dim])
    sliceOne (LMAD num
off [LMADDim num]
dims) (DimIndex num
dmind, LMADDim num
s num
n)
      | DimIndex num
dmind forall a. Eq a => a -> a -> Bool
== forall d. d -> d -> d -> DimIndex d
DimSlice (num
n forall a. Num a => a -> a -> a
- num
1) num
n (-num
1) =
          let off' :: num
off' = num
off forall a. Num a => a -> a -> a
+ forall num. (Eq num, IntegralExp num) => num -> num -> num
flatOneDim num
s (num
n forall a. Num a => a -> a -> a
- num
1)
           in forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off' ([LMADDim num]
dims forall a. [a] -> [a] -> [a]
++ [forall num. num -> num -> LMADDim num
LMADDim (num
s forall a. Num a => a -> a -> a
* (-num
1)) num
n])
    sliceOne (LMAD num
off [LMADDim num]
dims) (DimSlice num
b num
ne num
0, LMADDim num
s num
_) =
      forall num. num -> [LMADDim num] -> LMAD num
LMAD (num
off forall a. Num a => a -> a -> a
+ forall num. (Eq num, IntegralExp num) => num -> num -> num
flatOneDim num
s num
b) ([LMADDim num]
dims forall a. [a] -> [a] -> [a]
++ [forall num. num -> num -> LMADDim num
LMADDim num
0 num
ne])
    sliceOne (LMAD num
off [LMADDim num]
dims) (DimSlice num
bs num
ns num
ss, LMADDim num
s num
_) =
      forall num. num -> [LMADDim num] -> LMAD num
LMAD (num
off forall a. Num a => a -> a -> a
+ num
s forall a. Num a => a -> a -> a
* num
bs) ([LMADDim num]
dims forall a. [a] -> [a] -> [a]
++ [forall num. num -> num -> LMADDim num
LMADDim (num
ss forall a. Num a => a -> a -> a
* num
s) num
ns])

-- | Flat-slice an LMAD.
flatSlice ::
  (IntegralExp num) =>
  LMAD num ->
  FlatSlice num ->
  LMAD num
flatSlice :: forall num.
IntegralExp num =>
LMAD num -> FlatSlice num -> LMAD num
flatSlice (LMAD num
offset (LMADDim num
dim : [LMADDim num]
dims)) (FlatSlice num
new_offset [FlatDimIndex num]
is) =
  forall num. num -> [LMADDim num] -> LMAD num
LMAD
    (num
offset forall a. Num a => a -> a -> a
+ num
new_offset forall a. Num a => a -> a -> a
* forall num. LMADDim num -> num
ldStride LMADDim num
dim)
    (forall a b. (a -> b) -> [a] -> [b]
map (forall {num}. Num num => num -> FlatDimIndex num -> LMADDim num
helper forall a b. (a -> b) -> a -> b
$ forall num. LMADDim num -> num
ldStride LMADDim num
dim) [FlatDimIndex num]
is forall a. Semigroup a => a -> a -> a
<> [LMADDim num]
dims)
  where
    helper :: num -> FlatDimIndex num -> LMADDim num
helper num
s0 (FlatDimIndex num
n num
s) = forall num. num -> num -> LMADDim num
LMADDim (num
s0 forall a. Num a => a -> a -> a
* num
s) num
n
flatSlice (LMAD num
offset []) FlatSlice num
_ = forall num. num -> [LMADDim num] -> LMAD num
LMAD num
offset []

-- | Handle the case where a reshape operation can stay inside a
-- single LMAD.  See "Futhark.IR.Mem.IxFun.reshape" for
-- conditions.
reshape ::
  (Eq num, IntegralExp num) => LMAD num -> Shape num -> Maybe (LMAD num)
--
-- First a special case for when we are merely injecting unit
-- dimensions into an LMAD.
reshape :: forall num.
(Eq num, IntegralExp num) =>
LMAD num -> Shape num -> Maybe (LMAD num)
reshape (LMAD num
off [LMADDim num]
dims) Shape num
newshape
  | Just [LMADDim num]
dims' <- forall {a}.
(Eq a, Num a) =>
[a] -> [LMADDim a] -> Maybe [LMADDim a]
addingVacuous Shape num
newshape [LMADDim num]
dims =
      forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off [LMADDim num]
dims'
  where
    addingVacuous :: [a] -> [LMADDim a] -> Maybe [LMADDim a]
addingVacuous (a
dnew : [a]
dnews) (LMADDim a
dold : [LMADDim a]
dolds)
      | a
dnew forall a. Eq a => a -> a -> Bool
== forall num. LMADDim num -> num
ldShape LMADDim a
dold =
          (LMADDim a
dold :) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [a] -> [LMADDim a] -> Maybe [LMADDim a]
addingVacuous [a]
dnews [LMADDim a]
dolds
    addingVacuous (a
1 : [a]
dnews) [LMADDim a]
dolds =
      (forall num. num -> num -> LMADDim num
LMADDim a
0 a
1 :) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [a] -> [LMADDim a] -> Maybe [LMADDim a]
addingVacuous [a]
dnews [LMADDim a]
dolds
    addingVacuous [] [] = forall a. a -> Maybe a
Just []
    addingVacuous [a]
_ [LMADDim a]
_ = forall a. Maybe a
Nothing

-- Then the general case.
reshape lmad :: LMAD num
lmad@(LMAD num
off [LMADDim num]
dims) Shape num
newshape = do
  let base_stride :: num
base_stride = forall num. LMADDim num -> num
ldStride (forall a. [a] -> a
last [LMADDim num]
dims)
      no_zero_stride :: Bool
no_zero_stride = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\LMADDim num
ld -> forall num. LMADDim num -> num
ldStride LMADDim num
ld forall a. Eq a => a -> a -> Bool
/= num
0) [LMADDim num]
dims
      strides_as_expected :: Bool
strides_as_expected = LMAD num
lmad forall a. Eq a => a -> a -> Bool
== forall num. IntegralExp num => num -> num -> [num] -> LMAD num
iotaStrided num
off num
base_stride (forall a. LMAD a -> [a]
shape LMAD num
lmad)

  forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$ Bool
no_zero_stride Bool -> Bool -> Bool
&& Bool
strides_as_expected

  forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall num. IntegralExp num => num -> num -> [num] -> LMAD num
iotaStrided num
off num
base_stride Shape num
newshape
{-# NOINLINE reshape #-}

-- | Substitute a name with a PrimExp in an LMAD.
substituteInLMAD ::
  (Ord a) =>
  M.Map a (TPrimExp t a) ->
  LMAD (TPrimExp t a) ->
  LMAD (TPrimExp t a)
substituteInLMAD :: forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a) -> LMAD (TPrimExp t a) -> LMAD (TPrimExp t a)
substituteInLMAD Map a (TPrimExp t a)
tab (LMAD TPrimExp t a
offset [LMADDim (TPrimExp t a)]
dims) =
  forall num. num -> [LMADDim num] -> LMAD num
LMAD (forall {k} {k} {t :: k} {t :: k}. TPrimExp t a -> TPrimExp t a
sub TPrimExp t a
offset) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (\(LMADDim TPrimExp t a
s TPrimExp t a
n) -> forall num. num -> num -> LMADDim num
LMADDim (forall {k} {k} {t :: k} {t :: k}. TPrimExp t a -> TPrimExp t a
sub TPrimExp t a
s) (forall {k} {k} {t :: k} {t :: k}. TPrimExp t a -> TPrimExp t a
sub TPrimExp t a
n)) [LMADDim (TPrimExp t a)]
dims
  where
    tab' :: Map a (PrimExp a)
tab' = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped Map a (TPrimExp t a)
tab
    sub :: TPrimExp t a -> TPrimExp t a
sub = forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall v. Ord v => Map v (PrimExp v) -> PrimExp v -> PrimExp v
substituteInPrimExp Map a (PrimExp a)
tab' forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped

-- | Shape of an LMAD.
shape :: LMAD num -> Shape num
shape :: forall a. LMAD a -> [a]
shape = forall a b. (a -> b) -> [a] -> [b]
map forall num. LMADDim num -> num
ldShape forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall num. LMAD num -> [LMADDim num]
dims

-- | Rank of an LMAD.
rank :: LMAD num -> Int
rank :: forall a. LMAD a -> Int
rank = forall (t :: * -> *) a. Foldable t => t a -> Int
length forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. LMAD a -> [a]
shape

iotaStrided ::
  (IntegralExp num) =>
  -- | Offset
  num ->
  -- | Base Stride
  num ->
  -- | Shape
  [num] ->
  LMAD num
iotaStrided :: forall num. IntegralExp num => num -> num -> [num] -> LMAD num
iotaStrided num
off num
s [num]
ns =
  let ss :: [num]
ss = forall a. [a] -> [a]
tail forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
reverse forall a b. (a -> b) -> a -> b
$ forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl forall a. Num a => a -> a -> a
(*) num
s forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
reverse [num]
ns
   in forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall num. num -> num -> LMADDim num
LMADDim [num]
ss [num]
ns

-- | Generalised iota with user-specified offset.
iota ::
  (IntegralExp num) =>
  -- | Offset
  num ->
  -- | Shape
  [num] ->
  LMAD num
iota :: forall num. IntegralExp num => num -> [num] -> LMAD num
iota num
off = forall num. IntegralExp num => num -> num -> [num] -> LMAD num
iotaStrided num
off num
1
{-# NOINLINE iota #-}

-- | Create an LMAD that is existential in everything.
mkExistential :: Int -> Int -> LMAD (Ext a)
mkExistential :: forall a. Int -> Int -> LMAD (Ext a)
mkExistential Int
r Int
start = forall num. num -> [LMADDim num] -> LMAD num
LMAD (forall a. Int -> Ext a
Ext Int
start) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall {a}. Int -> LMADDim (Ext a)
onDim [Int
0 .. Int
r forall a. Num a => a -> a -> a
- Int
1]
  where
    onDim :: Int -> LMADDim (Ext a)
onDim Int
i = forall num. num -> num -> LMADDim num
LMADDim (forall a. Int -> Ext a
Ext (Int
start forall a. Num a => a -> a -> a
+ Int
1 forall a. Num a => a -> a -> a
+ Int
i forall a. Num a => a -> a -> a
* Int
2)) (forall a. Int -> Ext a
Ext (Int
start forall a. Num a => a -> a -> a
+ Int
2 forall a. Num a => a -> a -> a
+ Int
i forall a. Num a => a -> a -> a
* Int
2))

-- | Permute dimensions.
permute :: LMAD num -> Permutation -> LMAD num
permute :: forall num. LMAD num -> Permutation -> LMAD num
permute LMAD num
lmad Permutation
perm =
  LMAD num
lmad {dims :: [LMADDim num]
dims = forall a. Permutation -> [a] -> [a]
rearrangeShape Permutation
perm forall a b. (a -> b) -> a -> b
$ forall num. LMAD num -> [LMADDim num]
dims LMAD num
lmad}

-- | Computes the maximum span of an 'LMAD'. The result is the lowest and
-- highest flat values representable by that 'LMAD'.
flatSpan :: LMAD (TPrimExp Int64 VName) -> TPrimExp Int64 VName
flatSpan :: LMAD (TPrimExp Int64 VName) -> TPrimExp Int64 VName
flatSpan (LMAD TPrimExp Int64 VName
_ [LMADDim (TPrimExp Int64 VName)]
dims) =
  forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr
    ( \LMADDim (TPrimExp Int64 VName)
dim TPrimExp Int64 VName
upper ->
        let spn :: TPrimExp Int64 VName
spn = forall num. LMADDim num -> num
ldStride LMADDim (TPrimExp Int64 VName)
dim forall a. Num a => a -> a -> a
* (forall num. LMADDim num -> num
ldShape LMADDim (TPrimExp Int64 VName)
dim forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1)
         in -- If you've gotten this far, you've already lost
            TPrimExp Int64 VName
spn forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
upper
    )
    TPrimExp Int64 VName
0
    [LMADDim (TPrimExp Int64 VName)]
dims

-- | Conservatively flatten a list of LMAD dimensions
--
-- Since not all LMADs can actually be flattened, we try to overestimate the
-- flattened array instead. This means that any "holes" in betwen dimensions
-- will get filled out.
-- conservativeFlatten :: (IntegralExp e, Ord e, Pretty e) => LMAD e -> LMAD e
conservativeFlatten :: LMAD (TPrimExp Int64 VName) -> Maybe (LMAD (TPrimExp Int64 VName))
conservativeFlatten :: LMAD (TPrimExp Int64 VName) -> Maybe (LMAD (TPrimExp Int64 VName))
conservativeFlatten (LMAD TPrimExp Int64 VName
offset []) =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall num. num -> [LMADDim num] -> LMAD num
LMAD TPrimExp Int64 VName
offset [forall num. num -> num -> LMADDim num
LMADDim TPrimExp Int64 VName
1 TPrimExp Int64 VName
1]
conservativeFlatten l :: LMAD (TPrimExp Int64 VName)
l@(LMAD TPrimExp Int64 VName
_ [LMADDim (TPrimExp Int64 VName)
_]) =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure LMAD (TPrimExp Int64 VName)
l
conservativeFlatten l :: LMAD (TPrimExp Int64 VName)
l@(LMAD TPrimExp Int64 VName
offset [LMADDim (TPrimExp Int64 VName)]
dims) = do
  TPrimExp Int64 VName
strd <-
    forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM
      TPrimExp Int64 VName
-> TPrimExp Int64 VName -> Maybe (TPrimExp Int64 VName)
gcd
      (forall num. LMADDim num -> num
ldStride forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
head [LMADDim (TPrimExp Int64 VName)]
dims)
      forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall num. LMADDim num -> num
ldStride [LMADDim (TPrimExp Int64 VName)]
dims
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall num. num -> [LMADDim num] -> LMAD num
LMAD TPrimExp Int64 VName
offset [forall num. num -> num -> LMADDim num
LMADDim TPrimExp Int64 VName
strd (TPrimExp Int64 VName
shp forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1)]
  where
    shp :: TPrimExp Int64 VName
shp = LMAD (TPrimExp Int64 VName) -> TPrimExp Int64 VName
flatSpan LMAD (TPrimExp Int64 VName)
l

-- | Very conservative GCD calculation. Returns 'Nothing' if the result cannot
-- be immediately determined. Does not recurse at all.
gcd :: TPrimExp Int64 VName -> TPrimExp Int64 VName -> Maybe (TPrimExp Int64 VName)
gcd :: TPrimExp Int64 VName
-> TPrimExp Int64 VName -> Maybe (TPrimExp Int64 VName)
gcd TPrimExp Int64 VName
x TPrimExp Int64 VName
y = forall {a}. (Eq a, Num a) => a -> a -> Maybe a
gcd' (forall a. Num a => a -> a
abs TPrimExp Int64 VName
x) (forall a. Num a => a -> a
abs TPrimExp Int64 VName
y)
  where
    gcd' :: a -> a -> Maybe a
gcd' a
a a
b | a
a forall a. Eq a => a -> a -> Bool
== a
b = forall a. a -> Maybe a
Just a
a
    gcd' a
1 a
_ = forall a. a -> Maybe a
Just a
1
    gcd' a
_ a
1 = forall a. a -> Maybe a
Just a
1
    gcd' a
a a
0 = forall a. a -> Maybe a
Just a
a
    gcd' a
_ a
_ = forall a. Maybe a
Nothing -- gcd' b (a `Futhark.Util.IntegralExp.rem` b)

-- | Returns @True@ if the two 'LMAD's could be proven disjoint.
--
-- Uses some best-approximation heuristics to determine disjointness. For two
-- 1-dimensional arrays, we can guarantee whether or not they are disjoint, but
-- as soon as more than one dimension is involved, things get more
-- tricky. Currently, we try to 'conservativelyFlatten' any LMAD with more than
-- one dimension.
disjoint :: [(VName, PrimExp VName)] -> Names -> LMAD (TPrimExp Int64 VName) -> LMAD (TPrimExp Int64 VName) -> Bool
disjoint :: [(VName, PrimExp VName)]
-> Names
-> LMAD (TPrimExp Int64 VName)
-> LMAD (TPrimExp Int64 VName)
-> Bool
disjoint [(VName, PrimExp VName)]
less_thans Names
non_negatives (LMAD TPrimExp Int64 VName
offset1 [LMADDim (TPrimExp Int64 VName)
dim1]) (LMAD TPrimExp Int64 VName
offset2 [LMADDim (TPrimExp Int64 VName)
dim2]) =
  Maybe (TPrimExp Int64 VName) -> TPrimExp Int64 VName -> Bool
doesNotDivide (TPrimExp Int64 VName
-> TPrimExp Int64 VName -> Maybe (TPrimExp Int64 VName)
gcd (forall num. LMADDim num -> num
ldStride LMADDim (TPrimExp Int64 VName)
dim1) (forall num. LMADDim num -> num
ldStride LMADDim (TPrimExp Int64 VName)
dim2)) (TPrimExp Int64 VName
offset1 forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
offset2)
    Bool -> Bool -> Bool
|| [(VName, PrimExp VName)]
-> Names -> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Bool
AlgSimplify.lessThanish
      [(VName, PrimExp VName)]
less_thans
      Names
non_negatives
      (TPrimExp Int64 VName
offset2 forall a. Num a => a -> a -> a
+ (forall num. LMADDim num -> num
ldShape LMADDim (TPrimExp Int64 VName)
dim2 forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1) forall a. Num a => a -> a -> a
* forall num. LMADDim num -> num
ldStride LMADDim (TPrimExp Int64 VName)
dim2)
      TPrimExp Int64 VName
offset1
    Bool -> Bool -> Bool
|| [(VName, PrimExp VName)]
-> Names -> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Bool
AlgSimplify.lessThanish
      [(VName, PrimExp VName)]
less_thans
      Names
non_negatives
      (TPrimExp Int64 VName
offset1 forall a. Num a => a -> a -> a
+ (forall num. LMADDim num -> num
ldShape LMADDim (TPrimExp Int64 VName)
dim1 forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1) forall a. Num a => a -> a -> a
* forall num. LMADDim num -> num
ldStride LMADDim (TPrimExp Int64 VName)
dim1)
      TPrimExp Int64 VName
offset2
  where
    doesNotDivide :: Maybe (TPrimExp Int64 VName) -> TPrimExp Int64 VName -> Bool
    doesNotDivide :: Maybe (TPrimExp Int64 VName) -> TPrimExp Int64 VName -> Bool
doesNotDivide (Just TPrimExp Int64 VName
x) TPrimExp Int64 VName
y =
      forall e. IntegralExp e => e -> e -> e
Futhark.Util.IntegralExp.mod TPrimExp Int64 VName
y TPrimExp Int64 VName
x
        forall a b. a -> (a -> b) -> b
& forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped
        forall a b. a -> (a -> b) -> b
& forall v. PrimExp v -> PrimExp v
constFoldPrimExp
        forall a b. a -> (a -> b) -> b
& forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp
        forall a b. a -> (a -> b) -> b
& forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
(.==.) (TPrimExp Int64 VName
0 :: TPrimExp Int64 VName)
        forall a b. a -> (a -> b) -> b
& TPrimExp Bool VName -> Maybe Bool
primBool
        forall a b. a -> (a -> b) -> b
& forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False Bool -> Bool
not
    doesNotDivide Maybe (TPrimExp Int64 VName)
_ TPrimExp Int64 VName
_ = Bool
False
disjoint [(VName, PrimExp VName)]
less_thans Names
non_negatives LMAD (TPrimExp Int64 VName)
lmad1 LMAD (TPrimExp Int64 VName)
lmad2 =
  case (LMAD (TPrimExp Int64 VName) -> Maybe (LMAD (TPrimExp Int64 VName))
conservativeFlatten LMAD (TPrimExp Int64 VName)
lmad1, LMAD (TPrimExp Int64 VName) -> Maybe (LMAD (TPrimExp Int64 VName))
conservativeFlatten LMAD (TPrimExp Int64 VName)
lmad2) of
    (Just LMAD (TPrimExp Int64 VName)
lmad1', Just LMAD (TPrimExp Int64 VName)
lmad2') -> [(VName, PrimExp VName)]
-> Names
-> LMAD (TPrimExp Int64 VName)
-> LMAD (TPrimExp Int64 VName)
-> Bool
disjoint [(VName, PrimExp VName)]
less_thans Names
non_negatives LMAD (TPrimExp Int64 VName)
lmad1' LMAD (TPrimExp Int64 VName)
lmad2'
    (Maybe (LMAD (TPrimExp Int64 VName)),
 Maybe (LMAD (TPrimExp Int64 VName)))
_ -> Bool
False

disjoint2 :: scope -> asserts -> [(VName, PrimExp VName)] -> Names -> LMAD (TPrimExp Int64 VName) -> LMAD (TPrimExp Int64 VName) -> Bool
disjoint2 :: forall scope asserts.
scope
-> asserts
-> [(VName, PrimExp VName)]
-> Names
-> LMAD (TPrimExp Int64 VName)
-> LMAD (TPrimExp Int64 VName)
-> Bool
disjoint2 scope
_ asserts
_ [(VName, PrimExp VName)]
less_thans Names
non_negatives LMAD (TPrimExp Int64 VName)
lmad1 LMAD (TPrimExp Int64 VName)
lmad2 =
  let (SofP
offset1, [Interval]
interval1) = LMAD (TPrimExp Int64 VName) -> (SofP, [Interval])
lmadToIntervals LMAD (TPrimExp Int64 VName)
lmad1
      (SofP
offset2, [Interval]
interval2) = LMAD (TPrimExp Int64 VName) -> (SofP, [Interval])
lmadToIntervals LMAD (TPrimExp Int64 VName)
lmad2
      (SofP
neg_offset, SofP
pos_offset) =
        forall a. (a -> Bool) -> [a] -> ([a], [a])
partition Prod -> Bool
AlgSimplify.negated forall a b. (a -> b) -> a -> b
$
          SofP
offset1 SofP -> SofP -> SofP
`AlgSimplify.sub` SofP
offset2
      ([Interval]
interval1', [Interval]
interval2') =
        forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$
          forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (forall a b c. (a -> b -> c) -> b -> a -> c
flip SofP -> SofP -> Ordering
AlgSimplify.compareComplexity forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (PrimExp VName -> SofP
AlgSimplify.simplify0 forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Interval -> TPrimExp Int64 VName
stride forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b. (a, b) -> a
fst)) forall a b. (a -> b) -> a -> b
$
            [Interval] -> [Interval] -> [(Interval, Interval)]
intervalPairs [Interval]
interval1 [Interval]
interval2
   in case ( forall (m :: * -> *).
MonadFail m =>
SofP -> [Interval] -> m [Interval]
distributeOffset SofP
pos_offset [Interval]
interval1',
             forall (m :: * -> *).
MonadFail m =>
SofP -> [Interval] -> m [Interval]
distributeOffset (forall a b. (a -> b) -> [a] -> [b]
map Prod -> Prod
AlgSimplify.negate SofP
neg_offset) [Interval]
interval2'
           ) of
        (Just [Interval]
interval1'', Just [Interval]
interval2'') ->
          forall a. Maybe a -> Bool
isNothing
            ( forall scope asserts.
scope
-> asserts
-> [(VName, PrimExp VName)]
-> [PrimExp VName]
-> [Interval]
-> Maybe Interval
selfOverlap () () [(VName, PrimExp VName)]
less_thans (forall a b. (a -> b) -> [a] -> [b]
map (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall v. v -> PrimType -> PrimExp v
LeafExp forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int64) forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList Names
non_negatives) [Interval]
interval1''
            )
            Bool -> Bool -> Bool
&& forall a. Maybe a -> Bool
isNothing
              ( forall scope asserts.
scope
-> asserts
-> [(VName, PrimExp VName)]
-> [PrimExp VName]
-> [Interval]
-> Maybe Interval
selfOverlap () () [(VName, PrimExp VName)]
less_thans (forall a b. (a -> b) -> [a] -> [b]
map (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall v. v -> PrimType -> PrimExp v
LeafExp forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int64) forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList Names
non_negatives) [Interval]
interval2''
              )
            Bool -> Bool -> Bool
&& Bool -> Bool
not
              ( forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all
                  (forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ([(VName, PrimExp VName)] -> Names -> Interval -> Interval -> Bool
intervalOverlap [(VName, PrimExp VName)]
less_thans Names
non_negatives))
                  (forall a b. [a] -> [b] -> [(a, b)]
zip [Interval]
interval1'' [Interval]
interval2'')
              )
        (Maybe [Interval], Maybe [Interval])
_ ->
          Bool
False

disjoint3 :: M.Map VName Type -> [PrimExp VName] -> [(VName, PrimExp VName)] -> [PrimExp VName] -> LMAD (TPrimExp Int64 VName) -> LMAD (TPrimExp Int64 VName) -> Bool
disjoint3 :: Map VName Type
-> [PrimExp VName]
-> [(VName, PrimExp VName)]
-> [PrimExp VName]
-> LMAD (TPrimExp Int64 VName)
-> LMAD (TPrimExp Int64 VName)
-> Bool
disjoint3 Map VName Type
scope [PrimExp VName]
asserts [(VName, PrimExp VName)]
less_thans [PrimExp VName]
non_negatives LMAD (TPrimExp Int64 VName)
lmad1 LMAD (TPrimExp Int64 VName)
lmad2 =
  let (SofP
offset1, [Interval]
interval1) = LMAD (TPrimExp Int64 VName) -> (SofP, [Interval])
lmadToIntervals LMAD (TPrimExp Int64 VName)
lmad1
      (SofP
offset2, [Interval]
interval2) = LMAD (TPrimExp Int64 VName) -> (SofP, [Interval])
lmadToIntervals LMAD (TPrimExp Int64 VName)
lmad2
      interval1' :: [Interval]
interval1' = forall a. Eq a => (a -> a) -> a -> a
fixPoint ([Interval] -> [Interval]
mergeDims forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [Interval] -> [Interval]
joinDims) forall a b. (a -> b) -> a -> b
$ forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (forall a b c. (a -> b -> c) -> b -> a -> c
flip SofP -> SofP -> Ordering
AlgSimplify.compareComplexity forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (PrimExp VName -> SofP
AlgSimplify.simplify0 forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Interval -> TPrimExp Int64 VName
stride)) [Interval]
interval1
      interval2' :: [Interval]
interval2' = forall a. Eq a => (a -> a) -> a -> a
fixPoint ([Interval] -> [Interval]
mergeDims forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [Interval] -> [Interval]
joinDims) forall a b. (a -> b) -> a -> b
$ forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (forall a b c. (a -> b -> c) -> b -> a -> c
flip SofP -> SofP -> Ordering
AlgSimplify.compareComplexity forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (PrimExp VName -> SofP
AlgSimplify.simplify0 forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Interval -> TPrimExp Int64 VName
stride)) [Interval]
interval2
      ([Interval]
interval1'', [Interval]
interval2'') =
        forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$
          forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (forall a b c. (a -> b -> c) -> b -> a -> c
flip SofP -> SofP -> Ordering
AlgSimplify.compareComplexity forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (PrimExp VName -> SofP
AlgSimplify.simplify0 forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Interval -> TPrimExp Int64 VName
stride forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b. (a, b) -> a
fst)) forall a b. (a -> b) -> a -> b
$
            [Interval] -> [Interval] -> [(Interval, Interval)]
intervalPairs [Interval]
interval1' [Interval]
interval2'
   in Int -> [Interval] -> [Interval] -> SofP -> Bool
disjointHelper Int
4 [Interval]
interval1'' [Interval]
interval2'' forall a b. (a -> b) -> a -> b
$ SofP
offset1 SofP -> SofP -> SofP
`AlgSimplify.sub` SofP
offset2
  where
    disjointHelper :: Int -> [Interval] -> [Interval] -> AlgSimplify.SofP -> Bool
    disjointHelper :: Int -> [Interval] -> [Interval] -> SofP -> Bool
disjointHelper Int
0 [Interval]
_ [Interval]
_ SofP
_ = Bool
False
    disjointHelper Int
i [Interval]
is10 [Interval]
is20 SofP
offset =
      let ([Interval]
is1, [Interval]
is2) =
            forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$
              forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (forall a b c. (a -> b -> c) -> b -> a -> c
flip SofP -> SofP -> Ordering
AlgSimplify.compareComplexity forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (PrimExp VName -> SofP
AlgSimplify.simplify0 forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Interval -> TPrimExp Int64 VName
stride forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b. (a, b) -> a
fst)) forall a b. (a -> b) -> a -> b
$
                [Interval] -> [Interval] -> [(Interval, Interval)]
intervalPairs [Interval]
is10 [Interval]
is20
          (SofP
neg_offset, SofP
pos_offset) = forall a. (a -> Bool) -> [a] -> ([a], [a])
partition Prod -> Bool
AlgSimplify.negated SofP
offset
       in case ( forall (m :: * -> *).
MonadFail m =>
SofP -> [Interval] -> m [Interval]
distributeOffset SofP
pos_offset [Interval]
is1,
                 forall (m :: * -> *).
MonadFail m =>
SofP -> [Interval] -> m [Interval]
distributeOffset (forall a b. (a -> b) -> [a] -> [b]
map Prod -> Prod
AlgSimplify.negate SofP
neg_offset) [Interval]
is2
               ) of
            (Just [Interval]
is1', Just [Interval]
is2') -> do
              let overlap1 :: Maybe Interval
overlap1 = forall scope asserts.
scope
-> asserts
-> [(VName, PrimExp VName)]
-> [PrimExp VName]
-> [Interval]
-> Maybe Interval
selfOverlap Map VName Type
scope [PrimExp VName]
asserts [(VName, PrimExp VName)]
less_thans [PrimExp VName]
non_negatives [Interval]
is1'
              let overlap2 :: Maybe Interval
overlap2 = forall scope asserts.
scope
-> asserts
-> [(VName, PrimExp VName)]
-> [PrimExp VName]
-> [Interval]
-> Maybe Interval
selfOverlap Map VName Type
scope [PrimExp VName]
asserts [(VName, PrimExp VName)]
less_thans [PrimExp VName]
non_negatives [Interval]
is2'
              case (Maybe Interval
overlap1, Maybe Interval
overlap2) of
                (Maybe Interval
Nothing, Maybe Interval
Nothing) ->
                  case [VName] -> Names
namesFromList forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM PrimExp VName -> Maybe VName
justLeafExp [PrimExp VName]
non_negatives of
                    Just Names
non_negatives' ->
                      Bool -> Bool
not forall a b. (a -> b) -> a -> b
$
                        forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all
                          (forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ([(VName, PrimExp VName)] -> Names -> Interval -> Interval -> Bool
intervalOverlap [(VName, PrimExp VName)]
less_thans Names
non_negatives'))
                          (forall a b. [a] -> [b] -> [(a, b)]
zip [Interval]
is1 [Interval]
is2)
                    Maybe Names
_ -> Bool
False
                (Just Interval
overlapping_dim, Maybe Interval
_) ->
                  let expanded_offset :: Maybe SofP
expanded_offset = SofP -> SofP
AlgSimplify.simplifySofP' forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SofP -> [Interval] -> Maybe SofP
expandOffset SofP
offset [Interval]
is1
                      splits :: [(SofP, [Interval])]
splits = Interval -> [Interval] -> [(SofP, [Interval])]
splitDim Interval
overlapping_dim [Interval]
is1'
                   in forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\(SofP
new_offset, [Interval]
new_is1) -> Int -> [Interval] -> [Interval] -> SofP -> Bool
disjointHelper (Int
i forall a. Num a => a -> a -> a
- Int
1) ([Interval] -> [Interval]
joinDims [Interval]
new_is1) ([Interval] -> [Interval]
joinDims [Interval]
is2') SofP
new_offset) [(SofP, [Interval])]
splits
                        Bool -> Bool -> Bool
|| forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (Int -> [Interval] -> [Interval] -> SofP -> Bool
disjointHelper (Int
i forall a. Num a => a -> a -> a
- Int
1) [Interval]
is1 [Interval]
is2) Maybe SofP
expanded_offset
                (Maybe Interval
_, Just Interval
overlapping_dim) ->
                  let expanded_offset :: Maybe SofP
expanded_offset = SofP -> SofP
AlgSimplify.simplifySofP' forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SofP -> [Interval] -> Maybe SofP
expandOffset SofP
offset [Interval]
is2
                      splits :: [(SofP, [Interval])]
splits = Interval -> [Interval] -> [(SofP, [Interval])]
splitDim Interval
overlapping_dim [Interval]
is2'
                   in forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all
                        ( \(SofP
new_offset, [Interval]
new_is2) ->
                            Int -> [Interval] -> [Interval] -> SofP -> Bool
disjointHelper (Int
i forall a. Num a => a -> a -> a
- Int
1) ([Interval] -> [Interval]
joinDims [Interval]
is1') ([Interval] -> [Interval]
joinDims [Interval]
new_is2) forall a b. (a -> b) -> a -> b
$
                              forall a b. (a -> b) -> [a] -> [b]
map Prod -> Prod
AlgSimplify.negate SofP
new_offset
                        )
                        [(SofP, [Interval])]
splits
                        Bool -> Bool -> Bool
|| forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (Int -> [Interval] -> [Interval] -> SofP -> Bool
disjointHelper (Int
i forall a. Num a => a -> a -> a
- Int
1) [Interval]
is1 [Interval]
is2) Maybe SofP
expanded_offset
            (Maybe [Interval], Maybe [Interval])
_ -> Bool
False

joinDims :: [Interval] -> [Interval]
joinDims :: [Interval] -> [Interval]
joinDims = [Interval] -> [Interval] -> [Interval]
helper []
  where
    helper :: [Interval] -> [Interval] -> [Interval]
helper [Interval]
acc [] = forall a. [a] -> [a]
reverse [Interval]
acc
    helper [Interval]
acc [Interval
x] = forall a. [a] -> [a]
reverse forall a b. (a -> b) -> a -> b
$ Interval
x forall a. a -> [a] -> [a]
: [Interval]
acc
    helper [Interval]
acc (Interval
x : Interval
y : [Interval]
rest) =
      if Interval -> TPrimExp Int64 VName
stride Interval
x forall a. Eq a => a -> a -> Bool
== Interval -> TPrimExp Int64 VName
stride Interval
y Bool -> Bool -> Bool
&& Interval -> TPrimExp Int64 VName
lowerBound Interval
x forall a. Eq a => a -> a -> Bool
== TPrimExp Int64 VName
0 Bool -> Bool -> Bool
&& Interval -> TPrimExp Int64 VName
lowerBound Interval
y forall a. Eq a => a -> a -> Bool
== TPrimExp Int64 VName
0
        then [Interval] -> [Interval] -> [Interval]
helper [Interval]
acc forall a b. (a -> b) -> a -> b
$ Interval
x {numElements :: TPrimExp Int64 VName
numElements = Interval -> TPrimExp Int64 VName
numElements Interval
x forall a. Num a => a -> a -> a
* Interval -> TPrimExp Int64 VName
numElements Interval
y} forall a. a -> [a] -> [a]
: [Interval]
rest
        else [Interval] -> [Interval] -> [Interval]
helper (Interval
x forall a. a -> [a] -> [a]
: [Interval]
acc) (Interval
y forall a. a -> [a] -> [a]
: [Interval]
rest)

mergeDims :: [Interval] -> [Interval]
mergeDims :: [Interval] -> [Interval]
mergeDims = [Interval] -> [Interval] -> [Interval]
helper [] forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. [a] -> [a]
reverse
  where
    helper :: [Interval] -> [Interval] -> [Interval]
helper [Interval]
acc [] = [Interval]
acc
    helper [Interval]
acc [Interval
x] = Interval
x forall a. a -> [a] -> [a]
: [Interval]
acc
    helper [Interval]
acc (Interval
x : Interval
y : [Interval]
rest) =
      if Interval -> TPrimExp Int64 VName
stride Interval
x forall a. Num a => a -> a -> a
* Interval -> TPrimExp Int64 VName
numElements Interval
x forall a. Eq a => a -> a -> Bool
== Interval -> TPrimExp Int64 VName
stride Interval
y Bool -> Bool -> Bool
&& Interval -> TPrimExp Int64 VName
lowerBound Interval
x forall a. Eq a => a -> a -> Bool
== TPrimExp Int64 VName
0 Bool -> Bool -> Bool
&& Interval -> TPrimExp Int64 VName
lowerBound Interval
y forall a. Eq a => a -> a -> Bool
== TPrimExp Int64 VName
0
        then [Interval] -> [Interval] -> [Interval]
helper [Interval]
acc forall a b. (a -> b) -> a -> b
$ Interval
x {numElements :: TPrimExp Int64 VName
numElements = Interval -> TPrimExp Int64 VName
numElements Interval
x forall a. Num a => a -> a -> a
* Interval -> TPrimExp Int64 VName
numElements Interval
y} forall a. a -> [a] -> [a]
: [Interval]
rest
        else [Interval] -> [Interval] -> [Interval]
helper (Interval
x forall a. a -> [a] -> [a]
: [Interval]
acc) (Interval
y forall a. a -> [a] -> [a]
: [Interval]
rest)

splitDim :: Interval -> [Interval] -> [(AlgSimplify.SofP, [Interval])]
splitDim :: Interval -> [Interval] -> [(SofP, [Interval])]
splitDim Interval
overlapping_dim0 [Interval]
is
  | [Prod
st] <- PrimExp VName -> SofP
AlgSimplify.simplify0 forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ Interval -> TPrimExp Int64 VName
stride Interval
overlapping_dim0,
    [Prod
st1] <- PrimExp VName -> SofP
AlgSimplify.simplify0 forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ Interval -> TPrimExp Int64 VName
stride Interval
overlapping_dim,
    [Prod
spn] <- PrimExp VName -> SofP
AlgSimplify.simplify0 forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ Interval -> TPrimExp Int64 VName
stride Interval
overlapping_dim forall a. Num a => a -> a -> a
* Interval -> TPrimExp Int64 VName
numElements Interval
overlapping_dim,
    Interval -> TPrimExp Int64 VName
lowerBound Interval
overlapping_dim forall a. Eq a => a -> a -> Bool
== TPrimExp Int64 VName
0,
    Just Prod
big_dim_elems <- Prod -> Prod -> Maybe Prod
AlgSimplify.maybeDivide Prod
spn Prod
st,
    Just Prod
small_dim_elems <- Prod -> Prod -> Maybe Prod
AlgSimplify.maybeDivide Prod
st Prod
st1 =
      [ ( [],
          forall a. [a] -> [a]
init [Interval]
before
            forall a. Semigroup a => a -> a -> a
<> [ TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval TPrimExp Int64 VName
0 (forall v. PrimExp v -> TPrimExp Int64 v
isInt64 forall a b. (a -> b) -> a -> b
$ Prod -> PrimExp VName
AlgSimplify.prodToExp Prod
big_dim_elems) (Interval -> TPrimExp Int64 VName
stride Interval
overlapping_dim0),
                 TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval TPrimExp Int64 VName
0 (forall v. PrimExp v -> TPrimExp Int64 v
isInt64 forall a b. (a -> b) -> a -> b
$ Prod -> PrimExp VName
AlgSimplify.prodToExp Prod
small_dim_elems) (Interval -> TPrimExp Int64 VName
stride Interval
overlapping_dim)
               ]
            forall a. Semigroup a => a -> a -> a
<> [Interval]
after
        )
      ]
  | Bool
otherwise =
      let shrunk_dim :: Interval
shrunk_dim = Interval
overlapping_dim {numElements :: TPrimExp Int64 VName
numElements = Interval -> TPrimExp Int64 VName
numElements Interval
overlapping_dim forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1}
          point_offset :: SofP
point_offset = PrimExp VName -> SofP
AlgSimplify.simplify0 forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ (Interval -> TPrimExp Int64 VName
numElements Interval
overlapping_dim forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1 forall a. Num a => a -> a -> a
+ Interval -> TPrimExp Int64 VName
lowerBound Interval
overlapping_dim) forall a. Num a => a -> a -> a
* Interval -> TPrimExp Int64 VName
stride Interval
overlapping_dim
       in [ (SofP
point_offset, [Interval]
before forall a. Semigroup a => a -> a -> a
<> [Interval]
after),
            ([], [Interval]
before forall a. Semigroup a => a -> a -> a
<> [Interval
shrunk_dim] forall a. Semigroup a => a -> a -> a
<> [Interval]
after)
          ]
  where
    ([Interval]
before, Interval
overlapping_dim, [Interval]
after) =
      forall a. HasCallStack => Maybe a -> a
fromJust forall a b. (a -> b) -> a -> b
$
        forall a. Eq a => a -> [a] -> Maybe Int
elemIndex Interval
overlapping_dim0 [Interval]
is
          forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall int a. Integral int => int -> [a] -> Maybe ([a], a, [a])
focusNth [Interval]
is forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (forall a. Num a => a -> a -> a
+ Int
1))

lmadToIntervals :: LMAD (TPrimExp Int64 VName) -> (AlgSimplify.SofP, [Interval])
lmadToIntervals :: LMAD (TPrimExp Int64 VName) -> (SofP, [Interval])
lmadToIntervals (LMAD TPrimExp Int64 VName
offset []) = (PrimExp VName -> SofP
AlgSimplify.simplify0 forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
offset, [TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval TPrimExp Int64 VName
0 TPrimExp Int64 VName
1 TPrimExp Int64 VName
1])
lmadToIntervals (LMAD TPrimExp Int64 VName
offset [LMADDim (TPrimExp Int64 VName)]
dims0) =
  (SofP
offset', forall a b. (a -> b) -> [a] -> [b]
map LMADDim (TPrimExp Int64 VName) -> Interval
helper [LMADDim (TPrimExp Int64 VName)]
dims0)
  where
    offset' :: SofP
offset' = PrimExp VName -> SofP
AlgSimplify.simplify0 forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
offset

    helper :: LMADDim (TPrimExp Int64 VName) -> Interval
    helper :: LMADDim (TPrimExp Int64 VName) -> Interval
helper (LMADDim TPrimExp Int64 VName
strd TPrimExp Int64 VName
shp) = do
      TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval TPrimExp Int64 VName
0 (TPrimExp Int64 VName -> TPrimExp Int64 VName
AlgSimplify.simplify' TPrimExp Int64 VName
shp) (TPrimExp Int64 VName -> TPrimExp Int64 VName
AlgSimplify.simplify' TPrimExp Int64 VName
strd)

-- | Dynamically determine if two 'LMADDim' are equal.
--
-- True if the dynamic values of their constituents are equal.
dynamicEqualsLMADDim :: (Eq num) => LMADDim (TPrimExp t num) -> LMADDim (TPrimExp t num) -> TPrimExp Bool num
dynamicEqualsLMADDim :: forall {k} num (t :: k).
Eq num =>
LMADDim (TPrimExp t num)
-> LMADDim (TPrimExp t num) -> TPrimExp Bool num
dynamicEqualsLMADDim LMADDim (TPrimExp t num)
dim1 LMADDim (TPrimExp t num)
dim2 =
  forall num. LMADDim num -> num
ldStride LMADDim (TPrimExp t num)
dim1 forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. forall num. LMADDim num -> num
ldStride LMADDim (TPrimExp t num)
dim2 forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. forall num. LMADDim num -> num
ldShape LMADDim (TPrimExp t num)
dim1 forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. forall num. LMADDim num -> num
ldShape LMADDim (TPrimExp t num)
dim2

-- | Dynamically determine if two 'LMAD' are equal.
--
-- True if offset and constituent 'LMADDim' are equal.
dynamicEqualsLMAD :: (Eq num) => LMAD (TPrimExp t num) -> LMAD (TPrimExp t num) -> TPrimExp Bool num
dynamicEqualsLMAD :: forall {k} num (t :: k).
Eq num =>
LMAD (TPrimExp t num) -> LMAD (TPrimExp t num) -> TPrimExp Bool num
dynamicEqualsLMAD LMAD (TPrimExp t num)
lmad1 LMAD (TPrimExp t num)
lmad2 =
  forall num. LMAD num -> num
offset LMAD (TPrimExp t num)
lmad1 forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. forall num. LMAD num -> num
offset LMAD (TPrimExp t num)
lmad2
    forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr
      (forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
(.&&.) forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry forall {k} num (t :: k).
Eq num =>
LMADDim (TPrimExp t num)
-> LMADDim (TPrimExp t num) -> TPrimExp Bool num
dynamicEqualsLMADDim)
      forall v. TPrimExp Bool v
true
      (forall a b. [a] -> [b] -> [(a, b)]
zip (forall num. LMAD num -> [LMADDim num]
dims LMAD (TPrimExp t num)
lmad1) (forall num. LMAD num -> [LMADDim num]
dims LMAD (TPrimExp t num)
lmad2))
{-# NOINLINE dynamicEqualsLMAD #-}

-- | Returns true if two 'LMAD's are equivalent.
--
-- Equivalence in this case is matching in offsets and strides.
equivalent :: (Eq num) => LMAD num -> LMAD num -> Bool
equivalent :: forall num. Eq num => LMAD num -> LMAD num -> Bool
equivalent LMAD num
lmad1 LMAD num
lmad2 =
  forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall num. LMAD num -> [LMADDim num]
dims LMAD num
lmad1) forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall num. LMAD num -> [LMADDim num]
dims LMAD num
lmad2)
    Bool -> Bool -> Bool
&& forall num. LMAD num -> num
offset LMAD num
lmad1 forall a. Eq a => a -> a -> Bool
== forall num. LMAD num -> num
offset LMAD num
lmad2
    Bool -> Bool -> Bool
&& forall a b. (a -> b) -> [a] -> [b]
map forall num. LMADDim num -> num
ldStride (forall num. LMAD num -> [LMADDim num]
dims LMAD num
lmad1) forall a. Eq a => a -> a -> Bool
== forall a b. (a -> b) -> [a] -> [b]
map forall num. LMADDim num -> num
ldStride (forall num. LMAD num -> [LMADDim num]
dims LMAD num
lmad2)
{-# NOINLINE equivalent #-}

-- | Is this is a row-major array with zero offset?
isDirect :: (Eq num, IntegralExp num) => LMAD num -> Bool
isDirect :: forall num. (Eq num, IntegralExp num) => LMAD num -> Bool
isDirect LMAD num
lmad = LMAD num
lmad forall a. Eq a => a -> a -> Bool
== forall num. IntegralExp num => num -> [num] -> LMAD num
iota num
0 (forall a b. (a -> b) -> [a] -> [b]
map forall num. LMADDim num -> num
ldShape forall a b. (a -> b) -> a -> b
$ forall num. LMAD num -> [LMADDim num]
dims LMAD num
lmad)
{-# NOINLINE isDirect #-}