{-# OPTIONS_GHC -fno-warn-redundant-constraints #-}
-- | This module contains a representation for the index function based on
-- linear-memory accessor descriptors; see Zhu, Hoeflinger and David work.
module Futhark.IR.Mem.IxFun
       ( IxFun(..)
       , index
       , iota
       , offsetIndex
       , permute
       , rotate
       , reshape
       , slice
       , rebase
       , repeat
       , shape
       , rank
       , linearWithOffset
       , rearrangeWithOffset
       , isDirect
       , isLinear
       , substituteInIxFun
       , leastGeneralGeneralization
       , closeEnough
       )
       where

import Prelude hiding (mod, repeat)
import Data.List (sort, sortBy, zip4, zip5, zipWith5)
import qualified Data.List.NonEmpty as NE
import Data.List.NonEmpty (NonEmpty(..))
import Data.Function (on)
import Data.Maybe (isJust)
import Control.Monad.Identity
import Control.Monad.Writer
import qualified Data.Map.Strict as M

import Futhark.Analysis.PrimExp (PrimExp(..))
import Futhark.IR.Syntax.Core (Ext(..))
import Futhark.Transform.Substitute
import Futhark.Transform.Rename
import Futhark.IR.Syntax
  (ShapeChange, DimChange(..), DimIndex(..), Slice, unitSlice, dimFix)
import Futhark.IR.Prop
import Futhark.Util.IntegralExp
import Futhark.Util.Pretty
import Futhark.Analysis.PrimExp.Convert (substituteInPrimExp)
import qualified Futhark.Analysis.PrimExp.Generalize as PEG

type Shape num   = [num]
type Indices num = [num]
type Permutation = [Int]

data Monotonicity = Inc | Dec | Unknown
               -- ^ monotonously increasing, decreasing or unknown
             deriving (Int -> Monotonicity -> ShowS
[Monotonicity] -> ShowS
Monotonicity -> String
(Int -> Monotonicity -> ShowS)
-> (Monotonicity -> String)
-> ([Monotonicity] -> ShowS)
-> Show Monotonicity
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Monotonicity] -> ShowS
$cshowList :: [Monotonicity] -> ShowS
show :: Monotonicity -> String
$cshow :: Monotonicity -> String
showsPrec :: Int -> Monotonicity -> ShowS
$cshowsPrec :: Int -> Monotonicity -> ShowS
Show, Monotonicity -> Monotonicity -> Bool
(Monotonicity -> Monotonicity -> Bool)
-> (Monotonicity -> Monotonicity -> Bool) -> Eq Monotonicity
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Monotonicity -> Monotonicity -> Bool
$c/= :: Monotonicity -> Monotonicity -> Bool
== :: Monotonicity -> Monotonicity -> Bool
$c== :: Monotonicity -> Monotonicity -> Bool
Eq)

data LMADDim num = LMADDim { LMADDim num -> num
ldStride :: num
                           , LMADDim num -> num
ldRotate :: num
                           , LMADDim num -> num
ldShape :: num
                           , LMADDim num -> Int
ldPerm :: Int
                           , LMADDim num -> Monotonicity
ldMon :: Monotonicity
                           }
                 deriving (Int -> LMADDim num -> ShowS
[LMADDim num] -> ShowS
LMADDim num -> String
(Int -> LMADDim num -> ShowS)
-> (LMADDim num -> String)
-> ([LMADDim num] -> ShowS)
-> Show (LMADDim num)
forall num. Show num => Int -> LMADDim num -> ShowS
forall num. Show num => [LMADDim num] -> ShowS
forall num. Show num => LMADDim num -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [LMADDim num] -> ShowS
$cshowList :: forall num. Show num => [LMADDim num] -> ShowS
show :: LMADDim num -> String
$cshow :: forall num. Show num => LMADDim num -> String
showsPrec :: Int -> LMADDim num -> ShowS
$cshowsPrec :: forall num. Show num => Int -> LMADDim num -> ShowS
Show, LMADDim num -> LMADDim num -> Bool
(LMADDim num -> LMADDim num -> Bool)
-> (LMADDim num -> LMADDim num -> Bool) -> Eq (LMADDim num)
forall num. Eq num => LMADDim num -> LMADDim num -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: LMADDim num -> LMADDim num -> Bool
$c/= :: forall num. Eq num => LMADDim num -> LMADDim num -> Bool
== :: LMADDim num -> LMADDim num -> Bool
$c== :: forall num. Eq num => LMADDim num -> LMADDim num -> Bool
Eq)

-- | LMAD's representation consists of a general offset and for each dimension a
-- stride, rotate factor, number of elements (or shape), permutation, and
-- monotonicity. Note that the permutation is not strictly necessary in that the
-- permutation can be performed directly on LMAD dimensions, but then it is
-- difficult to extract the permutation back from an LMAD.
--
-- LMAD algebra is closed under composition w.r.t. operators such as permute,
-- repeat, index and slice.  However, other operations, such as reshape, cannot
-- always be represented inside the LMAD algebra.
--
-- It follows that the general representation of an index function is a list of
-- LMADS, in which each following LMAD in the list implicitly corresponds to an
-- irregular reshaping operation.
--
-- However, we expect that the common case is when the index function is one
-- LMAD -- we call this the "nice" representation.
--
-- Finally, the list of LMADs is kept in an @IxFun@ together with the shape of
-- the original array, and a bit to indicate whether the index function is
-- contiguous, i.e., if we instantiate all the points of the current index
-- function, do we get a contiguous memory interval?
--
-- By definition, the LMAD denotes the set of points (simplified):
--
--   \{ o + \Sigma_{j=0}^{k} ((i_j+r_j) `mod` n_j)*s_j,
--      \forall i_j such that 0<=i_j<n_j, j=1..k \}
data LMAD num = LMAD { LMAD num -> num
lmadOffset :: num
                     , LMAD num -> [LMADDim num]
lmadDims :: [LMADDim num]
                     }
                deriving (Int -> LMAD num -> ShowS
[LMAD num] -> ShowS
LMAD num -> String
(Int -> LMAD num -> ShowS)
-> (LMAD num -> String) -> ([LMAD num] -> ShowS) -> Show (LMAD num)
forall num. Show num => Int -> LMAD num -> ShowS
forall num. Show num => [LMAD num] -> ShowS
forall num. Show num => LMAD num -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [LMAD num] -> ShowS
$cshowList :: forall num. Show num => [LMAD num] -> ShowS
show :: LMAD num -> String
$cshow :: forall num. Show num => LMAD num -> String
showsPrec :: Int -> LMAD num -> ShowS
$cshowsPrec :: forall num. Show num => Int -> LMAD num -> ShowS
Show, LMAD num -> LMAD num -> Bool
(LMAD num -> LMAD num -> Bool)
-> (LMAD num -> LMAD num -> Bool) -> Eq (LMAD num)
forall num. Eq num => LMAD num -> LMAD num -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: LMAD num -> LMAD num -> Bool
$c/= :: forall num. Eq num => LMAD num -> LMAD num -> Bool
== :: LMAD num -> LMAD num -> Bool
$c== :: forall num. Eq num => LMAD num -> LMAD num -> Bool
Eq)

-- | An index function is a mapping from a multidimensional array
-- index space (the domain) to a one-dimensional memory index space.
-- Essentially, it explains where the element at position @[i,j,p]@ of
-- some array is stored inside the flat one-dimensional array that
-- constitutes its memory.  For example, we can use this to
-- distinguish row-major and column-major representations.
--
-- An index function is represented as a sequence of 'LMAD's.
data IxFun num = IxFun { IxFun num -> NonEmpty (LMAD num)
ixfunLMADs :: NonEmpty (LMAD num)
                       , IxFun num -> Shape num
base :: Shape num
                       , IxFun num -> Bool
ixfunContig :: Bool
                       -- ^ ignoring permutations, is the index function contiguous?
                       }
                 deriving (Int -> IxFun num -> ShowS
[IxFun num] -> ShowS
IxFun num -> String
(Int -> IxFun num -> ShowS)
-> (IxFun num -> String)
-> ([IxFun num] -> ShowS)
-> Show (IxFun num)
forall num. Show num => Int -> IxFun num -> ShowS
forall num. Show num => [IxFun num] -> ShowS
forall num. Show num => IxFun num -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [IxFun num] -> ShowS
$cshowList :: forall num. Show num => [IxFun num] -> ShowS
show :: IxFun num -> String
$cshow :: forall num. Show num => IxFun num -> String
showsPrec :: Int -> IxFun num -> ShowS
$cshowsPrec :: forall num. Show num => Int -> IxFun num -> ShowS
Show, IxFun num -> IxFun num -> Bool
(IxFun num -> IxFun num -> Bool)
-> (IxFun num -> IxFun num -> Bool) -> Eq (IxFun num)
forall num. Eq num => IxFun num -> IxFun num -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: IxFun num -> IxFun num -> Bool
$c/= :: forall num. Eq num => IxFun num -> IxFun num -> Bool
== :: IxFun num -> IxFun num -> Bool
$c== :: forall num. Eq num => IxFun num -> IxFun num -> Bool
Eq)


instance Pretty Monotonicity where
  ppr :: Monotonicity -> Doc
ppr = String -> Doc
text (String -> Doc) -> (Monotonicity -> String) -> Monotonicity -> Doc
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Monotonicity -> String
forall a. Show a => a -> String
show

instance Pretty num => Pretty (LMAD num) where
  ppr :: LMAD num -> Doc
ppr (LMAD num
offset [LMADDim num]
dims) =
    Doc -> Doc
braces (Doc -> Doc) -> Doc -> Doc
forall a b. (a -> b) -> a -> b
$ [Doc] -> Doc
semisep [ String -> Doc
text String
"offset: " Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
oneLine (num -> Doc
forall a. Pretty a => a -> Doc
ppr num
offset)
                     , String -> Doc
text String
"strides: " Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> (LMADDim num -> num) -> Doc
forall a. Pretty a => (LMADDim num -> a) -> Doc
p LMADDim num -> num
forall num. LMADDim num -> num
ldStride
                     , String -> Doc
text String
"rotates: " Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> (LMADDim num -> num) -> Doc
forall a. Pretty a => (LMADDim num -> a) -> Doc
p LMADDim num -> num
forall num. LMADDim num -> num
ldRotate
                     , String -> Doc
text String
"shape: " Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> (LMADDim num -> num) -> Doc
forall a. Pretty a => (LMADDim num -> a) -> Doc
p LMADDim num -> num
forall num. LMADDim num -> num
ldShape
                     , String -> Doc
text String
"permutation: " Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> (LMADDim num -> Int) -> Doc
forall a. Pretty a => (LMADDim num -> a) -> Doc
p LMADDim num -> Int
forall num. LMADDim num -> Int
ldPerm
                     , String -> Doc
text String
"monotonicity: " Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> (LMADDim num -> Monotonicity) -> Doc
forall a. Pretty a => (LMADDim num -> a) -> Doc
p LMADDim num -> Monotonicity
forall num. LMADDim num -> Monotonicity
ldMon
                     ]
    where p :: (LMADDim num -> a) -> Doc
p LMADDim num -> a
f = Doc -> Doc
oneLine (Doc -> Doc) -> Doc -> Doc
forall a b. (a -> b) -> a -> b
$ Doc -> Doc
brackets (Doc -> Doc) -> Doc -> Doc
forall a b. (a -> b) -> a -> b
$ [Doc] -> Doc
commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (LMADDim num -> Doc) -> [LMADDim num] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map (a -> Doc
forall a. Pretty a => a -> Doc
ppr (a -> Doc) -> (LMADDim num -> a) -> LMADDim num -> Doc
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LMADDim num -> a
f) [LMADDim num]
dims

instance Pretty num => Pretty (IxFun num) where
  ppr :: IxFun num -> Doc
ppr (IxFun NonEmpty (LMAD num)
lmads Shape num
oshp Bool
cg) =
    Doc -> Doc
braces (Doc -> Doc) -> Doc -> Doc
forall a b. (a -> b) -> a -> b
$ [Doc] -> Doc
semisep [ String -> Doc
text String
"base: " Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
brackets ([Doc] -> Doc
commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (num -> Doc) -> Shape num -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map num -> Doc
forall a. Pretty a => a -> Doc
ppr Shape num
oshp)
                     , String -> Doc
text String
"contiguous: " Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> String -> Doc
text (Bool -> String
forall a. Show a => a -> String
show Bool
cg)
                     , String -> Doc
text String
"LMADs: " Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
brackets ([Doc] -> Doc
commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ NonEmpty Doc -> [Doc]
forall a. NonEmpty a -> [a]
NE.toList (NonEmpty Doc -> [Doc]) -> NonEmpty Doc -> [Doc]
forall a b. (a -> b) -> a -> b
$ (LMAD num -> Doc) -> NonEmpty (LMAD num) -> NonEmpty Doc
forall a b. (a -> b) -> NonEmpty a -> NonEmpty b
NE.map LMAD num -> Doc
forall a. Pretty a => a -> Doc
ppr NonEmpty (LMAD num)
lmads)
                     ]


instance Substitute num => Substitute (LMAD num) where
  substituteNames :: Map VName VName -> LMAD num -> LMAD num
substituteNames Map VName VName
substs = (num -> num) -> LMAD num -> LMAD num
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((num -> num) -> LMAD num -> LMAD num)
-> (num -> num) -> LMAD num -> LMAD num
forall a b. (a -> b) -> a -> b
$ Map VName VName -> num -> num
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs

instance Substitute num => Substitute (IxFun num) where
  substituteNames :: Map VName VName -> IxFun num -> IxFun num
substituteNames Map VName VName
substs = (num -> num) -> IxFun num -> IxFun num
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((num -> num) -> IxFun num -> IxFun num)
-> (num -> num) -> IxFun num -> IxFun num
forall a b. (a -> b) -> a -> b
$ Map VName VName -> num -> num
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs

instance Substitute num => Rename (LMAD num) where
  rename :: LMAD num -> RenameM (LMAD num)
rename = LMAD num -> RenameM (LMAD num)
forall a. Substitute a => a -> RenameM a
substituteRename

instance Substitute num => Rename (IxFun num) where
  rename :: IxFun num -> RenameM (IxFun num)
rename = IxFun num -> RenameM (IxFun num)
forall a. Substitute a => a -> RenameM a
substituteRename


instance FreeIn num => FreeIn (LMAD num) where
  freeIn' :: LMAD num -> FV
freeIn' = (num -> FV) -> LMAD num -> FV
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap num -> FV
forall a. FreeIn a => a -> FV
freeIn'

instance FreeIn num => FreeIn (IxFun num) where
  freeIn' :: IxFun num -> FV
freeIn' = (num -> FV) -> IxFun num -> FV
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap num -> FV
forall a. FreeIn a => a -> FV
freeIn'

instance Functor LMAD where
  fmap :: (a -> b) -> LMAD a -> LMAD b
fmap a -> b
f = Identity (LMAD b) -> LMAD b
forall a. Identity a -> a
runIdentity (Identity (LMAD b) -> LMAD b)
-> (LMAD a -> Identity (LMAD b)) -> LMAD a -> LMAD b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> Identity b) -> LMAD a -> Identity (LMAD b)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (b -> Identity b
forall (m :: * -> *) a. Monad m => a -> m a
return (b -> Identity b) -> (a -> b) -> a -> Identity b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> b
f)

instance Functor IxFun where
  fmap :: (a -> b) -> IxFun a -> IxFun b
fmap a -> b
f = Identity (IxFun b) -> IxFun b
forall a. Identity a -> a
runIdentity (Identity (IxFun b) -> IxFun b)
-> (IxFun a -> Identity (IxFun b)) -> IxFun a -> IxFun b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> Identity b) -> IxFun a -> Identity (IxFun b)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (b -> Identity b
forall (m :: * -> *) a. Monad m => a -> m a
return (b -> Identity b) -> (a -> b) -> a -> Identity b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> b
f)


instance Foldable LMAD where
  foldMap :: (a -> m) -> LMAD a -> m
foldMap a -> m
f = Writer m (LMAD ()) -> m
forall w a. Writer w a -> w
execWriter (Writer m (LMAD ()) -> m)
-> (LMAD a -> Writer m (LMAD ())) -> LMAD a -> m
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> WriterT m Identity ()) -> LMAD a -> Writer m (LMAD ())
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (m -> WriterT m Identity ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell (m -> WriterT m Identity ())
-> (a -> m) -> a -> WriterT m Identity ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> m
f)

instance Foldable IxFun where
  foldMap :: (a -> m) -> IxFun a -> m
foldMap a -> m
f = Writer m (IxFun ()) -> m
forall w a. Writer w a -> w
execWriter (Writer m (IxFun ()) -> m)
-> (IxFun a -> Writer m (IxFun ())) -> IxFun a -> m
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> WriterT m Identity ()) -> IxFun a -> Writer m (IxFun ())
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (m -> WriterT m Identity ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell (m -> WriterT m Identity ())
-> (a -> m) -> a -> WriterT m Identity ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> m
f)


instance Traversable LMAD where
  traverse :: (a -> f b) -> LMAD a -> f (LMAD b)
traverse a -> f b
f (LMAD a
offset [LMADDim a]
dims) =
    b -> [LMADDim b] -> LMAD b
forall num. num -> [LMADDim num] -> LMAD num
LMAD (b -> [LMADDim b] -> LMAD b) -> f b -> f ([LMADDim b] -> LMAD b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> f b
f a
offset f ([LMADDim b] -> LMAD b) -> f [LMADDim b] -> f (LMAD b)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (LMADDim a -> f (LMADDim b)) -> [LMADDim a] -> f [LMADDim b]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse LMADDim a -> f (LMADDim b)
f' [LMADDim a]
dims
    where f' :: LMADDim a -> f (LMADDim b)
f' (LMADDim a
s a
r a
n Int
p Monotonicity
m) =
             b -> b -> b -> Int -> Monotonicity -> LMADDim b
forall num. num -> num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim (b -> b -> b -> Int -> Monotonicity -> LMADDim b)
-> f b -> f (b -> b -> Int -> Monotonicity -> LMADDim b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> f b
f a
s f (b -> b -> Int -> Monotonicity -> LMADDim b)
-> f b -> f (b -> Int -> Monotonicity -> LMADDim b)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> a -> f b
f a
r f (b -> Int -> Monotonicity -> LMADDim b)
-> f b -> f (Int -> Monotonicity -> LMADDim b)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> a -> f b
f a
n f (Int -> Monotonicity -> LMADDim b)
-> f Int -> f (Monotonicity -> LMADDim b)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> f Int
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
p f (Monotonicity -> LMADDim b) -> f Monotonicity -> f (LMADDim b)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Monotonicity -> f Monotonicity
forall (f :: * -> *) a. Applicative f => a -> f a
pure Monotonicity
m

instance Traversable IxFun where
  traverse :: (a -> f b) -> IxFun a -> f (IxFun b)
traverse a -> f b
f (IxFun NonEmpty (LMAD a)
lmads Shape a
oshp Bool
cg) =
    NonEmpty (LMAD b) -> Shape b -> Bool -> IxFun b
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun  (NonEmpty (LMAD b) -> Shape b -> Bool -> IxFun b)
-> f (NonEmpty (LMAD b)) -> f (Shape b -> Bool -> IxFun b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (LMAD a -> f (LMAD b))
-> NonEmpty (LMAD a) -> f (NonEmpty (LMAD b))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((a -> f b) -> LMAD a -> f (LMAD b)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse a -> f b
f) NonEmpty (LMAD a)
lmads f (Shape b -> Bool -> IxFun b)
-> f (Shape b) -> f (Bool -> IxFun b)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (a -> f b) -> Shape a -> f (Shape b)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse a -> f b
f Shape a
oshp f (Bool -> IxFun b) -> f Bool -> f (IxFun b)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Bool -> f Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
cg

(++@) :: [a] -> NonEmpty a -> NonEmpty a
[a]
es ++@ :: [a] -> NonEmpty a -> NonEmpty a
++@ (a
ne :| [a]
nes) = case [a]
es of
  a
e : [a]
es' -> a
e a -> [a] -> NonEmpty a
forall a. a -> [a] -> NonEmpty a
:| [a]
es' [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a
ne] [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
nes
  [] -> a
ne a -> [a] -> NonEmpty a
forall a. a -> [a] -> NonEmpty a
:| [a]
nes

(@++@) :: NonEmpty a -> NonEmpty a -> NonEmpty a
(a
x :| [a]
xs) @++@ :: NonEmpty a -> NonEmpty a -> NonEmpty a
@++@ (a
y :| [a]
ys) = a
x a -> [a] -> NonEmpty a
forall a. a -> [a] -> NonEmpty a
:| [a]
xs [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a
y] [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
ys

invertMonotonicity :: Monotonicity -> Monotonicity
invertMonotonicity :: Monotonicity -> Monotonicity
invertMonotonicity Monotonicity
Inc = Monotonicity
Dec
invertMonotonicity Monotonicity
Dec = Monotonicity
Inc
invertMonotonicity Monotonicity
Unknown = Monotonicity
Unknown

lmadPermutation :: LMAD num -> Permutation
lmadPermutation :: LMAD num -> Permutation
lmadPermutation = (LMADDim num -> Int) -> [LMADDim num] -> Permutation
forall a b. (a -> b) -> [a] -> [b]
map LMADDim num -> Int
forall num. LMADDim num -> Int
ldPerm ([LMADDim num] -> Permutation)
-> (LMAD num -> [LMADDim num]) -> LMAD num -> Permutation
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
lmadDims

setLMADPermutation :: Permutation -> LMAD num -> LMAD num
setLMADPermutation :: Permutation -> LMAD num -> LMAD num
setLMADPermutation Permutation
perm LMAD num
lmad =
  LMAD num
lmad { lmadDims :: [LMADDim num]
lmadDims = (LMADDim num -> Int -> LMADDim num)
-> [LMADDim num] -> Permutation -> [LMADDim num]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\LMADDim num
dim Int
p -> LMADDim num
dim { ldPerm :: Int
ldPerm = Int
p }) (LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad) Permutation
perm }

setLMADShape :: Shape num -> LMAD num -> LMAD num
setLMADShape :: Shape num -> LMAD num -> LMAD num
setLMADShape Shape num
shp LMAD num
lmad = LMAD num
lmad { lmadDims :: [LMADDim num]
lmadDims = (LMADDim num -> num -> LMADDim num)
-> [LMADDim num] -> Shape num -> [LMADDim num]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\LMADDim num
dim num
s -> LMADDim num
dim { ldShape :: num
ldShape = num
s }) (LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad) Shape num
shp }

-- | Substitute a name with a PrimExp in an LMAD.
substituteInLMAD :: Ord a => M.Map a (PrimExp a) -> LMAD (PrimExp a)
                 -> LMAD (PrimExp a)
substituteInLMAD :: Map a (PrimExp a) -> LMAD (PrimExp a) -> LMAD (PrimExp a)
substituteInLMAD Map a (PrimExp a)
tab (LMAD PrimExp a
offset [LMADDim (PrimExp a)]
dims) =
  let offset' :: PrimExp a
offset' = Map a (PrimExp a) -> PrimExp a -> PrimExp a
forall v. Ord v => Map v (PrimExp v) -> PrimExp v -> PrimExp v
substituteInPrimExp Map a (PrimExp a)
tab PrimExp a
offset
      dims' :: [LMADDim (PrimExp a)]
dims' = (LMADDim (PrimExp a) -> LMADDim (PrimExp a))
-> [LMADDim (PrimExp a)] -> [LMADDim (PrimExp a)]
forall a b. (a -> b) -> [a] -> [b]
map (\(LMADDim PrimExp a
s PrimExp a
r PrimExp a
n Int
p Monotonicity
m) ->
                     PrimExp a
-> PrimExp a
-> PrimExp a
-> Int
-> Monotonicity
-> LMADDim (PrimExp a)
forall num. num -> num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim
                     (Map a (PrimExp a) -> PrimExp a -> PrimExp a
forall v. Ord v => Map v (PrimExp v) -> PrimExp v -> PrimExp v
substituteInPrimExp Map a (PrimExp a)
tab PrimExp a
s)
                     (Map a (PrimExp a) -> PrimExp a -> PrimExp a
forall v. Ord v => Map v (PrimExp v) -> PrimExp v -> PrimExp v
substituteInPrimExp Map a (PrimExp a)
tab PrimExp a
r)
                     (Map a (PrimExp a) -> PrimExp a -> PrimExp a
forall v. Ord v => Map v (PrimExp v) -> PrimExp v -> PrimExp v
substituteInPrimExp Map a (PrimExp a)
tab PrimExp a
n)
                     Int
p Monotonicity
m)
              [LMADDim (PrimExp a)]
dims
  in PrimExp a -> [LMADDim (PrimExp a)] -> LMAD (PrimExp a)
forall num. num -> [LMADDim num] -> LMAD num
LMAD PrimExp a
offset' [LMADDim (PrimExp a)]
dims'

-- | Substitute a name with a PrimExp in an index function.
substituteInIxFun :: (Ord a) => M.Map a (PrimExp a) -> IxFun (PrimExp a)
                  -> IxFun (PrimExp a)
substituteInIxFun :: Map a (PrimExp a) -> IxFun (PrimExp a) -> IxFun (PrimExp a)
substituteInIxFun Map a (PrimExp a)
tab (IxFun NonEmpty (LMAD (PrimExp a))
lmads Shape (PrimExp a)
oshp Bool
cg) =
  NonEmpty (LMAD (PrimExp a))
-> Shape (PrimExp a) -> Bool -> IxFun (PrimExp a)
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun ((LMAD (PrimExp a) -> LMAD (PrimExp a))
-> NonEmpty (LMAD (PrimExp a)) -> NonEmpty (LMAD (PrimExp a))
forall a b. (a -> b) -> NonEmpty a -> NonEmpty b
NE.map (Map a (PrimExp a) -> LMAD (PrimExp a) -> LMAD (PrimExp a)
forall a.
Ord a =>
Map a (PrimExp a) -> LMAD (PrimExp a) -> LMAD (PrimExp a)
substituteInLMAD Map a (PrimExp a)
tab) NonEmpty (LMAD (PrimExp a))
lmads)
        ((PrimExp a -> PrimExp a) -> Shape (PrimExp a) -> Shape (PrimExp a)
forall a b. (a -> b) -> [a] -> [b]
map (Map a (PrimExp a) -> PrimExp a -> PrimExp a
forall v. Ord v => Map v (PrimExp v) -> PrimExp v -> PrimExp v
substituteInPrimExp Map a (PrimExp a)
tab) Shape (PrimExp a)
oshp)
        Bool
cg

-- | Is this is a row-major array?
isDirect :: (Eq num, IntegralExp num) => IxFun num -> Bool
isDirect :: IxFun num -> Bool
isDirect ixfun :: IxFun num
ixfun@(IxFun (LMAD num
offset [LMADDim num]
dims :| []) Shape num
oshp Bool
True) =
  let strides_expected :: Shape num
strides_expected = Shape num -> Shape num
forall a. [a] -> [a]
reverse (Shape num -> Shape num) -> Shape num -> Shape num
forall a b. (a -> b) -> a -> b
$ (num -> num -> num) -> num -> Shape num -> Shape num
forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl num -> num -> num
forall a. Num a => a -> a -> a
(*) num
1 (Shape num -> Shape num
forall a. [a] -> [a]
reverse (Shape num -> Shape num
forall a. [a] -> [a]
tail Shape num
oshp))
  in IxFun num -> Bool
forall a. IxFun a -> Bool
hasContiguousPerm IxFun num
ixfun Bool -> Bool -> Bool
&&
     Shape num -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape num
oshp Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [LMADDim num] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims Bool -> Bool -> Bool
&&
     num
offset num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0 Bool -> Bool -> Bool
&&
     ((LMADDim num, Int, num, num) -> Bool)
-> [(LMADDim num, Int, num, num)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\(LMADDim num
s num
r num
n Int
p Monotonicity
_, Int
m, num
d, num
se) ->
            num
s num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
se Bool -> Bool -> Bool
&& num
r num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0 Bool -> Bool -> Bool
&& num
n num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
d Bool -> Bool -> Bool
&& Int
p Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
m)
     ([LMADDim num]
-> Permutation
-> Shape num
-> Shape num
-> [(LMADDim num, Int, num, num)]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [LMADDim num]
dims [Int
0..[LMADDim num] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] Shape num
oshp Shape num
strides_expected)
isDirect IxFun num
_ = Bool
False

-- | Does the index function have an ascending permutation?
hasContiguousPerm :: IxFun num -> Bool
hasContiguousPerm :: IxFun num -> Bool
hasContiguousPerm (IxFun (LMAD num
lmad :| []) Shape num
_ Bool
_) =
  let perm :: Permutation
perm = LMAD num -> Permutation
forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad
  in Permutation
perm Permutation -> Permutation -> Bool
forall a. Eq a => a -> a -> Bool
== Permutation -> Permutation
forall a. Ord a => [a] -> [a]
sort Permutation
perm
hasContiguousPerm IxFun num
_ = Bool
False

-- | Shape of an index function.
shape :: (Eq num, IntegralExp num) => IxFun num -> Shape num
shape :: IxFun num -> Shape num
shape (IxFun (LMAD num
lmad :| [LMAD num]
_) Shape num
_ Bool
_) = LMAD num -> Shape num
forall num. (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShape LMAD num
lmad

-- | Shape of an LMAD.
lmadShape :: (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShape :: LMAD num -> Shape num
lmadShape LMAD num
lmad = Permutation -> Shape num -> Shape num
forall a. Permutation -> [a] -> [a]
permuteInv (LMAD num -> Permutation
forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad) (Shape num -> Shape num) -> Shape num -> Shape num
forall a b. (a -> b) -> a -> b
$ LMAD num -> Shape num
forall num. (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShapeBase LMAD num
lmad

-- | Shape of an LMAD, ignoring permutations.
lmadShapeBase :: (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShapeBase :: LMAD num -> Shape num
lmadShapeBase = (LMADDim num -> num) -> [LMADDim num] -> Shape num
forall a b. (a -> b) -> [a] -> [b]
map LMADDim num -> num
forall num. LMADDim num -> num
ldShape ([LMADDim num] -> Shape num)
-> (LMAD num -> [LMADDim num]) -> LMAD num -> Shape num
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
lmadDims

-- | Compute the flat memory index for a complete set @inds@ of array indices
-- and a certain element size @elem_size@.
index :: (IntegralExp num, Eq num) =>
          IxFun num -> Indices num -> num
index :: IxFun num -> Indices num -> num
index = NonEmpty (LMAD num) -> Indices num -> num
forall num.
(IntegralExp num, Eq num) =>
NonEmpty (LMAD num) -> Indices num -> num
indexFromLMADs (NonEmpty (LMAD num) -> Indices num -> num)
-> (IxFun num -> NonEmpty (LMAD num))
-> IxFun num
-> Indices num
-> num
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IxFun num -> NonEmpty (LMAD num)
forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs
  where indexFromLMADs :: (IntegralExp num, Eq num) =>
                          NonEmpty (LMAD num) -> Indices num -> num
        indexFromLMADs :: NonEmpty (LMAD num) -> Indices num -> num
indexFromLMADs (LMAD num
lmad :| []) Indices num
inds = LMAD num -> Indices num -> num
forall num.
(IntegralExp num, Eq num) =>
LMAD num -> Indices num -> num
indexLMAD LMAD num
lmad Indices num
inds
        indexFromLMADs (LMAD num
lmad1 :| LMAD num
lmad2 : [LMAD num]
lmads) Indices num
inds =
          let i_flat :: num
i_flat   = LMAD num -> Indices num -> num
forall num.
(IntegralExp num, Eq num) =>
LMAD num -> Indices num -> num
indexLMAD LMAD num
lmad1 Indices num
inds
              new_inds :: Indices num
new_inds = Indices num -> num -> Indices num
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex (Permutation -> Indices num -> Indices num
forall a. Permutation -> [a] -> [a]
permuteFwd (LMAD num -> Permutation
forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad2) (Indices num -> Indices num) -> Indices num -> Indices num
forall a b. (a -> b) -> a -> b
$ LMAD num -> Indices num
forall num. (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShapeBase LMAD num
lmad2) num
i_flat
          in NonEmpty (LMAD num) -> Indices num -> num
forall num.
(IntegralExp num, Eq num) =>
NonEmpty (LMAD num) -> Indices num -> num
indexFromLMADs (LMAD num
lmad2 LMAD num -> [LMAD num] -> NonEmpty (LMAD num)
forall a. a -> [a] -> NonEmpty a
:| [LMAD num]
lmads) Indices num
new_inds

        -- | Compute the flat index of an LMAD.
        indexLMAD :: (IntegralExp num, Eq num) =>
                     LMAD num -> Indices num -> num
        indexLMAD :: LMAD num -> Indices num -> num
indexLMAD lmad :: LMAD num
lmad@(LMAD num
off [LMADDim num]
dims) Indices num
inds =
          let prod :: num
prod = Indices num -> num
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum (Indices num -> num) -> Indices num -> num
forall a b. (a -> b) -> a -> b
$ ((num, num, num) -> num -> num)
-> [(num, num, num)] -> Indices num -> Indices num
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (num, num, num) -> num -> num
forall num.
(Eq num, IntegralExp num) =>
(num, num, num) -> num -> num
flatOneDim
                             ((LMADDim num -> (num, num, num))
-> [LMADDim num] -> [(num, num, num)]
forall a b. (a -> b) -> [a] -> [b]
map (\(LMADDim num
s num
r num
n Int
_ Monotonicity
_) -> (num
s, num
r, num
n)) [LMADDim num]
dims)
                             (Permutation -> Indices num -> Indices num
forall a. Permutation -> [a] -> [a]
permuteInv (LMAD num -> Permutation
forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad) Indices num
inds)
          in num
off num -> num -> num
forall a. Num a => a -> a -> a
+ num
prod

-- | iota.
iota :: IntegralExp num => Shape num -> IxFun num
iota :: Shape num -> IxFun num
iota Shape num
ns =
  let rs :: Shape num
rs = Int -> num -> Shape num
forall a. Int -> a -> [a]
replicate (Shape num -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape num
ns) num
0
  in NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (Monotonicity -> num -> [(num, num)] -> LMAD num
forall num.
IntegralExp num =>
Monotonicity -> num -> [(num, num)] -> LMAD num
makeRotIota Monotonicity
Inc num
0 (Shape num -> Shape num -> [(num, num)]
forall a b. [a] -> [b] -> [(a, b)]
zip Shape num
rs Shape num
ns) LMAD num -> [LMAD num] -> NonEmpty (LMAD num)
forall a. a -> [a] -> NonEmpty a
:| []) Shape num
ns Bool
True

-- | Permute dimensions.
permute :: IntegralExp num =>
           IxFun num -> Permutation -> IxFun num
permute :: IxFun num -> Permutation -> IxFun num
permute (IxFun (LMAD num
lmad :| [LMAD num]
lmads) Shape num
oshp Bool
cg) Permutation
perm_new =
  let perm_cur :: Permutation
perm_cur = LMAD num -> Permutation
forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad
      perm :: Permutation
perm = (Int -> Int) -> Permutation -> Permutation
forall a b. (a -> b) -> [a] -> [b]
map (Permutation
perm_cur Permutation -> Int -> Int
forall a. [a] -> Int -> a
!!) Permutation
perm_new
  in NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (Permutation -> LMAD num -> LMAD num
forall num. Permutation -> LMAD num -> LMAD num
setLMADPermutation Permutation
perm LMAD num
lmad LMAD num -> [LMAD num] -> NonEmpty (LMAD num)
forall a. a -> [a] -> NonEmpty a
:| [LMAD num]
lmads) Shape num
oshp Bool
cg

-- | Repeat dimensions.
repeat :: (Eq num, IntegralExp num) =>
          IxFun num -> [Shape num] -> Shape num -> IxFun num
repeat :: IxFun num -> [Shape num] -> Shape num -> IxFun num
repeat (IxFun (lmad :: LMAD num
lmad@(LMAD num
off [LMADDim num]
dims) :| [LMAD num]
lmads) Shape num
oshp Bool
_) [Shape num]
shps Shape num
shp =
  let perm :: Permutation
perm = LMAD num -> Permutation
forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad
      -- inverse permute the shapes and update the permutation
      lens :: Permutation
lens = (Shape num -> Int) -> [Shape num] -> Permutation
forall a b. (a -> b) -> [a] -> [b]
map (\Shape num
s -> Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Shape num -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape num
s) [Shape num]
shps
      ([Shape num]
shps', Permutation
lens') = [(Shape num, Int)] -> ([Shape num], Permutation)
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Shape num, Int)] -> ([Shape num], Permutation))
-> [(Shape num, Int)] -> ([Shape num], Permutation)
forall a b. (a -> b) -> a -> b
$ Permutation -> [(Shape num, Int)] -> [(Shape num, Int)]
forall a. Permutation -> [a] -> [a]
permuteInv Permutation
perm ([(Shape num, Int)] -> [(Shape num, Int)])
-> [(Shape num, Int)] -> [(Shape num, Int)]
forall a b. (a -> b) -> a -> b
$ [Shape num] -> Permutation -> [(Shape num, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Shape num]
shps Permutation
lens
      scn :: Permutation
scn = Int -> Permutation -> Permutation
forall a. Int -> [a] -> [a]
drop Int
1 (Permutation -> Permutation) -> Permutation -> Permutation
forall a b. (a -> b) -> a -> b
$ (Int -> Int -> Int) -> Int -> Permutation -> Permutation
forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) Int
0 Permutation
lens'
      perm' :: Permutation
perm' = ((Int, Int) -> Permutation) -> [(Int, Int)] -> Permutation
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (\(Int
p, Int
l) -> (Int -> Int) -> Permutation -> Permutation
forall a b. (a -> b) -> [a] -> [b]
map (\Int
i-> (Permutation
scn Permutation -> Int -> Int
forall a. [a] -> Int -> a
!! Int
p) Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
l Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i) [Int
0..Int
lInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1])
                        ([(Int, Int)] -> Permutation) -> [(Int, Int)] -> Permutation
forall a b. (a -> b) -> a -> b
$ Permutation -> Permutation -> [(Int, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip Permutation
perm Permutation
lens
      tmp :: Int
tmp = Permutation -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Permutation
perm'
      perm'' :: Permutation
perm'' = Permutation
perm' Permutation -> Permutation -> Permutation
forall a. [a] -> [a] -> [a]
++ [Int
tmp..Int
tmpInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1Int -> Int -> Int
forall a. Num a => a -> a -> a
+Shape num -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape num
shp]

      dims' :: [LMADDim num]
dims' = ((Shape num, LMADDim num) -> [LMADDim num])
-> [(Shape num, LMADDim num)] -> [LMADDim num]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (\(Shape num
shp_k, LMADDim num
srnp) ->
                            (num -> LMADDim num) -> Shape num -> [LMADDim num]
forall a b. (a -> b) -> [a] -> [b]
map num -> LMADDim num
forall num. Num num => num -> LMADDim num
fakeDim Shape num
shp_k [LMADDim num] -> [LMADDim num] -> [LMADDim num]
forall a. [a] -> [a] -> [a]
++ [LMADDim num
srnp]
                        ) ([(Shape num, LMADDim num)] -> [LMADDim num])
-> [(Shape num, LMADDim num)] -> [LMADDim num]
forall a b. (a -> b) -> a -> b
$ [Shape num] -> [LMADDim num] -> [(Shape num, LMADDim num)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Shape num]
shps' [LMADDim num]
dims
      lmad' :: LMAD num
lmad' = Permutation -> LMAD num -> LMAD num
forall num. Permutation -> LMAD num -> LMAD num
setLMADPermutation Permutation
perm'' (LMAD num -> LMAD num) -> LMAD num -> LMAD num
forall a b. (a -> b) -> a -> b
$ num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off ([LMADDim num]
dims' [LMADDim num] -> [LMADDim num] -> [LMADDim num]
forall a. [a] -> [a] -> [a]
++ (num -> LMADDim num) -> Shape num -> [LMADDim num]
forall a b. (a -> b) -> [a] -> [b]
map num -> LMADDim num
forall num. Num num => num -> LMADDim num
fakeDim Shape num
shp)
  in NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (LMAD num
lmad' LMAD num -> [LMAD num] -> NonEmpty (LMAD num)
forall a. a -> [a] -> NonEmpty a
:| [LMAD num]
lmads) Shape num
oshp Bool
False -- XXX: Can we be less conservative?
  where fakeDim :: num -> LMADDim num
fakeDim num
x = num -> num -> num -> Int -> Monotonicity -> LMADDim num
forall num. num -> num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim num
0 num
0 num
x Int
0 Monotonicity
Unknown

-- | Rotate an index function.
rotate :: (Eq num, IntegralExp num) =>
          IxFun num -> Indices num -> IxFun num
rotate :: IxFun num -> Indices num -> IxFun num
rotate  (IxFun (lmad :: LMAD num
lmad@(LMAD num
off [LMADDim num]
dims) :| [LMAD num]
lmads) Indices num
oshp Bool
cg) Indices num
offs =
  let dims' :: [LMADDim num]
dims' = (LMADDim num -> num -> LMADDim num)
-> [LMADDim num] -> Indices num -> [LMADDim num]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\(LMADDim num
s num
r num
n Int
p Monotonicity
f) num
o ->
                          if num
s num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0 then num -> num -> num -> Int -> Monotonicity -> LMADDim num
forall num. num -> num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim num
0 num
0 num
n Int
p Monotonicity
Unknown
                          else num -> num -> num -> Int -> Monotonicity -> LMADDim num
forall num. num -> num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim num
s (num
r num -> num -> num
forall a. Num a => a -> a -> a
+ num
o) num
n Int
p Monotonicity
f
                      ) [LMADDim num]
dims (Permutation -> Indices num -> Indices num
forall a. Permutation -> [a] -> [a]
permuteInv (LMAD num -> Permutation
forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad) Indices num
offs)
  in NonEmpty (LMAD num) -> Indices num -> Bool -> IxFun num
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off [LMADDim num]
dims' LMAD num -> [LMAD num] -> NonEmpty (LMAD num)
forall a. a -> [a] -> NonEmpty a
:| [LMAD num]
lmads) Indices num
oshp Bool
cg

-- | Handle the case where a slice can stay within a single LMAD.
sliceOneLMAD :: (Eq num, IntegralExp num) =>
                IxFun num -> Slice num -> Maybe (IxFun num)
sliceOneLMAD :: IxFun num -> Slice num -> Maybe (IxFun num)
sliceOneLMAD (IxFun (lmad :: LMAD num
lmad@(LMAD num
_ [LMADDim num]
ldims) :| [LMAD num]
lmads) Shape num
oshp Bool
cg) Slice num
is = do
  let perm :: Permutation
perm = LMAD num -> Permutation
forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad
      is' :: Slice num
is' = Permutation -> Slice num -> Slice num
forall a. Permutation -> [a] -> [a]
permuteInv Permutation
perm Slice num
is
      cg' :: Bool
cg' = Bool
cg Bool -> Bool -> Bool
&& LMAD num -> Slice num -> Bool
forall num.
(Eq num, IntegralExp num) =>
LMAD num -> Slice num -> Bool
slicePreservesContiguous LMAD num
lmad Slice num
is'
  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ LMAD num -> Slice num -> Bool
forall num.
(Eq num, IntegralExp num) =>
LMAD num -> Slice num -> Bool
harmlessRotation LMAD num
lmad Slice num
is'
  let lmad' :: LMAD num
lmad' = (LMAD num -> (DimIndex num, LMADDim num) -> LMAD num)
-> LMAD num -> [(DimIndex num, LMADDim num)] -> LMAD num
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl LMAD num -> (DimIndex num, LMADDim num) -> LMAD num
forall num.
(Eq num, IntegralExp num) =>
LMAD num -> (DimIndex num, LMADDim num) -> LMAD num
sliceOne (num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD (LMAD num -> num
forall num. LMAD num -> num
lmadOffset LMAD num
lmad) []) ([(DimIndex num, LMADDim num)] -> LMAD num)
-> [(DimIndex num, LMADDim num)] -> LMAD num
forall a b. (a -> b) -> a -> b
$ Slice num -> [LMADDim num] -> [(DimIndex num, LMADDim num)]
forall a b. [a] -> [b] -> [(a, b)]
zip Slice num
is' [LMADDim num]
ldims
      -- need to remove the fixed dims from the permutation
      perm' :: Permutation
perm' = Permutation -> Permutation -> Permutation
forall (t :: * -> *) a (t :: * -> *).
(Foldable t, Foldable t, Num a, Ord a) =>
t a -> t a -> [a]
updatePerm Permutation
perm (Permutation -> Permutation) -> Permutation -> Permutation
forall a b. (a -> b) -> a -> b
$ ((Int, DimIndex num) -> Int)
-> [(Int, DimIndex num)] -> Permutation
forall a b. (a -> b) -> [a] -> [b]
map (Int, DimIndex num) -> Int
forall a b. (a, b) -> a
fst ([(Int, DimIndex num)] -> Permutation)
-> [(Int, DimIndex num)] -> Permutation
forall a b. (a -> b) -> a -> b
$ ((Int, DimIndex num) -> Bool)
-> [(Int, DimIndex num)] -> [(Int, DimIndex num)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Maybe num -> Bool
forall a. Maybe a -> Bool
isJust (Maybe num -> Bool)
-> ((Int, DimIndex num) -> Maybe num)
-> (Int, DimIndex num)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DimIndex num -> Maybe num
forall d. DimIndex d -> Maybe d
dimFix (DimIndex num -> Maybe num)
-> ((Int, DimIndex num) -> DimIndex num)
-> (Int, DimIndex num)
-> Maybe num
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int, DimIndex num) -> DimIndex num
forall a b. (a, b) -> b
snd) ([(Int, DimIndex num)] -> [(Int, DimIndex num)])
-> [(Int, DimIndex num)] -> [(Int, DimIndex num)]
forall a b. (a -> b) -> a -> b
$
              Permutation -> Slice num -> [(Int, DimIndex num)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0..Slice num -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Slice num
is' Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] Slice num
is'

  IxFun num -> Maybe (IxFun num)
forall (m :: * -> *) a. Monad m => a -> m a
return (IxFun num -> Maybe (IxFun num)) -> IxFun num -> Maybe (IxFun num)
forall a b. (a -> b) -> a -> b
$ NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (Permutation -> LMAD num -> LMAD num
forall num. Permutation -> LMAD num -> LMAD num
setLMADPermutation Permutation
perm' LMAD num
lmad' LMAD num -> [LMAD num] -> NonEmpty (LMAD num)
forall a. a -> [a] -> NonEmpty a
:| [LMAD num]
lmads) Shape num
oshp Bool
cg'
  where updatePerm :: t a -> t a -> [a]
updatePerm t a
ps t a
inds = ([a] -> a -> [a]) -> [a] -> t a -> [a]
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\[a]
acc a
p -> [a]
acc [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ a -> [a]
decrease a
p) [] t a
ps
          where decrease :: a -> [a]
decrease a
p =
                  let d :: a
d = (a -> a -> a) -> a -> t a -> a
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\a
n a
i -> if a
i a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
p then -a
1
                                         else if a
i a -> a -> Bool
forall a. Ord a => a -> a -> Bool
> a
p
                                              then a
n
                                              else if a
n a -> a -> Bool
forall a. Eq a => a -> a -> Bool
/= -a
1 then a
n a -> a -> a
forall a. Num a => a -> a -> a
+ a
1
                                                   else a
n
                                ) a
0 t a
inds
                  in [a
p a -> a -> a
forall a. Num a => a -> a -> a
- a
d | a
d a -> a -> Bool
forall a. Eq a => a -> a -> Bool
/= -a
1]

        harmlessRotation' :: (Eq num, IntegralExp num) =>
                             LMADDim num -> DimIndex num -> Bool
        harmlessRotation' :: LMADDim num -> DimIndex num -> Bool
harmlessRotation' LMADDim num
_ (DimFix num
_)   = Bool
True
        harmlessRotation' (LMADDim num
0 num
_ num
_ Int
_ Monotonicity
_) DimIndex num
_  = Bool
True
        harmlessRotation' (LMADDim num
_ num
0 num
_ Int
_ Monotonicity
_) DimIndex num
_  = Bool
True
        harmlessRotation' (LMADDim num
_ num
_ num
n Int
_ Monotonicity
_) DimIndex num
dslc
            | DimIndex num
dslc DimIndex num -> DimIndex num -> Bool
forall a. Eq a => a -> a -> Bool
== num -> num -> num -> DimIndex num
forall d. d -> d -> d -> DimIndex d
DimSlice (num
n num -> num -> num
forall a. Num a => a -> a -> a
- num
1) num
n (-num
1) Bool -> Bool -> Bool
||
              DimIndex num
dslc DimIndex num -> DimIndex num -> Bool
forall a. Eq a => a -> a -> Bool
== num -> num -> DimIndex num
forall d. Num d => d -> d -> DimIndex d
unitSlice num
0 num
n      = Bool
True
        harmlessRotation' LMADDim num
_ DimIndex num
_            = Bool
False

        harmlessRotation :: (Eq num, IntegralExp num) =>
                             LMAD num -> Slice num -> Bool
        harmlessRotation :: LMAD num -> Slice num -> Bool
harmlessRotation (LMAD num
_ [LMADDim num]
dims) Slice num
iss =
            [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ([Bool] -> Bool) -> [Bool] -> Bool
forall a b. (a -> b) -> a -> b
$ (LMADDim num -> DimIndex num -> Bool)
-> [LMADDim num] -> Slice num -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith LMADDim num -> DimIndex num -> Bool
forall num.
(Eq num, IntegralExp num) =>
LMADDim num -> DimIndex num -> Bool
harmlessRotation' [LMADDim num]
dims Slice num
iss

        -- XXX: TODO: what happens to r on a negative-stride slice; is there
        -- such a case?
        sliceOne :: (Eq num, IntegralExp num) =>
                    LMAD num -> (DimIndex num, LMADDim num) -> LMAD num
        sliceOne :: LMAD num -> (DimIndex num, LMADDim num) -> LMAD num
sliceOne (LMAD num
off [LMADDim num]
dims) (DimFix num
i, LMADDim num
s num
r num
n Int
_ Monotonicity
_) =
            num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD (num
off num -> num -> num
forall a. Num a => a -> a -> a
+ (num, num, num) -> num -> num
forall num.
(Eq num, IntegralExp num) =>
(num, num, num) -> num -> num
flatOneDim (num
s, num
r, num
n) num
i) [LMADDim num]
dims
        sliceOne (LMAD num
off [LMADDim num]
dims) (DimSlice num
_ num
ne num
_, LMADDim num
0 num
_ num
_ Int
p Monotonicity
_) =
            num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off ([LMADDim num]
dims [LMADDim num] -> [LMADDim num] -> [LMADDim num]
forall a. [a] -> [a] -> [a]
++ [num -> num -> num -> Int -> Monotonicity -> LMADDim num
forall num. num -> num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim num
0 num
0 num
ne Int
p Monotonicity
Unknown])
        sliceOne (LMAD num
off [LMADDim num]
dims) (DimIndex num
dmind, dim :: LMADDim num
dim@(LMADDim num
_ num
_ num
n Int
_ Monotonicity
_))
            | DimIndex num
dmind DimIndex num -> DimIndex num -> Bool
forall a. Eq a => a -> a -> Bool
== num -> num -> DimIndex num
forall d. Num d => d -> d -> DimIndex d
unitSlice num
0 num
n = num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off ([LMADDim num]
dims [LMADDim num] -> [LMADDim num] -> [LMADDim num]
forall a. [a] -> [a] -> [a]
++ [LMADDim num
dim])
        sliceOne (LMAD num
off [LMADDim num]
dims) (DimIndex num
dmind, LMADDim num
s num
r num
n Int
p Monotonicity
m)
            | DimIndex num
dmind DimIndex num -> DimIndex num -> Bool
forall a. Eq a => a -> a -> Bool
== num -> num -> num -> DimIndex num
forall d. d -> d -> d -> DimIndex d
DimSlice (num
n num -> num -> num
forall a. Num a => a -> a -> a
- num
1) num
n (-num
1) =
              let r' :: num
r' = if num
r num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0 then num
0 else num
n num -> num -> num
forall a. Num a => a -> a -> a
- num
r
                  off' :: num
off' = num
off num -> num -> num
forall a. Num a => a -> a -> a
+ (num, num, num) -> num -> num
forall num.
(Eq num, IntegralExp num) =>
(num, num, num) -> num -> num
flatOneDim (num
s, num
0, num
n) (num
n num -> num -> num
forall a. Num a => a -> a -> a
- num
1)
              in  num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off' ([LMADDim num]
dims [LMADDim num] -> [LMADDim num] -> [LMADDim num]
forall a. [a] -> [a] -> [a]
++ [num -> num -> num -> Int -> Monotonicity -> LMADDim num
forall num. num -> num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim (num
s num -> num -> num
forall a. Num a => a -> a -> a
* (-num
1)) num
r' num
n Int
p (Monotonicity -> Monotonicity
invertMonotonicity Monotonicity
m)])
        sliceOne (LMAD num
off [LMADDim num]
dims) (DimSlice num
b num
ne num
0, LMADDim num
s num
r num
n Int
p Monotonicity
_) =
            num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD (num
off num -> num -> num
forall a. Num a => a -> a -> a
+ (num, num, num) -> num -> num
forall num.
(Eq num, IntegralExp num) =>
(num, num, num) -> num -> num
flatOneDim (num
s, num
r, num
n) num
b) ([LMADDim num]
dims [LMADDim num] -> [LMADDim num] -> [LMADDim num]
forall a. [a] -> [a] -> [a]
++ [num -> num -> num -> Int -> Monotonicity -> LMADDim num
forall num. num -> num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim num
0 num
0 num
ne Int
p Monotonicity
Unknown])
        sliceOne (LMAD num
off [LMADDim num]
dims) (DimSlice num
bs num
ns num
ss, LMADDim num
s num
0 num
_ Int
p Monotonicity
m) =
            let m' :: Monotonicity
m' = case num -> Maybe Int
forall e. IntegralExp e => e -> Maybe Int
sgn num
ss of
                       Just Int
1    -> Monotonicity
m
                       Just (-1) -> Monotonicity -> Monotonicity
invertMonotonicity Monotonicity
m
                       Maybe Int
_         -> Monotonicity
Unknown
            in  num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD (num
off num -> num -> num
forall a. Num a => a -> a -> a
+ num
s num -> num -> num
forall a. Num a => a -> a -> a
* num
bs) ([LMADDim num]
dims [LMADDim num] -> [LMADDim num] -> [LMADDim num]
forall a. [a] -> [a] -> [a]
++ [num -> num -> num -> Int -> Monotonicity -> LMADDim num
forall num. num -> num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim (num
ss num -> num -> num
forall a. Num a => a -> a -> a
* num
s) num
0 num
ns Int
p Monotonicity
m'])
        sliceOne LMAD num
_ (DimIndex num, LMADDim num)
_ = String -> LMAD num
forall a. HasCallStack => String -> a
error String
"slice: reached impossible case"

        slicePreservesContiguous :: (Eq num, IntegralExp num) =>
                                    LMAD num -> Slice num -> Bool
        slicePreservesContiguous :: LMAD num -> Slice num -> Bool
slicePreservesContiguous (LMAD num
_ [LMADDim num]
dims) Slice num
slc =
          -- remove from the slice the LMAD dimensions that have stride 0.
          -- If the LMAD was contiguous in mem, then these dims will not
          -- influence the contiguousness of the result.
          -- Also normalize the input slice, i.e., 0-stride and size-1
          -- slices are rewritten as DimFixed.
          let ([LMADDim num]
dims', Slice num
slc') = [(LMADDim num, DimIndex num)] -> ([LMADDim num], Slice num)
forall a b. [(a, b)] -> ([a], [b])
unzip ([(LMADDim num, DimIndex num)] -> ([LMADDim num], Slice num))
-> [(LMADDim num, DimIndex num)] -> ([LMADDim num], Slice num)
forall a b. (a -> b) -> a -> b
$
                ((LMADDim num, DimIndex num) -> Bool)
-> [(LMADDim num, DimIndex num)] -> [(LMADDim num, DimIndex num)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((num -> num -> Bool
forall a. Eq a => a -> a -> Bool
/= num
0) (num -> Bool)
-> ((LMADDim num, DimIndex num) -> num)
-> (LMADDim num, DimIndex num)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LMADDim num -> num
forall num. LMADDim num -> num
ldStride (LMADDim num -> num)
-> ((LMADDim num, DimIndex num) -> LMADDim num)
-> (LMADDim num, DimIndex num)
-> num
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (LMADDim num, DimIndex num) -> LMADDim num
forall a b. (a, b) -> a
fst) ([(LMADDim num, DimIndex num)] -> [(LMADDim num, DimIndex num)])
-> [(LMADDim num, DimIndex num)] -> [(LMADDim num, DimIndex num)]
forall a b. (a -> b) -> a -> b
$
                       [LMADDim num] -> Slice num -> [(LMADDim num, DimIndex num)]
forall a b. [a] -> [b] -> [(a, b)]
zip [LMADDim num]
dims (Slice num -> [(LMADDim num, DimIndex num)])
-> Slice num -> [(LMADDim num, DimIndex num)]
forall a b. (a -> b) -> a -> b
$ (DimIndex num -> DimIndex num) -> Slice num -> Slice num
forall a b. (a -> b) -> [a] -> [b]
map DimIndex num -> DimIndex num
forall num.
(Eq num, IntegralExp num) =>
DimIndex num -> DimIndex num
normIndex Slice num
slc
              -- Check that:
              -- 1. a clean split point exists between Fixed and Sliced dims
              -- 2. the outermost sliced dim has +/- 1 stride AND is unrotated or full.
              -- 3. the rest of inner sliced dims are full.
              (Bool
_, Bool
success) =
                ((Bool, Bool) -> (DimIndex num, LMADDim num) -> (Bool, Bool))
-> (Bool, Bool) -> [(DimIndex num, LMADDim num)] -> (Bool, Bool)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\(Bool
found, Bool
res) (DimIndex num
slcdim, LMADDim num
_ num
r num
n Int
_ Monotonicity
_) ->
                        case (DimIndex num
slcdim, Bool
found) of
                          (DimFix{},   Bool
True ) -> (Bool
found, Bool
False)
                          (DimFix{},   Bool
False) -> (Bool
found, Bool
res)
                          (DimSlice num
_ num
ne num
ds, Bool
False) -> -- outermost sliced dim: +/-1 stride
                            let res' :: Bool
res' = (num
r num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0 Bool -> Bool -> Bool
|| num
n num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
ne) Bool -> Bool -> Bool
&& (num
ds num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
1 Bool -> Bool -> Bool
|| num
ds num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== -num
1)
                            in (Bool
True, Bool
res Bool -> Bool -> Bool
&& Bool
res')
                          (DimSlice num
_ num
ne num
ds, Bool
True) ->  -- inner sliced dim: needs to be full
                            let res' :: Bool
res' = (num
n num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
ne) Bool -> Bool -> Bool
&& (num
ds num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
1 Bool -> Bool -> Bool
|| num
ds num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== -num
1)
                            in (Bool
found, Bool
res Bool -> Bool -> Bool
&& Bool
res')
                      ) (Bool
False, Bool
True) ([(DimIndex num, LMADDim num)] -> (Bool, Bool))
-> [(DimIndex num, LMADDim num)] -> (Bool, Bool)
forall a b. (a -> b) -> a -> b
$ Slice num -> [LMADDim num] -> [(DimIndex num, LMADDim num)]
forall a b. [a] -> [b] -> [(a, b)]
zip Slice num
slc' [LMADDim num]
dims'
          in Bool
success

        normIndex :: (Eq num, IntegralExp num) =>
                     DimIndex num -> DimIndex num
        normIndex :: DimIndex num -> DimIndex num
normIndex (DimSlice num
b num
1 num
_) = num -> DimIndex num
forall d. d -> DimIndex d
DimFix num
b
        normIndex (DimSlice num
b num
_ num
0) = num -> DimIndex num
forall d. d -> DimIndex d
DimFix num
b
        normIndex DimIndex num
d = DimIndex num
d

-- | Slice an index function.
slice :: (Eq num, IntegralExp num) =>
         IxFun num -> Slice num -> IxFun num
slice :: IxFun num -> Slice num -> IxFun num
slice IxFun num
_ [] = String -> IxFun num
forall a. HasCallStack => String -> a
error String
"slice: empty slice"
slice ixfun :: IxFun num
ixfun@(IxFun (lmad :: LMAD num
lmad@(LMAD num
_ [LMADDim num]
_) :| [LMAD num]
lmads) Shape num
oshp Bool
cg) Slice num
dim_slices
  -- Avoid identity slicing.
  | Slice num
dim_slices Slice num -> Slice num -> Bool
forall a. Eq a => a -> a -> Bool
== (num -> DimIndex num) -> Shape num -> Slice num
forall a b. (a -> b) -> [a] -> [b]
map (num -> num -> DimIndex num
forall d. Num d => d -> d -> DimIndex d
unitSlice num
0) (IxFun num -> Shape num
forall num. (Eq num, IntegralExp num) => IxFun num -> Shape num
shape IxFun num
ixfun) = IxFun num
ixfun
  | Just IxFun num
ixfun' <- IxFun num -> Slice num -> Maybe (IxFun num)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> Maybe (IxFun num)
sliceOneLMAD IxFun num
ixfun Slice num
dim_slices = IxFun num
ixfun'
  | Bool
otherwise =
    case IxFun num -> Slice num -> Maybe (IxFun num)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> Maybe (IxFun num)
sliceOneLMAD (Shape num -> IxFun num
forall num. IntegralExp num => Shape num -> IxFun num
iota (LMAD num -> Shape num
forall num. (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShape LMAD num
lmad)) Slice num
dim_slices of
      Just (IxFun (LMAD num
lmad' :| []) Shape num
_ Bool
cg') ->
        NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (LMAD num
lmad' LMAD num -> [LMAD num] -> NonEmpty (LMAD num)
forall a. a -> [a] -> NonEmpty a
:| LMAD num
lmad LMAD num -> [LMAD num] -> [LMAD num]
forall a. a -> [a] -> [a]
: [LMAD num]
lmads) Shape num
oshp (Bool
cg Bool -> Bool -> Bool
&& Bool
cg')
      Maybe (IxFun num)
_ -> String -> IxFun num
forall a. HasCallStack => String -> a
error String
"slice: reached impossible case"

-- | Handle the simple case where all reshape dimensions are coercions.
reshapeCoercion :: (Eq num, IntegralExp num) =>
                   IxFun num -> ShapeChange num -> Maybe (IxFun num)
reshapeCoercion :: IxFun num -> ShapeChange num -> Maybe (IxFun num)
reshapeCoercion (IxFun (lmad :: LMAD num
lmad@(LMAD num
off [LMADDim num]
dims) :| [LMAD num]
lmads) Shape num
_ Bool
cg) ShapeChange num
newshape = do
  let perm :: Permutation
perm = LMAD num -> Permutation
forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad
  (ShapeChange num
head_coercions, ShapeChange num
reshapes, ShapeChange num
tail_coercions) <- ShapeChange num
-> Maybe (ShapeChange num, ShapeChange num, ShapeChange num)
forall num.
(Eq num, IntegralExp num) =>
ShapeChange num
-> Maybe (ShapeChange num, ShapeChange num, ShapeChange num)
splitCoercions ShapeChange num
newshape
  let hd_len :: Int
hd_len = ShapeChange num -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange num
head_coercions
      num_coercions :: Int
num_coercions = Int
hd_len Int -> Int -> Int
forall a. Num a => a -> a -> a
+ ShapeChange num -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange num
tail_coercions
      dims' :: [LMADDim num]
dims' = Permutation -> [LMADDim num] -> [LMADDim num]
forall a. Permutation -> [a] -> [a]
permuteFwd Permutation
perm [LMADDim num]
dims
      mid_dims :: [LMADDim num]
mid_dims = Int -> [LMADDim num] -> [LMADDim num]
forall a. Int -> [a] -> [a]
take ([LMADDim num] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
num_coercions) ([LMADDim num] -> [LMADDim num]) -> [LMADDim num] -> [LMADDim num]
forall a b. (a -> b) -> a -> b
$ Int -> [LMADDim num] -> [LMADDim num]
forall a. Int -> [a] -> [a]
drop Int
hd_len [LMADDim num]
dims'
      num_rshps :: Int
num_rshps = ShapeChange num -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange num
reshapes
  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Int
num_rshps Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 Bool -> Bool -> Bool
|| (Int
num_rshps Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 Bool -> Bool -> Bool
&& [LMADDim num] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
mid_dims Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1))
  let dims'' :: [LMADDim num]
dims'' = ((Int, LMADDim num) -> LMADDim num)
-> [(Int, LMADDim num)] -> [LMADDim num]
forall a b. (a -> b) -> [a] -> [b]
map (Int, LMADDim num) -> LMADDim num
forall a b. (a, b) -> b
snd ([(Int, LMADDim num)] -> [LMADDim num])
-> [(Int, LMADDim num)] -> [LMADDim num]
forall a b. (a -> b) -> a -> b
$ ((Int, LMADDim num) -> (Int, LMADDim num) -> Ordering)
-> [(Int, LMADDim num)] -> [(Int, LMADDim num)]
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (Int -> Int -> Ordering)
-> ((Int, LMADDim num) -> Int)
-> (Int, LMADDim num)
-> (Int, LMADDim num)
-> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (Int, LMADDim num) -> Int
forall a b. (a, b) -> a
fst) ([(Int, LMADDim num)] -> [(Int, LMADDim num)])
-> [(Int, LMADDim num)] -> [(Int, LMADDim num)]
forall a b. (a -> b) -> a -> b
$
               (LMADDim num -> num -> (Int, LMADDim num))
-> [LMADDim num] -> Shape num -> [(Int, LMADDim num)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\LMADDim num
ld num
n -> (LMADDim num -> Int
forall num. LMADDim num -> Int
ldPerm LMADDim num
ld, LMADDim num
ld { ldShape :: num
ldShape = num
n }))
               [LMADDim num]
dims' (ShapeChange num -> Shape num
forall d. ShapeChange d -> [d]
newDims ShapeChange num
newshape)
      lmad' :: LMAD num
lmad' = num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off [LMADDim num]
dims''
  IxFun num -> Maybe (IxFun num)
forall (m :: * -> *) a. Monad m => a -> m a
return (IxFun num -> Maybe (IxFun num)) -> IxFun num -> Maybe (IxFun num)
forall a b. (a -> b) -> a -> b
$ NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (LMAD num
lmad' LMAD num -> [LMAD num] -> NonEmpty (LMAD num)
forall a. a -> [a] -> NonEmpty a
:| [LMAD num]
lmads) (ShapeChange num -> Shape num
forall d. ShapeChange d -> [d]
newDims ShapeChange num
newshape) Bool
cg

-- | Handle the case where a reshape operation can stay inside a single LMAD.
--
-- There are four conditions that all must hold for the result of a reshape
-- operation to remain in the one-LMAD domain:
--
--   (1) the permutation of the underlying LMAD must leave unchanged
--       the LMAD dimensions that were *not* reshape coercions.
--   (2) the repetition of dimensions of the underlying LMAD must
--       refer only to the coerced-dimensions of the reshape operation.
--   (3) similarly, the rotated dimensions must refer only to
--       dimensions that are coerced by the reshape operation.
--   (4) finally, the underlying memory is contiguous (and monotonous).
--
-- If any of these conditions do not hold, then the reshape operation will
-- conservatively add a new LMAD to the list, leading to a representation that
-- provides less opportunities for further analysis.
reshapeOneLMAD :: (Eq num, IntegralExp num) =>
                   IxFun num -> ShapeChange num -> Maybe (IxFun num)
reshapeOneLMAD :: IxFun num -> ShapeChange num -> Maybe (IxFun num)
reshapeOneLMAD ixfun :: IxFun num
ixfun@(IxFun (lmad :: LMAD num
lmad@(LMAD num
off [LMADDim num]
dims) :| [LMAD num]
lmads) Shape num
_ Bool
cg) ShapeChange num
newshape = do
  let perm :: Permutation
perm = LMAD num -> Permutation
forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad
  (ShapeChange num
head_coercions, ShapeChange num
reshapes, ShapeChange num
tail_coercions) <- ShapeChange num
-> Maybe (ShapeChange num, ShapeChange num, ShapeChange num)
forall num.
(Eq num, IntegralExp num) =>
ShapeChange num
-> Maybe (ShapeChange num, ShapeChange num, ShapeChange num)
splitCoercions ShapeChange num
newshape
  let hd_len :: Int
hd_len = ShapeChange num -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange num
head_coercions
      num_coercions :: Int
num_coercions = Int
hd_len Int -> Int -> Int
forall a. Num a => a -> a -> a
+ ShapeChange num -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange num
tail_coercions
      dims_perm :: [LMADDim num]
dims_perm = Permutation -> [LMADDim num] -> [LMADDim num]
forall a. Permutation -> [a] -> [a]
permuteFwd Permutation
perm [LMADDim num]
dims
      mid_dims :: [LMADDim num]
mid_dims = Int -> [LMADDim num] -> [LMADDim num]
forall a. Int -> [a] -> [a]
take ([LMADDim num] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
num_coercions) ([LMADDim num] -> [LMADDim num]) -> [LMADDim num] -> [LMADDim num]
forall a b. (a -> b) -> a -> b
$ Int -> [LMADDim num] -> [LMADDim num]
forall a. Int -> [a] -> [a]
drop Int
hd_len [LMADDim num]
dims_perm
      -- Ignore rotates, as we only care about not having rotates in the
      -- dimensions that aren't coercions (@mid_dims@), which we check
      -- separately.
      mon :: Monotonicity
mon = Bool -> IxFun num -> Monotonicity
forall num.
(Eq num, IntegralExp num) =>
Bool -> IxFun num -> Monotonicity
ixfunMonotonicityRots Bool
True IxFun num
ixfun

  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$
    -- checking conditions (2) and (3)
    (LMADDim num -> Bool) -> [LMADDim num] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\ (LMADDim num
s num
r num
_ Int
_ Monotonicity
_) -> num
s num -> num -> Bool
forall a. Eq a => a -> a -> Bool
/= num
0 Bool -> Bool -> Bool
&& num
r num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0) [LMADDim num]
mid_dims Bool -> Bool -> Bool
&&
    -- checking condition (1)
    Int -> Permutation -> Bool
forall a. (Eq a, Num a, Enum a) => a -> [a] -> Bool
consecutive Int
hd_len ((LMADDim num -> Int) -> [LMADDim num] -> Permutation
forall a b. (a -> b) -> [a] -> [b]
map LMADDim num -> Int
forall num. LMADDim num -> Int
ldPerm [LMADDim num]
mid_dims) Bool -> Bool -> Bool
&&
    -- checking condition (4)
    IxFun num -> Bool
forall a. IxFun a -> Bool
hasContiguousPerm IxFun num
ixfun Bool -> Bool -> Bool
&& Bool
cg Bool -> Bool -> Bool
&& (Monotonicity
mon Monotonicity -> Monotonicity -> Bool
forall a. Eq a => a -> a -> Bool
== Monotonicity
Inc Bool -> Bool -> Bool
|| Monotonicity
mon Monotonicity -> Monotonicity -> Bool
forall a. Eq a => a -> a -> Bool
== Monotonicity
Dec)

  -- make new permutation
  let rsh_len :: Int
rsh_len = ShapeChange num -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange num
reshapes
      diff :: Int
diff = ShapeChange num -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange num
newshape Int -> Int -> Int
forall a. Num a => a -> a -> a
- [LMADDim num] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims
      iota_shape :: Permutation
iota_shape = [Int
0..ShapeChange num -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange num
newshapeInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
      perm' :: Permutation
perm' = (Int -> Int) -> Permutation -> Permutation
forall a b. (a -> b) -> [a] -> [b]
map (\Int
i -> let ind :: Int
ind = if Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
hd_len
                                   then Int
i else Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
diff
                         in if (Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
hd_len) Bool -> Bool -> Bool
&& (Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
hd_len Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
rsh_len)
                            then Int
i -- already checked mid_dims not affected
                            else let p :: Int
p = LMADDim num -> Int
forall num. LMADDim num -> Int
ldPerm ([LMADDim num]
dims [LMADDim num] -> Int -> LMADDim num
forall a. [a] -> Int -> a
!! Int
ind)
                                 in if Int
p Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
hd_len
                                    then Int
p
                                    else Int
p Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
diff
                  ) Permutation
iota_shape
      -- split the dimensions
      ([(Int, (num, num))]
support_inds, [(Int, num)]
repeat_inds) =
        (([(Int, (num, num))], [(Int, num)])
 -> (Int, DimChange num, Int)
 -> ([(Int, (num, num))], [(Int, num)]))
-> ([(Int, (num, num))], [(Int, num)])
-> [(Int, DimChange num, Int)]
-> ([(Int, (num, num))], [(Int, num)])
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\([(Int, (num, num))]
sup, [(Int, num)]
rpt) (Int
i, DimChange num
shpdim, Int
ip) ->
                case (Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
hd_len, Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
hd_len Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
rsh_len, DimChange num
shpdim) of
                  (Bool
True,  Bool
_, DimCoercion num
n) ->
                    case [LMADDim num]
dims_perm [LMADDim num] -> Int -> LMADDim num
forall a. [a] -> Int -> a
!! Int
i of
                      (LMADDim num
0 num
_ num
_ Int
_ Monotonicity
_) -> ( [(Int, (num, num))]
sup, (Int
ip, num
n) (Int, num) -> [(Int, num)] -> [(Int, num)]
forall a. a -> [a] -> [a]
: [(Int, num)]
rpt )
                      (LMADDim num
_ num
r num
_ Int
_ Monotonicity
_) -> ( (Int
ip, (num
r, num
n)) (Int, (num, num)) -> [(Int, (num, num))] -> [(Int, (num, num))]
forall a. a -> [a] -> [a]
: [(Int, (num, num))]
sup, [(Int, num)]
rpt )
                  (Bool
_,  Bool
True, DimCoercion num
n) ->
                    case [LMADDim num]
dims_perm [LMADDim num] -> Int -> LMADDim num
forall a. [a] -> Int -> a
!! (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
diff) of
                      (LMADDim num
0 num
_ num
_ Int
_ Monotonicity
_) -> ( [(Int, (num, num))]
sup, (Int
ip, num
n) (Int, num) -> [(Int, num)] -> [(Int, num)]
forall a. a -> [a] -> [a]
: [(Int, num)]
rpt )
                      (LMADDim num
_ num
r num
_ Int
_ Monotonicity
_) -> ( (Int
ip, (num
r, num
n)) (Int, (num, num)) -> [(Int, (num, num))] -> [(Int, (num, num))]
forall a. a -> [a] -> [a]
: [(Int, (num, num))]
sup, [(Int, num)]
rpt )
                  (Bool
False, Bool
False, DimChange num
_) ->
                      ( (Int
ip, (num
0, DimChange num -> num
forall d. DimChange d -> d
newDim DimChange num
shpdim)) (Int, (num, num)) -> [(Int, (num, num))] -> [(Int, (num, num))]
forall a. a -> [a] -> [a]
: [(Int, (num, num))]
sup, [(Int, num)]
rpt )
                      -- already checked that the reshaped
                      -- dims cannot be repeats or rotates
                  (Bool, Bool, DimChange num)
_ -> String -> ([(Int, (num, num))], [(Int, num)])
forall a. HasCallStack => String -> a
error String
"reshape: reached impossible case"
              ) ([], []) ([(Int, DimChange num, Int)]
 -> ([(Int, (num, num))], [(Int, num)]))
-> [(Int, DimChange num, Int)]
-> ([(Int, (num, num))], [(Int, num)])
forall a b. (a -> b) -> a -> b
$ [(Int, DimChange num, Int)] -> [(Int, DimChange num, Int)]
forall a. [a] -> [a]
reverse ([(Int, DimChange num, Int)] -> [(Int, DimChange num, Int)])
-> [(Int, DimChange num, Int)] -> [(Int, DimChange num, Int)]
forall a b. (a -> b) -> a -> b
$ Permutation
-> ShapeChange num -> Permutation -> [(Int, DimChange num, Int)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 Permutation
iota_shape ShapeChange num
newshape Permutation
perm'

      (Permutation
sup_inds, [(num, num)]
support) = [(Int, (num, num))] -> (Permutation, [(num, num)])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Int, (num, num))] -> (Permutation, [(num, num)]))
-> [(Int, (num, num))] -> (Permutation, [(num, num)])
forall a b. (a -> b) -> a -> b
$ ((Int, (num, num)) -> (Int, (num, num)) -> Ordering)
-> [(Int, (num, num))] -> [(Int, (num, num))]
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (Int -> Int -> Ordering)
-> ((Int, (num, num)) -> Int)
-> (Int, (num, num))
-> (Int, (num, num))
-> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (Int, (num, num)) -> Int
forall a b. (a, b) -> a
fst) [(Int, (num, num))]
support_inds
      (Permutation
rpt_inds, Shape num
repeats) = [(Int, num)] -> (Permutation, Shape num)
forall a b. [(a, b)] -> ([a], [b])
unzip [(Int, num)]
repeat_inds
      LMAD num
off' [LMADDim num]
dims_sup = Monotonicity -> num -> [(num, num)] -> LMAD num
forall num.
IntegralExp num =>
Monotonicity -> num -> [(num, num)] -> LMAD num
makeRotIota Monotonicity
mon num
off [(num, num)]
support
      repeats' :: [LMADDim num]
repeats' = (num -> LMADDim num) -> Shape num -> [LMADDim num]
forall a b. (a -> b) -> [a] -> [b]
map (\num
n -> num -> num -> num -> Int -> Monotonicity -> LMADDim num
forall num. num -> num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim num
0 num
0 num
n Int
0 Monotonicity
Unknown) Shape num
repeats
      dims' :: [LMADDim num]
dims' = ((Int, LMADDim num) -> LMADDim num)
-> [(Int, LMADDim num)] -> [LMADDim num]
forall a b. (a -> b) -> [a] -> [b]
map (Int, LMADDim num) -> LMADDim num
forall a b. (a, b) -> b
snd ([(Int, LMADDim num)] -> [LMADDim num])
-> [(Int, LMADDim num)] -> [LMADDim num]
forall a b. (a -> b) -> a -> b
$ ((Int, LMADDim num) -> (Int, LMADDim num) -> Ordering)
-> [(Int, LMADDim num)] -> [(Int, LMADDim num)]
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (Int -> Int -> Ordering)
-> ((Int, LMADDim num) -> Int)
-> (Int, LMADDim num)
-> (Int, LMADDim num)
-> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (Int, LMADDim num) -> Int
forall a b. (a, b) -> a
fst)
              ([(Int, LMADDim num)] -> [(Int, LMADDim num)])
-> [(Int, LMADDim num)] -> [(Int, LMADDim num)]
forall a b. (a -> b) -> a -> b
$ Permutation -> [LMADDim num] -> [(Int, LMADDim num)]
forall a b. [a] -> [b] -> [(a, b)]
zip Permutation
sup_inds [LMADDim num]
dims_sup [(Int, LMADDim num)]
-> [(Int, LMADDim num)] -> [(Int, LMADDim num)]
forall a. [a] -> [a] -> [a]
++ Permutation -> [LMADDim num] -> [(Int, LMADDim num)]
forall a b. [a] -> [b] -> [(a, b)]
zip Permutation
rpt_inds [LMADDim num]
repeats'
      lmad' :: LMAD num
lmad' = num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off' [LMADDim num]
dims'
  IxFun num -> Maybe (IxFun num)
forall (m :: * -> *) a. Monad m => a -> m a
return (IxFun num -> Maybe (IxFun num)) -> IxFun num -> Maybe (IxFun num)
forall a b. (a -> b) -> a -> b
$ NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (Permutation -> LMAD num -> LMAD num
forall num. Permutation -> LMAD num -> LMAD num
setLMADPermutation Permutation
perm' LMAD num
lmad' LMAD num -> [LMAD num] -> NonEmpty (LMAD num)
forall a. a -> [a] -> NonEmpty a
:| [LMAD num]
lmads) (ShapeChange num -> Shape num
forall d. ShapeChange d -> [d]
newDims ShapeChange num
newshape) Bool
cg
  where consecutive :: a -> [a] -> Bool
consecutive a
_ [] = Bool
True
        consecutive a
i [a
p]= a
i a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
p
        consecutive a
i [a]
ps = [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ([Bool] -> Bool) -> [Bool] -> Bool
forall a b. (a -> b) -> a -> b
$ (a -> a -> Bool) -> [a] -> [a] -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith a -> a -> Bool
forall a. Eq a => a -> a -> Bool
(==) [a]
ps [a
i, a
ia -> a -> a
forall a. Num a => a -> a -> a
+a
1..]

splitCoercions :: (Eq num, IntegralExp num) =>
                  ShapeChange num -> Maybe (ShapeChange num, ShapeChange num, ShapeChange num)
splitCoercions :: ShapeChange num
-> Maybe (ShapeChange num, ShapeChange num, ShapeChange num)
splitCoercions ShapeChange num
newshape' = do
  let (ShapeChange num
head_coercions, ShapeChange num
newshape'') = (DimChange num -> Bool)
-> ShapeChange num -> (ShapeChange num, ShapeChange num)
forall a. (a -> Bool) -> [a] -> ([a], [a])
span DimChange num -> Bool
forall d. DimChange d -> Bool
isCoercion ShapeChange num
newshape'
      (ShapeChange num
reshapes, ShapeChange num
tail_coercions) = (DimChange num -> Bool)
-> ShapeChange num -> (ShapeChange num, ShapeChange num)
forall a. (a -> Bool) -> [a] -> ([a], [a])
break DimChange num -> Bool
forall d. DimChange d -> Bool
isCoercion ShapeChange num
newshape''
  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard ((DimChange num -> Bool) -> ShapeChange num -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all DimChange num -> Bool
forall d. DimChange d -> Bool
isCoercion ShapeChange num
tail_coercions)
  (ShapeChange num, ShapeChange num, ShapeChange num)
-> Maybe (ShapeChange num, ShapeChange num, ShapeChange num)
forall (m :: * -> *) a. Monad m => a -> m a
return (ShapeChange num
head_coercions, ShapeChange num
reshapes, ShapeChange num
tail_coercions)
  where isCoercion :: DimChange d -> Bool
isCoercion DimCoercion{} = Bool
True
        isCoercion DimChange d
_ = Bool
False

-- | Reshape an index function.
reshape :: (Eq num, IntegralExp num) =>
           IxFun num -> ShapeChange num -> IxFun num
reshape :: IxFun num -> ShapeChange num -> IxFun num
reshape IxFun num
ixfun ShapeChange num
new_shape
  | Just IxFun num
ixfun' <- IxFun num -> ShapeChange num -> Maybe (IxFun num)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> ShapeChange num -> Maybe (IxFun num)
reshapeCoercion IxFun num
ixfun ShapeChange num
new_shape = IxFun num
ixfun'
  | Just IxFun num
ixfun' <- IxFun num -> ShapeChange num -> Maybe (IxFun num)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> ShapeChange num -> Maybe (IxFun num)
reshapeOneLMAD IxFun num
ixfun ShapeChange num
new_shape = IxFun num
ixfun'
reshape (IxFun (LMAD num
lmad0 :| [LMAD num]
lmad0s) Shape num
oshp Bool
cg) ShapeChange num
new_shape =
  case Shape num -> IxFun num
forall num. IntegralExp num => Shape num -> IxFun num
iota (ShapeChange num -> Shape num
forall d. ShapeChange d -> [d]
newDims ShapeChange num
new_shape) of
    IxFun (LMAD num
lmad :| []) Shape num
_ Bool
_ -> NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (LMAD num
lmad LMAD num -> [LMAD num] -> NonEmpty (LMAD num)
forall a. a -> [a] -> NonEmpty a
:| LMAD num
lmad0 LMAD num -> [LMAD num] -> [LMAD num]
forall a. a -> [a] -> [a]
: [LMAD num]
lmad0s) Shape num
oshp Bool
cg
    IxFun num
_ -> String -> IxFun num
forall a. HasCallStack => String -> a
error String
"reshape: reached impossible case"

-- | The number of dimensions in the domain of the input function.
rank :: IntegralExp num =>
        IxFun num -> Int
rank :: IxFun num -> Int
rank (IxFun (LMAD num
_ [LMADDim num]
sss :| [LMAD num]
_) Shape num
_ Bool
_) = [LMADDim num] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
sss

-- | Handle the case where a rebase operation can stay within m + n - 1 LMADs,
-- where m is the number of LMADs in the index function, and n is the number of
-- LMADs in the new base.  If both index function have only on LMAD, this means
-- that we stay within the single-LMAD domain.
--
-- We can often stay in that domain if the original ixfun is essentially a
-- slice, e.g. `x[i, (k1,m,s1), (k2,n,s2)] = orig`.
--
-- XXX: TODO: handle repetitions in both lmads.
--
-- How to handle repeated dimensions in the original?
--
--   (a) Shave them off of the last lmad of original
--   (b) Compose the result from (a) with the first
--       lmad of the new base
--   (c) apply a repeat operation on the result of (b).
--
-- However, I strongly suspect that for in-place update what we need is actually
-- the INVERSE of the rebase function, i.e., given an index function new-base
-- and another one orig, compute the index function ixfun0 such that:
--
--   new-base == rebase ixfun0 ixfun, or equivalently:
--   new-base == ixfun o ixfun0
--
-- because then I can go bottom up and compose with ixfun0 all the index
-- functions corresponding to the memory block associated with ixfun.
rebaseNice :: (Eq num, IntegralExp num) =>
              IxFun num -> IxFun num -> Maybe (IxFun num)
rebaseNice :: IxFun num -> IxFun num -> Maybe (IxFun num)
rebaseNice
  new_base :: IxFun num
new_base@(IxFun (LMAD num
lmad_base :| [LMAD num]
lmads_base) Shape num
_ Bool
cg_base)
  ixfun :: IxFun num
ixfun@(IxFun NonEmpty (LMAD num)
lmads Shape num
shp Bool
cg) = do
  let (LMAD num
lmad_full :| [LMAD num]
lmads') = NonEmpty (LMAD num) -> NonEmpty (LMAD num)
forall a. NonEmpty a -> NonEmpty a
NE.reverse NonEmpty (LMAD num)
lmads
      (([Shape num]
outer_shapes, Shape num
inner_shape), LMAD num
lmad) = LMAD num -> (([Shape num], Shape num), LMAD num)
forall num.
(Eq num, IntegralExp num) =>
LMAD num -> (([Shape num], Shape num), LMAD num)
shaveoffRepeats LMAD num
lmad_full
      dims :: [LMADDim num]
dims = LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad
      perm :: Permutation
perm = LMAD num -> Permutation
forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad
      perm_base :: Permutation
perm_base = LMAD num -> Permutation
forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad_base

  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$
    -- Core rebase condition.
    IxFun num -> Shape num
forall a. IxFun a -> [a]
base IxFun num
ixfun Shape num -> Shape num -> Bool
forall a. Eq a => a -> a -> Bool
== IxFun num -> Shape num
forall num. (Eq num, IntegralExp num) => IxFun num -> Shape num
shape IxFun num
new_base
    -- Conservative safety conditions: ixfun is contiguous and has known
    -- monotonicity for all dimensions.
    Bool -> Bool -> Bool
&& Bool
cg Bool -> Bool -> Bool
&& (LMADDim num -> Bool) -> [LMADDim num] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((Monotonicity -> Monotonicity -> Bool
forall a. Eq a => a -> a -> Bool
/= Monotonicity
Unknown) (Monotonicity -> Bool)
-> (LMADDim num -> Monotonicity) -> LMADDim num -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LMADDim num -> Monotonicity
forall num. LMADDim num -> Monotonicity
ldMon) [LMADDim num]
dims
    -- XXX: We should be able to handle some basic cases where both index
    -- functions have non-trivial permutations.
    Bool -> Bool -> Bool
&& (IxFun num -> Bool
forall a. IxFun a -> Bool
hasContiguousPerm IxFun num
ixfun Bool -> Bool -> Bool
|| IxFun num -> Bool
forall a. IxFun a -> Bool
hasContiguousPerm IxFun num
new_base)
    -- We need the permutations to be of the same size if we want to compose
    -- them.  They don't have to be of the same size if the ixfun has a trivial
    -- permutation.  Supporting this latter case allows us to rebase when ixfun
    -- has been created by slicing with fixed dimensions.
    Bool -> Bool -> Bool
&& (Permutation -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Permutation
perm Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Permutation -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Permutation
perm_base Bool -> Bool -> Bool
|| IxFun num -> Bool
forall a. IxFun a -> Bool
hasContiguousPerm IxFun num
ixfun)
    -- To not have to worry about ixfun having non-1 strides, we also check that
    -- it is a row-major array (modulo permutation, which is handled
    -- separately).  Accept a non-full innermost dimension.  XXX: Maybe this can
    -- be less conservative?
    Bool -> Bool -> Bool
&& [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ((num -> LMADDim num -> Bool -> Bool)
-> Shape num -> [LMADDim num] -> [Bool] -> [Bool]
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 (\num
sn LMADDim num
ld Bool
inner -> num
sn num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== LMADDim num -> num
forall num. LMADDim num -> num
ldShape LMADDim num
ld Bool -> Bool -> Bool
|| (Bool
inner Bool -> Bool -> Bool
&& LMADDim num -> num
forall num. LMADDim num -> num
ldStride LMADDim num
ld num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
1))
            Shape num
shp [LMADDim num]
dims (Int -> Bool -> [Bool]
forall a. Int -> a -> [a]
replicate ([LMADDim num] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Bool
False [Bool] -> [Bool] -> [Bool]
forall a. [a] -> [a] -> [a]
++ [Bool
True]))

  -- Compose permutations, reverse strides and adjust offset if necessary.
  let perm_base' :: Permutation
perm_base' = if IxFun num -> Bool
forall a. IxFun a -> Bool
hasContiguousPerm IxFun num
ixfun
                   then Permutation
perm_base
                   else (Int -> Int) -> Permutation -> Permutation
forall a b. (a -> b) -> [a] -> [b]
map (Permutation
perm Permutation -> Int -> Int
forall a. [a] -> Int -> a
!!) Permutation
perm_base
      lmad_base' :: LMAD num
lmad_base' = Permutation -> LMAD num -> LMAD num
forall num. Permutation -> LMAD num -> LMAD num
setLMADPermutation Permutation
perm_base' LMAD num
lmad_base
      dims_base :: [LMADDim num]
dims_base = LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad_base'
      n_fewer_dims :: Int
n_fewer_dims = [LMADDim num] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims_base Int -> Int -> Int
forall a. Num a => a -> a -> a
- [LMADDim num] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims
      ([LMADDim num]
dims_base', Shape num
offs_contrib) = [(LMADDim num, num)] -> ([LMADDim num], Shape num)
forall a b. [(a, b)] -> ([a], [b])
unzip ([(LMADDim num, num)] -> ([LMADDim num], Shape num))
-> [(LMADDim num, num)] -> ([LMADDim num], Shape num)
forall a b. (a -> b) -> a -> b
$
        (LMADDim num -> LMADDim num -> (LMADDim num, num))
-> [LMADDim num] -> [LMADDim num] -> [(LMADDim num, num)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\(LMADDim num
s1 num
r1 num
n1 Int
p1 Monotonicity
_) (LMADDim num
_ num
r2 num
_ Int
_ Monotonicity
m2) ->
                   let (num
s', num
off') | Monotonicity
m2 Monotonicity -> Monotonicity -> Bool
forall a. Eq a => a -> a -> Bool
== Monotonicity
Inc = (num
s1, num
0)
                                  | Bool
otherwise = (num
s1 num -> num -> num
forall a. Num a => a -> a -> a
* (-num
1), num
s1 num -> num -> num
forall a. Num a => a -> a -> a
* (num
n1 num -> num -> num
forall a. Num a => a -> a -> a
- num
1))
                       r' :: num
r' | Monotonicity
m2 Monotonicity -> Monotonicity -> Bool
forall a. Eq a => a -> a -> Bool
== Monotonicity
Inc = if num
r2 num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0 then num
r1 else num
r1 num -> num -> num
forall a. Num a => a -> a -> a
+ num
r2
                          | num
r1 num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0 = num
r2
                          | num
r2 num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0 = num
n1 num -> num -> num
forall a. Num a => a -> a -> a
- num
r1
                          | Bool
otherwise = num
n1 num -> num -> num
forall a. Num a => a -> a -> a
- num
r1 num -> num -> num
forall a. Num a => a -> a -> a
+ num
r2
                   in (num -> num -> num -> Int -> Monotonicity -> LMADDim num
forall num. num -> num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim num
s' num
r' num
n1 (Int
p1 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
n_fewer_dims) Monotonicity
Inc, num
off'))
        -- If @dims@ is morally a slice, it might have fewer dimensions than
        -- @dims_base@.  Drop extraneous outer dimensions.
        (Int -> [LMADDim num] -> [LMADDim num]
forall a. Int -> [a] -> [a]
drop Int
n_fewer_dims [LMADDim num]
dims_base) [LMADDim num]
dims
      off_base :: num
off_base = LMAD num -> num
forall num. LMAD num -> num
lmadOffset LMAD num
lmad_base' num -> num -> num
forall a. Num a => a -> a -> a
+ Shape num -> num
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum Shape num
offs_contrib
      lmad_base'' :: LMAD num
lmad_base''
        | LMAD num -> num
forall num. LMAD num -> num
lmadOffset LMAD num
lmad num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0 = num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off_base [LMADDim num]
dims_base'
        | Bool
otherwise =
            -- If the innermost dimension of the ixfun was not full (but still
            -- had a stride of 1), add its offset relative to the new base.
            Shape num -> LMAD num -> LMAD num
forall num. Shape num -> LMAD num -> LMAD num
setLMADShape (LMAD num -> Shape num
forall num. (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShape LMAD num
lmad)
            (num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD (num
off_base num -> num -> num
forall a. Num a => a -> a -> a
+ LMADDim num -> num
forall num. LMADDim num -> num
ldStride ([LMADDim num] -> LMADDim num
forall a. [a] -> a
last [LMADDim num]
dims_base) num -> num -> num
forall a. Num a => a -> a -> a
* LMAD num -> num
forall num. LMAD num -> num
lmadOffset LMAD num
lmad)
             [LMADDim num]
dims_base')
      new_base' :: IxFun num
new_base' = NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (LMAD num
lmad_base'' LMAD num -> [LMAD num] -> NonEmpty (LMAD num)
forall a. a -> [a] -> NonEmpty a
:| [LMAD num]
lmads_base) Shape num
shp Bool
cg_base
      IxFun NonEmpty (LMAD num)
lmads_base' Shape num
_ Bool
_ = if (Shape num -> Bool) -> [Shape num] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Shape num -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Shape num]
outer_shapes Bool -> Bool -> Bool
&& Shape num -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Shape num
inner_shape
                              then IxFun num
new_base'
                              else IxFun num -> [Shape num] -> Shape num -> IxFun num
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> [Shape num] -> Shape num -> IxFun num
repeat IxFun num
new_base' [Shape num]
outer_shapes Shape num
inner_shape
      lmads'' :: NonEmpty (LMAD num)
lmads'' = [LMAD num]
lmads' [LMAD num] -> NonEmpty (LMAD num) -> NonEmpty (LMAD num)
forall a. [a] -> NonEmpty a -> NonEmpty a
++@ NonEmpty (LMAD num)
lmads_base'
  IxFun num -> Maybe (IxFun num)
forall (m :: * -> *) a. Monad m => a -> m a
return (IxFun num -> Maybe (IxFun num)) -> IxFun num -> Maybe (IxFun num)
forall a b. (a -> b) -> a -> b
$ NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun NonEmpty (LMAD num)
lmads'' Shape num
shp (Bool
cg Bool -> Bool -> Bool
&& Bool
cg_base)
  where shaveoffRepeats :: (Eq num, IntegralExp num) =>
                           LMAD num -> (([Shape num], Shape num), LMAD num)
        shaveoffRepeats :: LMAD num -> (([Shape num], Shape num), LMAD num)
shaveoffRepeats LMAD num
lmad =
        -- Given an input lmad, this function computes a repetition @r@ and a new lmad
        -- @res@, such that @repeat r res@ is identical to the input lmad.
          let perm :: Permutation
perm = LMAD num -> Permutation
forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad
              dims :: [LMADDim num]
dims = LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad
              -- compute the Repeat:
              resacc :: [Shape num]
resacc= ([Shape num] -> LMADDim num -> [Shape num])
-> [Shape num] -> [LMADDim num] -> [Shape num]
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\[Shape num]
acc (LMADDim num
s num
_ num
n Int
_ Monotonicity
_) ->
                              case [Shape num]
acc of
                                Shape num
rpt : [Shape num]
acc0 ->
                                    if num
s num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0 then (num
n num -> Shape num -> Shape num
forall a. a -> [a] -> [a]
: Shape num
rpt) Shape num -> [Shape num] -> [Shape num]
forall a. a -> [a] -> [a]
: [Shape num]
acc0
                                    else [] Shape num -> [Shape num] -> [Shape num]
forall a. a -> [a] -> [a]
: (Shape num
rpt Shape num -> [Shape num] -> [Shape num]
forall a. a -> [a] -> [a]
: [Shape num]
acc0)
                                [Shape num]
_ -> String -> [Shape num]
forall a. HasCallStack => String -> a
error String
"shaveoffRepeats: empty accumulator"
                            ) [[]] ([LMADDim num] -> [Shape num]) -> [LMADDim num] -> [Shape num]
forall a b. (a -> b) -> a -> b
$ [LMADDim num] -> [LMADDim num]
forall a. [a] -> [a]
reverse ([LMADDim num] -> [LMADDim num]) -> [LMADDim num] -> [LMADDim num]
forall a b. (a -> b) -> a -> b
$ Permutation -> [LMADDim num] -> [LMADDim num]
forall a. Permutation -> [a] -> [a]
permuteFwd Permutation
perm [LMADDim num]
dims
              last_shape :: Shape num
last_shape = [Shape num] -> Shape num
forall a. [a] -> a
last [Shape num]
resacc
              shapes :: [Shape num]
shapes = Int -> [Shape num] -> [Shape num]
forall a. Int -> [a] -> [a]
take ([Shape num] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Shape num]
resacc Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) [Shape num]
resacc
              -- update permutation and lmad:
              howManyRepLT :: Int -> a
howManyRepLT Int
k =
                (a -> LMADDim num -> a) -> a -> [LMADDim num] -> a
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\a
i (LMADDim num
s num
_ num
_ Int
p Monotonicity
_) ->
                         if num
s num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0 Bool -> Bool -> Bool
&& Int
p Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
k then a
i a -> a -> a
forall a. Num a => a -> a -> a
+ a
1 else a
i
                      ) a
0 [LMADDim num]
dims
              dims' :: [LMADDim num]
dims' = ([LMADDim num] -> LMADDim num -> [LMADDim num])
-> [LMADDim num] -> [LMADDim num] -> [LMADDim num]
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\[LMADDim num]
acc (LMADDim num
s num
r num
n Int
p Monotonicity
info) ->
                               if num
s num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0 then [LMADDim num]
acc
                               else let p' :: Int
p' = Int
p Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int -> Int
forall a. Num a => Int -> a
howManyRepLT Int
p
                                    in num -> num -> num -> Int -> Monotonicity -> LMADDim num
forall num. num -> num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim num
s num
r num
n Int
p' Monotonicity
info LMADDim num -> [LMADDim num] -> [LMADDim num]
forall a. a -> [a] -> [a]
: [LMADDim num]
acc
                             ) [] ([LMADDim num] -> [LMADDim num]) -> [LMADDim num] -> [LMADDim num]
forall a b. (a -> b) -> a -> b
$ [LMADDim num] -> [LMADDim num]
forall a. [a] -> [a]
reverse [LMADDim num]
dims
              lmad' :: LMAD num
lmad' = num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD (LMAD num -> num
forall num. LMAD num -> num
lmadOffset LMAD num
lmad) [LMADDim num]
dims'
          in (([Shape num]
shapes, Shape num
last_shape), LMAD num
lmad')

-- | Rebase an index function on top of a new base.
rebase :: (Eq num, IntegralExp num) =>
          IxFun num -> IxFun num -> IxFun num
rebase :: IxFun num -> IxFun num -> IxFun num
rebase new_base :: IxFun num
new_base@(IxFun NonEmpty (LMAD num)
lmads_base Shape num
shp_base Bool
cg_base) ixfun :: IxFun num
ixfun@(IxFun NonEmpty (LMAD num)
lmads Shape num
shp Bool
cg)
  | Just IxFun num
ixfun' <- IxFun num -> IxFun num -> Maybe (IxFun num)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> IxFun num -> Maybe (IxFun num)
rebaseNice IxFun num
new_base IxFun num
ixfun = IxFun num
ixfun'
  -- In the general case just concatenate LMADs since this refers to index
  -- function composition, which is always safe.
  | Bool
otherwise =
      let (NonEmpty (LMAD num)
lmads_base', Shape num
shp_base') =
            if IxFun num -> Shape num
forall a. IxFun a -> [a]
base IxFun num
ixfun Shape num -> Shape num -> Bool
forall a. Eq a => a -> a -> Bool
== IxFun num -> Shape num
forall num. (Eq num, IntegralExp num) => IxFun num -> Shape num
shape IxFun num
new_base
            then (NonEmpty (LMAD num)
lmads_base, Shape num
shp_base)
            else let IxFun NonEmpty (LMAD num)
lmads' Shape num
shp_base'' Bool
_ = IxFun num -> ShapeChange num -> IxFun num
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> ShapeChange num -> IxFun num
reshape IxFun num
new_base (ShapeChange num -> IxFun num) -> ShapeChange num -> IxFun num
forall a b. (a -> b) -> a -> b
$ (num -> DimChange num) -> Shape num -> ShapeChange num
forall a b. (a -> b) -> [a] -> [b]
map num -> DimChange num
forall d. d -> DimChange d
DimCoercion Shape num
shp
                 in (NonEmpty (LMAD num)
lmads', Shape num
shp_base'')
      in NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (NonEmpty (LMAD num)
lmads NonEmpty (LMAD num) -> NonEmpty (LMAD num) -> NonEmpty (LMAD num)
forall a. NonEmpty a -> NonEmpty a -> NonEmpty a
@++@ NonEmpty (LMAD num)
lmads_base') Shape num
shp_base' (Bool
cg Bool -> Bool -> Bool
&& Bool
cg_base)

ixfunMonotonicity :: (Eq num, IntegralExp num) => IxFun num -> Monotonicity
ixfunMonotonicity :: IxFun num -> Monotonicity
ixfunMonotonicity = Bool -> IxFun num -> Monotonicity
forall num.
(Eq num, IntegralExp num) =>
Bool -> IxFun num -> Monotonicity
ixfunMonotonicityRots Bool
False

-- | Offset index.  Results in the index function corresponding to indexing with
-- @i@ on the outermost dimension.
offsetIndex :: (Eq num, IntegralExp num) =>
               IxFun num -> num -> IxFun num
offsetIndex :: IxFun num -> num -> IxFun num
offsetIndex IxFun num
ixfun num
i | num
i num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0 = IxFun num
ixfun
offsetIndex IxFun num
ixfun num
i =
  case IxFun num -> Shape num
forall num. (Eq num, IntegralExp num) => IxFun num -> Shape num
shape IxFun num
ixfun of
    num
d : Shape num
ds -> IxFun num -> Slice num -> IxFun num
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
slice IxFun num
ixfun (num -> num -> num -> DimIndex num
forall d. d -> d -> d -> DimIndex d
DimSlice num
i (num
d num -> num -> num
forall a. Num a => a -> a -> a
- num
i) num
1 DimIndex num -> Slice num -> Slice num
forall a. a -> [a] -> [a]
: (num -> DimIndex num) -> Shape num -> Slice num
forall a b. (a -> b) -> [a] -> [b]
map (num -> num -> DimIndex num
forall d. Num d => d -> d -> DimIndex d
unitSlice num
0) Shape num
ds)
    [] -> String -> IxFun num
forall a. HasCallStack => String -> a
error String
"offsetIndex: underlying index function has rank zero"

-- | If the memory support of the index function is contiguous and row-major
-- (i.e., no transpositions, repetitions, rotates, etc.), then this should
-- return the offset from which the memory-support of this index function
-- starts.
linearWithOffset :: (Eq num, IntegralExp num) =>
                    IxFun num -> num -> Maybe num
linearWithOffset :: IxFun num -> num -> Maybe num
linearWithOffset ixfun :: IxFun num
ixfun@(IxFun (LMAD num
lmad :| []) Shape num
_ Bool
cg) num
elem_size
  | IxFun num -> Bool
forall a. IxFun a -> Bool
hasContiguousPerm IxFun num
ixfun Bool -> Bool -> Bool
&& Bool
cg Bool -> Bool -> Bool
&& IxFun num -> Monotonicity
forall num. (Eq num, IntegralExp num) => IxFun num -> Monotonicity
ixfunMonotonicity IxFun num
ixfun Monotonicity -> Monotonicity -> Bool
forall a. Eq a => a -> a -> Bool
== Monotonicity
Inc =
    num -> Maybe num
forall a. a -> Maybe a
Just (num -> Maybe num) -> num -> Maybe num
forall a b. (a -> b) -> a -> b
$ LMAD num -> num
forall num. LMAD num -> num
lmadOffset LMAD num
lmad num -> num -> num
forall a. Num a => a -> a -> a
* num
elem_size
linearWithOffset IxFun num
_ num
_ = Maybe num
forall a. Maybe a
Nothing

-- | Similar restrictions to @linearWithOffset@ except for transpositions, which
-- are returned together with the offset.
rearrangeWithOffset :: (Eq num, IntegralExp num) =>
                       IxFun num -> num -> Maybe (num, [(Int,num)])
rearrangeWithOffset :: IxFun num -> num -> Maybe (num, [(Int, num)])
rearrangeWithOffset (IxFun (LMAD num
lmad :| []) Shape num
oshp Bool
cg) num
elem_size = do
  -- Note that @cg@ describes whether the index function is
  -- contiguous, *ignoring permutations*.  This function requires that
  -- functionality.
  let perm :: Permutation
perm = LMAD num -> Permutation
forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad
      perm_contig :: Permutation
perm_contig = [Int
0..Permutation -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Permutation
permInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
  num
offset <- IxFun num -> num -> Maybe num
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
linearWithOffset
            (NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (Permutation -> LMAD num -> LMAD num
forall num. Permutation -> LMAD num -> LMAD num
setLMADPermutation Permutation
perm_contig LMAD num
lmad LMAD num -> [LMAD num] -> NonEmpty (LMAD num)
forall a. a -> [a] -> NonEmpty a
:| []) Shape num
oshp Bool
cg) num
elem_size
  (num, [(Int, num)]) -> Maybe (num, [(Int, num)])
forall (m :: * -> *) a. Monad m => a -> m a
return (num
offset, Permutation -> Shape num -> [(Int, num)]
forall a b. [a] -> [b] -> [(a, b)]
zip Permutation
perm (Permutation -> Shape num -> Shape num
forall a. Permutation -> [a] -> [a]
permuteFwd Permutation
perm (LMAD num -> Shape num
forall num. (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShapeBase LMAD num
lmad)))
rearrangeWithOffset IxFun num
_ num
_ = Maybe (num, [(Int, num)])
forall a. Maybe a
Nothing

-- | Is this a row-major array starting at offset zero?
isLinear :: (Eq num, IntegralExp num) => IxFun num -> Bool
isLinear :: IxFun num -> Bool
isLinear = (Maybe num -> Maybe num -> Bool
forall a. Eq a => a -> a -> Bool
== num -> Maybe num
forall a. a -> Maybe a
Just num
0) (Maybe num -> Bool)
-> (IxFun num -> Maybe num) -> IxFun num -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IxFun num -> num -> Maybe num) -> num -> IxFun num -> Maybe num
forall a b c. (a -> b -> c) -> b -> a -> c
flip IxFun num -> num -> Maybe num
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
linearWithOffset num
1

permuteFwd :: Permutation -> [a] -> [a]
permuteFwd :: Permutation -> [a] -> [a]
permuteFwd Permutation
ps [a]
elems = (Int -> a) -> Permutation -> [a]
forall a b. (a -> b) -> [a] -> [b]
map ([a]
elems [a] -> Int -> a
forall a. [a] -> Int -> a
!!) Permutation
ps

permuteInv :: Permutation -> [a] -> [a]
permuteInv :: Permutation -> [a] -> [a]
permuteInv Permutation
ps [a]
elems = ((Int, a) -> a) -> [(Int, a)] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map (Int, a) -> a
forall a b. (a, b) -> b
snd ([(Int, a)] -> [a]) -> [(Int, a)] -> [a]
forall a b. (a -> b) -> a -> b
$ ((Int, a) -> (Int, a) -> Ordering) -> [(Int, a)] -> [(Int, a)]
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (Int -> Int -> Ordering)
-> ((Int, a) -> Int) -> (Int, a) -> (Int, a) -> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (Int, a) -> Int
forall a b. (a, b) -> a
fst) ([(Int, a)] -> [(Int, a)]) -> [(Int, a)] -> [(Int, a)]
forall a b. (a -> b) -> a -> b
$ Permutation -> [a] -> [(Int, a)]
forall a b. [a] -> [b] -> [(a, b)]
zip Permutation
ps [a]
elems

flatOneDim :: (Eq num, IntegralExp num) =>
              (num, num, num) -> num -> num
flatOneDim :: (num, num, num) -> num -> num
flatOneDim (num
s, num
r, num
n) num
i
  | num
s num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0 = num
0
  | num
r num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0 = num
i num -> num -> num
forall a. Num a => a -> a -> a
* num
s
  | Bool
otherwise = ((num
i num -> num -> num
forall a. Num a => a -> a -> a
+ num
r) num -> num -> num
forall e. IntegralExp e => e -> e -> e
`mod` num
n) num -> num -> num
forall a. Num a => a -> a -> a
* num
s

-- | Generalised iota with user-specified offset and strides.
makeRotIota :: IntegralExp num =>
               Monotonicity -> num -> [(num, num)] -> LMAD num
makeRotIota :: Monotonicity -> num -> [(num, num)] -> LMAD num
makeRotIota Monotonicity
mon num
off [(num, num)]
support
  | Monotonicity
mon Monotonicity -> Monotonicity -> Bool
forall a. Eq a => a -> a -> Bool
== Monotonicity
Inc Bool -> Bool -> Bool
|| Monotonicity
mon Monotonicity -> Monotonicity -> Bool
forall a. Eq a => a -> a -> Bool
== Monotonicity
Dec =
    let rk :: Int
rk = [(num, num)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(num, num)]
support
        ([num]
rs, [num]
ns) = [(num, num)] -> ([num], [num])
forall a b. [(a, b)] -> ([a], [b])
unzip [(num, num)]
support
        ss0 :: [num]
ss0 = [num] -> [num]
forall a. [a] -> [a]
reverse ([num] -> [num]) -> [num] -> [num]
forall a b. (a -> b) -> a -> b
$ Int -> [num] -> [num]
forall a. Int -> [a] -> [a]
take Int
rk ([num] -> [num]) -> [num] -> [num]
forall a b. (a -> b) -> a -> b
$ (num -> num -> num) -> num -> [num] -> [num]
forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl num -> num -> num
forall a. Num a => a -> a -> a
(*) num
1 ([num] -> [num]) -> [num] -> [num]
forall a b. (a -> b) -> a -> b
$ [num] -> [num]
forall a. [a] -> [a]
reverse [num]
ns
        ss :: [num]
ss = if Monotonicity
mon Monotonicity -> Monotonicity -> Bool
forall a. Eq a => a -> a -> Bool
== Monotonicity
Inc
             then [num]
ss0
             else (num -> num) -> [num] -> [num]
forall a b. (a -> b) -> [a] -> [b]
map (num -> num -> num
forall a. Num a => a -> a -> a
* (-num
1)) [num]
ss0
        ps :: Permutation
ps = (Int -> Int) -> Permutation -> Permutation
forall a b. (a -> b) -> [a] -> [b]
map Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral [Int
0..Int
rkInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
        fi :: [Monotonicity]
fi = Int -> Monotonicity -> [Monotonicity]
forall a. Int -> a -> [a]
replicate Int
rk Monotonicity
mon
    in num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off ([LMADDim num] -> LMAD num) -> [LMADDim num] -> LMAD num
forall a b. (a -> b) -> a -> b
$ (num -> num -> num -> Int -> Monotonicity -> LMADDim num)
-> [num]
-> [num]
-> [num]
-> Permutation
-> [Monotonicity]
-> [LMADDim num]
forall a b c d e f.
(a -> b -> c -> d -> e -> f)
-> [a] -> [b] -> [c] -> [d] -> [e] -> [f]
zipWith5 num -> num -> num -> Int -> Monotonicity -> LMADDim num
forall num. num -> num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim [num]
ss [num]
rs [num]
ns Permutation
ps [Monotonicity]
fi
  | Bool
otherwise = String -> LMAD num
forall a. HasCallStack => String -> a
error String
"makeRotIota: requires Inc or Dec"

-- | Check monotonicity of an index function.
ixfunMonotonicityRots :: (Eq num, IntegralExp num) =>
                         Bool -> IxFun num -> Monotonicity
ixfunMonotonicityRots :: Bool -> IxFun num -> Monotonicity
ixfunMonotonicityRots Bool
ignore_rots (IxFun (LMAD num
lmad :| [LMAD num]
lmads) Shape num
_ Bool
_) =
  let mon0 :: Monotonicity
mon0 = LMAD num -> Monotonicity
forall num. (Eq num, IntegralExp num) => LMAD num -> Monotonicity
lmadMonotonicityRots LMAD num
lmad
  in if (LMAD num -> Bool) -> [LMAD num] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((Monotonicity -> Monotonicity -> Bool
forall a. Eq a => a -> a -> Bool
== Monotonicity
mon0) (Monotonicity -> Bool)
-> (LMAD num -> Monotonicity) -> LMAD num -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LMAD num -> Monotonicity
forall num. (Eq num, IntegralExp num) => LMAD num -> Monotonicity
lmadMonotonicityRots) [LMAD num]
lmads
     then Monotonicity
mon0
     else Monotonicity
Unknown
  where lmadMonotonicityRots :: (Eq num, IntegralExp num) =>
                                LMAD num -> Monotonicity
        lmadMonotonicityRots :: LMAD num -> Monotonicity
lmadMonotonicityRots (LMAD num
_ [LMADDim num]
dims)
          | (LMADDim num -> Bool) -> [LMADDim num] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Monotonicity -> LMADDim num -> Bool
forall num.
(Eq num, IntegralExp num) =>
Monotonicity -> LMADDim num -> Bool
isMonDim Monotonicity
Inc) [LMADDim num]
dims = Monotonicity
Inc
          | (LMADDim num -> Bool) -> [LMADDim num] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Monotonicity -> LMADDim num -> Bool
forall num.
(Eq num, IntegralExp num) =>
Monotonicity -> LMADDim num -> Bool
isMonDim Monotonicity
Dec) [LMADDim num]
dims = Monotonicity
Dec
          | Bool
otherwise = Monotonicity
Unknown

        isMonDim :: (Eq num, IntegralExp num) =>
                    Monotonicity -> LMADDim num -> Bool
        isMonDim :: Monotonicity -> LMADDim num -> Bool
isMonDim Monotonicity
mon (LMADDim num
s num
r num
_ Int
_ Monotonicity
ldmon) =
          num
s num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0 Bool -> Bool -> Bool
|| ((Bool
ignore_rots Bool -> Bool -> Bool
|| num
r num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0) Bool -> Bool -> Bool
&& Monotonicity
mon Monotonicity -> Monotonicity -> Bool
forall a. Eq a => a -> a -> Bool
== Monotonicity
ldmon)

-- | Generalization (anti-unification)
--
-- Anti-unification of two index functions is supported under the following conditions:
--   0. Both index functions are represented by ONE lmad (assumed common case!)
--   1. The support array of the two indexfuns have the same dimensionality
--      (we can relax this condition if we use a 1D support, as we probably should!)
--   2. The contiguous property and the per-dimension monotonicity are the same
--      (otherwise we might loose important information; this can be relaxed!)
--   3. Most importantly, both index functions correspond to the same permutation
--      (since the permutation is represented by INTs, this restriction cannot
--       be relaxed, unless we move to a gated-LMAD representation!)
leastGeneralGeneralization :: Eq v => IxFun (PrimExp v) -> IxFun (PrimExp v) ->
                              Maybe (IxFun (PrimExp (Ext v)), [(PrimExp v, PrimExp v)])
leastGeneralGeneralization :: IxFun (PrimExp v)
-> IxFun (PrimExp v)
-> Maybe (IxFun (PrimExp (Ext v)), [(PrimExp v, PrimExp v)])
leastGeneralGeneralization (IxFun (LMAD (PrimExp v)
lmad1 :| []) Shape (PrimExp v)
oshp1 Bool
ctg1) (IxFun (LMAD (PrimExp v)
lmad2 :| []) Shape (PrimExp v)
oshp2 Bool
ctg2) = do
  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$
    Shape (PrimExp v) -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape (PrimExp v)
oshp1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Shape (PrimExp v) -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape (PrimExp v)
oshp2 Bool -> Bool -> Bool
&&
    Bool
ctg1 Bool -> Bool -> Bool
forall a. Eq a => a -> a -> Bool
== Bool
ctg2 Bool -> Bool -> Bool
&&
    (LMADDim (PrimExp v) -> Int)
-> [LMADDim (PrimExp v)] -> Permutation
forall a b. (a -> b) -> [a] -> [b]
map LMADDim (PrimExp v) -> Int
forall num. LMADDim num -> Int
ldPerm (LMAD (PrimExp v) -> [LMADDim (PrimExp v)]
forall num. LMAD num -> [LMADDim num]
lmadDims LMAD (PrimExp v)
lmad1) Permutation -> Permutation -> Bool
forall a. Eq a => a -> a -> Bool
== (LMADDim (PrimExp v) -> Int)
-> [LMADDim (PrimExp v)] -> Permutation
forall a b. (a -> b) -> [a] -> [b]
map LMADDim (PrimExp v) -> Int
forall num. LMADDim num -> Int
ldPerm (LMAD (PrimExp v) -> [LMADDim (PrimExp v)]
forall num. LMAD num -> [LMADDim num]
lmadDims LMAD (PrimExp v)
lmad2) Bool -> Bool -> Bool
&&
    LMAD (PrimExp v) -> [Monotonicity]
forall num. LMAD num -> [Monotonicity]
lmadDMon LMAD (PrimExp v)
lmad1 [Monotonicity] -> [Monotonicity] -> Bool
forall a. Eq a => a -> a -> Bool
== LMAD (PrimExp v) -> [Monotonicity]
forall num. LMAD num -> [Monotonicity]
lmadDMon LMAD (PrimExp v)
lmad2
  let (Bool
ctg, Permutation
dperm, [Monotonicity]
dmon) = (Bool
ctg1, LMAD (PrimExp v) -> Permutation
forall num. LMAD num -> Permutation
lmadPermutation LMAD (PrimExp v)
lmad1, LMAD (PrimExp v) -> [Monotonicity]
forall num. LMAD num -> [Monotonicity]
lmadDMon LMAD (PrimExp v)
lmad1)
  ([PrimExp (Ext v)]
dshp, [(PrimExp v, PrimExp v)]
m1) <- [(PrimExp v, PrimExp v)]
-> Shape (PrimExp v)
-> Shape (PrimExp v)
-> Maybe ([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
forall v.
Eq v =>
[(PrimExp v, PrimExp v)]
-> [PrimExp v]
-> [PrimExp v]
-> Maybe ([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
generalize [] (LMAD (PrimExp v) -> Shape (PrimExp v)
forall a. LMAD a -> [a]
lmadDShp LMAD (PrimExp v)
lmad1) (LMAD (PrimExp v) -> Shape (PrimExp v)
forall a. LMAD a -> [a]
lmadDShp LMAD (PrimExp v)
lmad2)
  ([PrimExp (Ext v)]
oshp, [(PrimExp v, PrimExp v)]
m2) <- [(PrimExp v, PrimExp v)]
-> Shape (PrimExp v)
-> Shape (PrimExp v)
-> Maybe ([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
forall v.
Eq v =>
[(PrimExp v, PrimExp v)]
-> [PrimExp v]
-> [PrimExp v]
-> Maybe ([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
generalize [(PrimExp v, PrimExp v)]
m1 Shape (PrimExp v)
oshp1 Shape (PrimExp v)
oshp2
  ([PrimExp (Ext v)]
dstd, [(PrimExp v, PrimExp v)]
m3) <- [(PrimExp v, PrimExp v)]
-> Shape (PrimExp v)
-> Shape (PrimExp v)
-> Maybe ([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
forall v.
Eq v =>
[(PrimExp v, PrimExp v)]
-> [PrimExp v]
-> [PrimExp v]
-> Maybe ([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
generalize [(PrimExp v, PrimExp v)]
m2 (LMAD (PrimExp v) -> Shape (PrimExp v)
forall a. LMAD a -> [a]
lmadDSrd LMAD (PrimExp v)
lmad1) (LMAD (PrimExp v) -> Shape (PrimExp v)
forall a. LMAD a -> [a]
lmadDSrd LMAD (PrimExp v)
lmad2)
  ([PrimExp (Ext v)]
drot, [(PrimExp v, PrimExp v)]
m4) <- [(PrimExp v, PrimExp v)]
-> Shape (PrimExp v)
-> Shape (PrimExp v)
-> Maybe ([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
forall v.
Eq v =>
[(PrimExp v, PrimExp v)]
-> [PrimExp v]
-> [PrimExp v]
-> Maybe ([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
generalize [(PrimExp v, PrimExp v)]
m3 (LMAD (PrimExp v) -> Shape (PrimExp v)
forall a. LMAD a -> [a]
lmadDRot LMAD (PrimExp v)
lmad1) (LMAD (PrimExp v) -> Shape (PrimExp v)
forall a. LMAD a -> [a]
lmadDRot LMAD (PrimExp v)
lmad2)
  (PrimExp (Ext v)
offt, [(PrimExp v, PrimExp v)]
m5) <- [(PrimExp v, PrimExp v)]
-> PrimExp v
-> PrimExp v
-> Maybe (PrimExp (Ext v), [(PrimExp v, PrimExp v)])
forall v.
Eq v =>
[(PrimExp v, PrimExp v)]
-> PrimExp v
-> PrimExp v
-> Maybe (PrimExp (Ext v), [(PrimExp v, PrimExp v)])
PEG.leastGeneralGeneralization [(PrimExp v, PrimExp v)]
m4 (LMAD (PrimExp v) -> PrimExp v
forall num. LMAD num -> num
lmadOffset LMAD (PrimExp v)
lmad1) (LMAD (PrimExp v) -> PrimExp v
forall num. LMAD num -> num
lmadOffset LMAD (PrimExp v)
lmad2)
  let lmad_dims :: [LMADDim (PrimExp (Ext v))]
lmad_dims = ((PrimExp (Ext v), PrimExp (Ext v), PrimExp (Ext v), Int,
  Monotonicity)
 -> LMADDim (PrimExp (Ext v)))
-> [(PrimExp (Ext v), PrimExp (Ext v), PrimExp (Ext v), Int,
     Monotonicity)]
-> [LMADDim (PrimExp (Ext v))]
forall a b. (a -> b) -> [a] -> [b]
map (\(PrimExp (Ext v)
a,PrimExp (Ext v)
b,PrimExp (Ext v)
c,Int
d,Monotonicity
e) -> PrimExp (Ext v)
-> PrimExp (Ext v)
-> PrimExp (Ext v)
-> Int
-> Monotonicity
-> LMADDim (PrimExp (Ext v))
forall num. num -> num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim PrimExp (Ext v)
a PrimExp (Ext v)
b PrimExp (Ext v)
c Int
d Monotonicity
e) ([(PrimExp (Ext v), PrimExp (Ext v), PrimExp (Ext v), Int,
   Monotonicity)]
 -> [LMADDim (PrimExp (Ext v))])
-> [(PrimExp (Ext v), PrimExp (Ext v), PrimExp (Ext v), Int,
     Monotonicity)]
-> [LMADDim (PrimExp (Ext v))]
forall a b. (a -> b) -> a -> b
$
        [PrimExp (Ext v)]
-> [PrimExp (Ext v)]
-> [PrimExp (Ext v)]
-> Permutation
-> [Monotonicity]
-> [(PrimExp (Ext v), PrimExp (Ext v), PrimExp (Ext v), Int,
     Monotonicity)]
forall a b c d e.
[a] -> [b] -> [c] -> [d] -> [e] -> [(a, b, c, d, e)]
zip5 [PrimExp (Ext v)]
dstd [PrimExp (Ext v)]
drot [PrimExp (Ext v)]
dshp Permutation
dperm [Monotonicity]
dmon
      lmad :: LMAD (PrimExp (Ext v))
lmad = PrimExp (Ext v)
-> [LMADDim (PrimExp (Ext v))] -> LMAD (PrimExp (Ext v))
forall num. num -> [LMADDim num] -> LMAD num
LMAD PrimExp (Ext v)
offt [LMADDim (PrimExp (Ext v))]
lmad_dims
  (IxFun (PrimExp (Ext v)), [(PrimExp v, PrimExp v)])
-> Maybe (IxFun (PrimExp (Ext v)), [(PrimExp v, PrimExp v)])
forall (m :: * -> *) a. Monad m => a -> m a
return (NonEmpty (LMAD (PrimExp (Ext v)))
-> [PrimExp (Ext v)] -> Bool -> IxFun (PrimExp (Ext v))
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (LMAD (PrimExp (Ext v))
lmad LMAD (PrimExp (Ext v))
-> [LMAD (PrimExp (Ext v))] -> NonEmpty (LMAD (PrimExp (Ext v)))
forall a. a -> [a] -> NonEmpty a
:| []) [PrimExp (Ext v)]
oshp Bool
ctg, [(PrimExp v, PrimExp v)]
m5)
  where lmadDMon :: LMAD num -> [Monotonicity]
lmadDMon = (LMADDim num -> Monotonicity) -> [LMADDim num] -> [Monotonicity]
forall a b. (a -> b) -> [a] -> [b]
map LMADDim num -> Monotonicity
forall num. LMADDim num -> Monotonicity
ldMon    ([LMADDim num] -> [Monotonicity])
-> (LMAD num -> [LMADDim num]) -> LMAD num -> [Monotonicity]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
lmadDims
        lmadDSrd :: LMAD b -> [b]
lmadDSrd = (LMADDim b -> b) -> [LMADDim b] -> [b]
forall a b. (a -> b) -> [a] -> [b]
map LMADDim b -> b
forall num. LMADDim num -> num
ldStride ([LMADDim b] -> [b]) -> (LMAD b -> [LMADDim b]) -> LMAD b -> [b]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LMAD b -> [LMADDim b]
forall num. LMAD num -> [LMADDim num]
lmadDims
        lmadDShp :: LMAD b -> [b]
lmadDShp = (LMADDim b -> b) -> [LMADDim b] -> [b]
forall a b. (a -> b) -> [a] -> [b]
map LMADDim b -> b
forall num. LMADDim num -> num
ldShape  ([LMADDim b] -> [b]) -> (LMAD b -> [LMADDim b]) -> LMAD b -> [b]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LMAD b -> [LMADDim b]
forall num. LMAD num -> [LMADDim num]
lmadDims
        lmadDRot :: LMAD b -> [b]
lmadDRot = (LMADDim b -> b) -> [LMADDim b] -> [b]
forall a b. (a -> b) -> [a] -> [b]
map LMADDim b -> b
forall num. LMADDim num -> num
ldRotate ([LMADDim b] -> [b]) -> (LMAD b -> [LMADDim b]) -> LMAD b -> [b]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LMAD b -> [LMADDim b]
forall num. LMAD num -> [LMADDim num]
lmadDims
        generalize :: [(PrimExp v, PrimExp v)]
-> [PrimExp v]
-> [PrimExp v]
-> Maybe ([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
generalize [(PrimExp v, PrimExp v)]
m [PrimExp v]
l1 [PrimExp v]
l2 =
          (([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
 -> (PrimExp v, PrimExp v)
 -> Maybe ([PrimExp (Ext v)], [(PrimExp v, PrimExp v)]))
-> ([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
-> [(PrimExp v, PrimExp v)]
-> Maybe ([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (\([PrimExp (Ext v)]
l_acc, [(PrimExp v, PrimExp v)]
m') (PrimExp v
pe1,PrimExp v
pe2) -> do
                    (PrimExp (Ext v)
e, [(PrimExp v, PrimExp v)]
m'') <- [(PrimExp v, PrimExp v)]
-> PrimExp v
-> PrimExp v
-> Maybe (PrimExp (Ext v), [(PrimExp v, PrimExp v)])
forall v.
Eq v =>
[(PrimExp v, PrimExp v)]
-> PrimExp v
-> PrimExp v
-> Maybe (PrimExp (Ext v), [(PrimExp v, PrimExp v)])
PEG.leastGeneralGeneralization [(PrimExp v, PrimExp v)]
m' PrimExp v
pe1 PrimExp v
pe2
                    ([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
-> Maybe ([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
forall (m :: * -> *) a. Monad m => a -> m a
return ([PrimExp (Ext v)]
l_acc[PrimExp (Ext v)] -> [PrimExp (Ext v)] -> [PrimExp (Ext v)]
forall a. [a] -> [a] -> [a]
++[PrimExp (Ext v)
e], [(PrimExp v, PrimExp v)]
m'')
                ) ([], [(PrimExp v, PrimExp v)]
m) ([PrimExp v] -> [PrimExp v] -> [(PrimExp v, PrimExp v)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PrimExp v]
l1 [PrimExp v]
l2)
leastGeneralGeneralization IxFun (PrimExp v)
_ IxFun (PrimExp v)
_ = Maybe (IxFun (PrimExp (Ext v)), [(PrimExp v, PrimExp v)])
forall a. Maybe a
Nothing

-- | When comparing index functions as part of the type check in KernelsMem,
-- we may run into problems caused by the simplifier. As index functions can be
-- generalized over if-then-else expressions, the simplifier might hoist some of
-- the code from inside the if-then-else (computing the offset of an array, for
-- instance), but now the type checker cannot verify that the generalized index
-- function is valid, because some of the existentials are computed somewhere
-- else. To Work around this, we've had to relax the KernelsMem type-checker
-- a bit, specifically, we've introduced this function to verify whether two
-- index functions are "close enough" that we can assume that they match. We use
-- this instead of `ixfun1 == ixfun2` and hope that it's good enough.
closeEnough :: IxFun num -> IxFun num -> Bool
closeEnough :: IxFun num -> IxFun num -> Bool
closeEnough IxFun num
ixf1 IxFun num
ixf2 =
  ([num] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (IxFun num -> [num]
forall a. IxFun a -> [a]
base IxFun num
ixf1) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [num] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (IxFun num -> [num]
forall a. IxFun a -> [a]
base IxFun num
ixf2)) Bool -> Bool -> Bool
&&
  (IxFun num -> Bool
forall a. IxFun a -> Bool
ixfunContig IxFun num
ixf1 Bool -> Bool -> Bool
forall a. Eq a => a -> a -> Bool
== IxFun num -> Bool
forall a. IxFun a -> Bool
ixfunContig IxFun num
ixf2) Bool -> Bool -> Bool
&&
  (NonEmpty (LMAD num) -> Int
forall a. NonEmpty a -> Int
NE.length (IxFun num -> NonEmpty (LMAD num)
forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs IxFun num
ixf1) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== NonEmpty (LMAD num) -> Int
forall a. NonEmpty a -> Int
NE.length (IxFun num -> NonEmpty (LMAD num)
forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs IxFun num
ixf2)) Bool -> Bool -> Bool
&&
  ((LMAD num, LMAD num) -> Bool)
-> NonEmpty (LMAD num, LMAD num) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (LMAD num, LMAD num) -> Bool
forall num. (LMAD num, LMAD num) -> Bool
closeEnoughLMADs (NonEmpty (LMAD num)
-> NonEmpty (LMAD num) -> NonEmpty (LMAD num, LMAD num)
forall a b. NonEmpty a -> NonEmpty b -> NonEmpty (a, b)
NE.zip (IxFun num -> NonEmpty (LMAD num)
forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs IxFun num
ixf1) (IxFun num -> NonEmpty (LMAD num)
forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs IxFun num
ixf2))
  where
    closeEnoughLMADs :: (LMAD num, LMAD num) -> Bool
    closeEnoughLMADs :: (LMAD num, LMAD num) -> Bool
closeEnoughLMADs (LMAD num
lmad1, LMAD num
lmad2) =
      [LMADDim num] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad1) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [LMADDim num] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad2) Bool -> Bool -> Bool
&&
      (LMADDim num -> Int) -> [LMADDim num] -> Permutation
forall a b. (a -> b) -> [a] -> [b]
map LMADDim num -> Int
forall num. LMADDim num -> Int
ldPerm (LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad1) Permutation -> Permutation -> Bool
forall a. Eq a => a -> a -> Bool
==
      (LMADDim num -> Int) -> [LMADDim num] -> Permutation
forall a b. (a -> b) -> [a] -> [b]
map LMADDim num -> Int
forall num. LMADDim num -> Int
ldPerm (LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad2)