{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
{-# OPTIONS_GHC -fno-warn-redundant-constraints #-}
module Futhark.IR.Mem.IxFun
( IxFun (..),
LMAD (..),
LMADDim (..),
Monotonicity (..),
index,
iota,
iotaOffset,
permute,
rotate,
reshape,
slice,
flatSlice,
rebase,
shape,
rank,
linearWithOffset,
rearrangeWithOffset,
isDirect,
isLinear,
substituteInIxFun,
leastGeneralGeneralization,
existentialize,
closeEnough,
equivalent,
)
where
import Control.Category
import Control.Monad.Identity
import Control.Monad.State
import Control.Monad.Writer
import Data.Function (on, (&))
import Data.List (sort, sortBy, zip4, zip5, zipWith5)
import Data.List.NonEmpty (NonEmpty (..))
import qualified Data.List.NonEmpty as NE
import qualified Data.Map.Strict as M
import Data.Maybe (isJust)
import Futhark.Analysis.PrimExp
( IntExp,
PrimExp (..),
TPrimExp (..),
primExpType,
)
import Futhark.Analysis.PrimExp.Convert (substituteInPrimExp)
import qualified Futhark.Analysis.PrimExp.Generalize as PEG
import Futhark.IR.Prop
import Futhark.IR.Syntax
( DimChange (..),
DimIndex (..),
FlatDimIndex (..),
FlatSlice (..),
ShapeChange,
Slice (..),
dimFix,
flatSliceDims,
flatSliceStrides,
unitSlice,
)
import Futhark.IR.Syntax.Core (Ext (..))
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import Futhark.Util.IntegralExp
import Futhark.Util.Pretty
import Prelude hiding (id, mod, (.))
type Shape num = [num]
type Indices num = [num]
type Permutation = [Int]
data Monotonicity
= Inc
| Dec
|
Unknown
deriving (Int -> Monotonicity -> ShowS
[Monotonicity] -> ShowS
Monotonicity -> String
(Int -> Monotonicity -> ShowS)
-> (Monotonicity -> String)
-> ([Monotonicity] -> ShowS)
-> Show Monotonicity
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Monotonicity] -> ShowS
$cshowList :: [Monotonicity] -> ShowS
show :: Monotonicity -> String
$cshow :: Monotonicity -> String
showsPrec :: Int -> Monotonicity -> ShowS
$cshowsPrec :: Int -> Monotonicity -> ShowS
Show, Monotonicity -> Monotonicity -> Bool
(Monotonicity -> Monotonicity -> Bool)
-> (Monotonicity -> Monotonicity -> Bool) -> Eq Monotonicity
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Monotonicity -> Monotonicity -> Bool
$c/= :: Monotonicity -> Monotonicity -> Bool
== :: Monotonicity -> Monotonicity -> Bool
$c== :: Monotonicity -> Monotonicity -> Bool
Eq)
data LMADDim num = LMADDim
{ LMADDim num -> num
ldStride :: num,
LMADDim num -> num
ldRotate :: num,
LMADDim num -> num
ldShape :: num,
LMADDim num -> Int
ldPerm :: Int,
LMADDim num -> Monotonicity
ldMon :: Monotonicity
}
deriving (Int -> LMADDim num -> ShowS
[LMADDim num] -> ShowS
LMADDim num -> String
(Int -> LMADDim num -> ShowS)
-> (LMADDim num -> String)
-> ([LMADDim num] -> ShowS)
-> Show (LMADDim num)
forall num. Show num => Int -> LMADDim num -> ShowS
forall num. Show num => [LMADDim num] -> ShowS
forall num. Show num => LMADDim num -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [LMADDim num] -> ShowS
$cshowList :: forall num. Show num => [LMADDim num] -> ShowS
show :: LMADDim num -> String
$cshow :: forall num. Show num => LMADDim num -> String
showsPrec :: Int -> LMADDim num -> ShowS
$cshowsPrec :: forall num. Show num => Int -> LMADDim num -> ShowS
Show, LMADDim num -> LMADDim num -> Bool
(LMADDim num -> LMADDim num -> Bool)
-> (LMADDim num -> LMADDim num -> Bool) -> Eq (LMADDim num)
forall num. Eq num => LMADDim num -> LMADDim num -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: LMADDim num -> LMADDim num -> Bool
$c/= :: forall num. Eq num => LMADDim num -> LMADDim num -> Bool
== :: LMADDim num -> LMADDim num -> Bool
$c== :: forall num. Eq num => LMADDim num -> LMADDim num -> Bool
Eq)
data LMAD num = LMAD
{ LMAD num -> num
lmadOffset :: num,
LMAD num -> [LMADDim num]
lmadDims :: [LMADDim num]
}
deriving (Int -> LMAD num -> ShowS
[LMAD num] -> ShowS
LMAD num -> String
(Int -> LMAD num -> ShowS)
-> (LMAD num -> String) -> ([LMAD num] -> ShowS) -> Show (LMAD num)
forall num. Show num => Int -> LMAD num -> ShowS
forall num. Show num => [LMAD num] -> ShowS
forall num. Show num => LMAD num -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [LMAD num] -> ShowS
$cshowList :: forall num. Show num => [LMAD num] -> ShowS
show :: LMAD num -> String
$cshow :: forall num. Show num => LMAD num -> String
showsPrec :: Int -> LMAD num -> ShowS
$cshowsPrec :: forall num. Show num => Int -> LMAD num -> ShowS
Show, LMAD num -> LMAD num -> Bool
(LMAD num -> LMAD num -> Bool)
-> (LMAD num -> LMAD num -> Bool) -> Eq (LMAD num)
forall num. Eq num => LMAD num -> LMAD num -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: LMAD num -> LMAD num -> Bool
$c/= :: forall num. Eq num => LMAD num -> LMAD num -> Bool
== :: LMAD num -> LMAD num -> Bool
$c== :: forall num. Eq num => LMAD num -> LMAD num -> Bool
Eq)
data IxFun num = IxFun
{ IxFun num -> NonEmpty (LMAD num)
ixfunLMADs :: NonEmpty (LMAD num),
IxFun num -> Shape num
base :: Shape num,
IxFun num -> Bool
ixfunContig :: Bool
}
deriving (Int -> IxFun num -> ShowS
[IxFun num] -> ShowS
IxFun num -> String
(Int -> IxFun num -> ShowS)
-> (IxFun num -> String)
-> ([IxFun num] -> ShowS)
-> Show (IxFun num)
forall num. Show num => Int -> IxFun num -> ShowS
forall num. Show num => [IxFun num] -> ShowS
forall num. Show num => IxFun num -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [IxFun num] -> ShowS
$cshowList :: forall num. Show num => [IxFun num] -> ShowS
show :: IxFun num -> String
$cshow :: forall num. Show num => IxFun num -> String
showsPrec :: Int -> IxFun num -> ShowS
$cshowsPrec :: forall num. Show num => Int -> IxFun num -> ShowS
Show, IxFun num -> IxFun num -> Bool
(IxFun num -> IxFun num -> Bool)
-> (IxFun num -> IxFun num -> Bool) -> Eq (IxFun num)
forall num. Eq num => IxFun num -> IxFun num -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: IxFun num -> IxFun num -> Bool
$c/= :: forall num. Eq num => IxFun num -> IxFun num -> Bool
== :: IxFun num -> IxFun num -> Bool
$c== :: forall num. Eq num => IxFun num -> IxFun num -> Bool
Eq)
instance Pretty Monotonicity where
ppr :: Monotonicity -> Doc
ppr = String -> Doc
text (String -> Doc) -> (Monotonicity -> String) -> Monotonicity -> Doc
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Monotonicity -> String
forall a. Show a => a -> String
show
instance Pretty num => Pretty (LMAD num) where
ppr :: LMAD num -> Doc
ppr (LMAD num
offset [LMADDim num]
dims) =
Doc -> Doc
braces (Doc -> Doc) -> Doc -> Doc
forall a b. (a -> b) -> a -> b
$
[Doc] -> Doc
semisep
[ Doc
"offset: " Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
oneLine (num -> Doc
forall a. Pretty a => a -> Doc
ppr num
offset),
Doc
"strides: " Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> (LMADDim num -> num) -> Doc
forall b. Pretty b => (LMADDim num -> b) -> Doc
p LMADDim num -> num
forall num. LMADDim num -> num
ldStride,
Doc
"rotates: " Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> (LMADDim num -> num) -> Doc
forall b. Pretty b => (LMADDim num -> b) -> Doc
p LMADDim num -> num
forall num. LMADDim num -> num
ldRotate,
Doc
"shape: " Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> (LMADDim num -> num) -> Doc
forall b. Pretty b => (LMADDim num -> b) -> Doc
p LMADDim num -> num
forall num. LMADDim num -> num
ldShape,
Doc
"permutation: " Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> (LMADDim num -> Int) -> Doc
forall b. Pretty b => (LMADDim num -> b) -> Doc
p LMADDim num -> Int
forall num. LMADDim num -> Int
ldPerm,
Doc
"monotonicity: " Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> (LMADDim num -> Monotonicity) -> Doc
forall b. Pretty b => (LMADDim num -> b) -> Doc
p LMADDim num -> Monotonicity
forall num. LMADDim num -> Monotonicity
ldMon
]
where
p :: (LMADDim num -> b) -> Doc
p LMADDim num -> b
f = Doc -> Doc
oneLine (Doc -> Doc) -> Doc -> Doc
forall a b. (a -> b) -> a -> b
$ Doc -> Doc
brackets (Doc -> Doc) -> Doc -> Doc
forall a b. (a -> b) -> a -> b
$ [Doc] -> Doc
commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (LMADDim num -> Doc) -> [LMADDim num] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map (b -> Doc
forall a. Pretty a => a -> Doc
ppr (b -> Doc) -> (LMADDim num -> b) -> LMADDim num -> Doc
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. LMADDim num -> b
f) [LMADDim num]
dims
instance Pretty num => Pretty (IxFun num) where
ppr :: IxFun num -> Doc
ppr (IxFun NonEmpty (LMAD num)
lmads Shape num
oshp Bool
cg) =
Doc -> Doc
braces (Doc -> Doc) -> Doc -> Doc
forall a b. (a -> b) -> a -> b
$
[Doc] -> Doc
semisep
[ Doc
"base: " Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
brackets ([Doc] -> Doc
commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (num -> Doc) -> Shape num -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map num -> Doc
forall a. Pretty a => a -> Doc
ppr Shape num
oshp),
Doc
"contiguous: " Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> if Bool
cg then Doc
"true" else Doc
"false",
Doc
"LMADs: " Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
brackets ([Doc] -> Doc
commastack ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ NonEmpty Doc -> [Doc]
forall a. NonEmpty a -> [a]
NE.toList (NonEmpty Doc -> [Doc]) -> NonEmpty Doc -> [Doc]
forall a b. (a -> b) -> a -> b
$ (LMAD num -> Doc) -> NonEmpty (LMAD num) -> NonEmpty Doc
forall a b. (a -> b) -> NonEmpty a -> NonEmpty b
NE.map LMAD num -> Doc
forall a. Pretty a => a -> Doc
ppr NonEmpty (LMAD num)
lmads)
]
instance Substitute num => Substitute (LMAD num) where
substituteNames :: Map VName VName -> LMAD num -> LMAD num
substituteNames Map VName VName
substs = (num -> num) -> LMAD num -> LMAD num
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((num -> num) -> LMAD num -> LMAD num)
-> (num -> num) -> LMAD num -> LMAD num
forall a b. (a -> b) -> a -> b
$ Map VName VName -> num -> num
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs
instance Substitute num => Substitute (IxFun num) where
substituteNames :: Map VName VName -> IxFun num -> IxFun num
substituteNames Map VName VName
substs = (num -> num) -> IxFun num -> IxFun num
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((num -> num) -> IxFun num -> IxFun num)
-> (num -> num) -> IxFun num -> IxFun num
forall a b. (a -> b) -> a -> b
$ Map VName VName -> num -> num
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs
instance Substitute num => Rename (LMAD num) where
rename :: LMAD num -> RenameM (LMAD num)
rename = LMAD num -> RenameM (LMAD num)
forall a. Substitute a => a -> RenameM a
substituteRename
instance Substitute num => Rename (IxFun num) where
rename :: IxFun num -> RenameM (IxFun num)
rename = IxFun num -> RenameM (IxFun num)
forall a. Substitute a => a -> RenameM a
substituteRename
instance FreeIn num => FreeIn (LMAD num) where
freeIn' :: LMAD num -> FV
freeIn' = (num -> FV) -> LMAD num -> FV
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap num -> FV
forall a. FreeIn a => a -> FV
freeIn'
instance FreeIn num => FreeIn (IxFun num) where
freeIn' :: IxFun num -> FV
freeIn' = (num -> FV) -> IxFun num -> FV
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap num -> FV
forall a. FreeIn a => a -> FV
freeIn'
instance Functor LMAD where
fmap :: (a -> b) -> LMAD a -> LMAD b
fmap a -> b
f = Identity (LMAD b) -> LMAD b
forall a. Identity a -> a
runIdentity (Identity (LMAD b) -> LMAD b)
-> (LMAD a -> Identity (LMAD b)) -> LMAD a -> LMAD b
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (a -> Identity b) -> LMAD a -> Identity (LMAD b)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (b -> Identity b
forall (m :: * -> *) a. Monad m => a -> m a
return (b -> Identity b) -> (a -> b) -> a -> Identity b
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. a -> b
f)
instance Functor IxFun where
fmap :: (a -> b) -> IxFun a -> IxFun b
fmap a -> b
f = Identity (IxFun b) -> IxFun b
forall a. Identity a -> a
runIdentity (Identity (IxFun b) -> IxFun b)
-> (IxFun a -> Identity (IxFun b)) -> IxFun a -> IxFun b
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (a -> Identity b) -> IxFun a -> Identity (IxFun b)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (b -> Identity b
forall (m :: * -> *) a. Monad m => a -> m a
return (b -> Identity b) -> (a -> b) -> a -> Identity b
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. a -> b
f)
instance Foldable LMAD where
foldMap :: (a -> m) -> LMAD a -> m
foldMap a -> m
f = Writer m (LMAD ()) -> m
forall w a. Writer w a -> w
execWriter (Writer m (LMAD ()) -> m)
-> (LMAD a -> Writer m (LMAD ())) -> LMAD a -> m
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (a -> WriterT m Identity ()) -> LMAD a -> Writer m (LMAD ())
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (m -> WriterT m Identity ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell (m -> WriterT m Identity ())
-> (a -> m) -> a -> WriterT m Identity ()
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. a -> m
f)
instance Foldable IxFun where
foldMap :: (a -> m) -> IxFun a -> m
foldMap a -> m
f = Writer m (IxFun ()) -> m
forall w a. Writer w a -> w
execWriter (Writer m (IxFun ()) -> m)
-> (IxFun a -> Writer m (IxFun ())) -> IxFun a -> m
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (a -> WriterT m Identity ()) -> IxFun a -> Writer m (IxFun ())
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (m -> WriterT m Identity ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell (m -> WriterT m Identity ())
-> (a -> m) -> a -> WriterT m Identity ()
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. a -> m
f)
instance Traversable LMAD where
traverse :: (a -> f b) -> LMAD a -> f (LMAD b)
traverse a -> f b
f (LMAD a
offset [LMADDim a]
dims) =
b -> [LMADDim b] -> LMAD b
forall num. num -> [LMADDim num] -> LMAD num
LMAD (b -> [LMADDim b] -> LMAD b) -> f b -> f ([LMADDim b] -> LMAD b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> f b
f a
offset f ([LMADDim b] -> LMAD b) -> f [LMADDim b] -> f (LMAD b)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (LMADDim a -> f (LMADDim b)) -> [LMADDim a] -> f [LMADDim b]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse LMADDim a -> f (LMADDim b)
f' [LMADDim a]
dims
where
f' :: LMADDim a -> f (LMADDim b)
f' (LMADDim a
s a
r a
n Int
p Monotonicity
m) =
b -> b -> b -> Int -> Monotonicity -> LMADDim b
forall num. num -> num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim (b -> b -> b -> Int -> Monotonicity -> LMADDim b)
-> f b -> f (b -> b -> Int -> Monotonicity -> LMADDim b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> f b
f a
s f (b -> b -> Int -> Monotonicity -> LMADDim b)
-> f b -> f (b -> Int -> Monotonicity -> LMADDim b)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> a -> f b
f a
r f (b -> Int -> Monotonicity -> LMADDim b)
-> f b -> f (Int -> Monotonicity -> LMADDim b)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> a -> f b
f a
n f (Int -> Monotonicity -> LMADDim b)
-> f Int -> f (Monotonicity -> LMADDim b)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> f Int
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
p f (Monotonicity -> LMADDim b) -> f Monotonicity -> f (LMADDim b)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Monotonicity -> f Monotonicity
forall (f :: * -> *) a. Applicative f => a -> f a
pure Monotonicity
m
instance Traversable IxFun where
traverse :: (a -> f b) -> IxFun a -> f (IxFun b)
traverse a -> f b
f (IxFun NonEmpty (LMAD a)
lmads Shape a
oshp Bool
cg) =
NonEmpty (LMAD b) -> Shape b -> Bool -> IxFun b
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (NonEmpty (LMAD b) -> Shape b -> Bool -> IxFun b)
-> f (NonEmpty (LMAD b)) -> f (Shape b -> Bool -> IxFun b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (LMAD a -> f (LMAD b))
-> NonEmpty (LMAD a) -> f (NonEmpty (LMAD b))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((a -> f b) -> LMAD a -> f (LMAD b)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse a -> f b
f) NonEmpty (LMAD a)
lmads f (Shape b -> Bool -> IxFun b)
-> f (Shape b) -> f (Bool -> IxFun b)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (a -> f b) -> Shape a -> f (Shape b)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse a -> f b
f Shape a
oshp f (Bool -> IxFun b) -> f Bool -> f (IxFun b)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Bool -> f Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
cg
(++@) :: [a] -> NonEmpty a -> NonEmpty a
[a]
es ++@ :: [a] -> NonEmpty a -> NonEmpty a
++@ (a
ne :| [a]
nes) = case [a]
es of
a
e : [a]
es' -> a
e a -> [a] -> NonEmpty a
forall a. a -> [a] -> NonEmpty a
:| [a]
es' [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a
ne] [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
nes
[] -> a
ne a -> [a] -> NonEmpty a
forall a. a -> [a] -> NonEmpty a
:| [a]
nes
(@++@) :: NonEmpty a -> NonEmpty a -> NonEmpty a
(a
x :| [a]
xs) @++@ :: NonEmpty a -> NonEmpty a -> NonEmpty a
@++@ (a
y :| [a]
ys) = a
x a -> [a] -> NonEmpty a
forall a. a -> [a] -> NonEmpty a
:| [a]
xs [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a
y] [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
ys
invertMonotonicity :: Monotonicity -> Monotonicity
invertMonotonicity :: Monotonicity -> Monotonicity
invertMonotonicity Monotonicity
Inc = Monotonicity
Dec
invertMonotonicity Monotonicity
Dec = Monotonicity
Inc
invertMonotonicity Monotonicity
Unknown = Monotonicity
Unknown
lmadPermutation :: LMAD num -> Permutation
lmadPermutation :: LMAD num -> Permutation
lmadPermutation = (LMADDim num -> Int) -> [LMADDim num] -> Permutation
forall a b. (a -> b) -> [a] -> [b]
map LMADDim num -> Int
forall num. LMADDim num -> Int
ldPerm ([LMADDim num] -> Permutation)
-> (LMAD num -> [LMADDim num]) -> LMAD num -> Permutation
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
lmadDims
setLMADPermutation :: Permutation -> LMAD num -> LMAD num
setLMADPermutation :: Permutation -> LMAD num -> LMAD num
setLMADPermutation Permutation
perm LMAD num
lmad =
LMAD num
lmad {lmadDims :: [LMADDim num]
lmadDims = (LMADDim num -> Int -> LMADDim num)
-> [LMADDim num] -> Permutation -> [LMADDim num]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\LMADDim num
dim Int
p -> LMADDim num
dim {ldPerm :: Int
ldPerm = Int
p}) (LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad) Permutation
perm}
setLMADShape :: Shape num -> LMAD num -> LMAD num
setLMADShape :: Shape num -> LMAD num -> LMAD num
setLMADShape Shape num
shp LMAD num
lmad = LMAD num
lmad {lmadDims :: [LMADDim num]
lmadDims = (LMADDim num -> num -> LMADDim num)
-> [LMADDim num] -> Shape num -> [LMADDim num]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\LMADDim num
dim num
s -> LMADDim num
dim {ldShape :: num
ldShape = num
s}) (LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad) Shape num
shp}
substituteInLMAD ::
Ord a =>
M.Map a (PrimExp a) ->
LMAD (PrimExp a) ->
LMAD (PrimExp a)
substituteInLMAD :: Map a (PrimExp a) -> LMAD (PrimExp a) -> LMAD (PrimExp a)
substituteInLMAD Map a (PrimExp a)
tab (LMAD PrimExp a
offset [LMADDim (PrimExp a)]
dims) =
let offset' :: PrimExp a
offset' = Map a (PrimExp a) -> PrimExp a -> PrimExp a
forall v. Ord v => Map v (PrimExp v) -> PrimExp v -> PrimExp v
substituteInPrimExp Map a (PrimExp a)
tab PrimExp a
offset
dims' :: [LMADDim (PrimExp a)]
dims' =
(LMADDim (PrimExp a) -> LMADDim (PrimExp a))
-> [LMADDim (PrimExp a)] -> [LMADDim (PrimExp a)]
forall a b. (a -> b) -> [a] -> [b]
map
( \(LMADDim PrimExp a
s PrimExp a
r PrimExp a
n Int
p Monotonicity
m) ->
PrimExp a
-> PrimExp a
-> PrimExp a
-> Int
-> Monotonicity
-> LMADDim (PrimExp a)
forall num. num -> num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim
(Map a (PrimExp a) -> PrimExp a -> PrimExp a
forall v. Ord v => Map v (PrimExp v) -> PrimExp v -> PrimExp v
substituteInPrimExp Map a (PrimExp a)
tab PrimExp a
s)
(Map a (PrimExp a) -> PrimExp a -> PrimExp a
forall v. Ord v => Map v (PrimExp v) -> PrimExp v -> PrimExp v
substituteInPrimExp Map a (PrimExp a)
tab PrimExp a
r)
(Map a (PrimExp a) -> PrimExp a -> PrimExp a
forall v. Ord v => Map v (PrimExp v) -> PrimExp v -> PrimExp v
substituteInPrimExp Map a (PrimExp a)
tab PrimExp a
n)
Int
p
Monotonicity
m
)
[LMADDim (PrimExp a)]
dims
in PrimExp a -> [LMADDim (PrimExp a)] -> LMAD (PrimExp a)
forall num. num -> [LMADDim num] -> LMAD num
LMAD PrimExp a
offset' [LMADDim (PrimExp a)]
dims'
substituteInIxFun ::
Ord a =>
M.Map a (TPrimExp t a) ->
IxFun (TPrimExp t a) ->
IxFun (TPrimExp t a)
substituteInIxFun :: Map a (TPrimExp t a)
-> IxFun (TPrimExp t a) -> IxFun (TPrimExp t a)
substituteInIxFun Map a (TPrimExp t a)
tab (IxFun NonEmpty (LMAD (TPrimExp t a))
lmads Shape (TPrimExp t a)
oshp Bool
cg) =
NonEmpty (LMAD (TPrimExp t a))
-> Shape (TPrimExp t a) -> Bool -> IxFun (TPrimExp t a)
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun
((LMAD (TPrimExp t a) -> LMAD (TPrimExp t a))
-> NonEmpty (LMAD (TPrimExp t a)) -> NonEmpty (LMAD (TPrimExp t a))
forall a b. (a -> b) -> NonEmpty a -> NonEmpty b
NE.map ((PrimExp a -> TPrimExp t a)
-> LMAD (PrimExp a) -> LMAD (TPrimExp t a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap PrimExp a -> TPrimExp t a
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (LMAD (PrimExp a) -> LMAD (TPrimExp t a))
-> (LMAD (TPrimExp t a) -> LMAD (PrimExp a))
-> LMAD (TPrimExp t a)
-> LMAD (TPrimExp t a)
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map a (PrimExp a) -> LMAD (PrimExp a) -> LMAD (PrimExp a)
forall a.
Ord a =>
Map a (PrimExp a) -> LMAD (PrimExp a) -> LMAD (PrimExp a)
substituteInLMAD Map a (PrimExp a)
tab' (LMAD (PrimExp a) -> LMAD (PrimExp a))
-> (LMAD (TPrimExp t a) -> LMAD (PrimExp a))
-> LMAD (TPrimExp t a)
-> LMAD (PrimExp a)
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (TPrimExp t a -> PrimExp a)
-> LMAD (TPrimExp t a) -> LMAD (PrimExp a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap TPrimExp t a -> PrimExp a
forall t v. TPrimExp t v -> PrimExp v
untyped) NonEmpty (LMAD (TPrimExp t a))
lmads)
((TPrimExp t a -> TPrimExp t a)
-> Shape (TPrimExp t a) -> Shape (TPrimExp t a)
forall a b. (a -> b) -> [a] -> [b]
map (PrimExp a -> TPrimExp t a
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (PrimExp a -> TPrimExp t a)
-> (TPrimExp t a -> PrimExp a) -> TPrimExp t a -> TPrimExp t a
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map a (PrimExp a) -> PrimExp a -> PrimExp a
forall v. Ord v => Map v (PrimExp v) -> PrimExp v -> PrimExp v
substituteInPrimExp Map a (PrimExp a)
tab' (PrimExp a -> PrimExp a)
-> (TPrimExp t a -> PrimExp a) -> TPrimExp t a -> PrimExp a
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. TPrimExp t a -> PrimExp a
forall t v. TPrimExp t v -> PrimExp v
untyped) Shape (TPrimExp t a)
oshp)
Bool
cg
where
tab' :: Map a (PrimExp a)
tab' = (TPrimExp t a -> PrimExp a)
-> Map a (TPrimExp t a) -> Map a (PrimExp a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap TPrimExp t a -> PrimExp a
forall t v. TPrimExp t v -> PrimExp v
untyped Map a (TPrimExp t a)
tab
isDirect :: (Eq num, IntegralExp num) => IxFun num -> Bool
isDirect :: IxFun num -> Bool
isDirect ixfun :: IxFun num
ixfun@(IxFun (LMAD num
offset [LMADDim num]
dims :| []) Shape num
oshp Bool
True) =
let strides_expected :: Shape num
strides_expected = Shape num -> Shape num
forall a. [a] -> [a]
reverse (Shape num -> Shape num) -> Shape num -> Shape num
forall a b. (a -> b) -> a -> b
$ (num -> num -> num) -> num -> Shape num -> Shape num
forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl num -> num -> num
forall a. Num a => a -> a -> a
(*) num
1 (Shape num -> Shape num
forall a. [a] -> [a]
reverse (Shape num -> Shape num
forall a. [a] -> [a]
tail Shape num
oshp))
in IxFun num -> Bool
forall a. IxFun a -> Bool
hasContiguousPerm IxFun num
ixfun
Bool -> Bool -> Bool
&& Shape num -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape num
oshp Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [LMADDim num] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims
Bool -> Bool -> Bool
&& num
offset num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0
Bool -> Bool -> Bool
&& ((LMADDim num, Int, num, num) -> Bool)
-> [(LMADDim num, Int, num, num)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all
( \(LMADDim num
s num
r num
n Int
p Monotonicity
_, Int
m, num
d, num
se) ->
num
s num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
se Bool -> Bool -> Bool
&& num
r num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0 Bool -> Bool -> Bool
&& num
n num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
d Bool -> Bool -> Bool
&& Int
p Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
m
)
([LMADDim num]
-> Permutation
-> Shape num
-> Shape num
-> [(LMADDim num, Int, num, num)]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [LMADDim num]
dims [Int
0 .. [LMADDim num] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] Shape num
oshp Shape num
strides_expected)
isDirect IxFun num
_ = Bool
False
hasContiguousPerm :: IxFun num -> Bool
hasContiguousPerm :: IxFun num -> Bool
hasContiguousPerm (IxFun (LMAD num
lmad :| []) Shape num
_ Bool
_) =
let perm :: Permutation
perm = LMAD num -> Permutation
forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad
in Permutation
perm Permutation -> Permutation -> Bool
forall a. Eq a => a -> a -> Bool
== Permutation -> Permutation
forall a. Ord a => [a] -> [a]
sort Permutation
perm
hasContiguousPerm IxFun num
_ = Bool
False
shape :: (Eq num, IntegralExp num) => IxFun num -> Shape num
shape :: IxFun num -> Shape num
shape (IxFun (LMAD num
lmad :| [LMAD num]
_) Shape num
_ Bool
_) = LMAD num -> Shape num
forall num. (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShape LMAD num
lmad
lmadShape :: (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShape :: LMAD num -> Shape num
lmadShape LMAD num
lmad = Permutation -> Shape num -> Shape num
forall a. Permutation -> [a] -> [a]
permuteInv (LMAD num -> Permutation
forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad) (Shape num -> Shape num) -> Shape num -> Shape num
forall a b. (a -> b) -> a -> b
$ LMAD num -> Shape num
forall num. (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShapeBase LMAD num
lmad
lmadShapeBase :: (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShapeBase :: LMAD num -> Shape num
lmadShapeBase = (LMADDim num -> num) -> [LMADDim num] -> Shape num
forall a b. (a -> b) -> [a] -> [b]
map LMADDim num -> num
forall num. LMADDim num -> num
ldShape ([LMADDim num] -> Shape num)
-> (LMAD num -> [LMADDim num]) -> LMAD num -> Shape num
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
lmadDims
index ::
(IntegralExp num, Eq num) =>
IxFun num ->
Indices num ->
num
index :: IxFun num -> Indices num -> num
index = NonEmpty (LMAD num) -> Indices num -> num
forall num.
(IntegralExp num, Eq num) =>
NonEmpty (LMAD num) -> Indices num -> num
indexFromLMADs (NonEmpty (LMAD num) -> Indices num -> num)
-> (IxFun num -> NonEmpty (LMAD num))
-> IxFun num
-> Indices num
-> num
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. IxFun num -> NonEmpty (LMAD num)
forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs
where
indexFromLMADs ::
(IntegralExp num, Eq num) =>
NonEmpty (LMAD num) ->
Indices num ->
num
indexFromLMADs :: NonEmpty (LMAD num) -> Indices num -> num
indexFromLMADs (LMAD num
lmad :| []) Indices num
inds = LMAD num -> Indices num -> num
forall num.
(IntegralExp num, Eq num) =>
LMAD num -> Indices num -> num
indexLMAD LMAD num
lmad Indices num
inds
indexFromLMADs (LMAD num
lmad1 :| LMAD num
lmad2 : [LMAD num]
lmads) Indices num
inds =
let i_flat :: num
i_flat = LMAD num -> Indices num -> num
forall num.
(IntegralExp num, Eq num) =>
LMAD num -> Indices num -> num
indexLMAD LMAD num
lmad1 Indices num
inds
new_inds :: Indices num
new_inds = Indices num -> num -> Indices num
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex (Permutation -> Indices num -> Indices num
forall a. Permutation -> [a] -> [a]
permuteFwd (LMAD num -> Permutation
forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad2) (Indices num -> Indices num) -> Indices num -> Indices num
forall a b. (a -> b) -> a -> b
$ LMAD num -> Indices num
forall num. (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShapeBase LMAD num
lmad2) num
i_flat
in NonEmpty (LMAD num) -> Indices num -> num
forall num.
(IntegralExp num, Eq num) =>
NonEmpty (LMAD num) -> Indices num -> num
indexFromLMADs (LMAD num
lmad2 LMAD num -> [LMAD num] -> NonEmpty (LMAD num)
forall a. a -> [a] -> NonEmpty a
:| [LMAD num]
lmads) Indices num
new_inds
indexLMAD ::
(IntegralExp num, Eq num) =>
LMAD num ->
Indices num ->
num
indexLMAD :: LMAD num -> Indices num -> num
indexLMAD lmad :: LMAD num
lmad@(LMAD num
off [LMADDim num]
dims) Indices num
inds =
let prod :: num
prod =
Indices num -> num
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum (Indices num -> num) -> Indices num -> num
forall a b. (a -> b) -> a -> b
$
((num, num, num) -> num -> num)
-> [(num, num, num)] -> Indices num -> Indices num
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
(num, num, num) -> num -> num
forall num.
(Eq num, IntegralExp num) =>
(num, num, num) -> num -> num
flatOneDim
((LMADDim num -> (num, num, num))
-> [LMADDim num] -> [(num, num, num)]
forall a b. (a -> b) -> [a] -> [b]
map (\(LMADDim num
s num
r num
n Int
_ Monotonicity
_) -> (num
s, num
r, num
n)) [LMADDim num]
dims)
(Permutation -> Indices num -> Indices num
forall a. Permutation -> [a] -> [a]
permuteInv (LMAD num -> Permutation
forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad) Indices num
inds)
in num
off num -> num -> num
forall a. Num a => a -> a -> a
+ num
prod
iotaOffset :: IntegralExp num => num -> Shape num -> IxFun num
iotaOffset :: num -> Shape num -> IxFun num
iotaOffset num
o Shape num
ns =
let rs :: Shape num
rs = Int -> num -> Shape num
forall a. Int -> a -> [a]
replicate (Shape num -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape num
ns) num
0
in NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (Monotonicity -> num -> [(num, num)] -> LMAD num
forall num.
IntegralExp num =>
Monotonicity -> num -> [(num, num)] -> LMAD num
makeRotIota Monotonicity
Inc num
o (Shape num -> Shape num -> [(num, num)]
forall a b. [a] -> [b] -> [(a, b)]
zip Shape num
rs Shape num
ns) LMAD num -> [LMAD num] -> NonEmpty (LMAD num)
forall a. a -> [a] -> NonEmpty a
:| []) Shape num
ns Bool
True
iota :: IntegralExp num => Shape num -> IxFun num
iota :: Shape num -> IxFun num
iota = num -> Shape num -> IxFun num
forall num. IntegralExp num => num -> Shape num -> IxFun num
iotaOffset num
0
permute ::
IntegralExp num =>
IxFun num ->
Permutation ->
IxFun num
permute :: IxFun num -> Permutation -> IxFun num
permute (IxFun (LMAD num
lmad :| [LMAD num]
lmads) Shape num
oshp Bool
cg) Permutation
perm_new =
let perm_cur :: Permutation
perm_cur = LMAD num -> Permutation
forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad
perm :: Permutation
perm = (Int -> Int) -> Permutation -> Permutation
forall a b. (a -> b) -> [a] -> [b]
map (Permutation
perm_cur Permutation -> Int -> Int
forall a. [a] -> Int -> a
!!) Permutation
perm_new
in NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (Permutation -> LMAD num -> LMAD num
forall num. Permutation -> LMAD num -> LMAD num
setLMADPermutation Permutation
perm LMAD num
lmad LMAD num -> [LMAD num] -> NonEmpty (LMAD num)
forall a. a -> [a] -> NonEmpty a
:| [LMAD num]
lmads) Shape num
oshp Bool
cg
rotate ::
(Eq num, IntegralExp num) =>
IxFun num ->
Indices num ->
IxFun num
rotate :: IxFun num -> Indices num -> IxFun num
rotate (IxFun (lmad :: LMAD num
lmad@(LMAD num
off [LMADDim num]
dims) :| [LMAD num]
lmads) Indices num
oshp Bool
cg) Indices num
offs =
let dims' :: [LMADDim num]
dims' =
(LMADDim num -> num -> LMADDim num)
-> [LMADDim num] -> Indices num -> [LMADDim num]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
( \(LMADDim num
s num
r num
n Int
p Monotonicity
f) num
o ->
if num
s num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0
then num -> num -> num -> Int -> Monotonicity -> LMADDim num
forall num. num -> num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim num
0 num
0 num
n Int
p Monotonicity
Unknown
else num -> num -> num -> Int -> Monotonicity -> LMADDim num
forall num. num -> num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim num
s (num
r num -> num -> num
forall a. Num a => a -> a -> a
+ num
o) num
n Int
p Monotonicity
f
)
[LMADDim num]
dims
(Permutation -> Indices num -> Indices num
forall a. Permutation -> [a] -> [a]
permuteInv (LMAD num -> Permutation
forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad) Indices num
offs)
in NonEmpty (LMAD num) -> Indices num -> Bool -> IxFun num
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off [LMADDim num]
dims' LMAD num -> [LMAD num] -> NonEmpty (LMAD num)
forall a. a -> [a] -> NonEmpty a
:| [LMAD num]
lmads) Indices num
oshp Bool
cg
sliceOneLMAD ::
(Eq num, IntegralExp num) =>
IxFun num ->
Slice num ->
Maybe (IxFun num)
sliceOneLMAD :: IxFun num -> Slice num -> Maybe (IxFun num)
sliceOneLMAD (IxFun (lmad :: LMAD num
lmad@(LMAD num
_ [LMADDim num]
ldims) :| [LMAD num]
lmads) Shape num
oshp Bool
cg) (Slice [DimIndex num]
is) = do
let perm :: Permutation
perm = LMAD num -> Permutation
forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad
is' :: [DimIndex num]
is' = Permutation -> [DimIndex num] -> [DimIndex num]
forall a. Permutation -> [a] -> [a]
permuteInv Permutation
perm [DimIndex num]
is
cg' :: Bool
cg' = Bool
cg Bool -> Bool -> Bool
&& LMAD num -> Slice num -> Bool
forall num.
(Eq num, IntegralExp num) =>
LMAD num -> Slice num -> Bool
slicePreservesContiguous LMAD num
lmad ([DimIndex num] -> Slice num
forall d. [DimIndex d] -> Slice d
Slice [DimIndex num]
is')
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ LMAD num -> Slice num -> Bool
forall num.
(Eq num, IntegralExp num) =>
LMAD num -> Slice num -> Bool
harmlessRotation LMAD num
lmad ([DimIndex num] -> Slice num
forall d. [DimIndex d] -> Slice d
Slice [DimIndex num]
is')
let lmad' :: LMAD num
lmad' = (LMAD num -> (DimIndex num, LMADDim num) -> LMAD num)
-> LMAD num -> [(DimIndex num, LMADDim num)] -> LMAD num
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl LMAD num -> (DimIndex num, LMADDim num) -> LMAD num
forall num.
(Eq num, IntegralExp num) =>
LMAD num -> (DimIndex num, LMADDim num) -> LMAD num
sliceOne (num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD (LMAD num -> num
forall num. LMAD num -> num
lmadOffset LMAD num
lmad) []) ([(DimIndex num, LMADDim num)] -> LMAD num)
-> [(DimIndex num, LMADDim num)] -> LMAD num
forall a b. (a -> b) -> a -> b
$ [DimIndex num] -> [LMADDim num] -> [(DimIndex num, LMADDim num)]
forall a b. [a] -> [b] -> [(a, b)]
zip [DimIndex num]
is' [LMADDim num]
ldims
perm' :: Permutation
perm' =
Permutation -> Permutation -> Permutation
forall (t :: * -> *) b (t :: * -> *).
(Foldable t, Ord b, Foldable t, Num b) =>
t b -> t b -> [b]
updatePerm Permutation
perm (Permutation -> Permutation) -> Permutation -> Permutation
forall a b. (a -> b) -> a -> b
$
((Int, DimIndex num) -> Int)
-> [(Int, DimIndex num)] -> Permutation
forall a b. (a -> b) -> [a] -> [b]
map (Int, DimIndex num) -> Int
forall a b. (a, b) -> a
fst ([(Int, DimIndex num)] -> Permutation)
-> [(Int, DimIndex num)] -> Permutation
forall a b. (a -> b) -> a -> b
$
((Int, DimIndex num) -> Bool)
-> [(Int, DimIndex num)] -> [(Int, DimIndex num)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Maybe num -> Bool
forall a. Maybe a -> Bool
isJust (Maybe num -> Bool)
-> ((Int, DimIndex num) -> Maybe num)
-> (Int, DimIndex num)
-> Bool
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. DimIndex num -> Maybe num
forall d. DimIndex d -> Maybe d
dimFix (DimIndex num -> Maybe num)
-> ((Int, DimIndex num) -> DimIndex num)
-> (Int, DimIndex num)
-> Maybe num
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (Int, DimIndex num) -> DimIndex num
forall a b. (a, b) -> b
snd) ([(Int, DimIndex num)] -> [(Int, DimIndex num)])
-> [(Int, DimIndex num)] -> [(Int, DimIndex num)]
forall a b. (a -> b) -> a -> b
$
Permutation -> [DimIndex num] -> [(Int, DimIndex num)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 .. [DimIndex num] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex num]
is' Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] [DimIndex num]
is'
IxFun num -> Maybe (IxFun num)
forall (m :: * -> *) a. Monad m => a -> m a
return (IxFun num -> Maybe (IxFun num)) -> IxFun num -> Maybe (IxFun num)
forall a b. (a -> b) -> a -> b
$ NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (Permutation -> LMAD num -> LMAD num
forall num. Permutation -> LMAD num -> LMAD num
setLMADPermutation Permutation
perm' LMAD num
lmad' LMAD num -> [LMAD num] -> NonEmpty (LMAD num)
forall a. a -> [a] -> NonEmpty a
:| [LMAD num]
lmads) Shape num
oshp Bool
cg'
where
updatePerm :: t b -> t b -> [b]
updatePerm t b
ps t b
inds = (b -> [b]) -> t b -> [b]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap b -> [b]
decrease t b
ps
where
decrease :: b -> [b]
decrease b
p =
let f :: a -> b -> a
f a
n b
i
| b
i b -> b -> Bool
forall a. Eq a => a -> a -> Bool
== b
p = -a
1
| b
i b -> b -> Bool
forall a. Ord a => a -> a -> Bool
> b
p = a
n
| a
n a -> a -> Bool
forall a. Eq a => a -> a -> Bool
/= -a
1 = a
n a -> a -> a
forall a. Num a => a -> a -> a
+ a
1
| Bool
otherwise = a
n
d :: b
d = (b -> b -> b) -> b -> t b -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl b -> b -> b
forall a. (Num a, Eq a) => a -> b -> a
f b
0 t b
inds
in [b
p b -> b -> b
forall a. Num a => a -> a -> a
- b
d | b
d b -> b -> Bool
forall a. Eq a => a -> a -> Bool
/= -b
1]
harmlessRotation' ::
(Eq num, IntegralExp num) =>
LMADDim num ->
DimIndex num ->
Bool
harmlessRotation' :: LMADDim num -> DimIndex num -> Bool
harmlessRotation' LMADDim num
_ (DimFix num
_) = Bool
True
harmlessRotation' (LMADDim num
0 num
_ num
_ Int
_ Monotonicity
_) DimIndex num
_ = Bool
True
harmlessRotation' (LMADDim num
_ num
0 num
_ Int
_ Monotonicity
_) DimIndex num
_ = Bool
True
harmlessRotation' (LMADDim num
_ num
_ num
n Int
_ Monotonicity
_) DimIndex num
dslc
| DimIndex num
dslc DimIndex num -> DimIndex num -> Bool
forall a. Eq a => a -> a -> Bool
== num -> num -> num -> DimIndex num
forall d. d -> d -> d -> DimIndex d
DimSlice (num
n num -> num -> num
forall a. Num a => a -> a -> a
- num
1) num
n (-num
1)
Bool -> Bool -> Bool
|| DimIndex num
dslc DimIndex num -> DimIndex num -> Bool
forall a. Eq a => a -> a -> Bool
== num -> num -> DimIndex num
forall d. Num d => d -> d -> DimIndex d
unitSlice num
0 num
n =
Bool
True
harmlessRotation' LMADDim num
_ DimIndex num
_ = Bool
False
harmlessRotation ::
(Eq num, IntegralExp num) =>
LMAD num ->
Slice num ->
Bool
harmlessRotation :: LMAD num -> Slice num -> Bool
harmlessRotation (LMAD num
_ [LMADDim num]
dims) (Slice [DimIndex num]
iss) =
[Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ([Bool] -> Bool) -> [Bool] -> Bool
forall a b. (a -> b) -> a -> b
$ (LMADDim num -> DimIndex num -> Bool)
-> [LMADDim num] -> [DimIndex num] -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith LMADDim num -> DimIndex num -> Bool
forall num.
(Eq num, IntegralExp num) =>
LMADDim num -> DimIndex num -> Bool
harmlessRotation' [LMADDim num]
dims [DimIndex num]
iss
sliceOne ::
(Eq num, IntegralExp num) =>
LMAD num ->
(DimIndex num, LMADDim num) ->
LMAD num
sliceOne :: LMAD num -> (DimIndex num, LMADDim num) -> LMAD num
sliceOne (LMAD num
off [LMADDim num]
dims) (DimFix num
i, LMADDim num
s num
r num
n Int
_ Monotonicity
_) =
num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD (num
off num -> num -> num
forall a. Num a => a -> a -> a
+ (num, num, num) -> num -> num
forall num.
(Eq num, IntegralExp num) =>
(num, num, num) -> num -> num
flatOneDim (num
s, num
r, num
n) num
i) [LMADDim num]
dims
sliceOne (LMAD num
off [LMADDim num]
dims) (DimSlice num
_ num
ne num
_, LMADDim num
0 num
_ num
_ Int
p Monotonicity
_) =
num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off ([LMADDim num]
dims [LMADDim num] -> [LMADDim num] -> [LMADDim num]
forall a. [a] -> [a] -> [a]
++ [num -> num -> num -> Int -> Monotonicity -> LMADDim num
forall num. num -> num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim num
0 num
0 num
ne Int
p Monotonicity
Unknown])
sliceOne (LMAD num
off [LMADDim num]
dims) (DimIndex num
dmind, dim :: LMADDim num
dim@(LMADDim num
_ num
_ num
n Int
_ Monotonicity
_))
| DimIndex num
dmind DimIndex num -> DimIndex num -> Bool
forall a. Eq a => a -> a -> Bool
== num -> num -> DimIndex num
forall d. Num d => d -> d -> DimIndex d
unitSlice num
0 num
n = num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off ([LMADDim num]
dims [LMADDim num] -> [LMADDim num] -> [LMADDim num]
forall a. [a] -> [a] -> [a]
++ [LMADDim num
dim])
sliceOne (LMAD num
off [LMADDim num]
dims) (DimIndex num
dmind, LMADDim num
s num
r num
n Int
p Monotonicity
m)
| DimIndex num
dmind DimIndex num -> DimIndex num -> Bool
forall a. Eq a => a -> a -> Bool
== num -> num -> num -> DimIndex num
forall d. d -> d -> d -> DimIndex d
DimSlice (num
n num -> num -> num
forall a. Num a => a -> a -> a
- num
1) num
n (-num
1) =
let r' :: num
r' = if num
r num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0 then num
0 else num
n num -> num -> num
forall a. Num a => a -> a -> a
- num
r
off' :: num
off' = num
off num -> num -> num
forall a. Num a => a -> a -> a
+ (num, num, num) -> num -> num
forall num.
(Eq num, IntegralExp num) =>
(num, num, num) -> num -> num
flatOneDim (num
s, num
0, num
n) (num
n num -> num -> num
forall a. Num a => a -> a -> a
- num
1)
in num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off' ([LMADDim num]
dims [LMADDim num] -> [LMADDim num] -> [LMADDim num]
forall a. [a] -> [a] -> [a]
++ [num -> num -> num -> Int -> Monotonicity -> LMADDim num
forall num. num -> num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim (num
s num -> num -> num
forall a. Num a => a -> a -> a
* (-num
1)) num
r' num
n Int
p (Monotonicity -> Monotonicity
invertMonotonicity Monotonicity
m)])
sliceOne (LMAD num
off [LMADDim num]
dims) (DimSlice num
b num
ne num
0, LMADDim num
s num
r num
n Int
p Monotonicity
_) =
num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD (num
off num -> num -> num
forall a. Num a => a -> a -> a
+ (num, num, num) -> num -> num
forall num.
(Eq num, IntegralExp num) =>
(num, num, num) -> num -> num
flatOneDim (num
s, num
r, num
n) num
b) ([LMADDim num]
dims [LMADDim num] -> [LMADDim num] -> [LMADDim num]
forall a. [a] -> [a] -> [a]
++ [num -> num -> num -> Int -> Monotonicity -> LMADDim num
forall num. num -> num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim num
0 num
0 num
ne Int
p Monotonicity
Unknown])
sliceOne (LMAD num
off [LMADDim num]
dims) (DimSlice num
bs num
ns num
ss, LMADDim num
s num
0 num
_ Int
p Monotonicity
m) =
let m' :: Monotonicity
m' = case num -> Maybe Int
forall e. IntegralExp e => e -> Maybe Int
sgn num
ss of
Just Int
1 -> Monotonicity
m
Just (-1) -> Monotonicity -> Monotonicity
invertMonotonicity Monotonicity
m
Maybe Int
_ -> Monotonicity
Unknown
in num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD (num
off num -> num -> num
forall a. Num a => a -> a -> a
+ num
s num -> num -> num
forall a. Num a => a -> a -> a
* num
bs) ([LMADDim num]
dims [LMADDim num] -> [LMADDim num] -> [LMADDim num]
forall a. [a] -> [a] -> [a]
++ [num -> num -> num -> Int -> Monotonicity -> LMADDim num
forall num. num -> num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim (num
ss num -> num -> num
forall a. Num a => a -> a -> a
* num
s) num
0 num
ns Int
p Monotonicity
m'])
sliceOne LMAD num
_ (DimIndex num, LMADDim num)
_ = String -> LMAD num
forall a. HasCallStack => String -> a
error String
"slice: reached impossible case"
slicePreservesContiguous ::
(Eq num, IntegralExp num) =>
LMAD num ->
Slice num ->
Bool
slicePreservesContiguous :: LMAD num -> Slice num -> Bool
slicePreservesContiguous (LMAD num
_ [LMADDim num]
dims) (Slice [DimIndex num]
slc) =
let ([LMADDim num]
dims', [DimIndex num]
slc') =
[(LMADDim num, DimIndex num)] -> ([LMADDim num], [DimIndex num])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(LMADDim num, DimIndex num)] -> ([LMADDim num], [DimIndex num]))
-> [(LMADDim num, DimIndex num)] -> ([LMADDim num], [DimIndex num])
forall a b. (a -> b) -> a -> b
$
((LMADDim num, DimIndex num) -> Bool)
-> [(LMADDim num, DimIndex num)] -> [(LMADDim num, DimIndex num)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((num -> num -> Bool
forall a. Eq a => a -> a -> Bool
/= num
0) (num -> Bool)
-> ((LMADDim num, DimIndex num) -> num)
-> (LMADDim num, DimIndex num)
-> Bool
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. LMADDim num -> num
forall num. LMADDim num -> num
ldStride (LMADDim num -> num)
-> ((LMADDim num, DimIndex num) -> LMADDim num)
-> (LMADDim num, DimIndex num)
-> num
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (LMADDim num, DimIndex num) -> LMADDim num
forall a b. (a, b) -> a
fst) ([(LMADDim num, DimIndex num)] -> [(LMADDim num, DimIndex num)])
-> [(LMADDim num, DimIndex num)] -> [(LMADDim num, DimIndex num)]
forall a b. (a -> b) -> a -> b
$
[LMADDim num] -> [DimIndex num] -> [(LMADDim num, DimIndex num)]
forall a b. [a] -> [b] -> [(a, b)]
zip [LMADDim num]
dims ([DimIndex num] -> [(LMADDim num, DimIndex num)])
-> [DimIndex num] -> [(LMADDim num, DimIndex num)]
forall a b. (a -> b) -> a -> b
$ (DimIndex num -> DimIndex num) -> [DimIndex num] -> [DimIndex num]
forall a b. (a -> b) -> [a] -> [b]
map DimIndex num -> DimIndex num
forall num.
(Eq num, IntegralExp num) =>
DimIndex num -> DimIndex num
normIndex [DimIndex num]
slc
(Bool
_, Bool
success) =
((Bool, Bool) -> (DimIndex num, LMADDim num) -> (Bool, Bool))
-> (Bool, Bool) -> [(DimIndex num, LMADDim num)] -> (Bool, Bool)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
( \(Bool
found, Bool
res) (DimIndex num
slcdim, LMADDim num
_ num
r num
n Int
_ Monotonicity
_) ->
case (DimIndex num
slcdim, Bool
found) of
(DimFix {}, Bool
True) -> (Bool
found, Bool
False)
(DimFix {}, Bool
False) -> (Bool
found, Bool
res)
(DimSlice num
_ num
ne num
ds, Bool
False) ->
let res' :: Bool
res' = (num
r num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0 Bool -> Bool -> Bool
|| num
n num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
ne) Bool -> Bool -> Bool
&& (num
ds num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
1 Bool -> Bool -> Bool
|| num
ds num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== -num
1)
in (Bool
True, Bool
res Bool -> Bool -> Bool
&& Bool
res')
(DimSlice num
_ num
ne num
ds, Bool
True) ->
let res' :: Bool
res' = (num
n num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
ne) Bool -> Bool -> Bool
&& (num
ds num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
1 Bool -> Bool -> Bool
|| num
ds num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== -num
1)
in (Bool
found, Bool
res Bool -> Bool -> Bool
&& Bool
res')
)
(Bool
False, Bool
True)
([(DimIndex num, LMADDim num)] -> (Bool, Bool))
-> [(DimIndex num, LMADDim num)] -> (Bool, Bool)
forall a b. (a -> b) -> a -> b
$ [DimIndex num] -> [LMADDim num] -> [(DimIndex num, LMADDim num)]
forall a b. [a] -> [b] -> [(a, b)]
zip [DimIndex num]
slc' [LMADDim num]
dims'
in Bool
success
normIndex ::
(Eq num, IntegralExp num) =>
DimIndex num ->
DimIndex num
normIndex :: DimIndex num -> DimIndex num
normIndex (DimSlice num
b num
1 num
_) = num -> DimIndex num
forall d. d -> DimIndex d
DimFix num
b
normIndex (DimSlice num
b num
_ num
0) = num -> DimIndex num
forall d. d -> DimIndex d
DimFix num
b
normIndex DimIndex num
d = DimIndex num
d
slice ::
(Eq num, IntegralExp num) =>
IxFun num ->
Slice num ->
IxFun num
slice :: IxFun num -> Slice num -> IxFun num
slice ixfun :: IxFun num
ixfun@(IxFun (lmad :: LMAD num
lmad@(LMAD num
_ [LMADDim num]
_) :| [LMAD num]
lmads) Shape num
oshp Bool
cg) Slice num
dim_slices
| Slice num -> [DimIndex num]
forall d. Slice d -> [DimIndex d]
unSlice Slice num
dim_slices [DimIndex num] -> [DimIndex num] -> Bool
forall a. Eq a => a -> a -> Bool
== (num -> DimIndex num) -> Shape num -> [DimIndex num]
forall a b. (a -> b) -> [a] -> [b]
map (num -> num -> DimIndex num
forall d. Num d => d -> d -> DimIndex d
unitSlice num
0) (IxFun num -> Shape num
forall num. (Eq num, IntegralExp num) => IxFun num -> Shape num
shape IxFun num
ixfun) = IxFun num
ixfun
| Just IxFun num
ixfun' <- IxFun num -> Slice num -> Maybe (IxFun num)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> Maybe (IxFun num)
sliceOneLMAD IxFun num
ixfun Slice num
dim_slices = IxFun num
ixfun'
| Bool
otherwise =
case IxFun num -> Slice num -> Maybe (IxFun num)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> Maybe (IxFun num)
sliceOneLMAD (Shape num -> IxFun num
forall num. IntegralExp num => Shape num -> IxFun num
iota (LMAD num -> Shape num
forall num. (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShape LMAD num
lmad)) Slice num
dim_slices of
Just (IxFun (LMAD num
lmad' :| []) Shape num
_ Bool
cg') ->
NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (LMAD num
lmad' LMAD num -> [LMAD num] -> NonEmpty (LMAD num)
forall a. a -> [a] -> NonEmpty a
:| LMAD num
lmad LMAD num -> [LMAD num] -> [LMAD num]
forall a. a -> [a] -> [a]
: [LMAD num]
lmads) Shape num
oshp (Bool
cg Bool -> Bool -> Bool
&& Bool
cg')
Maybe (IxFun num)
_ -> String -> IxFun num
forall a. HasCallStack => String -> a
error String
"slice: reached impossible case"
flatSlice ::
(Eq num, IntegralExp num) =>
IxFun num ->
FlatSlice num ->
IxFun num
flatSlice :: IxFun num -> FlatSlice num -> IxFun num
flatSlice ixfun :: IxFun num
ixfun@(IxFun (LMAD num
offset (LMADDim num
dim : [LMADDim num]
dims) :| [LMAD num]
lmads) Shape num
oshp Bool
cg) (FlatSlice num
new_offset [FlatDimIndex num]
is)
| IxFun num -> Bool
forall a. IxFun a -> Bool
hasContiguousPerm IxFun num
ixfun,
LMADDim num -> num
forall num. LMADDim num -> num
ldRotate LMADDim num
dim num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0 =
let lmad :: LMAD num
lmad =
num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD
(num
offset num -> num -> num
forall a. Num a => a -> a -> a
+ num
new_offset num -> num -> num
forall a. Num a => a -> a -> a
* LMADDim num -> num
forall num. LMADDim num -> num
ldStride LMADDim num
dim)
( (FlatDimIndex num -> LMADDim num)
-> [FlatDimIndex num] -> [LMADDim num]
forall a b. (a -> b) -> [a] -> [b]
map (num -> FlatDimIndex num -> LMADDim num
forall num.
(Eq num, Num num) =>
num -> FlatDimIndex num -> LMADDim num
helper (num -> FlatDimIndex num -> LMADDim num)
-> num -> FlatDimIndex num -> LMADDim num
forall a b. (a -> b) -> a -> b
$ LMADDim num -> num
forall num. LMADDim num -> num
ldStride LMADDim num
dim) [FlatDimIndex num]
is
[LMADDim num] -> [LMADDim num] -> [LMADDim num]
forall a. Semigroup a => a -> a -> a
<> [LMADDim num]
dims
)
LMAD num -> (LMAD num -> LMAD num) -> LMAD num
forall a b. a -> (a -> b) -> b
& Permutation -> LMAD num -> LMAD num
forall num. Permutation -> LMAD num -> LMAD num
setLMADPermutation [Int
0 ..]
in NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (LMAD num
lmad LMAD num -> [LMAD num] -> NonEmpty (LMAD num)
forall a. a -> [a] -> NonEmpty a
:| [LMAD num]
lmads) Shape num
oshp Bool
cg
where
helper :: num -> FlatDimIndex num -> LMADDim num
helper num
s0 (FlatDimIndex num
n num
s) =
let new_mon :: Monotonicity
new_mon = if num
s0 num -> num -> num
forall a. Num a => a -> a -> a
* num
s num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
1 then Monotonicity
Inc else Monotonicity
Unknown
in num -> num -> num -> Int -> Monotonicity -> LMADDim num
forall num. num -> num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim (num
s0 num -> num -> num
forall a. Num a => a -> a -> a
* num
s) num
0 num
n Int
0 Monotonicity
new_mon
flatSlice (IxFun (LMAD num
lmad :| [LMAD num]
lmads) Shape num
oshp Bool
cg) s :: FlatSlice num
s@(FlatSlice num
new_offset [FlatDimIndex num]
_) =
NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD (num
new_offset num -> num -> num
forall a. Num a => a -> a -> a
* num
base_stride) ([LMADDim num]
new_dims [LMADDim num] -> [LMADDim num] -> [LMADDim num]
forall a. Semigroup a => a -> a -> a
<> [LMADDim num]
tail_dims) LMAD num -> [LMAD num] -> NonEmpty (LMAD num)
forall a. a -> [a] -> NonEmpty a
:| LMAD num
lmad LMAD num -> [LMAD num] -> [LMAD num]
forall a. a -> [a] -> [a]
: [LMAD num]
lmads) Shape num
oshp Bool
cg
where
tail_shapes :: Shape num
tail_shapes = Shape num -> Shape num
forall a. [a] -> [a]
tail (Shape num -> Shape num) -> Shape num -> Shape num
forall a b. (a -> b) -> a -> b
$ LMAD num -> Shape num
forall num. (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShape LMAD num
lmad
base_stride :: num
base_stride = Shape num -> num
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product Shape num
tail_shapes
tail_strides :: Shape num
tail_strides = Shape num -> Shape num
forall a. [a] -> [a]
tail (Shape num -> Shape num) -> Shape num -> Shape num
forall a b. (a -> b) -> a -> b
$ (num -> num -> num) -> num -> Shape num -> Shape num
forall a b. (a -> b -> b) -> b -> [a] -> [b]
scanr num -> num -> num
forall a. Num a => a -> a -> a
(*) num
1 Shape num
tail_shapes
tail_dims :: [LMADDim num]
tail_dims = (num -> num -> num -> Int -> Monotonicity -> LMADDim num)
-> Shape num
-> Shape num
-> Shape num
-> Permutation
-> [Monotonicity]
-> [LMADDim num]
forall a b c d e f.
(a -> b -> c -> d -> e -> f)
-> [a] -> [b] -> [c] -> [d] -> [e] -> [f]
zipWith5 num -> num -> num -> Int -> Monotonicity -> LMADDim num
forall num. num -> num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim Shape num
tail_strides (num -> Shape num
forall a. a -> [a]
repeat num
0) Shape num
tail_shapes [Shape num -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape num
new_shapes ..] (Monotonicity -> [Monotonicity]
forall a. a -> [a]
repeat Monotonicity
Inc)
new_shapes :: Shape num
new_shapes = FlatSlice num -> Shape num
forall d. FlatSlice d -> [d]
flatSliceDims FlatSlice num
s
new_strides :: Shape num
new_strides = (num -> num) -> Shape num -> Shape num
forall a b. (a -> b) -> [a] -> [b]
map (num -> num -> num
forall a. Num a => a -> a -> a
* num
base_stride) (Shape num -> Shape num) -> Shape num -> Shape num
forall a b. (a -> b) -> a -> b
$ FlatSlice num -> Shape num
forall d. FlatSlice d -> [d]
flatSliceStrides FlatSlice num
s
new_dims :: [LMADDim num]
new_dims = (num -> num -> num -> Int -> Monotonicity -> LMADDim num)
-> Shape num
-> Shape num
-> Shape num
-> Permutation
-> [Monotonicity]
-> [LMADDim num]
forall a b c d e f.
(a -> b -> c -> d -> e -> f)
-> [a] -> [b] -> [c] -> [d] -> [e] -> [f]
zipWith5 num -> num -> num -> Int -> Monotonicity -> LMADDim num
forall num. num -> num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim Shape num
new_strides (num -> Shape num
forall a. a -> [a]
repeat num
0) Shape num
new_shapes [Int
0 ..] (Monotonicity -> [Monotonicity]
forall a. a -> [a]
repeat Monotonicity
Inc)
reshapeCoercion ::
(Eq num, IntegralExp num) =>
IxFun num ->
ShapeChange num ->
Maybe (IxFun num)
reshapeCoercion :: IxFun num -> ShapeChange num -> Maybe (IxFun num)
reshapeCoercion (IxFun (lmad :: LMAD num
lmad@(LMAD num
off [LMADDim num]
dims) :| [LMAD num]
lmads) Shape num
oldbase Bool
cg) ShapeChange num
newshape = do
let perm :: Permutation
perm = LMAD num -> Permutation
forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad
(ShapeChange num
head_coercions, ShapeChange num
reshapes, ShapeChange num
tail_coercions) <- ShapeChange num
-> Maybe (ShapeChange num, ShapeChange num, ShapeChange num)
forall num.
(Eq num, IntegralExp num) =>
ShapeChange num
-> Maybe (ShapeChange num, ShapeChange num, ShapeChange num)
splitCoercions ShapeChange num
newshape
let hd_len :: Int
hd_len = ShapeChange num -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange num
head_coercions
num_coercions :: Int
num_coercions = Int
hd_len Int -> Int -> Int
forall a. Num a => a -> a -> a
+ ShapeChange num -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange num
tail_coercions
dims' :: [LMADDim num]
dims' = Permutation -> [LMADDim num] -> [LMADDim num]
forall a. Permutation -> [a] -> [a]
permuteFwd Permutation
perm [LMADDim num]
dims
mid_dims :: [LMADDim num]
mid_dims = Int -> [LMADDim num] -> [LMADDim num]
forall a. Int -> [a] -> [a]
take ([LMADDim num] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
num_coercions) ([LMADDim num] -> [LMADDim num]) -> [LMADDim num] -> [LMADDim num]
forall a b. (a -> b) -> a -> b
$ Int -> [LMADDim num] -> [LMADDim num]
forall a. Int -> [a] -> [a]
drop Int
hd_len [LMADDim num]
dims'
num_rshps :: Int
num_rshps = ShapeChange num -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange num
reshapes
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Int
num_rshps Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 Bool -> Bool -> Bool
|| (Int
num_rshps Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 Bool -> Bool -> Bool
&& [LMADDim num] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
mid_dims Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1))
let dims'' :: [LMADDim num]
dims'' =
Permutation -> [LMADDim num] -> [LMADDim num]
forall a. Permutation -> [a] -> [a]
permuteInv Permutation
perm ([LMADDim num] -> [LMADDim num]) -> [LMADDim num] -> [LMADDim num]
forall a b. (a -> b) -> a -> b
$
(LMADDim num -> num -> LMADDim num)
-> [LMADDim num] -> Shape num -> [LMADDim num]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
(\LMADDim num
ld num
n -> LMADDim num
ld {ldShape :: num
ldShape = num
n})
[LMADDim num]
dims'
(ShapeChange num -> Shape num
forall d. ShapeChange d -> [d]
newDims ShapeChange num
newshape)
lmad' :: LMAD num
lmad' = num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off [LMADDim num]
dims''
IxFun num -> Maybe (IxFun num)
forall (m :: * -> *) a. Monad m => a -> m a
return (IxFun num -> Maybe (IxFun num)) -> IxFun num -> Maybe (IxFun num)
forall a b. (a -> b) -> a -> b
$ NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (LMAD num
lmad' LMAD num -> [LMAD num] -> NonEmpty (LMAD num)
forall a. a -> [a] -> NonEmpty a
:| [LMAD num]
lmads) Shape num
oldbase Bool
cg
reshapeOneLMAD ::
(Eq num, IntegralExp num) =>
IxFun num ->
ShapeChange num ->
Maybe (IxFun num)
reshapeOneLMAD :: IxFun num -> ShapeChange num -> Maybe (IxFun num)
reshapeOneLMAD ixfun :: IxFun num
ixfun@(IxFun (lmad :: LMAD num
lmad@(LMAD num
off [LMADDim num]
dims) :| [LMAD num]
lmads) Shape num
oldbase Bool
cg) ShapeChange num
newshape = do
let perm :: Permutation
perm = LMAD num -> Permutation
forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad
(ShapeChange num
head_coercions, ShapeChange num
reshapes, ShapeChange num
tail_coercions) <- ShapeChange num
-> Maybe (ShapeChange num, ShapeChange num, ShapeChange num)
forall num.
(Eq num, IntegralExp num) =>
ShapeChange num
-> Maybe (ShapeChange num, ShapeChange num, ShapeChange num)
splitCoercions ShapeChange num
newshape
let hd_len :: Int
hd_len = ShapeChange num -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange num
head_coercions
num_coercions :: Int
num_coercions = Int
hd_len Int -> Int -> Int
forall a. Num a => a -> a -> a
+ ShapeChange num -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange num
tail_coercions
dims_perm :: [LMADDim num]
dims_perm = Permutation -> [LMADDim num] -> [LMADDim num]
forall a. Permutation -> [a] -> [a]
permuteFwd Permutation
perm [LMADDim num]
dims
mid_dims :: [LMADDim num]
mid_dims = Int -> [LMADDim num] -> [LMADDim num]
forall a. Int -> [a] -> [a]
take ([LMADDim num] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
num_coercions) ([LMADDim num] -> [LMADDim num]) -> [LMADDim num] -> [LMADDim num]
forall a b. (a -> b) -> a -> b
$ Int -> [LMADDim num] -> [LMADDim num]
forall a. Int -> [a] -> [a]
drop Int
hd_len [LMADDim num]
dims_perm
mon :: Monotonicity
mon = Bool -> IxFun num -> Monotonicity
forall num.
(Eq num, IntegralExp num) =>
Bool -> IxFun num -> Monotonicity
ixfunMonotonicityRots Bool
True IxFun num
ixfun
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$
(LMADDim num -> Bool) -> [LMADDim num] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\(LMADDim num
s num
r num
_ Int
_ Monotonicity
_) -> num
s num -> num -> Bool
forall a. Eq a => a -> a -> Bool
/= num
0 Bool -> Bool -> Bool
&& num
r num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0) [LMADDim num]
mid_dims
Bool -> Bool -> Bool
&&
Int -> Permutation -> Bool
forall a. (Eq a, Num a, Enum a) => a -> [a] -> Bool
consecutive Int
hd_len ((LMADDim num -> Int) -> [LMADDim num] -> Permutation
forall a b. (a -> b) -> [a] -> [b]
map LMADDim num -> Int
forall num. LMADDim num -> Int
ldPerm [LMADDim num]
mid_dims)
Bool -> Bool -> Bool
&&
IxFun num -> Bool
forall a. IxFun a -> Bool
hasContiguousPerm IxFun num
ixfun
Bool -> Bool -> Bool
&& Bool
cg
Bool -> Bool -> Bool
&& (Monotonicity
mon Monotonicity -> Monotonicity -> Bool
forall a. Eq a => a -> a -> Bool
== Monotonicity
Inc Bool -> Bool -> Bool
|| Monotonicity
mon Monotonicity -> Monotonicity -> Bool
forall a. Eq a => a -> a -> Bool
== Monotonicity
Dec)
let rsh_len :: Int
rsh_len = ShapeChange num -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange num
reshapes
diff :: Int
diff = ShapeChange num -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange num
newshape Int -> Int -> Int
forall a. Num a => a -> a -> a
- [LMADDim num] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims
iota_shape :: Permutation
iota_shape = [Int
0 .. ShapeChange num -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange num
newshape Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
perm' :: Permutation
perm' =
(Int -> Int) -> Permutation -> Permutation
forall a b. (a -> b) -> [a] -> [b]
map
( \Int
i ->
let ind :: Int
ind =
if Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
hd_len
then Int
i
else Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
diff
in if (Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
hd_len) Bool -> Bool -> Bool
&& (Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
hd_len Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
rsh_len)
then Int
i
else
let p :: Int
p = LMADDim num -> Int
forall num. LMADDim num -> Int
ldPerm ([LMADDim num]
dims [LMADDim num] -> Int -> LMADDim num
forall a. [a] -> Int -> a
!! Int
ind)
in if Int
p Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
hd_len
then Int
p
else Int
p Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
diff
)
Permutation
iota_shape
([(Int, (num, num))]
support_inds, [(Int, num)]
repeat_inds) =
(([(Int, (num, num))], [(Int, num)])
-> (Int, DimChange num, Int)
-> ([(Int, (num, num))], [(Int, num)]))
-> ([(Int, (num, num))], [(Int, num)])
-> [(Int, DimChange num, Int)]
-> ([(Int, (num, num))], [(Int, num)])
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
( \([(Int, (num, num))]
sup, [(Int, num)]
rpt) (Int
i, DimChange num
shpdim, Int
ip) ->
case (Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
hd_len, Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
hd_len Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
rsh_len, DimChange num
shpdim) of
(Bool
True, Bool
_, DimCoercion num
n) ->
case [LMADDim num]
dims_perm [LMADDim num] -> Int -> LMADDim num
forall a. [a] -> Int -> a
!! Int
i of
(LMADDim num
0 num
_ num
_ Int
_ Monotonicity
_) -> ([(Int, (num, num))]
sup, (Int
ip, num
n) (Int, num) -> [(Int, num)] -> [(Int, num)]
forall a. a -> [a] -> [a]
: [(Int, num)]
rpt)
(LMADDim num
_ num
r num
_ Int
_ Monotonicity
_) -> ((Int
ip, (num
r, num
n)) (Int, (num, num)) -> [(Int, (num, num))] -> [(Int, (num, num))]
forall a. a -> [a] -> [a]
: [(Int, (num, num))]
sup, [(Int, num)]
rpt)
(Bool
_, Bool
True, DimCoercion num
n) ->
case [LMADDim num]
dims_perm [LMADDim num] -> Int -> LMADDim num
forall a. [a] -> Int -> a
!! (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
diff) of
(LMADDim num
0 num
_ num
_ Int
_ Monotonicity
_) -> ([(Int, (num, num))]
sup, (Int
ip, num
n) (Int, num) -> [(Int, num)] -> [(Int, num)]
forall a. a -> [a] -> [a]
: [(Int, num)]
rpt)
(LMADDim num
_ num
r num
_ Int
_ Monotonicity
_) -> ((Int
ip, (num
r, num
n)) (Int, (num, num)) -> [(Int, (num, num))] -> [(Int, (num, num))]
forall a. a -> [a] -> [a]
: [(Int, (num, num))]
sup, [(Int, num)]
rpt)
(Bool
False, Bool
False, DimChange num
_) ->
((Int
ip, (num
0, DimChange num -> num
forall d. DimChange d -> d
newDim DimChange num
shpdim)) (Int, (num, num)) -> [(Int, (num, num))] -> [(Int, (num, num))]
forall a. a -> [a] -> [a]
: [(Int, (num, num))]
sup, [(Int, num)]
rpt)
(Bool, Bool, DimChange num)
_ -> String -> ([(Int, (num, num))], [(Int, num)])
forall a. HasCallStack => String -> a
error String
"reshape: reached impossible case"
)
([], [])
([(Int, DimChange num, Int)]
-> ([(Int, (num, num))], [(Int, num)]))
-> [(Int, DimChange num, Int)]
-> ([(Int, (num, num))], [(Int, num)])
forall a b. (a -> b) -> a -> b
$ [(Int, DimChange num, Int)] -> [(Int, DimChange num, Int)]
forall a. [a] -> [a]
reverse ([(Int, DimChange num, Int)] -> [(Int, DimChange num, Int)])
-> [(Int, DimChange num, Int)] -> [(Int, DimChange num, Int)]
forall a b. (a -> b) -> a -> b
$ Permutation
-> ShapeChange num -> Permutation -> [(Int, DimChange num, Int)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 Permutation
iota_shape ShapeChange num
newshape Permutation
perm'
(Permutation
sup_inds, [(num, num)]
support) = [(Int, (num, num))] -> (Permutation, [(num, num)])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Int, (num, num))] -> (Permutation, [(num, num)]))
-> [(Int, (num, num))] -> (Permutation, [(num, num)])
forall a b. (a -> b) -> a -> b
$ ((Int, (num, num)) -> (Int, (num, num)) -> Ordering)
-> [(Int, (num, num))] -> [(Int, (num, num))]
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (Int -> Int -> Ordering)
-> ((Int, (num, num)) -> Int)
-> (Int, (num, num))
-> (Int, (num, num))
-> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (Int, (num, num)) -> Int
forall a b. (a, b) -> a
fst) [(Int, (num, num))]
support_inds
(Permutation
rpt_inds, Shape num
repeats) = [(Int, num)] -> (Permutation, Shape num)
forall a b. [(a, b)] -> ([a], [b])
unzip [(Int, num)]
repeat_inds
LMAD num
off' [LMADDim num]
dims_sup = Monotonicity -> num -> [(num, num)] -> LMAD num
forall num.
IntegralExp num =>
Monotonicity -> num -> [(num, num)] -> LMAD num
makeRotIota Monotonicity
mon num
off [(num, num)]
support
repeats' :: [LMADDim num]
repeats' = (num -> LMADDim num) -> Shape num -> [LMADDim num]
forall a b. (a -> b) -> [a] -> [b]
map (\num
n -> num -> num -> num -> Int -> Monotonicity -> LMADDim num
forall num. num -> num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim num
0 num
0 num
n Int
0 Monotonicity
Unknown) Shape num
repeats
dims' :: [LMADDim num]
dims' =
((Int, LMADDim num) -> LMADDim num)
-> [(Int, LMADDim num)] -> [LMADDim num]
forall a b. (a -> b) -> [a] -> [b]
map (Int, LMADDim num) -> LMADDim num
forall a b. (a, b) -> b
snd ([(Int, LMADDim num)] -> [LMADDim num])
-> [(Int, LMADDim num)] -> [LMADDim num]
forall a b. (a -> b) -> a -> b
$
((Int, LMADDim num) -> (Int, LMADDim num) -> Ordering)
-> [(Int, LMADDim num)] -> [(Int, LMADDim num)]
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (Int -> Int -> Ordering)
-> ((Int, LMADDim num) -> Int)
-> (Int, LMADDim num)
-> (Int, LMADDim num)
-> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (Int, LMADDim num) -> Int
forall a b. (a, b) -> a
fst) ([(Int, LMADDim num)] -> [(Int, LMADDim num)])
-> [(Int, LMADDim num)] -> [(Int, LMADDim num)]
forall a b. (a -> b) -> a -> b
$
Permutation -> [LMADDim num] -> [(Int, LMADDim num)]
forall a b. [a] -> [b] -> [(a, b)]
zip Permutation
sup_inds [LMADDim num]
dims_sup [(Int, LMADDim num)]
-> [(Int, LMADDim num)] -> [(Int, LMADDim num)]
forall a. [a] -> [a] -> [a]
++ Permutation -> [LMADDim num] -> [(Int, LMADDim num)]
forall a b. [a] -> [b] -> [(a, b)]
zip Permutation
rpt_inds [LMADDim num]
repeats'
lmad' :: LMAD num
lmad' = num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off' [LMADDim num]
dims'
IxFun num -> Maybe (IxFun num)
forall (m :: * -> *) a. Monad m => a -> m a
return (IxFun num -> Maybe (IxFun num)) -> IxFun num -> Maybe (IxFun num)
forall a b. (a -> b) -> a -> b
$ NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (Permutation -> LMAD num -> LMAD num
forall num. Permutation -> LMAD num -> LMAD num
setLMADPermutation Permutation
perm' LMAD num
lmad' LMAD num -> [LMAD num] -> NonEmpty (LMAD num)
forall a. a -> [a] -> NonEmpty a
:| [LMAD num]
lmads) Shape num
oldbase Bool
cg
where
consecutive :: a -> [a] -> Bool
consecutive a
_ [] = Bool
True
consecutive a
i [a
p] = a
i a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
p
consecutive a
i [a]
ps = [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ([Bool] -> Bool) -> [Bool] -> Bool
forall a b. (a -> b) -> a -> b
$ (a -> a -> Bool) -> [a] -> [a] -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith a -> a -> Bool
forall a. Eq a => a -> a -> Bool
(==) [a]
ps [a
i, a
i a -> a -> a
forall a. Num a => a -> a -> a
+ a
1 ..]
splitCoercions ::
(Eq num, IntegralExp num) =>
ShapeChange num ->
Maybe (ShapeChange num, ShapeChange num, ShapeChange num)
splitCoercions :: ShapeChange num
-> Maybe (ShapeChange num, ShapeChange num, ShapeChange num)
splitCoercions ShapeChange num
newshape' = do
let (ShapeChange num
head_coercions, ShapeChange num
newshape'') = (DimChange num -> Bool)
-> ShapeChange num -> (ShapeChange num, ShapeChange num)
forall a. (a -> Bool) -> [a] -> ([a], [a])
span DimChange num -> Bool
forall d. DimChange d -> Bool
isCoercion ShapeChange num
newshape'
(ShapeChange num
reshapes, ShapeChange num
tail_coercions) = (DimChange num -> Bool)
-> ShapeChange num -> (ShapeChange num, ShapeChange num)
forall a. (a -> Bool) -> [a] -> ([a], [a])
break DimChange num -> Bool
forall d. DimChange d -> Bool
isCoercion ShapeChange num
newshape''
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard ((DimChange num -> Bool) -> ShapeChange num -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all DimChange num -> Bool
forall d. DimChange d -> Bool
isCoercion ShapeChange num
tail_coercions)
(ShapeChange num, ShapeChange num, ShapeChange num)
-> Maybe (ShapeChange num, ShapeChange num, ShapeChange num)
forall (m :: * -> *) a. Monad m => a -> m a
return (ShapeChange num
head_coercions, ShapeChange num
reshapes, ShapeChange num
tail_coercions)
where
isCoercion :: DimChange d -> Bool
isCoercion DimCoercion {} = Bool
True
isCoercion DimChange d
_ = Bool
False
reshape ::
(Eq num, IntegralExp num) =>
IxFun num ->
ShapeChange num ->
IxFun num
reshape :: IxFun num -> ShapeChange num -> IxFun num
reshape IxFun num
ixfun ShapeChange num
new_shape
| Just IxFun num
ixfun' <- IxFun num -> ShapeChange num -> Maybe (IxFun num)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> ShapeChange num -> Maybe (IxFun num)
reshapeCoercion IxFun num
ixfun ShapeChange num
new_shape = IxFun num
ixfun'
| Just IxFun num
ixfun' <- IxFun num -> ShapeChange num -> Maybe (IxFun num)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> ShapeChange num -> Maybe (IxFun num)
reshapeOneLMAD IxFun num
ixfun ShapeChange num
new_shape = IxFun num
ixfun'
reshape (IxFun (LMAD num
lmad0 :| [LMAD num]
lmad0s) Shape num
oshp Bool
cg) ShapeChange num
new_shape =
case Shape num -> IxFun num
forall num. IntegralExp num => Shape num -> IxFun num
iota (ShapeChange num -> Shape num
forall d. ShapeChange d -> [d]
newDims ShapeChange num
new_shape) of
IxFun (LMAD num
lmad :| []) Shape num
_ Bool
_ -> NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (LMAD num
lmad LMAD num -> [LMAD num] -> NonEmpty (LMAD num)
forall a. a -> [a] -> NonEmpty a
:| LMAD num
lmad0 LMAD num -> [LMAD num] -> [LMAD num]
forall a. a -> [a] -> [a]
: [LMAD num]
lmad0s) Shape num
oshp Bool
cg
IxFun num
_ -> String -> IxFun num
forall a. HasCallStack => String -> a
error String
"reshape: reached impossible case"
rank ::
IntegralExp num =>
IxFun num ->
Int
rank :: IxFun num -> Int
rank (IxFun (LMAD num
_ [LMADDim num]
sss :| [LMAD num]
_) Shape num
_ Bool
_) = [LMADDim num] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
sss
rebaseNice ::
(Eq num, IntegralExp num) =>
IxFun num ->
IxFun num ->
Maybe (IxFun num)
rebaseNice :: IxFun num -> IxFun num -> Maybe (IxFun num)
rebaseNice
new_base :: IxFun num
new_base@(IxFun (LMAD num
lmad_base :| [LMAD num]
lmads_base) Shape num
_ Bool
cg_base)
ixfun :: IxFun num
ixfun@(IxFun NonEmpty (LMAD num)
lmads Shape num
shp Bool
cg) = do
let (LMAD num
lmad :| [LMAD num]
lmads') = NonEmpty (LMAD num) -> NonEmpty (LMAD num)
forall a. NonEmpty a -> NonEmpty a
NE.reverse NonEmpty (LMAD num)
lmads
dims :: [LMADDim num]
dims = LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad
perm :: Permutation
perm = LMAD num -> Permutation
forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad
perm_base :: Permutation
perm_base = LMAD num -> Permutation
forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad_base
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$
IxFun num -> Shape num
forall a. IxFun a -> [a]
base IxFun num
ixfun Shape num -> Shape num -> Bool
forall a. Eq a => a -> a -> Bool
== IxFun num -> Shape num
forall num. (Eq num, IntegralExp num) => IxFun num -> Shape num
shape IxFun num
new_base
Bool -> Bool -> Bool
&& Bool
cg
Bool -> Bool -> Bool
&& (LMADDim num -> Bool) -> [LMADDim num] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((Monotonicity -> Monotonicity -> Bool
forall a. Eq a => a -> a -> Bool
/= Monotonicity
Unknown) (Monotonicity -> Bool)
-> (LMADDim num -> Monotonicity) -> LMADDim num -> Bool
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. LMADDim num -> Monotonicity
forall num. LMADDim num -> Monotonicity
ldMon) [LMADDim num]
dims
Bool -> Bool -> Bool
&& (IxFun num -> Bool
forall a. IxFun a -> Bool
hasContiguousPerm IxFun num
ixfun Bool -> Bool -> Bool
|| IxFun num -> Bool
forall a. IxFun a -> Bool
hasContiguousPerm IxFun num
new_base)
Bool -> Bool -> Bool
&& (Permutation -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Permutation
perm Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Permutation -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Permutation
perm_base Bool -> Bool -> Bool
|| IxFun num -> Bool
forall a. IxFun a -> Bool
hasContiguousPerm IxFun num
ixfun)
Bool -> Bool -> Bool
&& [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and
( (num -> LMADDim num -> Bool -> Bool)
-> Shape num -> [LMADDim num] -> [Bool] -> [Bool]
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3
(\num
sn LMADDim num
ld Bool
inner -> num
sn num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== LMADDim num -> num
forall num. LMADDim num -> num
ldShape LMADDim num
ld Bool -> Bool -> Bool
|| (Bool
inner Bool -> Bool -> Bool
&& LMADDim num -> num
forall num. LMADDim num -> num
ldStride LMADDim num
ld num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
1))
Shape num
shp
[LMADDim num]
dims
(Int -> Bool -> [Bool]
forall a. Int -> a -> [a]
replicate ([LMADDim num] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Bool
False [Bool] -> [Bool] -> [Bool]
forall a. [a] -> [a] -> [a]
++ [Bool
True])
)
let perm_base' :: Permutation
perm_base' =
if IxFun num -> Bool
forall a. IxFun a -> Bool
hasContiguousPerm IxFun num
ixfun
then Permutation
perm_base
else (Int -> Int) -> Permutation -> Permutation
forall a b. (a -> b) -> [a] -> [b]
map (Permutation
perm Permutation -> Int -> Int
forall a. [a] -> Int -> a
!!) Permutation
perm_base
lmad_base' :: LMAD num
lmad_base' = Permutation -> LMAD num -> LMAD num
forall num. Permutation -> LMAD num -> LMAD num
setLMADPermutation Permutation
perm_base' LMAD num
lmad_base
dims_base :: [LMADDim num]
dims_base = LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad_base'
n_fewer_dims :: Int
n_fewer_dims = [LMADDim num] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims_base Int -> Int -> Int
forall a. Num a => a -> a -> a
- [LMADDim num] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims
([LMADDim num]
dims_base', Shape num
offs_contrib) =
[(LMADDim num, num)] -> ([LMADDim num], Shape num)
forall a b. [(a, b)] -> ([a], [b])
unzip ([(LMADDim num, num)] -> ([LMADDim num], Shape num))
-> [(LMADDim num, num)] -> ([LMADDim num], Shape num)
forall a b. (a -> b) -> a -> b
$
(LMADDim num -> LMADDim num -> (LMADDim num, num))
-> [LMADDim num] -> [LMADDim num] -> [(LMADDim num, num)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
( \(LMADDim num
s1 num
r1 num
n1 Int
p1 Monotonicity
_) (LMADDim num
_ num
r2 num
_ Int
_ Monotonicity
m2) ->
let (num
s', num
off')
| Monotonicity
m2 Monotonicity -> Monotonicity -> Bool
forall a. Eq a => a -> a -> Bool
== Monotonicity
Inc = (num
s1, num
0)
| Bool
otherwise = (num
s1 num -> num -> num
forall a. Num a => a -> a -> a
* (-num
1), num
s1 num -> num -> num
forall a. Num a => a -> a -> a
* (num
n1 num -> num -> num
forall a. Num a => a -> a -> a
- num
1))
r' :: num
r'
| Monotonicity
m2 Monotonicity -> Monotonicity -> Bool
forall a. Eq a => a -> a -> Bool
== Monotonicity
Inc = if num
r2 num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0 then num
r1 else num
r1 num -> num -> num
forall a. Num a => a -> a -> a
+ num
r2
| num
r1 num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0 = num
r2
| num
r2 num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0 = num
n1 num -> num -> num
forall a. Num a => a -> a -> a
- num
r1
| Bool
otherwise = num
n1 num -> num -> num
forall a. Num a => a -> a -> a
- num
r1 num -> num -> num
forall a. Num a => a -> a -> a
+ num
r2
in (num -> num -> num -> Int -> Monotonicity -> LMADDim num
forall num. num -> num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim num
s' num
r' num
n1 (Int
p1 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
n_fewer_dims) Monotonicity
Inc, num
off')
)
(Int -> [LMADDim num] -> [LMADDim num]
forall a. Int -> [a] -> [a]
drop Int
n_fewer_dims [LMADDim num]
dims_base)
[LMADDim num]
dims
off_base :: num
off_base = LMAD num -> num
forall num. LMAD num -> num
lmadOffset LMAD num
lmad_base' num -> num -> num
forall a. Num a => a -> a -> a
+ Shape num -> num
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum Shape num
offs_contrib
lmad_base'' :: LMAD num
lmad_base''
| LMAD num -> num
forall num. LMAD num -> num
lmadOffset LMAD num
lmad num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0 = num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off_base [LMADDim num]
dims_base'
| Bool
otherwise =
Shape num -> LMAD num -> LMAD num
forall num. Shape num -> LMAD num -> LMAD num
setLMADShape
(LMAD num -> Shape num
forall num. (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShape LMAD num
lmad)
( num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD
(num
off_base num -> num -> num
forall a. Num a => a -> a -> a
+ LMADDim num -> num
forall num. LMADDim num -> num
ldStride ([LMADDim num] -> LMADDim num
forall a. [a] -> a
last [LMADDim num]
dims_base) num -> num -> num
forall a. Num a => a -> a -> a
* LMAD num -> num
forall num. LMAD num -> num
lmadOffset LMAD num
lmad)
[LMADDim num]
dims_base'
)
new_base' :: IxFun num
new_base' = NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (LMAD num
lmad_base'' LMAD num -> [LMAD num] -> NonEmpty (LMAD num)
forall a. a -> [a] -> NonEmpty a
:| [LMAD num]
lmads_base) Shape num
shp Bool
cg_base
IxFun NonEmpty (LMAD num)
lmads_base' Shape num
_ Bool
_ = IxFun num
new_base'
lmads'' :: NonEmpty (LMAD num)
lmads'' = [LMAD num]
lmads' [LMAD num] -> NonEmpty (LMAD num) -> NonEmpty (LMAD num)
forall a. [a] -> NonEmpty a -> NonEmpty a
++@ NonEmpty (LMAD num)
lmads_base'
IxFun num -> Maybe (IxFun num)
forall (m :: * -> *) a. Monad m => a -> m a
return (IxFun num -> Maybe (IxFun num)) -> IxFun num -> Maybe (IxFun num)
forall a b. (a -> b) -> a -> b
$ NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun NonEmpty (LMAD num)
lmads'' Shape num
shp (Bool
cg Bool -> Bool -> Bool
&& Bool
cg_base)
rebase ::
(Eq num, IntegralExp num) =>
IxFun num ->
IxFun num ->
IxFun num
rebase :: IxFun num -> IxFun num -> IxFun num
rebase new_base :: IxFun num
new_base@(IxFun NonEmpty (LMAD num)
lmads_base Shape num
shp_base Bool
cg_base) ixfun :: IxFun num
ixfun@(IxFun NonEmpty (LMAD num)
lmads Shape num
shp Bool
cg)
| Just IxFun num
ixfun' <- IxFun num -> IxFun num -> Maybe (IxFun num)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> IxFun num -> Maybe (IxFun num)
rebaseNice IxFun num
new_base IxFun num
ixfun = IxFun num
ixfun'
| Bool
otherwise =
let (NonEmpty (LMAD num)
lmads_base', Shape num
shp_base') =
if IxFun num -> Shape num
forall a. IxFun a -> [a]
base IxFun num
ixfun Shape num -> Shape num -> Bool
forall a. Eq a => a -> a -> Bool
== IxFun num -> Shape num
forall num. (Eq num, IntegralExp num) => IxFun num -> Shape num
shape IxFun num
new_base
then (NonEmpty (LMAD num)
lmads_base, Shape num
shp_base)
else
let IxFun NonEmpty (LMAD num)
lmads' Shape num
shp_base'' Bool
_ = IxFun num -> ShapeChange num -> IxFun num
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> ShapeChange num -> IxFun num
reshape IxFun num
new_base (ShapeChange num -> IxFun num) -> ShapeChange num -> IxFun num
forall a b. (a -> b) -> a -> b
$ (num -> DimChange num) -> Shape num -> ShapeChange num
forall a b. (a -> b) -> [a] -> [b]
map num -> DimChange num
forall d. d -> DimChange d
DimCoercion Shape num
shp
in (NonEmpty (LMAD num)
lmads', Shape num
shp_base'')
in NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (NonEmpty (LMAD num)
lmads NonEmpty (LMAD num) -> NonEmpty (LMAD num) -> NonEmpty (LMAD num)
forall a. NonEmpty a -> NonEmpty a -> NonEmpty a
@++@ NonEmpty (LMAD num)
lmads_base') Shape num
shp_base' (Bool
cg Bool -> Bool -> Bool
&& Bool
cg_base)
ixfunMonotonicity :: (Eq num, IntegralExp num) => IxFun num -> Monotonicity
ixfunMonotonicity :: IxFun num -> Monotonicity
ixfunMonotonicity = Bool -> IxFun num -> Monotonicity
forall num.
(Eq num, IntegralExp num) =>
Bool -> IxFun num -> Monotonicity
ixfunMonotonicityRots Bool
False
linearWithOffset ::
(Eq num, IntegralExp num) =>
IxFun num ->
num ->
Maybe num
linearWithOffset :: IxFun num -> num -> Maybe num
linearWithOffset ixfun :: IxFun num
ixfun@(IxFun (LMAD num
lmad :| []) Shape num
_ Bool
cg) num
elem_size
| IxFun num -> Bool
forall a. IxFun a -> Bool
hasContiguousPerm IxFun num
ixfun Bool -> Bool -> Bool
&& Bool
cg Bool -> Bool -> Bool
&& IxFun num -> Monotonicity
forall num. (Eq num, IntegralExp num) => IxFun num -> Monotonicity
ixfunMonotonicity IxFun num
ixfun Monotonicity -> Monotonicity -> Bool
forall a. Eq a => a -> a -> Bool
== Monotonicity
Inc =
num -> Maybe num
forall a. a -> Maybe a
Just (num -> Maybe num) -> num -> Maybe num
forall a b. (a -> b) -> a -> b
$ LMAD num -> num
forall num. LMAD num -> num
lmadOffset LMAD num
lmad num -> num -> num
forall a. Num a => a -> a -> a
* num
elem_size
linearWithOffset IxFun num
_ num
_ = Maybe num
forall a. Maybe a
Nothing
rearrangeWithOffset ::
(Eq num, IntegralExp num) =>
IxFun num ->
num ->
Maybe (num, [(Int, num)])
rearrangeWithOffset :: IxFun num -> num -> Maybe (num, [(Int, num)])
rearrangeWithOffset (IxFun (LMAD num
lmad :| []) Shape num
oshp Bool
cg) num
elem_size = do
let perm :: Permutation
perm = LMAD num -> Permutation
forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad
perm_contig :: Permutation
perm_contig = [Int
0 .. Permutation -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Permutation
perm Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
num
offset <-
IxFun num -> num -> Maybe num
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
linearWithOffset
(NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (Permutation -> LMAD num -> LMAD num
forall num. Permutation -> LMAD num -> LMAD num
setLMADPermutation Permutation
perm_contig LMAD num
lmad LMAD num -> [LMAD num] -> NonEmpty (LMAD num)
forall a. a -> [a] -> NonEmpty a
:| []) Shape num
oshp Bool
cg)
num
elem_size
(num, [(Int, num)]) -> Maybe (num, [(Int, num)])
forall (m :: * -> *) a. Monad m => a -> m a
return (num
offset, Permutation -> Shape num -> [(Int, num)]
forall a b. [a] -> [b] -> [(a, b)]
zip Permutation
perm (Permutation -> Shape num -> Shape num
forall a. Permutation -> [a] -> [a]
permuteFwd Permutation
perm (LMAD num -> Shape num
forall num. (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShapeBase LMAD num
lmad)))
rearrangeWithOffset IxFun num
_ num
_ = Maybe (num, [(Int, num)])
forall a. Maybe a
Nothing
isLinear :: (Eq num, IntegralExp num) => IxFun num -> Bool
isLinear :: IxFun num -> Bool
isLinear = (Maybe num -> Maybe num -> Bool
forall a. Eq a => a -> a -> Bool
== num -> Maybe num
forall a. a -> Maybe a
Just num
0) (Maybe num -> Bool)
-> (IxFun num -> Maybe num) -> IxFun num -> Bool
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (IxFun num -> num -> Maybe num) -> num -> IxFun num -> Maybe num
forall a b c. (a -> b -> c) -> b -> a -> c
flip IxFun num -> num -> Maybe num
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
linearWithOffset num
1
permuteFwd :: Permutation -> [a] -> [a]
permuteFwd :: Permutation -> [a] -> [a]
permuteFwd Permutation
ps [a]
elems = (Int -> a) -> Permutation -> [a]
forall a b. (a -> b) -> [a] -> [b]
map ([a]
elems [a] -> Int -> a
forall a. [a] -> Int -> a
!!) Permutation
ps
permuteInv :: Permutation -> [a] -> [a]
permuteInv :: Permutation -> [a] -> [a]
permuteInv Permutation
ps [a]
elems = ((Int, a) -> a) -> [(Int, a)] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map (Int, a) -> a
forall a b. (a, b) -> b
snd ([(Int, a)] -> [a]) -> [(Int, a)] -> [a]
forall a b. (a -> b) -> a -> b
$ ((Int, a) -> (Int, a) -> Ordering) -> [(Int, a)] -> [(Int, a)]
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (Int -> Int -> Ordering)
-> ((Int, a) -> Int) -> (Int, a) -> (Int, a) -> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (Int, a) -> Int
forall a b. (a, b) -> a
fst) ([(Int, a)] -> [(Int, a)]) -> [(Int, a)] -> [(Int, a)]
forall a b. (a -> b) -> a -> b
$ Permutation -> [a] -> [(Int, a)]
forall a b. [a] -> [b] -> [(a, b)]
zip Permutation
ps [a]
elems
flatOneDim ::
(Eq num, IntegralExp num) =>
(num, num, num) ->
num ->
num
flatOneDim :: (num, num, num) -> num -> num
flatOneDim (num
s, num
r, num
n) num
i
| num
s num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0 = num
0
| num
r num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0 = num
i num -> num -> num
forall a. Num a => a -> a -> a
* num
s
| Bool
otherwise = ((num
i num -> num -> num
forall a. Num a => a -> a -> a
+ num
r) num -> num -> num
forall e. IntegralExp e => e -> e -> e
`mod` num
n) num -> num -> num
forall a. Num a => a -> a -> a
* num
s
makeRotIota ::
IntegralExp num =>
Monotonicity ->
num ->
[(num, num)] ->
LMAD num
makeRotIota :: Monotonicity -> num -> [(num, num)] -> LMAD num
makeRotIota Monotonicity
mon num
off [(num, num)]
support
| Monotonicity
mon Monotonicity -> Monotonicity -> Bool
forall a. Eq a => a -> a -> Bool
== Monotonicity
Inc Bool -> Bool -> Bool
|| Monotonicity
mon Monotonicity -> Monotonicity -> Bool
forall a. Eq a => a -> a -> Bool
== Monotonicity
Dec =
let rk :: Int
rk = [(num, num)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(num, num)]
support
([num]
rs, [num]
ns) = [(num, num)] -> ([num], [num])
forall a b. [(a, b)] -> ([a], [b])
unzip [(num, num)]
support
ss0 :: [num]
ss0 = [num] -> [num]
forall a. [a] -> [a]
reverse ([num] -> [num]) -> [num] -> [num]
forall a b. (a -> b) -> a -> b
$ Int -> [num] -> [num]
forall a. Int -> [a] -> [a]
take Int
rk ([num] -> [num]) -> [num] -> [num]
forall a b. (a -> b) -> a -> b
$ (num -> num -> num) -> num -> [num] -> [num]
forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl num -> num -> num
forall a. Num a => a -> a -> a
(*) num
1 ([num] -> [num]) -> [num] -> [num]
forall a b. (a -> b) -> a -> b
$ [num] -> [num]
forall a. [a] -> [a]
reverse [num]
ns
ss :: [num]
ss =
if Monotonicity
mon Monotonicity -> Monotonicity -> Bool
forall a. Eq a => a -> a -> Bool
== Monotonicity
Inc
then [num]
ss0
else (num -> num) -> [num] -> [num]
forall a b. (a -> b) -> [a] -> [b]
map (num -> num -> num
forall a. Num a => a -> a -> a
* (-num
1)) [num]
ss0
ps :: Permutation
ps = (Int -> Int) -> Permutation -> Permutation
forall a b. (a -> b) -> [a] -> [b]
map Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral [Int
0 .. Int
rk Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
fi :: [Monotonicity]
fi = Int -> Monotonicity -> [Monotonicity]
forall a. Int -> a -> [a]
replicate Int
rk Monotonicity
mon
in num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off ([LMADDim num] -> LMAD num) -> [LMADDim num] -> LMAD num
forall a b. (a -> b) -> a -> b
$ (num -> num -> num -> Int -> Monotonicity -> LMADDim num)
-> [num]
-> [num]
-> [num]
-> Permutation
-> [Monotonicity]
-> [LMADDim num]
forall a b c d e f.
(a -> b -> c -> d -> e -> f)
-> [a] -> [b] -> [c] -> [d] -> [e] -> [f]
zipWith5 num -> num -> num -> Int -> Monotonicity -> LMADDim num
forall num. num -> num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim [num]
ss [num]
rs [num]
ns Permutation
ps [Monotonicity]
fi
| Bool
otherwise = String -> LMAD num
forall a. HasCallStack => String -> a
error String
"makeRotIota: requires Inc or Dec"
ixfunMonotonicityRots ::
(Eq num, IntegralExp num) =>
Bool ->
IxFun num ->
Monotonicity
ixfunMonotonicityRots :: Bool -> IxFun num -> Monotonicity
ixfunMonotonicityRots Bool
ignore_rots (IxFun (LMAD num
lmad :| [LMAD num]
lmads) Shape num
_ Bool
_) =
let mon0 :: Monotonicity
mon0 = LMAD num -> Monotonicity
forall num. (Eq num, IntegralExp num) => LMAD num -> Monotonicity
lmadMonotonicityRots LMAD num
lmad
in if (LMAD num -> Bool) -> [LMAD num] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((Monotonicity -> Monotonicity -> Bool
forall a. Eq a => a -> a -> Bool
== Monotonicity
mon0) (Monotonicity -> Bool)
-> (LMAD num -> Monotonicity) -> LMAD num -> Bool
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. LMAD num -> Monotonicity
forall num. (Eq num, IntegralExp num) => LMAD num -> Monotonicity
lmadMonotonicityRots) [LMAD num]
lmads
then Monotonicity
mon0
else Monotonicity
Unknown
where
lmadMonotonicityRots ::
(Eq num, IntegralExp num) =>
LMAD num ->
Monotonicity
lmadMonotonicityRots :: LMAD num -> Monotonicity
lmadMonotonicityRots (LMAD num
_ [LMADDim num]
dims)
| (LMADDim num -> Bool) -> [LMADDim num] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Monotonicity -> LMADDim num -> Bool
forall num.
(Eq num, IntegralExp num) =>
Monotonicity -> LMADDim num -> Bool
isMonDim Monotonicity
Inc) [LMADDim num]
dims = Monotonicity
Inc
| (LMADDim num -> Bool) -> [LMADDim num] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Monotonicity -> LMADDim num -> Bool
forall num.
(Eq num, IntegralExp num) =>
Monotonicity -> LMADDim num -> Bool
isMonDim Monotonicity
Dec) [LMADDim num]
dims = Monotonicity
Dec
| Bool
otherwise = Monotonicity
Unknown
isMonDim ::
(Eq num, IntegralExp num) =>
Monotonicity ->
LMADDim num ->
Bool
isMonDim :: Monotonicity -> LMADDim num -> Bool
isMonDim Monotonicity
mon (LMADDim num
s num
r num
_ Int
_ Monotonicity
ldmon) =
num
s num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0 Bool -> Bool -> Bool
|| ((Bool
ignore_rots Bool -> Bool -> Bool
|| num
r num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0) Bool -> Bool -> Bool
&& Monotonicity
mon Monotonicity -> Monotonicity -> Bool
forall a. Eq a => a -> a -> Bool
== Monotonicity
ldmon)
leastGeneralGeneralization ::
Eq v =>
IxFun (PrimExp v) ->
IxFun (PrimExp v) ->
Maybe (IxFun (PrimExp (Ext v)), [(PrimExp v, PrimExp v)])
leastGeneralGeneralization :: IxFun (PrimExp v)
-> IxFun (PrimExp v)
-> Maybe (IxFun (PrimExp (Ext v)), [(PrimExp v, PrimExp v)])
leastGeneralGeneralization (IxFun (LMAD (PrimExp v)
lmad1 :| []) Shape (PrimExp v)
oshp1 Bool
ctg1) (IxFun (LMAD (PrimExp v)
lmad2 :| []) Shape (PrimExp v)
oshp2 Bool
ctg2) = do
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$
Shape (PrimExp v) -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape (PrimExp v)
oshp1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Shape (PrimExp v) -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape (PrimExp v)
oshp2
Bool -> Bool -> Bool
&& Bool
ctg1 Bool -> Bool -> Bool
forall a. Eq a => a -> a -> Bool
== Bool
ctg2
Bool -> Bool -> Bool
&& (LMADDim (PrimExp v) -> Int)
-> [LMADDim (PrimExp v)] -> Permutation
forall a b. (a -> b) -> [a] -> [b]
map LMADDim (PrimExp v) -> Int
forall num. LMADDim num -> Int
ldPerm (LMAD (PrimExp v) -> [LMADDim (PrimExp v)]
forall num. LMAD num -> [LMADDim num]
lmadDims LMAD (PrimExp v)
lmad1) Permutation -> Permutation -> Bool
forall a. Eq a => a -> a -> Bool
== (LMADDim (PrimExp v) -> Int)
-> [LMADDim (PrimExp v)] -> Permutation
forall a b. (a -> b) -> [a] -> [b]
map LMADDim (PrimExp v) -> Int
forall num. LMADDim num -> Int
ldPerm (LMAD (PrimExp v) -> [LMADDim (PrimExp v)]
forall num. LMAD num -> [LMADDim num]
lmadDims LMAD (PrimExp v)
lmad2)
Bool -> Bool -> Bool
&& LMAD (PrimExp v) -> [Monotonicity]
forall num. LMAD num -> [Monotonicity]
lmadDMon LMAD (PrimExp v)
lmad1 [Monotonicity] -> [Monotonicity] -> Bool
forall a. Eq a => a -> a -> Bool
== LMAD (PrimExp v) -> [Monotonicity]
forall num. LMAD num -> [Monotonicity]
lmadDMon LMAD (PrimExp v)
lmad2
let (Bool
ctg, Permutation
dperm, [Monotonicity]
dmon) = (Bool
ctg1, LMAD (PrimExp v) -> Permutation
forall num. LMAD num -> Permutation
lmadPermutation LMAD (PrimExp v)
lmad1, LMAD (PrimExp v) -> [Monotonicity]
forall num. LMAD num -> [Monotonicity]
lmadDMon LMAD (PrimExp v)
lmad1)
([PrimExp (Ext v)]
dshp, [(PrimExp v, PrimExp v)]
m1) <- [(PrimExp v, PrimExp v)]
-> Shape (PrimExp v)
-> Shape (PrimExp v)
-> Maybe ([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
forall (m :: * -> *) v.
(Monad m, Eq v) =>
[(PrimExp v, PrimExp v)]
-> [PrimExp v]
-> [PrimExp v]
-> m ([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
generalize [] (LMAD (PrimExp v) -> Shape (PrimExp v)
forall a. LMAD a -> [a]
lmadDShp LMAD (PrimExp v)
lmad1) (LMAD (PrimExp v) -> Shape (PrimExp v)
forall a. LMAD a -> [a]
lmadDShp LMAD (PrimExp v)
lmad2)
([PrimExp (Ext v)]
oshp, [(PrimExp v, PrimExp v)]
m2) <- [(PrimExp v, PrimExp v)]
-> Shape (PrimExp v)
-> Shape (PrimExp v)
-> Maybe ([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
forall (m :: * -> *) v.
(Monad m, Eq v) =>
[(PrimExp v, PrimExp v)]
-> [PrimExp v]
-> [PrimExp v]
-> m ([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
generalize [(PrimExp v, PrimExp v)]
m1 Shape (PrimExp v)
oshp1 Shape (PrimExp v)
oshp2
([PrimExp (Ext v)]
dstd, [(PrimExp v, PrimExp v)]
m3) <- [(PrimExp v, PrimExp v)]
-> Shape (PrimExp v)
-> Shape (PrimExp v)
-> Maybe ([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
forall (m :: * -> *) v.
(Monad m, Eq v) =>
[(PrimExp v, PrimExp v)]
-> [PrimExp v]
-> [PrimExp v]
-> m ([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
generalize [(PrimExp v, PrimExp v)]
m2 (LMAD (PrimExp v) -> Shape (PrimExp v)
forall a. LMAD a -> [a]
lmadDSrd LMAD (PrimExp v)
lmad1) (LMAD (PrimExp v) -> Shape (PrimExp v)
forall a. LMAD a -> [a]
lmadDSrd LMAD (PrimExp v)
lmad2)
([PrimExp (Ext v)]
drot, [(PrimExp v, PrimExp v)]
m4) <- [(PrimExp v, PrimExp v)]
-> Shape (PrimExp v)
-> Shape (PrimExp v)
-> Maybe ([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
forall (m :: * -> *) v.
(Monad m, Eq v) =>
[(PrimExp v, PrimExp v)]
-> [PrimExp v]
-> [PrimExp v]
-> m ([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
generalize [(PrimExp v, PrimExp v)]
m3 (LMAD (PrimExp v) -> Shape (PrimExp v)
forall a. LMAD a -> [a]
lmadDRot LMAD (PrimExp v)
lmad1) (LMAD (PrimExp v) -> Shape (PrimExp v)
forall a. LMAD a -> [a]
lmadDRot LMAD (PrimExp v)
lmad2)
let (PrimExp (Ext v)
offt, [(PrimExp v, PrimExp v)]
m5) = [(PrimExp v, PrimExp v)]
-> PrimExp v
-> PrimExp v
-> (PrimExp (Ext v), [(PrimExp v, PrimExp v)])
forall v.
Eq v =>
[(PrimExp v, PrimExp v)]
-> PrimExp v
-> PrimExp v
-> (PrimExp (Ext v), [(PrimExp v, PrimExp v)])
PEG.leastGeneralGeneralization [(PrimExp v, PrimExp v)]
m4 (LMAD (PrimExp v) -> PrimExp v
forall num. LMAD num -> num
lmadOffset LMAD (PrimExp v)
lmad1) (LMAD (PrimExp v) -> PrimExp v
forall num. LMAD num -> num
lmadOffset LMAD (PrimExp v)
lmad2)
let lmad_dims :: [LMADDim (PrimExp (Ext v))]
lmad_dims =
((PrimExp (Ext v), PrimExp (Ext v), PrimExp (Ext v), Int,
Monotonicity)
-> LMADDim (PrimExp (Ext v)))
-> [(PrimExp (Ext v), PrimExp (Ext v), PrimExp (Ext v), Int,
Monotonicity)]
-> [LMADDim (PrimExp (Ext v))]
forall a b. (a -> b) -> [a] -> [b]
map (\(PrimExp (Ext v)
a, PrimExp (Ext v)
b, PrimExp (Ext v)
c, Int
d, Monotonicity
e) -> PrimExp (Ext v)
-> PrimExp (Ext v)
-> PrimExp (Ext v)
-> Int
-> Monotonicity
-> LMADDim (PrimExp (Ext v))
forall num. num -> num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim PrimExp (Ext v)
a PrimExp (Ext v)
b PrimExp (Ext v)
c Int
d Monotonicity
e) ([(PrimExp (Ext v), PrimExp (Ext v), PrimExp (Ext v), Int,
Monotonicity)]
-> [LMADDim (PrimExp (Ext v))])
-> [(PrimExp (Ext v), PrimExp (Ext v), PrimExp (Ext v), Int,
Monotonicity)]
-> [LMADDim (PrimExp (Ext v))]
forall a b. (a -> b) -> a -> b
$
[PrimExp (Ext v)]
-> [PrimExp (Ext v)]
-> [PrimExp (Ext v)]
-> Permutation
-> [Monotonicity]
-> [(PrimExp (Ext v), PrimExp (Ext v), PrimExp (Ext v), Int,
Monotonicity)]
forall a b c d e.
[a] -> [b] -> [c] -> [d] -> [e] -> [(a, b, c, d, e)]
zip5 [PrimExp (Ext v)]
dstd [PrimExp (Ext v)]
drot [PrimExp (Ext v)]
dshp Permutation
dperm [Monotonicity]
dmon
lmad :: LMAD (PrimExp (Ext v))
lmad = PrimExp (Ext v)
-> [LMADDim (PrimExp (Ext v))] -> LMAD (PrimExp (Ext v))
forall num. num -> [LMADDim num] -> LMAD num
LMAD PrimExp (Ext v)
offt [LMADDim (PrimExp (Ext v))]
lmad_dims
(IxFun (PrimExp (Ext v)), [(PrimExp v, PrimExp v)])
-> Maybe (IxFun (PrimExp (Ext v)), [(PrimExp v, PrimExp v)])
forall (m :: * -> *) a. Monad m => a -> m a
return (NonEmpty (LMAD (PrimExp (Ext v)))
-> [PrimExp (Ext v)] -> Bool -> IxFun (PrimExp (Ext v))
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (LMAD (PrimExp (Ext v))
lmad LMAD (PrimExp (Ext v))
-> [LMAD (PrimExp (Ext v))] -> NonEmpty (LMAD (PrimExp (Ext v)))
forall a. a -> [a] -> NonEmpty a
:| []) [PrimExp (Ext v)]
oshp Bool
ctg, [(PrimExp v, PrimExp v)]
m5)
where
lmadDMon :: LMAD num -> [Monotonicity]
lmadDMon = (LMADDim num -> Monotonicity) -> [LMADDim num] -> [Monotonicity]
forall a b. (a -> b) -> [a] -> [b]
map LMADDim num -> Monotonicity
forall num. LMADDim num -> Monotonicity
ldMon ([LMADDim num] -> [Monotonicity])
-> (LMAD num -> [LMADDim num]) -> LMAD num -> [Monotonicity]
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
lmadDims
lmadDSrd :: LMAD b -> [b]
lmadDSrd = (LMADDim b -> b) -> [LMADDim b] -> [b]
forall a b. (a -> b) -> [a] -> [b]
map LMADDim b -> b
forall num. LMADDim num -> num
ldStride ([LMADDim b] -> [b]) -> (LMAD b -> [LMADDim b]) -> LMAD b -> [b]
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. LMAD b -> [LMADDim b]
forall num. LMAD num -> [LMADDim num]
lmadDims
lmadDShp :: LMAD b -> [b]
lmadDShp = (LMADDim b -> b) -> [LMADDim b] -> [b]
forall a b. (a -> b) -> [a] -> [b]
map LMADDim b -> b
forall num. LMADDim num -> num
ldShape ([LMADDim b] -> [b]) -> (LMAD b -> [LMADDim b]) -> LMAD b -> [b]
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. LMAD b -> [LMADDim b]
forall num. LMAD num -> [LMADDim num]
lmadDims
lmadDRot :: LMAD b -> [b]
lmadDRot = (LMADDim b -> b) -> [LMADDim b] -> [b]
forall a b. (a -> b) -> [a] -> [b]
map LMADDim b -> b
forall num. LMADDim num -> num
ldRotate ([LMADDim b] -> [b]) -> (LMAD b -> [LMADDim b]) -> LMAD b -> [b]
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. LMAD b -> [LMADDim b]
forall num. LMAD num -> [LMADDim num]
lmadDims
generalize :: [(PrimExp v, PrimExp v)]
-> [PrimExp v]
-> [PrimExp v]
-> m ([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
generalize [(PrimExp v, PrimExp v)]
m [PrimExp v]
l1 [PrimExp v]
l2 =
(([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
-> (PrimExp v, PrimExp v)
-> m ([PrimExp (Ext v)], [(PrimExp v, PrimExp v)]))
-> ([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
-> [(PrimExp v, PrimExp v)]
-> m ([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM
( \([PrimExp (Ext v)]
l_acc, [(PrimExp v, PrimExp v)]
m') (PrimExp v
pe1, PrimExp v
pe2) -> do
let (PrimExp (Ext v)
e, [(PrimExp v, PrimExp v)]
m'') = [(PrimExp v, PrimExp v)]
-> PrimExp v
-> PrimExp v
-> (PrimExp (Ext v), [(PrimExp v, PrimExp v)])
forall v.
Eq v =>
[(PrimExp v, PrimExp v)]
-> PrimExp v
-> PrimExp v
-> (PrimExp (Ext v), [(PrimExp v, PrimExp v)])
PEG.leastGeneralGeneralization [(PrimExp v, PrimExp v)]
m' PrimExp v
pe1 PrimExp v
pe2
([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
-> m ([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
forall (m :: * -> *) a. Monad m => a -> m a
return ([PrimExp (Ext v)]
l_acc [PrimExp (Ext v)] -> [PrimExp (Ext v)] -> [PrimExp (Ext v)]
forall a. [a] -> [a] -> [a]
++ [PrimExp (Ext v)
e], [(PrimExp v, PrimExp v)]
m'')
)
([], [(PrimExp v, PrimExp v)]
m)
([PrimExp v] -> [PrimExp v] -> [(PrimExp v, PrimExp v)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PrimExp v]
l1 [PrimExp v]
l2)
leastGeneralGeneralization IxFun (PrimExp v)
_ IxFun (PrimExp v)
_ = Maybe (IxFun (PrimExp (Ext v)), [(PrimExp v, PrimExp v)])
forall a. Maybe a
Nothing
isSequential :: [Int] -> Bool
isSequential :: Permutation -> Bool
isSequential Permutation
xs =
((Int, Int) -> Bool) -> [(Int, Int)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((Int -> Int -> Bool) -> (Int, Int) -> Bool
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
(==)) ([(Int, Int)] -> Bool) -> [(Int, Int)] -> Bool
forall a b. (a -> b) -> a -> b
$ Permutation -> Permutation -> [(Int, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip Permutation
xs [Int
0 ..]
existentializeExp :: TPrimExp t v -> State [TPrimExp t v] (TPrimExp t (Ext v))
existentializeExp :: TPrimExp t v -> State [TPrimExp t v] (TPrimExp t (Ext v))
existentializeExp TPrimExp t v
e = do
Int
i <- ([TPrimExp t v] -> Int) -> StateT [TPrimExp t v] Identity Int
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets [TPrimExp t v] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length
([TPrimExp t v] -> [TPrimExp t v])
-> StateT [TPrimExp t v] Identity ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ([TPrimExp t v] -> [TPrimExp t v] -> [TPrimExp t v]
forall a. [a] -> [a] -> [a]
++ [TPrimExp t v
e])
let t :: PrimType
t = PrimExp v -> PrimType
forall v. PrimExp v -> PrimType
primExpType (PrimExp v -> PrimType) -> PrimExp v -> PrimType
forall a b. (a -> b) -> a -> b
$ TPrimExp t v -> PrimExp v
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp t v
e
TPrimExp t (Ext v) -> State [TPrimExp t v] (TPrimExp t (Ext v))
forall (m :: * -> *) a. Monad m => a -> m a
return (TPrimExp t (Ext v) -> State [TPrimExp t v] (TPrimExp t (Ext v)))
-> TPrimExp t (Ext v) -> State [TPrimExp t v] (TPrimExp t (Ext v))
forall a b. (a -> b) -> a -> b
$ PrimExp (Ext v) -> TPrimExp t (Ext v)
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (PrimExp (Ext v) -> TPrimExp t (Ext v))
-> PrimExp (Ext v) -> TPrimExp t (Ext v)
forall a b. (a -> b) -> a -> b
$ Ext v -> PrimType -> PrimExp (Ext v)
forall v. v -> PrimType -> PrimExp v
LeafExp (Int -> Ext v
forall a. Int -> Ext a
Ext Int
i) PrimType
t
existentialize ::
(IntExp t, Eq v, Pretty v) =>
IxFun (TPrimExp t v) ->
State [TPrimExp t v] (Maybe (IxFun (TPrimExp t (Ext v))))
existentialize :: IxFun (TPrimExp t v)
-> State [TPrimExp t v] (Maybe (IxFun (TPrimExp t (Ext v))))
existentialize (IxFun (LMAD (TPrimExp t v)
lmad :| []) [TPrimExp t v]
oshp Bool
True)
| (LMADDim (TPrimExp t v) -> Bool)
-> [LMADDim (TPrimExp t v)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((TPrimExp t v -> TPrimExp t v -> Bool
forall a. Eq a => a -> a -> Bool
== TPrimExp t v
0) (TPrimExp t v -> Bool)
-> (LMADDim (TPrimExp t v) -> TPrimExp t v)
-> LMADDim (TPrimExp t v)
-> Bool
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. LMADDim (TPrimExp t v) -> TPrimExp t v
forall num. LMADDim num -> num
ldRotate) (LMAD (TPrimExp t v) -> [LMADDim (TPrimExp t v)]
forall num. LMAD num -> [LMADDim num]
lmadDims LMAD (TPrimExp t v)
lmad),
[TPrimExp t v] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (LMAD (TPrimExp t v) -> [TPrimExp t v]
forall num. (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShape LMAD (TPrimExp t v)
lmad) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [TPrimExp t v] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp t v]
oshp,
Permutation -> Bool
isSequential ((LMADDim (TPrimExp t v) -> Int)
-> [LMADDim (TPrimExp t v)] -> Permutation
forall a b. (a -> b) -> [a] -> [b]
map LMADDim (TPrimExp t v) -> Int
forall num. LMADDim num -> Int
ldPerm ([LMADDim (TPrimExp t v)] -> Permutation)
-> [LMADDim (TPrimExp t v)] -> Permutation
forall a b. (a -> b) -> a -> b
$ LMAD (TPrimExp t v) -> [LMADDim (TPrimExp t v)]
forall num. LMAD num -> [LMADDim num]
lmadDims LMAD (TPrimExp t v)
lmad) = do
[TPrimExp t (Ext v)]
oshp' <- (TPrimExp t v
-> StateT [TPrimExp t v] Identity (TPrimExp t (Ext v)))
-> [TPrimExp t v]
-> StateT [TPrimExp t v] Identity [TPrimExp t (Ext v)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM TPrimExp t v -> StateT [TPrimExp t v] Identity (TPrimExp t (Ext v))
forall t v.
TPrimExp t v -> State [TPrimExp t v] (TPrimExp t (Ext v))
existentializeExp [TPrimExp t v]
oshp
TPrimExp t (Ext v)
lmadOffset' <- TPrimExp t v -> StateT [TPrimExp t v] Identity (TPrimExp t (Ext v))
forall t v.
TPrimExp t v -> State [TPrimExp t v] (TPrimExp t (Ext v))
existentializeExp (TPrimExp t v
-> StateT [TPrimExp t v] Identity (TPrimExp t (Ext v)))
-> TPrimExp t v
-> StateT [TPrimExp t v] Identity (TPrimExp t (Ext v))
forall a b. (a -> b) -> a -> b
$ LMAD (TPrimExp t v) -> TPrimExp t v
forall num. LMAD num -> num
lmadOffset LMAD (TPrimExp t v)
lmad
[LMADDim (TPrimExp t (Ext v))]
lmadDims' <- (LMADDim (TPrimExp t v)
-> StateT [TPrimExp t v] Identity (LMADDim (TPrimExp t (Ext v))))
-> [LMADDim (TPrimExp t v)]
-> StateT [TPrimExp t v] Identity [LMADDim (TPrimExp t (Ext v))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM LMADDim (TPrimExp t v)
-> StateT [TPrimExp t v] Identity (LMADDim (TPrimExp t (Ext v)))
forall t v.
LMADDim (TPrimExp t v)
-> State [TPrimExp t v] (LMADDim (TPrimExp t (Ext v)))
existentializeLMADDim ([LMADDim (TPrimExp t v)]
-> StateT [TPrimExp t v] Identity [LMADDim (TPrimExp t (Ext v))])
-> [LMADDim (TPrimExp t v)]
-> StateT [TPrimExp t v] Identity [LMADDim (TPrimExp t (Ext v))]
forall a b. (a -> b) -> a -> b
$ LMAD (TPrimExp t v) -> [LMADDim (TPrimExp t v)]
forall num. LMAD num -> [LMADDim num]
lmadDims LMAD (TPrimExp t v)
lmad
let lmad' :: LMAD (TPrimExp t (Ext v))
lmad' = TPrimExp t (Ext v)
-> [LMADDim (TPrimExp t (Ext v))] -> LMAD (TPrimExp t (Ext v))
forall num. num -> [LMADDim num] -> LMAD num
LMAD TPrimExp t (Ext v)
lmadOffset' [LMADDim (TPrimExp t (Ext v))]
lmadDims'
Maybe (IxFun (TPrimExp t (Ext v)))
-> State [TPrimExp t v] (Maybe (IxFun (TPrimExp t (Ext v))))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (IxFun (TPrimExp t (Ext v)))
-> State [TPrimExp t v] (Maybe (IxFun (TPrimExp t (Ext v)))))
-> Maybe (IxFun (TPrimExp t (Ext v)))
-> State [TPrimExp t v] (Maybe (IxFun (TPrimExp t (Ext v))))
forall a b. (a -> b) -> a -> b
$ IxFun (TPrimExp t (Ext v)) -> Maybe (IxFun (TPrimExp t (Ext v)))
forall a. a -> Maybe a
Just (IxFun (TPrimExp t (Ext v)) -> Maybe (IxFun (TPrimExp t (Ext v))))
-> IxFun (TPrimExp t (Ext v)) -> Maybe (IxFun (TPrimExp t (Ext v)))
forall a b. (a -> b) -> a -> b
$ NonEmpty (LMAD (TPrimExp t (Ext v)))
-> [TPrimExp t (Ext v)] -> Bool -> IxFun (TPrimExp t (Ext v))
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (LMAD (TPrimExp t (Ext v))
lmad' LMAD (TPrimExp t (Ext v))
-> [LMAD (TPrimExp t (Ext v))]
-> NonEmpty (LMAD (TPrimExp t (Ext v)))
forall a. a -> [a] -> NonEmpty a
:| []) [TPrimExp t (Ext v)]
oshp' Bool
True
where
existentializeLMADDim ::
LMADDim (TPrimExp t v) ->
State [TPrimExp t v] (LMADDim (TPrimExp t (Ext v)))
existentializeLMADDim :: LMADDim (TPrimExp t v)
-> State [TPrimExp t v] (LMADDim (TPrimExp t (Ext v)))
existentializeLMADDim (LMADDim TPrimExp t v
str TPrimExp t v
rot TPrimExp t v
shp Int
perm Monotonicity
mon) = do
TPrimExp t (Ext v)
stride' <- TPrimExp t v -> State [TPrimExp t v] (TPrimExp t (Ext v))
forall t v.
TPrimExp t v -> State [TPrimExp t v] (TPrimExp t (Ext v))
existentializeExp TPrimExp t v
str
TPrimExp t (Ext v)
shape' <- TPrimExp t v -> State [TPrimExp t v] (TPrimExp t (Ext v))
forall t v.
TPrimExp t v -> State [TPrimExp t v] (TPrimExp t (Ext v))
existentializeExp TPrimExp t v
shp
LMADDim (TPrimExp t (Ext v))
-> State [TPrimExp t v] (LMADDim (TPrimExp t (Ext v)))
forall (m :: * -> *) a. Monad m => a -> m a
return (LMADDim (TPrimExp t (Ext v))
-> State [TPrimExp t v] (LMADDim (TPrimExp t (Ext v))))
-> LMADDim (TPrimExp t (Ext v))
-> State [TPrimExp t v] (LMADDim (TPrimExp t (Ext v)))
forall a b. (a -> b) -> a -> b
$ TPrimExp t (Ext v)
-> TPrimExp t (Ext v)
-> TPrimExp t (Ext v)
-> Int
-> Monotonicity
-> LMADDim (TPrimExp t (Ext v))
forall num. num -> num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim TPrimExp t (Ext v)
stride' ((v -> Ext v) -> TPrimExp t v -> TPrimExp t (Ext v)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap v -> Ext v
forall a. a -> Ext a
Free TPrimExp t v
rot) TPrimExp t (Ext v)
shape' Int
perm Monotonicity
mon
existentialize IxFun (TPrimExp t v)
_ = Maybe (IxFun (TPrimExp t (Ext v)))
-> State [TPrimExp t v] (Maybe (IxFun (TPrimExp t (Ext v))))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (IxFun (TPrimExp t (Ext v)))
forall a. Maybe a
Nothing
closeEnough :: IxFun num -> IxFun num -> Bool
closeEnough :: IxFun num -> IxFun num -> Bool
closeEnough IxFun num
ixf1 IxFun num
ixf2 =
([num] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (IxFun num -> [num]
forall a. IxFun a -> [a]
base IxFun num
ixf1) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [num] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (IxFun num -> [num]
forall a. IxFun a -> [a]
base IxFun num
ixf2))
Bool -> Bool -> Bool
&& (NonEmpty (LMAD num) -> Int
forall a. NonEmpty a -> Int
NE.length (IxFun num -> NonEmpty (LMAD num)
forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs IxFun num
ixf1) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== NonEmpty (LMAD num) -> Int
forall a. NonEmpty a -> Int
NE.length (IxFun num -> NonEmpty (LMAD num)
forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs IxFun num
ixf2))
Bool -> Bool -> Bool
&& ((LMAD num, LMAD num) -> Bool)
-> NonEmpty (LMAD num, LMAD num) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (LMAD num, LMAD num) -> Bool
forall num. (LMAD num, LMAD num) -> Bool
closeEnoughLMADs (NonEmpty (LMAD num)
-> NonEmpty (LMAD num) -> NonEmpty (LMAD num, LMAD num)
forall a b. NonEmpty a -> NonEmpty b -> NonEmpty (a, b)
NE.zip (IxFun num -> NonEmpty (LMAD num)
forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs IxFun num
ixf1) (IxFun num -> NonEmpty (LMAD num)
forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs IxFun num
ixf2))
where
closeEnoughLMADs :: (LMAD num, LMAD num) -> Bool
closeEnoughLMADs :: (LMAD num, LMAD num) -> Bool
closeEnoughLMADs (LMAD num
lmad1, LMAD num
lmad2) =
[LMADDim num] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad1) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [LMADDim num] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad2)
Bool -> Bool -> Bool
&& (LMADDim num -> Int) -> [LMADDim num] -> Permutation
forall a b. (a -> b) -> [a] -> [b]
map LMADDim num -> Int
forall num. LMADDim num -> Int
ldPerm (LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad1)
Permutation -> Permutation -> Bool
forall a. Eq a => a -> a -> Bool
== (LMADDim num -> Int) -> [LMADDim num] -> Permutation
forall a b. (a -> b) -> [a] -> [b]
map LMADDim num -> Int
forall num. LMADDim num -> Int
ldPerm (LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad2)
equivalent :: Eq num => IxFun num -> IxFun num -> Bool
equivalent :: IxFun num -> IxFun num -> Bool
equivalent IxFun num
ixf1 IxFun num
ixf2 =
NonEmpty (LMAD num) -> Int
forall a. NonEmpty a -> Int
NE.length (IxFun num -> NonEmpty (LMAD num)
forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs IxFun num
ixf1) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== NonEmpty (LMAD num) -> Int
forall a. NonEmpty a -> Int
NE.length (IxFun num -> NonEmpty (LMAD num)
forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs IxFun num
ixf2)
Bool -> Bool -> Bool
&& ((LMAD num, LMAD num) -> Bool)
-> NonEmpty (LMAD num, LMAD num) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (LMAD num, LMAD num) -> Bool
forall b. Eq b => (LMAD b, LMAD b) -> Bool
closeEnoughLMADs (NonEmpty (LMAD num)
-> NonEmpty (LMAD num) -> NonEmpty (LMAD num, LMAD num)
forall a b. NonEmpty a -> NonEmpty b -> NonEmpty (a, b)
NE.zip (IxFun num -> NonEmpty (LMAD num)
forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs IxFun num
ixf1) (IxFun num -> NonEmpty (LMAD num)
forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs IxFun num
ixf2))
where
closeEnoughLMADs :: (LMAD b, LMAD b) -> Bool
closeEnoughLMADs (LMAD b
lmad1, LMAD b
lmad2) =
[LMADDim b] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (LMAD b -> [LMADDim b]
forall num. LMAD num -> [LMADDim num]
lmadDims LMAD b
lmad1) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [LMADDim b] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (LMAD b -> [LMADDim b]
forall num. LMAD num -> [LMADDim num]
lmadDims LMAD b
lmad2)
Bool -> Bool -> Bool
&& (LMADDim b -> Int) -> [LMADDim b] -> Permutation
forall a b. (a -> b) -> [a] -> [b]
map LMADDim b -> Int
forall num. LMADDim num -> Int
ldPerm (LMAD b -> [LMADDim b]
forall num. LMAD num -> [LMADDim num]
lmadDims LMAD b
lmad1)
Permutation -> Permutation -> Bool
forall a. Eq a => a -> a -> Bool
== (LMADDim b -> Int) -> [LMADDim b] -> Permutation
forall a b. (a -> b) -> [a] -> [b]
map LMADDim b -> Int
forall num. LMADDim num -> Int
ldPerm (LMAD b -> [LMADDim b]
forall num. LMAD num -> [LMADDim num]
lmadDims LMAD b
lmad2)
Bool -> Bool -> Bool
&& LMAD b -> b
forall num. LMAD num -> num
lmadOffset LMAD b
lmad1
b -> b -> Bool
forall a. Eq a => a -> a -> Bool
== LMAD b -> b
forall num. LMAD num -> num
lmadOffset LMAD b
lmad2
Bool -> Bool -> Bool
&& (LMADDim b -> b) -> [LMADDim b] -> [b]
forall a b. (a -> b) -> [a] -> [b]
map LMADDim b -> b
forall num. LMADDim num -> num
ldStride (LMAD b -> [LMADDim b]
forall num. LMAD num -> [LMADDim num]
lmadDims LMAD b
lmad1)
[b] -> [b] -> Bool
forall a. Eq a => a -> a -> Bool
== (LMADDim b -> b) -> [LMADDim b] -> [b]
forall a b. (a -> b) -> [a] -> [b]
map LMADDim b -> b
forall num. LMADDim num -> num
ldStride (LMAD b -> [LMADDim b]
forall num. LMAD num -> [LMADDim num]
lmadDims LMAD b
lmad2)
Bool -> Bool -> Bool
&& (LMADDim b -> b) -> [LMADDim b] -> [b]
forall a b. (a -> b) -> [a] -> [b]
map LMADDim b -> b
forall num. LMADDim num -> num
ldRotate (LMAD b -> [LMADDim b]
forall num. LMAD num -> [LMADDim num]
lmadDims LMAD b
lmad1)
[b] -> [b] -> Bool
forall a. Eq a => a -> a -> Bool
== (LMADDim b -> b) -> [LMADDim b] -> [b]
forall a b. (a -> b) -> [a] -> [b]
map LMADDim b -> b
forall num. LMADDim num -> num
ldRotate (LMAD b -> [LMADDim b]
forall num. LMAD num -> [LMADDim num]
lmadDims LMAD b
lmad2)