{-# OPTIONS_GHC -fno-warn-redundant-constraints #-}
module Futhark.IR.Mem.IxFun
( IxFun (..),
Shape,
LMAD (..),
LMADDim (..),
Monotonicity (..),
index,
mkExistential,
iota,
iotaOffset,
permute,
reshape,
coerce,
slice,
flatSlice,
rebase,
shape,
rank,
linearWithOffset,
rearrangeWithOffset,
isDirect,
isLinear,
substituteInIxFun,
existentialize,
closeEnough,
equivalent,
dynamicEqualsLMAD,
)
where
import Control.Category
import Control.Monad.Identity
import Control.Monad.State
import Control.Monad.Writer
import Data.Function (on, (&))
import Data.List (sort, sortBy, zip4, zipWith4)
import Data.List.NonEmpty (NonEmpty (..))
import Data.List.NonEmpty qualified as NE
import Data.Map.Strict qualified as M
import Data.Maybe (isJust)
import Futhark.Analysis.PrimExp
import Futhark.Analysis.PrimExp.Convert (substituteInPrimExp)
import Futhark.IR.Prop
import Futhark.IR.Syntax
( DimIndex (..),
FlatDimIndex (..),
FlatSlice (..),
Slice (..),
dimFix,
flatSliceDims,
flatSliceStrides,
unitSlice,
)
import Futhark.IR.Syntax.Core (Ext (..))
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import Futhark.Util.IntegralExp
import Futhark.Util.Pretty
import Prelude hiding (id, mod, (.))
type Shape num = [num]
type Indices num = [num]
type Permutation = [Int]
data Monotonicity
=
Inc
|
Dec
|
Unknown
deriving (Int -> Monotonicity -> ShowS
[Monotonicity] -> ShowS
Monotonicity -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Monotonicity] -> ShowS
$cshowList :: [Monotonicity] -> ShowS
show :: Monotonicity -> String
$cshow :: Monotonicity -> String
showsPrec :: Int -> Monotonicity -> ShowS
$cshowsPrec :: Int -> Monotonicity -> ShowS
Show, Monotonicity -> Monotonicity -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Monotonicity -> Monotonicity -> Bool
$c/= :: Monotonicity -> Monotonicity -> Bool
== :: Monotonicity -> Monotonicity -> Bool
$c== :: Monotonicity -> Monotonicity -> Bool
Eq)
data LMADDim num = LMADDim
{ forall num. LMADDim num -> num
ldStride :: num,
forall num. LMADDim num -> num
ldShape :: num,
forall num. LMADDim num -> Int
ldPerm :: Int,
forall num. LMADDim num -> Monotonicity
ldMon :: Monotonicity
}
deriving (Int -> LMADDim num -> ShowS
forall num. Show num => Int -> LMADDim num -> ShowS
forall num. Show num => [LMADDim num] -> ShowS
forall num. Show num => LMADDim num -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [LMADDim num] -> ShowS
$cshowList :: forall num. Show num => [LMADDim num] -> ShowS
show :: LMADDim num -> String
$cshow :: forall num. Show num => LMADDim num -> String
showsPrec :: Int -> LMADDim num -> ShowS
$cshowsPrec :: forall num. Show num => Int -> LMADDim num -> ShowS
Show, LMADDim num -> LMADDim num -> Bool
forall num. Eq num => LMADDim num -> LMADDim num -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: LMADDim num -> LMADDim num -> Bool
$c/= :: forall num. Eq num => LMADDim num -> LMADDim num -> Bool
== :: LMADDim num -> LMADDim num -> Bool
$c== :: forall num. Eq num => LMADDim num -> LMADDim num -> Bool
Eq)
data LMAD num = LMAD
{ forall num. LMAD num -> num
lmadOffset :: num,
forall num. LMAD num -> [LMADDim num]
lmadDims :: [LMADDim num]
}
deriving (Int -> LMAD num -> ShowS
forall num. Show num => Int -> LMAD num -> ShowS
forall num. Show num => [LMAD num] -> ShowS
forall num. Show num => LMAD num -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [LMAD num] -> ShowS
$cshowList :: forall num. Show num => [LMAD num] -> ShowS
show :: LMAD num -> String
$cshow :: forall num. Show num => LMAD num -> String
showsPrec :: Int -> LMAD num -> ShowS
$cshowsPrec :: forall num. Show num => Int -> LMAD num -> ShowS
Show, LMAD num -> LMAD num -> Bool
forall num. Eq num => LMAD num -> LMAD num -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: LMAD num -> LMAD num -> Bool
$c/= :: forall num. Eq num => LMAD num -> LMAD num -> Bool
== :: LMAD num -> LMAD num -> Bool
$c== :: forall num. Eq num => LMAD num -> LMAD num -> Bool
Eq)
data IxFun num = IxFun
{ forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs :: NonEmpty (LMAD num),
forall num. IxFun num -> Shape num
base :: Shape num,
forall num. IxFun num -> Bool
contiguous :: Bool
}
deriving (Int -> IxFun num -> ShowS
forall num. Show num => Int -> IxFun num -> ShowS
forall num. Show num => [IxFun num] -> ShowS
forall num. Show num => IxFun num -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [IxFun num] -> ShowS
$cshowList :: forall num. Show num => [IxFun num] -> ShowS
show :: IxFun num -> String
$cshow :: forall num. Show num => IxFun num -> String
showsPrec :: Int -> IxFun num -> ShowS
$cshowsPrec :: forall num. Show num => Int -> IxFun num -> ShowS
Show, IxFun num -> IxFun num -> Bool
forall num. Eq num => IxFun num -> IxFun num -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: IxFun num -> IxFun num -> Bool
$c/= :: forall num. Eq num => IxFun num -> IxFun num -> Bool
== :: IxFun num -> IxFun num -> Bool
$c== :: forall num. Eq num => IxFun num -> IxFun num -> Bool
Eq)
instance Pretty Monotonicity where
pretty :: forall ann. Monotonicity -> Doc ann
pretty = forall a ann. Pretty a => a -> Doc ann
pretty forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. Show a => a -> String
show
instance Pretty num => Pretty (LMAD num) where
pretty :: forall ann. LMAD num -> Doc ann
pretty (LMAD num
offset [LMADDim num]
dims) =
forall ann. Doc ann -> Doc ann
braces forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. [Doc a] -> Doc a
semistack forall a b. (a -> b) -> a -> b
$
[ Doc ann
"offset:" forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall ann. Doc ann -> Doc ann
group (forall a ann. Pretty a => a -> Doc ann
pretty num
offset),
Doc ann
"strides:" forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall {b} {ann}. Pretty b => (LMADDim num -> b) -> Doc ann
p forall num. LMADDim num -> num
ldStride,
Doc ann
"shape:" forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall {b} {ann}. Pretty b => (LMADDim num -> b) -> Doc ann
p forall num. LMADDim num -> num
ldShape,
Doc ann
"permutation:" forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall {b} {ann}. Pretty b => (LMADDim num -> b) -> Doc ann
p forall num. LMADDim num -> Int
ldPerm,
Doc ann
"monotonicity:" forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall {b} {ann}. Pretty b => (LMADDim num -> b) -> Doc ann
p forall num. LMADDim num -> Monotonicity
ldMon
]
where
p :: (LMADDim num -> b) -> Doc ann
p LMADDim num -> b
f = forall ann. Doc ann -> Doc ann
group forall a b. (a -> b) -> a -> b
$ forall ann. Doc ann -> Doc ann
brackets forall a b. (a -> b) -> a -> b
$ forall ann. Doc ann -> Doc ann
align forall a b. (a -> b) -> a -> b
$ forall a. [Doc a] -> Doc a
commasep forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall a ann. Pretty a => a -> Doc ann
pretty forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. LMADDim num -> b
f) [LMADDim num]
dims
instance Pretty num => Pretty (IxFun num) where
pretty :: forall ann. IxFun num -> Doc ann
pretty (IxFun NonEmpty (LMAD num)
lmads Shape num
oshp Bool
cg) =
forall ann. Doc ann -> Doc ann
braces forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. [Doc a] -> Doc a
semistack forall a b. (a -> b) -> a -> b
$
[ Doc ann
"base:" forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall ann. Doc ann -> Doc ann
brackets (forall a. [Doc a] -> Doc a
commasep forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty Shape num
oshp),
Doc ann
"contiguous:" forall ann. Doc ann -> Doc ann -> Doc ann
<+> if Bool
cg then Doc ann
"true" else Doc ann
"false",
Doc ann
"LMADs:" forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall ann. Doc ann -> Doc ann
brackets (forall a. [Doc a] -> Doc a
commastack forall a b. (a -> b) -> a -> b
$ forall a. NonEmpty a -> [a]
NE.toList forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> NonEmpty a -> NonEmpty b
NE.map forall a ann. Pretty a => a -> Doc ann
pretty NonEmpty (LMAD num)
lmads)
]
instance Substitute num => Substitute (LMAD num) where
substituteNames :: Map VName VName -> LMAD num -> LMAD num
substituteNames Map VName VName
substs = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a -> b) -> a -> b
$ forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs
instance Substitute num => Substitute (IxFun num) where
substituteNames :: Map VName VName -> IxFun num -> IxFun num
substituteNames Map VName VName
substs = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a -> b) -> a -> b
$ forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs
instance Substitute num => Rename (LMAD num) where
rename :: LMAD num -> RenameM (LMAD num)
rename = forall a. Substitute a => a -> RenameM a
substituteRename
instance Substitute num => Rename (IxFun num) where
rename :: IxFun num -> RenameM (IxFun num)
rename = forall a. Substitute a => a -> RenameM a
substituteRename
instance FreeIn num => FreeIn (LMAD num) where
freeIn' :: LMAD num -> FV
freeIn' = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap forall a. FreeIn a => a -> FV
freeIn'
instance FreeIn num => FreeIn (IxFun num) where
freeIn' :: IxFun num -> FV
freeIn' = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap forall a. FreeIn a => a -> FV
freeIn'
instance Functor LMAD where
fmap :: forall a b. (a -> b) -> LMAD a -> LMAD b
fmap a -> b
f = forall a. Identity a -> a
runIdentity forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. a -> b
f)
instance Functor IxFun where
fmap :: forall a b. (a -> b) -> IxFun a -> IxFun b
fmap a -> b
f = forall a. Identity a -> a
runIdentity forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. a -> b
f)
instance Foldable LMAD where
foldMap :: forall m a. Monoid m => (a -> m) -> LMAD a -> m
foldMap a -> m
f = forall w a. Writer w a -> w
execWriter forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. a -> m
f)
instance Foldable IxFun where
foldMap :: forall m a. Monoid m => (a -> m) -> IxFun a -> m
foldMap a -> m
f = forall w a. Writer w a -> w
execWriter forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. a -> m
f)
instance Traversable LMAD where
traverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> LMAD a -> f (LMAD b)
traverse a -> f b
f (LMAD a
offset [LMADDim a]
dims) =
forall num. num -> [LMADDim num] -> LMAD num
LMAD forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> f b
f a
offset forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse LMADDim a -> f (LMADDim b)
f' [LMADDim a]
dims
where
f' :: LMADDim a -> f (LMADDim b)
f' (LMADDim a
s a
n Int
p Monotonicity
m) = forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> f b
f a
s forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> a -> f b
f a
n forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
p forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Monotonicity
m
instance Traversable IxFun where
traverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> IxFun a -> f (IxFun b)
traverse a -> f b
f (IxFun NonEmpty (LMAD a)
lmads Shape a
oshp Bool
cg) =
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse a -> f b
f) NonEmpty (LMAD a)
lmads forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse a -> f b
f Shape a
oshp forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
cg
(++@) :: [a] -> NonEmpty a -> NonEmpty a
[a]
es ++@ :: forall a. [a] -> NonEmpty a -> NonEmpty a
++@ (a
ne :| [a]
nes) = case [a]
es of
a
e : [a]
es' -> a
e forall a. a -> [a] -> NonEmpty a
:| [a]
es' forall a. [a] -> [a] -> [a]
++ [a
ne] forall a. [a] -> [a] -> [a]
++ [a]
nes
[] -> a
ne forall a. a -> [a] -> NonEmpty a
:| [a]
nes
(@++@) :: NonEmpty a -> NonEmpty a -> NonEmpty a
(a
x :| [a]
xs) @++@ :: forall a. NonEmpty a -> NonEmpty a -> NonEmpty a
@++@ (a
y :| [a]
ys) = a
x forall a. a -> [a] -> NonEmpty a
:| [a]
xs forall a. [a] -> [a] -> [a]
++ [a
y] forall a. [a] -> [a] -> [a]
++ [a]
ys
invertMonotonicity :: Monotonicity -> Monotonicity
invertMonotonicity :: Monotonicity -> Monotonicity
invertMonotonicity Monotonicity
Inc = Monotonicity
Dec
invertMonotonicity Monotonicity
Dec = Monotonicity
Inc
invertMonotonicity Monotonicity
Unknown = Monotonicity
Unknown
lmadPermutation :: LMAD num -> Permutation
lmadPermutation :: forall num. LMAD num -> Permutation
lmadPermutation = forall a b. (a -> b) -> [a] -> [b]
map forall num. LMADDim num -> Int
ldPerm forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall num. LMAD num -> [LMADDim num]
lmadDims
setLMADPermutation :: Permutation -> LMAD num -> LMAD num
setLMADPermutation :: forall num. Permutation -> LMAD num -> LMAD num
setLMADPermutation Permutation
perm LMAD num
lmad =
LMAD num
lmad {lmadDims :: [LMADDim num]
lmadDims = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\LMADDim num
dim Int
p -> LMADDim num
dim {ldPerm :: Int
ldPerm = Int
p}) (forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad) Permutation
perm}
setLMADShape :: Shape num -> LMAD num -> LMAD num
setLMADShape :: forall num. Shape num -> LMAD num -> LMAD num
setLMADShape Shape num
shp LMAD num
lmad = LMAD num
lmad {lmadDims :: [LMADDim num]
lmadDims = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\LMADDim num
dim num
s -> LMADDim num
dim {ldShape :: num
ldShape = num
s}) (forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad) Shape num
shp}
substituteInLMAD ::
Ord a =>
M.Map a (PrimExp a) ->
LMAD (PrimExp a) ->
LMAD (PrimExp a)
substituteInLMAD :: forall a.
Ord a =>
Map a (PrimExp a) -> LMAD (PrimExp a) -> LMAD (PrimExp a)
substituteInLMAD Map a (PrimExp a)
tab (LMAD PrimExp a
offset [LMADDim (PrimExp a)]
dims) =
let offset' :: PrimExp a
offset' = forall v. Ord v => Map v (PrimExp v) -> PrimExp v -> PrimExp v
substituteInPrimExp Map a (PrimExp a)
tab PrimExp a
offset
dims' :: [LMADDim (PrimExp a)]
dims' =
forall a b. (a -> b) -> [a] -> [b]
map
( \(LMADDim PrimExp a
s PrimExp a
n Int
p Monotonicity
m) ->
forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim
(forall v. Ord v => Map v (PrimExp v) -> PrimExp v -> PrimExp v
substituteInPrimExp Map a (PrimExp a)
tab PrimExp a
s)
(forall v. Ord v => Map v (PrimExp v) -> PrimExp v -> PrimExp v
substituteInPrimExp Map a (PrimExp a)
tab PrimExp a
n)
Int
p
Monotonicity
m
)
[LMADDim (PrimExp a)]
dims
in forall num. num -> [LMADDim num] -> LMAD num
LMAD PrimExp a
offset' [LMADDim (PrimExp a)]
dims'
substituteInIxFun ::
Ord a =>
M.Map a (TPrimExp t a) ->
IxFun (TPrimExp t a) ->
IxFun (TPrimExp t a)
substituteInIxFun :: forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a)
-> IxFun (TPrimExp t a) -> IxFun (TPrimExp t a)
substituteInIxFun Map a (TPrimExp t a)
tab (IxFun NonEmpty (LMAD (TPrimExp t a))
lmads Shape (TPrimExp t a)
oshp Bool
cg) =
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun
(forall a b. (a -> b) -> NonEmpty a -> NonEmpty b
NE.map (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a.
Ord a =>
Map a (PrimExp a) -> LMAD (PrimExp a) -> LMAD (PrimExp a)
substituteInLMAD Map a (PrimExp a)
tab' forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped) NonEmpty (LMAD (TPrimExp t a))
lmads)
(forall a b. (a -> b) -> [a] -> [b]
map (forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall v. Ord v => Map v (PrimExp v) -> PrimExp v -> PrimExp v
substituteInPrimExp Map a (PrimExp a)
tab' forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped) Shape (TPrimExp t a)
oshp)
Bool
cg
where
tab' :: Map a (PrimExp a)
tab' = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped Map a (TPrimExp t a)
tab
isDirect :: (Eq num, IntegralExp num) => IxFun num -> Bool
isDirect :: forall num. (Eq num, IntegralExp num) => IxFun num -> Bool
isDirect ixfun :: IxFun num
ixfun@(IxFun (LMAD num
offset [LMADDim num]
dims :| []) Shape num
oshp Bool
True) =
let strides_expected :: Shape num
strides_expected = forall a. [a] -> [a]
reverse forall a b. (a -> b) -> a -> b
$ forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl forall a. Num a => a -> a -> a
(*) num
1 (forall a. [a] -> [a]
reverse (forall a. [a] -> [a]
tail Shape num
oshp))
in forall num. IxFun num -> Bool
hasContiguousPerm IxFun num
ixfun
Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape num
oshp forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims
Bool -> Bool -> Bool
&& num
offset forall a. Eq a => a -> a -> Bool
== num
0
Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all
(\(LMADDim num
s num
n Int
p Monotonicity
_, Int
m, num
d, num
se) -> num
s forall a. Eq a => a -> a -> Bool
== num
se Bool -> Bool -> Bool
&& num
n forall a. Eq a => a -> a -> Bool
== num
d Bool -> Bool -> Bool
&& Int
p forall a. Eq a => a -> a -> Bool
== Int
m)
(forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [LMADDim num]
dims [Int
0 .. forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims forall a. Num a => a -> a -> a
- Int
1] Shape num
oshp Shape num
strides_expected)
isDirect IxFun num
_ = Bool
False
hasContiguousPerm :: IxFun num -> Bool
hasContiguousPerm :: forall num. IxFun num -> Bool
hasContiguousPerm (IxFun (LMAD num
lmad :| []) Shape num
_ Bool
_) =
let perm :: Permutation
perm = forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad
in Permutation
perm forall a. Eq a => a -> a -> Bool
== forall a. Ord a => [a] -> [a]
sort Permutation
perm
hasContiguousPerm IxFun num
_ = Bool
False
shape :: (Eq num, IntegralExp num) => IxFun num -> Shape num
shape :: forall num. (Eq num, IntegralExp num) => IxFun num -> Shape num
shape (IxFun (LMAD num
lmad :| [LMAD num]
_) Shape num
_ Bool
_) =
forall a. Permutation -> [a] -> [a]
permuteFwd (forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad) forall a b. (a -> b) -> a -> b
$ forall num. (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShapeBase LMAD num
lmad
lmadShape :: (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShape :: forall num. (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShape LMAD num
lmad = forall a. Permutation -> [a] -> [a]
permuteInv (forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad) forall a b. (a -> b) -> a -> b
$ forall num. (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShapeBase LMAD num
lmad
lmadShapeBase :: (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShapeBase :: forall num. (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShapeBase = forall a b. (a -> b) -> [a] -> [b]
map forall num. LMADDim num -> num
ldShape forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall num. LMAD num -> [LMADDim num]
lmadDims
index ::
(IntegralExp num, Eq num) =>
IxFun num ->
Indices num ->
num
index :: forall num.
(IntegralExp num, Eq num) =>
IxFun num -> Indices num -> num
index = forall num.
(IntegralExp num, Eq num) =>
NonEmpty (LMAD num) -> Indices num -> num
indexFromLMADs forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs
where
indexFromLMADs ::
(IntegralExp num, Eq num) =>
NonEmpty (LMAD num) ->
Indices num ->
num
indexFromLMADs :: forall num.
(IntegralExp num, Eq num) =>
NonEmpty (LMAD num) -> Indices num -> num
indexFromLMADs (LMAD num
lmad :| []) Indices num
inds = forall num.
(IntegralExp num, Eq num) =>
LMAD num -> Indices num -> num
indexLMAD LMAD num
lmad Indices num
inds
indexFromLMADs (LMAD num
lmad1 :| LMAD num
lmad2 : [LMAD num]
lmads) Indices num
inds =
let i_flat :: num
i_flat = forall num.
(IntegralExp num, Eq num) =>
LMAD num -> Indices num -> num
indexLMAD LMAD num
lmad1 Indices num
inds
new_inds :: Indices num
new_inds = forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex (forall a. Permutation -> [a] -> [a]
permuteFwd (forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad2) forall a b. (a -> b) -> a -> b
$ forall num. (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShapeBase LMAD num
lmad2) num
i_flat
in forall num.
(IntegralExp num, Eq num) =>
NonEmpty (LMAD num) -> Indices num -> num
indexFromLMADs (LMAD num
lmad2 forall a. a -> [a] -> NonEmpty a
:| [LMAD num]
lmads) Indices num
new_inds
indexLMAD ::
(IntegralExp num, Eq num) =>
LMAD num ->
Indices num ->
num
indexLMAD :: forall num.
(IntegralExp num, Eq num) =>
LMAD num -> Indices num -> num
indexLMAD lmad :: LMAD num
lmad@(LMAD num
off [LMADDim num]
dims) Indices num
inds =
let prod :: num
prod =
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum forall a b. (a -> b) -> a -> b
$
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
forall num. (Eq num, IntegralExp num) => num -> num -> num
flatOneDim
(forall a b. (a -> b) -> [a] -> [b]
map forall num. LMADDim num -> num
ldStride [LMADDim num]
dims)
(forall a. Permutation -> [a] -> [a]
permuteInv (forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad) Indices num
inds)
in num
off forall a. Num a => a -> a -> a
+ num
prod
iotaOffset :: IntegralExp num => num -> Shape num -> IxFun num
iotaOffset :: forall num. IntegralExp num => num -> Shape num -> IxFun num
iotaOffset num
o Shape num
ns = forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (forall num.
IntegralExp num =>
Monotonicity -> num -> [num] -> LMAD num
makeRotIota Monotonicity
Inc num
o Shape num
ns forall a. a -> [a] -> NonEmpty a
:| []) Shape num
ns Bool
True
iota :: IntegralExp num => Shape num -> IxFun num
iota :: forall num. IntegralExp num => Shape num -> IxFun num
iota = forall num. IntegralExp num => num -> Shape num -> IxFun num
iotaOffset num
0
mkExistential :: Int -> [(Int, Monotonicity)] -> Bool -> Int -> IxFun (Ext a)
mkExistential :: forall a.
Int -> [(Int, Monotonicity)] -> Bool -> Int -> IxFun (Ext a)
mkExistential Int
basis_rank [(Int, Monotonicity)]
perm Bool
contig Int
start =
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (forall a. a -> NonEmpty a
NE.singleton forall {a}. LMAD (Ext a)
lmad) forall {a}. [Ext a]
basis Bool
contig
where
basis :: [Ext a]
basis = forall a. Int -> [a] -> [a]
take Int
basis_rank forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a. Int -> Ext a
Ext [Int
start forall a. Num a => a -> a -> a
+ Int
1 forall a. Num a => a -> a -> a
+ Int
dims_rank forall a. Num a => a -> a -> a
* Int
2 ..]
dims_rank :: Int
dims_rank = forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Int, Monotonicity)]
perm
lmad :: LMAD (Ext a)
lmad = forall num. num -> [LMADDim num] -> LMAD num
LMAD (forall a. Int -> Ext a
Ext Int
start) forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall {a}. (Int, Monotonicity) -> Int -> LMADDim (Ext a)
onDim [(Int, Monotonicity)]
perm [Int
0 ..]
onDim :: (Int, Monotonicity) -> Int -> LMADDim (Ext a)
onDim (Int
p, Monotonicity
mon) Int
i =
forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim (forall a. Int -> Ext a
Ext (Int
start forall a. Num a => a -> a -> a
+ Int
1 forall a. Num a => a -> a -> a
+ Int
i forall a. Num a => a -> a -> a
* Int
2)) (forall a. Int -> Ext a
Ext (Int
start forall a. Num a => a -> a -> a
+ Int
2 forall a. Num a => a -> a -> a
+ Int
i forall a. Num a => a -> a -> a
* Int
2)) Int
p Monotonicity
mon
permute ::
IntegralExp num =>
IxFun num ->
Permutation ->
IxFun num
permute :: forall num.
IntegralExp num =>
IxFun num -> Permutation -> IxFun num
permute (IxFun (LMAD num
lmad :| [LMAD num]
lmads) Shape num
oshp Bool
cg) Permutation
perm_new =
let perm_cur :: Permutation
perm_cur = forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad
perm :: Permutation
perm = forall a b. (a -> b) -> [a] -> [b]
map (Permutation
perm_cur !!) Permutation
perm_new
in forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (forall num. Permutation -> LMAD num -> LMAD num
setLMADPermutation Permutation
perm LMAD num
lmad forall a. a -> [a] -> NonEmpty a
:| [LMAD num]
lmads) Shape num
oshp Bool
cg
sliceOneLMAD ::
(Eq num, IntegralExp num) =>
IxFun num ->
Slice num ->
Maybe (IxFun num)
sliceOneLMAD :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> Maybe (IxFun num)
sliceOneLMAD (IxFun (lmad :: LMAD num
lmad@(LMAD num
_ [LMADDim num]
ldims) :| [LMAD num]
lmads) Shape num
oshp Bool
cg) (Slice [DimIndex num]
is) = do
let perm :: Permutation
perm = forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad
is' :: [DimIndex num]
is' = forall a. Permutation -> [a] -> [a]
permuteInv Permutation
perm [DimIndex num]
is
cg' :: Bool
cg' = Bool
cg Bool -> Bool -> Bool
&& forall num.
(Eq num, IntegralExp num) =>
LMAD num -> Slice num -> Bool
slicePreservesContiguous LMAD num
lmad (forall d. [DimIndex d] -> Slice d
Slice [DimIndex num]
is')
let lmad' :: LMAD num
lmad' = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl forall num.
(Eq num, IntegralExp num) =>
LMAD num -> (DimIndex num, LMADDim num) -> LMAD num
sliceOne (forall num. num -> [LMADDim num] -> LMAD num
LMAD (forall num. LMAD num -> num
lmadOffset LMAD num
lmad) []) forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [DimIndex num]
is' [LMADDim num]
ldims
perm' :: Permutation
perm' =
forall {t :: * -> *} {b} {t :: * -> *}.
(Foldable t, Ord b, Foldable t, Num b) =>
t b -> t b -> [b]
updatePerm Permutation
perm forall a b. (a -> b) -> a -> b
$
forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$
forall a. (a -> Bool) -> [a] -> [a]
filter (forall a. Maybe a -> Bool
isJust forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall d. DimIndex d -> Maybe d
dimFix forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b. (a, b) -> b
snd) forall a b. (a -> b) -> a -> b
$
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 .. forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex num]
is' forall a. Num a => a -> a -> a
- Int
1] [DimIndex num]
is'
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (forall num. Permutation -> LMAD num -> LMAD num
setLMADPermutation Permutation
perm' LMAD num
lmad' forall a. a -> [a] -> NonEmpty a
:| [LMAD num]
lmads) Shape num
oshp Bool
cg'
where
updatePerm :: t b -> t b -> [b]
updatePerm t b
ps t b
inds = forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap b -> [b]
decrease t b
ps
where
decrease :: b -> [b]
decrease b
p =
let f :: a -> b -> a
f a
n b
i
| b
i forall a. Eq a => a -> a -> Bool
== b
p = -a
1
| b
i forall a. Ord a => a -> a -> Bool
> b
p = a
n
| a
n forall a. Eq a => a -> a -> Bool
/= -a
1 = a
n forall a. Num a => a -> a -> a
+ a
1
| Bool
otherwise = a
n
d :: b
d = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl forall {a}. (Num a, Eq a) => a -> b -> a
f b
0 t b
inds
in [b
p forall a. Num a => a -> a -> a
- b
d | b
d forall a. Eq a => a -> a -> Bool
/= -b
1]
sliceOne ::
(Eq num, IntegralExp num) =>
LMAD num ->
(DimIndex num, LMADDim num) ->
LMAD num
sliceOne :: forall num.
(Eq num, IntegralExp num) =>
LMAD num -> (DimIndex num, LMADDim num) -> LMAD num
sliceOne (LMAD num
off [LMADDim num]
dims) (DimFix num
i, LMADDim num
s num
_x Int
_ Monotonicity
_) =
forall num. num -> [LMADDim num] -> LMAD num
LMAD (num
off forall a. Num a => a -> a -> a
+ forall num. (Eq num, IntegralExp num) => num -> num -> num
flatOneDim num
s num
i) [LMADDim num]
dims
sliceOne (LMAD num
off [LMADDim num]
dims) (DimSlice num
_ num
ne num
_, LMADDim num
0 num
_ Int
p Monotonicity
_) =
forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off ([LMADDim num]
dims forall a. [a] -> [a] -> [a]
++ [forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim num
0 num
ne Int
p Monotonicity
Unknown])
sliceOne (LMAD num
off [LMADDim num]
dims) (DimIndex num
dmind, dim :: LMADDim num
dim@(LMADDim num
_ num
n Int
_ Monotonicity
_))
| DimIndex num
dmind forall a. Eq a => a -> a -> Bool
== forall d. Num d => d -> d -> DimIndex d
unitSlice num
0 num
n = forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off ([LMADDim num]
dims forall a. [a] -> [a] -> [a]
++ [LMADDim num
dim])
sliceOne (LMAD num
off [LMADDim num]
dims) (DimIndex num
dmind, LMADDim num
s num
n Int
p Monotonicity
m)
| DimIndex num
dmind forall a. Eq a => a -> a -> Bool
== forall d. d -> d -> d -> DimIndex d
DimSlice (num
n forall a. Num a => a -> a -> a
- num
1) num
n (-num
1) =
let off' :: num
off' = num
off forall a. Num a => a -> a -> a
+ forall num. (Eq num, IntegralExp num) => num -> num -> num
flatOneDim num
s (num
n forall a. Num a => a -> a -> a
- num
1)
in forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off' ([LMADDim num]
dims forall a. [a] -> [a] -> [a]
++ [forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim (num
s forall a. Num a => a -> a -> a
* (-num
1)) num
n Int
p (Monotonicity -> Monotonicity
invertMonotonicity Monotonicity
m)])
sliceOne (LMAD num
off [LMADDim num]
dims) (DimSlice num
b num
ne num
0, LMADDim num
s num
_ Int
p Monotonicity
_) =
forall num. num -> [LMADDim num] -> LMAD num
LMAD (num
off forall a. Num a => a -> a -> a
+ forall num. (Eq num, IntegralExp num) => num -> num -> num
flatOneDim num
s num
b) ([LMADDim num]
dims forall a. [a] -> [a] -> [a]
++ [forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim num
0 num
ne Int
p Monotonicity
Unknown])
sliceOne (LMAD num
off [LMADDim num]
dims) (DimSlice num
bs num
ns num
ss, LMADDim num
s num
_ Int
p Monotonicity
m) =
let m' :: Monotonicity
m' = case forall e. IntegralExp e => e -> Maybe Int
sgn num
ss of
Just Int
1 -> Monotonicity
m
Just (-1) -> Monotonicity -> Monotonicity
invertMonotonicity Monotonicity
m
Maybe Int
_ -> Monotonicity
Unknown
in forall num. num -> [LMADDim num] -> LMAD num
LMAD (num
off forall a. Num a => a -> a -> a
+ num
s forall a. Num a => a -> a -> a
* num
bs) ([LMADDim num]
dims forall a. [a] -> [a] -> [a]
++ [forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim (num
ss forall a. Num a => a -> a -> a
* num
s) num
ns Int
p Monotonicity
m'])
slicePreservesContiguous ::
(Eq num, IntegralExp num) =>
LMAD num ->
Slice num ->
Bool
slicePreservesContiguous :: forall num.
(Eq num, IntegralExp num) =>
LMAD num -> Slice num -> Bool
slicePreservesContiguous (LMAD num
_ [LMADDim num]
dims) (Slice [DimIndex num]
slc) =
let ([LMADDim num]
dims', [DimIndex num]
slc') =
forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$
forall a. (a -> Bool) -> [a] -> [a]
filter ((forall a. Eq a => a -> a -> Bool
/= num
0) forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall num. LMADDim num -> num
ldStride forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b. (a, b) -> a
fst) forall a b. (a -> b) -> a -> b
$
forall a b. [a] -> [b] -> [(a, b)]
zip [LMADDim num]
dims forall a b. (a -> b) -> a -> b
$
forall a b. (a -> b) -> [a] -> [b]
map forall num.
(Eq num, IntegralExp num) =>
DimIndex num -> DimIndex num
normIndex [DimIndex num]
slc
(Bool
_, Bool
success) =
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
( \(Bool
found, Bool
res) (DimIndex num
slcdim, LMADDim num
_ num
n Int
_ Monotonicity
_) ->
case (DimIndex num
slcdim, Bool
found) of
(DimFix {}, Bool
True) -> (Bool
found, Bool
False)
(DimFix {}, Bool
False) -> (Bool
found, Bool
res)
(DimSlice num
_ num
_ num
ds, Bool
False) ->
let res' :: Bool
res' = (num
ds forall a. Eq a => a -> a -> Bool
== num
1 Bool -> Bool -> Bool
|| num
ds forall a. Eq a => a -> a -> Bool
== -num
1)
in (Bool
True, Bool
res Bool -> Bool -> Bool
&& Bool
res')
(DimSlice num
_ num
ne num
ds, Bool
True) ->
let res' :: Bool
res' = (num
n forall a. Eq a => a -> a -> Bool
== num
ne) Bool -> Bool -> Bool
&& (num
ds forall a. Eq a => a -> a -> Bool
== num
1 Bool -> Bool -> Bool
|| num
ds forall a. Eq a => a -> a -> Bool
== -num
1)
in (Bool
found, Bool
res Bool -> Bool -> Bool
&& Bool
res')
)
(Bool
False, Bool
True)
forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [DimIndex num]
slc' [LMADDim num]
dims'
in Bool
success
normIndex ::
(Eq num, IntegralExp num) =>
DimIndex num ->
DimIndex num
normIndex :: forall num.
(Eq num, IntegralExp num) =>
DimIndex num -> DimIndex num
normIndex (DimSlice num
b num
1 num
_) = forall d. d -> DimIndex d
DimFix num
b
normIndex (DimSlice num
b num
_ num
0) = forall d. d -> DimIndex d
DimFix num
b
normIndex DimIndex num
d = DimIndex num
d
slice ::
(Eq num, IntegralExp num) =>
IxFun num ->
Slice num ->
IxFun num
slice :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
slice ixfun :: IxFun num
ixfun@(IxFun (lmad :: LMAD num
lmad@(LMAD num
_ [LMADDim num]
_) :| [LMAD num]
lmads) Shape num
oshp Bool
cg) Slice num
dim_slices
| forall d. Slice d -> [DimIndex d]
unSlice Slice num
dim_slices forall a. Eq a => a -> a -> Bool
== forall a b. (a -> b) -> [a] -> [b]
map (forall d. Num d => d -> d -> DimIndex d
unitSlice num
0) (forall num. (Eq num, IntegralExp num) => IxFun num -> Shape num
shape IxFun num
ixfun) = IxFun num
ixfun
| Just IxFun num
ixfun' <- forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> Maybe (IxFun num)
sliceOneLMAD IxFun num
ixfun Slice num
dim_slices = IxFun num
ixfun'
| Bool
otherwise =
case forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> Maybe (IxFun num)
sliceOneLMAD (forall num. IntegralExp num => Shape num -> IxFun num
iota (forall num. (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShape LMAD num
lmad)) Slice num
dim_slices of
Just (IxFun (LMAD num
lmad' :| []) Shape num
_ Bool
cg') ->
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (LMAD num
lmad' forall a. a -> [a] -> NonEmpty a
:| LMAD num
lmad forall a. a -> [a] -> [a]
: [LMAD num]
lmads) Shape num
oshp (Bool
cg Bool -> Bool -> Bool
&& Bool
cg')
Maybe (IxFun num)
_ -> forall a. HasCallStack => String -> a
error String
"slice: reached impossible case"
flatSlice ::
(Eq num, IntegralExp num) =>
IxFun num ->
FlatSlice num ->
IxFun num
flatSlice :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> FlatSlice num -> IxFun num
flatSlice ixfun :: IxFun num
ixfun@(IxFun (LMAD num
offset (LMADDim num
dim : [LMADDim num]
dims) :| [LMAD num]
lmads) Shape num
oshp Bool
cg) (FlatSlice num
new_offset [FlatDimIndex num]
is)
| forall num. IxFun num -> Bool
hasContiguousPerm IxFun num
ixfun =
let lmad :: LMAD num
lmad =
forall num. num -> [LMADDim num] -> LMAD num
LMAD
(num
offset forall a. Num a => a -> a -> a
+ num
new_offset forall a. Num a => a -> a -> a
* forall num. LMADDim num -> num
ldStride LMADDim num
dim)
(forall a b. (a -> b) -> [a] -> [b]
map (forall {num}.
(Eq num, Num num) =>
num -> FlatDimIndex num -> LMADDim num
helper forall a b. (a -> b) -> a -> b
$ forall num. LMADDim num -> num
ldStride LMADDim num
dim) [FlatDimIndex num]
is forall a. Semigroup a => a -> a -> a
<> [LMADDim num]
dims)
forall a b. a -> (a -> b) -> b
& forall num. Permutation -> LMAD num -> LMAD num
setLMADPermutation [Int
0 ..]
in forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (LMAD num
lmad forall a. a -> [a] -> NonEmpty a
:| [LMAD num]
lmads) Shape num
oshp Bool
cg
where
helper :: num -> FlatDimIndex num -> LMADDim num
helper num
s0 (FlatDimIndex num
n num
s) =
let new_mon :: Monotonicity
new_mon = if num
s0 forall a. Num a => a -> a -> a
* num
s forall a. Eq a => a -> a -> Bool
== num
1 then Monotonicity
Inc else Monotonicity
Unknown
in forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim (num
s0 forall a. Num a => a -> a -> a
* num
s) num
n Int
0 Monotonicity
new_mon
flatSlice (IxFun (LMAD num
lmad :| [LMAD num]
lmads) Shape num
oshp Bool
cg) s :: FlatSlice num
s@(FlatSlice num
new_offset [FlatDimIndex num]
_) =
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (forall num. num -> [LMADDim num] -> LMAD num
LMAD (num
new_offset forall a. Num a => a -> a -> a
* num
base_stride) ([LMADDim num]
new_dims forall a. Semigroup a => a -> a -> a
<> [LMADDim num]
tail_dims) forall a. a -> [a] -> NonEmpty a
:| LMAD num
lmad forall a. a -> [a] -> [a]
: [LMAD num]
lmads) Shape num
oshp Bool
cg
where
tail_shapes :: Shape num
tail_shapes = forall a. [a] -> [a]
tail forall a b. (a -> b) -> a -> b
$ forall num. (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShape LMAD num
lmad
base_stride :: num
base_stride = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product Shape num
tail_shapes
tail_strides :: Shape num
tail_strides = forall a. [a] -> [a]
tail forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b -> b) -> b -> [a] -> [b]
scanr forall a. Num a => a -> a -> a
(*) num
1 Shape num
tail_shapes
tail_dims :: [LMADDim num]
tail_dims = forall a b c d e.
(a -> b -> c -> d -> e) -> [a] -> [b] -> [c] -> [d] -> [e]
zipWith4 forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim Shape num
tail_strides Shape num
tail_shapes [forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape num
new_shapes ..] (forall a. a -> [a]
repeat Monotonicity
Inc)
new_shapes :: Shape num
new_shapes = forall d. FlatSlice d -> [d]
flatSliceDims FlatSlice num
s
new_strides :: Shape num
new_strides = forall a b. (a -> b) -> [a] -> [b]
map (forall a. Num a => a -> a -> a
* num
base_stride) forall a b. (a -> b) -> a -> b
$ forall d. FlatSlice d -> [d]
flatSliceStrides FlatSlice num
s
new_dims :: [LMADDim num]
new_dims = forall a b c d e.
(a -> b -> c -> d -> e) -> [a] -> [b] -> [c] -> [d] -> [e]
zipWith4 forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim Shape num
new_strides Shape num
new_shapes [Int
0 ..] (forall a. a -> [a]
repeat Monotonicity
Inc)
reshapeOneLMAD ::
(Eq num, IntegralExp num) =>
IxFun num ->
Shape num ->
Maybe (IxFun num)
reshapeOneLMAD :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Shape num -> Maybe (IxFun num)
reshapeOneLMAD ixfun :: IxFun num
ixfun@(IxFun (lmad :: LMAD num
lmad@(LMAD num
off [LMADDim num]
dims) :| [LMAD num]
lmads) Shape num
oldbase Bool
cg) Shape num
newshape = do
let perm :: Permutation
perm = forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad
dims_perm :: [LMADDim num]
dims_perm = forall a. Permutation -> [a] -> [a]
permuteFwd Permutation
perm [LMADDim num]
dims
mid_dims :: [LMADDim num]
mid_dims = forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims) [LMADDim num]
dims_perm
mon :: Monotonicity
mon = forall num. (Eq num, IntegralExp num) => IxFun num -> Monotonicity
ixfunMonotonicity IxFun num
ixfun
forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\(LMADDim num
s num
_ Int
_ Monotonicity
_) -> num
s forall a. Eq a => a -> a -> Bool
/= num
0) [LMADDim num]
mid_dims
Bool -> Bool -> Bool
&&
forall {b}. (Eq b, Num b, Enum b) => b -> [b] -> Bool
consecutive Int
0 (forall a b. (a -> b) -> [a] -> [b]
map forall num. LMADDim num -> Int
ldPerm [LMADDim num]
mid_dims)
Bool -> Bool -> Bool
&&
forall num. IxFun num -> Bool
hasContiguousPerm IxFun num
ixfun
Bool -> Bool -> Bool
&& Bool
cg
Bool -> Bool -> Bool
&& (Monotonicity
mon forall a. Eq a => a -> a -> Bool
== Monotonicity
Inc Bool -> Bool -> Bool
|| Monotonicity
mon forall a. Eq a => a -> a -> Bool
== Monotonicity
Dec)
let rsh_len :: Int
rsh_len = forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape num
newshape
diff :: Int
diff = forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape num
newshape forall a. Num a => a -> a -> a
- forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims
iota_shape :: Permutation
iota_shape = [Int
0 .. forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape num
newshape forall a. Num a => a -> a -> a
- Int
1]
perm' :: Permutation
perm' =
forall a b. (a -> b) -> [a] -> [b]
map
( \Int
i ->
let ind :: Int
ind = Int
i forall a. Num a => a -> a -> a
- Int
diff
in if (Int
i forall a. Ord a => a -> a -> Bool
>= Int
0) Bool -> Bool -> Bool
&& (Int
i forall a. Ord a => a -> a -> Bool
< Int
rsh_len)
then Int
i
else forall num. LMADDim num -> Int
ldPerm ([LMADDim num]
dims forall a. [a] -> Int -> a
!! Int
ind) forall a. Num a => a -> a -> a
+ Int
diff
)
Permutation
iota_shape
([(Int, num)]
support_inds, [a]
repeat_inds) =
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
(\([(Int, num)]
sup, [a]
rpt) (num
shpdim, Int
ip) -> ((Int
ip, num
shpdim) forall a. a -> [a] -> [a]
: [(Int, num)]
sup, [a]
rpt))
([], [])
forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
reverse
forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip Shape num
newshape Permutation
perm'
(Permutation
sup_inds, Shape num
support) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (forall a. Ord a => a -> a -> Ordering
compare forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` forall a b. (a, b) -> a
fst) [(Int, num)]
support_inds
([a]
rpt_inds, [b]
repeats) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a. [a]
repeat_inds
LMAD num
off' [LMADDim num]
dims_sup = forall num.
IntegralExp num =>
Monotonicity -> num -> [num] -> LMAD num
makeRotIota Monotonicity
mon num
off Shape num
support
repeats' :: [LMADDim num]
repeats' = forall a b. (a -> b) -> [a] -> [b]
map (\num
n -> forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim num
0 num
n Int
0 Monotonicity
Unknown) forall a. [a]
repeats
dims' :: [LMADDim num]
dims' =
forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (forall a. Ord a => a -> a -> Ordering
compare forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` forall a b. (a, b) -> a
fst) forall a b. (a -> b) -> a -> b
$
forall a b. [a] -> [b] -> [(a, b)]
zip Permutation
sup_inds [LMADDim num]
dims_sup forall a. [a] -> [a] -> [a]
++ forall a b. [a] -> [b] -> [(a, b)]
zip forall a. [a]
rpt_inds [LMADDim num]
repeats'
lmad' :: LMAD num
lmad' = forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off' [LMADDim num]
dims'
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (forall num. Permutation -> LMAD num -> LMAD num
setLMADPermutation Permutation
perm' LMAD num
lmad' forall a. a -> [a] -> NonEmpty a
:| [LMAD num]
lmads) Shape num
oldbase Bool
cg
where
consecutive :: b -> [b] -> Bool
consecutive b
_ [] = Bool
True
consecutive b
i [b
p] = b
i forall a. Eq a => a -> a -> Bool
== b
p
consecutive b
i [b]
ps = forall (t :: * -> *). Foldable t => t Bool -> Bool
and forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. Eq a => a -> a -> Bool
(==) [b]
ps [b
i, b
i forall a. Num a => a -> a -> a
+ b
1 ..]
reshape ::
(Eq num, IntegralExp num) =>
IxFun num ->
Shape num ->
IxFun num
reshape :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Shape num -> IxFun num
reshape IxFun num
ixfun Shape num
new_shape
| Just IxFun num
ixfun' <- forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Shape num -> Maybe (IxFun num)
reshapeOneLMAD IxFun num
ixfun Shape num
new_shape = IxFun num
ixfun'
reshape (IxFun (LMAD num
lmad0 :| [LMAD num]
lmad0s) Shape num
oshp Bool
cg) Shape num
new_shape =
case forall num. IntegralExp num => Shape num -> IxFun num
iota Shape num
new_shape of
IxFun (LMAD num
lmad :| []) Shape num
_ Bool
_ -> forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (LMAD num
lmad forall a. a -> [a] -> NonEmpty a
:| LMAD num
lmad0 forall a. a -> [a] -> [a]
: [LMAD num]
lmad0s) Shape num
oshp Bool
cg
IxFun num
_ -> forall a. HasCallStack => String -> a
error String
"reshape: reached impossible case"
coerce ::
(Eq num, IntegralExp num) =>
IxFun num ->
Shape num ->
IxFun num
coerce :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Shape num -> IxFun num
coerce (IxFun (LMAD num
lmad :| [LMAD num]
lmads) Shape num
oshp Bool
cg) Shape num
new_shape =
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (LMAD num -> LMAD num
onLMAD LMAD num
lmad forall a. a -> [a] -> NonEmpty a
:| [LMAD num]
lmads) Shape num
oshp Bool
cg
where
onLMAD :: LMAD num -> LMAD num
onLMAD (LMAD num
offset [LMADDim num]
dims) = forall num. num -> [LMADDim num] -> LMAD num
LMAD num
offset forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall {num}. LMADDim num -> num -> LMADDim num
onDim [LMADDim num]
dims Shape num
new_shape
onDim :: LMADDim num -> num -> LMADDim num
onDim LMADDim num
ld num
d = LMADDim num
ld {ldShape :: num
ldShape = num
d}
rank ::
IntegralExp num =>
IxFun num ->
Int
rank :: forall num. IntegralExp num => IxFun num -> Int
rank (IxFun (LMAD num
_ [LMADDim num]
sss :| [LMAD num]
_) Shape num
_ Bool
_) = forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
sss
rebaseNice ::
(Eq num, IntegralExp num) =>
IxFun num ->
IxFun num ->
Maybe (IxFun num)
rebaseNice :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> IxFun num -> Maybe (IxFun num)
rebaseNice
new_base :: IxFun num
new_base@(IxFun (LMAD num
lmad_base :| [LMAD num]
lmads_base) Shape num
_ Bool
cg_base)
ixfun :: IxFun num
ixfun@(IxFun NonEmpty (LMAD num)
lmads Shape num
shp Bool
cg) = do
let (LMAD num
lmad :| [LMAD num]
lmads') = forall a. NonEmpty a -> NonEmpty a
NE.reverse NonEmpty (LMAD num)
lmads
dims :: [LMADDim num]
dims = forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad
perm :: Permutation
perm = forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad
perm_base :: Permutation
perm_base = forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad_base
forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$
forall num. IxFun num -> Shape num
base IxFun num
ixfun forall a. Eq a => a -> a -> Bool
== forall num. (Eq num, IntegralExp num) => IxFun num -> Shape num
shape IxFun num
new_base
Bool -> Bool -> Bool
&& Bool
cg
Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((forall a. Eq a => a -> a -> Bool
/= Monotonicity
Unknown) forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall num. LMADDim num -> Monotonicity
ldMon) [LMADDim num]
dims
Bool -> Bool -> Bool
&& (forall num. IxFun num -> Bool
hasContiguousPerm IxFun num
ixfun Bool -> Bool -> Bool
|| forall num. IxFun num -> Bool
hasContiguousPerm IxFun num
new_base)
Bool -> Bool -> Bool
&& (forall (t :: * -> *) a. Foldable t => t a -> Int
length Permutation
perm forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length Permutation
perm_base Bool -> Bool -> Bool
|| forall num. IxFun num -> Bool
hasContiguousPerm IxFun num
ixfun)
Bool -> Bool -> Bool
&& forall (t :: * -> *). Foldable t => t Bool -> Bool
and
( forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3
(\num
sn LMADDim num
ld Bool
inner -> num
sn forall a. Eq a => a -> a -> Bool
== forall num. LMADDim num -> num
ldShape LMADDim num
ld Bool -> Bool -> Bool
|| (Bool
inner Bool -> Bool -> Bool
&& forall num. LMADDim num -> num
ldStride LMADDim num
ld forall a. Eq a => a -> a -> Bool
== num
1))
Shape num
shp
[LMADDim num]
dims
(forall a. Int -> a -> [a]
replicate (forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims forall a. Num a => a -> a -> a
- Int
1) Bool
False forall a. [a] -> [a] -> [a]
++ [Bool
True])
)
let perm_base' :: Permutation
perm_base' =
if forall num. IxFun num -> Bool
hasContiguousPerm IxFun num
ixfun
then Permutation
perm_base
else forall a b. (a -> b) -> [a] -> [b]
map (Permutation
perm !!) Permutation
perm_base
lmad_base' :: LMAD num
lmad_base' = forall num. Permutation -> LMAD num -> LMAD num
setLMADPermutation Permutation
perm_base' LMAD num
lmad_base
dims_base :: [LMADDim num]
dims_base = forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad_base'
n_fewer_dims :: Int
n_fewer_dims = forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims_base forall a. Num a => a -> a -> a
- forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims
([LMADDim num]
dims_base', Shape num
offs_contrib) =
forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
( \(LMADDim num
s1 num
n1 Int
p1 Monotonicity
_) (LMADDim num
_ num
_ Int
_ Monotonicity
m2) ->
let (num
s', num
off')
| Monotonicity
m2 forall a. Eq a => a -> a -> Bool
== Monotonicity
Inc = (num
s1, num
0)
| Bool
otherwise = (num
s1 forall a. Num a => a -> a -> a
* (-num
1), num
s1 forall a. Num a => a -> a -> a
* (num
n1 forall a. Num a => a -> a -> a
- num
1))
in (forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim num
s' num
n1 (Int
p1 forall a. Num a => a -> a -> a
- Int
n_fewer_dims) Monotonicity
Inc, num
off')
)
(forall a. Int -> [a] -> [a]
drop Int
n_fewer_dims [LMADDim num]
dims_base)
[LMADDim num]
dims
off_base :: num
off_base = forall num. LMAD num -> num
lmadOffset LMAD num
lmad_base' forall a. Num a => a -> a -> a
+ forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum Shape num
offs_contrib
lmad_base'' :: LMAD num
lmad_base''
| forall num. LMAD num -> num
lmadOffset LMAD num
lmad forall a. Eq a => a -> a -> Bool
== num
0 = forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off_base [LMADDim num]
dims_base'
| Bool
otherwise =
forall num. Shape num -> LMAD num -> LMAD num
setLMADShape
(forall num. (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShape LMAD num
lmad)
( forall num. num -> [LMADDim num] -> LMAD num
LMAD
(num
off_base forall a. Num a => a -> a -> a
+ forall num. LMADDim num -> num
ldStride (forall a. [a] -> a
last [LMADDim num]
dims_base) forall a. Num a => a -> a -> a
* forall num. LMAD num -> num
lmadOffset LMAD num
lmad)
[LMADDim num]
dims_base'
)
new_base' :: IxFun num
new_base' = forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (LMAD num
lmad_base'' forall a. a -> [a] -> NonEmpty a
:| [LMAD num]
lmads_base) Shape num
shp Bool
cg_base
IxFun NonEmpty (LMAD num)
lmads_base' Shape num
_ Bool
_ = IxFun num
new_base'
lmads'' :: NonEmpty (LMAD num)
lmads'' = [LMAD num]
lmads' forall a. [a] -> NonEmpty a -> NonEmpty a
++@ NonEmpty (LMAD num)
lmads_base'
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun NonEmpty (LMAD num)
lmads'' Shape num
shp (Bool
cg Bool -> Bool -> Bool
&& Bool
cg_base)
rebase ::
(Eq num, IntegralExp num) =>
IxFun num ->
IxFun num ->
IxFun num
rebase :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> IxFun num -> IxFun num
rebase new_base :: IxFun num
new_base@(IxFun NonEmpty (LMAD num)
lmads_base Shape num
shp_base Bool
cg_base) ixfun :: IxFun num
ixfun@(IxFun NonEmpty (LMAD num)
lmads Shape num
shp Bool
cg)
| Just IxFun num
ixfun' <- forall num.
(Eq num, IntegralExp num) =>
IxFun num -> IxFun num -> Maybe (IxFun num)
rebaseNice IxFun num
new_base IxFun num
ixfun = IxFun num
ixfun'
| Bool
otherwise =
let (NonEmpty (LMAD num)
lmads_base', Shape num
shp_base') =
if forall num. IxFun num -> Shape num
base IxFun num
ixfun forall a. Eq a => a -> a -> Bool
== forall num. (Eq num, IntegralExp num) => IxFun num -> Shape num
shape IxFun num
new_base
then (NonEmpty (LMAD num)
lmads_base, Shape num
shp_base)
else
let IxFun NonEmpty (LMAD num)
lmads' Shape num
shp_base'' Bool
_ = forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Shape num -> IxFun num
reshape IxFun num
new_base Shape num
shp
in (NonEmpty (LMAD num)
lmads', Shape num
shp_base'')
in forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (NonEmpty (LMAD num)
lmads forall a. NonEmpty a -> NonEmpty a -> NonEmpty a
@++@ NonEmpty (LMAD num)
lmads_base') Shape num
shp_base' (Bool
cg Bool -> Bool -> Bool
&& Bool
cg_base)
linearWithOffset ::
(Eq num, IntegralExp num) =>
IxFun num ->
num ->
Maybe num
linearWithOffset :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
linearWithOffset ixfun :: IxFun num
ixfun@(IxFun (LMAD num
lmad :| []) Shape num
_ Bool
cg) num
elem_size
| forall num. IxFun num -> Bool
hasContiguousPerm IxFun num
ixfun Bool -> Bool -> Bool
&& Bool
cg Bool -> Bool -> Bool
&& forall num. (Eq num, IntegralExp num) => IxFun num -> Monotonicity
ixfunMonotonicity IxFun num
ixfun forall a. Eq a => a -> a -> Bool
== Monotonicity
Inc =
forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall num. LMAD num -> num
lmadOffset LMAD num
lmad forall a. Num a => a -> a -> a
* num
elem_size
linearWithOffset IxFun num
_ num
_ = forall a. Maybe a
Nothing
rearrangeWithOffset ::
(Eq num, IntegralExp num) =>
IxFun num ->
num ->
Maybe (num, [(Int, num)])
rearrangeWithOffset :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe (num, [(Int, num)])
rearrangeWithOffset (IxFun (LMAD num
lmad :| []) Shape num
oshp Bool
cg) num
elem_size = do
let perm :: Permutation
perm = forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad
perm_contig :: Permutation
perm_contig = [Int
0 .. forall (t :: * -> *) a. Foldable t => t a -> Int
length Permutation
perm forall a. Num a => a -> a -> a
- Int
1]
num
offset <-
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
linearWithOffset
(forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (forall num. Permutation -> LMAD num -> LMAD num
setLMADPermutation Permutation
perm_contig LMAD num
lmad forall a. a -> [a] -> NonEmpty a
:| []) Shape num
oshp Bool
cg)
num
elem_size
forall (f :: * -> *) a. Applicative f => a -> f a
pure (num
offset, forall a b. [a] -> [b] -> [(a, b)]
zip Permutation
perm (forall a. Permutation -> [a] -> [a]
permuteFwd Permutation
perm (forall num. (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShapeBase LMAD num
lmad)))
rearrangeWithOffset IxFun num
_ num
_ = forall a. Maybe a
Nothing
isLinear :: (Eq num, IntegralExp num) => IxFun num -> Bool
isLinear :: forall num. (Eq num, IntegralExp num) => IxFun num -> Bool
isLinear = (forall a. Eq a => a -> a -> Bool
== forall a. a -> Maybe a
Just num
0) forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b c. (a -> b -> c) -> b -> a -> c
flip forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
linearWithOffset num
1
permuteFwd :: Permutation -> [a] -> [a]
permuteFwd :: forall a. Permutation -> [a] -> [a]
permuteFwd Permutation
ps [a]
elems = forall a b. (a -> b) -> [a] -> [b]
map ([a]
elems !!) Permutation
ps
permuteInv :: Permutation -> [a] -> [a]
permuteInv :: forall a. Permutation -> [a] -> [a]
permuteInv Permutation
ps [a]
elems = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (forall a. Ord a => a -> a -> Ordering
compare forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` forall a b. (a, b) -> a
fst) forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip Permutation
ps [a]
elems
flatOneDim ::
(Eq num, IntegralExp num) =>
num ->
num ->
num
flatOneDim :: forall num. (Eq num, IntegralExp num) => num -> num -> num
flatOneDim num
s num
i
| num
s forall a. Eq a => a -> a -> Bool
== num
0 = num
0
| Bool
otherwise = num
i forall a. Num a => a -> a -> a
* num
s
makeRotIota ::
IntegralExp num =>
Monotonicity ->
num ->
[num] ->
LMAD num
makeRotIota :: forall num.
IntegralExp num =>
Monotonicity -> num -> [num] -> LMAD num
makeRotIota Monotonicity
mon num
off [num]
ns
| Monotonicity
mon forall a. Eq a => a -> a -> Bool
== Monotonicity
Inc Bool -> Bool -> Bool
|| Monotonicity
mon forall a. Eq a => a -> a -> Bool
== Monotonicity
Dec =
let rk :: Int
rk = forall (t :: * -> *) a. Foldable t => t a -> Int
length [num]
ns
ss0 :: [num]
ss0 = forall a. [a] -> [a]
reverse forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
take Int
rk forall a b. (a -> b) -> a -> b
$ forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl forall a. Num a => a -> a -> a
(*) num
1 forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
reverse [num]
ns
ss :: [num]
ss =
if Monotonicity
mon forall a. Eq a => a -> a -> Bool
== Monotonicity
Inc
then [num]
ss0
else forall a b. (a -> b) -> [a] -> [b]
map (forall a. Num a => a -> a -> a
* (-num
1)) [num]
ss0
ps :: Permutation
ps = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (Integral a, Num b) => a -> b
fromIntegral [Int
0 .. Int
rk forall a. Num a => a -> a -> a
- Int
1]
fi :: [Monotonicity]
fi = forall a. Int -> a -> [a]
replicate Int
rk Monotonicity
mon
in forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off forall a b. (a -> b) -> a -> b
$ forall a b c d e.
(a -> b -> c -> d -> e) -> [a] -> [b] -> [c] -> [d] -> [e]
zipWith4 forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim [num]
ss [num]
ns Permutation
ps [Monotonicity]
fi
| Bool
otherwise = forall a. HasCallStack => String -> a
error String
"makeRotIota: requires Inc or Dec"
ixfunMonotonicity ::
(Eq num, IntegralExp num) =>
IxFun num ->
Monotonicity
ixfunMonotonicity :: forall num. (Eq num, IntegralExp num) => IxFun num -> Monotonicity
ixfunMonotonicity (IxFun (LMAD num
lmad :| [LMAD num]
lmads) Shape num
_ Bool
_) =
let mon0 :: Monotonicity
mon0 = forall num. (Eq num, IntegralExp num) => LMAD num -> Monotonicity
lmadMonotonicityRots LMAD num
lmad
in if forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((forall a. Eq a => a -> a -> Bool
== Monotonicity
mon0) forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall num. (Eq num, IntegralExp num) => LMAD num -> Monotonicity
lmadMonotonicityRots) [LMAD num]
lmads
then Monotonicity
mon0
else Monotonicity
Unknown
where
lmadMonotonicityRots ::
(Eq num, IntegralExp num) =>
LMAD num ->
Monotonicity
lmadMonotonicityRots :: forall num. (Eq num, IntegralExp num) => LMAD num -> Monotonicity
lmadMonotonicityRots (LMAD num
_ [LMADDim num]
dims)
| forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall num.
(Eq num, IntegralExp num) =>
Monotonicity -> LMADDim num -> Bool
isMonDim Monotonicity
Inc) [LMADDim num]
dims = Monotonicity
Inc
| forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall num.
(Eq num, IntegralExp num) =>
Monotonicity -> LMADDim num -> Bool
isMonDim Monotonicity
Dec) [LMADDim num]
dims = Monotonicity
Dec
| Bool
otherwise = Monotonicity
Unknown
isMonDim ::
(Eq num, IntegralExp num) =>
Monotonicity ->
LMADDim num ->
Bool
isMonDim :: forall num.
(Eq num, IntegralExp num) =>
Monotonicity -> LMADDim num -> Bool
isMonDim Monotonicity
mon (LMADDim num
s num
_ Int
_ Monotonicity
ldmon) =
num
s forall a. Eq a => a -> a -> Bool
== num
0 Bool -> Bool -> Bool
|| Monotonicity
mon forall a. Eq a => a -> a -> Bool
== Monotonicity
ldmon
existentialize ::
IxFun (TPrimExp Int64 a) ->
IxFun (TPrimExp Int64 (Ext b))
existentialize :: forall a b.
IxFun (TPrimExp Int64 a) -> IxFun (TPrimExp Int64 (Ext b))
existentialize IxFun (TPrimExp Int64 a)
ixfun = forall s a. State s a -> s -> a
evalState (forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall a b. a -> b -> a
const forall {k} {t :: k} {a}. StateT Int Identity (TPrimExp t (Ext a))
mkExt) IxFun (TPrimExp Int64 a)
ixfun) Int
0
where
mkExt :: StateT Int Identity (TPrimExp t (Ext a))
mkExt = do
Int
i <- forall s (m :: * -> *). MonadState s m => m s
get
forall s (m :: * -> *). MonadState s m => s -> m ()
put forall a b. (a -> b) -> a -> b
$ Int
i forall a. Num a => a -> a -> a
+ Int
1
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp forall a b. (a -> b) -> a -> b
$ forall v. v -> PrimType -> PrimExp v
LeafExp (forall a. Int -> Ext a
Ext Int
i) PrimType
int64
closeEnough :: IxFun num -> IxFun num -> Bool
closeEnough :: forall num. IxFun num -> IxFun num -> Bool
closeEnough IxFun num
ixf1 IxFun num
ixf2 =
(forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall num. IxFun num -> Shape num
base IxFun num
ixf1) forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall num. IxFun num -> Shape num
base IxFun num
ixf2))
Bool -> Bool -> Bool
&& (forall a. NonEmpty a -> Int
NE.length (forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs IxFun num
ixf1) forall a. Eq a => a -> a -> Bool
== forall a. NonEmpty a -> Int
NE.length (forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs IxFun num
ixf2))
Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all forall num. (LMAD num, LMAD num) -> Bool
closeEnoughLMADs (forall a b. NonEmpty a -> NonEmpty b -> NonEmpty (a, b)
NE.zip (forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs IxFun num
ixf1) (forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs IxFun num
ixf2))
Bool -> Bool -> Bool
&& (forall num. IxFun num -> Bool
contiguous IxFun num
ixf1 forall a. Ord a => a -> a -> Bool
<= forall num. IxFun num -> Bool
contiguous IxFun num
ixf2)
where
closeEnoughLMADs :: (LMAD num, LMAD num) -> Bool
closeEnoughLMADs :: forall num. (LMAD num, LMAD num) -> Bool
closeEnoughLMADs (LMAD num
lmad1, LMAD num
lmad2) =
forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad1) forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad2)
Bool -> Bool -> Bool
&& forall a b. (a -> b) -> [a] -> [b]
map forall num. LMADDim num -> Int
ldPerm (forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad1)
forall a. Eq a => a -> a -> Bool
== forall a b. (a -> b) -> [a] -> [b]
map forall num. LMADDim num -> Int
ldPerm (forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad2)
equivalent :: Eq num => IxFun num -> IxFun num -> Bool
equivalent :: forall num. Eq num => IxFun num -> IxFun num -> Bool
equivalent IxFun num
ixf1 IxFun num
ixf2 =
forall a. NonEmpty a -> Int
NE.length (forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs IxFun num
ixf1) forall a. Eq a => a -> a -> Bool
== forall a. NonEmpty a -> Int
NE.length (forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs IxFun num
ixf2)
Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all forall {b}. Eq b => (LMAD b, LMAD b) -> Bool
equivalentLMADs (forall a b. NonEmpty a -> NonEmpty b -> NonEmpty (a, b)
NE.zip (forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs IxFun num
ixf1) (forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs IxFun num
ixf2))
where
equivalentLMADs :: (LMAD b, LMAD b) -> Bool
equivalentLMADs (LMAD b
lmad1, LMAD b
lmad2) =
forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall num. LMAD num -> [LMADDim num]
lmadDims LMAD b
lmad1) forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall num. LMAD num -> [LMADDim num]
lmadDims LMAD b
lmad2)
Bool -> Bool -> Bool
&& forall a b. (a -> b) -> [a] -> [b]
map forall num. LMADDim num -> Int
ldPerm (forall num. LMAD num -> [LMADDim num]
lmadDims LMAD b
lmad1)
forall a. Eq a => a -> a -> Bool
== forall a b. (a -> b) -> [a] -> [b]
map forall num. LMADDim num -> Int
ldPerm (forall num. LMAD num -> [LMADDim num]
lmadDims LMAD b
lmad2)
Bool -> Bool -> Bool
&& forall num. LMAD num -> num
lmadOffset LMAD b
lmad1
forall a. Eq a => a -> a -> Bool
== forall num. LMAD num -> num
lmadOffset LMAD b
lmad2
Bool -> Bool -> Bool
&& forall a b. (a -> b) -> [a] -> [b]
map forall num. LMADDim num -> num
ldStride (forall num. LMAD num -> [LMADDim num]
lmadDims LMAD b
lmad1)
forall a. Eq a => a -> a -> Bool
== forall a b. (a -> b) -> [a] -> [b]
map forall num. LMADDim num -> num
ldStride (forall num. LMAD num -> [LMADDim num]
lmadDims LMAD b
lmad2)
dynamicEqualsLMADDim :: Eq num => LMADDim (TPrimExp t num) -> LMADDim (TPrimExp t num) -> TPrimExp Bool num
dynamicEqualsLMADDim :: forall {k} num (t :: k).
Eq num =>
LMADDim (TPrimExp t num)
-> LMADDim (TPrimExp t num) -> TPrimExp Bool num
dynamicEqualsLMADDim LMADDim (TPrimExp t num)
dim1 LMADDim (TPrimExp t num)
dim2 =
forall num. LMADDim num -> num
ldStride LMADDim (TPrimExp t num)
dim1 forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. forall num. LMADDim num -> num
ldStride LMADDim (TPrimExp t num)
dim2
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. forall num. LMADDim num -> num
ldShape LMADDim (TPrimExp t num)
dim1 forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. forall num. LMADDim num -> num
ldShape LMADDim (TPrimExp t num)
dim2
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. forall v. Bool -> TPrimExp Bool v
fromBool (forall num. LMADDim num -> Int
ldPerm LMADDim (TPrimExp t num)
dim1 forall a. Eq a => a -> a -> Bool
== forall num. LMADDim num -> Int
ldPerm LMADDim (TPrimExp t num)
dim2)
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. forall v. Bool -> TPrimExp Bool v
fromBool (forall num. LMADDim num -> Monotonicity
ldMon LMADDim (TPrimExp t num)
dim1 forall a. Eq a => a -> a -> Bool
== forall num. LMADDim num -> Monotonicity
ldMon LMADDim (TPrimExp t num)
dim2)
dynamicEqualsLMAD :: Eq num => LMAD (TPrimExp t num) -> LMAD (TPrimExp t num) -> TPrimExp Bool num
dynamicEqualsLMAD :: forall {k} num (t :: k).
Eq num =>
LMAD (TPrimExp t num) -> LMAD (TPrimExp t num) -> TPrimExp Bool num
dynamicEqualsLMAD LMAD (TPrimExp t num)
lmad1 LMAD (TPrimExp t num)
lmad2 =
forall num. LMAD num -> num
lmadOffset LMAD (TPrimExp t num)
lmad1 forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. forall num. LMAD num -> num
lmadOffset LMAD (TPrimExp t num)
lmad2
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr
(forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
(.&&.) forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry forall {k} num (t :: k).
Eq num =>
LMADDim (TPrimExp t num)
-> LMADDim (TPrimExp t num) -> TPrimExp Bool num
dynamicEqualsLMADDim)
forall v. TPrimExp Bool v
true
(forall a b. [a] -> [b] -> [(a, b)]
zip (forall num. LMAD num -> [LMADDim num]
lmadDims LMAD (TPrimExp t num)
lmad1) (forall num. LMAD num -> [LMADDim num]
lmadDims LMAD (TPrimExp t num)
lmad2))