-- | This module contains a representation of linear-memory accessor
-- descriptors (LMAD); see work by Zhu, Hoeflinger and David.
--
-- This module is designed to be used as a qualified import, as the
-- exported names are quite generic.
module Futhark.IR.Mem.LMAD
  ( Shape,
    Indices,
    LMAD (..),
    LMADDim (..),
    Permutation,
    index,
    slice,
    flatSlice,
    reshape,
    permute,
    shape,
    rank,
    substituteInLMAD,
    disjoint,
    disjoint2,
    disjoint3,
    dynamicEqualsLMAD,
    iota,
    mkExistential,
    equivalent,
    isDirect,
  )
where

import Control.Category
import Control.Monad
import Data.Function (on, (&))
import Data.List (elemIndex, partition, sortBy)
import Data.Map.Strict qualified as M
import Data.Maybe (fromJust, isNothing)
import Data.Traversable
import Futhark.Analysis.AlgSimplify qualified as AlgSimplify
import Futhark.Analysis.PrimExp
import Futhark.Analysis.PrimExp.Convert
import Futhark.IR.Mem.Interval
import Futhark.IR.Prop
import Futhark.IR.Syntax
  ( DimIndex (..),
    Ext (..),
    FlatDimIndex (..),
    FlatSlice (..),
    Slice (..),
    Type,
    unitSlice,
  )
import Futhark.IR.Syntax.Core (VName (..))
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import Futhark.Util
import Futhark.Util.IntegralExp
import Futhark.Util.Pretty
import Prelude hiding (gcd, id, mod, (.))

-- | The shape of an index function.
type Shape num = [num]

-- | Indices passed to an LMAD.  Must always match the rank of the LMAD.
type Indices num = [num]

-- | A complete permutation.
type Permutation = [Int]

-- | A single dimension in an 'LMAD'.
data LMADDim num = LMADDim
  { forall num. LMADDim num -> num
ldStride :: num,
    forall num. LMADDim num -> num
ldShape :: num
  }
  deriving (Int -> LMADDim num -> ShowS
[LMADDim num] -> ShowS
LMADDim num -> String
(Int -> LMADDim num -> ShowS)
-> (LMADDim num -> String)
-> ([LMADDim num] -> ShowS)
-> Show (LMADDim num)
forall num. Show num => Int -> LMADDim num -> ShowS
forall num. Show num => [LMADDim num] -> ShowS
forall num. Show num => LMADDim num -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall num. Show num => Int -> LMADDim num -> ShowS
showsPrec :: Int -> LMADDim num -> ShowS
$cshow :: forall num. Show num => LMADDim num -> String
show :: LMADDim num -> String
$cshowList :: forall num. Show num => [LMADDim num] -> ShowS
showList :: [LMADDim num] -> ShowS
Show, LMADDim num -> LMADDim num -> Bool
(LMADDim num -> LMADDim num -> Bool)
-> (LMADDim num -> LMADDim num -> Bool) -> Eq (LMADDim num)
forall num. Eq num => LMADDim num -> LMADDim num -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall num. Eq num => LMADDim num -> LMADDim num -> Bool
== :: LMADDim num -> LMADDim num -> Bool
$c/= :: forall num. Eq num => LMADDim num -> LMADDim num -> Bool
/= :: LMADDim num -> LMADDim num -> Bool
Eq, Eq (LMADDim num)
Eq (LMADDim num)
-> (LMADDim num -> LMADDim num -> Ordering)
-> (LMADDim num -> LMADDim num -> Bool)
-> (LMADDim num -> LMADDim num -> Bool)
-> (LMADDim num -> LMADDim num -> Bool)
-> (LMADDim num -> LMADDim num -> Bool)
-> (LMADDim num -> LMADDim num -> LMADDim num)
-> (LMADDim num -> LMADDim num -> LMADDim num)
-> Ord (LMADDim num)
LMADDim num -> LMADDim num -> Bool
LMADDim num -> LMADDim num -> Ordering
LMADDim num -> LMADDim num -> LMADDim num
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {num}. Ord num => Eq (LMADDim num)
forall num. Ord num => LMADDim num -> LMADDim num -> Bool
forall num. Ord num => LMADDim num -> LMADDim num -> Ordering
forall num. Ord num => LMADDim num -> LMADDim num -> LMADDim num
$ccompare :: forall num. Ord num => LMADDim num -> LMADDim num -> Ordering
compare :: LMADDim num -> LMADDim num -> Ordering
$c< :: forall num. Ord num => LMADDim num -> LMADDim num -> Bool
< :: LMADDim num -> LMADDim num -> Bool
$c<= :: forall num. Ord num => LMADDim num -> LMADDim num -> Bool
<= :: LMADDim num -> LMADDim num -> Bool
$c> :: forall num. Ord num => LMADDim num -> LMADDim num -> Bool
> :: LMADDim num -> LMADDim num -> Bool
$c>= :: forall num. Ord num => LMADDim num -> LMADDim num -> Bool
>= :: LMADDim num -> LMADDim num -> Bool
$cmax :: forall num. Ord num => LMADDim num -> LMADDim num -> LMADDim num
max :: LMADDim num -> LMADDim num -> LMADDim num
$cmin :: forall num. Ord num => LMADDim num -> LMADDim num -> LMADDim num
min :: LMADDim num -> LMADDim num -> LMADDim num
Ord)

-- | LMAD's representation consists of a general offset and for each
-- dimension a stride, number of elements (or shape), and
-- permutation. Note that the permutation is not strictly necessary in
-- that the permutation can be performed directly on LMAD dimensions,
-- but then it is difficult to extract the permutation back from an
-- LMAD.
--
-- LMAD algebra is closed under composition w.r.t. operators such as
-- permute, index and slice.  However, other operations, such as
-- reshape, cannot always be represented inside the LMAD algebra.
--
-- It follows that the general representation of an index function is a list of
-- LMADS, in which each following LMAD in the list implicitly corresponds to an
-- irregular reshaping operation.
--
-- However, we expect that the common case is when the index function is one
-- LMAD -- we call this the "nice" representation.
--
-- Finally, the list of LMADs is kept in an @IxFun@ together with the shape of
-- the original array, and a bit to indicate whether the index function is
-- contiguous, i.e., if we instantiate all the points of the current index
-- function, do we get a contiguous memory interval?
--
-- By definition, the LMAD \( \sigma + \{ (n_1, s_1), \ldots, (n_k, s_k) \} \),
-- where \(n\) and \(s\) denote the shape and stride of each dimension, denotes
-- the set of points:
--
-- \[
--    \{ ~ \sigma + i_1 * s_1 + \ldots + i_m * s_m ~ | ~ 0 \leq i_1 < n_1, \ldots, 0 \leq i_m < n_m ~ \}
-- \]
data LMAD num = LMAD
  { forall num. LMAD num -> num
offset :: num,
    forall num. LMAD num -> [LMADDim num]
dims :: [LMADDim num]
  }
  deriving (Int -> LMAD num -> ShowS
[LMAD num] -> ShowS
LMAD num -> String
(Int -> LMAD num -> ShowS)
-> (LMAD num -> String) -> ([LMAD num] -> ShowS) -> Show (LMAD num)
forall num. Show num => Int -> LMAD num -> ShowS
forall num. Show num => [LMAD num] -> ShowS
forall num. Show num => LMAD num -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall num. Show num => Int -> LMAD num -> ShowS
showsPrec :: Int -> LMAD num -> ShowS
$cshow :: forall num. Show num => LMAD num -> String
show :: LMAD num -> String
$cshowList :: forall num. Show num => [LMAD num] -> ShowS
showList :: [LMAD num] -> ShowS
Show, LMAD num -> LMAD num -> Bool
(LMAD num -> LMAD num -> Bool)
-> (LMAD num -> LMAD num -> Bool) -> Eq (LMAD num)
forall num. Eq num => LMAD num -> LMAD num -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall num. Eq num => LMAD num -> LMAD num -> Bool
== :: LMAD num -> LMAD num -> Bool
$c/= :: forall num. Eq num => LMAD num -> LMAD num -> Bool
/= :: LMAD num -> LMAD num -> Bool
Eq, Eq (LMAD num)
Eq (LMAD num)
-> (LMAD num -> LMAD num -> Ordering)
-> (LMAD num -> LMAD num -> Bool)
-> (LMAD num -> LMAD num -> Bool)
-> (LMAD num -> LMAD num -> Bool)
-> (LMAD num -> LMAD num -> Bool)
-> (LMAD num -> LMAD num -> LMAD num)
-> (LMAD num -> LMAD num -> LMAD num)
-> Ord (LMAD num)
LMAD num -> LMAD num -> Bool
LMAD num -> LMAD num -> Ordering
LMAD num -> LMAD num -> LMAD num
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {num}. Ord num => Eq (LMAD num)
forall num. Ord num => LMAD num -> LMAD num -> Bool
forall num. Ord num => LMAD num -> LMAD num -> Ordering
forall num. Ord num => LMAD num -> LMAD num -> LMAD num
$ccompare :: forall num. Ord num => LMAD num -> LMAD num -> Ordering
compare :: LMAD num -> LMAD num -> Ordering
$c< :: forall num. Ord num => LMAD num -> LMAD num -> Bool
< :: LMAD num -> LMAD num -> Bool
$c<= :: forall num. Ord num => LMAD num -> LMAD num -> Bool
<= :: LMAD num -> LMAD num -> Bool
$c> :: forall num. Ord num => LMAD num -> LMAD num -> Bool
> :: LMAD num -> LMAD num -> Bool
$c>= :: forall num. Ord num => LMAD num -> LMAD num -> Bool
>= :: LMAD num -> LMAD num -> Bool
$cmax :: forall num. Ord num => LMAD num -> LMAD num -> LMAD num
max :: LMAD num -> LMAD num -> LMAD num
$cmin :: forall num. Ord num => LMAD num -> LMAD num -> LMAD num
min :: LMAD num -> LMAD num -> LMAD num
Ord)

instance (Pretty num) => Pretty (LMAD num) where
  pretty :: forall ann. LMAD num -> Doc ann
pretty (LMAD num
offset [LMADDim num]
dims) =
    Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
braces (Doc ann -> Doc ann)
-> ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
semistack ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$
      [ Doc ann
"offset:" Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
group (num -> Doc ann
forall ann. num -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty num
offset),
        Doc ann
"strides:" Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> (LMADDim num -> num) -> Doc ann
forall {b} {ann}. Pretty b => (LMADDim num -> b) -> Doc ann
p LMADDim num -> num
forall num. LMADDim num -> num
ldStride,
        Doc ann
"shape:" Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> (LMADDim num -> num) -> Doc ann
forall {b} {ann}. Pretty b => (LMADDim num -> b) -> Doc ann
p LMADDim num -> num
forall num. LMADDim num -> num
ldShape
      ]
    where
      p :: (LMADDim num -> b) -> Doc ann
p LMADDim num -> b
f = Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
group (Doc ann -> Doc ann) -> Doc ann -> Doc ann
forall a b. (a -> b) -> a -> b
$ Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
brackets (Doc ann -> Doc ann) -> Doc ann -> Doc ann
forall a b. (a -> b) -> a -> b
$ Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
align (Doc ann -> Doc ann) -> Doc ann -> Doc ann
forall a b. (a -> b) -> a -> b
$ [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
commasep ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ (LMADDim num -> Doc ann) -> [LMADDim num] -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map (b -> Doc ann
forall ann. b -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty (b -> Doc ann) -> (LMADDim num -> b) -> LMADDim num -> Doc ann
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. LMADDim num -> b
f) [LMADDim num]
dims

instance (Substitute num) => Substitute (LMAD num) where
  substituteNames :: Map VName VName -> LMAD num -> LMAD num
substituteNames Map VName VName
substs = (num -> num) -> LMAD num -> LMAD num
forall a b. (a -> b) -> LMAD a -> LMAD b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((num -> num) -> LMAD num -> LMAD num)
-> (num -> num) -> LMAD num -> LMAD num
forall a b. (a -> b) -> a -> b
$ Map VName VName -> num -> num
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs

instance (Substitute num) => Rename (LMAD num) where
  rename :: LMAD num -> RenameM (LMAD num)
rename = LMAD num -> RenameM (LMAD num)
forall a. Substitute a => a -> RenameM a
substituteRename

instance (FreeIn num) => FreeIn (LMAD num) where
  freeIn' :: LMAD num -> FV
freeIn' = (num -> FV) -> LMAD num -> FV
forall m a. Monoid m => (a -> m) -> LMAD a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap num -> FV
forall a. FreeIn a => a -> FV
freeIn'

instance (FreeIn num) => FreeIn (LMADDim num) where
  freeIn' :: LMADDim num -> FV
freeIn' (LMADDim num
s num
n) = num -> FV
forall a. FreeIn a => a -> FV
freeIn' num
s FV -> FV -> FV
forall a. Semigroup a => a -> a -> a
<> num -> FV
forall a. FreeIn a => a -> FV
freeIn' num
n

instance Functor LMAD where
  fmap :: forall a b. (a -> b) -> LMAD a -> LMAD b
fmap = (a -> b) -> LMAD a -> LMAD b
forall (t :: * -> *) a b. Traversable t => (a -> b) -> t a -> t b
fmapDefault

instance Foldable LMAD where
  foldMap :: forall m a. Monoid m => (a -> m) -> LMAD a -> m
foldMap = (a -> m) -> LMAD a -> m
forall (t :: * -> *) m a.
(Traversable t, Monoid m) =>
(a -> m) -> t a -> m
foldMapDefault

instance Traversable LMAD where
  traverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> LMAD a -> f (LMAD b)
traverse a -> f b
f (LMAD a
offset [LMADDim a]
dims) =
    b -> [LMADDim b] -> LMAD b
forall num. num -> [LMADDim num] -> LMAD num
LMAD (b -> [LMADDim b] -> LMAD b) -> f b -> f ([LMADDim b] -> LMAD b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> f b
f a
offset f ([LMADDim b] -> LMAD b) -> f [LMADDim b] -> f (LMAD b)
forall a b. f (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (LMADDim a -> f (LMADDim b)) -> [LMADDim a] -> f [LMADDim b]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse LMADDim a -> f (LMADDim b)
f' [LMADDim a]
dims
    where
      f' :: LMADDim a -> f (LMADDim b)
f' (LMADDim a
s a
n) = b -> b -> LMADDim b
forall num. num -> num -> LMADDim num
LMADDim (b -> b -> LMADDim b) -> f b -> f (b -> LMADDim b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> f b
f a
s f (b -> LMADDim b) -> f b -> f (LMADDim b)
forall a b. f (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> a -> f b
f a
n

flatOneDim ::
  (Eq num, IntegralExp num) =>
  num ->
  num ->
  num
flatOneDim :: forall num. (Eq num, IntegralExp num) => num -> num -> num
flatOneDim num
s num
i
  | num
s num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0 = num
0
  | Bool
otherwise = num
i num -> num -> num
forall a. Num a => a -> a -> a
* num
s

index :: (IntegralExp num, Eq num) => LMAD num -> Indices num -> num
index :: forall num.
(IntegralExp num, Eq num) =>
LMAD num -> Indices num -> num
index (LMAD num
off [LMADDim num]
dims) Indices num
inds =
  num
off num -> num -> num
forall a. Num a => a -> a -> a
+ Indices num -> num
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum Indices num
prods
  where
    prods :: Indices num
prods = (num -> num -> num) -> Indices num -> Indices num -> Indices num
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith num -> num -> num
forall num. (Eq num, IntegralExp num) => num -> num -> num
flatOneDim ((LMADDim num -> num) -> [LMADDim num] -> Indices num
forall a b. (a -> b) -> [a] -> [b]
map LMADDim num -> num
forall num. LMADDim num -> num
ldStride [LMADDim num]
dims) Indices num
inds

-- | Handle the case where a slice can stay within a single LMAD.
slice ::
  (Eq num, IntegralExp num) =>
  LMAD num ->
  Slice num ->
  LMAD num
slice :: forall num.
(Eq num, IntegralExp num) =>
LMAD num -> Slice num -> LMAD num
slice lmad :: LMAD num
lmad@(LMAD num
_ [LMADDim num]
ldims) (Slice [DimIndex num]
is) =
  (LMAD num -> (DimIndex num, LMADDim num) -> LMAD num)
-> LMAD num -> [(DimIndex num, LMADDim num)] -> LMAD num
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl LMAD num -> (DimIndex num, LMADDim num) -> LMAD num
forall num.
(Eq num, IntegralExp num) =>
LMAD num -> (DimIndex num, LMADDim num) -> LMAD num
sliceOne (num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD (LMAD num -> num
forall num. LMAD num -> num
offset LMAD num
lmad) []) ([(DimIndex num, LMADDim num)] -> LMAD num)
-> [(DimIndex num, LMADDim num)] -> LMAD num
forall a b. (a -> b) -> a -> b
$ [DimIndex num] -> [LMADDim num] -> [(DimIndex num, LMADDim num)]
forall a b. [a] -> [b] -> [(a, b)]
zip [DimIndex num]
is [LMADDim num]
ldims
  where
    sliceOne ::
      (Eq num, IntegralExp num) =>
      LMAD num ->
      (DimIndex num, LMADDim num) ->
      LMAD num
    sliceOne :: forall num.
(Eq num, IntegralExp num) =>
LMAD num -> (DimIndex num, LMADDim num) -> LMAD num
sliceOne (LMAD num
off [LMADDim num]
dims) (DimFix num
i, LMADDim num
s num
_x) =
      num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD (num
off num -> num -> num
forall a. Num a => a -> a -> a
+ num -> num -> num
forall num. (Eq num, IntegralExp num) => num -> num -> num
flatOneDim num
s num
i) [LMADDim num]
dims
    sliceOne (LMAD num
off [LMADDim num]
dims) (DimSlice num
_ num
ne num
_, LMADDim num
0 num
_) =
      num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off ([LMADDim num]
dims [LMADDim num] -> [LMADDim num] -> [LMADDim num]
forall a. [a] -> [a] -> [a]
++ [num -> num -> LMADDim num
forall num. num -> num -> LMADDim num
LMADDim num
0 num
ne])
    sliceOne (LMAD num
off [LMADDim num]
dims) (DimIndex num
dmind, dim :: LMADDim num
dim@(LMADDim num
_ num
n))
      | DimIndex num
dmind DimIndex num -> DimIndex num -> Bool
forall a. Eq a => a -> a -> Bool
== num -> num -> DimIndex num
forall d. Num d => d -> d -> DimIndex d
unitSlice num
0 num
n = num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off ([LMADDim num]
dims [LMADDim num] -> [LMADDim num] -> [LMADDim num]
forall a. [a] -> [a] -> [a]
++ [LMADDim num
dim])
    sliceOne (LMAD num
off [LMADDim num]
dims) (DimIndex num
dmind, LMADDim num
s num
n)
      | DimIndex num
dmind DimIndex num -> DimIndex num -> Bool
forall a. Eq a => a -> a -> Bool
== num -> num -> num -> DimIndex num
forall d. d -> d -> d -> DimIndex d
DimSlice (num
n num -> num -> num
forall a. Num a => a -> a -> a
- num
1) num
n (-num
1) =
          let off' :: num
off' = num
off num -> num -> num
forall a. Num a => a -> a -> a
+ num -> num -> num
forall num. (Eq num, IntegralExp num) => num -> num -> num
flatOneDim num
s (num
n num -> num -> num
forall a. Num a => a -> a -> a
- num
1)
           in num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off' ([LMADDim num]
dims [LMADDim num] -> [LMADDim num] -> [LMADDim num]
forall a. [a] -> [a] -> [a]
++ [num -> num -> LMADDim num
forall num. num -> num -> LMADDim num
LMADDim (num
s num -> num -> num
forall a. Num a => a -> a -> a
* (-num
1)) num
n])
    sliceOne (LMAD num
off [LMADDim num]
dims) (DimSlice num
b num
ne num
0, LMADDim num
s num
_) =
      num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD (num
off num -> num -> num
forall a. Num a => a -> a -> a
+ num -> num -> num
forall num. (Eq num, IntegralExp num) => num -> num -> num
flatOneDim num
s num
b) ([LMADDim num]
dims [LMADDim num] -> [LMADDim num] -> [LMADDim num]
forall a. [a] -> [a] -> [a]
++ [num -> num -> LMADDim num
forall num. num -> num -> LMADDim num
LMADDim num
0 num
ne])
    sliceOne (LMAD num
off [LMADDim num]
dims) (DimSlice num
bs num
ns num
ss, LMADDim num
s num
_) =
      num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD (num
off num -> num -> num
forall a. Num a => a -> a -> a
+ num
s num -> num -> num
forall a. Num a => a -> a -> a
* num
bs) ([LMADDim num]
dims [LMADDim num] -> [LMADDim num] -> [LMADDim num]
forall a. [a] -> [a] -> [a]
++ [num -> num -> LMADDim num
forall num. num -> num -> LMADDim num
LMADDim (num
ss num -> num -> num
forall a. Num a => a -> a -> a
* num
s) num
ns])

-- | Flat-slice an LMAD.
flatSlice ::
  (IntegralExp num) =>
  LMAD num ->
  FlatSlice num ->
  LMAD num
flatSlice :: forall num.
IntegralExp num =>
LMAD num -> FlatSlice num -> LMAD num
flatSlice (LMAD num
offset (LMADDim num
dim : [LMADDim num]
dims)) (FlatSlice num
new_offset [FlatDimIndex num]
is) =
  num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD
    (num
offset num -> num -> num
forall a. Num a => a -> a -> a
+ num
new_offset num -> num -> num
forall a. Num a => a -> a -> a
* LMADDim num -> num
forall num. LMADDim num -> num
ldStride LMADDim num
dim)
    ((FlatDimIndex num -> LMADDim num)
-> [FlatDimIndex num] -> [LMADDim num]
forall a b. (a -> b) -> [a] -> [b]
map (num -> FlatDimIndex num -> LMADDim num
forall {num}. Num num => num -> FlatDimIndex num -> LMADDim num
helper (num -> FlatDimIndex num -> LMADDim num)
-> num -> FlatDimIndex num -> LMADDim num
forall a b. (a -> b) -> a -> b
$ LMADDim num -> num
forall num. LMADDim num -> num
ldStride LMADDim num
dim) [FlatDimIndex num]
is [LMADDim num] -> [LMADDim num] -> [LMADDim num]
forall a. Semigroup a => a -> a -> a
<> [LMADDim num]
dims)
  where
    helper :: num -> FlatDimIndex num -> LMADDim num
helper num
s0 (FlatDimIndex num
n num
s) = num -> num -> LMADDim num
forall num. num -> num -> LMADDim num
LMADDim (num
s0 num -> num -> num
forall a. Num a => a -> a -> a
* num
s) num
n
flatSlice (LMAD num
offset []) FlatSlice num
_ = num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD num
offset []

-- | Handle the case where a reshape operation can stay inside a
-- single LMAD.  See "Futhark.IR.Mem.IxFun.reshape" for
-- conditions.
reshape ::
  (Eq num, IntegralExp num) => LMAD num -> Shape num -> Maybe (LMAD num)
--
-- First a special case for when we are merely injecting unit
-- dimensions into an LMAD.
reshape :: forall num.
(Eq num, IntegralExp num) =>
LMAD num -> Shape num -> Maybe (LMAD num)
reshape (LMAD num
off [LMADDim num]
dims) Shape num
newshape
  | Just [LMADDim num]
dims' <- Shape num -> [LMADDim num] -> Maybe [LMADDim num]
forall {a}.
(Eq a, Num a) =>
[a] -> [LMADDim a] -> Maybe [LMADDim a]
addingVacuous Shape num
newshape [LMADDim num]
dims =
      LMAD num -> Maybe (LMAD num)
forall a. a -> Maybe a
Just (LMAD num -> Maybe (LMAD num)) -> LMAD num -> Maybe (LMAD num)
forall a b. (a -> b) -> a -> b
$ num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off [LMADDim num]
dims'
  where
    addingVacuous :: [a] -> [LMADDim a] -> Maybe [LMADDim a]
addingVacuous (a
dnew : [a]
dnews) (LMADDim a
dold : [LMADDim a]
dolds)
      | a
dnew a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== LMADDim a -> a
forall num. LMADDim num -> num
ldShape LMADDim a
dold =
          (LMADDim a
dold :) ([LMADDim a] -> [LMADDim a])
-> Maybe [LMADDim a] -> Maybe [LMADDim a]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [a] -> [LMADDim a] -> Maybe [LMADDim a]
addingVacuous [a]
dnews [LMADDim a]
dolds
    addingVacuous (a
1 : [a]
dnews) [LMADDim a]
dolds =
      (a -> a -> LMADDim a
forall num. num -> num -> LMADDim num
LMADDim a
0 a
1 :) ([LMADDim a] -> [LMADDim a])
-> Maybe [LMADDim a] -> Maybe [LMADDim a]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [a] -> [LMADDim a] -> Maybe [LMADDim a]
addingVacuous [a]
dnews [LMADDim a]
dolds
    addingVacuous [] [] = [LMADDim a] -> Maybe [LMADDim a]
forall a. a -> Maybe a
Just []
    addingVacuous [a]
_ [LMADDim a]
_ = Maybe [LMADDim a]
forall a. Maybe a
Nothing

-- Then the general case.
reshape lmad :: LMAD num
lmad@(LMAD num
off [LMADDim num]
dims) Shape num
newshape = do
  let base_stride :: num
base_stride = LMADDim num -> num
forall num. LMADDim num -> num
ldStride ([LMADDim num] -> LMADDim num
forall a. HasCallStack => [a] -> a
last [LMADDim num]
dims)
      no_zero_stride :: Bool
no_zero_stride = (LMADDim num -> Bool) -> [LMADDim num] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\LMADDim num
ld -> LMADDim num -> num
forall num. LMADDim num -> num
ldStride LMADDim num
ld num -> num -> Bool
forall a. Eq a => a -> a -> Bool
/= num
0) [LMADDim num]
dims
      strides_as_expected :: Bool
strides_as_expected = LMAD num
lmad LMAD num -> LMAD num -> Bool
forall a. Eq a => a -> a -> Bool
== num -> num -> Shape num -> LMAD num
forall num. IntegralExp num => num -> num -> [num] -> LMAD num
iotaStrided num
off num
base_stride (LMAD num -> Shape num
forall a. LMAD a -> [a]
shape LMAD num
lmad)

  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Bool
no_zero_stride Bool -> Bool -> Bool
&& Bool
strides_as_expected

  LMAD num -> Maybe (LMAD num)
forall a. a -> Maybe a
Just (LMAD num -> Maybe (LMAD num)) -> LMAD num -> Maybe (LMAD num)
forall a b. (a -> b) -> a -> b
$ num -> num -> Shape num -> LMAD num
forall num. IntegralExp num => num -> num -> [num] -> LMAD num
iotaStrided num
off num
base_stride Shape num
newshape
{-# NOINLINE reshape #-}

-- | Substitute a name with a PrimExp in an LMAD.
substituteInLMAD ::
  (Ord a) =>
  M.Map a (TPrimExp t a) ->
  LMAD (TPrimExp t a) ->
  LMAD (TPrimExp t a)
substituteInLMAD :: forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a) -> LMAD (TPrimExp t a) -> LMAD (TPrimExp t a)
substituteInLMAD Map a (TPrimExp t a)
tab (LMAD TPrimExp t a
offset [LMADDim (TPrimExp t a)]
dims) =
  TPrimExp t a -> [LMADDim (TPrimExp t a)] -> LMAD (TPrimExp t a)
forall num. num -> [LMADDim num] -> LMAD num
LMAD (TPrimExp t a -> TPrimExp t a
forall {k} {k} {t :: k} {t :: k}. TPrimExp t a -> TPrimExp t a
sub TPrimExp t a
offset) ([LMADDim (TPrimExp t a)] -> LMAD (TPrimExp t a))
-> [LMADDim (TPrimExp t a)] -> LMAD (TPrimExp t a)
forall a b. (a -> b) -> a -> b
$ (LMADDim (TPrimExp t a) -> LMADDim (TPrimExp t a))
-> [LMADDim (TPrimExp t a)] -> [LMADDim (TPrimExp t a)]
forall a b. (a -> b) -> [a] -> [b]
map (\(LMADDim TPrimExp t a
s TPrimExp t a
n) -> TPrimExp t a -> TPrimExp t a -> LMADDim (TPrimExp t a)
forall num. num -> num -> LMADDim num
LMADDim (TPrimExp t a -> TPrimExp t a
forall {k} {k} {t :: k} {t :: k}. TPrimExp t a -> TPrimExp t a
sub TPrimExp t a
s) (TPrimExp t a -> TPrimExp t a
forall {k} {k} {t :: k} {t :: k}. TPrimExp t a -> TPrimExp t a
sub TPrimExp t a
n)) [LMADDim (TPrimExp t a)]
dims
  where
    tab' :: Map a (PrimExp a)
tab' = (TPrimExp t a -> PrimExp a)
-> Map a (TPrimExp t a) -> Map a (PrimExp a)
forall a b. (a -> b) -> Map a a -> Map a b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap TPrimExp t a -> PrimExp a
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped Map a (TPrimExp t a)
tab
    sub :: TPrimExp t a -> TPrimExp t a
sub = PrimExp a -> TPrimExp t a
forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp (PrimExp a -> TPrimExp t a)
-> (TPrimExp t a -> PrimExp a) -> TPrimExp t a -> TPrimExp t a
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map a (PrimExp a) -> PrimExp a -> PrimExp a
forall v. Ord v => Map v (PrimExp v) -> PrimExp v -> PrimExp v
substituteInPrimExp Map a (PrimExp a)
tab' (PrimExp a -> PrimExp a)
-> (TPrimExp t a -> PrimExp a) -> TPrimExp t a -> PrimExp a
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. TPrimExp t a -> PrimExp a
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped

-- | Shape of an LMAD.
shape :: LMAD num -> Shape num
shape :: forall a. LMAD a -> [a]
shape = (LMADDim num -> num) -> [LMADDim num] -> [num]
forall a b. (a -> b) -> [a] -> [b]
map LMADDim num -> num
forall num. LMADDim num -> num
ldShape ([LMADDim num] -> [num])
-> (LMAD num -> [LMADDim num]) -> LMAD num -> [num]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
dims

-- | Rank of an LMAD.
rank :: LMAD num -> Int
rank :: forall a. LMAD a -> Int
rank = [num] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([num] -> Int) -> (LMAD num -> [num]) -> LMAD num -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. LMAD num -> [num]
forall a. LMAD a -> [a]
shape

iotaStrided ::
  (IntegralExp num) =>
  -- | Offset
  num ->
  -- | Base Stride
  num ->
  -- | Shape
  [num] ->
  LMAD num
iotaStrided :: forall num. IntegralExp num => num -> num -> [num] -> LMAD num
iotaStrided num
off num
s [num]
ns =
  let ss :: [num]
ss = [num] -> [num]
forall a. HasCallStack => [a] -> [a]
tail ([num] -> [num]) -> [num] -> [num]
forall a b. (a -> b) -> a -> b
$ [num] -> [num]
forall a. [a] -> [a]
reverse ([num] -> [num]) -> [num] -> [num]
forall a b. (a -> b) -> a -> b
$ (num -> num -> num) -> num -> [num] -> [num]
forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl num -> num -> num
forall a. Num a => a -> a -> a
(*) num
s ([num] -> [num]) -> [num] -> [num]
forall a b. (a -> b) -> a -> b
$ [num] -> [num]
forall a. [a] -> [a]
reverse [num]
ns
   in num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off ([LMADDim num] -> LMAD num) -> [LMADDim num] -> LMAD num
forall a b. (a -> b) -> a -> b
$ (num -> num -> LMADDim num) -> [num] -> [num] -> [LMADDim num]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith num -> num -> LMADDim num
forall num. num -> num -> LMADDim num
LMADDim [num]
ss [num]
ns

-- | Generalised iota with user-specified offset.
iota ::
  (IntegralExp num) =>
  -- | Offset
  num ->
  -- | Shape
  [num] ->
  LMAD num
iota :: forall num. IntegralExp num => num -> [num] -> LMAD num
iota num
off = num -> num -> [num] -> LMAD num
forall num. IntegralExp num => num -> num -> [num] -> LMAD num
iotaStrided num
off num
1
{-# NOINLINE iota #-}

-- | Create an LMAD that is existential in everything.
mkExistential :: Int -> Int -> LMAD (Ext a)
mkExistential :: forall a. Int -> Int -> LMAD (Ext a)
mkExistential Int
r Int
start = Ext a -> [LMADDim (Ext a)] -> LMAD (Ext a)
forall num. num -> [LMADDim num] -> LMAD num
LMAD (Int -> Ext a
forall a. Int -> Ext a
Ext Int
start) ([LMADDim (Ext a)] -> LMAD (Ext a))
-> [LMADDim (Ext a)] -> LMAD (Ext a)
forall a b. (a -> b) -> a -> b
$ (Int -> LMADDim (Ext a)) -> [Int] -> [LMADDim (Ext a)]
forall a b. (a -> b) -> [a] -> [b]
map Int -> LMADDim (Ext a)
forall {a}. Int -> LMADDim (Ext a)
onDim [Int
0 .. Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
  where
    onDim :: Int -> LMADDim (Ext a)
onDim Int
i = Ext a -> Ext a -> LMADDim (Ext a)
forall num. num -> num -> LMADDim num
LMADDim (Int -> Ext a
forall a. Int -> Ext a
Ext (Int
start Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
2)) (Int -> Ext a
forall a. Int -> Ext a
Ext (Int
start Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
2))

-- | Permute dimensions.
permute :: LMAD num -> Permutation -> LMAD num
permute :: forall num. LMAD num -> [Int] -> LMAD num
permute LMAD num
lmad [Int]
perm =
  LMAD num
lmad {dims :: [LMADDim num]
dims = [Int] -> [LMADDim num] -> [LMADDim num]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm ([LMADDim num] -> [LMADDim num]) -> [LMADDim num] -> [LMADDim num]
forall a b. (a -> b) -> a -> b
$ LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
dims LMAD num
lmad}

-- | Computes the maximum span of an 'LMAD'. The result is the lowest and
-- highest flat values representable by that 'LMAD'.
flatSpan :: LMAD (TPrimExp Int64 VName) -> TPrimExp Int64 VName
flatSpan :: LMAD (TPrimExp Int64 VName) -> TPrimExp Int64 VName
flatSpan (LMAD TPrimExp Int64 VName
_ [LMADDim (TPrimExp Int64 VName)]
dims) =
  (LMADDim (TPrimExp Int64 VName)
 -> TPrimExp Int64 VName -> TPrimExp Int64 VName)
-> TPrimExp Int64 VName
-> [LMADDim (TPrimExp Int64 VName)]
-> TPrimExp Int64 VName
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr
    ( \LMADDim (TPrimExp Int64 VName)
dim TPrimExp Int64 VName
upper ->
        let spn :: TPrimExp Int64 VName
spn = LMADDim (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall num. LMADDim num -> num
ldStride LMADDim (TPrimExp Int64 VName)
dim TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* (LMADDim (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall num. LMADDim num -> num
ldShape LMADDim (TPrimExp Int64 VName)
dim TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1)
         in -- If you've gotten this far, you've already lost
            TPrimExp Int64 VName
spn TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
upper
    )
    TPrimExp Int64 VName
0
    [LMADDim (TPrimExp Int64 VName)]
dims

-- | Conservatively flatten a list of LMAD dimensions
--
-- Since not all LMADs can actually be flattened, we try to overestimate the
-- flattened array instead. This means that any "holes" in betwen dimensions
-- will get filled out.
-- conservativeFlatten :: (IntegralExp e, Ord e, Pretty e) => LMAD e -> LMAD e
conservativeFlatten :: LMAD (TPrimExp Int64 VName) -> Maybe (LMAD (TPrimExp Int64 VName))
conservativeFlatten :: LMAD (TPrimExp Int64 VName) -> Maybe (LMAD (TPrimExp Int64 VName))
conservativeFlatten (LMAD TPrimExp Int64 VName
offset []) =
  LMAD (TPrimExp Int64 VName) -> Maybe (LMAD (TPrimExp Int64 VName))
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (LMAD (TPrimExp Int64 VName)
 -> Maybe (LMAD (TPrimExp Int64 VName)))
-> LMAD (TPrimExp Int64 VName)
-> Maybe (LMAD (TPrimExp Int64 VName))
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
-> [LMADDim (TPrimExp Int64 VName)] -> LMAD (TPrimExp Int64 VName)
forall num. num -> [LMADDim num] -> LMAD num
LMAD TPrimExp Int64 VName
offset [TPrimExp Int64 VName
-> TPrimExp Int64 VName -> LMADDim (TPrimExp Int64 VName)
forall num. num -> num -> LMADDim num
LMADDim TPrimExp Int64 VName
1 TPrimExp Int64 VName
1]
conservativeFlatten l :: LMAD (TPrimExp Int64 VName)
l@(LMAD TPrimExp Int64 VName
_ [LMADDim (TPrimExp Int64 VName)
_]) =
  LMAD (TPrimExp Int64 VName) -> Maybe (LMAD (TPrimExp Int64 VName))
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure LMAD (TPrimExp Int64 VName)
l
conservativeFlatten l :: LMAD (TPrimExp Int64 VName)
l@(LMAD TPrimExp Int64 VName
offset [LMADDim (TPrimExp Int64 VName)]
dims) = do
  TPrimExp Int64 VName
strd <-
    (TPrimExp Int64 VName
 -> TPrimExp Int64 VName -> Maybe (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> [TPrimExp Int64 VName]
-> Maybe (TPrimExp Int64 VName)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM
      TPrimExp Int64 VName
-> TPrimExp Int64 VName -> Maybe (TPrimExp Int64 VName)
gcd
      (LMADDim (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall num. LMADDim num -> num
ldStride (LMADDim (TPrimExp Int64 VName) -> TPrimExp Int64 VName)
-> LMADDim (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ [LMADDim (TPrimExp Int64 VName)] -> LMADDim (TPrimExp Int64 VName)
forall a. HasCallStack => [a] -> a
head [LMADDim (TPrimExp Int64 VName)]
dims)
      ([TPrimExp Int64 VName] -> Maybe (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> Maybe (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ (LMADDim (TPrimExp Int64 VName) -> TPrimExp Int64 VName)
-> [LMADDim (TPrimExp Int64 VName)] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map LMADDim (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall num. LMADDim num -> num
ldStride [LMADDim (TPrimExp Int64 VName)]
dims
  LMAD (TPrimExp Int64 VName) -> Maybe (LMAD (TPrimExp Int64 VName))
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (LMAD (TPrimExp Int64 VName)
 -> Maybe (LMAD (TPrimExp Int64 VName)))
-> LMAD (TPrimExp Int64 VName)
-> Maybe (LMAD (TPrimExp Int64 VName))
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
-> [LMADDim (TPrimExp Int64 VName)] -> LMAD (TPrimExp Int64 VName)
forall num. num -> [LMADDim num] -> LMAD num
LMAD TPrimExp Int64 VName
offset [TPrimExp Int64 VName
-> TPrimExp Int64 VName -> LMADDim (TPrimExp Int64 VName)
forall num. num -> num -> LMADDim num
LMADDim TPrimExp Int64 VName
strd (TPrimExp Int64 VName
shp TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1)]
  where
    shp :: TPrimExp Int64 VName
shp = LMAD (TPrimExp Int64 VName) -> TPrimExp Int64 VName
flatSpan LMAD (TPrimExp Int64 VName)
l

-- | Very conservative GCD calculation. Returns 'Nothing' if the result cannot
-- be immediately determined. Does not recurse at all.
gcd :: TPrimExp Int64 VName -> TPrimExp Int64 VName -> Maybe (TPrimExp Int64 VName)
gcd :: TPrimExp Int64 VName
-> TPrimExp Int64 VName -> Maybe (TPrimExp Int64 VName)
gcd TPrimExp Int64 VName
x TPrimExp Int64 VName
y = TPrimExp Int64 VName
-> TPrimExp Int64 VName -> Maybe (TPrimExp Int64 VName)
forall {a}. (Eq a, Num a) => a -> a -> Maybe a
gcd' (TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a
abs TPrimExp Int64 VName
x) (TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a
abs TPrimExp Int64 VName
y)
  where
    gcd' :: a -> a -> Maybe a
gcd' a
a a
b | a
a a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
b = a -> Maybe a
forall a. a -> Maybe a
Just a
a
    gcd' a
1 a
_ = a -> Maybe a
forall a. a -> Maybe a
Just a
1
    gcd' a
_ a
1 = a -> Maybe a
forall a. a -> Maybe a
Just a
1
    gcd' a
a a
0 = a -> Maybe a
forall a. a -> Maybe a
Just a
a
    gcd' a
_ a
_ = Maybe a
forall a. Maybe a
Nothing -- gcd' b (a `Futhark.Util.IntegralExp.rem` b)

-- | Returns @True@ if the two 'LMAD's could be proven disjoint.
--
-- Uses some best-approximation heuristics to determine disjointness. For two
-- 1-dimensional arrays, we can guarantee whether or not they are disjoint, but
-- as soon as more than one dimension is involved, things get more
-- tricky. Currently, we try to 'conservativelyFlatten' any LMAD with more than
-- one dimension.
disjoint :: [(VName, PrimExp VName)] -> Names -> LMAD (TPrimExp Int64 VName) -> LMAD (TPrimExp Int64 VName) -> Bool
disjoint :: [(VName, PrimExp VName)]
-> Names
-> LMAD (TPrimExp Int64 VName)
-> LMAD (TPrimExp Int64 VName)
-> Bool
disjoint [(VName, PrimExp VName)]
less_thans Names
non_negatives (LMAD TPrimExp Int64 VName
offset1 [LMADDim (TPrimExp Int64 VName)
dim1]) (LMAD TPrimExp Int64 VName
offset2 [LMADDim (TPrimExp Int64 VName)
dim2]) =
  Maybe (TPrimExp Int64 VName) -> TPrimExp Int64 VName -> Bool
doesNotDivide (TPrimExp Int64 VName
-> TPrimExp Int64 VName -> Maybe (TPrimExp Int64 VName)
gcd (LMADDim (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall num. LMADDim num -> num
ldStride LMADDim (TPrimExp Int64 VName)
dim1) (LMADDim (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall num. LMADDim num -> num
ldStride LMADDim (TPrimExp Int64 VName)
dim2)) (TPrimExp Int64 VName
offset1 TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
offset2)
    Bool -> Bool -> Bool
|| [(VName, PrimExp VName)]
-> Names -> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Bool
AlgSimplify.lessThanish
      [(VName, PrimExp VName)]
less_thans
      Names
non_negatives
      (TPrimExp Int64 VName
offset2 TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ (LMADDim (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall num. LMADDim num -> num
ldShape LMADDim (TPrimExp Int64 VName)
dim2 TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* LMADDim (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall num. LMADDim num -> num
ldStride LMADDim (TPrimExp Int64 VName)
dim2)
      TPrimExp Int64 VName
offset1
    Bool -> Bool -> Bool
|| [(VName, PrimExp VName)]
-> Names -> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Bool
AlgSimplify.lessThanish
      [(VName, PrimExp VName)]
less_thans
      Names
non_negatives
      (TPrimExp Int64 VName
offset1 TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ (LMADDim (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall num. LMADDim num -> num
ldShape LMADDim (TPrimExp Int64 VName)
dim1 TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* LMADDim (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall num. LMADDim num -> num
ldStride LMADDim (TPrimExp Int64 VName)
dim1)
      TPrimExp Int64 VName
offset2
  where
    doesNotDivide :: Maybe (TPrimExp Int64 VName) -> TPrimExp Int64 VName -> Bool
    doesNotDivide :: Maybe (TPrimExp Int64 VName) -> TPrimExp Int64 VName -> Bool
doesNotDivide (Just TPrimExp Int64 VName
x) TPrimExp Int64 VName
y =
      TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
Futhark.Util.IntegralExp.mod TPrimExp Int64 VName
y TPrimExp Int64 VName
x
        TPrimExp Int64 VName
-> (TPrimExp Int64 VName -> PrimExp VName) -> PrimExp VName
forall a b. a -> (a -> b) -> b
& TPrimExp Int64 VName -> PrimExp VName
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped
        PrimExp VName -> (PrimExp VName -> PrimExp VName) -> PrimExp VName
forall a b. a -> (a -> b) -> b
& PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v
constFoldPrimExp
        PrimExp VName
-> (PrimExp VName -> TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall a b. a -> (a -> b) -> b
& PrimExp VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp
        TPrimExp Int64 VName
-> (TPrimExp Int64 VName -> TPrimExp Bool VName)
-> TPrimExp Bool VName
forall a b. a -> (a -> b) -> b
& TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
(.==.) (TPrimExp Int64 VName
0 :: TPrimExp Int64 VName)
        TPrimExp Bool VName
-> (TPrimExp Bool VName -> Maybe Bool) -> Maybe Bool
forall a b. a -> (a -> b) -> b
& TPrimExp Bool VName -> Maybe Bool
primBool
        Maybe Bool -> (Maybe Bool -> Bool) -> Bool
forall a b. a -> (a -> b) -> b
& Bool -> (Bool -> Bool) -> Maybe Bool -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False Bool -> Bool
not
    doesNotDivide Maybe (TPrimExp Int64 VName)
_ TPrimExp Int64 VName
_ = Bool
False
disjoint [(VName, PrimExp VName)]
less_thans Names
non_negatives LMAD (TPrimExp Int64 VName)
lmad1 LMAD (TPrimExp Int64 VName)
lmad2 =
  case (LMAD (TPrimExp Int64 VName) -> Maybe (LMAD (TPrimExp Int64 VName))
conservativeFlatten LMAD (TPrimExp Int64 VName)
lmad1, LMAD (TPrimExp Int64 VName) -> Maybe (LMAD (TPrimExp Int64 VName))
conservativeFlatten LMAD (TPrimExp Int64 VName)
lmad2) of
    (Just LMAD (TPrimExp Int64 VName)
lmad1', Just LMAD (TPrimExp Int64 VName)
lmad2') -> [(VName, PrimExp VName)]
-> Names
-> LMAD (TPrimExp Int64 VName)
-> LMAD (TPrimExp Int64 VName)
-> Bool
disjoint [(VName, PrimExp VName)]
less_thans Names
non_negatives LMAD (TPrimExp Int64 VName)
lmad1' LMAD (TPrimExp Int64 VName)
lmad2'
    (Maybe (LMAD (TPrimExp Int64 VName)),
 Maybe (LMAD (TPrimExp Int64 VName)))
_ -> Bool
False

disjoint2 :: scope -> asserts -> [(VName, PrimExp VName)] -> Names -> LMAD (TPrimExp Int64 VName) -> LMAD (TPrimExp Int64 VName) -> Bool
disjoint2 :: forall scope asserts.
scope
-> asserts
-> [(VName, PrimExp VName)]
-> Names
-> LMAD (TPrimExp Int64 VName)
-> LMAD (TPrimExp Int64 VName)
-> Bool
disjoint2 scope
_ asserts
_ [(VName, PrimExp VName)]
less_thans Names
non_negatives LMAD (TPrimExp Int64 VName)
lmad1 LMAD (TPrimExp Int64 VName)
lmad2 =
  let (SofP
offset1, [Interval]
interval1) = LMAD (TPrimExp Int64 VName) -> (SofP, [Interval])
lmadToIntervals LMAD (TPrimExp Int64 VName)
lmad1
      (SofP
offset2, [Interval]
interval2) = LMAD (TPrimExp Int64 VName) -> (SofP, [Interval])
lmadToIntervals LMAD (TPrimExp Int64 VName)
lmad2
      (SofP
neg_offset, SofP
pos_offset) =
        (Prod -> Bool) -> SofP -> (SofP, SofP)
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition Prod -> Bool
AlgSimplify.negated (SofP -> (SofP, SofP)) -> SofP -> (SofP, SofP)
forall a b. (a -> b) -> a -> b
$
          SofP
offset1 SofP -> SofP -> SofP
`AlgSimplify.sub` SofP
offset2
      ([Interval]
interval1', [Interval]
interval2') =
        [(Interval, Interval)] -> ([Interval], [Interval])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Interval, Interval)] -> ([Interval], [Interval]))
-> [(Interval, Interval)] -> ([Interval], [Interval])
forall a b. (a -> b) -> a -> b
$
          ((Interval, Interval) -> (Interval, Interval) -> Ordering)
-> [(Interval, Interval)] -> [(Interval, Interval)]
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy ((SofP -> SofP -> Ordering) -> SofP -> SofP -> Ordering
forall a b c. (a -> b -> c) -> b -> a -> c
flip SofP -> SofP -> Ordering
AlgSimplify.compareComplexity (SofP -> SofP -> Ordering)
-> ((Interval, Interval) -> SofP)
-> (Interval, Interval)
-> (Interval, Interval)
-> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (PrimExp VName -> SofP
AlgSimplify.simplify0 (PrimExp VName -> SofP)
-> ((Interval, Interval) -> PrimExp VName)
-> (Interval, Interval)
-> SofP
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. TPrimExp Int64 VName -> PrimExp VName
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName -> PrimExp VName)
-> ((Interval, Interval) -> TPrimExp Int64 VName)
-> (Interval, Interval)
-> PrimExp VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Interval -> TPrimExp Int64 VName
stride (Interval -> TPrimExp Int64 VName)
-> ((Interval, Interval) -> Interval)
-> (Interval, Interval)
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (Interval, Interval) -> Interval
forall a b. (a, b) -> a
fst)) ([(Interval, Interval)] -> [(Interval, Interval)])
-> [(Interval, Interval)] -> [(Interval, Interval)]
forall a b. (a -> b) -> a -> b
$
            [Interval] -> [Interval] -> [(Interval, Interval)]
intervalPairs [Interval]
interval1 [Interval]
interval2
   in case ( SofP -> [Interval] -> Maybe [Interval]
forall (m :: * -> *).
MonadFail m =>
SofP -> [Interval] -> m [Interval]
distributeOffset SofP
pos_offset [Interval]
interval1',
             SofP -> [Interval] -> Maybe [Interval]
forall (m :: * -> *).
MonadFail m =>
SofP -> [Interval] -> m [Interval]
distributeOffset ((Prod -> Prod) -> SofP -> SofP
forall a b. (a -> b) -> [a] -> [b]
map Prod -> Prod
AlgSimplify.negate SofP
neg_offset) [Interval]
interval2'
           ) of
        (Just [Interval]
interval1'', Just [Interval]
interval2'') ->
          Maybe Interval -> Bool
forall a. Maybe a -> Bool
isNothing
            ( ()
-> ()
-> [(VName, PrimExp VName)]
-> [PrimExp VName]
-> [Interval]
-> Maybe Interval
forall scope asserts.
scope
-> asserts
-> [(VName, PrimExp VName)]
-> [PrimExp VName]
-> [Interval]
-> Maybe Interval
selfOverlap () () [(VName, PrimExp VName)]
less_thans ((VName -> PrimExp VName) -> [VName] -> [PrimExp VName]
forall a b. (a -> b) -> [a] -> [b]
map ((VName -> PrimType -> PrimExp VName)
-> PrimType -> VName -> PrimExp VName
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp (PrimType -> VName -> PrimExp VName)
-> PrimType -> VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int64) ([VName] -> [PrimExp VName]) -> [VName] -> [PrimExp VName]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList Names
non_negatives) [Interval]
interval1''
            )
            Bool -> Bool -> Bool
&& Maybe Interval -> Bool
forall a. Maybe a -> Bool
isNothing
              ( ()
-> ()
-> [(VName, PrimExp VName)]
-> [PrimExp VName]
-> [Interval]
-> Maybe Interval
forall scope asserts.
scope
-> asserts
-> [(VName, PrimExp VName)]
-> [PrimExp VName]
-> [Interval]
-> Maybe Interval
selfOverlap () () [(VName, PrimExp VName)]
less_thans ((VName -> PrimExp VName) -> [VName] -> [PrimExp VName]
forall a b. (a -> b) -> [a] -> [b]
map ((VName -> PrimType -> PrimExp VName)
-> PrimType -> VName -> PrimExp VName
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp (PrimType -> VName -> PrimExp VName)
-> PrimType -> VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int64) ([VName] -> [PrimExp VName]) -> [VName] -> [PrimExp VName]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList Names
non_negatives) [Interval]
interval2''
              )
            Bool -> Bool -> Bool
&& Bool -> Bool
not
              ( ((Interval, Interval) -> Bool) -> [(Interval, Interval)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all
                  ((Interval -> Interval -> Bool) -> (Interval, Interval) -> Bool
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ([(VName, PrimExp VName)] -> Names -> Interval -> Interval -> Bool
intervalOverlap [(VName, PrimExp VName)]
less_thans Names
non_negatives))
                  ([Interval] -> [Interval] -> [(Interval, Interval)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Interval]
interval1'' [Interval]
interval2'')
              )
        (Maybe [Interval], Maybe [Interval])
_ ->
          Bool
False

disjoint3 :: M.Map VName Type -> [PrimExp VName] -> [(VName, PrimExp VName)] -> [PrimExp VName] -> LMAD (TPrimExp Int64 VName) -> LMAD (TPrimExp Int64 VName) -> Bool
disjoint3 :: Map VName Type
-> [PrimExp VName]
-> [(VName, PrimExp VName)]
-> [PrimExp VName]
-> LMAD (TPrimExp Int64 VName)
-> LMAD (TPrimExp Int64 VName)
-> Bool
disjoint3 Map VName Type
scope [PrimExp VName]
asserts [(VName, PrimExp VName)]
less_thans [PrimExp VName]
non_negatives LMAD (TPrimExp Int64 VName)
lmad1 LMAD (TPrimExp Int64 VName)
lmad2 =
  let (SofP
offset1, [Interval]
interval1) = LMAD (TPrimExp Int64 VName) -> (SofP, [Interval])
lmadToIntervals LMAD (TPrimExp Int64 VName)
lmad1
      (SofP
offset2, [Interval]
interval2) = LMAD (TPrimExp Int64 VName) -> (SofP, [Interval])
lmadToIntervals LMAD (TPrimExp Int64 VName)
lmad2
      interval1' :: [Interval]
interval1' = ([Interval] -> [Interval]) -> [Interval] -> [Interval]
forall a. Eq a => (a -> a) -> a -> a
fixPoint ([Interval] -> [Interval]
mergeDims ([Interval] -> [Interval])
-> ([Interval] -> [Interval]) -> [Interval] -> [Interval]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [Interval] -> [Interval]
joinDims) ([Interval] -> [Interval]) -> [Interval] -> [Interval]
forall a b. (a -> b) -> a -> b
$ (Interval -> Interval -> Ordering) -> [Interval] -> [Interval]
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy ((SofP -> SofP -> Ordering) -> SofP -> SofP -> Ordering
forall a b c. (a -> b -> c) -> b -> a -> c
flip SofP -> SofP -> Ordering
AlgSimplify.compareComplexity (SofP -> SofP -> Ordering)
-> (Interval -> SofP) -> Interval -> Interval -> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (PrimExp VName -> SofP
AlgSimplify.simplify0 (PrimExp VName -> SofP)
-> (Interval -> PrimExp VName) -> Interval -> SofP
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. TPrimExp Int64 VName -> PrimExp VName
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName -> PrimExp VName)
-> (Interval -> TPrimExp Int64 VName) -> Interval -> PrimExp VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Interval -> TPrimExp Int64 VName
stride)) [Interval]
interval1
      interval2' :: [Interval]
interval2' = ([Interval] -> [Interval]) -> [Interval] -> [Interval]
forall a. Eq a => (a -> a) -> a -> a
fixPoint ([Interval] -> [Interval]
mergeDims ([Interval] -> [Interval])
-> ([Interval] -> [Interval]) -> [Interval] -> [Interval]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [Interval] -> [Interval]
joinDims) ([Interval] -> [Interval]) -> [Interval] -> [Interval]
forall a b. (a -> b) -> a -> b
$ (Interval -> Interval -> Ordering) -> [Interval] -> [Interval]
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy ((SofP -> SofP -> Ordering) -> SofP -> SofP -> Ordering
forall a b c. (a -> b -> c) -> b -> a -> c
flip SofP -> SofP -> Ordering
AlgSimplify.compareComplexity (SofP -> SofP -> Ordering)
-> (Interval -> SofP) -> Interval -> Interval -> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (PrimExp VName -> SofP
AlgSimplify.simplify0 (PrimExp VName -> SofP)
-> (Interval -> PrimExp VName) -> Interval -> SofP
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. TPrimExp Int64 VName -> PrimExp VName
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName -> PrimExp VName)
-> (Interval -> TPrimExp Int64 VName) -> Interval -> PrimExp VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Interval -> TPrimExp Int64 VName
stride)) [Interval]
interval2
      ([Interval]
interval1'', [Interval]
interval2'') =
        [(Interval, Interval)] -> ([Interval], [Interval])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Interval, Interval)] -> ([Interval], [Interval]))
-> [(Interval, Interval)] -> ([Interval], [Interval])
forall a b. (a -> b) -> a -> b
$
          ((Interval, Interval) -> (Interval, Interval) -> Ordering)
-> [(Interval, Interval)] -> [(Interval, Interval)]
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy ((SofP -> SofP -> Ordering) -> SofP -> SofP -> Ordering
forall a b c. (a -> b -> c) -> b -> a -> c
flip SofP -> SofP -> Ordering
AlgSimplify.compareComplexity (SofP -> SofP -> Ordering)
-> ((Interval, Interval) -> SofP)
-> (Interval, Interval)
-> (Interval, Interval)
-> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (PrimExp VName -> SofP
AlgSimplify.simplify0 (PrimExp VName -> SofP)
-> ((Interval, Interval) -> PrimExp VName)
-> (Interval, Interval)
-> SofP
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. TPrimExp Int64 VName -> PrimExp VName
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName -> PrimExp VName)
-> ((Interval, Interval) -> TPrimExp Int64 VName)
-> (Interval, Interval)
-> PrimExp VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Interval -> TPrimExp Int64 VName
stride (Interval -> TPrimExp Int64 VName)
-> ((Interval, Interval) -> Interval)
-> (Interval, Interval)
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (Interval, Interval) -> Interval
forall a b. (a, b) -> a
fst)) ([(Interval, Interval)] -> [(Interval, Interval)])
-> [(Interval, Interval)] -> [(Interval, Interval)]
forall a b. (a -> b) -> a -> b
$
            [Interval] -> [Interval] -> [(Interval, Interval)]
intervalPairs [Interval]
interval1' [Interval]
interval2'
   in Int -> [Interval] -> [Interval] -> SofP -> Bool
disjointHelper Int
4 [Interval]
interval1'' [Interval]
interval2'' (SofP -> Bool) -> SofP -> Bool
forall a b. (a -> b) -> a -> b
$ SofP
offset1 SofP -> SofP -> SofP
`AlgSimplify.sub` SofP
offset2
  where
    disjointHelper :: Int -> [Interval] -> [Interval] -> AlgSimplify.SofP -> Bool
    disjointHelper :: Int -> [Interval] -> [Interval] -> SofP -> Bool
disjointHelper Int
0 [Interval]
_ [Interval]
_ SofP
_ = Bool
False
    disjointHelper Int
i [Interval]
is10 [Interval]
is20 SofP
offset =
      let ([Interval]
is1, [Interval]
is2) =
            [(Interval, Interval)] -> ([Interval], [Interval])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Interval, Interval)] -> ([Interval], [Interval]))
-> [(Interval, Interval)] -> ([Interval], [Interval])
forall a b. (a -> b) -> a -> b
$
              ((Interval, Interval) -> (Interval, Interval) -> Ordering)
-> [(Interval, Interval)] -> [(Interval, Interval)]
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy ((SofP -> SofP -> Ordering) -> SofP -> SofP -> Ordering
forall a b c. (a -> b -> c) -> b -> a -> c
flip SofP -> SofP -> Ordering
AlgSimplify.compareComplexity (SofP -> SofP -> Ordering)
-> ((Interval, Interval) -> SofP)
-> (Interval, Interval)
-> (Interval, Interval)
-> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (PrimExp VName -> SofP
AlgSimplify.simplify0 (PrimExp VName -> SofP)
-> ((Interval, Interval) -> PrimExp VName)
-> (Interval, Interval)
-> SofP
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. TPrimExp Int64 VName -> PrimExp VName
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName -> PrimExp VName)
-> ((Interval, Interval) -> TPrimExp Int64 VName)
-> (Interval, Interval)
-> PrimExp VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Interval -> TPrimExp Int64 VName
stride (Interval -> TPrimExp Int64 VName)
-> ((Interval, Interval) -> Interval)
-> (Interval, Interval)
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (Interval, Interval) -> Interval
forall a b. (a, b) -> a
fst)) ([(Interval, Interval)] -> [(Interval, Interval)])
-> [(Interval, Interval)] -> [(Interval, Interval)]
forall a b. (a -> b) -> a -> b
$
                [Interval] -> [Interval] -> [(Interval, Interval)]
intervalPairs [Interval]
is10 [Interval]
is20
          (SofP
neg_offset, SofP
pos_offset) = (Prod -> Bool) -> SofP -> (SofP, SofP)
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition Prod -> Bool
AlgSimplify.negated SofP
offset
       in case ( SofP -> [Interval] -> Maybe [Interval]
forall (m :: * -> *).
MonadFail m =>
SofP -> [Interval] -> m [Interval]
distributeOffset SofP
pos_offset [Interval]
is1,
                 SofP -> [Interval] -> Maybe [Interval]
forall (m :: * -> *).
MonadFail m =>
SofP -> [Interval] -> m [Interval]
distributeOffset ((Prod -> Prod) -> SofP -> SofP
forall a b. (a -> b) -> [a] -> [b]
map Prod -> Prod
AlgSimplify.negate SofP
neg_offset) [Interval]
is2
               ) of
            (Just [Interval]
is1', Just [Interval]
is2') -> do
              let overlap1 :: Maybe Interval
overlap1 = Map VName Type
-> [PrimExp VName]
-> [(VName, PrimExp VName)]
-> [PrimExp VName]
-> [Interval]
-> Maybe Interval
forall scope asserts.
scope
-> asserts
-> [(VName, PrimExp VName)]
-> [PrimExp VName]
-> [Interval]
-> Maybe Interval
selfOverlap Map VName Type
scope [PrimExp VName]
asserts [(VName, PrimExp VName)]
less_thans [PrimExp VName]
non_negatives [Interval]
is1'
              let overlap2 :: Maybe Interval
overlap2 = Map VName Type
-> [PrimExp VName]
-> [(VName, PrimExp VName)]
-> [PrimExp VName]
-> [Interval]
-> Maybe Interval
forall scope asserts.
scope
-> asserts
-> [(VName, PrimExp VName)]
-> [PrimExp VName]
-> [Interval]
-> Maybe Interval
selfOverlap Map VName Type
scope [PrimExp VName]
asserts [(VName, PrimExp VName)]
less_thans [PrimExp VName]
non_negatives [Interval]
is2'
              case (Maybe Interval
overlap1, Maybe Interval
overlap2) of
                (Maybe Interval
Nothing, Maybe Interval
Nothing) ->
                  case [VName] -> Names
namesFromList ([VName] -> Names) -> Maybe [VName] -> Maybe Names
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (PrimExp VName -> Maybe VName) -> [PrimExp VName] -> Maybe [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM PrimExp VName -> Maybe VName
justLeafExp [PrimExp VName]
non_negatives of
                    Just Names
non_negatives' ->
                      Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$
                        ((Interval, Interval) -> Bool) -> [(Interval, Interval)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all
                          ((Interval -> Interval -> Bool) -> (Interval, Interval) -> Bool
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ([(VName, PrimExp VName)] -> Names -> Interval -> Interval -> Bool
intervalOverlap [(VName, PrimExp VName)]
less_thans Names
non_negatives'))
                          ([Interval] -> [Interval] -> [(Interval, Interval)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Interval]
is1 [Interval]
is2)
                    Maybe Names
_ -> Bool
False
                (Just Interval
overlapping_dim, Maybe Interval
_) ->
                  let expanded_offset :: Maybe SofP
expanded_offset = SofP -> SofP
AlgSimplify.simplifySofP' (SofP -> SofP) -> Maybe SofP -> Maybe SofP
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SofP -> [Interval] -> Maybe SofP
expandOffset SofP
offset [Interval]
is1
                      splits :: [(SofP, [Interval])]
splits = Interval -> [Interval] -> [(SofP, [Interval])]
splitDim Interval
overlapping_dim [Interval]
is1'
                   in ((SofP, [Interval]) -> Bool) -> [(SofP, [Interval])] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\(SofP
new_offset, [Interval]
new_is1) -> Int -> [Interval] -> [Interval] -> SofP -> Bool
disjointHelper (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) ([Interval] -> [Interval]
joinDims [Interval]
new_is1) ([Interval] -> [Interval]
joinDims [Interval]
is2') SofP
new_offset) [(SofP, [Interval])]
splits
                        Bool -> Bool -> Bool
|| Bool -> (SofP -> Bool) -> Maybe SofP -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (Int -> [Interval] -> [Interval] -> SofP -> Bool
disjointHelper (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) [Interval]
is1 [Interval]
is2) Maybe SofP
expanded_offset
                (Maybe Interval
_, Just Interval
overlapping_dim) ->
                  let expanded_offset :: Maybe SofP
expanded_offset = SofP -> SofP
AlgSimplify.simplifySofP' (SofP -> SofP) -> Maybe SofP -> Maybe SofP
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SofP -> [Interval] -> Maybe SofP
expandOffset SofP
offset [Interval]
is2
                      splits :: [(SofP, [Interval])]
splits = Interval -> [Interval] -> [(SofP, [Interval])]
splitDim Interval
overlapping_dim [Interval]
is2'
                   in ((SofP, [Interval]) -> Bool) -> [(SofP, [Interval])] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all
                        ( \(SofP
new_offset, [Interval]
new_is2) ->
                            Int -> [Interval] -> [Interval] -> SofP -> Bool
disjointHelper (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) ([Interval] -> [Interval]
joinDims [Interval]
is1') ([Interval] -> [Interval]
joinDims [Interval]
new_is2) (SofP -> Bool) -> SofP -> Bool
forall a b. (a -> b) -> a -> b
$
                              (Prod -> Prod) -> SofP -> SofP
forall a b. (a -> b) -> [a] -> [b]
map Prod -> Prod
AlgSimplify.negate SofP
new_offset
                        )
                        [(SofP, [Interval])]
splits
                        Bool -> Bool -> Bool
|| Bool -> (SofP -> Bool) -> Maybe SofP -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (Int -> [Interval] -> [Interval] -> SofP -> Bool
disjointHelper (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) [Interval]
is1 [Interval]
is2) Maybe SofP
expanded_offset
            (Maybe [Interval], Maybe [Interval])
_ -> Bool
False

joinDims :: [Interval] -> [Interval]
joinDims :: [Interval] -> [Interval]
joinDims = [Interval] -> [Interval] -> [Interval]
helper []
  where
    helper :: [Interval] -> [Interval] -> [Interval]
helper [Interval]
acc [] = [Interval] -> [Interval]
forall a. [a] -> [a]
reverse [Interval]
acc
    helper [Interval]
acc [Interval
x] = [Interval] -> [Interval]
forall a. [a] -> [a]
reverse ([Interval] -> [Interval]) -> [Interval] -> [Interval]
forall a b. (a -> b) -> a -> b
$ Interval
x Interval -> [Interval] -> [Interval]
forall a. a -> [a] -> [a]
: [Interval]
acc
    helper [Interval]
acc (Interval
x : Interval
y : [Interval]
rest) =
      if Interval -> TPrimExp Int64 VName
stride Interval
x TPrimExp Int64 VName -> TPrimExp Int64 VName -> Bool
forall a. Eq a => a -> a -> Bool
== Interval -> TPrimExp Int64 VName
stride Interval
y Bool -> Bool -> Bool
&& Interval -> TPrimExp Int64 VName
lowerBound Interval
x TPrimExp Int64 VName -> TPrimExp Int64 VName -> Bool
forall a. Eq a => a -> a -> Bool
== TPrimExp Int64 VName
0 Bool -> Bool -> Bool
&& Interval -> TPrimExp Int64 VName
lowerBound Interval
y TPrimExp Int64 VName -> TPrimExp Int64 VName -> Bool
forall a. Eq a => a -> a -> Bool
== TPrimExp Int64 VName
0
        then [Interval] -> [Interval] -> [Interval]
helper [Interval]
acc ([Interval] -> [Interval]) -> [Interval] -> [Interval]
forall a b. (a -> b) -> a -> b
$ Interval
x {numElements :: TPrimExp Int64 VName
numElements = Interval -> TPrimExp Int64 VName
numElements Interval
x TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* Interval -> TPrimExp Int64 VName
numElements Interval
y} Interval -> [Interval] -> [Interval]
forall a. a -> [a] -> [a]
: [Interval]
rest
        else [Interval] -> [Interval] -> [Interval]
helper (Interval
x Interval -> [Interval] -> [Interval]
forall a. a -> [a] -> [a]
: [Interval]
acc) (Interval
y Interval -> [Interval] -> [Interval]
forall a. a -> [a] -> [a]
: [Interval]
rest)

mergeDims :: [Interval] -> [Interval]
mergeDims :: [Interval] -> [Interval]
mergeDims = [Interval] -> [Interval] -> [Interval]
helper [] ([Interval] -> [Interval])
-> ([Interval] -> [Interval]) -> [Interval] -> [Interval]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [Interval] -> [Interval]
forall a. [a] -> [a]
reverse
  where
    helper :: [Interval] -> [Interval] -> [Interval]
helper [Interval]
acc [] = [Interval]
acc
    helper [Interval]
acc [Interval
x] = Interval
x Interval -> [Interval] -> [Interval]
forall a. a -> [a] -> [a]
: [Interval]
acc
    helper [Interval]
acc (Interval
x : Interval
y : [Interval]
rest) =
      if Interval -> TPrimExp Int64 VName
stride Interval
x TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* Interval -> TPrimExp Int64 VName
numElements Interval
x TPrimExp Int64 VName -> TPrimExp Int64 VName -> Bool
forall a. Eq a => a -> a -> Bool
== Interval -> TPrimExp Int64 VName
stride Interval
y Bool -> Bool -> Bool
&& Interval -> TPrimExp Int64 VName
lowerBound Interval
x TPrimExp Int64 VName -> TPrimExp Int64 VName -> Bool
forall a. Eq a => a -> a -> Bool
== TPrimExp Int64 VName
0 Bool -> Bool -> Bool
&& Interval -> TPrimExp Int64 VName
lowerBound Interval
y TPrimExp Int64 VName -> TPrimExp Int64 VName -> Bool
forall a. Eq a => a -> a -> Bool
== TPrimExp Int64 VName
0
        then [Interval] -> [Interval] -> [Interval]
helper [Interval]
acc ([Interval] -> [Interval]) -> [Interval] -> [Interval]
forall a b. (a -> b) -> a -> b
$ Interval
x {numElements :: TPrimExp Int64 VName
numElements = Interval -> TPrimExp Int64 VName
numElements Interval
x TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* Interval -> TPrimExp Int64 VName
numElements Interval
y} Interval -> [Interval] -> [Interval]
forall a. a -> [a] -> [a]
: [Interval]
rest
        else [Interval] -> [Interval] -> [Interval]
helper (Interval
x Interval -> [Interval] -> [Interval]
forall a. a -> [a] -> [a]
: [Interval]
acc) (Interval
y Interval -> [Interval] -> [Interval]
forall a. a -> [a] -> [a]
: [Interval]
rest)

splitDim :: Interval -> [Interval] -> [(AlgSimplify.SofP, [Interval])]
splitDim :: Interval -> [Interval] -> [(SofP, [Interval])]
splitDim Interval
overlapping_dim0 [Interval]
is
  | [Prod
st] <- PrimExp VName -> SofP
AlgSimplify.simplify0 (PrimExp VName -> SofP) -> PrimExp VName -> SofP
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> PrimExp VName
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName -> PrimExp VName)
-> TPrimExp Int64 VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ Interval -> TPrimExp Int64 VName
stride Interval
overlapping_dim0,
    [Prod
st1] <- PrimExp VName -> SofP
AlgSimplify.simplify0 (PrimExp VName -> SofP) -> PrimExp VName -> SofP
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> PrimExp VName
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName -> PrimExp VName)
-> TPrimExp Int64 VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ Interval -> TPrimExp Int64 VName
stride Interval
overlapping_dim,
    [Prod
spn] <- PrimExp VName -> SofP
AlgSimplify.simplify0 (PrimExp VName -> SofP) -> PrimExp VName -> SofP
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> PrimExp VName
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName -> PrimExp VName)
-> TPrimExp Int64 VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ Interval -> TPrimExp Int64 VName
stride Interval
overlapping_dim TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* Interval -> TPrimExp Int64 VName
numElements Interval
overlapping_dim,
    Interval -> TPrimExp Int64 VName
lowerBound Interval
overlapping_dim TPrimExp Int64 VName -> TPrimExp Int64 VName -> Bool
forall a. Eq a => a -> a -> Bool
== TPrimExp Int64 VName
0,
    Just Prod
big_dim_elems <- Prod -> Prod -> Maybe Prod
AlgSimplify.maybeDivide Prod
spn Prod
st,
    Just Prod
small_dim_elems <- Prod -> Prod -> Maybe Prod
AlgSimplify.maybeDivide Prod
st Prod
st1 =
      [ ( [],
          [Interval] -> [Interval]
forall a. HasCallStack => [a] -> [a]
init [Interval]
before
            [Interval] -> [Interval] -> [Interval]
forall a. Semigroup a => a -> a -> a
<> [ TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval TPrimExp Int64 VName
0 (PrimExp VName -> TPrimExp Int64 VName
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (PrimExp VName -> TPrimExp Int64 VName)
-> PrimExp VName -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ Prod -> PrimExp VName
AlgSimplify.prodToExp Prod
big_dim_elems) (Interval -> TPrimExp Int64 VName
stride Interval
overlapping_dim0),
                 TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval TPrimExp Int64 VName
0 (PrimExp VName -> TPrimExp Int64 VName
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (PrimExp VName -> TPrimExp Int64 VName)
-> PrimExp VName -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ Prod -> PrimExp VName
AlgSimplify.prodToExp Prod
small_dim_elems) (Interval -> TPrimExp Int64 VName
stride Interval
overlapping_dim)
               ]
            [Interval] -> [Interval] -> [Interval]
forall a. Semigroup a => a -> a -> a
<> [Interval]
after
        )
      ]
  | Bool
otherwise =
      let shrunk_dim :: Interval
shrunk_dim = Interval
overlapping_dim {numElements :: TPrimExp Int64 VName
numElements = Interval -> TPrimExp Int64 VName
numElements Interval
overlapping_dim TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1}
          point_offset :: SofP
point_offset = PrimExp VName -> SofP
AlgSimplify.simplify0 (PrimExp VName -> SofP) -> PrimExp VName -> SofP
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> PrimExp VName
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName -> PrimExp VName)
-> TPrimExp Int64 VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ (Interval -> TPrimExp Int64 VName
numElements Interval
overlapping_dim TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1 TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ Interval -> TPrimExp Int64 VName
lowerBound Interval
overlapping_dim) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* Interval -> TPrimExp Int64 VName
stride Interval
overlapping_dim
       in [ (SofP
point_offset, [Interval]
before [Interval] -> [Interval] -> [Interval]
forall a. Semigroup a => a -> a -> a
<> [Interval]
after),
            ([], [Interval]
before [Interval] -> [Interval] -> [Interval]
forall a. Semigroup a => a -> a -> a
<> [Interval
shrunk_dim] [Interval] -> [Interval] -> [Interval]
forall a. Semigroup a => a -> a -> a
<> [Interval]
after)
          ]
  where
    ([Interval]
before, Interval
overlapping_dim, [Interval]
after) =
      Maybe ([Interval], Interval, [Interval])
-> ([Interval], Interval, [Interval])
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe ([Interval], Interval, [Interval])
 -> ([Interval], Interval, [Interval]))
-> Maybe ([Interval], Interval, [Interval])
-> ([Interval], Interval, [Interval])
forall a b. (a -> b) -> a -> b
$
        Interval -> [Interval] -> Maybe Int
forall a. Eq a => a -> [a] -> Maybe Int
elemIndex Interval
overlapping_dim0 [Interval]
is
          Maybe Int
-> (Int -> Maybe ([Interval], Interval, [Interval]))
-> Maybe ([Interval], Interval, [Interval])
forall a b. Maybe a -> (a -> Maybe b) -> Maybe b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ((Int -> [Interval] -> Maybe ([Interval], Interval, [Interval]))
-> [Interval] -> Int -> Maybe ([Interval], Interval, [Interval])
forall a b c. (a -> b -> c) -> b -> a -> c
flip Int -> [Interval] -> Maybe ([Interval], Interval, [Interval])
forall int a. Integral int => int -> [a] -> Maybe ([a], a, [a])
focusNth [Interval]
is (Int -> Maybe ([Interval], Interval, [Interval]))
-> (Int -> Int) -> Int -> Maybe ([Interval], Interval, [Interval])
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1))

lmadToIntervals :: LMAD (TPrimExp Int64 VName) -> (AlgSimplify.SofP, [Interval])
lmadToIntervals :: LMAD (TPrimExp Int64 VName) -> (SofP, [Interval])
lmadToIntervals (LMAD TPrimExp Int64 VName
offset []) = (PrimExp VName -> SofP
AlgSimplify.simplify0 (PrimExp VName -> SofP) -> PrimExp VName -> SofP
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> PrimExp VName
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
offset, [TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval TPrimExp Int64 VName
0 TPrimExp Int64 VName
1 TPrimExp Int64 VName
1])
lmadToIntervals (LMAD TPrimExp Int64 VName
offset [LMADDim (TPrimExp Int64 VName)]
dims0) =
  (SofP
offset', (LMADDim (TPrimExp Int64 VName) -> Interval)
-> [LMADDim (TPrimExp Int64 VName)] -> [Interval]
forall a b. (a -> b) -> [a] -> [b]
map LMADDim (TPrimExp Int64 VName) -> Interval
helper [LMADDim (TPrimExp Int64 VName)]
dims0)
  where
    offset' :: SofP
offset' = PrimExp VName -> SofP
AlgSimplify.simplify0 (PrimExp VName -> SofP) -> PrimExp VName -> SofP
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> PrimExp VName
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
offset

    helper :: LMADDim (TPrimExp Int64 VName) -> Interval
    helper :: LMADDim (TPrimExp Int64 VName) -> Interval
helper (LMADDim TPrimExp Int64 VName
strd TPrimExp Int64 VName
shp) = do
      TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval TPrimExp Int64 VName
0 (TPrimExp Int64 VName -> TPrimExp Int64 VName
AlgSimplify.simplify' TPrimExp Int64 VName
shp) (TPrimExp Int64 VName -> TPrimExp Int64 VName
AlgSimplify.simplify' TPrimExp Int64 VName
strd)

-- | Dynamically determine if two 'LMADDim' are equal.
--
-- True if the dynamic values of their constituents are equal.
dynamicEqualsLMADDim :: (Eq num) => LMADDim (TPrimExp t num) -> LMADDim (TPrimExp t num) -> TPrimExp Bool num
dynamicEqualsLMADDim :: forall {k} num (t :: k).
Eq num =>
LMADDim (TPrimExp t num)
-> LMADDim (TPrimExp t num) -> TPrimExp Bool num
dynamicEqualsLMADDim LMADDim (TPrimExp t num)
dim1 LMADDim (TPrimExp t num)
dim2 =
  LMADDim (TPrimExp t num) -> TPrimExp t num
forall num. LMADDim num -> num
ldStride LMADDim (TPrimExp t num)
dim1 TPrimExp t num -> TPrimExp t num -> TPrimExp Bool num
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. LMADDim (TPrimExp t num) -> TPrimExp t num
forall num. LMADDim num -> num
ldStride LMADDim (TPrimExp t num)
dim2 TPrimExp Bool num -> TPrimExp Bool num -> TPrimExp Bool num
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. LMADDim (TPrimExp t num) -> TPrimExp t num
forall num. LMADDim num -> num
ldShape LMADDim (TPrimExp t num)
dim1 TPrimExp t num -> TPrimExp t num -> TPrimExp Bool num
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. LMADDim (TPrimExp t num) -> TPrimExp t num
forall num. LMADDim num -> num
ldShape LMADDim (TPrimExp t num)
dim2

-- | Dynamically determine if two 'LMAD' are equal.
--
-- True if offset and constituent 'LMADDim' are equal.
dynamicEqualsLMAD :: (Eq num) => LMAD (TPrimExp t num) -> LMAD (TPrimExp t num) -> TPrimExp Bool num
dynamicEqualsLMAD :: forall {k} num (t :: k).
Eq num =>
LMAD (TPrimExp t num) -> LMAD (TPrimExp t num) -> TPrimExp Bool num
dynamicEqualsLMAD LMAD (TPrimExp t num)
lmad1 LMAD (TPrimExp t num)
lmad2 =
  LMAD (TPrimExp t num) -> TPrimExp t num
forall num. LMAD num -> num
offset LMAD (TPrimExp t num)
lmad1 TPrimExp t num -> TPrimExp t num -> TPrimExp Bool num
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. LMAD (TPrimExp t num) -> TPrimExp t num
forall num. LMAD num -> num
offset LMAD (TPrimExp t num)
lmad2
    TPrimExp Bool num -> TPrimExp Bool num -> TPrimExp Bool num
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. ((LMADDim (TPrimExp t num), LMADDim (TPrimExp t num))
 -> TPrimExp Bool num -> TPrimExp Bool num)
-> TPrimExp Bool num
-> [(LMADDim (TPrimExp t num), LMADDim (TPrimExp t num))]
-> TPrimExp Bool num
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr
      (TPrimExp Bool num -> TPrimExp Bool num -> TPrimExp Bool num
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
(.&&.) (TPrimExp Bool num -> TPrimExp Bool num -> TPrimExp Bool num)
-> ((LMADDim (TPrimExp t num), LMADDim (TPrimExp t num))
    -> TPrimExp Bool num)
-> (LMADDim (TPrimExp t num), LMADDim (TPrimExp t num))
-> TPrimExp Bool num
-> TPrimExp Bool num
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (LMADDim (TPrimExp t num)
 -> LMADDim (TPrimExp t num) -> TPrimExp Bool num)
-> (LMADDim (TPrimExp t num), LMADDim (TPrimExp t num))
-> TPrimExp Bool num
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry LMADDim (TPrimExp t num)
-> LMADDim (TPrimExp t num) -> TPrimExp Bool num
forall {k} num (t :: k).
Eq num =>
LMADDim (TPrimExp t num)
-> LMADDim (TPrimExp t num) -> TPrimExp Bool num
dynamicEqualsLMADDim)
      TPrimExp Bool num
forall v. TPrimExp Bool v
true
      ([LMADDim (TPrimExp t num)]
-> [LMADDim (TPrimExp t num)]
-> [(LMADDim (TPrimExp t num), LMADDim (TPrimExp t num))]
forall a b. [a] -> [b] -> [(a, b)]
zip (LMAD (TPrimExp t num) -> [LMADDim (TPrimExp t num)]
forall num. LMAD num -> [LMADDim num]
dims LMAD (TPrimExp t num)
lmad1) (LMAD (TPrimExp t num) -> [LMADDim (TPrimExp t num)]
forall num. LMAD num -> [LMADDim num]
dims LMAD (TPrimExp t num)
lmad2))
{-# NOINLINE dynamicEqualsLMAD #-}

-- | Returns true if two 'LMAD's are equivalent.
--
-- Equivalence in this case is matching in offsets and strides.
equivalent :: (Eq num) => LMAD num -> LMAD num -> Bool
equivalent :: forall num. Eq num => LMAD num -> LMAD num -> Bool
equivalent LMAD num
lmad1 LMAD num
lmad2 =
  [LMADDim num] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
dims LMAD num
lmad1) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [LMADDim num] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
dims LMAD num
lmad2)
    Bool -> Bool -> Bool
&& LMAD num -> num
forall num. LMAD num -> num
offset LMAD num
lmad1 num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== LMAD num -> num
forall num. LMAD num -> num
offset LMAD num
lmad2
    Bool -> Bool -> Bool
&& (LMADDim num -> num) -> [LMADDim num] -> [num]
forall a b. (a -> b) -> [a] -> [b]
map LMADDim num -> num
forall num. LMADDim num -> num
ldStride (LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
dims LMAD num
lmad1) [num] -> [num] -> Bool
forall a. Eq a => a -> a -> Bool
== (LMADDim num -> num) -> [LMADDim num] -> [num]
forall a b. (a -> b) -> [a] -> [b]
map LMADDim num -> num
forall num. LMADDim num -> num
ldStride (LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
dims LMAD num
lmad2)
{-# NOINLINE equivalent #-}

-- | Is this is a row-major array with zero offset?
isDirect :: (Eq num, IntegralExp num) => LMAD num -> Bool
isDirect :: forall num. (Eq num, IntegralExp num) => LMAD num -> Bool
isDirect LMAD num
lmad = LMAD num
lmad LMAD num -> LMAD num -> Bool
forall a. Eq a => a -> a -> Bool
== num -> [num] -> LMAD num
forall num. IntegralExp num => num -> [num] -> LMAD num
iota num
0 ((LMADDim num -> num) -> [LMADDim num] -> [num]
forall a b. (a -> b) -> [a] -> [b]
map LMADDim num -> num
forall num. LMADDim num -> num
ldShape ([LMADDim num] -> [num]) -> [LMADDim num] -> [num]
forall a b. (a -> b) -> a -> b
$ LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
dims LMAD num
lmad)
{-# NOINLINE isDirect #-}