{-# LANGUAGE OverloadedStrings #-}
{-# OPTIONS_GHC -fno-warn-redundant-constraints #-}

-- | This module contains a representation for the index function based on
-- linear-memory accessor descriptors; see Zhu, Hoeflinger and David work.
module Futhark.IR.Mem.IxFun
  ( IxFun (..),
    LMAD (..),
    LMADDim (..),
    Monotonicity (..),
    index,
    iota,
    iotaOffset,
    permute,
    rotate,
    reshape,
    slice,
    rebase,
    shape,
    rank,
    linearWithOffset,
    rearrangeWithOffset,
    isDirect,
    isLinear,
    substituteInIxFun,
    leastGeneralGeneralization,
    existentialize,
    closeEnough,
  )
where

import Control.Category
import Control.Monad.Identity
import Control.Monad.State
import Control.Monad.Writer
import Data.Function (on)
import Data.List (sort, sortBy, zip4, zip5, zipWith5)
import Data.List.NonEmpty (NonEmpty (..))
import qualified Data.List.NonEmpty as NE
import qualified Data.Map.Strict as M
import Data.Maybe (isJust)
import Futhark.Analysis.PrimExp
  ( IntExp,
    PrimExp (..),
    TPrimExp (..),
    primExpType,
  )
import Futhark.Analysis.PrimExp.Convert (substituteInPrimExp)
import qualified Futhark.Analysis.PrimExp.Generalize as PEG
import Futhark.IR.Prop
import Futhark.IR.Syntax
  ( DimChange (..),
    DimIndex (..),
    ShapeChange,
    Slice,
    dimFix,
    unitSlice,
  )
import Futhark.IR.Syntax.Core (Ext (..))
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import Futhark.Util.IntegralExp
import Futhark.Util.Pretty
import Prelude hiding (id, mod, (.))

type Shape num = [num]

type Indices num = [num]

type Permutation = [Int]

data Monotonicity
  = Inc
  | Dec
  | -- | monotonously increasing, decreasing or unknown
    Unknown
  deriving (Int -> Monotonicity -> ShowS
[Monotonicity] -> ShowS
Monotonicity -> String
(Int -> Monotonicity -> ShowS)
-> (Monotonicity -> String)
-> ([Monotonicity] -> ShowS)
-> Show Monotonicity
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Monotonicity] -> ShowS
$cshowList :: [Monotonicity] -> ShowS
show :: Monotonicity -> String
$cshow :: Monotonicity -> String
showsPrec :: Int -> Monotonicity -> ShowS
$cshowsPrec :: Int -> Monotonicity -> ShowS
Show, Monotonicity -> Monotonicity -> Bool
(Monotonicity -> Monotonicity -> Bool)
-> (Monotonicity -> Monotonicity -> Bool) -> Eq Monotonicity
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Monotonicity -> Monotonicity -> Bool
$c/= :: Monotonicity -> Monotonicity -> Bool
== :: Monotonicity -> Monotonicity -> Bool
$c== :: Monotonicity -> Monotonicity -> Bool
Eq)

data LMADDim num = LMADDim
  { LMADDim num -> num
ldStride :: num,
    LMADDim num -> num
ldRotate :: num,
    LMADDim num -> num
ldShape :: num,
    LMADDim num -> Int
ldPerm :: Int,
    LMADDim num -> Monotonicity
ldMon :: Monotonicity
  }
  deriving (Int -> LMADDim num -> ShowS
[LMADDim num] -> ShowS
LMADDim num -> String
(Int -> LMADDim num -> ShowS)
-> (LMADDim num -> String)
-> ([LMADDim num] -> ShowS)
-> Show (LMADDim num)
forall num. Show num => Int -> LMADDim num -> ShowS
forall num. Show num => [LMADDim num] -> ShowS
forall num. Show num => LMADDim num -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [LMADDim num] -> ShowS
$cshowList :: forall num. Show num => [LMADDim num] -> ShowS
show :: LMADDim num -> String
$cshow :: forall num. Show num => LMADDim num -> String
showsPrec :: Int -> LMADDim num -> ShowS
$cshowsPrec :: forall num. Show num => Int -> LMADDim num -> ShowS
Show, LMADDim num -> LMADDim num -> Bool
(LMADDim num -> LMADDim num -> Bool)
-> (LMADDim num -> LMADDim num -> Bool) -> Eq (LMADDim num)
forall num. Eq num => LMADDim num -> LMADDim num -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: LMADDim num -> LMADDim num -> Bool
$c/= :: forall num. Eq num => LMADDim num -> LMADDim num -> Bool
== :: LMADDim num -> LMADDim num -> Bool
$c== :: forall num. Eq num => LMADDim num -> LMADDim num -> Bool
Eq)

-- | LMAD's representation consists of a general offset and for each dimension a
-- stride, rotate factor, number of elements (or shape), permutation, and
-- monotonicity. Note that the permutation is not strictly necessary in that the
-- permutation can be performed directly on LMAD dimensions, but then it is
-- difficult to extract the permutation back from an LMAD.
--
-- LMAD algebra is closed under composition w.r.t. operators such as
-- permute, index and slice.  However, other operations, such as
-- reshape, cannot always be represented inside the LMAD algebra.
--
-- It follows that the general representation of an index function is a list of
-- LMADS, in which each following LMAD in the list implicitly corresponds to an
-- irregular reshaping operation.
--
-- However, we expect that the common case is when the index function is one
-- LMAD -- we call this the "nice" representation.
--
-- Finally, the list of LMADs is kept in an @IxFun@ together with the shape of
-- the original array, and a bit to indicate whether the index function is
-- contiguous, i.e., if we instantiate all the points of the current index
-- function, do we get a contiguous memory interval?
--
-- By definition, the LMAD denotes the set of points (simplified):
--
--   \{ o + \Sigma_{j=0}^{k} ((i_j+r_j) `mod` n_j)*s_j,
--      \forall i_j such that 0<=i_j<n_j, j=1..k \}
data LMAD num = LMAD
  { LMAD num -> num
lmadOffset :: num,
    LMAD num -> [LMADDim num]
lmadDims :: [LMADDim num]
  }
  deriving (Int -> LMAD num -> ShowS
[LMAD num] -> ShowS
LMAD num -> String
(Int -> LMAD num -> ShowS)
-> (LMAD num -> String) -> ([LMAD num] -> ShowS) -> Show (LMAD num)
forall num. Show num => Int -> LMAD num -> ShowS
forall num. Show num => [LMAD num] -> ShowS
forall num. Show num => LMAD num -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [LMAD num] -> ShowS
$cshowList :: forall num. Show num => [LMAD num] -> ShowS
show :: LMAD num -> String
$cshow :: forall num. Show num => LMAD num -> String
showsPrec :: Int -> LMAD num -> ShowS
$cshowsPrec :: forall num. Show num => Int -> LMAD num -> ShowS
Show, LMAD num -> LMAD num -> Bool
(LMAD num -> LMAD num -> Bool)
-> (LMAD num -> LMAD num -> Bool) -> Eq (LMAD num)
forall num. Eq num => LMAD num -> LMAD num -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: LMAD num -> LMAD num -> Bool
$c/= :: forall num. Eq num => LMAD num -> LMAD num -> Bool
== :: LMAD num -> LMAD num -> Bool
$c== :: forall num. Eq num => LMAD num -> LMAD num -> Bool
Eq)

-- | An index function is a mapping from a multidimensional array
-- index space (the domain) to a one-dimensional memory index space.
-- Essentially, it explains where the element at position @[i,j,p]@ of
-- some array is stored inside the flat one-dimensional array that
-- constitutes its memory.  For example, we can use this to
-- distinguish row-major and column-major representations.
--
-- An index function is represented as a sequence of 'LMAD's.
data IxFun num = IxFun
  { IxFun num -> NonEmpty (LMAD num)
ixfunLMADs :: NonEmpty (LMAD num),
    IxFun num -> Shape num
base :: Shape num,
    -- | ignoring permutations, is the index function contiguous?
    IxFun num -> Bool
ixfunContig :: Bool
  }
  deriving (Int -> IxFun num -> ShowS
[IxFun num] -> ShowS
IxFun num -> String
(Int -> IxFun num -> ShowS)
-> (IxFun num -> String)
-> ([IxFun num] -> ShowS)
-> Show (IxFun num)
forall num. Show num => Int -> IxFun num -> ShowS
forall num. Show num => [IxFun num] -> ShowS
forall num. Show num => IxFun num -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [IxFun num] -> ShowS
$cshowList :: forall num. Show num => [IxFun num] -> ShowS
show :: IxFun num -> String
$cshow :: forall num. Show num => IxFun num -> String
showsPrec :: Int -> IxFun num -> ShowS
$cshowsPrec :: forall num. Show num => Int -> IxFun num -> ShowS
Show, IxFun num -> IxFun num -> Bool
(IxFun num -> IxFun num -> Bool)
-> (IxFun num -> IxFun num -> Bool) -> Eq (IxFun num)
forall num. Eq num => IxFun num -> IxFun num -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: IxFun num -> IxFun num -> Bool
$c/= :: forall num. Eq num => IxFun num -> IxFun num -> Bool
== :: IxFun num -> IxFun num -> Bool
$c== :: forall num. Eq num => IxFun num -> IxFun num -> Bool
Eq)

instance Pretty Monotonicity where
  ppr :: Monotonicity -> Doc
ppr = String -> Doc
text (String -> Doc) -> (Monotonicity -> String) -> Monotonicity -> Doc
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Monotonicity -> String
forall a. Show a => a -> String
show

instance Pretty num => Pretty (LMAD num) where
  ppr :: LMAD num -> Doc
ppr (LMAD num
offset [LMADDim num]
dims) =
    Doc -> Doc
braces (Doc -> Doc) -> Doc -> Doc
forall a b. (a -> b) -> a -> b
$
      [Doc] -> Doc
semisep
        [ Doc
"offset: " Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
oneLine (num -> Doc
forall a. Pretty a => a -> Doc
ppr num
offset),
          Doc
"strides: " Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> (LMADDim num -> num) -> Doc
forall b. Pretty b => (LMADDim num -> b) -> Doc
p LMADDim num -> num
forall num. LMADDim num -> num
ldStride,
          Doc
"rotates: " Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> (LMADDim num -> num) -> Doc
forall b. Pretty b => (LMADDim num -> b) -> Doc
p LMADDim num -> num
forall num. LMADDim num -> num
ldRotate,
          Doc
"shape: " Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> (LMADDim num -> num) -> Doc
forall b. Pretty b => (LMADDim num -> b) -> Doc
p LMADDim num -> num
forall num. LMADDim num -> num
ldShape,
          Doc
"permutation: " Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> (LMADDim num -> Int) -> Doc
forall b. Pretty b => (LMADDim num -> b) -> Doc
p LMADDim num -> Int
forall num. LMADDim num -> Int
ldPerm,
          Doc
"monotonicity: " Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> (LMADDim num -> Monotonicity) -> Doc
forall b. Pretty b => (LMADDim num -> b) -> Doc
p LMADDim num -> Monotonicity
forall num. LMADDim num -> Monotonicity
ldMon
        ]
    where
      p :: (LMADDim num -> b) -> Doc
p LMADDim num -> b
f = Doc -> Doc
oneLine (Doc -> Doc) -> Doc -> Doc
forall a b. (a -> b) -> a -> b
$ Doc -> Doc
brackets (Doc -> Doc) -> Doc -> Doc
forall a b. (a -> b) -> a -> b
$ [Doc] -> Doc
commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (LMADDim num -> Doc) -> [LMADDim num] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map (b -> Doc
forall a. Pretty a => a -> Doc
ppr (b -> Doc) -> (LMADDim num -> b) -> LMADDim num -> Doc
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. LMADDim num -> b
f) [LMADDim num]
dims

instance Pretty num => Pretty (IxFun num) where
  ppr :: IxFun num -> Doc
ppr (IxFun NonEmpty (LMAD num)
lmads Shape num
oshp Bool
cg) =
    Doc -> Doc
braces (Doc -> Doc) -> Doc -> Doc
forall a b. (a -> b) -> a -> b
$
      [Doc] -> Doc
semisep
        [ Doc
"base: " Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
brackets ([Doc] -> Doc
commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (num -> Doc) -> Shape num -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map num -> Doc
forall a. Pretty a => a -> Doc
ppr Shape num
oshp),
          Doc
"contiguous: " Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> if Bool
cg then Doc
"true" else Doc
"false",
          Doc
"LMADs: " Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
brackets ([Doc] -> Doc
commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ NonEmpty Doc -> [Doc]
forall a. NonEmpty a -> [a]
NE.toList (NonEmpty Doc -> [Doc]) -> NonEmpty Doc -> [Doc]
forall a b. (a -> b) -> a -> b
$ (LMAD num -> Doc) -> NonEmpty (LMAD num) -> NonEmpty Doc
forall a b. (a -> b) -> NonEmpty a -> NonEmpty b
NE.map LMAD num -> Doc
forall a. Pretty a => a -> Doc
ppr NonEmpty (LMAD num)
lmads)
        ]

instance Substitute num => Substitute (LMAD num) where
  substituteNames :: Map VName VName -> LMAD num -> LMAD num
substituteNames Map VName VName
substs = (num -> num) -> LMAD num -> LMAD num
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((num -> num) -> LMAD num -> LMAD num)
-> (num -> num) -> LMAD num -> LMAD num
forall a b. (a -> b) -> a -> b
$ Map VName VName -> num -> num
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs

instance Substitute num => Substitute (IxFun num) where
  substituteNames :: Map VName VName -> IxFun num -> IxFun num
substituteNames Map VName VName
substs = (num -> num) -> IxFun num -> IxFun num
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((num -> num) -> IxFun num -> IxFun num)
-> (num -> num) -> IxFun num -> IxFun num
forall a b. (a -> b) -> a -> b
$ Map VName VName -> num -> num
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs

instance Substitute num => Rename (LMAD num) where
  rename :: LMAD num -> RenameM (LMAD num)
rename = LMAD num -> RenameM (LMAD num)
forall a. Substitute a => a -> RenameM a
substituteRename

instance Substitute num => Rename (IxFun num) where
  rename :: IxFun num -> RenameM (IxFun num)
rename = IxFun num -> RenameM (IxFun num)
forall a. Substitute a => a -> RenameM a
substituteRename

instance FreeIn num => FreeIn (LMAD num) where
  freeIn' :: LMAD num -> FV
freeIn' = (num -> FV) -> LMAD num -> FV
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap num -> FV
forall a. FreeIn a => a -> FV
freeIn'

instance FreeIn num => FreeIn (IxFun num) where
  freeIn' :: IxFun num -> FV
freeIn' = (num -> FV) -> IxFun num -> FV
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap num -> FV
forall a. FreeIn a => a -> FV
freeIn'

instance Functor LMAD where
  fmap :: (a -> b) -> LMAD a -> LMAD b
fmap a -> b
f = Identity (LMAD b) -> LMAD b
forall a. Identity a -> a
runIdentity (Identity (LMAD b) -> LMAD b)
-> (LMAD a -> Identity (LMAD b)) -> LMAD a -> LMAD b
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (a -> Identity b) -> LMAD a -> Identity (LMAD b)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (b -> Identity b
forall (m :: * -> *) a. Monad m => a -> m a
return (b -> Identity b) -> (a -> b) -> a -> Identity b
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. a -> b
f)

instance Functor IxFun where
  fmap :: (a -> b) -> IxFun a -> IxFun b
fmap a -> b
f = Identity (IxFun b) -> IxFun b
forall a. Identity a -> a
runIdentity (Identity (IxFun b) -> IxFun b)
-> (IxFun a -> Identity (IxFun b)) -> IxFun a -> IxFun b
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (a -> Identity b) -> IxFun a -> Identity (IxFun b)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (b -> Identity b
forall (m :: * -> *) a. Monad m => a -> m a
return (b -> Identity b) -> (a -> b) -> a -> Identity b
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. a -> b
f)

instance Foldable LMAD where
  foldMap :: (a -> m) -> LMAD a -> m
foldMap a -> m
f = Writer m (LMAD ()) -> m
forall w a. Writer w a -> w
execWriter (Writer m (LMAD ()) -> m)
-> (LMAD a -> Writer m (LMAD ())) -> LMAD a -> m
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (a -> WriterT m Identity ()) -> LMAD a -> Writer m (LMAD ())
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (m -> WriterT m Identity ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell (m -> WriterT m Identity ())
-> (a -> m) -> a -> WriterT m Identity ()
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. a -> m
f)

instance Foldable IxFun where
  foldMap :: (a -> m) -> IxFun a -> m
foldMap a -> m
f = Writer m (IxFun ()) -> m
forall w a. Writer w a -> w
execWriter (Writer m (IxFun ()) -> m)
-> (IxFun a -> Writer m (IxFun ())) -> IxFun a -> m
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (a -> WriterT m Identity ()) -> IxFun a -> Writer m (IxFun ())
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (m -> WriterT m Identity ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell (m -> WriterT m Identity ())
-> (a -> m) -> a -> WriterT m Identity ()
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. a -> m
f)

instance Traversable LMAD where
  traverse :: (a -> f b) -> LMAD a -> f (LMAD b)
traverse a -> f b
f (LMAD a
offset [LMADDim a]
dims) =
    b -> [LMADDim b] -> LMAD b
forall num. num -> [LMADDim num] -> LMAD num
LMAD (b -> [LMADDim b] -> LMAD b) -> f b -> f ([LMADDim b] -> LMAD b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> f b
f a
offset f ([LMADDim b] -> LMAD b) -> f [LMADDim b] -> f (LMAD b)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (LMADDim a -> f (LMADDim b)) -> [LMADDim a] -> f [LMADDim b]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse LMADDim a -> f (LMADDim b)
f' [LMADDim a]
dims
    where
      f' :: LMADDim a -> f (LMADDim b)
f' (LMADDim a
s a
r a
n Int
p Monotonicity
m) =
        b -> b -> b -> Int -> Monotonicity -> LMADDim b
forall num. num -> num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim (b -> b -> b -> Int -> Monotonicity -> LMADDim b)
-> f b -> f (b -> b -> Int -> Monotonicity -> LMADDim b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> f b
f a
s f (b -> b -> Int -> Monotonicity -> LMADDim b)
-> f b -> f (b -> Int -> Monotonicity -> LMADDim b)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> a -> f b
f a
r f (b -> Int -> Monotonicity -> LMADDim b)
-> f b -> f (Int -> Monotonicity -> LMADDim b)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> a -> f b
f a
n f (Int -> Monotonicity -> LMADDim b)
-> f Int -> f (Monotonicity -> LMADDim b)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> f Int
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
p f (Monotonicity -> LMADDim b) -> f Monotonicity -> f (LMADDim b)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Monotonicity -> f Monotonicity
forall (f :: * -> *) a. Applicative f => a -> f a
pure Monotonicity
m

instance Traversable IxFun where
  traverse :: (a -> f b) -> IxFun a -> f (IxFun b)
traverse a -> f b
f (IxFun NonEmpty (LMAD a)
lmads Shape a
oshp Bool
cg) =
    NonEmpty (LMAD b) -> Shape b -> Bool -> IxFun b
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (NonEmpty (LMAD b) -> Shape b -> Bool -> IxFun b)
-> f (NonEmpty (LMAD b)) -> f (Shape b -> Bool -> IxFun b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (LMAD a -> f (LMAD b))
-> NonEmpty (LMAD a) -> f (NonEmpty (LMAD b))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((a -> f b) -> LMAD a -> f (LMAD b)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse a -> f b
f) NonEmpty (LMAD a)
lmads f (Shape b -> Bool -> IxFun b)
-> f (Shape b) -> f (Bool -> IxFun b)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (a -> f b) -> Shape a -> f (Shape b)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse a -> f b
f Shape a
oshp f (Bool -> IxFun b) -> f Bool -> f (IxFun b)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Bool -> f Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
cg

(++@) :: [a] -> NonEmpty a -> NonEmpty a
[a]
es ++@ :: [a] -> NonEmpty a -> NonEmpty a
++@ (a
ne :| [a]
nes) = case [a]
es of
  a
e : [a]
es' -> a
e a -> [a] -> NonEmpty a
forall a. a -> [a] -> NonEmpty a
:| [a]
es' [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a
ne] [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
nes
  [] -> a
ne a -> [a] -> NonEmpty a
forall a. a -> [a] -> NonEmpty a
:| [a]
nes

(@++@) :: NonEmpty a -> NonEmpty a -> NonEmpty a
(a
x :| [a]
xs) @++@ :: NonEmpty a -> NonEmpty a -> NonEmpty a
@++@ (a
y :| [a]
ys) = a
x a -> [a] -> NonEmpty a
forall a. a -> [a] -> NonEmpty a
:| [a]
xs [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a
y] [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
ys

invertMonotonicity :: Monotonicity -> Monotonicity
invertMonotonicity :: Monotonicity -> Monotonicity
invertMonotonicity Monotonicity
Inc = Monotonicity
Dec
invertMonotonicity Monotonicity
Dec = Monotonicity
Inc
invertMonotonicity Monotonicity
Unknown = Monotonicity
Unknown

lmadPermutation :: LMAD num -> Permutation
lmadPermutation :: LMAD num -> Permutation
lmadPermutation = (LMADDim num -> Int) -> [LMADDim num] -> Permutation
forall a b. (a -> b) -> [a] -> [b]
map LMADDim num -> Int
forall num. LMADDim num -> Int
ldPerm ([LMADDim num] -> Permutation)
-> (LMAD num -> [LMADDim num]) -> LMAD num -> Permutation
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
lmadDims

setLMADPermutation :: Permutation -> LMAD num -> LMAD num
setLMADPermutation :: Permutation -> LMAD num -> LMAD num
setLMADPermutation Permutation
perm LMAD num
lmad =
  LMAD num
lmad {lmadDims :: [LMADDim num]
lmadDims = (LMADDim num -> Int -> LMADDim num)
-> [LMADDim num] -> Permutation -> [LMADDim num]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\LMADDim num
dim Int
p -> LMADDim num
dim {ldPerm :: Int
ldPerm = Int
p}) (LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad) Permutation
perm}

setLMADShape :: Shape num -> LMAD num -> LMAD num
setLMADShape :: Shape num -> LMAD num -> LMAD num
setLMADShape Shape num
shp LMAD num
lmad = LMAD num
lmad {lmadDims :: [LMADDim num]
lmadDims = (LMADDim num -> num -> LMADDim num)
-> [LMADDim num] -> Shape num -> [LMADDim num]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\LMADDim num
dim num
s -> LMADDim num
dim {ldShape :: num
ldShape = num
s}) (LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad) Shape num
shp}

-- | Substitute a name with a PrimExp in an LMAD.
substituteInLMAD ::
  Ord a =>
  M.Map a (PrimExp a) ->
  LMAD (PrimExp a) ->
  LMAD (PrimExp a)
substituteInLMAD :: Map a (PrimExp a) -> LMAD (PrimExp a) -> LMAD (PrimExp a)
substituteInLMAD Map a (PrimExp a)
tab (LMAD PrimExp a
offset [LMADDim (PrimExp a)]
dims) =
  let offset' :: PrimExp a
offset' = Map a (PrimExp a) -> PrimExp a -> PrimExp a
forall v. Ord v => Map v (PrimExp v) -> PrimExp v -> PrimExp v
substituteInPrimExp Map a (PrimExp a)
tab PrimExp a
offset
      dims' :: [LMADDim (PrimExp a)]
dims' =
        (LMADDim (PrimExp a) -> LMADDim (PrimExp a))
-> [LMADDim (PrimExp a)] -> [LMADDim (PrimExp a)]
forall a b. (a -> b) -> [a] -> [b]
map
          ( \(LMADDim PrimExp a
s PrimExp a
r PrimExp a
n Int
p Monotonicity
m) ->
              PrimExp a
-> PrimExp a
-> PrimExp a
-> Int
-> Monotonicity
-> LMADDim (PrimExp a)
forall num. num -> num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim
                (Map a (PrimExp a) -> PrimExp a -> PrimExp a
forall v. Ord v => Map v (PrimExp v) -> PrimExp v -> PrimExp v
substituteInPrimExp Map a (PrimExp a)
tab PrimExp a
s)
                (Map a (PrimExp a) -> PrimExp a -> PrimExp a
forall v. Ord v => Map v (PrimExp v) -> PrimExp v -> PrimExp v
substituteInPrimExp Map a (PrimExp a)
tab PrimExp a
r)
                (Map a (PrimExp a) -> PrimExp a -> PrimExp a
forall v. Ord v => Map v (PrimExp v) -> PrimExp v -> PrimExp v
substituteInPrimExp Map a (PrimExp a)
tab PrimExp a
n)
                Int
p
                Monotonicity
m
          )
          [LMADDim (PrimExp a)]
dims
   in PrimExp a -> [LMADDim (PrimExp a)] -> LMAD (PrimExp a)
forall num. num -> [LMADDim num] -> LMAD num
LMAD PrimExp a
offset' [LMADDim (PrimExp a)]
dims'

-- | Substitute a name with a PrimExp in an index function.
substituteInIxFun ::
  Ord a =>
  M.Map a (TPrimExp t a) ->
  IxFun (TPrimExp t a) ->
  IxFun (TPrimExp t a)
substituteInIxFun :: Map a (TPrimExp t a)
-> IxFun (TPrimExp t a) -> IxFun (TPrimExp t a)
substituteInIxFun Map a (TPrimExp t a)
tab (IxFun NonEmpty (LMAD (TPrimExp t a))
lmads Shape (TPrimExp t a)
oshp Bool
cg) =
  NonEmpty (LMAD (TPrimExp t a))
-> Shape (TPrimExp t a) -> Bool -> IxFun (TPrimExp t a)
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun
    ((LMAD (TPrimExp t a) -> LMAD (TPrimExp t a))
-> NonEmpty (LMAD (TPrimExp t a)) -> NonEmpty (LMAD (TPrimExp t a))
forall a b. (a -> b) -> NonEmpty a -> NonEmpty b
NE.map ((PrimExp a -> TPrimExp t a)
-> LMAD (PrimExp a) -> LMAD (TPrimExp t a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap PrimExp a -> TPrimExp t a
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (LMAD (PrimExp a) -> LMAD (TPrimExp t a))
-> (LMAD (TPrimExp t a) -> LMAD (PrimExp a))
-> LMAD (TPrimExp t a)
-> LMAD (TPrimExp t a)
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map a (PrimExp a) -> LMAD (PrimExp a) -> LMAD (PrimExp a)
forall a.
Ord a =>
Map a (PrimExp a) -> LMAD (PrimExp a) -> LMAD (PrimExp a)
substituteInLMAD Map a (PrimExp a)
tab' (LMAD (PrimExp a) -> LMAD (PrimExp a))
-> (LMAD (TPrimExp t a) -> LMAD (PrimExp a))
-> LMAD (TPrimExp t a)
-> LMAD (PrimExp a)
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (TPrimExp t a -> PrimExp a)
-> LMAD (TPrimExp t a) -> LMAD (PrimExp a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap TPrimExp t a -> PrimExp a
forall t v. TPrimExp t v -> PrimExp v
untyped) NonEmpty (LMAD (TPrimExp t a))
lmads)
    ((TPrimExp t a -> TPrimExp t a)
-> Shape (TPrimExp t a) -> Shape (TPrimExp t a)
forall a b. (a -> b) -> [a] -> [b]
map (PrimExp a -> TPrimExp t a
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (PrimExp a -> TPrimExp t a)
-> (TPrimExp t a -> PrimExp a) -> TPrimExp t a -> TPrimExp t a
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map a (PrimExp a) -> PrimExp a -> PrimExp a
forall v. Ord v => Map v (PrimExp v) -> PrimExp v -> PrimExp v
substituteInPrimExp Map a (PrimExp a)
tab' (PrimExp a -> PrimExp a)
-> (TPrimExp t a -> PrimExp a) -> TPrimExp t a -> PrimExp a
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. TPrimExp t a -> PrimExp a
forall t v. TPrimExp t v -> PrimExp v
untyped) Shape (TPrimExp t a)
oshp)
    Bool
cg
  where
    tab' :: Map a (PrimExp a)
tab' = (TPrimExp t a -> PrimExp a)
-> Map a (TPrimExp t a) -> Map a (PrimExp a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap TPrimExp t a -> PrimExp a
forall t v. TPrimExp t v -> PrimExp v
untyped Map a (TPrimExp t a)
tab

-- | Is this is a row-major array?
isDirect :: (Eq num, IntegralExp num) => IxFun num -> Bool
isDirect :: IxFun num -> Bool
isDirect ixfun :: IxFun num
ixfun@(IxFun (LMAD num
offset [LMADDim num]
dims :| []) Shape num
oshp Bool
True) =
  let strides_expected :: Shape num
strides_expected = Shape num -> Shape num
forall a. [a] -> [a]
reverse (Shape num -> Shape num) -> Shape num -> Shape num
forall a b. (a -> b) -> a -> b
$ (num -> num -> num) -> num -> Shape num -> Shape num
forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl num -> num -> num
forall a. Num a => a -> a -> a
(*) num
1 (Shape num -> Shape num
forall a. [a] -> [a]
reverse (Shape num -> Shape num
forall a. [a] -> [a]
tail Shape num
oshp))
   in IxFun num -> Bool
forall a. IxFun a -> Bool
hasContiguousPerm IxFun num
ixfun
        Bool -> Bool -> Bool
&& Shape num -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape num
oshp Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [LMADDim num] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims
        Bool -> Bool -> Bool
&& num
offset num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0
        Bool -> Bool -> Bool
&& ((LMADDim num, Int, num, num) -> Bool)
-> [(LMADDim num, Int, num, num)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all
          ( \(LMADDim num
s num
r num
n Int
p Monotonicity
_, Int
m, num
d, num
se) ->
              num
s num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
se Bool -> Bool -> Bool
&& num
r num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0 Bool -> Bool -> Bool
&& num
n num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
d Bool -> Bool -> Bool
&& Int
p Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
m
          )
          ([LMADDim num]
-> Permutation
-> Shape num
-> Shape num
-> [(LMADDim num, Int, num, num)]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [LMADDim num]
dims [Int
0 .. [LMADDim num] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] Shape num
oshp Shape num
strides_expected)
isDirect IxFun num
_ = Bool
False

-- | Does the index function have an ascending permutation?
hasContiguousPerm :: IxFun num -> Bool
hasContiguousPerm :: IxFun num -> Bool
hasContiguousPerm (IxFun (LMAD num
lmad :| []) Shape num
_ Bool
_) =
  let perm :: Permutation
perm = LMAD num -> Permutation
forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad
   in Permutation
perm Permutation -> Permutation -> Bool
forall a. Eq a => a -> a -> Bool
== Permutation -> Permutation
forall a. Ord a => [a] -> [a]
sort Permutation
perm
hasContiguousPerm IxFun num
_ = Bool
False

-- | Shape of an index function.
shape :: (Eq num, IntegralExp num) => IxFun num -> Shape num
shape :: IxFun num -> Shape num
shape (IxFun (LMAD num
lmad :| [LMAD num]
_) Shape num
_ Bool
_) = LMAD num -> Shape num
forall num. (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShape LMAD num
lmad

-- | Shape of an LMAD.
lmadShape :: (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShape :: LMAD num -> Shape num
lmadShape LMAD num
lmad = Permutation -> Shape num -> Shape num
forall a. Permutation -> [a] -> [a]
permuteInv (LMAD num -> Permutation
forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad) (Shape num -> Shape num) -> Shape num -> Shape num
forall a b. (a -> b) -> a -> b
$ LMAD num -> Shape num
forall num. (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShapeBase LMAD num
lmad

-- | Shape of an LMAD, ignoring permutations.
lmadShapeBase :: (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShapeBase :: LMAD num -> Shape num
lmadShapeBase = (LMADDim num -> num) -> [LMADDim num] -> Shape num
forall a b. (a -> b) -> [a] -> [b]
map LMADDim num -> num
forall num. LMADDim num -> num
ldShape ([LMADDim num] -> Shape num)
-> (LMAD num -> [LMADDim num]) -> LMAD num -> Shape num
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
lmadDims

-- | Compute the flat memory index for a complete set @inds@ of array indices
-- and a certain element size @elem_size@.
index ::
  (IntegralExp num, Eq num) =>
  IxFun num ->
  Indices num ->
  num
index :: IxFun num -> Indices num -> num
index = NonEmpty (LMAD num) -> Indices num -> num
forall num.
(IntegralExp num, Eq num) =>
NonEmpty (LMAD num) -> Indices num -> num
indexFromLMADs (NonEmpty (LMAD num) -> Indices num -> num)
-> (IxFun num -> NonEmpty (LMAD num))
-> IxFun num
-> Indices num
-> num
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. IxFun num -> NonEmpty (LMAD num)
forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs
  where
    indexFromLMADs ::
      (IntegralExp num, Eq num) =>
      NonEmpty (LMAD num) ->
      Indices num ->
      num
    indexFromLMADs :: NonEmpty (LMAD num) -> Indices num -> num
indexFromLMADs (LMAD num
lmad :| []) Indices num
inds = LMAD num -> Indices num -> num
forall num.
(IntegralExp num, Eq num) =>
LMAD num -> Indices num -> num
indexLMAD LMAD num
lmad Indices num
inds
    indexFromLMADs (LMAD num
lmad1 :| LMAD num
lmad2 : [LMAD num]
lmads) Indices num
inds =
      let i_flat :: num
i_flat = LMAD num -> Indices num -> num
forall num.
(IntegralExp num, Eq num) =>
LMAD num -> Indices num -> num
indexLMAD LMAD num
lmad1 Indices num
inds
          new_inds :: Indices num
new_inds = Indices num -> num -> Indices num
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex (Permutation -> Indices num -> Indices num
forall a. Permutation -> [a] -> [a]
permuteFwd (LMAD num -> Permutation
forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad2) (Indices num -> Indices num) -> Indices num -> Indices num
forall a b. (a -> b) -> a -> b
$ LMAD num -> Indices num
forall num. (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShapeBase LMAD num
lmad2) num
i_flat
       in NonEmpty (LMAD num) -> Indices num -> num
forall num.
(IntegralExp num, Eq num) =>
NonEmpty (LMAD num) -> Indices num -> num
indexFromLMADs (LMAD num
lmad2 LMAD num -> [LMAD num] -> NonEmpty (LMAD num)
forall a. a -> [a] -> NonEmpty a
:| [LMAD num]
lmads) Indices num
new_inds
    indexLMAD ::
      (IntegralExp num, Eq num) =>
      LMAD num ->
      Indices num ->
      num
    indexLMAD :: LMAD num -> Indices num -> num
indexLMAD lmad :: LMAD num
lmad@(LMAD num
off [LMADDim num]
dims) Indices num
inds =
      let prod :: num
prod =
            Indices num -> num
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum (Indices num -> num) -> Indices num -> num
forall a b. (a -> b) -> a -> b
$
              ((num, num, num) -> num -> num)
-> [(num, num, num)] -> Indices num -> Indices num
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
                (num, num, num) -> num -> num
forall num.
(Eq num, IntegralExp num) =>
(num, num, num) -> num -> num
flatOneDim
                ((LMADDim num -> (num, num, num))
-> [LMADDim num] -> [(num, num, num)]
forall a b. (a -> b) -> [a] -> [b]
map (\(LMADDim num
s num
r num
n Int
_ Monotonicity
_) -> (num
s, num
r, num
n)) [LMADDim num]
dims)
                (Permutation -> Indices num -> Indices num
forall a. Permutation -> [a] -> [a]
permuteInv (LMAD num -> Permutation
forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad) Indices num
inds)
       in num
off num -> num -> num
forall a. Num a => a -> a -> a
+ num
prod

-- | iota with offset.
iotaOffset :: IntegralExp num => num -> Shape num -> IxFun num
iotaOffset :: num -> Shape num -> IxFun num
iotaOffset num
o Shape num
ns =
  let rs :: Shape num
rs = Int -> num -> Shape num
forall a. Int -> a -> [a]
replicate (Shape num -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape num
ns) num
0
   in NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (Monotonicity -> num -> [(num, num)] -> LMAD num
forall num.
IntegralExp num =>
Monotonicity -> num -> [(num, num)] -> LMAD num
makeRotIota Monotonicity
Inc num
o (Shape num -> Shape num -> [(num, num)]
forall a b. [a] -> [b] -> [(a, b)]
zip Shape num
rs Shape num
ns) LMAD num -> [LMAD num] -> NonEmpty (LMAD num)
forall a. a -> [a] -> NonEmpty a
:| []) Shape num
ns Bool
True

-- | iota.
iota :: IntegralExp num => Shape num -> IxFun num
iota :: Shape num -> IxFun num
iota = num -> Shape num -> IxFun num
forall num. IntegralExp num => num -> Shape num -> IxFun num
iotaOffset num
0

-- | Permute dimensions.
permute ::
  IntegralExp num =>
  IxFun num ->
  Permutation ->
  IxFun num
permute :: IxFun num -> Permutation -> IxFun num
permute (IxFun (LMAD num
lmad :| [LMAD num]
lmads) Shape num
oshp Bool
cg) Permutation
perm_new =
  let perm_cur :: Permutation
perm_cur = LMAD num -> Permutation
forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad
      perm :: Permutation
perm = (Int -> Int) -> Permutation -> Permutation
forall a b. (a -> b) -> [a] -> [b]
map (Permutation
perm_cur Permutation -> Int -> Int
forall a. [a] -> Int -> a
!!) Permutation
perm_new
   in NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (Permutation -> LMAD num -> LMAD num
forall num. Permutation -> LMAD num -> LMAD num
setLMADPermutation Permutation
perm LMAD num
lmad LMAD num -> [LMAD num] -> NonEmpty (LMAD num)
forall a. a -> [a] -> NonEmpty a
:| [LMAD num]
lmads) Shape num
oshp Bool
cg

-- | Rotate an index function.
rotate ::
  (Eq num, IntegralExp num) =>
  IxFun num ->
  Indices num ->
  IxFun num
rotate :: IxFun num -> Indices num -> IxFun num
rotate (IxFun (lmad :: LMAD num
lmad@(LMAD num
off [LMADDim num]
dims) :| [LMAD num]
lmads) Indices num
oshp Bool
cg) Indices num
offs =
  let dims' :: [LMADDim num]
dims' =
        (LMADDim num -> num -> LMADDim num)
-> [LMADDim num] -> Indices num -> [LMADDim num]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
          ( \(LMADDim num
s num
r num
n Int
p Monotonicity
f) num
o ->
              if num
s num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0
                then num -> num -> num -> Int -> Monotonicity -> LMADDim num
forall num. num -> num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim num
0 num
0 num
n Int
p Monotonicity
Unknown
                else num -> num -> num -> Int -> Monotonicity -> LMADDim num
forall num. num -> num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim num
s (num
r num -> num -> num
forall a. Num a => a -> a -> a
+ num
o) num
n Int
p Monotonicity
f
          )
          [LMADDim num]
dims
          (Permutation -> Indices num -> Indices num
forall a. Permutation -> [a] -> [a]
permuteInv (LMAD num -> Permutation
forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad) Indices num
offs)
   in NonEmpty (LMAD num) -> Indices num -> Bool -> IxFun num
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off [LMADDim num]
dims' LMAD num -> [LMAD num] -> NonEmpty (LMAD num)
forall a. a -> [a] -> NonEmpty a
:| [LMAD num]
lmads) Indices num
oshp Bool
cg

-- | Handle the case where a slice can stay within a single LMAD.
sliceOneLMAD ::
  (Eq num, IntegralExp num) =>
  IxFun num ->
  Slice num ->
  Maybe (IxFun num)
sliceOneLMAD :: IxFun num -> Slice num -> Maybe (IxFun num)
sliceOneLMAD (IxFun (lmad :: LMAD num
lmad@(LMAD num
_ [LMADDim num]
ldims) :| [LMAD num]
lmads) Shape num
oshp Bool
cg) Slice num
is = do
  let perm :: Permutation
perm = LMAD num -> Permutation
forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad
      is' :: Slice num
is' = Permutation -> Slice num -> Slice num
forall a. Permutation -> [a] -> [a]
permuteInv Permutation
perm Slice num
is
      cg' :: Bool
cg' = Bool
cg Bool -> Bool -> Bool
&& LMAD num -> Slice num -> Bool
forall num.
(Eq num, IntegralExp num) =>
LMAD num -> Slice num -> Bool
slicePreservesContiguous LMAD num
lmad Slice num
is'
  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ LMAD num -> Slice num -> Bool
forall num.
(Eq num, IntegralExp num) =>
LMAD num -> Slice num -> Bool
harmlessRotation LMAD num
lmad Slice num
is'
  let lmad' :: LMAD num
lmad' = (LMAD num -> (DimIndex num, LMADDim num) -> LMAD num)
-> LMAD num -> [(DimIndex num, LMADDim num)] -> LMAD num
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl LMAD num -> (DimIndex num, LMADDim num) -> LMAD num
forall num.
(Eq num, IntegralExp num) =>
LMAD num -> (DimIndex num, LMADDim num) -> LMAD num
sliceOne (num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD (LMAD num -> num
forall num. LMAD num -> num
lmadOffset LMAD num
lmad) []) ([(DimIndex num, LMADDim num)] -> LMAD num)
-> [(DimIndex num, LMADDim num)] -> LMAD num
forall a b. (a -> b) -> a -> b
$ Slice num -> [LMADDim num] -> [(DimIndex num, LMADDim num)]
forall a b. [a] -> [b] -> [(a, b)]
zip Slice num
is' [LMADDim num]
ldims
      -- need to remove the fixed dims from the permutation
      perm' :: Permutation
perm' =
        Permutation -> Permutation -> Permutation
forall (t :: * -> *) a (t :: * -> *).
(Foldable t, Foldable t, Num a, Ord a) =>
t a -> t a -> [a]
updatePerm Permutation
perm (Permutation -> Permutation) -> Permutation -> Permutation
forall a b. (a -> b) -> a -> b
$
          ((Int, DimIndex num) -> Int)
-> [(Int, DimIndex num)] -> Permutation
forall a b. (a -> b) -> [a] -> [b]
map (Int, DimIndex num) -> Int
forall a b. (a, b) -> a
fst ([(Int, DimIndex num)] -> Permutation)
-> [(Int, DimIndex num)] -> Permutation
forall a b. (a -> b) -> a -> b
$
            ((Int, DimIndex num) -> Bool)
-> [(Int, DimIndex num)] -> [(Int, DimIndex num)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Maybe num -> Bool
forall a. Maybe a -> Bool
isJust (Maybe num -> Bool)
-> ((Int, DimIndex num) -> Maybe num)
-> (Int, DimIndex num)
-> Bool
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. DimIndex num -> Maybe num
forall d. DimIndex d -> Maybe d
dimFix (DimIndex num -> Maybe num)
-> ((Int, DimIndex num) -> DimIndex num)
-> (Int, DimIndex num)
-> Maybe num
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (Int, DimIndex num) -> DimIndex num
forall a b. (a, b) -> b
snd) ([(Int, DimIndex num)] -> [(Int, DimIndex num)])
-> [(Int, DimIndex num)] -> [(Int, DimIndex num)]
forall a b. (a -> b) -> a -> b
$
              Permutation -> Slice num -> [(Int, DimIndex num)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 .. Slice num -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Slice num
is' Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] Slice num
is'

  IxFun num -> Maybe (IxFun num)
forall (m :: * -> *) a. Monad m => a -> m a
return (IxFun num -> Maybe (IxFun num)) -> IxFun num -> Maybe (IxFun num)
forall a b. (a -> b) -> a -> b
$ NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (Permutation -> LMAD num -> LMAD num
forall num. Permutation -> LMAD num -> LMAD num
setLMADPermutation Permutation
perm' LMAD num
lmad' LMAD num -> [LMAD num] -> NonEmpty (LMAD num)
forall a. a -> [a] -> NonEmpty a
:| [LMAD num]
lmads) Shape num
oshp Bool
cg'
  where
    updatePerm :: t a -> t a -> [a]
updatePerm t a
ps t a
inds = ([a] -> a -> [a]) -> [a] -> t a -> [a]
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\[a]
acc a
p -> [a]
acc [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ a -> [a]
decrease a
p) [] t a
ps
      where
        decrease :: a -> [a]
decrease a
p =
          let d :: a
d =
                (a -> a -> a) -> a -> t a -> a
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
                  ( \a
n a
i ->
                      if a
i a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
p
                        then -a
1
                        else
                          if a
i a -> a -> Bool
forall a. Ord a => a -> a -> Bool
> a
p
                            then a
n
                            else
                              if a
n a -> a -> Bool
forall a. Eq a => a -> a -> Bool
/= -a
1
                                then a
n a -> a -> a
forall a. Num a => a -> a -> a
+ a
1
                                else a
n
                  )
                  a
0
                  t a
inds
           in [a
p a -> a -> a
forall a. Num a => a -> a -> a
- a
d | a
d a -> a -> Bool
forall a. Eq a => a -> a -> Bool
/= -a
1]

    harmlessRotation' ::
      (Eq num, IntegralExp num) =>
      LMADDim num ->
      DimIndex num ->
      Bool
    harmlessRotation' :: LMADDim num -> DimIndex num -> Bool
harmlessRotation' LMADDim num
_ (DimFix num
_) = Bool
True
    harmlessRotation' (LMADDim num
0 num
_ num
_ Int
_ Monotonicity
_) DimIndex num
_ = Bool
True
    harmlessRotation' (LMADDim num
_ num
0 num
_ Int
_ Monotonicity
_) DimIndex num
_ = Bool
True
    harmlessRotation' (LMADDim num
_ num
_ num
n Int
_ Monotonicity
_) DimIndex num
dslc
      | DimIndex num
dslc DimIndex num -> DimIndex num -> Bool
forall a. Eq a => a -> a -> Bool
== num -> num -> num -> DimIndex num
forall d. d -> d -> d -> DimIndex d
DimSlice (num
n num -> num -> num
forall a. Num a => a -> a -> a
- num
1) num
n (-num
1)
          Bool -> Bool -> Bool
|| DimIndex num
dslc DimIndex num -> DimIndex num -> Bool
forall a. Eq a => a -> a -> Bool
== num -> num -> DimIndex num
forall d. Num d => d -> d -> DimIndex d
unitSlice num
0 num
n =
        Bool
True
    harmlessRotation' LMADDim num
_ DimIndex num
_ = Bool
False

    harmlessRotation ::
      (Eq num, IntegralExp num) =>
      LMAD num ->
      Slice num ->
      Bool
    harmlessRotation :: LMAD num -> Slice num -> Bool
harmlessRotation (LMAD num
_ [LMADDim num]
dims) Slice num
iss =
      [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ([Bool] -> Bool) -> [Bool] -> Bool
forall a b. (a -> b) -> a -> b
$ (LMADDim num -> DimIndex num -> Bool)
-> [LMADDim num] -> Slice num -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith LMADDim num -> DimIndex num -> Bool
forall num.
(Eq num, IntegralExp num) =>
LMADDim num -> DimIndex num -> Bool
harmlessRotation' [LMADDim num]
dims Slice num
iss

    -- XXX: TODO: what happens to r on a negative-stride slice; is there
    -- such a case?
    sliceOne ::
      (Eq num, IntegralExp num) =>
      LMAD num ->
      (DimIndex num, LMADDim num) ->
      LMAD num
    sliceOne :: LMAD num -> (DimIndex num, LMADDim num) -> LMAD num
sliceOne (LMAD num
off [LMADDim num]
dims) (DimFix num
i, LMADDim num
s num
r num
n Int
_ Monotonicity
_) =
      num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD (num
off num -> num -> num
forall a. Num a => a -> a -> a
+ (num, num, num) -> num -> num
forall num.
(Eq num, IntegralExp num) =>
(num, num, num) -> num -> num
flatOneDim (num
s, num
r, num
n) num
i) [LMADDim num]
dims
    sliceOne (LMAD num
off [LMADDim num]
dims) (DimSlice num
_ num
ne num
_, LMADDim num
0 num
_ num
_ Int
p Monotonicity
_) =
      num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off ([LMADDim num]
dims [LMADDim num] -> [LMADDim num] -> [LMADDim num]
forall a. [a] -> [a] -> [a]
++ [num -> num -> num -> Int -> Monotonicity -> LMADDim num
forall num. num -> num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim num
0 num
0 num
ne Int
p Monotonicity
Unknown])
    sliceOne (LMAD num
off [LMADDim num]
dims) (DimIndex num
dmind, dim :: LMADDim num
dim@(LMADDim num
_ num
_ num
n Int
_ Monotonicity
_))
      | DimIndex num
dmind DimIndex num -> DimIndex num -> Bool
forall a. Eq a => a -> a -> Bool
== num -> num -> DimIndex num
forall d. Num d => d -> d -> DimIndex d
unitSlice num
0 num
n = num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off ([LMADDim num]
dims [LMADDim num] -> [LMADDim num] -> [LMADDim num]
forall a. [a] -> [a] -> [a]
++ [LMADDim num
dim])
    sliceOne (LMAD num
off [LMADDim num]
dims) (DimIndex num
dmind, LMADDim num
s num
r num
n Int
p Monotonicity
m)
      | DimIndex num
dmind DimIndex num -> DimIndex num -> Bool
forall a. Eq a => a -> a -> Bool
== num -> num -> num -> DimIndex num
forall d. d -> d -> d -> DimIndex d
DimSlice (num
n num -> num -> num
forall a. Num a => a -> a -> a
- num
1) num
n (-num
1) =
        let r' :: num
r' = if num
r num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0 then num
0 else num
n num -> num -> num
forall a. Num a => a -> a -> a
- num
r
            off' :: num
off' = num
off num -> num -> num
forall a. Num a => a -> a -> a
+ (num, num, num) -> num -> num
forall num.
(Eq num, IntegralExp num) =>
(num, num, num) -> num -> num
flatOneDim (num
s, num
0, num
n) (num
n num -> num -> num
forall a. Num a => a -> a -> a
- num
1)
         in num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off' ([LMADDim num]
dims [LMADDim num] -> [LMADDim num] -> [LMADDim num]
forall a. [a] -> [a] -> [a]
++ [num -> num -> num -> Int -> Monotonicity -> LMADDim num
forall num. num -> num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim (num
s num -> num -> num
forall a. Num a => a -> a -> a
* (-num
1)) num
r' num
n Int
p (Monotonicity -> Monotonicity
invertMonotonicity Monotonicity
m)])
    sliceOne (LMAD num
off [LMADDim num]
dims) (DimSlice num
b num
ne num
0, LMADDim num
s num
r num
n Int
p Monotonicity
_) =
      num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD (num
off num -> num -> num
forall a. Num a => a -> a -> a
+ (num, num, num) -> num -> num
forall num.
(Eq num, IntegralExp num) =>
(num, num, num) -> num -> num
flatOneDim (num
s, num
r, num
n) num
b) ([LMADDim num]
dims [LMADDim num] -> [LMADDim num] -> [LMADDim num]
forall a. [a] -> [a] -> [a]
++ [num -> num -> num -> Int -> Monotonicity -> LMADDim num
forall num. num -> num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim num
0 num
0 num
ne Int
p Monotonicity
Unknown])
    sliceOne (LMAD num
off [LMADDim num]
dims) (DimSlice num
bs num
ns num
ss, LMADDim num
s num
0 num
_ Int
p Monotonicity
m) =
      let m' :: Monotonicity
m' = case num -> Maybe Int
forall e. IntegralExp e => e -> Maybe Int
sgn num
ss of
            Just Int
1 -> Monotonicity
m
            Just (-1) -> Monotonicity -> Monotonicity
invertMonotonicity Monotonicity
m
            Maybe Int
_ -> Monotonicity
Unknown
       in num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD (num
off num -> num -> num
forall a. Num a => a -> a -> a
+ num
s num -> num -> num
forall a. Num a => a -> a -> a
* num
bs) ([LMADDim num]
dims [LMADDim num] -> [LMADDim num] -> [LMADDim num]
forall a. [a] -> [a] -> [a]
++ [num -> num -> num -> Int -> Monotonicity -> LMADDim num
forall num. num -> num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim (num
ss num -> num -> num
forall a. Num a => a -> a -> a
* num
s) num
0 num
ns Int
p Monotonicity
m'])
    sliceOne LMAD num
_ (DimIndex num, LMADDim num)
_ = String -> LMAD num
forall a. HasCallStack => String -> a
error String
"slice: reached impossible case"

    slicePreservesContiguous ::
      (Eq num, IntegralExp num) =>
      LMAD num ->
      Slice num ->
      Bool
    slicePreservesContiguous :: LMAD num -> Slice num -> Bool
slicePreservesContiguous (LMAD num
_ [LMADDim num]
dims) Slice num
slc =
      -- remove from the slice the LMAD dimensions that have stride 0.
      -- If the LMAD was contiguous in mem, then these dims will not
      -- influence the contiguousness of the result.
      -- Also normalize the input slice, i.e., 0-stride and size-1
      -- slices are rewritten as DimFixed.
      let ([LMADDim num]
dims', Slice num
slc') =
            [(LMADDim num, DimIndex num)] -> ([LMADDim num], Slice num)
forall a b. [(a, b)] -> ([a], [b])
unzip ([(LMADDim num, DimIndex num)] -> ([LMADDim num], Slice num))
-> [(LMADDim num, DimIndex num)] -> ([LMADDim num], Slice num)
forall a b. (a -> b) -> a -> b
$
              ((LMADDim num, DimIndex num) -> Bool)
-> [(LMADDim num, DimIndex num)] -> [(LMADDim num, DimIndex num)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((num -> num -> Bool
forall a. Eq a => a -> a -> Bool
/= num
0) (num -> Bool)
-> ((LMADDim num, DimIndex num) -> num)
-> (LMADDim num, DimIndex num)
-> Bool
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. LMADDim num -> num
forall num. LMADDim num -> num
ldStride (LMADDim num -> num)
-> ((LMADDim num, DimIndex num) -> LMADDim num)
-> (LMADDim num, DimIndex num)
-> num
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (LMADDim num, DimIndex num) -> LMADDim num
forall a b. (a, b) -> a
fst) ([(LMADDim num, DimIndex num)] -> [(LMADDim num, DimIndex num)])
-> [(LMADDim num, DimIndex num)] -> [(LMADDim num, DimIndex num)]
forall a b. (a -> b) -> a -> b
$
                [LMADDim num] -> Slice num -> [(LMADDim num, DimIndex num)]
forall a b. [a] -> [b] -> [(a, b)]
zip [LMADDim num]
dims (Slice num -> [(LMADDim num, DimIndex num)])
-> Slice num -> [(LMADDim num, DimIndex num)]
forall a b. (a -> b) -> a -> b
$ (DimIndex num -> DimIndex num) -> Slice num -> Slice num
forall a b. (a -> b) -> [a] -> [b]
map DimIndex num -> DimIndex num
forall num.
(Eq num, IntegralExp num) =>
DimIndex num -> DimIndex num
normIndex Slice num
slc
          -- Check that:
          -- 1. a clean split point exists between Fixed and Sliced dims
          -- 2. the outermost sliced dim has +/- 1 stride AND is unrotated or full.
          -- 3. the rest of inner sliced dims are full.
          (Bool
_, Bool
success) =
            ((Bool, Bool) -> (DimIndex num, LMADDim num) -> (Bool, Bool))
-> (Bool, Bool) -> [(DimIndex num, LMADDim num)] -> (Bool, Bool)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
              ( \(Bool
found, Bool
res) (DimIndex num
slcdim, LMADDim num
_ num
r num
n Int
_ Monotonicity
_) ->
                  case (DimIndex num
slcdim, Bool
found) of
                    (DimFix {}, Bool
True) -> (Bool
found, Bool
False)
                    (DimFix {}, Bool
False) -> (Bool
found, Bool
res)
                    (DimSlice num
_ num
ne num
ds, Bool
False) ->
                      -- outermost sliced dim: +/-1 stride
                      let res' :: Bool
res' = (num
r num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0 Bool -> Bool -> Bool
|| num
n num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
ne) Bool -> Bool -> Bool
&& (num
ds num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
1 Bool -> Bool -> Bool
|| num
ds num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== -num
1)
                       in (Bool
True, Bool
res Bool -> Bool -> Bool
&& Bool
res')
                    (DimSlice num
_ num
ne num
ds, Bool
True) ->
                      -- inner sliced dim: needs to be full
                      let res' :: Bool
res' = (num
n num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
ne) Bool -> Bool -> Bool
&& (num
ds num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
1 Bool -> Bool -> Bool
|| num
ds num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== -num
1)
                       in (Bool
found, Bool
res Bool -> Bool -> Bool
&& Bool
res')
              )
              (Bool
False, Bool
True)
              ([(DimIndex num, LMADDim num)] -> (Bool, Bool))
-> [(DimIndex num, LMADDim num)] -> (Bool, Bool)
forall a b. (a -> b) -> a -> b
$ Slice num -> [LMADDim num] -> [(DimIndex num, LMADDim num)]
forall a b. [a] -> [b] -> [(a, b)]
zip Slice num
slc' [LMADDim num]
dims'
       in Bool
success

    normIndex ::
      (Eq num, IntegralExp num) =>
      DimIndex num ->
      DimIndex num
    normIndex :: DimIndex num -> DimIndex num
normIndex (DimSlice num
b num
1 num
_) = num -> DimIndex num
forall d. d -> DimIndex d
DimFix num
b
    normIndex (DimSlice num
b num
_ num
0) = num -> DimIndex num
forall d. d -> DimIndex d
DimFix num
b
    normIndex DimIndex num
d = DimIndex num
d

-- | Slice an index function.
slice ::
  (Eq num, IntegralExp num) =>
  IxFun num ->
  Slice num ->
  IxFun num
slice :: IxFun num -> Slice num -> IxFun num
slice IxFun num
_ [] = String -> IxFun num
forall a. HasCallStack => String -> a
error String
"slice: empty slice"
slice ixfun :: IxFun num
ixfun@(IxFun (lmad :: LMAD num
lmad@(LMAD num
_ [LMADDim num]
_) :| [LMAD num]
lmads) Shape num
oshp Bool
cg) Slice num
dim_slices
  -- Avoid identity slicing.
  | Slice num
dim_slices Slice num -> Slice num -> Bool
forall a. Eq a => a -> a -> Bool
== (num -> DimIndex num) -> Shape num -> Slice num
forall a b. (a -> b) -> [a] -> [b]
map (num -> num -> DimIndex num
forall d. Num d => d -> d -> DimIndex d
unitSlice num
0) (IxFun num -> Shape num
forall num. (Eq num, IntegralExp num) => IxFun num -> Shape num
shape IxFun num
ixfun) = IxFun num
ixfun
  | Just IxFun num
ixfun' <- IxFun num -> Slice num -> Maybe (IxFun num)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> Maybe (IxFun num)
sliceOneLMAD IxFun num
ixfun Slice num
dim_slices = IxFun num
ixfun'
  | Bool
otherwise =
    case IxFun num -> Slice num -> Maybe (IxFun num)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> Maybe (IxFun num)
sliceOneLMAD (Shape num -> IxFun num
forall num. IntegralExp num => Shape num -> IxFun num
iota (LMAD num -> Shape num
forall num. (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShape LMAD num
lmad)) Slice num
dim_slices of
      Just (IxFun (LMAD num
lmad' :| []) Shape num
_ Bool
cg') ->
        NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (LMAD num
lmad' LMAD num -> [LMAD num] -> NonEmpty (LMAD num)
forall a. a -> [a] -> NonEmpty a
:| LMAD num
lmad LMAD num -> [LMAD num] -> [LMAD num]
forall a. a -> [a] -> [a]
: [LMAD num]
lmads) Shape num
oshp (Bool
cg Bool -> Bool -> Bool
&& Bool
cg')
      Maybe (IxFun num)
_ -> String -> IxFun num
forall a. HasCallStack => String -> a
error String
"slice: reached impossible case"

-- | Handle the simple case where all reshape dimensions are coercions.
reshapeCoercion ::
  (Eq num, IntegralExp num) =>
  IxFun num ->
  ShapeChange num ->
  Maybe (IxFun num)
reshapeCoercion :: IxFun num -> ShapeChange num -> Maybe (IxFun num)
reshapeCoercion (IxFun (lmad :: LMAD num
lmad@(LMAD num
off [LMADDim num]
dims) :| [LMAD num]
lmads) Shape num
oldbase Bool
cg) ShapeChange num
newshape = do
  let perm :: Permutation
perm = LMAD num -> Permutation
forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad
  (ShapeChange num
head_coercions, ShapeChange num
reshapes, ShapeChange num
tail_coercions) <- ShapeChange num
-> Maybe (ShapeChange num, ShapeChange num, ShapeChange num)
forall num.
(Eq num, IntegralExp num) =>
ShapeChange num
-> Maybe (ShapeChange num, ShapeChange num, ShapeChange num)
splitCoercions ShapeChange num
newshape
  let hd_len :: Int
hd_len = ShapeChange num -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange num
head_coercions
      num_coercions :: Int
num_coercions = Int
hd_len Int -> Int -> Int
forall a. Num a => a -> a -> a
+ ShapeChange num -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange num
tail_coercions
      dims' :: [LMADDim num]
dims' = Permutation -> [LMADDim num] -> [LMADDim num]
forall a. Permutation -> [a] -> [a]
permuteFwd Permutation
perm [LMADDim num]
dims
      mid_dims :: [LMADDim num]
mid_dims = Int -> [LMADDim num] -> [LMADDim num]
forall a. Int -> [a] -> [a]
take ([LMADDim num] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
num_coercions) ([LMADDim num] -> [LMADDim num]) -> [LMADDim num] -> [LMADDim num]
forall a b. (a -> b) -> a -> b
$ Int -> [LMADDim num] -> [LMADDim num]
forall a. Int -> [a] -> [a]
drop Int
hd_len [LMADDim num]
dims'
      num_rshps :: Int
num_rshps = ShapeChange num -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange num
reshapes
  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Int
num_rshps Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 Bool -> Bool -> Bool
|| (Int
num_rshps Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 Bool -> Bool -> Bool
&& [LMADDim num] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
mid_dims Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1))
  let dims'' :: [LMADDim num]
dims'' =
        Permutation -> [LMADDim num] -> [LMADDim num]
forall a. Permutation -> [a] -> [a]
permuteInv Permutation
perm ([LMADDim num] -> [LMADDim num]) -> [LMADDim num] -> [LMADDim num]
forall a b. (a -> b) -> a -> b
$
          (LMADDim num -> num -> LMADDim num)
-> [LMADDim num] -> Shape num -> [LMADDim num]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
            (\LMADDim num
ld num
n -> LMADDim num
ld {ldShape :: num
ldShape = num
n})
            [LMADDim num]
dims'
            (ShapeChange num -> Shape num
forall d. ShapeChange d -> [d]
newDims ShapeChange num
newshape)
      lmad' :: LMAD num
lmad' = num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off [LMADDim num]
dims''
  IxFun num -> Maybe (IxFun num)
forall (m :: * -> *) a. Monad m => a -> m a
return (IxFun num -> Maybe (IxFun num)) -> IxFun num -> Maybe (IxFun num)
forall a b. (a -> b) -> a -> b
$ NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (LMAD num
lmad' LMAD num -> [LMAD num] -> NonEmpty (LMAD num)
forall a. a -> [a] -> NonEmpty a
:| [LMAD num]
lmads) Shape num
oldbase Bool
cg

-- | Handle the case where a reshape operation can stay inside a single LMAD.
--
-- There are four conditions that all must hold for the result of a reshape
-- operation to remain in the one-LMAD domain:
--
--   (1) the permutation of the underlying LMAD must leave unchanged
--       the LMAD dimensions that were *not* reshape coercions.
--   (2) the repetition of dimensions of the underlying LMAD must
--       refer only to the coerced-dimensions of the reshape operation.
--   (3) similarly, the rotated dimensions must refer only to
--       dimensions that are coerced by the reshape operation.
--   (4) finally, the underlying memory is contiguous (and monotonous).
--
-- If any of these conditions do not hold, then the reshape operation will
-- conservatively add a new LMAD to the list, leading to a representation that
-- provides less opportunities for further analysis.
reshapeOneLMAD ::
  (Eq num, IntegralExp num) =>
  IxFun num ->
  ShapeChange num ->
  Maybe (IxFun num)
reshapeOneLMAD :: IxFun num -> ShapeChange num -> Maybe (IxFun num)
reshapeOneLMAD ixfun :: IxFun num
ixfun@(IxFun (lmad :: LMAD num
lmad@(LMAD num
off [LMADDim num]
dims) :| [LMAD num]
lmads) Shape num
oldbase Bool
cg) ShapeChange num
newshape = do
  let perm :: Permutation
perm = LMAD num -> Permutation
forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad
  (ShapeChange num
head_coercions, ShapeChange num
reshapes, ShapeChange num
tail_coercions) <- ShapeChange num
-> Maybe (ShapeChange num, ShapeChange num, ShapeChange num)
forall num.
(Eq num, IntegralExp num) =>
ShapeChange num
-> Maybe (ShapeChange num, ShapeChange num, ShapeChange num)
splitCoercions ShapeChange num
newshape
  let hd_len :: Int
hd_len = ShapeChange num -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange num
head_coercions
      num_coercions :: Int
num_coercions = Int
hd_len Int -> Int -> Int
forall a. Num a => a -> a -> a
+ ShapeChange num -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange num
tail_coercions
      dims_perm :: [LMADDim num]
dims_perm = Permutation -> [LMADDim num] -> [LMADDim num]
forall a. Permutation -> [a] -> [a]
permuteFwd Permutation
perm [LMADDim num]
dims
      mid_dims :: [LMADDim num]
mid_dims = Int -> [LMADDim num] -> [LMADDim num]
forall a. Int -> [a] -> [a]
take ([LMADDim num] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
num_coercions) ([LMADDim num] -> [LMADDim num]) -> [LMADDim num] -> [LMADDim num]
forall a b. (a -> b) -> a -> b
$ Int -> [LMADDim num] -> [LMADDim num]
forall a. Int -> [a] -> [a]
drop Int
hd_len [LMADDim num]
dims_perm
      -- Ignore rotates, as we only care about not having rotates in the
      -- dimensions that aren't coercions (@mid_dims@), which we check
      -- separately.
      mon :: Monotonicity
mon = Bool -> IxFun num -> Monotonicity
forall num.
(Eq num, IntegralExp num) =>
Bool -> IxFun num -> Monotonicity
ixfunMonotonicityRots Bool
True IxFun num
ixfun

  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$
    -- checking conditions (2) and (3)
    (LMADDim num -> Bool) -> [LMADDim num] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\(LMADDim num
s num
r num
_ Int
_ Monotonicity
_) -> num
s num -> num -> Bool
forall a. Eq a => a -> a -> Bool
/= num
0 Bool -> Bool -> Bool
&& num
r num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0) [LMADDim num]
mid_dims
      Bool -> Bool -> Bool
&&
      -- checking condition (1)
      Int -> Permutation -> Bool
forall a. (Eq a, Num a, Enum a) => a -> [a] -> Bool
consecutive Int
hd_len ((LMADDim num -> Int) -> [LMADDim num] -> Permutation
forall a b. (a -> b) -> [a] -> [b]
map LMADDim num -> Int
forall num. LMADDim num -> Int
ldPerm [LMADDim num]
mid_dims)
      Bool -> Bool -> Bool
&&
      -- checking condition (4)
      IxFun num -> Bool
forall a. IxFun a -> Bool
hasContiguousPerm IxFun num
ixfun
      Bool -> Bool -> Bool
&& Bool
cg
      Bool -> Bool -> Bool
&& (Monotonicity
mon Monotonicity -> Monotonicity -> Bool
forall a. Eq a => a -> a -> Bool
== Monotonicity
Inc Bool -> Bool -> Bool
|| Monotonicity
mon Monotonicity -> Monotonicity -> Bool
forall a. Eq a => a -> a -> Bool
== Monotonicity
Dec)

  -- make new permutation
  let rsh_len :: Int
rsh_len = ShapeChange num -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange num
reshapes
      diff :: Int
diff = ShapeChange num -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange num
newshape Int -> Int -> Int
forall a. Num a => a -> a -> a
- [LMADDim num] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims
      iota_shape :: Permutation
iota_shape = [Int
0 .. ShapeChange num -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange num
newshape Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
      perm' :: Permutation
perm' =
        (Int -> Int) -> Permutation -> Permutation
forall a b. (a -> b) -> [a] -> [b]
map
          ( \Int
i ->
              let ind :: Int
ind =
                    if Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
hd_len
                      then Int
i
                      else Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
diff
               in if (Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
hd_len) Bool -> Bool -> Bool
&& (Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
hd_len Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
rsh_len)
                    then Int
i -- already checked mid_dims not affected
                    else
                      let p :: Int
p = LMADDim num -> Int
forall num. LMADDim num -> Int
ldPerm ([LMADDim num]
dims [LMADDim num] -> Int -> LMADDim num
forall a. [a] -> Int -> a
!! Int
ind)
                       in if Int
p Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
hd_len
                            then Int
p
                            else Int
p Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
diff
          )
          Permutation
iota_shape
      -- split the dimensions
      ([(Int, (num, num))]
support_inds, [(Int, num)]
repeat_inds) =
        (([(Int, (num, num))], [(Int, num)])
 -> (Int, DimChange num, Int)
 -> ([(Int, (num, num))], [(Int, num)]))
-> ([(Int, (num, num))], [(Int, num)])
-> [(Int, DimChange num, Int)]
-> ([(Int, (num, num))], [(Int, num)])
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
          ( \([(Int, (num, num))]
sup, [(Int, num)]
rpt) (Int
i, DimChange num
shpdim, Int
ip) ->
              case (Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
hd_len, Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
hd_len Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
rsh_len, DimChange num
shpdim) of
                (Bool
True, Bool
_, DimCoercion num
n) ->
                  case [LMADDim num]
dims_perm [LMADDim num] -> Int -> LMADDim num
forall a. [a] -> Int -> a
!! Int
i of
                    (LMADDim num
0 num
_ num
_ Int
_ Monotonicity
_) -> ([(Int, (num, num))]
sup, (Int
ip, num
n) (Int, num) -> [(Int, num)] -> [(Int, num)]
forall a. a -> [a] -> [a]
: [(Int, num)]
rpt)
                    (LMADDim num
_ num
r num
_ Int
_ Monotonicity
_) -> ((Int
ip, (num
r, num
n)) (Int, (num, num)) -> [(Int, (num, num))] -> [(Int, (num, num))]
forall a. a -> [a] -> [a]
: [(Int, (num, num))]
sup, [(Int, num)]
rpt)
                (Bool
_, Bool
True, DimCoercion num
n) ->
                  case [LMADDim num]
dims_perm [LMADDim num] -> Int -> LMADDim num
forall a. [a] -> Int -> a
!! (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
diff) of
                    (LMADDim num
0 num
_ num
_ Int
_ Monotonicity
_) -> ([(Int, (num, num))]
sup, (Int
ip, num
n) (Int, num) -> [(Int, num)] -> [(Int, num)]
forall a. a -> [a] -> [a]
: [(Int, num)]
rpt)
                    (LMADDim num
_ num
r num
_ Int
_ Monotonicity
_) -> ((Int
ip, (num
r, num
n)) (Int, (num, num)) -> [(Int, (num, num))] -> [(Int, (num, num))]
forall a. a -> [a] -> [a]
: [(Int, (num, num))]
sup, [(Int, num)]
rpt)
                (Bool
False, Bool
False, DimChange num
_) ->
                  ((Int
ip, (num
0, DimChange num -> num
forall d. DimChange d -> d
newDim DimChange num
shpdim)) (Int, (num, num)) -> [(Int, (num, num))] -> [(Int, (num, num))]
forall a. a -> [a] -> [a]
: [(Int, (num, num))]
sup, [(Int, num)]
rpt)
                -- already checked that the reshaped
                -- dims cannot be rotates
                (Bool, Bool, DimChange num)
_ -> String -> ([(Int, (num, num))], [(Int, num)])
forall a. HasCallStack => String -> a
error String
"reshape: reached impossible case"
          )
          ([], [])
          ([(Int, DimChange num, Int)]
 -> ([(Int, (num, num))], [(Int, num)]))
-> [(Int, DimChange num, Int)]
-> ([(Int, (num, num))], [(Int, num)])
forall a b. (a -> b) -> a -> b
$ [(Int, DimChange num, Int)] -> [(Int, DimChange num, Int)]
forall a. [a] -> [a]
reverse ([(Int, DimChange num, Int)] -> [(Int, DimChange num, Int)])
-> [(Int, DimChange num, Int)] -> [(Int, DimChange num, Int)]
forall a b. (a -> b) -> a -> b
$ Permutation
-> ShapeChange num -> Permutation -> [(Int, DimChange num, Int)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 Permutation
iota_shape ShapeChange num
newshape Permutation
perm'

      (Permutation
sup_inds, [(num, num)]
support) = [(Int, (num, num))] -> (Permutation, [(num, num)])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Int, (num, num))] -> (Permutation, [(num, num)]))
-> [(Int, (num, num))] -> (Permutation, [(num, num)])
forall a b. (a -> b) -> a -> b
$ ((Int, (num, num)) -> (Int, (num, num)) -> Ordering)
-> [(Int, (num, num))] -> [(Int, (num, num))]
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (Int -> Int -> Ordering)
-> ((Int, (num, num)) -> Int)
-> (Int, (num, num))
-> (Int, (num, num))
-> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (Int, (num, num)) -> Int
forall a b. (a, b) -> a
fst) [(Int, (num, num))]
support_inds
      (Permutation
rpt_inds, Shape num
repeats) = [(Int, num)] -> (Permutation, Shape num)
forall a b. [(a, b)] -> ([a], [b])
unzip [(Int, num)]
repeat_inds
      LMAD num
off' [LMADDim num]
dims_sup = Monotonicity -> num -> [(num, num)] -> LMAD num
forall num.
IntegralExp num =>
Monotonicity -> num -> [(num, num)] -> LMAD num
makeRotIota Monotonicity
mon num
off [(num, num)]
support
      repeats' :: [LMADDim num]
repeats' = (num -> LMADDim num) -> Shape num -> [LMADDim num]
forall a b. (a -> b) -> [a] -> [b]
map (\num
n -> num -> num -> num -> Int -> Monotonicity -> LMADDim num
forall num. num -> num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim num
0 num
0 num
n Int
0 Monotonicity
Unknown) Shape num
repeats
      dims' :: [LMADDim num]
dims' =
        ((Int, LMADDim num) -> LMADDim num)
-> [(Int, LMADDim num)] -> [LMADDim num]
forall a b. (a -> b) -> [a] -> [b]
map (Int, LMADDim num) -> LMADDim num
forall a b. (a, b) -> b
snd ([(Int, LMADDim num)] -> [LMADDim num])
-> [(Int, LMADDim num)] -> [LMADDim num]
forall a b. (a -> b) -> a -> b
$
          ((Int, LMADDim num) -> (Int, LMADDim num) -> Ordering)
-> [(Int, LMADDim num)] -> [(Int, LMADDim num)]
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (Int -> Int -> Ordering)
-> ((Int, LMADDim num) -> Int)
-> (Int, LMADDim num)
-> (Int, LMADDim num)
-> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (Int, LMADDim num) -> Int
forall a b. (a, b) -> a
fst) ([(Int, LMADDim num)] -> [(Int, LMADDim num)])
-> [(Int, LMADDim num)] -> [(Int, LMADDim num)]
forall a b. (a -> b) -> a -> b
$
            Permutation -> [LMADDim num] -> [(Int, LMADDim num)]
forall a b. [a] -> [b] -> [(a, b)]
zip Permutation
sup_inds [LMADDim num]
dims_sup [(Int, LMADDim num)]
-> [(Int, LMADDim num)] -> [(Int, LMADDim num)]
forall a. [a] -> [a] -> [a]
++ Permutation -> [LMADDim num] -> [(Int, LMADDim num)]
forall a b. [a] -> [b] -> [(a, b)]
zip Permutation
rpt_inds [LMADDim num]
repeats'
      lmad' :: LMAD num
lmad' = num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off' [LMADDim num]
dims'
  IxFun num -> Maybe (IxFun num)
forall (m :: * -> *) a. Monad m => a -> m a
return (IxFun num -> Maybe (IxFun num)) -> IxFun num -> Maybe (IxFun num)
forall a b. (a -> b) -> a -> b
$ NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (Permutation -> LMAD num -> LMAD num
forall num. Permutation -> LMAD num -> LMAD num
setLMADPermutation Permutation
perm' LMAD num
lmad' LMAD num -> [LMAD num] -> NonEmpty (LMAD num)
forall a. a -> [a] -> NonEmpty a
:| [LMAD num]
lmads) Shape num
oldbase Bool
cg
  where
    consecutive :: a -> [a] -> Bool
consecutive a
_ [] = Bool
True
    consecutive a
i [a
p] = a
i a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
p
    consecutive a
i [a]
ps = [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ([Bool] -> Bool) -> [Bool] -> Bool
forall a b. (a -> b) -> a -> b
$ (a -> a -> Bool) -> [a] -> [a] -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith a -> a -> Bool
forall a. Eq a => a -> a -> Bool
(==) [a]
ps [a
i, a
i a -> a -> a
forall a. Num a => a -> a -> a
+ a
1 ..]

splitCoercions ::
  (Eq num, IntegralExp num) =>
  ShapeChange num ->
  Maybe (ShapeChange num, ShapeChange num, ShapeChange num)
splitCoercions :: ShapeChange num
-> Maybe (ShapeChange num, ShapeChange num, ShapeChange num)
splitCoercions ShapeChange num
newshape' = do
  let (ShapeChange num
head_coercions, ShapeChange num
newshape'') = (DimChange num -> Bool)
-> ShapeChange num -> (ShapeChange num, ShapeChange num)
forall a. (a -> Bool) -> [a] -> ([a], [a])
span DimChange num -> Bool
forall d. DimChange d -> Bool
isCoercion ShapeChange num
newshape'
      (ShapeChange num
reshapes, ShapeChange num
tail_coercions) = (DimChange num -> Bool)
-> ShapeChange num -> (ShapeChange num, ShapeChange num)
forall a. (a -> Bool) -> [a] -> ([a], [a])
break DimChange num -> Bool
forall d. DimChange d -> Bool
isCoercion ShapeChange num
newshape''
  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard ((DimChange num -> Bool) -> ShapeChange num -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all DimChange num -> Bool
forall d. DimChange d -> Bool
isCoercion ShapeChange num
tail_coercions)
  (ShapeChange num, ShapeChange num, ShapeChange num)
-> Maybe (ShapeChange num, ShapeChange num, ShapeChange num)
forall (m :: * -> *) a. Monad m => a -> m a
return (ShapeChange num
head_coercions, ShapeChange num
reshapes, ShapeChange num
tail_coercions)
  where
    isCoercion :: DimChange d -> Bool
isCoercion DimCoercion {} = Bool
True
    isCoercion DimChange d
_ = Bool
False

-- | Reshape an index function.
reshape ::
  (Eq num, IntegralExp num) =>
  IxFun num ->
  ShapeChange num ->
  IxFun num
reshape :: IxFun num -> ShapeChange num -> IxFun num
reshape IxFun num
ixfun ShapeChange num
new_shape
  | Just IxFun num
ixfun' <- IxFun num -> ShapeChange num -> Maybe (IxFun num)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> ShapeChange num -> Maybe (IxFun num)
reshapeCoercion IxFun num
ixfun ShapeChange num
new_shape = IxFun num
ixfun'
  | Just IxFun num
ixfun' <- IxFun num -> ShapeChange num -> Maybe (IxFun num)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> ShapeChange num -> Maybe (IxFun num)
reshapeOneLMAD IxFun num
ixfun ShapeChange num
new_shape = IxFun num
ixfun'
reshape (IxFun (LMAD num
lmad0 :| [LMAD num]
lmad0s) Shape num
oshp Bool
cg) ShapeChange num
new_shape =
  case Shape num -> IxFun num
forall num. IntegralExp num => Shape num -> IxFun num
iota (ShapeChange num -> Shape num
forall d. ShapeChange d -> [d]
newDims ShapeChange num
new_shape) of
    IxFun (LMAD num
lmad :| []) Shape num
_ Bool
_ -> NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (LMAD num
lmad LMAD num -> [LMAD num] -> NonEmpty (LMAD num)
forall a. a -> [a] -> NonEmpty a
:| LMAD num
lmad0 LMAD num -> [LMAD num] -> [LMAD num]
forall a. a -> [a] -> [a]
: [LMAD num]
lmad0s) Shape num
oshp Bool
cg
    IxFun num
_ -> String -> IxFun num
forall a. HasCallStack => String -> a
error String
"reshape: reached impossible case"

-- | The number of dimensions in the domain of the input function.
rank ::
  IntegralExp num =>
  IxFun num ->
  Int
rank :: IxFun num -> Int
rank (IxFun (LMAD num
_ [LMADDim num]
sss :| [LMAD num]
_) Shape num
_ Bool
_) = [LMADDim num] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
sss

-- | Handle the case where a rebase operation can stay within m + n - 1 LMADs,
-- where m is the number of LMADs in the index function, and n is the number of
-- LMADs in the new base.  If both index function have only on LMAD, this means
-- that we stay within the single-LMAD domain.
--
-- We can often stay in that domain if the original ixfun is essentially a
-- slice, e.g. `x[i, (k1,m,s1), (k2,n,s2)] = orig`.
--
-- XXX: TODO: handle repetitions in both lmads.
--
-- How to handle repeated dimensions in the original?
--
--   (a) Shave them off of the last lmad of original
--   (b) Compose the result from (a) with the first
--       lmad of the new base
--   (c) apply a repeat operation on the result of (b).
--
-- However, I strongly suspect that for in-place update what we need is actually
-- the INVERSE of the rebase function, i.e., given an index function new-base
-- and another one orig, compute the index function ixfun0 such that:
--
--   new-base == rebase ixfun0 ixfun, or equivalently:
--   new-base == ixfun o ixfun0
--
-- because then I can go bottom up and compose with ixfun0 all the index
-- functions corresponding to the memory block associated with ixfun.
rebaseNice ::
  (Eq num, IntegralExp num) =>
  IxFun num ->
  IxFun num ->
  Maybe (IxFun num)
rebaseNice :: IxFun num -> IxFun num -> Maybe (IxFun num)
rebaseNice
  new_base :: IxFun num
new_base@(IxFun (LMAD num
lmad_base :| [LMAD num]
lmads_base) Shape num
_ Bool
cg_base)
  ixfun :: IxFun num
ixfun@(IxFun NonEmpty (LMAD num)
lmads Shape num
shp Bool
cg) = do
    let (LMAD num
lmad :| [LMAD num]
lmads') = NonEmpty (LMAD num) -> NonEmpty (LMAD num)
forall a. NonEmpty a -> NonEmpty a
NE.reverse NonEmpty (LMAD num)
lmads
        dims :: [LMADDim num]
dims = LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad
        perm :: Permutation
perm = LMAD num -> Permutation
forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad
        perm_base :: Permutation
perm_base = LMAD num -> Permutation
forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad_base

    Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$
      -- Core rebase condition.
      IxFun num -> Shape num
forall a. IxFun a -> [a]
base IxFun num
ixfun Shape num -> Shape num -> Bool
forall a. Eq a => a -> a -> Bool
== IxFun num -> Shape num
forall num. (Eq num, IntegralExp num) => IxFun num -> Shape num
shape IxFun num
new_base
        -- Conservative safety conditions: ixfun is contiguous and has known
        -- monotonicity for all dimensions.
        Bool -> Bool -> Bool
&& Bool
cg
        Bool -> Bool -> Bool
&& (LMADDim num -> Bool) -> [LMADDim num] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((Monotonicity -> Monotonicity -> Bool
forall a. Eq a => a -> a -> Bool
/= Monotonicity
Unknown) (Monotonicity -> Bool)
-> (LMADDim num -> Monotonicity) -> LMADDim num -> Bool
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. LMADDim num -> Monotonicity
forall num. LMADDim num -> Monotonicity
ldMon) [LMADDim num]
dims
        -- XXX: We should be able to handle some basic cases where both index
        -- functions have non-trivial permutations.
        Bool -> Bool -> Bool
&& (IxFun num -> Bool
forall a. IxFun a -> Bool
hasContiguousPerm IxFun num
ixfun Bool -> Bool -> Bool
|| IxFun num -> Bool
forall a. IxFun a -> Bool
hasContiguousPerm IxFun num
new_base)
        -- We need the permutations to be of the same size if we want to compose
        -- them.  They don't have to be of the same size if the ixfun has a trivial
        -- permutation.  Supporting this latter case allows us to rebase when ixfun
        -- has been created by slicing with fixed dimensions.
        Bool -> Bool -> Bool
&& (Permutation -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Permutation
perm Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Permutation -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Permutation
perm_base Bool -> Bool -> Bool
|| IxFun num -> Bool
forall a. IxFun a -> Bool
hasContiguousPerm IxFun num
ixfun)
        -- To not have to worry about ixfun having non-1 strides, we also check that
        -- it is a row-major array (modulo permutation, which is handled
        -- separately).  Accept a non-full innermost dimension.  XXX: Maybe this can
        -- be less conservative?
        Bool -> Bool -> Bool
&& [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and
          ( (num -> LMADDim num -> Bool -> Bool)
-> Shape num -> [LMADDim num] -> [Bool] -> [Bool]
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3
              (\num
sn LMADDim num
ld Bool
inner -> num
sn num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== LMADDim num -> num
forall num. LMADDim num -> num
ldShape LMADDim num
ld Bool -> Bool -> Bool
|| (Bool
inner Bool -> Bool -> Bool
&& LMADDim num -> num
forall num. LMADDim num -> num
ldStride LMADDim num
ld num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
1))
              Shape num
shp
              [LMADDim num]
dims
              (Int -> Bool -> [Bool]
forall a. Int -> a -> [a]
replicate ([LMADDim num] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Bool
False [Bool] -> [Bool] -> [Bool]
forall a. [a] -> [a] -> [a]
++ [Bool
True])
          )

    -- Compose permutations, reverse strides and adjust offset if necessary.
    let perm_base' :: Permutation
perm_base' =
          if IxFun num -> Bool
forall a. IxFun a -> Bool
hasContiguousPerm IxFun num
ixfun
            then Permutation
perm_base
            else (Int -> Int) -> Permutation -> Permutation
forall a b. (a -> b) -> [a] -> [b]
map (Permutation
perm Permutation -> Int -> Int
forall a. [a] -> Int -> a
!!) Permutation
perm_base
        lmad_base' :: LMAD num
lmad_base' = Permutation -> LMAD num -> LMAD num
forall num. Permutation -> LMAD num -> LMAD num
setLMADPermutation Permutation
perm_base' LMAD num
lmad_base
        dims_base :: [LMADDim num]
dims_base = LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad_base'
        n_fewer_dims :: Int
n_fewer_dims = [LMADDim num] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims_base Int -> Int -> Int
forall a. Num a => a -> a -> a
- [LMADDim num] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims
        ([LMADDim num]
dims_base', Shape num
offs_contrib) =
          [(LMADDim num, num)] -> ([LMADDim num], Shape num)
forall a b. [(a, b)] -> ([a], [b])
unzip ([(LMADDim num, num)] -> ([LMADDim num], Shape num))
-> [(LMADDim num, num)] -> ([LMADDim num], Shape num)
forall a b. (a -> b) -> a -> b
$
            (LMADDim num -> LMADDim num -> (LMADDim num, num))
-> [LMADDim num] -> [LMADDim num] -> [(LMADDim num, num)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
              ( \(LMADDim num
s1 num
r1 num
n1 Int
p1 Monotonicity
_) (LMADDim num
_ num
r2 num
_ Int
_ Monotonicity
m2) ->
                  let (num
s', num
off')
                        | Monotonicity
m2 Monotonicity -> Monotonicity -> Bool
forall a. Eq a => a -> a -> Bool
== Monotonicity
Inc = (num
s1, num
0)
                        | Bool
otherwise = (num
s1 num -> num -> num
forall a. Num a => a -> a -> a
* (-num
1), num
s1 num -> num -> num
forall a. Num a => a -> a -> a
* (num
n1 num -> num -> num
forall a. Num a => a -> a -> a
- num
1))
                      r' :: num
r'
                        | Monotonicity
m2 Monotonicity -> Monotonicity -> Bool
forall a. Eq a => a -> a -> Bool
== Monotonicity
Inc = if num
r2 num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0 then num
r1 else num
r1 num -> num -> num
forall a. Num a => a -> a -> a
+ num
r2
                        | num
r1 num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0 = num
r2
                        | num
r2 num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0 = num
n1 num -> num -> num
forall a. Num a => a -> a -> a
- num
r1
                        | Bool
otherwise = num
n1 num -> num -> num
forall a. Num a => a -> a -> a
- num
r1 num -> num -> num
forall a. Num a => a -> a -> a
+ num
r2
                   in (num -> num -> num -> Int -> Monotonicity -> LMADDim num
forall num. num -> num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim num
s' num
r' num
n1 (Int
p1 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
n_fewer_dims) Monotonicity
Inc, num
off')
              )
              -- If @dims@ is morally a slice, it might have fewer dimensions than
              -- @dims_base@.  Drop extraneous outer dimensions.
              (Int -> [LMADDim num] -> [LMADDim num]
forall a. Int -> [a] -> [a]
drop Int
n_fewer_dims [LMADDim num]
dims_base)
              [LMADDim num]
dims
        off_base :: num
off_base = LMAD num -> num
forall num. LMAD num -> num
lmadOffset LMAD num
lmad_base' num -> num -> num
forall a. Num a => a -> a -> a
+ Shape num -> num
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum Shape num
offs_contrib
        lmad_base'' :: LMAD num
lmad_base''
          | LMAD num -> num
forall num. LMAD num -> num
lmadOffset LMAD num
lmad num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0 = num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off_base [LMADDim num]
dims_base'
          | Bool
otherwise =
            -- If the innermost dimension of the ixfun was not full (but still
            -- had a stride of 1), add its offset relative to the new base.
            Shape num -> LMAD num -> LMAD num
forall num. Shape num -> LMAD num -> LMAD num
setLMADShape
              (LMAD num -> Shape num
forall num. (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShape LMAD num
lmad)
              ( num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD
                  (num
off_base num -> num -> num
forall a. Num a => a -> a -> a
+ LMADDim num -> num
forall num. LMADDim num -> num
ldStride ([LMADDim num] -> LMADDim num
forall a. [a] -> a
last [LMADDim num]
dims_base) num -> num -> num
forall a. Num a => a -> a -> a
* LMAD num -> num
forall num. LMAD num -> num
lmadOffset LMAD num
lmad)
                  [LMADDim num]
dims_base'
              )
        new_base' :: IxFun num
new_base' = NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (LMAD num
lmad_base'' LMAD num -> [LMAD num] -> NonEmpty (LMAD num)
forall a. a -> [a] -> NonEmpty a
:| [LMAD num]
lmads_base) Shape num
shp Bool
cg_base
        IxFun NonEmpty (LMAD num)
lmads_base' Shape num
_ Bool
_ = IxFun num
new_base'
        lmads'' :: NonEmpty (LMAD num)
lmads'' = [LMAD num]
lmads' [LMAD num] -> NonEmpty (LMAD num) -> NonEmpty (LMAD num)
forall a. [a] -> NonEmpty a -> NonEmpty a
++@ NonEmpty (LMAD num)
lmads_base'
    IxFun num -> Maybe (IxFun num)
forall (m :: * -> *) a. Monad m => a -> m a
return (IxFun num -> Maybe (IxFun num)) -> IxFun num -> Maybe (IxFun num)
forall a b. (a -> b) -> a -> b
$ NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun NonEmpty (LMAD num)
lmads'' Shape num
shp (Bool
cg Bool -> Bool -> Bool
&& Bool
cg_base)

-- | Rebase an index function on top of a new base.
rebase ::
  (Eq num, IntegralExp num) =>
  IxFun num ->
  IxFun num ->
  IxFun num
rebase :: IxFun num -> IxFun num -> IxFun num
rebase new_base :: IxFun num
new_base@(IxFun NonEmpty (LMAD num)
lmads_base Shape num
shp_base Bool
cg_base) ixfun :: IxFun num
ixfun@(IxFun NonEmpty (LMAD num)
lmads Shape num
shp Bool
cg)
  | Just IxFun num
ixfun' <- IxFun num -> IxFun num -> Maybe (IxFun num)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> IxFun num -> Maybe (IxFun num)
rebaseNice IxFun num
new_base IxFun num
ixfun = IxFun num
ixfun'
  -- In the general case just concatenate LMADs since this refers to index
  -- function composition, which is always safe.
  | Bool
otherwise =
    let (NonEmpty (LMAD num)
lmads_base', Shape num
shp_base') =
          if IxFun num -> Shape num
forall a. IxFun a -> [a]
base IxFun num
ixfun Shape num -> Shape num -> Bool
forall a. Eq a => a -> a -> Bool
== IxFun num -> Shape num
forall num. (Eq num, IntegralExp num) => IxFun num -> Shape num
shape IxFun num
new_base
            then (NonEmpty (LMAD num)
lmads_base, Shape num
shp_base)
            else
              let IxFun NonEmpty (LMAD num)
lmads' Shape num
shp_base'' Bool
_ = IxFun num -> ShapeChange num -> IxFun num
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> ShapeChange num -> IxFun num
reshape IxFun num
new_base (ShapeChange num -> IxFun num) -> ShapeChange num -> IxFun num
forall a b. (a -> b) -> a -> b
$ (num -> DimChange num) -> Shape num -> ShapeChange num
forall a b. (a -> b) -> [a] -> [b]
map num -> DimChange num
forall d. d -> DimChange d
DimCoercion Shape num
shp
               in (NonEmpty (LMAD num)
lmads', Shape num
shp_base'')
     in NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (NonEmpty (LMAD num)
lmads NonEmpty (LMAD num) -> NonEmpty (LMAD num) -> NonEmpty (LMAD num)
forall a. NonEmpty a -> NonEmpty a -> NonEmpty a
@++@ NonEmpty (LMAD num)
lmads_base') Shape num
shp_base' (Bool
cg Bool -> Bool -> Bool
&& Bool
cg_base)

ixfunMonotonicity :: (Eq num, IntegralExp num) => IxFun num -> Monotonicity
ixfunMonotonicity :: IxFun num -> Monotonicity
ixfunMonotonicity = Bool -> IxFun num -> Monotonicity
forall num.
(Eq num, IntegralExp num) =>
Bool -> IxFun num -> Monotonicity
ixfunMonotonicityRots Bool
False

-- | If the memory support of the index function is contiguous and row-major
-- (i.e., no transpositions, repetitions, rotates, etc.), then this should
-- return the offset from which the memory-support of this index function
-- starts.
linearWithOffset ::
  (Eq num, IntegralExp num) =>
  IxFun num ->
  num ->
  Maybe num
linearWithOffset :: IxFun num -> num -> Maybe num
linearWithOffset ixfun :: IxFun num
ixfun@(IxFun (LMAD num
lmad :| []) Shape num
_ Bool
cg) num
elem_size
  | IxFun num -> Bool
forall a. IxFun a -> Bool
hasContiguousPerm IxFun num
ixfun Bool -> Bool -> Bool
&& Bool
cg Bool -> Bool -> Bool
&& IxFun num -> Monotonicity
forall num. (Eq num, IntegralExp num) => IxFun num -> Monotonicity
ixfunMonotonicity IxFun num
ixfun Monotonicity -> Monotonicity -> Bool
forall a. Eq a => a -> a -> Bool
== Monotonicity
Inc =
    num -> Maybe num
forall a. a -> Maybe a
Just (num -> Maybe num) -> num -> Maybe num
forall a b. (a -> b) -> a -> b
$ LMAD num -> num
forall num. LMAD num -> num
lmadOffset LMAD num
lmad num -> num -> num
forall a. Num a => a -> a -> a
* num
elem_size
linearWithOffset IxFun num
_ num
_ = Maybe num
forall a. Maybe a
Nothing

-- | Similar restrictions to @linearWithOffset@ except for transpositions, which
-- are returned together with the offset.
rearrangeWithOffset ::
  (Eq num, IntegralExp num) =>
  IxFun num ->
  num ->
  Maybe (num, [(Int, num)])
rearrangeWithOffset :: IxFun num -> num -> Maybe (num, [(Int, num)])
rearrangeWithOffset (IxFun (LMAD num
lmad :| []) Shape num
oshp Bool
cg) num
elem_size = do
  -- Note that @cg@ describes whether the index function is
  -- contiguous, *ignoring permutations*.  This function requires that
  -- functionality.
  let perm :: Permutation
perm = LMAD num -> Permutation
forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad
      perm_contig :: Permutation
perm_contig = [Int
0 .. Permutation -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Permutation
perm Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
  num
offset <-
    IxFun num -> num -> Maybe num
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
linearWithOffset
      (NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (Permutation -> LMAD num -> LMAD num
forall num. Permutation -> LMAD num -> LMAD num
setLMADPermutation Permutation
perm_contig LMAD num
lmad LMAD num -> [LMAD num] -> NonEmpty (LMAD num)
forall a. a -> [a] -> NonEmpty a
:| []) Shape num
oshp Bool
cg)
      num
elem_size
  (num, [(Int, num)]) -> Maybe (num, [(Int, num)])
forall (m :: * -> *) a. Monad m => a -> m a
return (num
offset, Permutation -> Shape num -> [(Int, num)]
forall a b. [a] -> [b] -> [(a, b)]
zip Permutation
perm (Permutation -> Shape num -> Shape num
forall a. Permutation -> [a] -> [a]
permuteFwd Permutation
perm (LMAD num -> Shape num
forall num. (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShapeBase LMAD num
lmad)))
rearrangeWithOffset IxFun num
_ num
_ = Maybe (num, [(Int, num)])
forall a. Maybe a
Nothing

-- | Is this a row-major array starting at offset zero?
isLinear :: (Eq num, IntegralExp num) => IxFun num -> Bool
isLinear :: IxFun num -> Bool
isLinear = (Maybe num -> Maybe num -> Bool
forall a. Eq a => a -> a -> Bool
== num -> Maybe num
forall a. a -> Maybe a
Just num
0) (Maybe num -> Bool)
-> (IxFun num -> Maybe num) -> IxFun num -> Bool
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (IxFun num -> num -> Maybe num) -> num -> IxFun num -> Maybe num
forall a b c. (a -> b -> c) -> b -> a -> c
flip IxFun num -> num -> Maybe num
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
linearWithOffset num
1

permuteFwd :: Permutation -> [a] -> [a]
permuteFwd :: Permutation -> [a] -> [a]
permuteFwd Permutation
ps [a]
elems = (Int -> a) -> Permutation -> [a]
forall a b. (a -> b) -> [a] -> [b]
map ([a]
elems [a] -> Int -> a
forall a. [a] -> Int -> a
!!) Permutation
ps

permuteInv :: Permutation -> [a] -> [a]
permuteInv :: Permutation -> [a] -> [a]
permuteInv Permutation
ps [a]
elems = ((Int, a) -> a) -> [(Int, a)] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map (Int, a) -> a
forall a b. (a, b) -> b
snd ([(Int, a)] -> [a]) -> [(Int, a)] -> [a]
forall a b. (a -> b) -> a -> b
$ ((Int, a) -> (Int, a) -> Ordering) -> [(Int, a)] -> [(Int, a)]
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (Int -> Int -> Ordering)
-> ((Int, a) -> Int) -> (Int, a) -> (Int, a) -> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (Int, a) -> Int
forall a b. (a, b) -> a
fst) ([(Int, a)] -> [(Int, a)]) -> [(Int, a)] -> [(Int, a)]
forall a b. (a -> b) -> a -> b
$ Permutation -> [a] -> [(Int, a)]
forall a b. [a] -> [b] -> [(a, b)]
zip Permutation
ps [a]
elems

flatOneDim ::
  (Eq num, IntegralExp num) =>
  (num, num, num) ->
  num ->
  num
flatOneDim :: (num, num, num) -> num -> num
flatOneDim (num
s, num
r, num
n) num
i
  | num
s num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0 = num
0
  | num
r num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0 = num
i num -> num -> num
forall a. Num a => a -> a -> a
* num
s
  | Bool
otherwise = ((num
i num -> num -> num
forall a. Num a => a -> a -> a
+ num
r) num -> num -> num
forall e. IntegralExp e => e -> e -> e
`mod` num
n) num -> num -> num
forall a. Num a => a -> a -> a
* num
s

-- | Generalised iota with user-specified offset and strides.
makeRotIota ::
  IntegralExp num =>
  Monotonicity ->
  num ->
  [(num, num)] ->
  LMAD num
makeRotIota :: Monotonicity -> num -> [(num, num)] -> LMAD num
makeRotIota Monotonicity
mon num
off [(num, num)]
support
  | Monotonicity
mon Monotonicity -> Monotonicity -> Bool
forall a. Eq a => a -> a -> Bool
== Monotonicity
Inc Bool -> Bool -> Bool
|| Monotonicity
mon Monotonicity -> Monotonicity -> Bool
forall a. Eq a => a -> a -> Bool
== Monotonicity
Dec =
    let rk :: Int
rk = [(num, num)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(num, num)]
support
        ([num]
rs, [num]
ns) = [(num, num)] -> ([num], [num])
forall a b. [(a, b)] -> ([a], [b])
unzip [(num, num)]
support
        ss0 :: [num]
ss0 = [num] -> [num]
forall a. [a] -> [a]
reverse ([num] -> [num]) -> [num] -> [num]
forall a b. (a -> b) -> a -> b
$ Int -> [num] -> [num]
forall a. Int -> [a] -> [a]
take Int
rk ([num] -> [num]) -> [num] -> [num]
forall a b. (a -> b) -> a -> b
$ (num -> num -> num) -> num -> [num] -> [num]
forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl num -> num -> num
forall a. Num a => a -> a -> a
(*) num
1 ([num] -> [num]) -> [num] -> [num]
forall a b. (a -> b) -> a -> b
$ [num] -> [num]
forall a. [a] -> [a]
reverse [num]
ns
        ss :: [num]
ss =
          if Monotonicity
mon Monotonicity -> Monotonicity -> Bool
forall a. Eq a => a -> a -> Bool
== Monotonicity
Inc
            then [num]
ss0
            else (num -> num) -> [num] -> [num]
forall a b. (a -> b) -> [a] -> [b]
map (num -> num -> num
forall a. Num a => a -> a -> a
* (-num
1)) [num]
ss0
        ps :: Permutation
ps = (Int -> Int) -> Permutation -> Permutation
forall a b. (a -> b) -> [a] -> [b]
map Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral [Int
0 .. Int
rk Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
        fi :: [Monotonicity]
fi = Int -> Monotonicity -> [Monotonicity]
forall a. Int -> a -> [a]
replicate Int
rk Monotonicity
mon
     in num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off ([LMADDim num] -> LMAD num) -> [LMADDim num] -> LMAD num
forall a b. (a -> b) -> a -> b
$ (num -> num -> num -> Int -> Monotonicity -> LMADDim num)
-> [num]
-> [num]
-> [num]
-> Permutation
-> [Monotonicity]
-> [LMADDim num]
forall a b c d e f.
(a -> b -> c -> d -> e -> f)
-> [a] -> [b] -> [c] -> [d] -> [e] -> [f]
zipWith5 num -> num -> num -> Int -> Monotonicity -> LMADDim num
forall num. num -> num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim [num]
ss [num]
rs [num]
ns Permutation
ps [Monotonicity]
fi
  | Bool
otherwise = String -> LMAD num
forall a. HasCallStack => String -> a
error String
"makeRotIota: requires Inc or Dec"

-- | Check monotonicity of an index function.
ixfunMonotonicityRots ::
  (Eq num, IntegralExp num) =>
  Bool ->
  IxFun num ->
  Monotonicity
ixfunMonotonicityRots :: Bool -> IxFun num -> Monotonicity
ixfunMonotonicityRots Bool
ignore_rots (IxFun (LMAD num
lmad :| [LMAD num]
lmads) Shape num
_ Bool
_) =
  let mon0 :: Monotonicity
mon0 = LMAD num -> Monotonicity
forall num. (Eq num, IntegralExp num) => LMAD num -> Monotonicity
lmadMonotonicityRots LMAD num
lmad
   in if (LMAD num -> Bool) -> [LMAD num] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((Monotonicity -> Monotonicity -> Bool
forall a. Eq a => a -> a -> Bool
== Monotonicity
mon0) (Monotonicity -> Bool)
-> (LMAD num -> Monotonicity) -> LMAD num -> Bool
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. LMAD num -> Monotonicity
forall num. (Eq num, IntegralExp num) => LMAD num -> Monotonicity
lmadMonotonicityRots) [LMAD num]
lmads
        then Monotonicity
mon0
        else Monotonicity
Unknown
  where
    lmadMonotonicityRots ::
      (Eq num, IntegralExp num) =>
      LMAD num ->
      Monotonicity
    lmadMonotonicityRots :: LMAD num -> Monotonicity
lmadMonotonicityRots (LMAD num
_ [LMADDim num]
dims)
      | (LMADDim num -> Bool) -> [LMADDim num] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Monotonicity -> LMADDim num -> Bool
forall num.
(Eq num, IntegralExp num) =>
Monotonicity -> LMADDim num -> Bool
isMonDim Monotonicity
Inc) [LMADDim num]
dims = Monotonicity
Inc
      | (LMADDim num -> Bool) -> [LMADDim num] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Monotonicity -> LMADDim num -> Bool
forall num.
(Eq num, IntegralExp num) =>
Monotonicity -> LMADDim num -> Bool
isMonDim Monotonicity
Dec) [LMADDim num]
dims = Monotonicity
Dec
      | Bool
otherwise = Monotonicity
Unknown

    isMonDim ::
      (Eq num, IntegralExp num) =>
      Monotonicity ->
      LMADDim num ->
      Bool
    isMonDim :: Monotonicity -> LMADDim num -> Bool
isMonDim Monotonicity
mon (LMADDim num
s num
r num
_ Int
_ Monotonicity
ldmon) =
      num
s num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0 Bool -> Bool -> Bool
|| ((Bool
ignore_rots Bool -> Bool -> Bool
|| num
r num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0) Bool -> Bool -> Bool
&& Monotonicity
mon Monotonicity -> Monotonicity -> Bool
forall a. Eq a => a -> a -> Bool
== Monotonicity
ldmon)

-- | Generalization (anti-unification)
--
-- Anti-unification of two index functions is supported under the following conditions:
--   0. Both index functions are represented by ONE lmad (assumed common case!)
--   1. The support array of the two indexfuns have the same dimensionality
--      (we can relax this condition if we use a 1D support, as we probably should!)
--   2. The contiguous property and the per-dimension monotonicity are the same
--      (otherwise we might loose important information; this can be relaxed!)
--   3. Most importantly, both index functions correspond to the same permutation
--      (since the permutation is represented by INTs, this restriction cannot
--       be relaxed, unless we move to a gated-LMAD representation!)
leastGeneralGeneralization ::
  Eq v =>
  IxFun (PrimExp v) ->
  IxFun (PrimExp v) ->
  Maybe (IxFun (PrimExp (Ext v)), [(PrimExp v, PrimExp v)])
leastGeneralGeneralization :: IxFun (PrimExp v)
-> IxFun (PrimExp v)
-> Maybe (IxFun (PrimExp (Ext v)), [(PrimExp v, PrimExp v)])
leastGeneralGeneralization (IxFun (LMAD (PrimExp v)
lmad1 :| []) Shape (PrimExp v)
oshp1 Bool
ctg1) (IxFun (LMAD (PrimExp v)
lmad2 :| []) Shape (PrimExp v)
oshp2 Bool
ctg2) = do
  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$
    Shape (PrimExp v) -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape (PrimExp v)
oshp1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Shape (PrimExp v) -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape (PrimExp v)
oshp2
      Bool -> Bool -> Bool
&& Bool
ctg1 Bool -> Bool -> Bool
forall a. Eq a => a -> a -> Bool
== Bool
ctg2
      Bool -> Bool -> Bool
&& (LMADDim (PrimExp v) -> Int)
-> [LMADDim (PrimExp v)] -> Permutation
forall a b. (a -> b) -> [a] -> [b]
map LMADDim (PrimExp v) -> Int
forall num. LMADDim num -> Int
ldPerm (LMAD (PrimExp v) -> [LMADDim (PrimExp v)]
forall num. LMAD num -> [LMADDim num]
lmadDims LMAD (PrimExp v)
lmad1) Permutation -> Permutation -> Bool
forall a. Eq a => a -> a -> Bool
== (LMADDim (PrimExp v) -> Int)
-> [LMADDim (PrimExp v)] -> Permutation
forall a b. (a -> b) -> [a] -> [b]
map LMADDim (PrimExp v) -> Int
forall num. LMADDim num -> Int
ldPerm (LMAD (PrimExp v) -> [LMADDim (PrimExp v)]
forall num. LMAD num -> [LMADDim num]
lmadDims LMAD (PrimExp v)
lmad2)
      Bool -> Bool -> Bool
&& LMAD (PrimExp v) -> [Monotonicity]
forall num. LMAD num -> [Monotonicity]
lmadDMon LMAD (PrimExp v)
lmad1 [Monotonicity] -> [Monotonicity] -> Bool
forall a. Eq a => a -> a -> Bool
== LMAD (PrimExp v) -> [Monotonicity]
forall num. LMAD num -> [Monotonicity]
lmadDMon LMAD (PrimExp v)
lmad2
  let (Bool
ctg, Permutation
dperm, [Monotonicity]
dmon) = (Bool
ctg1, LMAD (PrimExp v) -> Permutation
forall num. LMAD num -> Permutation
lmadPermutation LMAD (PrimExp v)
lmad1, LMAD (PrimExp v) -> [Monotonicity]
forall num. LMAD num -> [Monotonicity]
lmadDMon LMAD (PrimExp v)
lmad1)
  ([PrimExp (Ext v)]
dshp, [(PrimExp v, PrimExp v)]
m1) <- [(PrimExp v, PrimExp v)]
-> Shape (PrimExp v)
-> Shape (PrimExp v)
-> Maybe ([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
forall (m :: * -> *) v.
(Monad m, Eq v) =>
[(PrimExp v, PrimExp v)]
-> [PrimExp v]
-> [PrimExp v]
-> m ([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
generalize [] (LMAD (PrimExp v) -> Shape (PrimExp v)
forall a. LMAD a -> [a]
lmadDShp LMAD (PrimExp v)
lmad1) (LMAD (PrimExp v) -> Shape (PrimExp v)
forall a. LMAD a -> [a]
lmadDShp LMAD (PrimExp v)
lmad2)
  ([PrimExp (Ext v)]
oshp, [(PrimExp v, PrimExp v)]
m2) <- [(PrimExp v, PrimExp v)]
-> Shape (PrimExp v)
-> Shape (PrimExp v)
-> Maybe ([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
forall (m :: * -> *) v.
(Monad m, Eq v) =>
[(PrimExp v, PrimExp v)]
-> [PrimExp v]
-> [PrimExp v]
-> m ([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
generalize [(PrimExp v, PrimExp v)]
m1 Shape (PrimExp v)
oshp1 Shape (PrimExp v)
oshp2
  ([PrimExp (Ext v)]
dstd, [(PrimExp v, PrimExp v)]
m3) <- [(PrimExp v, PrimExp v)]
-> Shape (PrimExp v)
-> Shape (PrimExp v)
-> Maybe ([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
forall (m :: * -> *) v.
(Monad m, Eq v) =>
[(PrimExp v, PrimExp v)]
-> [PrimExp v]
-> [PrimExp v]
-> m ([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
generalize [(PrimExp v, PrimExp v)]
m2 (LMAD (PrimExp v) -> Shape (PrimExp v)
forall a. LMAD a -> [a]
lmadDSrd LMAD (PrimExp v)
lmad1) (LMAD (PrimExp v) -> Shape (PrimExp v)
forall a. LMAD a -> [a]
lmadDSrd LMAD (PrimExp v)
lmad2)
  ([PrimExp (Ext v)]
drot, [(PrimExp v, PrimExp v)]
m4) <- [(PrimExp v, PrimExp v)]
-> Shape (PrimExp v)
-> Shape (PrimExp v)
-> Maybe ([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
forall (m :: * -> *) v.
(Monad m, Eq v) =>
[(PrimExp v, PrimExp v)]
-> [PrimExp v]
-> [PrimExp v]
-> m ([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
generalize [(PrimExp v, PrimExp v)]
m3 (LMAD (PrimExp v) -> Shape (PrimExp v)
forall a. LMAD a -> [a]
lmadDRot LMAD (PrimExp v)
lmad1) (LMAD (PrimExp v) -> Shape (PrimExp v)
forall a. LMAD a -> [a]
lmadDRot LMAD (PrimExp v)
lmad2)
  let (PrimExp (Ext v)
offt, [(PrimExp v, PrimExp v)]
m5) = [(PrimExp v, PrimExp v)]
-> PrimExp v
-> PrimExp v
-> (PrimExp (Ext v), [(PrimExp v, PrimExp v)])
forall v.
Eq v =>
[(PrimExp v, PrimExp v)]
-> PrimExp v
-> PrimExp v
-> (PrimExp (Ext v), [(PrimExp v, PrimExp v)])
PEG.leastGeneralGeneralization [(PrimExp v, PrimExp v)]
m4 (LMAD (PrimExp v) -> PrimExp v
forall num. LMAD num -> num
lmadOffset LMAD (PrimExp v)
lmad1) (LMAD (PrimExp v) -> PrimExp v
forall num. LMAD num -> num
lmadOffset LMAD (PrimExp v)
lmad2)
  let lmad_dims :: [LMADDim (PrimExp (Ext v))]
lmad_dims =
        ((PrimExp (Ext v), PrimExp (Ext v), PrimExp (Ext v), Int,
  Monotonicity)
 -> LMADDim (PrimExp (Ext v)))
-> [(PrimExp (Ext v), PrimExp (Ext v), PrimExp (Ext v), Int,
     Monotonicity)]
-> [LMADDim (PrimExp (Ext v))]
forall a b. (a -> b) -> [a] -> [b]
map (\(PrimExp (Ext v)
a, PrimExp (Ext v)
b, PrimExp (Ext v)
c, Int
d, Monotonicity
e) -> PrimExp (Ext v)
-> PrimExp (Ext v)
-> PrimExp (Ext v)
-> Int
-> Monotonicity
-> LMADDim (PrimExp (Ext v))
forall num. num -> num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim PrimExp (Ext v)
a PrimExp (Ext v)
b PrimExp (Ext v)
c Int
d Monotonicity
e) ([(PrimExp (Ext v), PrimExp (Ext v), PrimExp (Ext v), Int,
   Monotonicity)]
 -> [LMADDim (PrimExp (Ext v))])
-> [(PrimExp (Ext v), PrimExp (Ext v), PrimExp (Ext v), Int,
     Monotonicity)]
-> [LMADDim (PrimExp (Ext v))]
forall a b. (a -> b) -> a -> b
$
          [PrimExp (Ext v)]
-> [PrimExp (Ext v)]
-> [PrimExp (Ext v)]
-> Permutation
-> [Monotonicity]
-> [(PrimExp (Ext v), PrimExp (Ext v), PrimExp (Ext v), Int,
     Monotonicity)]
forall a b c d e.
[a] -> [b] -> [c] -> [d] -> [e] -> [(a, b, c, d, e)]
zip5 [PrimExp (Ext v)]
dstd [PrimExp (Ext v)]
drot [PrimExp (Ext v)]
dshp Permutation
dperm [Monotonicity]
dmon
      lmad :: LMAD (PrimExp (Ext v))
lmad = PrimExp (Ext v)
-> [LMADDim (PrimExp (Ext v))] -> LMAD (PrimExp (Ext v))
forall num. num -> [LMADDim num] -> LMAD num
LMAD PrimExp (Ext v)
offt [LMADDim (PrimExp (Ext v))]
lmad_dims
  (IxFun (PrimExp (Ext v)), [(PrimExp v, PrimExp v)])
-> Maybe (IxFun (PrimExp (Ext v)), [(PrimExp v, PrimExp v)])
forall (m :: * -> *) a. Monad m => a -> m a
return (NonEmpty (LMAD (PrimExp (Ext v)))
-> [PrimExp (Ext v)] -> Bool -> IxFun (PrimExp (Ext v))
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (LMAD (PrimExp (Ext v))
lmad LMAD (PrimExp (Ext v))
-> [LMAD (PrimExp (Ext v))] -> NonEmpty (LMAD (PrimExp (Ext v)))
forall a. a -> [a] -> NonEmpty a
:| []) [PrimExp (Ext v)]
oshp Bool
ctg, [(PrimExp v, PrimExp v)]
m5)
  where
    lmadDMon :: LMAD num -> [Monotonicity]
lmadDMon = (LMADDim num -> Monotonicity) -> [LMADDim num] -> [Monotonicity]
forall a b. (a -> b) -> [a] -> [b]
map LMADDim num -> Monotonicity
forall num. LMADDim num -> Monotonicity
ldMon ([LMADDim num] -> [Monotonicity])
-> (LMAD num -> [LMADDim num]) -> LMAD num -> [Monotonicity]
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
lmadDims
    lmadDSrd :: LMAD b -> [b]
lmadDSrd = (LMADDim b -> b) -> [LMADDim b] -> [b]
forall a b. (a -> b) -> [a] -> [b]
map LMADDim b -> b
forall num. LMADDim num -> num
ldStride ([LMADDim b] -> [b]) -> (LMAD b -> [LMADDim b]) -> LMAD b -> [b]
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. LMAD b -> [LMADDim b]
forall num. LMAD num -> [LMADDim num]
lmadDims
    lmadDShp :: LMAD b -> [b]
lmadDShp = (LMADDim b -> b) -> [LMADDim b] -> [b]
forall a b. (a -> b) -> [a] -> [b]
map LMADDim b -> b
forall num. LMADDim num -> num
ldShape ([LMADDim b] -> [b]) -> (LMAD b -> [LMADDim b]) -> LMAD b -> [b]
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. LMAD b -> [LMADDim b]
forall num. LMAD num -> [LMADDim num]
lmadDims
    lmadDRot :: LMAD b -> [b]
lmadDRot = (LMADDim b -> b) -> [LMADDim b] -> [b]
forall a b. (a -> b) -> [a] -> [b]
map LMADDim b -> b
forall num. LMADDim num -> num
ldRotate ([LMADDim b] -> [b]) -> (LMAD b -> [LMADDim b]) -> LMAD b -> [b]
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. LMAD b -> [LMADDim b]
forall num. LMAD num -> [LMADDim num]
lmadDims
    generalize :: [(PrimExp v, PrimExp v)]
-> [PrimExp v]
-> [PrimExp v]
-> m ([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
generalize [(PrimExp v, PrimExp v)]
m [PrimExp v]
l1 [PrimExp v]
l2 =
      (([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
 -> (PrimExp v, PrimExp v)
 -> m ([PrimExp (Ext v)], [(PrimExp v, PrimExp v)]))
-> ([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
-> [(PrimExp v, PrimExp v)]
-> m ([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM
        ( \([PrimExp (Ext v)]
l_acc, [(PrimExp v, PrimExp v)]
m') (PrimExp v
pe1, PrimExp v
pe2) -> do
            let (PrimExp (Ext v)
e, [(PrimExp v, PrimExp v)]
m'') = [(PrimExp v, PrimExp v)]
-> PrimExp v
-> PrimExp v
-> (PrimExp (Ext v), [(PrimExp v, PrimExp v)])
forall v.
Eq v =>
[(PrimExp v, PrimExp v)]
-> PrimExp v
-> PrimExp v
-> (PrimExp (Ext v), [(PrimExp v, PrimExp v)])
PEG.leastGeneralGeneralization [(PrimExp v, PrimExp v)]
m' PrimExp v
pe1 PrimExp v
pe2
            ([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
-> m ([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
forall (m :: * -> *) a. Monad m => a -> m a
return ([PrimExp (Ext v)]
l_acc [PrimExp (Ext v)] -> [PrimExp (Ext v)] -> [PrimExp (Ext v)]
forall a. [a] -> [a] -> [a]
++ [PrimExp (Ext v)
e], [(PrimExp v, PrimExp v)]
m'')
        )
        ([], [(PrimExp v, PrimExp v)]
m)
        ([PrimExp v] -> [PrimExp v] -> [(PrimExp v, PrimExp v)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PrimExp v]
l1 [PrimExp v]
l2)
leastGeneralGeneralization IxFun (PrimExp v)
_ IxFun (PrimExp v)
_ = Maybe (IxFun (PrimExp (Ext v)), [(PrimExp v, PrimExp v)])
forall a. Maybe a
Nothing

isSequential :: [Int] -> Bool
isSequential :: Permutation -> Bool
isSequential Permutation
xs =
  ((Int, Int) -> Bool) -> [(Int, Int)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((Int -> Int -> Bool) -> (Int, Int) -> Bool
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
(==)) ([(Int, Int)] -> Bool) -> [(Int, Int)] -> Bool
forall a b. (a -> b) -> a -> b
$ Permutation -> Permutation -> [(Int, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip Permutation
xs [Int
0 ..]

existentializeExp :: TPrimExp t v -> State [TPrimExp t v] (TPrimExp t (Ext v))
existentializeExp :: TPrimExp t v -> State [TPrimExp t v] (TPrimExp t (Ext v))
existentializeExp TPrimExp t v
e = do
  Int
i <- ([TPrimExp t v] -> Int) -> StateT [TPrimExp t v] Identity Int
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets [TPrimExp t v] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length
  ([TPrimExp t v] -> [TPrimExp t v])
-> StateT [TPrimExp t v] Identity ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ([TPrimExp t v] -> [TPrimExp t v] -> [TPrimExp t v]
forall a. [a] -> [a] -> [a]
++ [TPrimExp t v
e])
  let t :: PrimType
t = PrimExp v -> PrimType
forall v. PrimExp v -> PrimType
primExpType (PrimExp v -> PrimType) -> PrimExp v -> PrimType
forall a b. (a -> b) -> a -> b
$ TPrimExp t v -> PrimExp v
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp t v
e
  TPrimExp t (Ext v) -> State [TPrimExp t v] (TPrimExp t (Ext v))
forall (m :: * -> *) a. Monad m => a -> m a
return (TPrimExp t (Ext v) -> State [TPrimExp t v] (TPrimExp t (Ext v)))
-> TPrimExp t (Ext v) -> State [TPrimExp t v] (TPrimExp t (Ext v))
forall a b. (a -> b) -> a -> b
$ PrimExp (Ext v) -> TPrimExp t (Ext v)
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (PrimExp (Ext v) -> TPrimExp t (Ext v))
-> PrimExp (Ext v) -> TPrimExp t (Ext v)
forall a b. (a -> b) -> a -> b
$ Ext v -> PrimType -> PrimExp (Ext v)
forall v. v -> PrimType -> PrimExp v
LeafExp (Int -> Ext v
forall a. Int -> Ext a
Ext Int
i) PrimType
t

-- We require that there's only one lmad, and that the index function is contiguous, and the base shape has only one dimension
existentialize ::
  (IntExp t, Eq v, Pretty v) =>
  IxFun (TPrimExp t v) ->
  State [TPrimExp t v] (Maybe (IxFun (TPrimExp t (Ext v))))
existentialize :: IxFun (TPrimExp t v)
-> State [TPrimExp t v] (Maybe (IxFun (TPrimExp t (Ext v))))
existentialize (IxFun (LMAD (TPrimExp t v)
lmad :| []) [TPrimExp t v]
oshp Bool
True)
  | (LMADDim (TPrimExp t v) -> Bool)
-> [LMADDim (TPrimExp t v)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((TPrimExp t v -> TPrimExp t v -> Bool
forall a. Eq a => a -> a -> Bool
== TPrimExp t v
0) (TPrimExp t v -> Bool)
-> (LMADDim (TPrimExp t v) -> TPrimExp t v)
-> LMADDim (TPrimExp t v)
-> Bool
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. LMADDim (TPrimExp t v) -> TPrimExp t v
forall num. LMADDim num -> num
ldRotate) (LMAD (TPrimExp t v) -> [LMADDim (TPrimExp t v)]
forall num. LMAD num -> [LMADDim num]
lmadDims LMAD (TPrimExp t v)
lmad),
    [TPrimExp t v] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (LMAD (TPrimExp t v) -> [TPrimExp t v]
forall num. (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShape LMAD (TPrimExp t v)
lmad) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [TPrimExp t v] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp t v]
oshp,
    Permutation -> Bool
isSequential ((LMADDim (TPrimExp t v) -> Int)
-> [LMADDim (TPrimExp t v)] -> Permutation
forall a b. (a -> b) -> [a] -> [b]
map LMADDim (TPrimExp t v) -> Int
forall num. LMADDim num -> Int
ldPerm ([LMADDim (TPrimExp t v)] -> Permutation)
-> [LMADDim (TPrimExp t v)] -> Permutation
forall a b. (a -> b) -> a -> b
$ LMAD (TPrimExp t v) -> [LMADDim (TPrimExp t v)]
forall num. LMAD num -> [LMADDim num]
lmadDims LMAD (TPrimExp t v)
lmad) = do
    [TPrimExp t (Ext v)]
oshp' <- (TPrimExp t v
 -> StateT [TPrimExp t v] Identity (TPrimExp t (Ext v)))
-> [TPrimExp t v]
-> StateT [TPrimExp t v] Identity [TPrimExp t (Ext v)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM TPrimExp t v -> StateT [TPrimExp t v] Identity (TPrimExp t (Ext v))
forall t v.
TPrimExp t v -> State [TPrimExp t v] (TPrimExp t (Ext v))
existentializeExp [TPrimExp t v]
oshp
    TPrimExp t (Ext v)
lmadOffset' <- TPrimExp t v -> StateT [TPrimExp t v] Identity (TPrimExp t (Ext v))
forall t v.
TPrimExp t v -> State [TPrimExp t v] (TPrimExp t (Ext v))
existentializeExp (TPrimExp t v
 -> StateT [TPrimExp t v] Identity (TPrimExp t (Ext v)))
-> TPrimExp t v
-> StateT [TPrimExp t v] Identity (TPrimExp t (Ext v))
forall a b. (a -> b) -> a -> b
$ LMAD (TPrimExp t v) -> TPrimExp t v
forall num. LMAD num -> num
lmadOffset LMAD (TPrimExp t v)
lmad
    [LMADDim (TPrimExp t (Ext v))]
lmadDims' <- (LMADDim (TPrimExp t v)
 -> StateT [TPrimExp t v] Identity (LMADDim (TPrimExp t (Ext v))))
-> [LMADDim (TPrimExp t v)]
-> StateT [TPrimExp t v] Identity [LMADDim (TPrimExp t (Ext v))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM LMADDim (TPrimExp t v)
-> StateT [TPrimExp t v] Identity (LMADDim (TPrimExp t (Ext v)))
forall t v.
LMADDim (TPrimExp t v)
-> State [TPrimExp t v] (LMADDim (TPrimExp t (Ext v)))
existentializeLMADDim ([LMADDim (TPrimExp t v)]
 -> StateT [TPrimExp t v] Identity [LMADDim (TPrimExp t (Ext v))])
-> [LMADDim (TPrimExp t v)]
-> StateT [TPrimExp t v] Identity [LMADDim (TPrimExp t (Ext v))]
forall a b. (a -> b) -> a -> b
$ LMAD (TPrimExp t v) -> [LMADDim (TPrimExp t v)]
forall num. LMAD num -> [LMADDim num]
lmadDims LMAD (TPrimExp t v)
lmad
    let lmad' :: LMAD (TPrimExp t (Ext v))
lmad' = TPrimExp t (Ext v)
-> [LMADDim (TPrimExp t (Ext v))] -> LMAD (TPrimExp t (Ext v))
forall num. num -> [LMADDim num] -> LMAD num
LMAD TPrimExp t (Ext v)
lmadOffset' [LMADDim (TPrimExp t (Ext v))]
lmadDims'
    Maybe (IxFun (TPrimExp t (Ext v)))
-> State [TPrimExp t v] (Maybe (IxFun (TPrimExp t (Ext v))))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (IxFun (TPrimExp t (Ext v)))
 -> State [TPrimExp t v] (Maybe (IxFun (TPrimExp t (Ext v)))))
-> Maybe (IxFun (TPrimExp t (Ext v)))
-> State [TPrimExp t v] (Maybe (IxFun (TPrimExp t (Ext v))))
forall a b. (a -> b) -> a -> b
$ IxFun (TPrimExp t (Ext v)) -> Maybe (IxFun (TPrimExp t (Ext v)))
forall a. a -> Maybe a
Just (IxFun (TPrimExp t (Ext v)) -> Maybe (IxFun (TPrimExp t (Ext v))))
-> IxFun (TPrimExp t (Ext v)) -> Maybe (IxFun (TPrimExp t (Ext v)))
forall a b. (a -> b) -> a -> b
$ NonEmpty (LMAD (TPrimExp t (Ext v)))
-> [TPrimExp t (Ext v)] -> Bool -> IxFun (TPrimExp t (Ext v))
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (LMAD (TPrimExp t (Ext v))
lmad' LMAD (TPrimExp t (Ext v))
-> [LMAD (TPrimExp t (Ext v))]
-> NonEmpty (LMAD (TPrimExp t (Ext v)))
forall a. a -> [a] -> NonEmpty a
:| []) [TPrimExp t (Ext v)]
oshp' Bool
True
  where
    existentializeLMADDim ::
      LMADDim (TPrimExp t v) ->
      State [TPrimExp t v] (LMADDim (TPrimExp t (Ext v)))
    existentializeLMADDim :: LMADDim (TPrimExp t v)
-> State [TPrimExp t v] (LMADDim (TPrimExp t (Ext v)))
existentializeLMADDim (LMADDim TPrimExp t v
str TPrimExp t v
rot TPrimExp t v
shp Int
perm Monotonicity
mon) = do
      TPrimExp t (Ext v)
stride' <- TPrimExp t v -> State [TPrimExp t v] (TPrimExp t (Ext v))
forall t v.
TPrimExp t v -> State [TPrimExp t v] (TPrimExp t (Ext v))
existentializeExp TPrimExp t v
str
      TPrimExp t (Ext v)
shape' <- TPrimExp t v -> State [TPrimExp t v] (TPrimExp t (Ext v))
forall t v.
TPrimExp t v -> State [TPrimExp t v] (TPrimExp t (Ext v))
existentializeExp TPrimExp t v
shp
      LMADDim (TPrimExp t (Ext v))
-> State [TPrimExp t v] (LMADDim (TPrimExp t (Ext v)))
forall (m :: * -> *) a. Monad m => a -> m a
return (LMADDim (TPrimExp t (Ext v))
 -> State [TPrimExp t v] (LMADDim (TPrimExp t (Ext v))))
-> LMADDim (TPrimExp t (Ext v))
-> State [TPrimExp t v] (LMADDim (TPrimExp t (Ext v)))
forall a b. (a -> b) -> a -> b
$ TPrimExp t (Ext v)
-> TPrimExp t (Ext v)
-> TPrimExp t (Ext v)
-> Int
-> Monotonicity
-> LMADDim (TPrimExp t (Ext v))
forall num. num -> num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim TPrimExp t (Ext v)
stride' ((v -> Ext v) -> TPrimExp t v -> TPrimExp t (Ext v)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap v -> Ext v
forall a. a -> Ext a
Free TPrimExp t v
rot) TPrimExp t (Ext v)
shape' Int
perm Monotonicity
mon

-- oshp' = LeafExp (Ext 0)
-- lmad' = LMAD lmadOffset' lmadDims'
-- lmadOffset' = LeafExp (Ext 1)
-- (_, lmadDims', lmadDimSubsts) = foldr generalizeDim (2, [], []) $ lmadDims lmad
-- substs = oshp : lmadOffset lmad' : lmadDimSubsts

-- generalizeDim :: (Int, [LMADDim num]) -> LMADDim num -> (Int, [LMADDim num])
-- generalizeDim (i, acc) (LMADDim stride rotate shape perm mon) =
--   (i + 3,
--    LMADDim (LeafExp $ Ext i) (LeafExp $ Ext $ i + 1) (LeafExp $ Ext $ i + 2) perm mon,
--    [stride, rotate, shape])
existentialize IxFun (TPrimExp t v)
_ = Maybe (IxFun (TPrimExp t (Ext v)))
-> State [TPrimExp t v] (Maybe (IxFun (TPrimExp t (Ext v))))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (IxFun (TPrimExp t (Ext v)))
forall a. Maybe a
Nothing

-- | When comparing index functions as part of the type check in KernelsMem,
-- we may run into problems caused by the simplifier. As index functions can be
-- generalized over if-then-else expressions, the simplifier might hoist some of
-- the code from inside the if-then-else (computing the offset of an array, for
-- instance), but now the type checker cannot verify that the generalized index
-- function is valid, because some of the existentials are computed somewhere
-- else. To Work around this, we've had to relax the KernelsMem type-checker
-- a bit, specifically, we've introduced this function to verify whether two
-- index functions are "close enough" that we can assume that they match. We use
-- this instead of `ixfun1 == ixfun2` and hope that it's good enough.
closeEnough :: IxFun num -> IxFun num -> Bool
closeEnough :: IxFun num -> IxFun num -> Bool
closeEnough IxFun num
ixf1 IxFun num
ixf2 =
  ([num] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (IxFun num -> [num]
forall a. IxFun a -> [a]
base IxFun num
ixf1) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [num] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (IxFun num -> [num]
forall a. IxFun a -> [a]
base IxFun num
ixf2))
    Bool -> Bool -> Bool
&& (NonEmpty (LMAD num) -> Int
forall a. NonEmpty a -> Int
NE.length (IxFun num -> NonEmpty (LMAD num)
forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs IxFun num
ixf1) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== NonEmpty (LMAD num) -> Int
forall a. NonEmpty a -> Int
NE.length (IxFun num -> NonEmpty (LMAD num)
forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs IxFun num
ixf2))
    Bool -> Bool -> Bool
&& ((LMAD num, LMAD num) -> Bool)
-> NonEmpty (LMAD num, LMAD num) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (LMAD num, LMAD num) -> Bool
forall num. (LMAD num, LMAD num) -> Bool
closeEnoughLMADs (NonEmpty (LMAD num)
-> NonEmpty (LMAD num) -> NonEmpty (LMAD num, LMAD num)
forall a b. NonEmpty a -> NonEmpty b -> NonEmpty (a, b)
NE.zip (IxFun num -> NonEmpty (LMAD num)
forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs IxFun num
ixf1) (IxFun num -> NonEmpty (LMAD num)
forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs IxFun num
ixf2))
  where
    closeEnoughLMADs :: (LMAD num, LMAD num) -> Bool
    closeEnoughLMADs :: (LMAD num, LMAD num) -> Bool
closeEnoughLMADs (LMAD num
lmad1, LMAD num
lmad2) =
      [LMADDim num] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad1) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [LMADDim num] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad2)
        Bool -> Bool -> Bool
&& (LMADDim num -> Int) -> [LMADDim num] -> Permutation
forall a b. (a -> b) -> [a] -> [b]
map LMADDim num -> Int
forall num. LMADDim num -> Int
ldPerm (LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad1)
        Permutation -> Permutation -> Bool
forall a. Eq a => a -> a -> Bool
== (LMADDim num -> Int) -> [LMADDim num] -> Permutation
forall a b. (a -> b) -> [a] -> [b]
map LMADDim num -> Int
forall num. LMADDim num -> Int
ldPerm (LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad2)