{-# OPTIONS_GHC -fno-warn-redundant-constraints #-}

-- | This module contains a representation for the index function based on
-- linear-memory accessor descriptors; see Zhu, Hoeflinger and David work.
module Futhark.IR.Mem.IxFun
  ( IxFun (..),
    Shape,
    LMAD (..),
    LMADDim (..),
    Monotonicity (..),
    index,
    mkExistential,
    iota,
    iotaOffset,
    permute,
    reshape,
    coerce,
    slice,
    flatSlice,
    rebase,
    shape,
    lmadShape,
    rank,
    linearWithOffset,
    rearrangeWithOffset,
    isDirect,
    isLinear,
    substituteInIxFun,
    substituteInLMAD,
    existentialize,
    closeEnough,
    equivalent,
    hasOneLmad,
    permuteInv,
    conservativeFlatten,
    disjoint,
    disjoint2,
    disjoint3,
    dynamicEqualsLMAD,
  )
where

import Control.Category
import Control.Monad.Identity
import Control.Monad.State
import Control.Monad.Writer
import Data.Function (on, (&))
import Data.List (elemIndex, partition, sort, sortBy, zip4, zipWith4)
import Data.List.NonEmpty (NonEmpty (..))
import Data.List.NonEmpty qualified as NE
import Data.Map.Strict qualified as M
import Data.Maybe (fromJust, isJust, isNothing)
import Futhark.Analysis.AlgSimplify qualified as AlgSimplify
import Futhark.Analysis.PrimExp
import Futhark.Analysis.PrimExp.Convert
import Futhark.IR.Mem.Interval
import Futhark.IR.Prop
import Futhark.IR.Syntax
  ( DimIndex (..),
    FlatDimIndex (..),
    FlatSlice (..),
    Slice (..),
    Type,
    dimFix,
    flatSliceDims,
    flatSliceStrides,
    unitSlice,
  )
import Futhark.IR.Syntax.Core (Ext (..), VName (..))
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import Futhark.Util
import Futhark.Util.IntegralExp
import Futhark.Util.Pretty
import Prelude hiding (gcd, id, mod, (.))

-- | The shape of an index function.
type Shape num = [num]

type Indices num = [num]

type Permutation = [Int]

-- | The physical element ordering alongside a dimension, i.e. the
-- sign of the stride.
data Monotonicity
  = -- | Increasing.
    Inc
  | -- | Decreasing.
    Dec
  | -- | Unknown.
    Unknown
  deriving (Int -> Monotonicity -> ShowS
[Monotonicity] -> ShowS
Monotonicity -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Monotonicity] -> ShowS
$cshowList :: [Monotonicity] -> ShowS
show :: Monotonicity -> String
$cshow :: Monotonicity -> String
showsPrec :: Int -> Monotonicity -> ShowS
$cshowsPrec :: Int -> Monotonicity -> ShowS
Show, Monotonicity -> Monotonicity -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Monotonicity -> Monotonicity -> Bool
$c/= :: Monotonicity -> Monotonicity -> Bool
== :: Monotonicity -> Monotonicity -> Bool
$c== :: Monotonicity -> Monotonicity -> Bool
Eq)

-- | A single dimension in an 'LMAD'.
data LMADDim num = LMADDim
  { forall num. LMADDim num -> num
ldStride :: num,
    forall num. LMADDim num -> num
ldShape :: num,
    forall num. LMADDim num -> Int
ldPerm :: Int,
    forall num. LMADDim num -> Monotonicity
ldMon :: Monotonicity
  }
  deriving (Int -> LMADDim num -> ShowS
forall num. Show num => Int -> LMADDim num -> ShowS
forall num. Show num => [LMADDim num] -> ShowS
forall num. Show num => LMADDim num -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [LMADDim num] -> ShowS
$cshowList :: forall num. Show num => [LMADDim num] -> ShowS
show :: LMADDim num -> String
$cshow :: forall num. Show num => LMADDim num -> String
showsPrec :: Int -> LMADDim num -> ShowS
$cshowsPrec :: forall num. Show num => Int -> LMADDim num -> ShowS
Show, LMADDim num -> LMADDim num -> Bool
forall num. Eq num => LMADDim num -> LMADDim num -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: LMADDim num -> LMADDim num -> Bool
$c/= :: forall num. Eq num => LMADDim num -> LMADDim num -> Bool
== :: LMADDim num -> LMADDim num -> Bool
$c== :: forall num. Eq num => LMADDim num -> LMADDim num -> Bool
Eq)

instance Ord Monotonicity where
  <= :: Monotonicity -> Monotonicity -> Bool
(<=) Monotonicity
_ Monotonicity
Inc = Bool
True
  (<=) Monotonicity
Unknown Monotonicity
_ = Bool
True
  (<=) Monotonicity
_ Monotonicity
Unknown = Bool
False
  (<=) Monotonicity
Inc Monotonicity
Dec = Bool
False
  (<=) Monotonicity
_ Monotonicity
Dec = Bool
True

instance Ord num => Ord (LMADDim num) where
  (LMADDim num
s1 num
q1 Int
p1 Monotonicity
m1) <= :: LMADDim num -> LMADDim num -> Bool
<= (LMADDim num
s2 num
q2 Int
p2 Monotonicity
m2) =
    ([num
q1, num
s1] forall a. Ord a => a -> a -> Bool
< [num
q2, num
s2])
      Bool -> Bool -> Bool
|| ( ([num
q1, num
s1] forall a. Eq a => a -> a -> Bool
== [num
q2, num
s2])
             Bool -> Bool -> Bool
&& ( (Int
p1 forall a. Ord a => a -> a -> Bool
< Int
p2)
                    Bool -> Bool -> Bool
|| ( (Int
p1 forall a. Eq a => a -> a -> Bool
== Int
p2)
                           Bool -> Bool -> Bool
&& (Monotonicity
m1 forall a. Ord a => a -> a -> Bool
<= Monotonicity
m2)
                       )
                )
         )

-- | LMAD's representation consists of a general offset and for each dimension a
-- stride, number of elements (or shape), permutation, and
-- monotonicity. Note that the permutation is not strictly necessary in that the
-- permutation can be performed directly on LMAD dimensions, but then it is
-- difficult to extract the permutation back from an LMAD.
--
-- LMAD algebra is closed under composition w.r.t. operators such as
-- permute, index and slice.  However, other operations, such as
-- reshape, cannot always be represented inside the LMAD algebra.
--
-- It follows that the general representation of an index function is a list of
-- LMADS, in which each following LMAD in the list implicitly corresponds to an
-- irregular reshaping operation.
--
-- However, we expect that the common case is when the index function is one
-- LMAD -- we call this the "nice" representation.
--
-- Finally, the list of LMADs is kept in an @IxFun@ together with the shape of
-- the original array, and a bit to indicate whether the index function is
-- contiguous, i.e., if we instantiate all the points of the current index
-- function, do we get a contiguous memory interval?
--
-- By definition, the LMAD \( \sigma + \{ (n_1, s_1), \ldots, (n_k, s_k) \} \),
-- where \(n\) and \(s\) denote the shape and stride of each dimension, denotes
-- the set of points:
--
-- \[
--    \{ ~ \sigma + i_1 * s_1 + \ldots + i_m * s_m ~ | ~ 0 \leq i_1 < n_1, \ldots, 0 \leq i_m < n_m ~ \}
-- \]
data LMAD num = LMAD
  { forall num. LMAD num -> num
lmadOffset :: num,
    forall num. LMAD num -> [LMADDim num]
lmadDims :: [LMADDim num]
  }
  deriving (Int -> LMAD num -> ShowS
forall num. Show num => Int -> LMAD num -> ShowS
forall num. Show num => [LMAD num] -> ShowS
forall num. Show num => LMAD num -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [LMAD num] -> ShowS
$cshowList :: forall num. Show num => [LMAD num] -> ShowS
show :: LMAD num -> String
$cshow :: forall num. Show num => LMAD num -> String
showsPrec :: Int -> LMAD num -> ShowS
$cshowsPrec :: forall num. Show num => Int -> LMAD num -> ShowS
Show, LMAD num -> LMAD num -> Bool
forall num. Eq num => LMAD num -> LMAD num -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: LMAD num -> LMAD num -> Bool
$c/= :: forall num. Eq num => LMAD num -> LMAD num -> Bool
== :: LMAD num -> LMAD num -> Bool
$c== :: forall num. Eq num => LMAD num -> LMAD num -> Bool
Eq, LMAD num -> LMAD num -> Bool
LMAD num -> LMAD num -> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {num}. Ord num => Eq (LMAD num)
forall num. Ord num => LMAD num -> LMAD num -> Bool
forall num. Ord num => LMAD num -> LMAD num -> Ordering
forall num. Ord num => LMAD num -> LMAD num -> LMAD num
min :: LMAD num -> LMAD num -> LMAD num
$cmin :: forall num. Ord num => LMAD num -> LMAD num -> LMAD num
max :: LMAD num -> LMAD num -> LMAD num
$cmax :: forall num. Ord num => LMAD num -> LMAD num -> LMAD num
>= :: LMAD num -> LMAD num -> Bool
$c>= :: forall num. Ord num => LMAD num -> LMAD num -> Bool
> :: LMAD num -> LMAD num -> Bool
$c> :: forall num. Ord num => LMAD num -> LMAD num -> Bool
<= :: LMAD num -> LMAD num -> Bool
$c<= :: forall num. Ord num => LMAD num -> LMAD num -> Bool
< :: LMAD num -> LMAD num -> Bool
$c< :: forall num. Ord num => LMAD num -> LMAD num -> Bool
compare :: LMAD num -> LMAD num -> Ordering
$ccompare :: forall num. Ord num => LMAD num -> LMAD num -> Ordering
Ord)

-- | An index function is a mapping from a multidimensional array
-- index space (the domain) to a one-dimensional memory index space.
-- Essentially, it explains where the element at position @[i,j,p]@ of
-- some array is stored inside the flat one-dimensional array that
-- constitutes its memory.  For example, we can use this to
-- distinguish row-major and column-major representations.
--
-- An index function is represented as a sequence of 'LMAD's.
data IxFun num = IxFun
  { forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs :: NonEmpty (LMAD num),
    -- | the shape of the support array, i.e., the original array
    --   that birthed (is the start point) of this index function.
    forall num. IxFun num -> Shape num
base :: Shape num,
    -- | ignoring permutations, is the index function contiguous?
    forall num. IxFun num -> Bool
contiguous :: Bool
  }
  deriving (Int -> IxFun num -> ShowS
forall num. Show num => Int -> IxFun num -> ShowS
forall num. Show num => [IxFun num] -> ShowS
forall num. Show num => IxFun num -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [IxFun num] -> ShowS
$cshowList :: forall num. Show num => [IxFun num] -> ShowS
show :: IxFun num -> String
$cshow :: forall num. Show num => IxFun num -> String
showsPrec :: Int -> IxFun num -> ShowS
$cshowsPrec :: forall num. Show num => Int -> IxFun num -> ShowS
Show, IxFun num -> IxFun num -> Bool
forall num. Eq num => IxFun num -> IxFun num -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: IxFun num -> IxFun num -> Bool
$c/= :: forall num. Eq num => IxFun num -> IxFun num -> Bool
== :: IxFun num -> IxFun num -> Bool
$c== :: forall num. Eq num => IxFun num -> IxFun num -> Bool
Eq)

instance Pretty Monotonicity where
  pretty :: forall ann. Monotonicity -> Doc ann
pretty = forall a ann. Pretty a => a -> Doc ann
pretty forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. Show a => a -> String
show

instance Pretty num => Pretty (LMAD num) where
  pretty :: forall ann. LMAD num -> Doc ann
pretty (LMAD num
offset [LMADDim num]
dims) =
    forall ann. Doc ann -> Doc ann
braces forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. [Doc a] -> Doc a
semistack forall a b. (a -> b) -> a -> b
$
      [ Doc ann
"offset:" forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall ann. Doc ann -> Doc ann
group (forall a ann. Pretty a => a -> Doc ann
pretty num
offset),
        Doc ann
"strides:" forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall {b} {ann}. Pretty b => (LMADDim num -> b) -> Doc ann
p forall num. LMADDim num -> num
ldStride,
        Doc ann
"shape:" forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall {b} {ann}. Pretty b => (LMADDim num -> b) -> Doc ann
p forall num. LMADDim num -> num
ldShape,
        Doc ann
"permutation:" forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall {b} {ann}. Pretty b => (LMADDim num -> b) -> Doc ann
p forall num. LMADDim num -> Int
ldPerm,
        Doc ann
"monotonicity:" forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall {b} {ann}. Pretty b => (LMADDim num -> b) -> Doc ann
p forall num. LMADDim num -> Monotonicity
ldMon
      ]
    where
      p :: (LMADDim num -> b) -> Doc ann
p LMADDim num -> b
f = forall ann. Doc ann -> Doc ann
group forall a b. (a -> b) -> a -> b
$ forall ann. Doc ann -> Doc ann
brackets forall a b. (a -> b) -> a -> b
$ forall ann. Doc ann -> Doc ann
align forall a b. (a -> b) -> a -> b
$ forall a. [Doc a] -> Doc a
commasep forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall a ann. Pretty a => a -> Doc ann
pretty forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. LMADDim num -> b
f) [LMADDim num]
dims

instance Pretty num => Pretty (IxFun num) where
  pretty :: forall ann. IxFun num -> Doc ann
pretty (IxFun NonEmpty (LMAD num)
lmads Shape num
oshp Bool
cg) =
    forall ann. Doc ann -> Doc ann
braces forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. [Doc a] -> Doc a
semistack forall a b. (a -> b) -> a -> b
$
      [ Doc ann
"base:" forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall ann. Doc ann -> Doc ann
brackets (forall a. [Doc a] -> Doc a
commasep forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty Shape num
oshp),
        Doc ann
"contiguous:" forall ann. Doc ann -> Doc ann -> Doc ann
<+> if Bool
cg then Doc ann
"true" else Doc ann
"false",
        Doc ann
"LMADs:" forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall ann. Doc ann -> Doc ann
brackets (forall a. [Doc a] -> Doc a
commastack forall a b. (a -> b) -> a -> b
$ forall a. NonEmpty a -> [a]
NE.toList forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> NonEmpty a -> NonEmpty b
NE.map forall a ann. Pretty a => a -> Doc ann
pretty NonEmpty (LMAD num)
lmads)
      ]

instance Substitute num => Substitute (LMAD num) where
  substituteNames :: Map VName VName -> LMAD num -> LMAD num
substituteNames Map VName VName
substs = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a -> b) -> a -> b
$ forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs

instance Substitute num => Substitute (IxFun num) where
  substituteNames :: Map VName VName -> IxFun num -> IxFun num
substituteNames Map VName VName
substs = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a -> b) -> a -> b
$ forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs

instance Substitute num => Rename (LMAD num) where
  rename :: LMAD num -> RenameM (LMAD num)
rename = forall a. Substitute a => a -> RenameM a
substituteRename

instance Substitute num => Rename (IxFun num) where
  rename :: IxFun num -> RenameM (IxFun num)
rename = forall a. Substitute a => a -> RenameM a
substituteRename

instance FreeIn num => FreeIn (LMAD num) where
  freeIn' :: LMAD num -> FV
freeIn' = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap forall a. FreeIn a => a -> FV
freeIn'

instance FreeIn num => FreeIn (IxFun num) where
  freeIn' :: IxFun num -> FV
freeIn' = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap forall a. FreeIn a => a -> FV
freeIn'

instance FreeIn num => FreeIn (LMADDim num) where
  freeIn' :: LMADDim num -> FV
freeIn' (LMADDim num
s num
n Int
_ Monotonicity
_) = forall a. FreeIn a => a -> FV
freeIn' num
s forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' num
n

instance Functor LMAD where
  fmap :: forall a b. (a -> b) -> LMAD a -> LMAD b
fmap a -> b
f = forall a. Identity a -> a
runIdentity forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. a -> b
f)

instance Functor IxFun where
  fmap :: forall a b. (a -> b) -> IxFun a -> IxFun b
fmap a -> b
f = forall a. Identity a -> a
runIdentity forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. a -> b
f)

instance Foldable LMAD where
  foldMap :: forall m a. Monoid m => (a -> m) -> LMAD a -> m
foldMap a -> m
f = forall w a. Writer w a -> w
execWriter forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. a -> m
f)

instance Foldable IxFun where
  foldMap :: forall m a. Monoid m => (a -> m) -> IxFun a -> m
foldMap a -> m
f = forall w a. Writer w a -> w
execWriter forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. a -> m
f)

instance Traversable LMAD where
  traverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> LMAD a -> f (LMAD b)
traverse a -> f b
f (LMAD a
offset [LMADDim a]
dims) =
    forall num. num -> [LMADDim num] -> LMAD num
LMAD forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> f b
f a
offset forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse LMADDim a -> f (LMADDim b)
f' [LMADDim a]
dims
    where
      f' :: LMADDim a -> f (LMADDim b)
f' (LMADDim a
s a
n Int
p Monotonicity
m) = forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> f b
f a
s forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> a -> f b
f a
n forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
p forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Monotonicity
m

-- It is important that the traversal order here is the same as in
-- mkExistential.
instance Traversable IxFun where
  traverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> IxFun a -> f (IxFun b)
traverse a -> f b
f (IxFun NonEmpty (LMAD a)
lmads Shape a
oshp Bool
cg) =
    forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse a -> f b
f) NonEmpty (LMAD a)
lmads forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse a -> f b
f Shape a
oshp forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
cg

(++@) :: [a] -> NonEmpty a -> NonEmpty a
[a]
es ++@ :: forall a. [a] -> NonEmpty a -> NonEmpty a
++@ (a
ne :| [a]
nes) = case [a]
es of
  a
e : [a]
es' -> a
e forall a. a -> [a] -> NonEmpty a
:| [a]
es' forall a. [a] -> [a] -> [a]
++ [a
ne] forall a. [a] -> [a] -> [a]
++ [a]
nes
  [] -> a
ne forall a. a -> [a] -> NonEmpty a
:| [a]
nes

(@++@) :: NonEmpty a -> NonEmpty a -> NonEmpty a
(a
x :| [a]
xs) @++@ :: forall a. NonEmpty a -> NonEmpty a -> NonEmpty a
@++@ (a
y :| [a]
ys) = a
x forall a. a -> [a] -> NonEmpty a
:| [a]
xs forall a. [a] -> [a] -> [a]
++ [a
y] forall a. [a] -> [a] -> [a]
++ [a]
ys

invertMonotonicity :: Monotonicity -> Monotonicity
invertMonotonicity :: Monotonicity -> Monotonicity
invertMonotonicity Monotonicity
Inc = Monotonicity
Dec
invertMonotonicity Monotonicity
Dec = Monotonicity
Inc
invertMonotonicity Monotonicity
Unknown = Monotonicity
Unknown

lmadPermutation :: LMAD num -> Permutation
lmadPermutation :: forall num. LMAD num -> Permutation
lmadPermutation = forall a b. (a -> b) -> [a] -> [b]
map forall num. LMADDim num -> Int
ldPerm forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall num. LMAD num -> [LMADDim num]
lmadDims

setLMADPermutation :: Permutation -> LMAD num -> LMAD num
setLMADPermutation :: forall num. Permutation -> LMAD num -> LMAD num
setLMADPermutation Permutation
perm LMAD num
lmad =
  LMAD num
lmad {lmadDims :: [LMADDim num]
lmadDims = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\LMADDim num
dim Int
p -> LMADDim num
dim {ldPerm :: Int
ldPerm = Int
p}) (forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad) Permutation
perm}

setLMADShape :: Shape num -> LMAD num -> LMAD num
setLMADShape :: forall num. Shape num -> LMAD num -> LMAD num
setLMADShape Shape num
shp LMAD num
lmad = LMAD num
lmad {lmadDims :: [LMADDim num]
lmadDims = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\LMADDim num
dim num
s -> LMADDim num
dim {ldShape :: num
ldShape = num
s}) (forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad) Shape num
shp}

-- | Substitute a name with a PrimExp in an LMAD.
substituteInLMAD ::
  Ord a =>
  M.Map a (TPrimExp t a) ->
  LMAD (TPrimExp t a) ->
  LMAD (TPrimExp t a)
substituteInLMAD :: forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a) -> LMAD (TPrimExp t a) -> LMAD (TPrimExp t a)
substituteInLMAD Map a (TPrimExp t a)
tab (LMAD TPrimExp t a
offset [LMADDim (TPrimExp t a)]
dims) =
  let offset' :: TPrimExp t a
offset' = forall {k} {k} {t :: k} {t :: k}. TPrimExp t a -> TPrimExp t a
sub TPrimExp t a
offset
      dims' :: [LMADDim (TPrimExp t a)]
dims' =
        forall a b. (a -> b) -> [a] -> [b]
map
          ( \(LMADDim TPrimExp t a
s TPrimExp t a
n Int
p Monotonicity
m) ->
              forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim
                (forall {k} {k} {t :: k} {t :: k}. TPrimExp t a -> TPrimExp t a
sub TPrimExp t a
s)
                (forall {k} {k} {t :: k} {t :: k}. TPrimExp t a -> TPrimExp t a
sub TPrimExp t a
n)
                Int
p
                Monotonicity
m
          )
          [LMADDim (TPrimExp t a)]
dims
   in forall num. num -> [LMADDim num] -> LMAD num
LMAD forall {k} {t :: k}. TPrimExp t a
offset' forall {k} {t :: k}. [LMADDim (TPrimExp t a)]
dims'
  where
    tab' :: Map a (PrimExp a)
tab' = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped Map a (TPrimExp t a)
tab
    sub :: TPrimExp t a -> TPrimExp t a
sub = forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall v. Ord v => Map v (PrimExp v) -> PrimExp v -> PrimExp v
substituteInPrimExp Map a (PrimExp a)
tab' forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped

-- | Substitute a name with a PrimExp in an index function.
substituteInIxFun ::
  Ord a =>
  M.Map a (TPrimExp t a) ->
  IxFun (TPrimExp t a) ->
  IxFun (TPrimExp t a)
substituteInIxFun :: forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a)
-> IxFun (TPrimExp t a) -> IxFun (TPrimExp t a)
substituteInIxFun Map a (TPrimExp t a)
tab (IxFun NonEmpty (LMAD (TPrimExp t a))
lmads Shape (TPrimExp t a)
oshp Bool
cg) =
  forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun
    (forall a b. (a -> b) -> NonEmpty a -> NonEmpty b
NE.map (forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a) -> LMAD (TPrimExp t a) -> LMAD (TPrimExp t a)
substituteInLMAD Map a (TPrimExp t a)
tab) NonEmpty (LMAD (TPrimExp t a))
lmads)
    (forall a b. (a -> b) -> [a] -> [b]
map (forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall v. Ord v => Map v (PrimExp v) -> PrimExp v -> PrimExp v
substituteInPrimExp Map a (PrimExp a)
tab' forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped) Shape (TPrimExp t a)
oshp)
    Bool
cg
  where
    tab' :: Map a (PrimExp a)
tab' = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped Map a (TPrimExp t a)
tab

-- | Is this is a row-major array?
isDirect :: (Eq num, IntegralExp num) => IxFun num -> Bool
isDirect :: forall num. (Eq num, IntegralExp num) => IxFun num -> Bool
isDirect ixfun :: IxFun num
ixfun@(IxFun (LMAD num
offset [LMADDim num]
dims :| []) Shape num
oshp Bool
True) =
  let strides_expected :: Shape num
strides_expected = forall a. [a] -> [a]
reverse forall a b. (a -> b) -> a -> b
$ forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl forall a. Num a => a -> a -> a
(*) num
1 (forall a. [a] -> [a]
reverse (forall a. [a] -> [a]
tail Shape num
oshp))
   in forall num. IxFun num -> Bool
hasContiguousPerm IxFun num
ixfun
        Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape num
oshp forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims
        Bool -> Bool -> Bool
&& num
offset forall a. Eq a => a -> a -> Bool
== num
0
        Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all
          (\(LMADDim num
s num
n Int
p Monotonicity
_, Int
m, num
d, num
se) -> num
s forall a. Eq a => a -> a -> Bool
== num
se Bool -> Bool -> Bool
&& num
n forall a. Eq a => a -> a -> Bool
== num
d Bool -> Bool -> Bool
&& Int
p forall a. Eq a => a -> a -> Bool
== Int
m)
          (forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [LMADDim num]
dims [Int
0 .. forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims forall a. Num a => a -> a -> a
- Int
1] Shape num
oshp Shape num
strides_expected)
isDirect IxFun num
_ = Bool
False

-- | Is index function "analyzable", i.e., consists of one LMAD
hasOneLmad :: IxFun num -> Bool
hasOneLmad :: forall num. IxFun num -> Bool
hasOneLmad (IxFun (LMAD num
_ :| []) Shape num
_ Bool
_) = Bool
True
hasOneLmad IxFun num
_ = Bool
False

-- | Does the index function have an ascending permutation?
hasContiguousPerm :: IxFun num -> Bool
hasContiguousPerm :: forall num. IxFun num -> Bool
hasContiguousPerm (IxFun (LMAD num
lmad :| []) Shape num
_ Bool
_) =
  let perm :: Permutation
perm = forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad
   in Permutation
perm forall a. Eq a => a -> a -> Bool
== forall a. Ord a => [a] -> [a]
sort Permutation
perm
hasContiguousPerm IxFun num
_ = Bool
False

-- | The index space of the index function.  This is the same as the
-- shape of arrays that the index function supports.
shape :: (Eq num, IntegralExp num) => IxFun num -> Shape num
shape :: forall num. (Eq num, IntegralExp num) => IxFun num -> Shape num
shape (IxFun (LMAD num
lmad :| [LMAD num]
_) Shape num
_ Bool
_) =
  forall a. Permutation -> [a] -> [a]
permuteFwd (forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad) forall a b. (a -> b) -> a -> b
$ forall num. (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShapeBase LMAD num
lmad

-- | Shape of an LMAD.
lmadShape :: (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShape :: forall num. (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShape LMAD num
lmad = forall a. Permutation -> [a] -> [a]
permuteInv (forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad) forall a b. (a -> b) -> a -> b
$ forall num. (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShapeBase LMAD num
lmad

-- | Shape of an LMAD, ignoring permutations.
lmadShapeBase :: (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShapeBase :: forall num. (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShapeBase = forall a b. (a -> b) -> [a] -> [b]
map forall num. LMADDim num -> num
ldShape forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall num. LMAD num -> [LMADDim num]
lmadDims

-- | Compute the flat memory index for a complete set @inds@ of array indices
-- and a certain element size @elem_size@.
index ::
  (IntegralExp num, Eq num) =>
  IxFun num ->
  Indices num ->
  num
index :: forall num.
(IntegralExp num, Eq num) =>
IxFun num -> Indices num -> num
index = forall num.
(IntegralExp num, Eq num) =>
NonEmpty (LMAD num) -> Indices num -> num
indexFromLMADs forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs
  where
    indexFromLMADs ::
      (IntegralExp num, Eq num) =>
      NonEmpty (LMAD num) ->
      Indices num ->
      num
    indexFromLMADs :: forall num.
(IntegralExp num, Eq num) =>
NonEmpty (LMAD num) -> Indices num -> num
indexFromLMADs (LMAD num
lmad :| []) Indices num
inds = forall num.
(IntegralExp num, Eq num) =>
LMAD num -> Indices num -> num
indexLMAD LMAD num
lmad Indices num
inds
    indexFromLMADs (LMAD num
lmad1 :| LMAD num
lmad2 : [LMAD num]
lmads) Indices num
inds =
      let i_flat :: num
i_flat = forall num.
(IntegralExp num, Eq num) =>
LMAD num -> Indices num -> num
indexLMAD LMAD num
lmad1 Indices num
inds
          new_inds :: Indices num
new_inds = forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex (forall a. Permutation -> [a] -> [a]
permuteFwd (forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad2) forall a b. (a -> b) -> a -> b
$ forall num. (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShapeBase LMAD num
lmad2) num
i_flat
       in forall num.
(IntegralExp num, Eq num) =>
NonEmpty (LMAD num) -> Indices num -> num
indexFromLMADs (LMAD num
lmad2 forall a. a -> [a] -> NonEmpty a
:| [LMAD num]
lmads) Indices num
new_inds
    indexLMAD ::
      (IntegralExp num, Eq num) =>
      LMAD num ->
      Indices num ->
      num
    indexLMAD :: forall num.
(IntegralExp num, Eq num) =>
LMAD num -> Indices num -> num
indexLMAD lmad :: LMAD num
lmad@(LMAD num
off [LMADDim num]
dims) Indices num
inds =
      let prod :: num
prod =
            forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum forall a b. (a -> b) -> a -> b
$
              forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
                forall num. (Eq num, IntegralExp num) => num -> num -> num
flatOneDim
                (forall a b. (a -> b) -> [a] -> [b]
map forall num. LMADDim num -> num
ldStride [LMADDim num]
dims)
                (forall a. Permutation -> [a] -> [a]
permuteInv (forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad) Indices num
inds)
       in num
off forall a. Num a => a -> a -> a
+ num
prod

-- | iota with offset.
iotaOffset :: IntegralExp num => num -> Shape num -> IxFun num
iotaOffset :: forall num. IntegralExp num => num -> Shape num -> IxFun num
iotaOffset num
o Shape num
ns = forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (forall num.
IntegralExp num =>
Monotonicity -> num -> [num] -> LMAD num
makeRotIota Monotonicity
Inc num
o Shape num
ns forall a. a -> [a] -> NonEmpty a
:| []) Shape num
ns Bool
True

-- | iota.
iota :: IntegralExp num => Shape num -> IxFun num
iota :: forall num. IntegralExp num => Shape num -> IxFun num
iota = forall num. IntegralExp num => num -> Shape num -> IxFun num
iotaOffset num
0

-- | Create a contiguous single-LMAD index function that is
-- existential in everything, with the provided permutation,
-- monotonicity, and contiguousness.
mkExistential :: Int -> [(Int, Monotonicity)] -> Bool -> Int -> IxFun (Ext a)
mkExistential :: forall a.
Int -> [(Int, Monotonicity)] -> Bool -> Int -> IxFun (Ext a)
mkExistential Int
basis_rank [(Int, Monotonicity)]
perm Bool
contig Int
start =
  forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (forall a. a -> NonEmpty a
NE.singleton forall {a}. LMAD (Ext a)
lmad) forall {a}. [Ext a]
basis Bool
contig
  where
    basis :: [Ext a]
basis = forall a. Int -> [a] -> [a]
take Int
basis_rank forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a. Int -> Ext a
Ext [Int
start forall a. Num a => a -> a -> a
+ Int
1 forall a. Num a => a -> a -> a
+ Int
dims_rank forall a. Num a => a -> a -> a
* Int
2 ..]
    dims_rank :: Int
dims_rank = forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Int, Monotonicity)]
perm
    lmad :: LMAD (Ext a)
lmad = forall num. num -> [LMADDim num] -> LMAD num
LMAD (forall a. Int -> Ext a
Ext Int
start) forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall {a}. (Int, Monotonicity) -> Int -> LMADDim (Ext a)
onDim [(Int, Monotonicity)]
perm [Int
0 ..]
    onDim :: (Int, Monotonicity) -> Int -> LMADDim (Ext a)
onDim (Int
p, Monotonicity
mon) Int
i =
      forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim (forall a. Int -> Ext a
Ext (Int
start forall a. Num a => a -> a -> a
+ Int
1 forall a. Num a => a -> a -> a
+ Int
i forall a. Num a => a -> a -> a
* Int
2)) (forall a. Int -> Ext a
Ext (Int
start forall a. Num a => a -> a -> a
+ Int
2 forall a. Num a => a -> a -> a
+ Int
i forall a. Num a => a -> a -> a
* Int
2)) Int
p Monotonicity
mon

-- | Permute dimensions.
permute ::
  IntegralExp num =>
  IxFun num ->
  Permutation ->
  IxFun num
permute :: forall num.
IntegralExp num =>
IxFun num -> Permutation -> IxFun num
permute (IxFun (LMAD num
lmad :| [LMAD num]
lmads) Shape num
oshp Bool
cg) Permutation
perm_new =
  let perm_cur :: Permutation
perm_cur = forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad
      perm :: Permutation
perm = forall a b. (a -> b) -> [a] -> [b]
map (Permutation
perm_cur !!) Permutation
perm_new
   in forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (forall num. Permutation -> LMAD num -> LMAD num
setLMADPermutation Permutation
perm LMAD num
lmad forall a. a -> [a] -> NonEmpty a
:| [LMAD num]
lmads) Shape num
oshp Bool
cg

-- | Handle the case where a slice can stay within a single LMAD.
sliceOneLMAD ::
  (Eq num, IntegralExp num) =>
  IxFun num ->
  Slice num ->
  Maybe (IxFun num)
sliceOneLMAD :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> Maybe (IxFun num)
sliceOneLMAD (IxFun (lmad :: LMAD num
lmad@(LMAD num
_ [LMADDim num]
ldims) :| [LMAD num]
lmads) Shape num
oshp Bool
cg) (Slice [DimIndex num]
is) = do
  let perm :: Permutation
perm = forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad
      is' :: [DimIndex num]
is' = forall a. Permutation -> [a] -> [a]
permuteInv Permutation
perm [DimIndex num]
is
      cg' :: Bool
cg' = Bool
cg Bool -> Bool -> Bool
&& forall num.
(Eq num, IntegralExp num) =>
LMAD num -> Slice num -> Bool
slicePreservesContiguous LMAD num
lmad (forall d. [DimIndex d] -> Slice d
Slice [DimIndex num]
is')
  let lmad' :: LMAD num
lmad' = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl forall num.
(Eq num, IntegralExp num) =>
LMAD num -> (DimIndex num, LMADDim num) -> LMAD num
sliceOne (forall num. num -> [LMADDim num] -> LMAD num
LMAD (forall num. LMAD num -> num
lmadOffset LMAD num
lmad) []) forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [DimIndex num]
is' [LMADDim num]
ldims
      -- need to remove the fixed dims from the permutation
      perm' :: Permutation
perm' =
        forall {t :: * -> *} {b} {t :: * -> *}.
(Foldable t, Ord b, Foldable t, Num b) =>
t b -> t b -> [b]
updatePerm Permutation
perm forall a b. (a -> b) -> a -> b
$
          forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$
            forall a. (a -> Bool) -> [a] -> [a]
filter (forall a. Maybe a -> Bool
isJust forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall d. DimIndex d -> Maybe d
dimFix forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b. (a, b) -> b
snd) forall a b. (a -> b) -> a -> b
$
              forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 .. forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex num]
is' forall a. Num a => a -> a -> a
- Int
1] [DimIndex num]
is'

  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (forall num. Permutation -> LMAD num -> LMAD num
setLMADPermutation Permutation
perm' LMAD num
lmad' forall a. a -> [a] -> NonEmpty a
:| [LMAD num]
lmads) Shape num
oshp Bool
cg'
  where
    updatePerm :: t b -> t b -> [b]
updatePerm t b
ps t b
inds = forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap b -> [b]
decrease t b
ps
      where
        decrease :: b -> [b]
decrease b
p =
          let f :: a -> b -> a
f a
n b
i
                | b
i forall a. Eq a => a -> a -> Bool
== b
p = -a
1
                | b
i forall a. Ord a => a -> a -> Bool
> b
p = a
n
                | a
n forall a. Eq a => a -> a -> Bool
/= -a
1 = a
n forall a. Num a => a -> a -> a
+ a
1
                | Bool
otherwise = a
n
              d :: b
d = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl forall {a}. (Num a, Eq a) => a -> b -> a
f b
0 t b
inds
           in [b
p forall a. Num a => a -> a -> a
- b
d | b
d forall a. Eq a => a -> a -> Bool
/= -b
1]

    -- XXX: TODO: what happens to r on a negative-stride slice; is there
    -- such a case?
    sliceOne ::
      (Eq num, IntegralExp num) =>
      LMAD num ->
      (DimIndex num, LMADDim num) ->
      LMAD num
    sliceOne :: forall num.
(Eq num, IntegralExp num) =>
LMAD num -> (DimIndex num, LMADDim num) -> LMAD num
sliceOne (LMAD num
off [LMADDim num]
dims) (DimFix num
i, LMADDim num
s num
_x Int
_ Monotonicity
_) =
      forall num. num -> [LMADDim num] -> LMAD num
LMAD (num
off forall a. Num a => a -> a -> a
+ forall num. (Eq num, IntegralExp num) => num -> num -> num
flatOneDim num
s num
i) [LMADDim num]
dims
    sliceOne (LMAD num
off [LMADDim num]
dims) (DimSlice num
_ num
ne num
_, LMADDim num
0 num
_ Int
p Monotonicity
_) =
      forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off ([LMADDim num]
dims forall a. [a] -> [a] -> [a]
++ [forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim num
0 num
ne Int
p Monotonicity
Unknown])
    sliceOne (LMAD num
off [LMADDim num]
dims) (DimIndex num
dmind, dim :: LMADDim num
dim@(LMADDim num
_ num
n Int
_ Monotonicity
_))
      | DimIndex num
dmind forall a. Eq a => a -> a -> Bool
== forall d. Num d => d -> d -> DimIndex d
unitSlice num
0 num
n = forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off ([LMADDim num]
dims forall a. [a] -> [a] -> [a]
++ [LMADDim num
dim])
    sliceOne (LMAD num
off [LMADDim num]
dims) (DimIndex num
dmind, LMADDim num
s num
n Int
p Monotonicity
m)
      | DimIndex num
dmind forall a. Eq a => a -> a -> Bool
== forall d. d -> d -> d -> DimIndex d
DimSlice (num
n forall a. Num a => a -> a -> a
- num
1) num
n (-num
1) =
          let off' :: num
off' = num
off forall a. Num a => a -> a -> a
+ forall num. (Eq num, IntegralExp num) => num -> num -> num
flatOneDim num
s (num
n forall a. Num a => a -> a -> a
- num
1)
           in forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off' ([LMADDim num]
dims forall a. [a] -> [a] -> [a]
++ [forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim (num
s forall a. Num a => a -> a -> a
* (-num
1)) num
n Int
p (Monotonicity -> Monotonicity
invertMonotonicity Monotonicity
m)])
    sliceOne (LMAD num
off [LMADDim num]
dims) (DimSlice num
b num
ne num
0, LMADDim num
s num
_ Int
p Monotonicity
_) =
      forall num. num -> [LMADDim num] -> LMAD num
LMAD (num
off forall a. Num a => a -> a -> a
+ forall num. (Eq num, IntegralExp num) => num -> num -> num
flatOneDim num
s num
b) ([LMADDim num]
dims forall a. [a] -> [a] -> [a]
++ [forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim num
0 num
ne Int
p Monotonicity
Unknown])
    sliceOne (LMAD num
off [LMADDim num]
dims) (DimSlice num
bs num
ns num
ss, LMADDim num
s num
_ Int
p Monotonicity
m) =
      let m' :: Monotonicity
m' = case forall e. IntegralExp e => e -> Maybe Int
sgn num
ss of
            Just Int
1 -> Monotonicity
m
            Just (-1) -> Monotonicity -> Monotonicity
invertMonotonicity Monotonicity
m
            Maybe Int
_ -> Monotonicity
Unknown
       in forall num. num -> [LMADDim num] -> LMAD num
LMAD (num
off forall a. Num a => a -> a -> a
+ num
s forall a. Num a => a -> a -> a
* num
bs) ([LMADDim num]
dims forall a. [a] -> [a] -> [a]
++ [forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim (num
ss forall a. Num a => a -> a -> a
* num
s) num
ns Int
p Monotonicity
m'])

    slicePreservesContiguous ::
      (Eq num, IntegralExp num) =>
      LMAD num ->
      Slice num ->
      Bool
    slicePreservesContiguous :: forall num.
(Eq num, IntegralExp num) =>
LMAD num -> Slice num -> Bool
slicePreservesContiguous (LMAD num
_ [LMADDim num]
dims) (Slice [DimIndex num]
slc) =
      -- remove from the slice the LMAD dimensions that have stride 0.
      -- If the LMAD was contiguous in mem, then these dims will not
      -- influence the contiguousness of the result.
      -- Also normalize the input slice, i.e., 0-stride and size-1
      -- slices are rewritten as DimFixed.
      let ([LMADDim num]
dims', [DimIndex num]
slc') =
            forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$
              forall a. (a -> Bool) -> [a] -> [a]
filter ((forall a. Eq a => a -> a -> Bool
/= num
0) forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall num. LMADDim num -> num
ldStride forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b. (a, b) -> a
fst) forall a b. (a -> b) -> a -> b
$
                forall a b. [a] -> [b] -> [(a, b)]
zip [LMADDim num]
dims forall a b. (a -> b) -> a -> b
$
                  forall a b. (a -> b) -> [a] -> [b]
map forall num.
(Eq num, IntegralExp num) =>
DimIndex num -> DimIndex num
normIndex [DimIndex num]
slc
          -- Check that:
          -- 1. a clean split point exists between Fixed and Sliced dims
          -- 2. the outermost sliced dim has +/- 1 stride.
          -- 3. the rest of inner sliced dims are full.
          (Bool
_, Bool
success) =
            forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
              ( \(Bool
found, Bool
res) (DimIndex num
slcdim, LMADDim num
_ num
n Int
_ Monotonicity
_) ->
                  case (DimIndex num
slcdim, Bool
found) of
                    (DimFix {}, Bool
True) -> (Bool
found, Bool
False)
                    (DimFix {}, Bool
False) -> (Bool
found, Bool
res)
                    (DimSlice num
_ num
_ num
ds, Bool
False) ->
                      -- outermost sliced dim: +/-1 stride
                      let res' :: Bool
res' = (num
ds forall a. Eq a => a -> a -> Bool
== num
1 Bool -> Bool -> Bool
|| num
ds forall a. Eq a => a -> a -> Bool
== -num
1)
                       in (Bool
True, Bool
res Bool -> Bool -> Bool
&& Bool
res')
                    (DimSlice num
_ num
ne num
ds, Bool
True) ->
                      -- inner sliced dim: needs to be full
                      let res' :: Bool
res' = (num
n forall a. Eq a => a -> a -> Bool
== num
ne) Bool -> Bool -> Bool
&& (num
ds forall a. Eq a => a -> a -> Bool
== num
1 Bool -> Bool -> Bool
|| num
ds forall a. Eq a => a -> a -> Bool
== -num
1)
                       in (Bool
found, Bool
res Bool -> Bool -> Bool
&& Bool
res')
              )
              (Bool
False, Bool
True)
              forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [DimIndex num]
slc' [LMADDim num]
dims'
       in Bool
success

    normIndex ::
      (Eq num, IntegralExp num) =>
      DimIndex num ->
      DimIndex num
    normIndex :: forall num.
(Eq num, IntegralExp num) =>
DimIndex num -> DimIndex num
normIndex (DimSlice num
b num
1 num
_) = forall d. d -> DimIndex d
DimFix num
b
    normIndex (DimSlice num
b num
_ num
0) = forall d. d -> DimIndex d
DimFix num
b
    normIndex DimIndex num
d = DimIndex num
d

-- | Slice an index function.
slice ::
  (Eq num, IntegralExp num) =>
  IxFun num ->
  Slice num ->
  IxFun num
slice :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
slice ixfun :: IxFun num
ixfun@(IxFun (lmad :: LMAD num
lmad@(LMAD num
_ [LMADDim num]
_) :| [LMAD num]
lmads) Shape num
oshp Bool
cg) Slice num
dim_slices
  -- Avoid identity slicing.
  | forall d. Slice d -> [DimIndex d]
unSlice Slice num
dim_slices forall a. Eq a => a -> a -> Bool
== forall a b. (a -> b) -> [a] -> [b]
map (forall d. Num d => d -> d -> DimIndex d
unitSlice num
0) (forall num. (Eq num, IntegralExp num) => IxFun num -> Shape num
shape IxFun num
ixfun) = IxFun num
ixfun
  | Just IxFun num
ixfun' <- forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> Maybe (IxFun num)
sliceOneLMAD IxFun num
ixfun Slice num
dim_slices = IxFun num
ixfun'
  | Bool
otherwise =
      case forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> Maybe (IxFun num)
sliceOneLMAD (forall num. IntegralExp num => Shape num -> IxFun num
iota (forall num. (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShape LMAD num
lmad)) Slice num
dim_slices of
        Just (IxFun (LMAD num
lmad' :| []) Shape num
_ Bool
cg') ->
          forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (LMAD num
lmad' forall a. a -> [a] -> NonEmpty a
:| LMAD num
lmad forall a. a -> [a] -> [a]
: [LMAD num]
lmads) Shape num
oshp (Bool
cg Bool -> Bool -> Bool
&& Bool
cg')
        Maybe (IxFun num)
_ -> forall a. HasCallStack => String -> a
error String
"slice: reached impossible case"

-- | Flat-slice an index function.
flatSlice ::
  (Eq num, IntegralExp num) =>
  IxFun num ->
  FlatSlice num ->
  IxFun num
flatSlice :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> FlatSlice num -> IxFun num
flatSlice ixfun :: IxFun num
ixfun@(IxFun (LMAD num
offset (LMADDim num
dim : [LMADDim num]
dims) :| [LMAD num]
lmads) Shape num
oshp Bool
cg) (FlatSlice num
new_offset [FlatDimIndex num]
is)
  | forall num. IxFun num -> Bool
hasContiguousPerm IxFun num
ixfun =
      let lmad :: LMAD num
lmad =
            forall num. num -> [LMADDim num] -> LMAD num
LMAD
              (num
offset forall a. Num a => a -> a -> a
+ num
new_offset forall a. Num a => a -> a -> a
* forall num. LMADDim num -> num
ldStride LMADDim num
dim)
              (forall a b. (a -> b) -> [a] -> [b]
map (forall {num}.
(Eq num, Num num) =>
num -> FlatDimIndex num -> LMADDim num
helper forall a b. (a -> b) -> a -> b
$ forall num. LMADDim num -> num
ldStride LMADDim num
dim) [FlatDimIndex num]
is forall a. Semigroup a => a -> a -> a
<> [LMADDim num]
dims)
              forall a b. a -> (a -> b) -> b
& forall num. Permutation -> LMAD num -> LMAD num
setLMADPermutation [Int
0 ..]
       in forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (LMAD num
lmad forall a. a -> [a] -> NonEmpty a
:| [LMAD num]
lmads) Shape num
oshp Bool
cg
  where
    helper :: num -> FlatDimIndex num -> LMADDim num
helper num
s0 (FlatDimIndex num
n num
s) =
      let new_mon :: Monotonicity
new_mon = if num
s0 forall a. Num a => a -> a -> a
* num
s forall a. Eq a => a -> a -> Bool
== num
1 then Monotonicity
Inc else Monotonicity
Unknown
       in forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim (num
s0 forall a. Num a => a -> a -> a
* num
s) num
n Int
0 Monotonicity
new_mon
flatSlice (IxFun (LMAD num
lmad :| [LMAD num]
lmads) Shape num
oshp Bool
cg) s :: FlatSlice num
s@(FlatSlice num
new_offset [FlatDimIndex num]
_) =
  forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (forall num. num -> [LMADDim num] -> LMAD num
LMAD (num
new_offset forall a. Num a => a -> a -> a
* num
base_stride) ([LMADDim num]
new_dims forall a. Semigroup a => a -> a -> a
<> [LMADDim num]
tail_dims) forall a. a -> [a] -> NonEmpty a
:| LMAD num
lmad forall a. a -> [a] -> [a]
: [LMAD num]
lmads) Shape num
oshp Bool
cg
  where
    tail_shapes :: Shape num
tail_shapes = forall a. [a] -> [a]
tail forall a b. (a -> b) -> a -> b
$ forall num. (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShape LMAD num
lmad
    base_stride :: num
base_stride = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product Shape num
tail_shapes
    tail_strides :: Shape num
tail_strides = forall a. [a] -> [a]
tail forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b -> b) -> b -> [a] -> [b]
scanr forall a. Num a => a -> a -> a
(*) num
1 Shape num
tail_shapes
    tail_dims :: [LMADDim num]
tail_dims = forall a b c d e.
(a -> b -> c -> d -> e) -> [a] -> [b] -> [c] -> [d] -> [e]
zipWith4 forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim Shape num
tail_strides Shape num
tail_shapes [forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape num
new_shapes ..] (forall a. a -> [a]
repeat Monotonicity
Inc)
    new_shapes :: Shape num
new_shapes = forall d. FlatSlice d -> [d]
flatSliceDims FlatSlice num
s
    new_strides :: Shape num
new_strides = forall a b. (a -> b) -> [a] -> [b]
map (forall a. Num a => a -> a -> a
* num
base_stride) forall a b. (a -> b) -> a -> b
$ forall d. FlatSlice d -> [d]
flatSliceStrides FlatSlice num
s
    new_dims :: [LMADDim num]
new_dims = forall a b c d e.
(a -> b -> c -> d -> e) -> [a] -> [b] -> [c] -> [d] -> [e]
zipWith4 forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim Shape num
new_strides Shape num
new_shapes [Int
0 ..] (forall a. a -> [a]
repeat Monotonicity
Inc)

-- | Handle the case where a reshape operation can stay inside a single LMAD.
--
-- There are four conditions that all must hold for the result of a reshape
-- operation to remain in the one-LMAD domain:
--
--   (1) the permutation of the underlying LMAD must leave unchanged
--       the LMAD dimensions that were *not* reshape coercions.
--   (2) the repetition of dimensions of the underlying LMAD must
--       refer only to the coerced-dimensions of the reshape operation.
--   (3) finally, the underlying memory is contiguous (and monotonous).
--
-- If any of these conditions do not hold, then the reshape operation will
-- conservatively add a new LMAD to the list, leading to a representation that
-- provides less opportunities for further analysis.
reshapeOneLMAD ::
  (Eq num, IntegralExp num) =>
  IxFun num ->
  Shape num ->
  Maybe (IxFun num)
reshapeOneLMAD :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Shape num -> Maybe (IxFun num)
reshapeOneLMAD ixfun :: IxFun num
ixfun@(IxFun (lmad :: LMAD num
lmad@(LMAD num
off [LMADDim num]
dims) :| [LMAD num]
lmads) Shape num
oldbase Bool
cg) Shape num
newshape = do
  let perm :: Permutation
perm = forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad
      dims_perm :: [LMADDim num]
dims_perm = forall a. Permutation -> [a] -> [a]
permuteFwd Permutation
perm [LMADDim num]
dims
      mid_dims :: [LMADDim num]
mid_dims = forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims) [LMADDim num]
dims_perm
      mon :: Monotonicity
mon = forall num. (Eq num, IntegralExp num) => IxFun num -> Monotonicity
ixfunMonotonicity IxFun num
ixfun

  forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$
    -- checking conditions (2)
    forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\(LMADDim num
s num
_ Int
_ Monotonicity
_) -> num
s forall a. Eq a => a -> a -> Bool
/= num
0) [LMADDim num]
mid_dims
      Bool -> Bool -> Bool
&&
      -- checking condition (1)
      forall {b}. (Eq b, Num b, Enum b) => b -> [b] -> Bool
consecutive Int
0 (forall a b. (a -> b) -> [a] -> [b]
map forall num. LMADDim num -> Int
ldPerm [LMADDim num]
mid_dims)
      Bool -> Bool -> Bool
&&
      -- checking condition (3)
      forall num. IxFun num -> Bool
hasContiguousPerm IxFun num
ixfun
      Bool -> Bool -> Bool
&& Bool
cg
      Bool -> Bool -> Bool
&& (Monotonicity
mon forall a. Eq a => a -> a -> Bool
== Monotonicity
Inc Bool -> Bool -> Bool
|| Monotonicity
mon forall a. Eq a => a -> a -> Bool
== Monotonicity
Dec)

  -- make new permutation
  let rsh_len :: Int
rsh_len = forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape num
newshape
      diff :: Int
diff = forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape num
newshape forall a. Num a => a -> a -> a
- forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims
      iota_shape :: Permutation
iota_shape = [Int
0 .. forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape num
newshape forall a. Num a => a -> a -> a
- Int
1]
      perm' :: Permutation
perm' =
        forall a b. (a -> b) -> [a] -> [b]
map
          ( \Int
i ->
              let ind :: Int
ind = Int
i forall a. Num a => a -> a -> a
- Int
diff
               in if (Int
i forall a. Ord a => a -> a -> Bool
>= Int
0) Bool -> Bool -> Bool
&& (Int
i forall a. Ord a => a -> a -> Bool
< Int
rsh_len)
                    then Int
i -- already checked mid_dims not affected
                    else forall num. LMADDim num -> Int
ldPerm ([LMADDim num]
dims forall a. [a] -> Int -> a
!! Int
ind) forall a. Num a => a -> a -> a
+ Int
diff
          )
          Permutation
iota_shape
      -- split the dimensions
      ([(Int, num)]
support_inds, [a]
repeat_inds) =
        forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
          (\([(Int, num)]
sup, [a]
rpt) (num
shpdim, Int
ip) -> ((Int
ip, num
shpdim) forall a. a -> [a] -> [a]
: [(Int, num)]
sup, [a]
rpt))
          ([], [])
          forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
reverse
          forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip Shape num
newshape Permutation
perm'

      (Permutation
sup_inds, Shape num
support) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (forall a. Ord a => a -> a -> Ordering
compare forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` forall a b. (a, b) -> a
fst) [(Int, num)]
support_inds
      ([a]
rpt_inds, [b]
repeats) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a. [a]
repeat_inds
      LMAD num
off' [LMADDim num]
dims_sup = forall num.
IntegralExp num =>
Monotonicity -> num -> [num] -> LMAD num
makeRotIota Monotonicity
mon num
off Shape num
support
      repeats' :: [LMADDim num]
repeats' = forall a b. (a -> b) -> [a] -> [b]
map (\num
n -> forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim num
0 num
n Int
0 Monotonicity
Unknown) forall a. [a]
repeats
      dims' :: [LMADDim num]
dims' =
        forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$
          forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (forall a. Ord a => a -> a -> Ordering
compare forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` forall a b. (a, b) -> a
fst) forall a b. (a -> b) -> a -> b
$
            forall a b. [a] -> [b] -> [(a, b)]
zip Permutation
sup_inds [LMADDim num]
dims_sup forall a. [a] -> [a] -> [a]
++ forall a b. [a] -> [b] -> [(a, b)]
zip forall a. [a]
rpt_inds [LMADDim num]
repeats'
      lmad' :: LMAD num
lmad' = forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off' [LMADDim num]
dims'
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (forall num. Permutation -> LMAD num -> LMAD num
setLMADPermutation Permutation
perm' LMAD num
lmad' forall a. a -> [a] -> NonEmpty a
:| [LMAD num]
lmads) Shape num
oldbase Bool
cg
  where
    consecutive :: b -> [b] -> Bool
consecutive b
_ [] = Bool
True
    consecutive b
i [b
p] = b
i forall a. Eq a => a -> a -> Bool
== b
p
    consecutive b
i [b]
ps = forall (t :: * -> *). Foldable t => t Bool -> Bool
and forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. Eq a => a -> a -> Bool
(==) [b]
ps [b
i, b
i forall a. Num a => a -> a -> a
+ b
1 ..]

-- | Reshape an index function.
reshape ::
  (Eq num, IntegralExp num) =>
  IxFun num ->
  Shape num ->
  IxFun num
reshape :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Shape num -> IxFun num
reshape IxFun num
ixfun Shape num
new_shape
  | Just IxFun num
ixfun' <- forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Shape num -> Maybe (IxFun num)
reshapeOneLMAD IxFun num
ixfun Shape num
new_shape = IxFun num
ixfun'
reshape (IxFun (LMAD num
lmad0 :| [LMAD num]
lmad0s) Shape num
oshp Bool
cg) Shape num
new_shape =
  case forall num. IntegralExp num => Shape num -> IxFun num
iota Shape num
new_shape of
    IxFun (LMAD num
lmad :| []) Shape num
_ Bool
_ -> forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (LMAD num
lmad forall a. a -> [a] -> NonEmpty a
:| LMAD num
lmad0 forall a. a -> [a] -> [a]
: [LMAD num]
lmad0s) Shape num
oshp Bool
cg
    IxFun num
_ -> forall a. HasCallStack => String -> a
error String
"reshape: reached impossible case"

-- | Coerce an index function to look like it has a new shape.
-- Dynamically the shape must be the same.
coerce ::
  (Eq num, IntegralExp num) =>
  IxFun num ->
  Shape num ->
  IxFun num
coerce :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Shape num -> IxFun num
coerce (IxFun (LMAD num
lmad :| [LMAD num]
lmads) Shape num
oshp Bool
cg) Shape num
new_shape =
  forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (LMAD num -> LMAD num
onLMAD LMAD num
lmad forall a. a -> [a] -> NonEmpty a
:| [LMAD num]
lmads) Shape num
oshp Bool
cg
  where
    onLMAD :: LMAD num -> LMAD num
onLMAD (LMAD num
offset [LMADDim num]
dims) = forall num. num -> [LMADDim num] -> LMAD num
LMAD num
offset forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall {num}. LMADDim num -> num -> LMADDim num
onDim [LMADDim num]
dims Shape num
new_shape
    onDim :: LMADDim num -> num -> LMADDim num
onDim LMADDim num
ld num
d = LMADDim num
ld {ldShape :: num
ldShape = num
d}

-- | The number of dimensions in the domain of the input function.
rank ::
  IntegralExp num =>
  IxFun num ->
  Int
rank :: forall num. IntegralExp num => IxFun num -> Int
rank (IxFun (LMAD num
_ [LMADDim num]
sss :| [LMAD num]
_) Shape num
_ Bool
_) = forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
sss

-- | Essentially @rebase new_base ixfun = ixfun o new_base@
-- Core soundness condition: @base ixfun == shape new_base@
-- Handles the case where a rebase operation can stay within m + n - 1 LMADs,
-- where m is the number of LMADs in the index function, and n is the number of
-- LMADs in the new base.  If both index function have only on LMAD, this means
-- that we stay within the single-LMAD domain.
--
-- We can often stay in that domain if the original ixfun is essentially a
-- slice, e.g. `x[i, (k1,m,s1), (k2,n,s2)] = orig`.
--
-- XXX: TODO: handle repetitions in both lmads.
--
-- How to handle repeated dimensions in the original?
--
--   (a) Shave them off of the last lmad of original
--   (b) Compose the result from (a) with the first
--       lmad of the new base
--   (c) apply a repeat operation on the result of (b).
--
-- However, I strongly suspect that for in-place update what we need is actually
-- the INVERSE of the rebase function, i.e., given an index function new-base
-- and another one orig, compute the index function ixfun0 such that:
--
--   new-base == rebase ixfun0 ixfun, or equivalently:
--   new-base == ixfun o ixfun0
--
-- because then I can go bottom up and compose with ixfun0 all the index
-- functions corresponding to the memory block associated with ixfun.
rebaseNice ::
  (Eq num, IntegralExp num) =>
  IxFun num ->
  IxFun num ->
  Maybe (IxFun num)
rebaseNice :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> IxFun num -> Maybe (IxFun num)
rebaseNice
  new_base :: IxFun num
new_base@(IxFun (LMAD num
lmad_base :| [LMAD num]
lmads_base) Shape num
_ Bool
cg_base)
  ixfun :: IxFun num
ixfun@(IxFun NonEmpty (LMAD num)
lmads Shape num
shp Bool
cg) = do
    let (LMAD num
lmad :| [LMAD num]
lmads') = forall a. NonEmpty a -> NonEmpty a
NE.reverse NonEmpty (LMAD num)
lmads
        dims :: [LMADDim num]
dims = forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad
        perm :: Permutation
perm = forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad
        perm_base :: Permutation
perm_base = forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad_base

    forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$
      -- Core rebase condition.
      forall num. IxFun num -> Shape num
base IxFun num
ixfun forall a. Eq a => a -> a -> Bool
== forall num. (Eq num, IntegralExp num) => IxFun num -> Shape num
shape IxFun num
new_base
        -- Conservative safety conditions: ixfun is contiguous and has known
        -- monotonicity for all dimensions.
        Bool -> Bool -> Bool
&& Bool
cg
        Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((forall a. Eq a => a -> a -> Bool
/= Monotonicity
Unknown) forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall num. LMADDim num -> Monotonicity
ldMon) [LMADDim num]
dims
        -- XXX: We should be able to handle some basic cases where both index
        -- functions have non-trivial permutations.
        Bool -> Bool -> Bool
&& (forall num. IxFun num -> Bool
hasContiguousPerm IxFun num
ixfun Bool -> Bool -> Bool
|| forall num. IxFun num -> Bool
hasContiguousPerm IxFun num
new_base)
        -- We need the permutations to be of the same size if we want to compose
        -- them.  They don't have to be of the same size if the ixfun has a trivial
        -- permutation.  Supporting this latter case allows us to rebase when ixfun
        -- has been created by slicing with fixed dimensions.
        Bool -> Bool -> Bool
&& (forall (t :: * -> *) a. Foldable t => t a -> Int
length Permutation
perm forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length Permutation
perm_base Bool -> Bool -> Bool
|| forall num. IxFun num -> Bool
hasContiguousPerm IxFun num
ixfun)
        -- To not have to worry about ixfun having non-1 strides, we also check that
        -- it is a row-major array (modulo permutation, which is handled
        -- separately).  Accept a non-full innermost dimension.  XXX: Maybe this can
        -- be less conservative?
        Bool -> Bool -> Bool
&& forall (t :: * -> *). Foldable t => t Bool -> Bool
and
          ( forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3
              (\num
sn LMADDim num
ld Bool
inner -> num
sn forall a. Eq a => a -> a -> Bool
== forall num. LMADDim num -> num
ldShape LMADDim num
ld Bool -> Bool -> Bool
|| (Bool
inner Bool -> Bool -> Bool
&& forall num. LMADDim num -> num
ldStride LMADDim num
ld forall a. Eq a => a -> a -> Bool
== num
1))
              Shape num
shp
              [LMADDim num]
dims
              (forall a. Int -> a -> [a]
replicate (forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims forall a. Num a => a -> a -> a
- Int
1) Bool
False forall a. [a] -> [a] -> [a]
++ [Bool
True])
          )

    -- Compose permutations, reverse strides and adjust offset if necessary.
    let perm_base' :: Permutation
perm_base' =
          if forall num. IxFun num -> Bool
hasContiguousPerm IxFun num
ixfun
            then Permutation
perm_base
            else forall a b. (a -> b) -> [a] -> [b]
map (Permutation
perm !!) Permutation
perm_base
        lmad_base' :: LMAD num
lmad_base' = forall num. Permutation -> LMAD num -> LMAD num
setLMADPermutation Permutation
perm_base' LMAD num
lmad_base
        dims_base :: [LMADDim num]
dims_base = forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad_base'
        n_fewer_dims :: Int
n_fewer_dims = forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims_base forall a. Num a => a -> a -> a
- forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims
        ([LMADDim num]
dims_base', Shape num
offs_contrib) =
          forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$
            forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
              ( \(LMADDim num
s1 num
n1 Int
p1 Monotonicity
_) (LMADDim num
_ num
_ Int
_ Monotonicity
m2) ->
                  let (num
s', num
off')
                        | Monotonicity
m2 forall a. Eq a => a -> a -> Bool
== Monotonicity
Inc = (num
s1, num
0)
                        | Bool
otherwise = (num
s1 forall a. Num a => a -> a -> a
* (-num
1), num
s1 forall a. Num a => a -> a -> a
* (num
n1 forall a. Num a => a -> a -> a
- num
1))
                   in (forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim num
s' num
n1 (Int
p1 forall a. Num a => a -> a -> a
- Int
n_fewer_dims) Monotonicity
Inc, num
off')
              )
              -- If @dims@ is morally a slice, it might have fewer dimensions than
              -- @dims_base@.  Drop extraneous outer dimensions.
              (forall a. Int -> [a] -> [a]
drop Int
n_fewer_dims [LMADDim num]
dims_base)
              [LMADDim num]
dims
        off_base :: num
off_base = forall num. LMAD num -> num
lmadOffset LMAD num
lmad_base' forall a. Num a => a -> a -> a
+ forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum Shape num
offs_contrib
        lmad_base'' :: LMAD num
lmad_base''
          | forall num. LMAD num -> num
lmadOffset LMAD num
lmad forall a. Eq a => a -> a -> Bool
== num
0 = forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off_base [LMADDim num]
dims_base'
          | Bool
otherwise =
              -- If the innermost dimension of the ixfun was not full (but still
              -- had a stride of 1), add its offset relative to the new base.
              forall num. Shape num -> LMAD num -> LMAD num
setLMADShape
                (forall num. (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShape LMAD num
lmad)
                ( forall num. num -> [LMADDim num] -> LMAD num
LMAD
                    (num
off_base forall a. Num a => a -> a -> a
+ forall num. LMADDim num -> num
ldStride (forall a. [a] -> a
last [LMADDim num]
dims_base) forall a. Num a => a -> a -> a
* forall num. LMAD num -> num
lmadOffset LMAD num
lmad)
                    [LMADDim num]
dims_base'
                )
        new_base' :: IxFun num
new_base' = forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (LMAD num
lmad_base'' forall a. a -> [a] -> NonEmpty a
:| [LMAD num]
lmads_base) Shape num
shp Bool
cg_base
        IxFun NonEmpty (LMAD num)
lmads_base' Shape num
_ Bool
_ = IxFun num
new_base'
        lmads'' :: NonEmpty (LMAD num)
lmads'' = [LMAD num]
lmads' forall a. [a] -> NonEmpty a -> NonEmpty a
++@ NonEmpty (LMAD num)
lmads_base'
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun NonEmpty (LMAD num)
lmads'' Shape num
shp (Bool
cg Bool -> Bool -> Bool
&& Bool
cg_base)

-- | Rebase an index function on top of a new base.
rebase ::
  (Eq num, IntegralExp num) =>
  IxFun num ->
  IxFun num ->
  IxFun num
rebase :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> IxFun num -> IxFun num
rebase new_base :: IxFun num
new_base@(IxFun NonEmpty (LMAD num)
lmads_base Shape num
shp_base Bool
cg_base) ixfun :: IxFun num
ixfun@(IxFun NonEmpty (LMAD num)
lmads Shape num
shp Bool
cg)
  | Just IxFun num
ixfun' <- forall num.
(Eq num, IntegralExp num) =>
IxFun num -> IxFun num -> Maybe (IxFun num)
rebaseNice IxFun num
new_base IxFun num
ixfun = IxFun num
ixfun'
  -- In the general case just concatenate LMADs since this refers to index
  -- function composition, which is always safe.
  | Bool
otherwise =
      let (NonEmpty (LMAD num)
lmads_base', Shape num
shp_base') =
            if forall num. IxFun num -> Shape num
base IxFun num
ixfun forall a. Eq a => a -> a -> Bool
== forall num. (Eq num, IntegralExp num) => IxFun num -> Shape num
shape IxFun num
new_base
              then (NonEmpty (LMAD num)
lmads_base, Shape num
shp_base)
              else
                let IxFun NonEmpty (LMAD num)
lmads' Shape num
shp_base'' Bool
_ = forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Shape num -> IxFun num
reshape IxFun num
new_base Shape num
shp
                 in (NonEmpty (LMAD num)
lmads', Shape num
shp_base'')
       in forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (NonEmpty (LMAD num)
lmads forall a. NonEmpty a -> NonEmpty a -> NonEmpty a
@++@ NonEmpty (LMAD num)
lmads_base') Shape num
shp_base' (Bool
cg Bool -> Bool -> Bool
&& Bool
cg_base)

-- | If the memory support of the index function is contiguous and row-major
-- (i.e., no transpositions, repetitions, rotates, etc.), then this should
-- return the offset from which the memory-support of this index function
-- starts.
linearWithOffset ::
  (Eq num, IntegralExp num) =>
  IxFun num ->
  num ->
  Maybe num
linearWithOffset :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
linearWithOffset ixfun :: IxFun num
ixfun@(IxFun (LMAD num
lmad :| []) Shape num
_ Bool
cg) num
elem_size
  | forall num. IxFun num -> Bool
hasContiguousPerm IxFun num
ixfun Bool -> Bool -> Bool
&& Bool
cg Bool -> Bool -> Bool
&& forall num. (Eq num, IntegralExp num) => IxFun num -> Monotonicity
ixfunMonotonicity IxFun num
ixfun forall a. Eq a => a -> a -> Bool
== Monotonicity
Inc =
      forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall num. LMAD num -> num
lmadOffset LMAD num
lmad forall a. Num a => a -> a -> a
* num
elem_size
linearWithOffset IxFun num
_ num
_ = forall a. Maybe a
Nothing

-- | Similar restrictions to @linearWithOffset@ except for transpositions, which
-- are returned together with the offset.
rearrangeWithOffset ::
  (Eq num, IntegralExp num) =>
  IxFun num ->
  num ->
  Maybe (num, [(Int, num)])
rearrangeWithOffset :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe (num, [(Int, num)])
rearrangeWithOffset (IxFun (LMAD num
lmad :| []) Shape num
oshp Bool
cg) num
elem_size = do
  -- Note that @cg@ describes whether the index function is
  -- contiguous, *ignoring permutations*.  This function requires that
  -- functionality.
  let perm :: Permutation
perm = forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad
      perm_contig :: Permutation
perm_contig = [Int
0 .. forall (t :: * -> *) a. Foldable t => t a -> Int
length Permutation
perm forall a. Num a => a -> a -> a
- Int
1]
  num
offset <-
    forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
linearWithOffset
      (forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (forall num. Permutation -> LMAD num -> LMAD num
setLMADPermutation Permutation
perm_contig LMAD num
lmad forall a. a -> [a] -> NonEmpty a
:| []) Shape num
oshp Bool
cg)
      num
elem_size
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (num
offset, forall a b. [a] -> [b] -> [(a, b)]
zip Permutation
perm (forall a. Permutation -> [a] -> [a]
permuteFwd Permutation
perm (forall num. (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShapeBase LMAD num
lmad)))
rearrangeWithOffset IxFun num
_ num
_ = forall a. Maybe a
Nothing

-- | Is this a row-major array starting at offset zero?
isLinear :: (Eq num, IntegralExp num) => IxFun num -> Bool
isLinear :: forall num. (Eq num, IntegralExp num) => IxFun num -> Bool
isLinear = (forall a. Eq a => a -> a -> Bool
== forall a. a -> Maybe a
Just num
0) forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b c. (a -> b -> c) -> b -> a -> c
flip forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
linearWithOffset num
1

permuteFwd :: Permutation -> [a] -> [a]
permuteFwd :: forall a. Permutation -> [a] -> [a]
permuteFwd Permutation
ps [a]
elems = forall a b. (a -> b) -> [a] -> [b]
map ([a]
elems !!) Permutation
ps

permuteInv :: Permutation -> [a] -> [a]
permuteInv :: forall a. Permutation -> [a] -> [a]
permuteInv Permutation
ps [a]
elems = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (forall a. Ord a => a -> a -> Ordering
compare forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` forall a b. (a, b) -> a
fst) forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip Permutation
ps [a]
elems

flatOneDim ::
  (Eq num, IntegralExp num) =>
  num ->
  num ->
  num
flatOneDim :: forall num. (Eq num, IntegralExp num) => num -> num -> num
flatOneDim num
s num
i
  | num
s forall a. Eq a => a -> a -> Bool
== num
0 = num
0
  | Bool
otherwise = num
i forall a. Num a => a -> a -> a
* num
s

-- | Generalised iota with user-specified offset and rotates.
makeRotIota ::
  IntegralExp num =>
  Monotonicity ->
  -- | Offset
  num ->
  -- | Shape
  [num] ->
  LMAD num
makeRotIota :: forall num.
IntegralExp num =>
Monotonicity -> num -> [num] -> LMAD num
makeRotIota Monotonicity
mon num
off [num]
ns
  | Monotonicity
mon forall a. Eq a => a -> a -> Bool
== Monotonicity
Inc Bool -> Bool -> Bool
|| Monotonicity
mon forall a. Eq a => a -> a -> Bool
== Monotonicity
Dec =
      let rk :: Int
rk = forall (t :: * -> *) a. Foldable t => t a -> Int
length [num]
ns
          ss0 :: [num]
ss0 = forall a. [a] -> [a]
reverse forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
take Int
rk forall a b. (a -> b) -> a -> b
$ forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl forall a. Num a => a -> a -> a
(*) num
1 forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
reverse [num]
ns
          ss :: [num]
ss =
            if Monotonicity
mon forall a. Eq a => a -> a -> Bool
== Monotonicity
Inc
              then [num]
ss0
              else forall a b. (a -> b) -> [a] -> [b]
map (forall a. Num a => a -> a -> a
* (-num
1)) [num]
ss0
          ps :: Permutation
ps = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (Integral a, Num b) => a -> b
fromIntegral [Int
0 .. Int
rk forall a. Num a => a -> a -> a
- Int
1]
          fi :: [Monotonicity]
fi = forall a. Int -> a -> [a]
replicate Int
rk Monotonicity
mon
       in forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off forall a b. (a -> b) -> a -> b
$ forall a b c d e.
(a -> b -> c -> d -> e) -> [a] -> [b] -> [c] -> [d] -> [e]
zipWith4 forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim [num]
ss [num]
ns Permutation
ps [Monotonicity]
fi
  | Bool
otherwise = forall a. HasCallStack => String -> a
error String
"makeRotIota: requires Inc or Dec"

-- | Check monotonicity of an index function.
ixfunMonotonicity ::
  (Eq num, IntegralExp num) =>
  IxFun num ->
  Monotonicity
ixfunMonotonicity :: forall num. (Eq num, IntegralExp num) => IxFun num -> Monotonicity
ixfunMonotonicity (IxFun (LMAD num
lmad :| [LMAD num]
lmads) Shape num
_ Bool
_) =
  let mon0 :: Monotonicity
mon0 = forall num. (Eq num, IntegralExp num) => LMAD num -> Monotonicity
lmadMonotonicityRots LMAD num
lmad
   in if forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((forall a. Eq a => a -> a -> Bool
== Monotonicity
mon0) forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall num. (Eq num, IntegralExp num) => LMAD num -> Monotonicity
lmadMonotonicityRots) [LMAD num]
lmads
        then Monotonicity
mon0
        else Monotonicity
Unknown
  where
    lmadMonotonicityRots ::
      (Eq num, IntegralExp num) =>
      LMAD num ->
      Monotonicity
    lmadMonotonicityRots :: forall num. (Eq num, IntegralExp num) => LMAD num -> Monotonicity
lmadMonotonicityRots (LMAD num
_ [LMADDim num]
dims)
      | forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall num.
(Eq num, IntegralExp num) =>
Monotonicity -> LMADDim num -> Bool
isMonDim Monotonicity
Inc) [LMADDim num]
dims = Monotonicity
Inc
      | forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall num.
(Eq num, IntegralExp num) =>
Monotonicity -> LMADDim num -> Bool
isMonDim Monotonicity
Dec) [LMADDim num]
dims = Monotonicity
Dec
      | Bool
otherwise = Monotonicity
Unknown

    isMonDim ::
      (Eq num, IntegralExp num) =>
      Monotonicity ->
      LMADDim num ->
      Bool
    isMonDim :: forall num.
(Eq num, IntegralExp num) =>
Monotonicity -> LMADDim num -> Bool
isMonDim Monotonicity
mon (LMADDim num
s num
_ Int
_ Monotonicity
ldmon) =
      num
s forall a. Eq a => a -> a -> Bool
== num
0 Bool -> Bool -> Bool
|| Monotonicity
mon forall a. Eq a => a -> a -> Bool
== Monotonicity
ldmon

-- | Turn all the leaves of the index function into 'Ext's.  We
--  require that there's only one LMAD, that the index function is
--  contiguous, and the base shape has only one dimension.
existentialize ::
  IxFun (TPrimExp Int64 a) ->
  IxFun (TPrimExp Int64 (Ext b))
existentialize :: forall a b.
IxFun (TPrimExp Int64 a) -> IxFun (TPrimExp Int64 (Ext b))
existentialize IxFun (TPrimExp Int64 a)
ixfun = forall s a. State s a -> s -> a
evalState (forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall a b. a -> b -> a
const forall {k} {t :: k} {a}. StateT Int Identity (TPrimExp t (Ext a))
mkExt) IxFun (TPrimExp Int64 a)
ixfun) Int
0
  where
    mkExt :: StateT Int Identity (TPrimExp t (Ext a))
mkExt = do
      Int
i <- forall s (m :: * -> *). MonadState s m => m s
get
      forall s (m :: * -> *). MonadState s m => s -> m ()
put forall a b. (a -> b) -> a -> b
$ Int
i forall a. Num a => a -> a -> a
+ Int
1
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp forall a b. (a -> b) -> a -> b
$ forall v. v -> PrimType -> PrimExp v
LeafExp (forall a. Int -> Ext a
Ext Int
i) PrimType
int64

-- | When comparing index functions as part of the type check in KernelsMem,
-- we may run into problems caused by the simplifier. As index functions can be
-- generalized over if-then-else expressions, the simplifier might hoist some of
-- the code from inside the if-then-else (computing the offset of an array, for
-- instance), but now the type checker cannot verify that the generalized index
-- function is valid, because some of the existentials are computed somewhere
-- else. To Work around this, we've had to relax the KernelsMem type-checker
-- a bit, specifically, we've introduced this function to verify whether two
-- index functions are "close enough" that we can assume that they match. We use
-- this instead of `ixfun1 == ixfun2` and hope that it's good enough.
closeEnough :: IxFun num -> IxFun num -> Bool
closeEnough :: forall num. IxFun num -> IxFun num -> Bool
closeEnough IxFun num
ixf1 IxFun num
ixf2 =
  (forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall num. IxFun num -> Shape num
base IxFun num
ixf1) forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall num. IxFun num -> Shape num
base IxFun num
ixf2))
    Bool -> Bool -> Bool
&& (forall a. NonEmpty a -> Int
NE.length (forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs IxFun num
ixf1) forall a. Eq a => a -> a -> Bool
== forall a. NonEmpty a -> Int
NE.length (forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs IxFun num
ixf2))
    Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all forall num. (LMAD num, LMAD num) -> Bool
closeEnoughLMADs (forall a b. NonEmpty a -> NonEmpty b -> NonEmpty (a, b)
NE.zip (forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs IxFun num
ixf1) (forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs IxFun num
ixf2))
    -- This treats ixf1 as the "declared type" that we are matching against.
    Bool -> Bool -> Bool
&& (forall num. IxFun num -> Bool
contiguous IxFun num
ixf1 forall a. Ord a => a -> a -> Bool
<= forall num. IxFun num -> Bool
contiguous IxFun num
ixf2)
  where
    closeEnoughLMADs :: (LMAD num, LMAD num) -> Bool
    closeEnoughLMADs :: forall num. (LMAD num, LMAD num) -> Bool
closeEnoughLMADs (LMAD num
lmad1, LMAD num
lmad2) =
      forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad1) forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad2)
        Bool -> Bool -> Bool
&& forall a b. (a -> b) -> [a] -> [b]
map forall num. LMADDim num -> Int
ldPerm (forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad1)
          forall a. Eq a => a -> a -> Bool
== forall a b. (a -> b) -> [a] -> [b]
map forall num. LMADDim num -> Int
ldPerm (forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad2)

-- | Returns true if two 'IxFun's are equivalent.
--
-- Equivalence in this case is defined as having the same number of LMADs, with
-- each pair of LMADs matching in permutation, offsets, strides and rotations.
equivalent :: Eq num => IxFun num -> IxFun num -> Bool
equivalent :: forall num. Eq num => IxFun num -> IxFun num -> Bool
equivalent IxFun num
ixf1 IxFun num
ixf2 =
  forall a. NonEmpty a -> Int
NE.length (forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs IxFun num
ixf1) forall a. Eq a => a -> a -> Bool
== forall a. NonEmpty a -> Int
NE.length (forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs IxFun num
ixf2)
    Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all forall {b}. Eq b => (LMAD b, LMAD b) -> Bool
equivalentLMADs (forall a b. NonEmpty a -> NonEmpty b -> NonEmpty (a, b)
NE.zip (forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs IxFun num
ixf1) (forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs IxFun num
ixf2))
  where
    equivalentLMADs :: (LMAD b, LMAD b) -> Bool
equivalentLMADs (LMAD b
lmad1, LMAD b
lmad2) =
      forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall num. LMAD num -> [LMADDim num]
lmadDims LMAD b
lmad1) forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall num. LMAD num -> [LMADDim num]
lmadDims LMAD b
lmad2)
        Bool -> Bool -> Bool
&& forall a b. (a -> b) -> [a] -> [b]
map forall num. LMADDim num -> Int
ldPerm (forall num. LMAD num -> [LMADDim num]
lmadDims LMAD b
lmad1)
          forall a. Eq a => a -> a -> Bool
== forall a b. (a -> b) -> [a] -> [b]
map forall num. LMADDim num -> Int
ldPerm (forall num. LMAD num -> [LMADDim num]
lmadDims LMAD b
lmad2)
        Bool -> Bool -> Bool
&& forall num. LMAD num -> num
lmadOffset LMAD b
lmad1
          forall a. Eq a => a -> a -> Bool
== forall num. LMAD num -> num
lmadOffset LMAD b
lmad2
        Bool -> Bool -> Bool
&& forall a b. (a -> b) -> [a] -> [b]
map forall num. LMADDim num -> num
ldStride (forall num. LMAD num -> [LMADDim num]
lmadDims LMAD b
lmad1)
          forall a. Eq a => a -> a -> Bool
== forall a b. (a -> b) -> [a] -> [b]
map forall num. LMADDim num -> num
ldStride (forall num. LMAD num -> [LMADDim num]
lmadDims LMAD b
lmad2)

-- | Computes the maximum span of an 'LMAD'. The result is the lowest and
-- highest flat values representable by that 'LMAD'.
flatSpan :: LMAD (TPrimExp Int64 VName) -> TPrimExp Int64 VName
flatSpan :: LMAD (TPrimExp Int64 VName) -> TPrimExp Int64 VName
flatSpan (LMAD TPrimExp Int64 VName
_ [LMADDim (TPrimExp Int64 VName)]
dims) =
  forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr
    ( \LMADDim (TPrimExp Int64 VName)
dim TPrimExp Int64 VName
upper ->
        let spn :: TPrimExp Int64 VName
spn = forall num. LMADDim num -> num
ldStride LMADDim (TPrimExp Int64 VName)
dim forall a. Num a => a -> a -> a
* (forall num. LMADDim num -> num
ldShape LMADDim (TPrimExp Int64 VName)
dim forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1)
         in -- If you've gotten this far, you've already lost
            TPrimExp Int64 VName
spn forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
upper
    )
    TPrimExp Int64 VName
0
    [LMADDim (TPrimExp Int64 VName)]
dims

-- | Conservatively flatten a list of LMAD dimensions
--
-- Since not all LMADs can actually be flattened, we try to overestimate the
-- flattened array instead. This means that any "holes" in betwen dimensions
-- will get filled out.
-- conservativeFlatten :: (IntegralExp e, Ord e, Pretty e) => LMAD e -> LMAD e
conservativeFlatten :: LMAD (TPrimExp Int64 VName) -> Maybe (LMAD (TPrimExp Int64 VName))
conservativeFlatten :: LMAD (TPrimExp Int64 VName) -> Maybe (LMAD (TPrimExp Int64 VName))
conservativeFlatten (LMAD TPrimExp Int64 VName
offset []) =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall num. num -> [LMADDim num] -> LMAD num
LMAD TPrimExp Int64 VName
offset [forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim TPrimExp Int64 VName
1 TPrimExp Int64 VName
1 Int
0 Monotonicity
Inc]
conservativeFlatten l :: LMAD (TPrimExp Int64 VName)
l@(LMAD TPrimExp Int64 VName
_ [LMADDim (TPrimExp Int64 VName)
_]) =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure LMAD (TPrimExp Int64 VName)
l
conservativeFlatten l :: LMAD (TPrimExp Int64 VName)
l@(LMAD TPrimExp Int64 VName
offset [LMADDim (TPrimExp Int64 VName)]
dims) = do
  TPrimExp Int64 VName
strd <-
    forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM
      TPrimExp Int64 VName
-> TPrimExp Int64 VName -> Maybe (TPrimExp Int64 VName)
gcd
      (forall num. LMADDim num -> num
ldStride forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
head [LMADDim (TPrimExp Int64 VName)]
dims)
      forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall num. LMADDim num -> num
ldStride [LMADDim (TPrimExp Int64 VName)]
dims
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall num. num -> [LMADDim num] -> LMAD num
LMAD TPrimExp Int64 VName
offset [forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim TPrimExp Int64 VName
strd (TPrimExp Int64 VName
shp forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1) Int
0 Monotonicity
Unknown]
  where
    shp :: TPrimExp Int64 VName
shp = LMAD (TPrimExp Int64 VName) -> TPrimExp Int64 VName
flatSpan LMAD (TPrimExp Int64 VName)
l

-- | Very conservative GCD calculation. Returns 'Nothing' if the result cannot
-- be immediately determined. Does not recurse at all.
gcd :: TPrimExp Int64 VName -> TPrimExp Int64 VName -> Maybe (TPrimExp Int64 VName)
gcd :: TPrimExp Int64 VName
-> TPrimExp Int64 VName -> Maybe (TPrimExp Int64 VName)
gcd TPrimExp Int64 VName
x TPrimExp Int64 VName
y = forall {a}. (Eq a, Num a) => a -> a -> Maybe a
gcd' (forall a. Num a => a -> a
abs TPrimExp Int64 VName
x) (forall a. Num a => a -> a
abs TPrimExp Int64 VName
y)
  where
    gcd' :: a -> a -> Maybe a
gcd' a
a a
b | a
a forall a. Eq a => a -> a -> Bool
== a
b = forall a. a -> Maybe a
Just a
a
    gcd' a
1 a
_ = forall a. a -> Maybe a
Just a
1
    gcd' a
_ a
1 = forall a. a -> Maybe a
Just a
1
    gcd' a
a a
0 = forall a. a -> Maybe a
Just a
a
    gcd' a
_ a
_ = forall a. Maybe a
Nothing -- gcd' b (a `Futhark.Util.IntegralExp.rem` b)

-- | Returns @True@ if the two 'LMAD's could be proven disjoint.
--
-- Uses some best-approximation heuristics to determine disjointness. For two
-- 1-dimensional arrays, we can guarantee whether or not they are disjoint, but
-- as soon as more than one dimension is involved, things get more
-- tricky. Currently, we try to 'conservativelyFlatten' any LMAD with more than
-- one dimension.
disjoint :: [(VName, PrimExp VName)] -> Names -> LMAD (TPrimExp Int64 VName) -> LMAD (TPrimExp Int64 VName) -> Bool
disjoint :: [(VName, PrimExp VName)]
-> Names
-> LMAD (TPrimExp Int64 VName)
-> LMAD (TPrimExp Int64 VName)
-> Bool
disjoint [(VName, PrimExp VName)]
less_thans Names
non_negatives (LMAD TPrimExp Int64 VName
offset1 [LMADDim (TPrimExp Int64 VName)
dim1]) (LMAD TPrimExp Int64 VName
offset2 [LMADDim (TPrimExp Int64 VName)
dim2]) =
  Maybe (TPrimExp Int64 VName) -> TPrimExp Int64 VName -> Bool
doesNotDivide (TPrimExp Int64 VName
-> TPrimExp Int64 VName -> Maybe (TPrimExp Int64 VName)
gcd (forall num. LMADDim num -> num
ldStride LMADDim (TPrimExp Int64 VName)
dim1) (forall num. LMADDim num -> num
ldStride LMADDim (TPrimExp Int64 VName)
dim2)) (TPrimExp Int64 VName
offset1 forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
offset2)
    Bool -> Bool -> Bool
|| [(VName, PrimExp VName)]
-> Names -> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Bool
AlgSimplify.lessThanish
      [(VName, PrimExp VName)]
less_thans
      Names
non_negatives
      (TPrimExp Int64 VName
offset2 forall a. Num a => a -> a -> a
+ (forall num. LMADDim num -> num
ldShape LMADDim (TPrimExp Int64 VName)
dim2 forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1) forall a. Num a => a -> a -> a
* forall num. LMADDim num -> num
ldStride LMADDim (TPrimExp Int64 VName)
dim2)
      TPrimExp Int64 VName
offset1
    Bool -> Bool -> Bool
|| [(VName, PrimExp VName)]
-> Names -> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Bool
AlgSimplify.lessThanish
      [(VName, PrimExp VName)]
less_thans
      Names
non_negatives
      (TPrimExp Int64 VName
offset1 forall a. Num a => a -> a -> a
+ (forall num. LMADDim num -> num
ldShape LMADDim (TPrimExp Int64 VName)
dim1 forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1) forall a. Num a => a -> a -> a
* forall num. LMADDim num -> num
ldStride LMADDim (TPrimExp Int64 VName)
dim1)
      TPrimExp Int64 VName
offset2
  where
    doesNotDivide :: Maybe (TPrimExp Int64 VName) -> TPrimExp Int64 VName -> Bool
    doesNotDivide :: Maybe (TPrimExp Int64 VName) -> TPrimExp Int64 VName -> Bool
doesNotDivide (Just TPrimExp Int64 VName
x) TPrimExp Int64 VName
y =
      forall e. IntegralExp e => e -> e -> e
Futhark.Util.IntegralExp.mod TPrimExp Int64 VName
y TPrimExp Int64 VName
x
        forall a b. a -> (a -> b) -> b
& forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped
        forall a b. a -> (a -> b) -> b
& forall v. PrimExp v -> PrimExp v
constFoldPrimExp
        forall a b. a -> (a -> b) -> b
& forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp
        forall a b. a -> (a -> b) -> b
& forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
(.==.) (TPrimExp Int64 VName
0 :: TPrimExp Int64 VName)
        forall a b. a -> (a -> b) -> b
& TPrimExp Bool VName -> Maybe Bool
primBool
        forall a b. a -> (a -> b) -> b
& forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False Bool -> Bool
not
    doesNotDivide Maybe (TPrimExp Int64 VName)
_ TPrimExp Int64 VName
_ = Bool
False
disjoint [(VName, PrimExp VName)]
less_thans Names
non_negatives LMAD (TPrimExp Int64 VName)
lmad1 LMAD (TPrimExp Int64 VName)
lmad2 =
  case (LMAD (TPrimExp Int64 VName) -> Maybe (LMAD (TPrimExp Int64 VName))
conservativeFlatten LMAD (TPrimExp Int64 VName)
lmad1, LMAD (TPrimExp Int64 VName) -> Maybe (LMAD (TPrimExp Int64 VName))
conservativeFlatten LMAD (TPrimExp Int64 VName)
lmad2) of
    (Just LMAD (TPrimExp Int64 VName)
lmad1', Just LMAD (TPrimExp Int64 VName)
lmad2') -> [(VName, PrimExp VName)]
-> Names
-> LMAD (TPrimExp Int64 VName)
-> LMAD (TPrimExp Int64 VName)
-> Bool
disjoint [(VName, PrimExp VName)]
less_thans Names
non_negatives LMAD (TPrimExp Int64 VName)
lmad1' LMAD (TPrimExp Int64 VName)
lmad2'
    (Maybe (LMAD (TPrimExp Int64 VName)),
 Maybe (LMAD (TPrimExp Int64 VName)))
_ -> Bool
False

disjoint2 :: scope -> asserts -> [(VName, PrimExp VName)] -> Names -> LMAD (TPrimExp Int64 VName) -> LMAD (TPrimExp Int64 VName) -> Bool
disjoint2 :: forall scope asserts.
scope
-> asserts
-> [(VName, PrimExp VName)]
-> Names
-> LMAD (TPrimExp Int64 VName)
-> LMAD (TPrimExp Int64 VName)
-> Bool
disjoint2 scope
_ asserts
_ [(VName, PrimExp VName)]
less_thans Names
non_negatives LMAD (TPrimExp Int64 VName)
lmad1 LMAD (TPrimExp Int64 VName)
lmad2 =
  let (SofP
offset1, [Interval]
interval1) = LMAD (TPrimExp Int64 VName) -> (SofP, [Interval])
lmadToIntervals LMAD (TPrimExp Int64 VName)
lmad1
      (SofP
offset2, [Interval]
interval2) = LMAD (TPrimExp Int64 VName) -> (SofP, [Interval])
lmadToIntervals LMAD (TPrimExp Int64 VName)
lmad2
      (SofP
neg_offset, SofP
pos_offset) =
        forall a. (a -> Bool) -> [a] -> ([a], [a])
partition Prod -> Bool
AlgSimplify.negated forall a b. (a -> b) -> a -> b
$
          SofP
offset1 SofP -> SofP -> SofP
`AlgSimplify.sub` SofP
offset2
      ([Interval]
interval1', [Interval]
interval2') =
        forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$
          forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (forall a b c. (a -> b -> c) -> b -> a -> c
flip SofP -> SofP -> Ordering
AlgSimplify.compareComplexity forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (PrimExp VName -> SofP
AlgSimplify.simplify0 forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Interval -> TPrimExp Int64 VName
stride forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b. (a, b) -> a
fst)) forall a b. (a -> b) -> a -> b
$
            [Interval] -> [Interval] -> [(Interval, Interval)]
intervalPairs [Interval]
interval1 [Interval]
interval2
   in case ( forall (m :: * -> *).
MonadFail m =>
SofP -> [Interval] -> m [Interval]
distributeOffset SofP
pos_offset [Interval]
interval1',
             forall (m :: * -> *).
MonadFail m =>
SofP -> [Interval] -> m [Interval]
distributeOffset (forall a b. (a -> b) -> [a] -> [b]
map Prod -> Prod
AlgSimplify.negate SofP
neg_offset) [Interval]
interval2'
           ) of
        (Just [Interval]
interval1'', Just [Interval]
interval2'') ->
          forall a. Maybe a -> Bool
isNothing
            ( forall scope asserts.
scope
-> asserts
-> [(VName, PrimExp VName)]
-> [PrimExp VName]
-> [Interval]
-> Maybe Interval
selfOverlap () () [(VName, PrimExp VName)]
less_thans (forall a b. (a -> b) -> [a] -> [b]
map (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall v. v -> PrimType -> PrimExp v
LeafExp forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int64) forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList Names
non_negatives) [Interval]
interval1''
            )
            Bool -> Bool -> Bool
&& forall a. Maybe a -> Bool
isNothing
              ( forall scope asserts.
scope
-> asserts
-> [(VName, PrimExp VName)]
-> [PrimExp VName]
-> [Interval]
-> Maybe Interval
selfOverlap () () [(VName, PrimExp VName)]
less_thans (forall a b. (a -> b) -> [a] -> [b]
map (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall v. v -> PrimType -> PrimExp v
LeafExp forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int64) forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList Names
non_negatives) [Interval]
interval2''
              )
            Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any
              (Bool -> Bool
not forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ([(VName, PrimExp VName)] -> Names -> Interval -> Interval -> Bool
intervalOverlap [(VName, PrimExp VName)]
less_thans Names
non_negatives))
              (forall a b. [a] -> [b] -> [(a, b)]
zip [Interval]
interval1'' [Interval]
interval2'')
        (Maybe [Interval], Maybe [Interval])
_ ->
          Bool
False

disjoint3 :: M.Map VName Type -> [PrimExp VName] -> [(VName, PrimExp VName)] -> [PrimExp VName] -> LMAD (TPrimExp Int64 VName) -> LMAD (TPrimExp Int64 VName) -> Bool
disjoint3 :: Map VName Type
-> [PrimExp VName]
-> [(VName, PrimExp VName)]
-> [PrimExp VName]
-> LMAD (TPrimExp Int64 VName)
-> LMAD (TPrimExp Int64 VName)
-> Bool
disjoint3 Map VName Type
scope [PrimExp VName]
asserts [(VName, PrimExp VName)]
less_thans [PrimExp VName]
non_negatives LMAD (TPrimExp Int64 VName)
lmad1 LMAD (TPrimExp Int64 VName)
lmad2 =
  let (SofP
offset1, [Interval]
interval1) = LMAD (TPrimExp Int64 VName) -> (SofP, [Interval])
lmadToIntervals LMAD (TPrimExp Int64 VName)
lmad1
      (SofP
offset2, [Interval]
interval2) = LMAD (TPrimExp Int64 VName) -> (SofP, [Interval])
lmadToIntervals LMAD (TPrimExp Int64 VName)
lmad2
      interval1' :: [Interval]
interval1' = forall a. Eq a => (a -> a) -> a -> a
fixPoint ([Interval] -> [Interval]
mergeDims forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [Interval] -> [Interval]
joinDims) forall a b. (a -> b) -> a -> b
$ forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (forall a b c. (a -> b -> c) -> b -> a -> c
flip SofP -> SofP -> Ordering
AlgSimplify.compareComplexity forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (PrimExp VName -> SofP
AlgSimplify.simplify0 forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Interval -> TPrimExp Int64 VName
stride)) [Interval]
interval1
      interval2' :: [Interval]
interval2' = forall a. Eq a => (a -> a) -> a -> a
fixPoint ([Interval] -> [Interval]
mergeDims forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [Interval] -> [Interval]
joinDims) forall a b. (a -> b) -> a -> b
$ forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (forall a b c. (a -> b -> c) -> b -> a -> c
flip SofP -> SofP -> Ordering
AlgSimplify.compareComplexity forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (PrimExp VName -> SofP
AlgSimplify.simplify0 forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Interval -> TPrimExp Int64 VName
stride)) [Interval]
interval2
      ([Interval]
interval1'', [Interval]
interval2'') =
        forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$
          forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (forall a b c. (a -> b -> c) -> b -> a -> c
flip SofP -> SofP -> Ordering
AlgSimplify.compareComplexity forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (PrimExp VName -> SofP
AlgSimplify.simplify0 forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Interval -> TPrimExp Int64 VName
stride forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b. (a, b) -> a
fst)) forall a b. (a -> b) -> a -> b
$
            [Interval] -> [Interval] -> [(Interval, Interval)]
intervalPairs [Interval]
interval1' [Interval]
interval2'
   in Int -> [Interval] -> [Interval] -> SofP -> Bool
disjointHelper Int
4 [Interval]
interval1'' [Interval]
interval2'' forall a b. (a -> b) -> a -> b
$ SofP
offset1 SofP -> SofP -> SofP
`AlgSimplify.sub` SofP
offset2
  where
    disjointHelper :: Int -> [Interval] -> [Interval] -> AlgSimplify.SofP -> Bool
    disjointHelper :: Int -> [Interval] -> [Interval] -> SofP -> Bool
disjointHelper Int
0 [Interval]
_ [Interval]
_ SofP
_ = Bool
False
    disjointHelper Int
i [Interval]
is10 [Interval]
is20 SofP
offset =
      let ([Interval]
is1, [Interval]
is2) =
            forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$
              forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (forall a b c. (a -> b -> c) -> b -> a -> c
flip SofP -> SofP -> Ordering
AlgSimplify.compareComplexity forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (PrimExp VName -> SofP
AlgSimplify.simplify0 forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Interval -> TPrimExp Int64 VName
stride forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b. (a, b) -> a
fst)) forall a b. (a -> b) -> a -> b
$
                [Interval] -> [Interval] -> [(Interval, Interval)]
intervalPairs [Interval]
is10 [Interval]
is20
          (SofP
neg_offset, SofP
pos_offset) = forall a. (a -> Bool) -> [a] -> ([a], [a])
partition Prod -> Bool
AlgSimplify.negated SofP
offset
       in case ( forall (m :: * -> *).
MonadFail m =>
SofP -> [Interval] -> m [Interval]
distributeOffset SofP
pos_offset [Interval]
is1,
                 forall (m :: * -> *).
MonadFail m =>
SofP -> [Interval] -> m [Interval]
distributeOffset (forall a b. (a -> b) -> [a] -> [b]
map Prod -> Prod
AlgSimplify.negate SofP
neg_offset) [Interval]
is2
               ) of
            (Just [Interval]
is1', Just [Interval]
is2') -> do
              let overlap1 :: Maybe Interval
overlap1 = forall scope asserts.
scope
-> asserts
-> [(VName, PrimExp VName)]
-> [PrimExp VName]
-> [Interval]
-> Maybe Interval
selfOverlap Map VName Type
scope [PrimExp VName]
asserts [(VName, PrimExp VName)]
less_thans [PrimExp VName]
non_negatives [Interval]
is1'
              let overlap2 :: Maybe Interval
overlap2 = forall scope asserts.
scope
-> asserts
-> [(VName, PrimExp VName)]
-> [PrimExp VName]
-> [Interval]
-> Maybe Interval
selfOverlap Map VName Type
scope [PrimExp VName]
asserts [(VName, PrimExp VName)]
less_thans [PrimExp VName]
non_negatives [Interval]
is2'
              case (Maybe Interval
overlap1, Maybe Interval
overlap2) of
                (Maybe Interval
Nothing, Maybe Interval
Nothing) ->
                  case [VName] -> Names
namesFromList forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM PrimExp VName -> Maybe VName
justLeafExp [PrimExp VName]
non_negatives of
                    Just Names
non_negatives' ->
                      forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any
                        (Bool -> Bool
not forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ([(VName, PrimExp VName)] -> Names -> Interval -> Interval -> Bool
intervalOverlap [(VName, PrimExp VName)]
less_thans Names
non_negatives'))
                        (forall a b. [a] -> [b] -> [(a, b)]
zip [Interval]
is1 [Interval]
is2)
                    Maybe Names
_ -> Bool
False
                (Just Interval
overlapping_dim, Maybe Interval
_) ->
                  let expanded_offset :: Maybe SofP
expanded_offset = SofP -> SofP
AlgSimplify.simplifySofP' forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SofP -> [Interval] -> Maybe SofP
expandOffset SofP
offset [Interval]
is1
                      splits :: [(SofP, [Interval])]
splits = Interval -> [Interval] -> [(SofP, [Interval])]
splitDim Interval
overlapping_dim [Interval]
is1'
                   in forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\(SofP
new_offset, [Interval]
new_is1) -> Int -> [Interval] -> [Interval] -> SofP -> Bool
disjointHelper (Int
i forall a. Num a => a -> a -> a
- Int
1) ([Interval] -> [Interval]
joinDims [Interval]
new_is1) ([Interval] -> [Interval]
joinDims [Interval]
is2') SofP
new_offset) [(SofP, [Interval])]
splits
                        Bool -> Bool -> Bool
|| forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (Int -> [Interval] -> [Interval] -> SofP -> Bool
disjointHelper (Int
i forall a. Num a => a -> a -> a
- Int
1) [Interval]
is1 [Interval]
is2) Maybe SofP
expanded_offset
                (Maybe Interval
_, Just Interval
overlapping_dim) ->
                  let expanded_offset :: Maybe SofP
expanded_offset = SofP -> SofP
AlgSimplify.simplifySofP' forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SofP -> [Interval] -> Maybe SofP
expandOffset SofP
offset [Interval]
is2
                      splits :: [(SofP, [Interval])]
splits = Interval -> [Interval] -> [(SofP, [Interval])]
splitDim Interval
overlapping_dim [Interval]
is2'
                   in forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all
                        ( \(SofP
new_offset, [Interval]
new_is2) ->
                            Int -> [Interval] -> [Interval] -> SofP -> Bool
disjointHelper (Int
i forall a. Num a => a -> a -> a
- Int
1) ([Interval] -> [Interval]
joinDims [Interval]
is1') ([Interval] -> [Interval]
joinDims [Interval]
new_is2) forall a b. (a -> b) -> a -> b
$
                              forall a b. (a -> b) -> [a] -> [b]
map Prod -> Prod
AlgSimplify.negate SofP
new_offset
                        )
                        [(SofP, [Interval])]
splits
                        Bool -> Bool -> Bool
|| forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (Int -> [Interval] -> [Interval] -> SofP -> Bool
disjointHelper (Int
i forall a. Num a => a -> a -> a
- Int
1) [Interval]
is1 [Interval]
is2) Maybe SofP
expanded_offset
            (Maybe [Interval], Maybe [Interval])
_ -> Bool
False

joinDims :: [Interval] -> [Interval]
joinDims :: [Interval] -> [Interval]
joinDims = [Interval] -> [Interval] -> [Interval]
helper []
  where
    helper :: [Interval] -> [Interval] -> [Interval]
helper [Interval]
acc [] = forall a. [a] -> [a]
reverse [Interval]
acc
    helper [Interval]
acc [Interval
x] = forall a. [a] -> [a]
reverse forall a b. (a -> b) -> a -> b
$ Interval
x forall a. a -> [a] -> [a]
: [Interval]
acc
    helper [Interval]
acc (Interval
x : Interval
y : [Interval]
rest) =
      if Interval -> TPrimExp Int64 VName
stride Interval
x forall a. Eq a => a -> a -> Bool
== Interval -> TPrimExp Int64 VName
stride Interval
y Bool -> Bool -> Bool
&& Interval -> TPrimExp Int64 VName
lowerBound Interval
x forall a. Eq a => a -> a -> Bool
== TPrimExp Int64 VName
0 Bool -> Bool -> Bool
&& Interval -> TPrimExp Int64 VName
lowerBound Interval
y forall a. Eq a => a -> a -> Bool
== TPrimExp Int64 VName
0
        then [Interval] -> [Interval] -> [Interval]
helper [Interval]
acc forall a b. (a -> b) -> a -> b
$ Interval
x {numElements :: TPrimExp Int64 VName
numElements = Interval -> TPrimExp Int64 VName
numElements Interval
x forall a. Num a => a -> a -> a
* Interval -> TPrimExp Int64 VName
numElements Interval
y} forall a. a -> [a] -> [a]
: [Interval]
rest
        else [Interval] -> [Interval] -> [Interval]
helper (Interval
x forall a. a -> [a] -> [a]
: [Interval]
acc) (Interval
y forall a. a -> [a] -> [a]
: [Interval]
rest)

mergeDims :: [Interval] -> [Interval]
mergeDims :: [Interval] -> [Interval]
mergeDims = [Interval] -> [Interval] -> [Interval]
helper [] forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. [a] -> [a]
reverse
  where
    helper :: [Interval] -> [Interval] -> [Interval]
helper [Interval]
acc [] = [Interval]
acc
    helper [Interval]
acc [Interval
x] = Interval
x forall a. a -> [a] -> [a]
: [Interval]
acc
    helper [Interval]
acc (Interval
x : Interval
y : [Interval]
rest) =
      if Interval -> TPrimExp Int64 VName
stride Interval
x forall a. Num a => a -> a -> a
* Interval -> TPrimExp Int64 VName
numElements Interval
x forall a. Eq a => a -> a -> Bool
== Interval -> TPrimExp Int64 VName
stride Interval
y Bool -> Bool -> Bool
&& Interval -> TPrimExp Int64 VName
lowerBound Interval
x forall a. Eq a => a -> a -> Bool
== TPrimExp Int64 VName
0 Bool -> Bool -> Bool
&& Interval -> TPrimExp Int64 VName
lowerBound Interval
y forall a. Eq a => a -> a -> Bool
== TPrimExp Int64 VName
0
        then [Interval] -> [Interval] -> [Interval]
helper [Interval]
acc forall a b. (a -> b) -> a -> b
$ Interval
x {numElements :: TPrimExp Int64 VName
numElements = Interval -> TPrimExp Int64 VName
numElements Interval
x forall a. Num a => a -> a -> a
* Interval -> TPrimExp Int64 VName
numElements Interval
y} forall a. a -> [a] -> [a]
: [Interval]
rest
        else [Interval] -> [Interval] -> [Interval]
helper (Interval
x forall a. a -> [a] -> [a]
: [Interval]
acc) (Interval
y forall a. a -> [a] -> [a]
: [Interval]
rest)

splitDim :: Interval -> [Interval] -> [(AlgSimplify.SofP, [Interval])]
splitDim :: Interval -> [Interval] -> [(SofP, [Interval])]
splitDim Interval
overlapping_dim0 [Interval]
is
  | [Prod
st] <- PrimExp VName -> SofP
AlgSimplify.simplify0 forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ Interval -> TPrimExp Int64 VName
stride Interval
overlapping_dim0,
    [Prod
st1] <- PrimExp VName -> SofP
AlgSimplify.simplify0 forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ Interval -> TPrimExp Int64 VName
stride Interval
overlapping_dim,
    [Prod
spn] <- PrimExp VName -> SofP
AlgSimplify.simplify0 forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ Interval -> TPrimExp Int64 VName
stride Interval
overlapping_dim forall a. Num a => a -> a -> a
* Interval -> TPrimExp Int64 VName
numElements Interval
overlapping_dim,
    Interval -> TPrimExp Int64 VName
lowerBound Interval
overlapping_dim forall a. Eq a => a -> a -> Bool
== TPrimExp Int64 VName
0,
    Just Prod
big_dim_elems <- Prod -> Prod -> Maybe Prod
AlgSimplify.maybeDivide Prod
spn Prod
st,
    Just Prod
small_dim_elems <- Prod -> Prod -> Maybe Prod
AlgSimplify.maybeDivide Prod
st Prod
st1 =
      [ ( [],
          forall a. [a] -> [a]
init [Interval]
before
            forall a. Semigroup a => a -> a -> a
<> [ TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval TPrimExp Int64 VName
0 (forall v. PrimExp v -> TPrimExp Int64 v
isInt64 forall a b. (a -> b) -> a -> b
$ Prod -> PrimExp VName
AlgSimplify.prodToExp Prod
big_dim_elems) (Interval -> TPrimExp Int64 VName
stride Interval
overlapping_dim0),
                 TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval TPrimExp Int64 VName
0 (forall v. PrimExp v -> TPrimExp Int64 v
isInt64 forall a b. (a -> b) -> a -> b
$ Prod -> PrimExp VName
AlgSimplify.prodToExp Prod
small_dim_elems) (Interval -> TPrimExp Int64 VName
stride Interval
overlapping_dim)
               ]
            forall a. Semigroup a => a -> a -> a
<> [Interval]
after
        )
      ]
  | Bool
otherwise =
      let shrunk_dim :: Interval
shrunk_dim = Interval
overlapping_dim {numElements :: TPrimExp Int64 VName
numElements = Interval -> TPrimExp Int64 VName
numElements Interval
overlapping_dim forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1}
          point_offset :: SofP
point_offset = PrimExp VName -> SofP
AlgSimplify.simplify0 forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ (Interval -> TPrimExp Int64 VName
numElements Interval
overlapping_dim forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1 forall a. Num a => a -> a -> a
+ Interval -> TPrimExp Int64 VName
lowerBound Interval
overlapping_dim) forall a. Num a => a -> a -> a
* Interval -> TPrimExp Int64 VName
stride Interval
overlapping_dim
       in [ (SofP
point_offset, [Interval]
before forall a. Semigroup a => a -> a -> a
<> [Interval]
after),
            ([], [Interval]
before forall a. Semigroup a => a -> a -> a
<> [Interval
shrunk_dim] forall a. Semigroup a => a -> a -> a
<> [Interval]
after)
          ]
  where
    ([Interval]
before, Interval
overlapping_dim, [Interval]
after) =
      forall a. HasCallStack => Maybe a -> a
fromJust forall a b. (a -> b) -> a -> b
$
        forall a. Eq a => a -> [a] -> Maybe Int
elemIndex Interval
overlapping_dim0 [Interval]
is
          forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall int a. Integral int => int -> [a] -> Maybe ([a], a, [a])
focusNth [Interval]
is forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (forall a. Num a => a -> a -> a
+ Int
1))

lmadToIntervals :: LMAD (TPrimExp Int64 VName) -> (AlgSimplify.SofP, [Interval])
lmadToIntervals :: LMAD (TPrimExp Int64 VName) -> (SofP, [Interval])
lmadToIntervals (LMAD TPrimExp Int64 VName
offset []) = (PrimExp VName -> SofP
AlgSimplify.simplify0 forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
offset, [TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval TPrimExp Int64 VName
0 TPrimExp Int64 VName
1 TPrimExp Int64 VName
1])
lmadToIntervals lmad :: LMAD (TPrimExp Int64 VName)
lmad@(LMAD TPrimExp Int64 VName
offset [LMADDim (TPrimExp Int64 VName)]
dims0) =
  (SofP
offset', forall a b. (a -> b) -> [a] -> [b]
map LMADDim (TPrimExp Int64 VName) -> Interval
helper forall a b. (a -> b) -> a -> b
$ forall a. Permutation -> [a] -> [a]
permuteInv (forall num. LMAD num -> Permutation
lmadPermutation LMAD (TPrimExp Int64 VName)
lmad) [LMADDim (TPrimExp Int64 VName)]
dims0)
  where
    offset' :: SofP
offset' = PrimExp VName -> SofP
AlgSimplify.simplify0 forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
offset

    helper :: LMADDim (TPrimExp Int64 VName) -> Interval
    helper :: LMADDim (TPrimExp Int64 VName) -> Interval
helper (LMADDim TPrimExp Int64 VName
strd TPrimExp Int64 VName
shp Int
_ Monotonicity
_) = do
      TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval TPrimExp Int64 VName
0 (TPrimExp Int64 VName -> TPrimExp Int64 VName
AlgSimplify.simplify' TPrimExp Int64 VName
shp) (TPrimExp Int64 VName -> TPrimExp Int64 VName
AlgSimplify.simplify' TPrimExp Int64 VName
strd)

-- | Dynamically determine if two 'LMADDim' are equal.
--
-- True if the dynamic values of their constituents are equal.
dynamicEqualsLMADDim :: Eq num => LMADDim (TPrimExp t num) -> LMADDim (TPrimExp t num) -> TPrimExp Bool num
dynamicEqualsLMADDim :: forall {k} num (t :: k).
Eq num =>
LMADDim (TPrimExp t num)
-> LMADDim (TPrimExp t num) -> TPrimExp Bool num
dynamicEqualsLMADDim LMADDim (TPrimExp t num)
dim1 LMADDim (TPrimExp t num)
dim2 =
  forall num. LMADDim num -> num
ldStride LMADDim (TPrimExp t num)
dim1 forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. forall num. LMADDim num -> num
ldStride LMADDim (TPrimExp t num)
dim2
    forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. forall num. LMADDim num -> num
ldShape LMADDim (TPrimExp t num)
dim1 forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. forall num. LMADDim num -> num
ldShape LMADDim (TPrimExp t num)
dim2
    forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. forall v. Bool -> TPrimExp Bool v
fromBool (forall num. LMADDim num -> Int
ldPerm LMADDim (TPrimExp t num)
dim1 forall a. Eq a => a -> a -> Bool
== forall num. LMADDim num -> Int
ldPerm LMADDim (TPrimExp t num)
dim2)
    forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. forall v. Bool -> TPrimExp Bool v
fromBool (forall num. LMADDim num -> Monotonicity
ldMon LMADDim (TPrimExp t num)
dim1 forall a. Eq a => a -> a -> Bool
== forall num. LMADDim num -> Monotonicity
ldMon LMADDim (TPrimExp t num)
dim2)

-- | Dynamically determine if two 'LMAD' are equal.
--
-- True if offset and constituent 'LMADDim' are equal.
dynamicEqualsLMAD :: Eq num => LMAD (TPrimExp t num) -> LMAD (TPrimExp t num) -> TPrimExp Bool num
dynamicEqualsLMAD :: forall {k} num (t :: k).
Eq num =>
LMAD (TPrimExp t num) -> LMAD (TPrimExp t num) -> TPrimExp Bool num
dynamicEqualsLMAD LMAD (TPrimExp t num)
lmad1 LMAD (TPrimExp t num)
lmad2 =
  forall num. LMAD num -> num
lmadOffset LMAD (TPrimExp t num)
lmad1 forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. forall num. LMAD num -> num
lmadOffset LMAD (TPrimExp t num)
lmad2
    forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr
      (forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
(.&&.) forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry forall {k} num (t :: k).
Eq num =>
LMADDim (TPrimExp t num)
-> LMADDim (TPrimExp t num) -> TPrimExp Bool num
dynamicEqualsLMADDim)
      forall v. TPrimExp Bool v
true
      (forall a b. [a] -> [b] -> [(a, b)]
zip (forall num. LMAD num -> [LMADDim num]
lmadDims LMAD (TPrimExp t num)
lmad1) (forall num. LMAD num -> [LMADDim num]
lmadDims LMAD (TPrimExp t num)
lmad2))