{-# OPTIONS_GHC -fno-warn-redundant-constraints #-}
module Futhark.IR.Mem.IxFun
( IxFun (..),
Shape,
LMAD (..),
LMADDim (..),
Monotonicity (..),
index,
mkExistential,
iota,
iotaOffset,
permute,
reshape,
coerce,
slice,
flatSlice,
rebase,
shape,
lmadShape,
rank,
linearWithOffset,
rearrangeWithOffset,
isDirect,
isLinear,
substituteInIxFun,
substituteInLMAD,
existentialize,
closeEnough,
equivalent,
hasOneLmad,
permuteInv,
conservativeFlatten,
disjoint,
disjoint2,
disjoint3,
dynamicEqualsLMAD,
)
where
import Control.Category
import Control.Monad
import Control.Monad.State
import Data.Function (on, (&))
import Data.List (elemIndex, partition, sort, sortBy, zip4, zipWith4)
import Data.List.NonEmpty (NonEmpty (..))
import Data.List.NonEmpty qualified as NE
import Data.Map.Strict qualified as M
import Data.Maybe (fromJust, isJust, isNothing)
import Data.Traversable
import Futhark.Analysis.AlgSimplify qualified as AlgSimplify
import Futhark.Analysis.PrimExp
import Futhark.Analysis.PrimExp.Convert
import Futhark.IR.Mem.Interval
import Futhark.IR.Prop
import Futhark.IR.Syntax
( DimIndex (..),
FlatDimIndex (..),
FlatSlice (..),
Slice (..),
Type,
dimFix,
flatSliceDims,
flatSliceStrides,
unitSlice,
)
import Futhark.IR.Syntax.Core (Ext (..), VName (..))
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import Futhark.Util
import Futhark.Util.IntegralExp
import Futhark.Util.Pretty
import Prelude hiding (gcd, id, mod, (.))
type Shape num = [num]
type Indices num = [num]
type Permutation = [Int]
data Monotonicity
=
Inc
|
Dec
|
Unknown
deriving (Int -> Monotonicity -> ShowS
[Monotonicity] -> ShowS
Monotonicity -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Monotonicity] -> ShowS
$cshowList :: [Monotonicity] -> ShowS
show :: Monotonicity -> String
$cshow :: Monotonicity -> String
showsPrec :: Int -> Monotonicity -> ShowS
$cshowsPrec :: Int -> Monotonicity -> ShowS
Show, Monotonicity -> Monotonicity -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Monotonicity -> Monotonicity -> Bool
$c/= :: Monotonicity -> Monotonicity -> Bool
== :: Monotonicity -> Monotonicity -> Bool
$c== :: Monotonicity -> Monotonicity -> Bool
Eq)
data LMADDim num = LMADDim
{ forall num. LMADDim num -> num
ldStride :: num,
forall num. LMADDim num -> num
ldShape :: num,
forall num. LMADDim num -> Int
ldPerm :: Int,
forall num. LMADDim num -> Monotonicity
ldMon :: Monotonicity
}
deriving (Int -> LMADDim num -> ShowS
forall num. Show num => Int -> LMADDim num -> ShowS
forall num. Show num => [LMADDim num] -> ShowS
forall num. Show num => LMADDim num -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [LMADDim num] -> ShowS
$cshowList :: forall num. Show num => [LMADDim num] -> ShowS
show :: LMADDim num -> String
$cshow :: forall num. Show num => LMADDim num -> String
showsPrec :: Int -> LMADDim num -> ShowS
$cshowsPrec :: forall num. Show num => Int -> LMADDim num -> ShowS
Show, LMADDim num -> LMADDim num -> Bool
forall num. Eq num => LMADDim num -> LMADDim num -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: LMADDim num -> LMADDim num -> Bool
$c/= :: forall num. Eq num => LMADDim num -> LMADDim num -> Bool
== :: LMADDim num -> LMADDim num -> Bool
$c== :: forall num. Eq num => LMADDim num -> LMADDim num -> Bool
Eq)
instance Ord Monotonicity where
<= :: Monotonicity -> Monotonicity -> Bool
(<=) Monotonicity
_ Monotonicity
Inc = Bool
True
(<=) Monotonicity
Unknown Monotonicity
_ = Bool
True
(<=) Monotonicity
_ Monotonicity
Unknown = Bool
False
(<=) Monotonicity
Inc Monotonicity
Dec = Bool
False
(<=) Monotonicity
_ Monotonicity
Dec = Bool
True
instance Ord num => Ord (LMADDim num) where
(LMADDim num
s1 num
q1 Int
p1 Monotonicity
m1) <= :: LMADDim num -> LMADDim num -> Bool
<= (LMADDim num
s2 num
q2 Int
p2 Monotonicity
m2) =
([num
q1, num
s1] forall a. Ord a => a -> a -> Bool
< [num
q2, num
s2])
Bool -> Bool -> Bool
|| ( ([num
q1, num
s1] forall a. Eq a => a -> a -> Bool
== [num
q2, num
s2])
Bool -> Bool -> Bool
&& ( (Int
p1 forall a. Ord a => a -> a -> Bool
< Int
p2)
Bool -> Bool -> Bool
|| ( (Int
p1 forall a. Eq a => a -> a -> Bool
== Int
p2)
Bool -> Bool -> Bool
&& (Monotonicity
m1 forall a. Ord a => a -> a -> Bool
<= Monotonicity
m2)
)
)
)
data LMAD num = LMAD
{ forall num. LMAD num -> num
lmadOffset :: num,
forall num. LMAD num -> [LMADDim num]
lmadDims :: [LMADDim num]
}
deriving (Int -> LMAD num -> ShowS
forall num. Show num => Int -> LMAD num -> ShowS
forall num. Show num => [LMAD num] -> ShowS
forall num. Show num => LMAD num -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [LMAD num] -> ShowS
$cshowList :: forall num. Show num => [LMAD num] -> ShowS
show :: LMAD num -> String
$cshow :: forall num. Show num => LMAD num -> String
showsPrec :: Int -> LMAD num -> ShowS
$cshowsPrec :: forall num. Show num => Int -> LMAD num -> ShowS
Show, LMAD num -> LMAD num -> Bool
forall num. Eq num => LMAD num -> LMAD num -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: LMAD num -> LMAD num -> Bool
$c/= :: forall num. Eq num => LMAD num -> LMAD num -> Bool
== :: LMAD num -> LMAD num -> Bool
$c== :: forall num. Eq num => LMAD num -> LMAD num -> Bool
Eq, LMAD num -> LMAD num -> Bool
LMAD num -> LMAD num -> Ordering
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {num}. Ord num => Eq (LMAD num)
forall num. Ord num => LMAD num -> LMAD num -> Bool
forall num. Ord num => LMAD num -> LMAD num -> Ordering
forall num. Ord num => LMAD num -> LMAD num -> LMAD num
min :: LMAD num -> LMAD num -> LMAD num
$cmin :: forall num. Ord num => LMAD num -> LMAD num -> LMAD num
max :: LMAD num -> LMAD num -> LMAD num
$cmax :: forall num. Ord num => LMAD num -> LMAD num -> LMAD num
>= :: LMAD num -> LMAD num -> Bool
$c>= :: forall num. Ord num => LMAD num -> LMAD num -> Bool
> :: LMAD num -> LMAD num -> Bool
$c> :: forall num. Ord num => LMAD num -> LMAD num -> Bool
<= :: LMAD num -> LMAD num -> Bool
$c<= :: forall num. Ord num => LMAD num -> LMAD num -> Bool
< :: LMAD num -> LMAD num -> Bool
$c< :: forall num. Ord num => LMAD num -> LMAD num -> Bool
compare :: LMAD num -> LMAD num -> Ordering
$ccompare :: forall num. Ord num => LMAD num -> LMAD num -> Ordering
Ord)
data IxFun num = IxFun
{ forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs :: NonEmpty (LMAD num),
forall num. IxFun num -> Shape num
base :: Shape num,
forall num. IxFun num -> Bool
contiguous :: Bool
}
deriving (Int -> IxFun num -> ShowS
forall num. Show num => Int -> IxFun num -> ShowS
forall num. Show num => [IxFun num] -> ShowS
forall num. Show num => IxFun num -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [IxFun num] -> ShowS
$cshowList :: forall num. Show num => [IxFun num] -> ShowS
show :: IxFun num -> String
$cshow :: forall num. Show num => IxFun num -> String
showsPrec :: Int -> IxFun num -> ShowS
$cshowsPrec :: forall num. Show num => Int -> IxFun num -> ShowS
Show, IxFun num -> IxFun num -> Bool
forall num. Eq num => IxFun num -> IxFun num -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: IxFun num -> IxFun num -> Bool
$c/= :: forall num. Eq num => IxFun num -> IxFun num -> Bool
== :: IxFun num -> IxFun num -> Bool
$c== :: forall num. Eq num => IxFun num -> IxFun num -> Bool
Eq)
instance Pretty Monotonicity where
pretty :: forall ann. Monotonicity -> Doc ann
pretty = forall a ann. Pretty a => a -> Doc ann
pretty forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. Show a => a -> String
show
instance Pretty num => Pretty (LMAD num) where
pretty :: forall ann. LMAD num -> Doc ann
pretty (LMAD num
offset [LMADDim num]
dims) =
forall ann. Doc ann -> Doc ann
braces forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. [Doc a] -> Doc a
semistack forall a b. (a -> b) -> a -> b
$
[ Doc ann
"offset:" forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall ann. Doc ann -> Doc ann
group (forall a ann. Pretty a => a -> Doc ann
pretty num
offset),
Doc ann
"strides:" forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall {b} {ann}. Pretty b => (LMADDim num -> b) -> Doc ann
p forall num. LMADDim num -> num
ldStride,
Doc ann
"shape:" forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall {b} {ann}. Pretty b => (LMADDim num -> b) -> Doc ann
p forall num. LMADDim num -> num
ldShape,
Doc ann
"permutation:" forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall {b} {ann}. Pretty b => (LMADDim num -> b) -> Doc ann
p forall num. LMADDim num -> Int
ldPerm,
Doc ann
"monotonicity:" forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall {b} {ann}. Pretty b => (LMADDim num -> b) -> Doc ann
p forall num. LMADDim num -> Monotonicity
ldMon
]
where
p :: (LMADDim num -> b) -> Doc ann
p LMADDim num -> b
f = forall ann. Doc ann -> Doc ann
group forall a b. (a -> b) -> a -> b
$ forall ann. Doc ann -> Doc ann
brackets forall a b. (a -> b) -> a -> b
$ forall ann. Doc ann -> Doc ann
align forall a b. (a -> b) -> a -> b
$ forall a. [Doc a] -> Doc a
commasep forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall a ann. Pretty a => a -> Doc ann
pretty forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. LMADDim num -> b
f) [LMADDim num]
dims
instance Pretty num => Pretty (IxFun num) where
pretty :: forall ann. IxFun num -> Doc ann
pretty (IxFun NonEmpty (LMAD num)
lmads Shape num
oshp Bool
cg) =
forall ann. Doc ann -> Doc ann
braces forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. [Doc a] -> Doc a
semistack forall a b. (a -> b) -> a -> b
$
[ Doc ann
"base:" forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall ann. Doc ann -> Doc ann
brackets (forall a. [Doc a] -> Doc a
commasep forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty Shape num
oshp),
Doc ann
"contiguous:" forall ann. Doc ann -> Doc ann -> Doc ann
<+> if Bool
cg then Doc ann
"true" else Doc ann
"false",
Doc ann
"LMADs:" forall ann. Doc ann -> Doc ann -> Doc ann
<+> forall ann. Doc ann -> Doc ann
brackets (forall a. [Doc a] -> Doc a
commastack forall a b. (a -> b) -> a -> b
$ forall a. NonEmpty a -> [a]
NE.toList forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> NonEmpty a -> NonEmpty b
NE.map forall a ann. Pretty a => a -> Doc ann
pretty NonEmpty (LMAD num)
lmads)
]
instance Substitute num => Substitute (LMAD num) where
substituteNames :: Map VName VName -> LMAD num -> LMAD num
substituteNames Map VName VName
substs = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a -> b) -> a -> b
$ forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs
instance Substitute num => Substitute (IxFun num) where
substituteNames :: Map VName VName -> IxFun num -> IxFun num
substituteNames Map VName VName
substs = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a -> b) -> a -> b
$ forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs
instance Substitute num => Rename (LMAD num) where
rename :: LMAD num -> RenameM (LMAD num)
rename = forall a. Substitute a => a -> RenameM a
substituteRename
instance Substitute num => Rename (IxFun num) where
rename :: IxFun num -> RenameM (IxFun num)
rename = forall a. Substitute a => a -> RenameM a
substituteRename
instance FreeIn num => FreeIn (LMAD num) where
freeIn' :: LMAD num -> FV
freeIn' = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap forall a. FreeIn a => a -> FV
freeIn'
instance FreeIn num => FreeIn (IxFun num) where
freeIn' :: IxFun num -> FV
freeIn' = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap forall a. FreeIn a => a -> FV
freeIn'
instance FreeIn num => FreeIn (LMADDim num) where
freeIn' :: LMADDim num -> FV
freeIn' (LMADDim num
s num
n Int
_ Monotonicity
_) = forall a. FreeIn a => a -> FV
freeIn' num
s forall a. Semigroup a => a -> a -> a
<> forall a. FreeIn a => a -> FV
freeIn' num
n
instance Functor LMAD where
fmap :: forall a b. (a -> b) -> LMAD a -> LMAD b
fmap = forall (t :: * -> *) a b. Traversable t => (a -> b) -> t a -> t b
fmapDefault
instance Functor IxFun where
fmap :: forall a b. (a -> b) -> IxFun a -> IxFun b
fmap = forall (t :: * -> *) a b. Traversable t => (a -> b) -> t a -> t b
fmapDefault
instance Foldable LMAD where
foldMap :: forall m a. Monoid m => (a -> m) -> LMAD a -> m
foldMap = forall (t :: * -> *) m a.
(Traversable t, Monoid m) =>
(a -> m) -> t a -> m
foldMapDefault
instance Foldable IxFun where
foldMap :: forall m a. Monoid m => (a -> m) -> IxFun a -> m
foldMap = forall (t :: * -> *) m a.
(Traversable t, Monoid m) =>
(a -> m) -> t a -> m
foldMapDefault
instance Traversable LMAD where
traverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> LMAD a -> f (LMAD b)
traverse a -> f b
f (LMAD a
offset [LMADDim a]
dims) =
forall num. num -> [LMADDim num] -> LMAD num
LMAD forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> f b
f a
offset forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse LMADDim a -> f (LMADDim b)
f' [LMADDim a]
dims
where
f' :: LMADDim a -> f (LMADDim b)
f' (LMADDim a
s a
n Int
p Monotonicity
m) = forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> f b
f a
s forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> a -> f b
f a
n forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
p forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Monotonicity
m
instance Traversable IxFun where
traverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> IxFun a -> f (IxFun b)
traverse a -> f b
f (IxFun NonEmpty (LMAD a)
lmads Shape a
oshp Bool
cg) =
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse a -> f b
f) NonEmpty (LMAD a)
lmads forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse a -> f b
f Shape a
oshp forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
cg
(++@) :: [a] -> NonEmpty a -> NonEmpty a
[a]
es ++@ :: forall a. [a] -> NonEmpty a -> NonEmpty a
++@ (a
ne :| [a]
nes) = case [a]
es of
a
e : [a]
es' -> a
e forall a. a -> [a] -> NonEmpty a
:| [a]
es' forall a. [a] -> [a] -> [a]
++ [a
ne] forall a. [a] -> [a] -> [a]
++ [a]
nes
[] -> a
ne forall a. a -> [a] -> NonEmpty a
:| [a]
nes
(@++@) :: NonEmpty a -> NonEmpty a -> NonEmpty a
(a
x :| [a]
xs) @++@ :: forall a. NonEmpty a -> NonEmpty a -> NonEmpty a
@++@ (a
y :| [a]
ys) = a
x forall a. a -> [a] -> NonEmpty a
:| [a]
xs forall a. [a] -> [a] -> [a]
++ [a
y] forall a. [a] -> [a] -> [a]
++ [a]
ys
invertMonotonicity :: Monotonicity -> Monotonicity
invertMonotonicity :: Monotonicity -> Monotonicity
invertMonotonicity Monotonicity
Inc = Monotonicity
Dec
invertMonotonicity Monotonicity
Dec = Monotonicity
Inc
invertMonotonicity Monotonicity
Unknown = Monotonicity
Unknown
lmadPermutation :: LMAD num -> Permutation
lmadPermutation :: forall num. LMAD num -> Permutation
lmadPermutation = forall a b. (a -> b) -> [a] -> [b]
map forall num. LMADDim num -> Int
ldPerm forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall num. LMAD num -> [LMADDim num]
lmadDims
setLMADPermutation :: Permutation -> LMAD num -> LMAD num
setLMADPermutation :: forall num. Permutation -> LMAD num -> LMAD num
setLMADPermutation Permutation
perm LMAD num
lmad =
LMAD num
lmad {lmadDims :: [LMADDim num]
lmadDims = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\LMADDim num
dim Int
p -> LMADDim num
dim {ldPerm :: Int
ldPerm = Int
p}) (forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad) Permutation
perm}
setLMADShape :: Shape num -> LMAD num -> LMAD num
setLMADShape :: forall num. Shape num -> LMAD num -> LMAD num
setLMADShape Shape num
shp LMAD num
lmad = LMAD num
lmad {lmadDims :: [LMADDim num]
lmadDims = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\LMADDim num
dim num
s -> LMADDim num
dim {ldShape :: num
ldShape = num
s}) (forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad) Shape num
shp}
substituteInLMAD ::
Ord a =>
M.Map a (TPrimExp t a) ->
LMAD (TPrimExp t a) ->
LMAD (TPrimExp t a)
substituteInLMAD :: forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a) -> LMAD (TPrimExp t a) -> LMAD (TPrimExp t a)
substituteInLMAD Map a (TPrimExp t a)
tab (LMAD TPrimExp t a
offset [LMADDim (TPrimExp t a)]
dims) =
let offset' :: TPrimExp t a
offset' = forall {k} {k} {t :: k} {t :: k}. TPrimExp t a -> TPrimExp t a
sub TPrimExp t a
offset
dims' :: [LMADDim (TPrimExp t a)]
dims' =
forall a b. (a -> b) -> [a] -> [b]
map
( \(LMADDim TPrimExp t a
s TPrimExp t a
n Int
p Monotonicity
m) ->
forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim
(forall {k} {k} {t :: k} {t :: k}. TPrimExp t a -> TPrimExp t a
sub TPrimExp t a
s)
(forall {k} {k} {t :: k} {t :: k}. TPrimExp t a -> TPrimExp t a
sub TPrimExp t a
n)
Int
p
Monotonicity
m
)
[LMADDim (TPrimExp t a)]
dims
in forall num. num -> [LMADDim num] -> LMAD num
LMAD forall {k} {t :: k}. TPrimExp t a
offset' forall {k} {t :: k}. [LMADDim (TPrimExp t a)]
dims'
where
tab' :: Map a (PrimExp a)
tab' = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped Map a (TPrimExp t a)
tab
sub :: TPrimExp t a -> TPrimExp t a
sub = forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall v. Ord v => Map v (PrimExp v) -> PrimExp v -> PrimExp v
substituteInPrimExp Map a (PrimExp a)
tab' forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped
substituteInIxFun ::
Ord a =>
M.Map a (TPrimExp t a) ->
IxFun (TPrimExp t a) ->
IxFun (TPrimExp t a)
substituteInIxFun :: forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a)
-> IxFun (TPrimExp t a) -> IxFun (TPrimExp t a)
substituteInIxFun Map a (TPrimExp t a)
tab (IxFun NonEmpty (LMAD (TPrimExp t a))
lmads Shape (TPrimExp t a)
oshp Bool
cg) =
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun
(forall a b. (a -> b) -> NonEmpty a -> NonEmpty b
NE.map (forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a) -> LMAD (TPrimExp t a) -> LMAD (TPrimExp t a)
substituteInLMAD Map a (TPrimExp t a)
tab) NonEmpty (LMAD (TPrimExp t a))
lmads)
(forall a b. (a -> b) -> [a] -> [b]
map (forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall v. Ord v => Map v (PrimExp v) -> PrimExp v -> PrimExp v
substituteInPrimExp Map a (PrimExp a)
tab' forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped) Shape (TPrimExp t a)
oshp)
Bool
cg
where
tab' :: Map a (PrimExp a)
tab' = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped Map a (TPrimExp t a)
tab
isDirect :: (Eq num, IntegralExp num) => IxFun num -> Bool
isDirect :: forall num. (Eq num, IntegralExp num) => IxFun num -> Bool
isDirect ixfun :: IxFun num
ixfun@(IxFun (LMAD num
offset [LMADDim num]
dims :| []) Shape num
oshp Bool
True) =
let strides_expected :: Shape num
strides_expected = forall a. [a] -> [a]
reverse forall a b. (a -> b) -> a -> b
$ forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl forall a. Num a => a -> a -> a
(*) num
1 (forall a. [a] -> [a]
reverse (forall a. [a] -> [a]
tail Shape num
oshp))
in forall num. IxFun num -> Bool
hasContiguousPerm IxFun num
ixfun
Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape num
oshp forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims
Bool -> Bool -> Bool
&& num
offset forall a. Eq a => a -> a -> Bool
== num
0
Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all
(\(LMADDim num
s num
n Int
p Monotonicity
_, Int
m, num
d, num
se) -> num
s forall a. Eq a => a -> a -> Bool
== num
se Bool -> Bool -> Bool
&& num
n forall a. Eq a => a -> a -> Bool
== num
d Bool -> Bool -> Bool
&& Int
p forall a. Eq a => a -> a -> Bool
== Int
m)
(forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [LMADDim num]
dims [Int
0 .. forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims forall a. Num a => a -> a -> a
- Int
1] Shape num
oshp Shape num
strides_expected)
isDirect IxFun num
_ = Bool
False
hasOneLmad :: IxFun num -> Bool
hasOneLmad :: forall num. IxFun num -> Bool
hasOneLmad (IxFun (LMAD num
_ :| []) Shape num
_ Bool
_) = Bool
True
hasOneLmad IxFun num
_ = Bool
False
hasContiguousPerm :: IxFun num -> Bool
hasContiguousPerm :: forall num. IxFun num -> Bool
hasContiguousPerm (IxFun (LMAD num
lmad :| []) Shape num
_ Bool
_) =
let perm :: Permutation
perm = forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad
in Permutation
perm forall a. Eq a => a -> a -> Bool
== forall a. Ord a => [a] -> [a]
sort Permutation
perm
hasContiguousPerm IxFun num
_ = Bool
False
shape :: (Eq num, IntegralExp num) => IxFun num -> Shape num
shape :: forall num. (Eq num, IntegralExp num) => IxFun num -> Shape num
shape (IxFun (LMAD num
lmad :| [LMAD num]
_) Shape num
_ Bool
_) =
forall a. Permutation -> [a] -> [a]
permuteFwd (forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad) forall a b. (a -> b) -> a -> b
$ forall num. (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShapeBase LMAD num
lmad
lmadShape :: (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShape :: forall num. (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShape LMAD num
lmad = forall a. Permutation -> [a] -> [a]
permuteInv (forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad) forall a b. (a -> b) -> a -> b
$ forall num. (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShapeBase LMAD num
lmad
lmadShapeBase :: (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShapeBase :: forall num. (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShapeBase = forall a b. (a -> b) -> [a] -> [b]
map forall num. LMADDim num -> num
ldShape forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall num. LMAD num -> [LMADDim num]
lmadDims
index ::
(IntegralExp num, Eq num) =>
IxFun num ->
Indices num ->
num
index :: forall num.
(IntegralExp num, Eq num) =>
IxFun num -> Indices num -> num
index = forall num.
(IntegralExp num, Eq num) =>
NonEmpty (LMAD num) -> Indices num -> num
indexFromLMADs forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs
where
indexFromLMADs ::
(IntegralExp num, Eq num) =>
NonEmpty (LMAD num) ->
Indices num ->
num
indexFromLMADs :: forall num.
(IntegralExp num, Eq num) =>
NonEmpty (LMAD num) -> Indices num -> num
indexFromLMADs (LMAD num
lmad :| []) Indices num
inds = forall num.
(IntegralExp num, Eq num) =>
LMAD num -> Indices num -> num
indexLMAD LMAD num
lmad Indices num
inds
indexFromLMADs (LMAD num
lmad1 :| LMAD num
lmad2 : [LMAD num]
lmads) Indices num
inds =
let i_flat :: num
i_flat = forall num.
(IntegralExp num, Eq num) =>
LMAD num -> Indices num -> num
indexLMAD LMAD num
lmad1 Indices num
inds
new_inds :: Indices num
new_inds = forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex (forall a. Permutation -> [a] -> [a]
permuteFwd (forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad2) forall a b. (a -> b) -> a -> b
$ forall num. (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShapeBase LMAD num
lmad2) num
i_flat
in forall num.
(IntegralExp num, Eq num) =>
NonEmpty (LMAD num) -> Indices num -> num
indexFromLMADs (LMAD num
lmad2 forall a. a -> [a] -> NonEmpty a
:| [LMAD num]
lmads) Indices num
new_inds
indexLMAD ::
(IntegralExp num, Eq num) =>
LMAD num ->
Indices num ->
num
indexLMAD :: forall num.
(IntegralExp num, Eq num) =>
LMAD num -> Indices num -> num
indexLMAD lmad :: LMAD num
lmad@(LMAD num
off [LMADDim num]
dims) Indices num
inds =
let prod :: num
prod =
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum forall a b. (a -> b) -> a -> b
$
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
forall num. (Eq num, IntegralExp num) => num -> num -> num
flatOneDim
(forall a b. (a -> b) -> [a] -> [b]
map forall num. LMADDim num -> num
ldStride [LMADDim num]
dims)
(forall a. Permutation -> [a] -> [a]
permuteInv (forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad) Indices num
inds)
in num
off forall a. Num a => a -> a -> a
+ num
prod
iotaOffset :: IntegralExp num => num -> Shape num -> IxFun num
iotaOffset :: forall num. IntegralExp num => num -> Shape num -> IxFun num
iotaOffset num
o Shape num
ns = forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (forall num.
IntegralExp num =>
Monotonicity -> num -> [num] -> LMAD num
makeRotIota Monotonicity
Inc num
o Shape num
ns forall a. a -> [a] -> NonEmpty a
:| []) Shape num
ns Bool
True
iota :: IntegralExp num => Shape num -> IxFun num
iota :: forall num. IntegralExp num => Shape num -> IxFun num
iota = forall num. IntegralExp num => num -> Shape num -> IxFun num
iotaOffset num
0
mkExistential :: Int -> [(Int, Monotonicity)] -> Bool -> Int -> IxFun (Ext a)
mkExistential :: forall a.
Int -> [(Int, Monotonicity)] -> Bool -> Int -> IxFun (Ext a)
mkExistential Int
basis_rank [(Int, Monotonicity)]
perm Bool
contig Int
start =
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (forall a. a -> NonEmpty a
NE.singleton forall {a}. LMAD (Ext a)
lmad) forall {a}. [Ext a]
basis Bool
contig
where
basis :: [Ext a]
basis = forall a. Int -> [a] -> [a]
take Int
basis_rank forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a. Int -> Ext a
Ext [Int
start forall a. Num a => a -> a -> a
+ Int
1 forall a. Num a => a -> a -> a
+ Int
dims_rank forall a. Num a => a -> a -> a
* Int
2 ..]
dims_rank :: Int
dims_rank = forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Int, Monotonicity)]
perm
lmad :: LMAD (Ext a)
lmad = forall num. num -> [LMADDim num] -> LMAD num
LMAD (forall a. Int -> Ext a
Ext Int
start) forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall {a}. (Int, Monotonicity) -> Int -> LMADDim (Ext a)
onDim [(Int, Monotonicity)]
perm [Int
0 ..]
onDim :: (Int, Monotonicity) -> Int -> LMADDim (Ext a)
onDim (Int
p, Monotonicity
mon) Int
i =
forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim (forall a. Int -> Ext a
Ext (Int
start forall a. Num a => a -> a -> a
+ Int
1 forall a. Num a => a -> a -> a
+ Int
i forall a. Num a => a -> a -> a
* Int
2)) (forall a. Int -> Ext a
Ext (Int
start forall a. Num a => a -> a -> a
+ Int
2 forall a. Num a => a -> a -> a
+ Int
i forall a. Num a => a -> a -> a
* Int
2)) Int
p Monotonicity
mon
permute ::
IntegralExp num =>
IxFun num ->
Permutation ->
IxFun num
permute :: forall num.
IntegralExp num =>
IxFun num -> Permutation -> IxFun num
permute (IxFun (LMAD num
lmad :| [LMAD num]
lmads) Shape num
oshp Bool
cg) Permutation
perm_new =
let perm_cur :: Permutation
perm_cur = forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad
perm :: Permutation
perm = forall a b. (a -> b) -> [a] -> [b]
map (Permutation
perm_cur !!) Permutation
perm_new
in forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (forall num. Permutation -> LMAD num -> LMAD num
setLMADPermutation Permutation
perm LMAD num
lmad forall a. a -> [a] -> NonEmpty a
:| [LMAD num]
lmads) Shape num
oshp Bool
cg
sliceOneLMAD ::
(Eq num, IntegralExp num) =>
IxFun num ->
Slice num ->
Maybe (IxFun num)
sliceOneLMAD :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> Maybe (IxFun num)
sliceOneLMAD (IxFun (lmad :: LMAD num
lmad@(LMAD num
_ [LMADDim num]
ldims) :| [LMAD num]
lmads) Shape num
oshp Bool
cg) (Slice [DimIndex num]
is) = do
let perm :: Permutation
perm = forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad
is' :: [DimIndex num]
is' = forall a. Permutation -> [a] -> [a]
permuteInv Permutation
perm [DimIndex num]
is
cg' :: Bool
cg' = Bool
cg Bool -> Bool -> Bool
&& forall num.
(Eq num, IntegralExp num) =>
LMAD num -> Slice num -> Bool
slicePreservesContiguous LMAD num
lmad (forall d. [DimIndex d] -> Slice d
Slice [DimIndex num]
is')
let lmad' :: LMAD num
lmad' = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl forall num.
(Eq num, IntegralExp num) =>
LMAD num -> (DimIndex num, LMADDim num) -> LMAD num
sliceOne (forall num. num -> [LMADDim num] -> LMAD num
LMAD (forall num. LMAD num -> num
lmadOffset LMAD num
lmad) []) forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [DimIndex num]
is' [LMADDim num]
ldims
perm' :: Permutation
perm' =
forall {t :: * -> *} {b} {t :: * -> *}.
(Foldable t, Ord b, Foldable t, Num b) =>
t b -> t b -> [b]
updatePerm Permutation
perm forall a b. (a -> b) -> a -> b
$
forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$
forall a. (a -> Bool) -> [a] -> [a]
filter (forall a. Maybe a -> Bool
isJust forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall d. DimIndex d -> Maybe d
dimFix forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b. (a, b) -> b
snd) forall a b. (a -> b) -> a -> b
$
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 .. forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex num]
is' forall a. Num a => a -> a -> a
- Int
1] [DimIndex num]
is'
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (forall num. Permutation -> LMAD num -> LMAD num
setLMADPermutation Permutation
perm' LMAD num
lmad' forall a. a -> [a] -> NonEmpty a
:| [LMAD num]
lmads) Shape num
oshp Bool
cg'
where
updatePerm :: t b -> t b -> [b]
updatePerm t b
ps t b
inds = forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap b -> [b]
decrease t b
ps
where
decrease :: b -> [b]
decrease b
p =
let f :: a -> b -> a
f a
n b
i
| b
i forall a. Eq a => a -> a -> Bool
== b
p = -a
1
| b
i forall a. Ord a => a -> a -> Bool
> b
p = a
n
| a
n forall a. Eq a => a -> a -> Bool
/= -a
1 = a
n forall a. Num a => a -> a -> a
+ a
1
| Bool
otherwise = a
n
d :: b
d = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl forall {a}. (Num a, Eq a) => a -> b -> a
f b
0 t b
inds
in [b
p forall a. Num a => a -> a -> a
- b
d | b
d forall a. Eq a => a -> a -> Bool
/= -b
1]
sliceOne ::
(Eq num, IntegralExp num) =>
LMAD num ->
(DimIndex num, LMADDim num) ->
LMAD num
sliceOne :: forall num.
(Eq num, IntegralExp num) =>
LMAD num -> (DimIndex num, LMADDim num) -> LMAD num
sliceOne (LMAD num
off [LMADDim num]
dims) (DimFix num
i, LMADDim num
s num
_x Int
_ Monotonicity
_) =
forall num. num -> [LMADDim num] -> LMAD num
LMAD (num
off forall a. Num a => a -> a -> a
+ forall num. (Eq num, IntegralExp num) => num -> num -> num
flatOneDim num
s num
i) [LMADDim num]
dims
sliceOne (LMAD num
off [LMADDim num]
dims) (DimSlice num
_ num
ne num
_, LMADDim num
0 num
_ Int
p Monotonicity
_) =
forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off ([LMADDim num]
dims forall a. [a] -> [a] -> [a]
++ [forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim num
0 num
ne Int
p Monotonicity
Unknown])
sliceOne (LMAD num
off [LMADDim num]
dims) (DimIndex num
dmind, dim :: LMADDim num
dim@(LMADDim num
_ num
n Int
_ Monotonicity
_))
| DimIndex num
dmind forall a. Eq a => a -> a -> Bool
== forall d. Num d => d -> d -> DimIndex d
unitSlice num
0 num
n = forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off ([LMADDim num]
dims forall a. [a] -> [a] -> [a]
++ [LMADDim num
dim])
sliceOne (LMAD num
off [LMADDim num]
dims) (DimIndex num
dmind, LMADDim num
s num
n Int
p Monotonicity
m)
| DimIndex num
dmind forall a. Eq a => a -> a -> Bool
== forall d. d -> d -> d -> DimIndex d
DimSlice (num
n forall a. Num a => a -> a -> a
- num
1) num
n (-num
1) =
let off' :: num
off' = num
off forall a. Num a => a -> a -> a
+ forall num. (Eq num, IntegralExp num) => num -> num -> num
flatOneDim num
s (num
n forall a. Num a => a -> a -> a
- num
1)
in forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off' ([LMADDim num]
dims forall a. [a] -> [a] -> [a]
++ [forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim (num
s forall a. Num a => a -> a -> a
* (-num
1)) num
n Int
p (Monotonicity -> Monotonicity
invertMonotonicity Monotonicity
m)])
sliceOne (LMAD num
off [LMADDim num]
dims) (DimSlice num
b num
ne num
0, LMADDim num
s num
_ Int
p Monotonicity
_) =
forall num. num -> [LMADDim num] -> LMAD num
LMAD (num
off forall a. Num a => a -> a -> a
+ forall num. (Eq num, IntegralExp num) => num -> num -> num
flatOneDim num
s num
b) ([LMADDim num]
dims forall a. [a] -> [a] -> [a]
++ [forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim num
0 num
ne Int
p Monotonicity
Unknown])
sliceOne (LMAD num
off [LMADDim num]
dims) (DimSlice num
bs num
ns num
ss, LMADDim num
s num
_ Int
p Monotonicity
m) =
let m' :: Monotonicity
m' = case forall e. IntegralExp e => e -> Maybe Int
sgn num
ss of
Just Int
1 -> Monotonicity
m
Just (-1) -> Monotonicity -> Monotonicity
invertMonotonicity Monotonicity
m
Maybe Int
_ -> Monotonicity
Unknown
in forall num. num -> [LMADDim num] -> LMAD num
LMAD (num
off forall a. Num a => a -> a -> a
+ num
s forall a. Num a => a -> a -> a
* num
bs) ([LMADDim num]
dims forall a. [a] -> [a] -> [a]
++ [forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim (num
ss forall a. Num a => a -> a -> a
* num
s) num
ns Int
p Monotonicity
m'])
slicePreservesContiguous ::
(Eq num, IntegralExp num) =>
LMAD num ->
Slice num ->
Bool
slicePreservesContiguous :: forall num.
(Eq num, IntegralExp num) =>
LMAD num -> Slice num -> Bool
slicePreservesContiguous (LMAD num
_ [LMADDim num]
dims) (Slice [DimIndex num]
slc) =
let ([LMADDim num]
dims', [DimIndex num]
slc') =
forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$
forall a. (a -> Bool) -> [a] -> [a]
filter ((forall a. Eq a => a -> a -> Bool
/= num
0) forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall num. LMADDim num -> num
ldStride forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b. (a, b) -> a
fst) forall a b. (a -> b) -> a -> b
$
forall a b. [a] -> [b] -> [(a, b)]
zip [LMADDim num]
dims forall a b. (a -> b) -> a -> b
$
forall a b. (a -> b) -> [a] -> [b]
map forall num.
(Eq num, IntegralExp num) =>
DimIndex num -> DimIndex num
normIndex [DimIndex num]
slc
(Bool
_, Bool
success) =
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
( \(Bool
found, Bool
res) (DimIndex num
slcdim, LMADDim num
_ num
n Int
_ Monotonicity
_) ->
case (DimIndex num
slcdim, Bool
found) of
(DimFix {}, Bool
True) -> (Bool
found, Bool
False)
(DimFix {}, Bool
False) -> (Bool
found, Bool
res)
(DimSlice num
_ num
_ num
ds, Bool
False) ->
let res' :: Bool
res' = (num
ds forall a. Eq a => a -> a -> Bool
== num
1 Bool -> Bool -> Bool
|| num
ds forall a. Eq a => a -> a -> Bool
== -num
1)
in (Bool
True, Bool
res Bool -> Bool -> Bool
&& Bool
res')
(DimSlice num
_ num
ne num
ds, Bool
True) ->
let res' :: Bool
res' = (num
n forall a. Eq a => a -> a -> Bool
== num
ne) Bool -> Bool -> Bool
&& (num
ds forall a. Eq a => a -> a -> Bool
== num
1 Bool -> Bool -> Bool
|| num
ds forall a. Eq a => a -> a -> Bool
== -num
1)
in (Bool
found, Bool
res Bool -> Bool -> Bool
&& Bool
res')
)
(Bool
False, Bool
True)
forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [DimIndex num]
slc' [LMADDim num]
dims'
in Bool
success
normIndex ::
(Eq num, IntegralExp num) =>
DimIndex num ->
DimIndex num
normIndex :: forall num.
(Eq num, IntegralExp num) =>
DimIndex num -> DimIndex num
normIndex (DimSlice num
b num
1 num
_) = forall d. d -> DimIndex d
DimFix num
b
normIndex (DimSlice num
b num
_ num
0) = forall d. d -> DimIndex d
DimFix num
b
normIndex DimIndex num
d = DimIndex num
d
slice ::
(Eq num, IntegralExp num) =>
IxFun num ->
Slice num ->
IxFun num
slice :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
slice ixfun :: IxFun num
ixfun@(IxFun (lmad :: LMAD num
lmad@(LMAD num
_ [LMADDim num]
_) :| [LMAD num]
lmads) Shape num
oshp Bool
cg) Slice num
dim_slices
| forall d. Slice d -> [DimIndex d]
unSlice Slice num
dim_slices forall a. Eq a => a -> a -> Bool
== forall a b. (a -> b) -> [a] -> [b]
map (forall d. Num d => d -> d -> DimIndex d
unitSlice num
0) (forall num. (Eq num, IntegralExp num) => IxFun num -> Shape num
shape IxFun num
ixfun) = IxFun num
ixfun
| Just IxFun num
ixfun' <- forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> Maybe (IxFun num)
sliceOneLMAD IxFun num
ixfun Slice num
dim_slices = IxFun num
ixfun'
| Bool
otherwise =
case forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> Maybe (IxFun num)
sliceOneLMAD (forall num. IntegralExp num => Shape num -> IxFun num
iota (forall num. (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShape LMAD num
lmad)) Slice num
dim_slices of
Just (IxFun (LMAD num
lmad' :| []) Shape num
_ Bool
cg') ->
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (LMAD num
lmad' forall a. a -> [a] -> NonEmpty a
:| LMAD num
lmad forall a. a -> [a] -> [a]
: [LMAD num]
lmads) Shape num
oshp (Bool
cg Bool -> Bool -> Bool
&& Bool
cg')
Maybe (IxFun num)
_ -> forall a. HasCallStack => String -> a
error String
"slice: reached impossible case"
flatSlice ::
(Eq num, IntegralExp num) =>
IxFun num ->
FlatSlice num ->
IxFun num
flatSlice :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> FlatSlice num -> IxFun num
flatSlice ixfun :: IxFun num
ixfun@(IxFun (LMAD num
offset (LMADDim num
dim : [LMADDim num]
dims) :| [LMAD num]
lmads) Shape num
oshp Bool
cg) (FlatSlice num
new_offset [FlatDimIndex num]
is)
| forall num. IxFun num -> Bool
hasContiguousPerm IxFun num
ixfun =
let lmad :: LMAD num
lmad =
forall num. num -> [LMADDim num] -> LMAD num
LMAD
(num
offset forall a. Num a => a -> a -> a
+ num
new_offset forall a. Num a => a -> a -> a
* forall num. LMADDim num -> num
ldStride LMADDim num
dim)
(forall a b. (a -> b) -> [a] -> [b]
map (forall {num}.
(Eq num, Num num) =>
num -> FlatDimIndex num -> LMADDim num
helper forall a b. (a -> b) -> a -> b
$ forall num. LMADDim num -> num
ldStride LMADDim num
dim) [FlatDimIndex num]
is forall a. Semigroup a => a -> a -> a
<> [LMADDim num]
dims)
forall a b. a -> (a -> b) -> b
& forall num. Permutation -> LMAD num -> LMAD num
setLMADPermutation [Int
0 ..]
in forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (LMAD num
lmad forall a. a -> [a] -> NonEmpty a
:| [LMAD num]
lmads) Shape num
oshp Bool
cg
where
helper :: num -> FlatDimIndex num -> LMADDim num
helper num
s0 (FlatDimIndex num
n num
s) =
let new_mon :: Monotonicity
new_mon = if num
s0 forall a. Num a => a -> a -> a
* num
s forall a. Eq a => a -> a -> Bool
== num
1 then Monotonicity
Inc else Monotonicity
Unknown
in forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim (num
s0 forall a. Num a => a -> a -> a
* num
s) num
n Int
0 Monotonicity
new_mon
flatSlice (IxFun (LMAD num
lmad :| [LMAD num]
lmads) Shape num
oshp Bool
cg) s :: FlatSlice num
s@(FlatSlice num
new_offset [FlatDimIndex num]
_) =
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (forall num. num -> [LMADDim num] -> LMAD num
LMAD (num
new_offset forall a. Num a => a -> a -> a
* num
base_stride) ([LMADDim num]
new_dims forall a. Semigroup a => a -> a -> a
<> [LMADDim num]
tail_dims) forall a. a -> [a] -> NonEmpty a
:| LMAD num
lmad forall a. a -> [a] -> [a]
: [LMAD num]
lmads) Shape num
oshp Bool
cg
where
tail_shapes :: Shape num
tail_shapes = forall a. [a] -> [a]
tail forall a b. (a -> b) -> a -> b
$ forall num. (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShape LMAD num
lmad
base_stride :: num
base_stride = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product Shape num
tail_shapes
tail_strides :: Shape num
tail_strides = forall a. [a] -> [a]
tail forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b -> b) -> b -> [a] -> [b]
scanr forall a. Num a => a -> a -> a
(*) num
1 Shape num
tail_shapes
tail_dims :: [LMADDim num]
tail_dims = forall a b c d e.
(a -> b -> c -> d -> e) -> [a] -> [b] -> [c] -> [d] -> [e]
zipWith4 forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim Shape num
tail_strides Shape num
tail_shapes [forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape num
new_shapes ..] (forall a. a -> [a]
repeat Monotonicity
Inc)
new_shapes :: Shape num
new_shapes = forall d. FlatSlice d -> [d]
flatSliceDims FlatSlice num
s
new_strides :: Shape num
new_strides = forall a b. (a -> b) -> [a] -> [b]
map (forall a. Num a => a -> a -> a
* num
base_stride) forall a b. (a -> b) -> a -> b
$ forall d. FlatSlice d -> [d]
flatSliceStrides FlatSlice num
s
new_dims :: [LMADDim num]
new_dims = forall a b c d e.
(a -> b -> c -> d -> e) -> [a] -> [b] -> [c] -> [d] -> [e]
zipWith4 forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim Shape num
new_strides Shape num
new_shapes [Int
0 ..] (forall a. a -> [a]
repeat Monotonicity
Inc)
reshapeOneLMAD ::
(Eq num, IntegralExp num) =>
IxFun num ->
Shape num ->
Maybe (IxFun num)
reshapeOneLMAD :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Shape num -> Maybe (IxFun num)
reshapeOneLMAD ixfun :: IxFun num
ixfun@(IxFun (lmad :: LMAD num
lmad@(LMAD num
off [LMADDim num]
dims) :| [LMAD num]
lmads) Shape num
oldbase Bool
cg) Shape num
newshape = do
let perm :: Permutation
perm = forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad
dims_perm :: [LMADDim num]
dims_perm = forall a. Permutation -> [a] -> [a]
permuteFwd Permutation
perm [LMADDim num]
dims
mid_dims :: [LMADDim num]
mid_dims = forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims) [LMADDim num]
dims_perm
mon :: Monotonicity
mon = forall num. (Eq num, IntegralExp num) => IxFun num -> Monotonicity
ixfunMonotonicity IxFun num
ixfun
forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\(LMADDim num
s num
_ Int
_ Monotonicity
_) -> num
s forall a. Eq a => a -> a -> Bool
/= num
0) [LMADDim num]
mid_dims
Bool -> Bool -> Bool
&&
forall {b}. (Eq b, Num b, Enum b) => b -> [b] -> Bool
consecutive Int
0 (forall a b. (a -> b) -> [a] -> [b]
map forall num. LMADDim num -> Int
ldPerm [LMADDim num]
mid_dims)
Bool -> Bool -> Bool
&&
forall num. IxFun num -> Bool
hasContiguousPerm IxFun num
ixfun
Bool -> Bool -> Bool
&& Bool
cg
Bool -> Bool -> Bool
&& (Monotonicity
mon forall a. Eq a => a -> a -> Bool
== Monotonicity
Inc Bool -> Bool -> Bool
|| Monotonicity
mon forall a. Eq a => a -> a -> Bool
== Monotonicity
Dec)
let rsh_len :: Int
rsh_len = forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape num
newshape
diff :: Int
diff = forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape num
newshape forall a. Num a => a -> a -> a
- forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims
iota_shape :: Permutation
iota_shape = [Int
0 .. forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape num
newshape forall a. Num a => a -> a -> a
- Int
1]
perm' :: Permutation
perm' =
forall a b. (a -> b) -> [a] -> [b]
map
( \Int
i ->
let ind :: Int
ind = Int
i forall a. Num a => a -> a -> a
- Int
diff
in if (Int
i forall a. Ord a => a -> a -> Bool
>= Int
0) Bool -> Bool -> Bool
&& (Int
i forall a. Ord a => a -> a -> Bool
< Int
rsh_len)
then Int
i
else forall num. LMADDim num -> Int
ldPerm ([LMADDim num]
dims forall a. [a] -> Int -> a
!! Int
ind) forall a. Num a => a -> a -> a
+ Int
diff
)
Permutation
iota_shape
([(Int, num)]
support_inds, [a]
repeat_inds) =
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
(\([(Int, num)]
sup, [a]
rpt) (num
shpdim, Int
ip) -> ((Int
ip, num
shpdim) forall a. a -> [a] -> [a]
: [(Int, num)]
sup, [a]
rpt))
([], [])
forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
reverse
forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip Shape num
newshape Permutation
perm'
(Permutation
sup_inds, Shape num
support) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (forall a. Ord a => a -> a -> Ordering
compare forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` forall a b. (a, b) -> a
fst) [(Int, num)]
support_inds
([a]
rpt_inds, [b]
repeats) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a. [a]
repeat_inds
LMAD num
off' [LMADDim num]
dims_sup = forall num.
IntegralExp num =>
Monotonicity -> num -> [num] -> LMAD num
makeRotIota Monotonicity
mon num
off Shape num
support
repeats' :: [LMADDim num]
repeats' = forall a b. (a -> b) -> [a] -> [b]
map (\num
n -> forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim num
0 num
n Int
0 Monotonicity
Unknown) forall a. [a]
repeats
dims' :: [LMADDim num]
dims' =
forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (forall a. Ord a => a -> a -> Ordering
compare forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` forall a b. (a, b) -> a
fst) forall a b. (a -> b) -> a -> b
$
forall a b. [a] -> [b] -> [(a, b)]
zip Permutation
sup_inds [LMADDim num]
dims_sup forall a. [a] -> [a] -> [a]
++ forall a b. [a] -> [b] -> [(a, b)]
zip forall a. [a]
rpt_inds [LMADDim num]
repeats'
lmad' :: LMAD num
lmad' = forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off' [LMADDim num]
dims'
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (forall num. Permutation -> LMAD num -> LMAD num
setLMADPermutation Permutation
perm' LMAD num
lmad' forall a. a -> [a] -> NonEmpty a
:| [LMAD num]
lmads) Shape num
oldbase Bool
cg
where
consecutive :: b -> [b] -> Bool
consecutive b
_ [] = Bool
True
consecutive b
i [b
p] = b
i forall a. Eq a => a -> a -> Bool
== b
p
consecutive b
i [b]
ps = forall (t :: * -> *). Foldable t => t Bool -> Bool
and forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. Eq a => a -> a -> Bool
(==) [b]
ps [b
i, b
i forall a. Num a => a -> a -> a
+ b
1 ..]
reshape ::
(Eq num, IntegralExp num) =>
IxFun num ->
Shape num ->
IxFun num
reshape :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Shape num -> IxFun num
reshape IxFun num
ixfun Shape num
new_shape
| Just IxFun num
ixfun' <- forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Shape num -> Maybe (IxFun num)
reshapeOneLMAD IxFun num
ixfun Shape num
new_shape = IxFun num
ixfun'
reshape (IxFun (LMAD num
lmad0 :| [LMAD num]
lmad0s) Shape num
oshp Bool
cg) Shape num
new_shape =
case forall num. IntegralExp num => Shape num -> IxFun num
iota Shape num
new_shape of
IxFun (LMAD num
lmad :| []) Shape num
_ Bool
_ -> forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (LMAD num
lmad forall a. a -> [a] -> NonEmpty a
:| LMAD num
lmad0 forall a. a -> [a] -> [a]
: [LMAD num]
lmad0s) Shape num
oshp Bool
cg
IxFun num
_ -> forall a. HasCallStack => String -> a
error String
"reshape: reached impossible case"
coerce ::
(Eq num, IntegralExp num) =>
IxFun num ->
Shape num ->
IxFun num
coerce :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Shape num -> IxFun num
coerce (IxFun (LMAD num
lmad :| [LMAD num]
lmads) Shape num
oshp Bool
cg) Shape num
new_shape =
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (LMAD num -> LMAD num
onLMAD LMAD num
lmad forall a. a -> [a] -> NonEmpty a
:| [LMAD num]
lmads) Shape num
oshp Bool
cg
where
onLMAD :: LMAD num -> LMAD num
onLMAD (LMAD num
offset [LMADDim num]
dims) = forall num. num -> [LMADDim num] -> LMAD num
LMAD num
offset forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall {num}. LMADDim num -> num -> LMADDim num
onDim [LMADDim num]
dims Shape num
new_shape
onDim :: LMADDim num -> num -> LMADDim num
onDim LMADDim num
ld num
d = LMADDim num
ld {ldShape :: num
ldShape = num
d}
rank ::
IntegralExp num =>
IxFun num ->
Int
rank :: forall num. IntegralExp num => IxFun num -> Int
rank (IxFun (LMAD num
_ [LMADDim num]
sss :| [LMAD num]
_) Shape num
_ Bool
_) = forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
sss
rebaseNice ::
(Eq num, IntegralExp num) =>
IxFun num ->
IxFun num ->
Maybe (IxFun num)
rebaseNice :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> IxFun num -> Maybe (IxFun num)
rebaseNice
new_base :: IxFun num
new_base@(IxFun (LMAD num
lmad_base :| [LMAD num]
lmads_base) Shape num
_ Bool
cg_base)
ixfun :: IxFun num
ixfun@(IxFun NonEmpty (LMAD num)
lmads Shape num
shp Bool
cg) = do
let (LMAD num
lmad :| [LMAD num]
lmads') = forall a. NonEmpty a -> NonEmpty a
NE.reverse NonEmpty (LMAD num)
lmads
dims :: [LMADDim num]
dims = forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad
perm :: Permutation
perm = forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad
perm_base :: Permutation
perm_base = forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad_base
forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$
forall num. IxFun num -> Shape num
base IxFun num
ixfun forall a. Eq a => a -> a -> Bool
== forall num. (Eq num, IntegralExp num) => IxFun num -> Shape num
shape IxFun num
new_base
Bool -> Bool -> Bool
&& Bool
cg
Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((forall a. Eq a => a -> a -> Bool
/= Monotonicity
Unknown) forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall num. LMADDim num -> Monotonicity
ldMon) [LMADDim num]
dims
Bool -> Bool -> Bool
&& (forall num. IxFun num -> Bool
hasContiguousPerm IxFun num
ixfun Bool -> Bool -> Bool
|| forall num. IxFun num -> Bool
hasContiguousPerm IxFun num
new_base)
Bool -> Bool -> Bool
&& (forall (t :: * -> *) a. Foldable t => t a -> Int
length Permutation
perm forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length Permutation
perm_base Bool -> Bool -> Bool
|| forall num. IxFun num -> Bool
hasContiguousPerm IxFun num
ixfun)
Bool -> Bool -> Bool
&& forall (t :: * -> *). Foldable t => t Bool -> Bool
and
( forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3
(\num
sn LMADDim num
ld Bool
inner -> num
sn forall a. Eq a => a -> a -> Bool
== forall num. LMADDim num -> num
ldShape LMADDim num
ld Bool -> Bool -> Bool
|| (Bool
inner Bool -> Bool -> Bool
&& forall num. LMADDim num -> num
ldStride LMADDim num
ld forall a. Eq a => a -> a -> Bool
== num
1))
Shape num
shp
[LMADDim num]
dims
(forall a. Int -> a -> [a]
replicate (forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims forall a. Num a => a -> a -> a
- Int
1) Bool
False forall a. [a] -> [a] -> [a]
++ [Bool
True])
)
let perm_base' :: Permutation
perm_base' =
if forall num. IxFun num -> Bool
hasContiguousPerm IxFun num
ixfun
then Permutation
perm_base
else forall a b. (a -> b) -> [a] -> [b]
map (Permutation
perm !!) Permutation
perm_base
lmad_base' :: LMAD num
lmad_base' = forall num. Permutation -> LMAD num -> LMAD num
setLMADPermutation Permutation
perm_base' LMAD num
lmad_base
dims_base :: [LMADDim num]
dims_base = forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad_base'
n_fewer_dims :: Int
n_fewer_dims = forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims_base forall a. Num a => a -> a -> a
- forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims
([LMADDim num]
dims_base', Shape num
offs_contrib) =
forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
( \(LMADDim num
s1 num
n1 Int
p1 Monotonicity
_) (LMADDim num
_ num
_ Int
_ Monotonicity
m2) ->
let (num
s', num
off')
| Monotonicity
m2 forall a. Eq a => a -> a -> Bool
== Monotonicity
Inc = (num
s1, num
0)
| Bool
otherwise = (num
s1 forall a. Num a => a -> a -> a
* (-num
1), num
s1 forall a. Num a => a -> a -> a
* (num
n1 forall a. Num a => a -> a -> a
- num
1))
in (forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim num
s' num
n1 (Int
p1 forall a. Num a => a -> a -> a
- Int
n_fewer_dims) Monotonicity
Inc, num
off')
)
(forall a. Int -> [a] -> [a]
drop Int
n_fewer_dims [LMADDim num]
dims_base)
[LMADDim num]
dims
off_base :: num
off_base = forall num. LMAD num -> num
lmadOffset LMAD num
lmad_base' forall a. Num a => a -> a -> a
+ forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum Shape num
offs_contrib
lmad_base'' :: LMAD num
lmad_base''
| forall num. LMAD num -> num
lmadOffset LMAD num
lmad forall a. Eq a => a -> a -> Bool
== num
0 = forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off_base [LMADDim num]
dims_base'
| Bool
otherwise =
forall num. Shape num -> LMAD num -> LMAD num
setLMADShape
(forall num. (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShape LMAD num
lmad)
( forall num. num -> [LMADDim num] -> LMAD num
LMAD
(num
off_base forall a. Num a => a -> a -> a
+ forall num. LMADDim num -> num
ldStride (forall a. [a] -> a
last [LMADDim num]
dims_base) forall a. Num a => a -> a -> a
* forall num. LMAD num -> num
lmadOffset LMAD num
lmad)
[LMADDim num]
dims_base'
)
new_base' :: IxFun num
new_base' = forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (LMAD num
lmad_base'' forall a. a -> [a] -> NonEmpty a
:| [LMAD num]
lmads_base) Shape num
shp Bool
cg_base
IxFun NonEmpty (LMAD num)
lmads_base' Shape num
_ Bool
_ = IxFun num
new_base'
lmads'' :: NonEmpty (LMAD num)
lmads'' = [LMAD num]
lmads' forall a. [a] -> NonEmpty a -> NonEmpty a
++@ NonEmpty (LMAD num)
lmads_base'
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun NonEmpty (LMAD num)
lmads'' Shape num
shp (Bool
cg Bool -> Bool -> Bool
&& Bool
cg_base)
rebase ::
(Eq num, IntegralExp num) =>
IxFun num ->
IxFun num ->
IxFun num
rebase :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> IxFun num -> IxFun num
rebase new_base :: IxFun num
new_base@(IxFun NonEmpty (LMAD num)
lmads_base Shape num
shp_base Bool
cg_base) ixfun :: IxFun num
ixfun@(IxFun NonEmpty (LMAD num)
lmads Shape num
shp Bool
cg)
| Just IxFun num
ixfun' <- forall num.
(Eq num, IntegralExp num) =>
IxFun num -> IxFun num -> Maybe (IxFun num)
rebaseNice IxFun num
new_base IxFun num
ixfun = IxFun num
ixfun'
| Bool
otherwise =
let (NonEmpty (LMAD num)
lmads_base', Shape num
shp_base') =
if forall num. IxFun num -> Shape num
base IxFun num
ixfun forall a. Eq a => a -> a -> Bool
== forall num. (Eq num, IntegralExp num) => IxFun num -> Shape num
shape IxFun num
new_base
then (NonEmpty (LMAD num)
lmads_base, Shape num
shp_base)
else
let IxFun NonEmpty (LMAD num)
lmads' Shape num
shp_base'' Bool
_ = forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Shape num -> IxFun num
reshape IxFun num
new_base Shape num
shp
in (NonEmpty (LMAD num)
lmads', Shape num
shp_base'')
in forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (NonEmpty (LMAD num)
lmads forall a. NonEmpty a -> NonEmpty a -> NonEmpty a
@++@ NonEmpty (LMAD num)
lmads_base') Shape num
shp_base' (Bool
cg Bool -> Bool -> Bool
&& Bool
cg_base)
linearWithOffset ::
(Eq num, IntegralExp num) =>
IxFun num ->
num ->
Maybe num
linearWithOffset :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
linearWithOffset ixfun :: IxFun num
ixfun@(IxFun (LMAD num
lmad :| []) Shape num
_ Bool
cg) num
elem_size
| forall num. IxFun num -> Bool
hasContiguousPerm IxFun num
ixfun Bool -> Bool -> Bool
&& Bool
cg Bool -> Bool -> Bool
&& forall num. (Eq num, IntegralExp num) => IxFun num -> Monotonicity
ixfunMonotonicity IxFun num
ixfun forall a. Eq a => a -> a -> Bool
== Monotonicity
Inc =
forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall num. LMAD num -> num
lmadOffset LMAD num
lmad forall a. Num a => a -> a -> a
* num
elem_size
linearWithOffset IxFun num
_ num
_ = forall a. Maybe a
Nothing
rearrangeWithOffset ::
(Eq num, IntegralExp num) =>
IxFun num ->
num ->
Maybe (num, [(Int, num)])
rearrangeWithOffset :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe (num, [(Int, num)])
rearrangeWithOffset (IxFun (LMAD num
lmad :| []) Shape num
oshp Bool
cg) num
elem_size = do
let perm :: Permutation
perm = forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad
perm_contig :: Permutation
perm_contig = [Int
0 .. forall (t :: * -> *) a. Foldable t => t a -> Int
length Permutation
perm forall a. Num a => a -> a -> a
- Int
1]
num
offset <-
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
linearWithOffset
(forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (forall num. Permutation -> LMAD num -> LMAD num
setLMADPermutation Permutation
perm_contig LMAD num
lmad forall a. a -> [a] -> NonEmpty a
:| []) Shape num
oshp Bool
cg)
num
elem_size
forall (f :: * -> *) a. Applicative f => a -> f a
pure (num
offset, forall a b. [a] -> [b] -> [(a, b)]
zip Permutation
perm (forall a. Permutation -> [a] -> [a]
permuteFwd Permutation
perm (forall num. (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShapeBase LMAD num
lmad)))
rearrangeWithOffset IxFun num
_ num
_ = forall a. Maybe a
Nothing
isLinear :: (Eq num, IntegralExp num) => IxFun num -> Bool
isLinear :: forall num. (Eq num, IntegralExp num) => IxFun num -> Bool
isLinear = (forall a. Eq a => a -> a -> Bool
== forall a. a -> Maybe a
Just num
0) forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b c. (a -> b -> c) -> b -> a -> c
flip forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
linearWithOffset num
1
permuteFwd :: Permutation -> [a] -> [a]
permuteFwd :: forall a. Permutation -> [a] -> [a]
permuteFwd Permutation
ps [a]
elems = forall a b. (a -> b) -> [a] -> [b]
map ([a]
elems !!) Permutation
ps
permuteInv :: Permutation -> [a] -> [a]
permuteInv :: forall a. Permutation -> [a] -> [a]
permuteInv Permutation
ps [a]
elems = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (forall a. Ord a => a -> a -> Ordering
compare forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` forall a b. (a, b) -> a
fst) forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip Permutation
ps [a]
elems
flatOneDim ::
(Eq num, IntegralExp num) =>
num ->
num ->
num
flatOneDim :: forall num. (Eq num, IntegralExp num) => num -> num -> num
flatOneDim num
s num
i
| num
s forall a. Eq a => a -> a -> Bool
== num
0 = num
0
| Bool
otherwise = num
i forall a. Num a => a -> a -> a
* num
s
makeRotIota ::
IntegralExp num =>
Monotonicity ->
num ->
[num] ->
LMAD num
makeRotIota :: forall num.
IntegralExp num =>
Monotonicity -> num -> [num] -> LMAD num
makeRotIota Monotonicity
mon num
off [num]
ns
| Monotonicity
mon forall a. Eq a => a -> a -> Bool
== Monotonicity
Inc Bool -> Bool -> Bool
|| Monotonicity
mon forall a. Eq a => a -> a -> Bool
== Monotonicity
Dec =
let rk :: Int
rk = forall (t :: * -> *) a. Foldable t => t a -> Int
length [num]
ns
ss0 :: [num]
ss0 = forall a. [a] -> [a]
reverse forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
take Int
rk forall a b. (a -> b) -> a -> b
$ forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl forall a. Num a => a -> a -> a
(*) num
1 forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
reverse [num]
ns
ss :: [num]
ss =
if Monotonicity
mon forall a. Eq a => a -> a -> Bool
== Monotonicity
Inc
then [num]
ss0
else forall a b. (a -> b) -> [a] -> [b]
map (forall a. Num a => a -> a -> a
* (-num
1)) [num]
ss0
ps :: Permutation
ps = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (Integral a, Num b) => a -> b
fromIntegral [Int
0 .. Int
rk forall a. Num a => a -> a -> a
- Int
1]
fi :: [Monotonicity]
fi = forall a. Int -> a -> [a]
replicate Int
rk Monotonicity
mon
in forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off forall a b. (a -> b) -> a -> b
$ forall a b c d e.
(a -> b -> c -> d -> e) -> [a] -> [b] -> [c] -> [d] -> [e]
zipWith4 forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim [num]
ss [num]
ns Permutation
ps [Monotonicity]
fi
| Bool
otherwise = forall a. HasCallStack => String -> a
error String
"makeRotIota: requires Inc or Dec"
ixfunMonotonicity ::
(Eq num, IntegralExp num) =>
IxFun num ->
Monotonicity
ixfunMonotonicity :: forall num. (Eq num, IntegralExp num) => IxFun num -> Monotonicity
ixfunMonotonicity (IxFun (LMAD num
lmad :| [LMAD num]
lmads) Shape num
_ Bool
_) =
let mon0 :: Monotonicity
mon0 = forall num. (Eq num, IntegralExp num) => LMAD num -> Monotonicity
lmadMonotonicityRots LMAD num
lmad
in if forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((forall a. Eq a => a -> a -> Bool
== Monotonicity
mon0) forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall num. (Eq num, IntegralExp num) => LMAD num -> Monotonicity
lmadMonotonicityRots) [LMAD num]
lmads
then Monotonicity
mon0
else Monotonicity
Unknown
where
lmadMonotonicityRots ::
(Eq num, IntegralExp num) =>
LMAD num ->
Monotonicity
lmadMonotonicityRots :: forall num. (Eq num, IntegralExp num) => LMAD num -> Monotonicity
lmadMonotonicityRots (LMAD num
_ [LMADDim num]
dims)
| forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall num.
(Eq num, IntegralExp num) =>
Monotonicity -> LMADDim num -> Bool
isMonDim Monotonicity
Inc) [LMADDim num]
dims = Monotonicity
Inc
| forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall num.
(Eq num, IntegralExp num) =>
Monotonicity -> LMADDim num -> Bool
isMonDim Monotonicity
Dec) [LMADDim num]
dims = Monotonicity
Dec
| Bool
otherwise = Monotonicity
Unknown
isMonDim ::
(Eq num, IntegralExp num) =>
Monotonicity ->
LMADDim num ->
Bool
isMonDim :: forall num.
(Eq num, IntegralExp num) =>
Monotonicity -> LMADDim num -> Bool
isMonDim Monotonicity
mon (LMADDim num
s num
_ Int
_ Monotonicity
ldmon) =
num
s forall a. Eq a => a -> a -> Bool
== num
0 Bool -> Bool -> Bool
|| Monotonicity
mon forall a. Eq a => a -> a -> Bool
== Monotonicity
ldmon
existentialize ::
IxFun (TPrimExp Int64 a) ->
IxFun (TPrimExp Int64 (Ext b))
existentialize :: forall a b.
IxFun (TPrimExp Int64 a) -> IxFun (TPrimExp Int64 (Ext b))
existentialize IxFun (TPrimExp Int64 a)
ixfun = forall s a. State s a -> s -> a
evalState (forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall a b. a -> b -> a
const forall {k} {t :: k} {a}. StateT Int Identity (TPrimExp t (Ext a))
mkExt) IxFun (TPrimExp Int64 a)
ixfun) Int
0
where
mkExt :: StateT Int Identity (TPrimExp t (Ext a))
mkExt = do
Int
i <- forall s (m :: * -> *). MonadState s m => m s
get
forall s (m :: * -> *). MonadState s m => s -> m ()
put forall a b. (a -> b) -> a -> b
$ Int
i forall a. Num a => a -> a -> a
+ Int
1
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp forall a b. (a -> b) -> a -> b
$ forall v. v -> PrimType -> PrimExp v
LeafExp (forall a. Int -> Ext a
Ext Int
i) PrimType
int64
closeEnough :: IxFun num -> IxFun num -> Bool
closeEnough :: forall num. IxFun num -> IxFun num -> Bool
closeEnough IxFun num
ixf1 IxFun num
ixf2 =
(forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall num. IxFun num -> Shape num
base IxFun num
ixf1) forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall num. IxFun num -> Shape num
base IxFun num
ixf2))
Bool -> Bool -> Bool
&& (forall a. NonEmpty a -> Int
NE.length (forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs IxFun num
ixf1) forall a. Eq a => a -> a -> Bool
== forall a. NonEmpty a -> Int
NE.length (forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs IxFun num
ixf2))
Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all forall num. (LMAD num, LMAD num) -> Bool
closeEnoughLMADs (forall a b. NonEmpty a -> NonEmpty b -> NonEmpty (a, b)
NE.zip (forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs IxFun num
ixf1) (forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs IxFun num
ixf2))
Bool -> Bool -> Bool
&& (forall num. IxFun num -> Bool
contiguous IxFun num
ixf1 forall a. Ord a => a -> a -> Bool
<= forall num. IxFun num -> Bool
contiguous IxFun num
ixf2)
where
closeEnoughLMADs :: (LMAD num, LMAD num) -> Bool
closeEnoughLMADs :: forall num. (LMAD num, LMAD num) -> Bool
closeEnoughLMADs (LMAD num
lmad1, LMAD num
lmad2) =
forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad1) forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad2)
Bool -> Bool -> Bool
&& forall a b. (a -> b) -> [a] -> [b]
map forall num. LMADDim num -> Int
ldPerm (forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad1)
forall a. Eq a => a -> a -> Bool
== forall a b. (a -> b) -> [a] -> [b]
map forall num. LMADDim num -> Int
ldPerm (forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad2)
equivalent :: Eq num => IxFun num -> IxFun num -> Bool
equivalent :: forall num. Eq num => IxFun num -> IxFun num -> Bool
equivalent IxFun num
ixf1 IxFun num
ixf2 =
forall a. NonEmpty a -> Int
NE.length (forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs IxFun num
ixf1) forall a. Eq a => a -> a -> Bool
== forall a. NonEmpty a -> Int
NE.length (forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs IxFun num
ixf2)
Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all forall {b}. Eq b => (LMAD b, LMAD b) -> Bool
equivalentLMADs (forall a b. NonEmpty a -> NonEmpty b -> NonEmpty (a, b)
NE.zip (forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs IxFun num
ixf1) (forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs IxFun num
ixf2))
where
equivalentLMADs :: (LMAD b, LMAD b) -> Bool
equivalentLMADs (LMAD b
lmad1, LMAD b
lmad2) =
forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall num. LMAD num -> [LMADDim num]
lmadDims LMAD b
lmad1) forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall num. LMAD num -> [LMADDim num]
lmadDims LMAD b
lmad2)
Bool -> Bool -> Bool
&& forall a b. (a -> b) -> [a] -> [b]
map forall num. LMADDim num -> Int
ldPerm (forall num. LMAD num -> [LMADDim num]
lmadDims LMAD b
lmad1)
forall a. Eq a => a -> a -> Bool
== forall a b. (a -> b) -> [a] -> [b]
map forall num. LMADDim num -> Int
ldPerm (forall num. LMAD num -> [LMADDim num]
lmadDims LMAD b
lmad2)
Bool -> Bool -> Bool
&& forall num. LMAD num -> num
lmadOffset LMAD b
lmad1
forall a. Eq a => a -> a -> Bool
== forall num. LMAD num -> num
lmadOffset LMAD b
lmad2
Bool -> Bool -> Bool
&& forall a b. (a -> b) -> [a] -> [b]
map forall num. LMADDim num -> num
ldStride (forall num. LMAD num -> [LMADDim num]
lmadDims LMAD b
lmad1)
forall a. Eq a => a -> a -> Bool
== forall a b. (a -> b) -> [a] -> [b]
map forall num. LMADDim num -> num
ldStride (forall num. LMAD num -> [LMADDim num]
lmadDims LMAD b
lmad2)
flatSpan :: LMAD (TPrimExp Int64 VName) -> TPrimExp Int64 VName
flatSpan :: LMAD (TPrimExp Int64 VName) -> TPrimExp Int64 VName
flatSpan (LMAD TPrimExp Int64 VName
_ [LMADDim (TPrimExp Int64 VName)]
dims) =
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr
( \LMADDim (TPrimExp Int64 VName)
dim TPrimExp Int64 VName
upper ->
let spn :: TPrimExp Int64 VName
spn = forall num. LMADDim num -> num
ldStride LMADDim (TPrimExp Int64 VName)
dim forall a. Num a => a -> a -> a
* (forall num. LMADDim num -> num
ldShape LMADDim (TPrimExp Int64 VName)
dim forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1)
in
TPrimExp Int64 VName
spn forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
upper
)
TPrimExp Int64 VName
0
[LMADDim (TPrimExp Int64 VName)]
dims
conservativeFlatten :: LMAD (TPrimExp Int64 VName) -> Maybe (LMAD (TPrimExp Int64 VName))
conservativeFlatten :: LMAD (TPrimExp Int64 VName) -> Maybe (LMAD (TPrimExp Int64 VName))
conservativeFlatten (LMAD TPrimExp Int64 VName
offset []) =
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall num. num -> [LMADDim num] -> LMAD num
LMAD TPrimExp Int64 VName
offset [forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim TPrimExp Int64 VName
1 TPrimExp Int64 VName
1 Int
0 Monotonicity
Inc]
conservativeFlatten l :: LMAD (TPrimExp Int64 VName)
l@(LMAD TPrimExp Int64 VName
_ [LMADDim (TPrimExp Int64 VName)
_]) =
forall (f :: * -> *) a. Applicative f => a -> f a
pure LMAD (TPrimExp Int64 VName)
l
conservativeFlatten l :: LMAD (TPrimExp Int64 VName)
l@(LMAD TPrimExp Int64 VName
offset [LMADDim (TPrimExp Int64 VName)]
dims) = do
TPrimExp Int64 VName
strd <-
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM
TPrimExp Int64 VName
-> TPrimExp Int64 VName -> Maybe (TPrimExp Int64 VName)
gcd
(forall num. LMADDim num -> num
ldStride forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
head [LMADDim (TPrimExp Int64 VName)]
dims)
forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall num. LMADDim num -> num
ldStride [LMADDim (TPrimExp Int64 VName)]
dims
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall num. num -> [LMADDim num] -> LMAD num
LMAD TPrimExp Int64 VName
offset [forall num. num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim TPrimExp Int64 VName
strd (TPrimExp Int64 VName
shp forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1) Int
0 Monotonicity
Unknown]
where
shp :: TPrimExp Int64 VName
shp = LMAD (TPrimExp Int64 VName) -> TPrimExp Int64 VName
flatSpan LMAD (TPrimExp Int64 VName)
l
gcd :: TPrimExp Int64 VName -> TPrimExp Int64 VName -> Maybe (TPrimExp Int64 VName)
gcd :: TPrimExp Int64 VName
-> TPrimExp Int64 VName -> Maybe (TPrimExp Int64 VName)
gcd TPrimExp Int64 VName
x TPrimExp Int64 VName
y = forall {a}. (Eq a, Num a) => a -> a -> Maybe a
gcd' (forall a. Num a => a -> a
abs TPrimExp Int64 VName
x) (forall a. Num a => a -> a
abs TPrimExp Int64 VName
y)
where
gcd' :: a -> a -> Maybe a
gcd' a
a a
b | a
a forall a. Eq a => a -> a -> Bool
== a
b = forall a. a -> Maybe a
Just a
a
gcd' a
1 a
_ = forall a. a -> Maybe a
Just a
1
gcd' a
_ a
1 = forall a. a -> Maybe a
Just a
1
gcd' a
a a
0 = forall a. a -> Maybe a
Just a
a
gcd' a
_ a
_ = forall a. Maybe a
Nothing
disjoint :: [(VName, PrimExp VName)] -> Names -> LMAD (TPrimExp Int64 VName) -> LMAD (TPrimExp Int64 VName) -> Bool
disjoint :: [(VName, PrimExp VName)]
-> Names
-> LMAD (TPrimExp Int64 VName)
-> LMAD (TPrimExp Int64 VName)
-> Bool
disjoint [(VName, PrimExp VName)]
less_thans Names
non_negatives (LMAD TPrimExp Int64 VName
offset1 [LMADDim (TPrimExp Int64 VName)
dim1]) (LMAD TPrimExp Int64 VName
offset2 [LMADDim (TPrimExp Int64 VName)
dim2]) =
Maybe (TPrimExp Int64 VName) -> TPrimExp Int64 VName -> Bool
doesNotDivide (TPrimExp Int64 VName
-> TPrimExp Int64 VName -> Maybe (TPrimExp Int64 VName)
gcd (forall num. LMADDim num -> num
ldStride LMADDim (TPrimExp Int64 VName)
dim1) (forall num. LMADDim num -> num
ldStride LMADDim (TPrimExp Int64 VName)
dim2)) (TPrimExp Int64 VName
offset1 forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
offset2)
Bool -> Bool -> Bool
|| [(VName, PrimExp VName)]
-> Names -> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Bool
AlgSimplify.lessThanish
[(VName, PrimExp VName)]
less_thans
Names
non_negatives
(TPrimExp Int64 VName
offset2 forall a. Num a => a -> a -> a
+ (forall num. LMADDim num -> num
ldShape LMADDim (TPrimExp Int64 VName)
dim2 forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1) forall a. Num a => a -> a -> a
* forall num. LMADDim num -> num
ldStride LMADDim (TPrimExp Int64 VName)
dim2)
TPrimExp Int64 VName
offset1
Bool -> Bool -> Bool
|| [(VName, PrimExp VName)]
-> Names -> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Bool
AlgSimplify.lessThanish
[(VName, PrimExp VName)]
less_thans
Names
non_negatives
(TPrimExp Int64 VName
offset1 forall a. Num a => a -> a -> a
+ (forall num. LMADDim num -> num
ldShape LMADDim (TPrimExp Int64 VName)
dim1 forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1) forall a. Num a => a -> a -> a
* forall num. LMADDim num -> num
ldStride LMADDim (TPrimExp Int64 VName)
dim1)
TPrimExp Int64 VName
offset2
where
doesNotDivide :: Maybe (TPrimExp Int64 VName) -> TPrimExp Int64 VName -> Bool
doesNotDivide :: Maybe (TPrimExp Int64 VName) -> TPrimExp Int64 VName -> Bool
doesNotDivide (Just TPrimExp Int64 VName
x) TPrimExp Int64 VName
y =
forall e. IntegralExp e => e -> e -> e
Futhark.Util.IntegralExp.mod TPrimExp Int64 VName
y TPrimExp Int64 VName
x
forall a b. a -> (a -> b) -> b
& forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped
forall a b. a -> (a -> b) -> b
& forall v. PrimExp v -> PrimExp v
constFoldPrimExp
forall a b. a -> (a -> b) -> b
& forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp
forall a b. a -> (a -> b) -> b
& forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
(.==.) (TPrimExp Int64 VName
0 :: TPrimExp Int64 VName)
forall a b. a -> (a -> b) -> b
& TPrimExp Bool VName -> Maybe Bool
primBool
forall a b. a -> (a -> b) -> b
& forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False Bool -> Bool
not
doesNotDivide Maybe (TPrimExp Int64 VName)
_ TPrimExp Int64 VName
_ = Bool
False
disjoint [(VName, PrimExp VName)]
less_thans Names
non_negatives LMAD (TPrimExp Int64 VName)
lmad1 LMAD (TPrimExp Int64 VName)
lmad2 =
case (LMAD (TPrimExp Int64 VName) -> Maybe (LMAD (TPrimExp Int64 VName))
conservativeFlatten LMAD (TPrimExp Int64 VName)
lmad1, LMAD (TPrimExp Int64 VName) -> Maybe (LMAD (TPrimExp Int64 VName))
conservativeFlatten LMAD (TPrimExp Int64 VName)
lmad2) of
(Just LMAD (TPrimExp Int64 VName)
lmad1', Just LMAD (TPrimExp Int64 VName)
lmad2') -> [(VName, PrimExp VName)]
-> Names
-> LMAD (TPrimExp Int64 VName)
-> LMAD (TPrimExp Int64 VName)
-> Bool
disjoint [(VName, PrimExp VName)]
less_thans Names
non_negatives LMAD (TPrimExp Int64 VName)
lmad1' LMAD (TPrimExp Int64 VName)
lmad2'
(Maybe (LMAD (TPrimExp Int64 VName)),
Maybe (LMAD (TPrimExp Int64 VName)))
_ -> Bool
False
disjoint2 :: scope -> asserts -> [(VName, PrimExp VName)] -> Names -> LMAD (TPrimExp Int64 VName) -> LMAD (TPrimExp Int64 VName) -> Bool
disjoint2 :: forall scope asserts.
scope
-> asserts
-> [(VName, PrimExp VName)]
-> Names
-> LMAD (TPrimExp Int64 VName)
-> LMAD (TPrimExp Int64 VName)
-> Bool
disjoint2 scope
_ asserts
_ [(VName, PrimExp VName)]
less_thans Names
non_negatives LMAD (TPrimExp Int64 VName)
lmad1 LMAD (TPrimExp Int64 VName)
lmad2 =
let (SofP
offset1, [Interval]
interval1) = LMAD (TPrimExp Int64 VName) -> (SofP, [Interval])
lmadToIntervals LMAD (TPrimExp Int64 VName)
lmad1
(SofP
offset2, [Interval]
interval2) = LMAD (TPrimExp Int64 VName) -> (SofP, [Interval])
lmadToIntervals LMAD (TPrimExp Int64 VName)
lmad2
(SofP
neg_offset, SofP
pos_offset) =
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition Prod -> Bool
AlgSimplify.negated forall a b. (a -> b) -> a -> b
$
SofP
offset1 SofP -> SofP -> SofP
`AlgSimplify.sub` SofP
offset2
([Interval]
interval1', [Interval]
interval2') =
forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (forall a b c. (a -> b -> c) -> b -> a -> c
flip SofP -> SofP -> Ordering
AlgSimplify.compareComplexity forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (PrimExp VName -> SofP
AlgSimplify.simplify0 forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Interval -> TPrimExp Int64 VName
stride forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b. (a, b) -> a
fst)) forall a b. (a -> b) -> a -> b
$
[Interval] -> [Interval] -> [(Interval, Interval)]
intervalPairs [Interval]
interval1 [Interval]
interval2
in case ( forall (m :: * -> *).
MonadFail m =>
SofP -> [Interval] -> m [Interval]
distributeOffset SofP
pos_offset [Interval]
interval1',
forall (m :: * -> *).
MonadFail m =>
SofP -> [Interval] -> m [Interval]
distributeOffset (forall a b. (a -> b) -> [a] -> [b]
map Prod -> Prod
AlgSimplify.negate SofP
neg_offset) [Interval]
interval2'
) of
(Just [Interval]
interval1'', Just [Interval]
interval2'') ->
forall a. Maybe a -> Bool
isNothing
( forall scope asserts.
scope
-> asserts
-> [(VName, PrimExp VName)]
-> [PrimExp VName]
-> [Interval]
-> Maybe Interval
selfOverlap () () [(VName, PrimExp VName)]
less_thans (forall a b. (a -> b) -> [a] -> [b]
map (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall v. v -> PrimType -> PrimExp v
LeafExp forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int64) forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList Names
non_negatives) [Interval]
interval1''
)
Bool -> Bool -> Bool
&& forall a. Maybe a -> Bool
isNothing
( forall scope asserts.
scope
-> asserts
-> [(VName, PrimExp VName)]
-> [PrimExp VName]
-> [Interval]
-> Maybe Interval
selfOverlap () () [(VName, PrimExp VName)]
less_thans (forall a b. (a -> b) -> [a] -> [b]
map (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall v. v -> PrimType -> PrimExp v
LeafExp forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int64) forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList Names
non_negatives) [Interval]
interval2''
)
Bool -> Bool -> Bool
&& Bool -> Bool
not
( forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all
(forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ([(VName, PrimExp VName)] -> Names -> Interval -> Interval -> Bool
intervalOverlap [(VName, PrimExp VName)]
less_thans Names
non_negatives))
(forall a b. [a] -> [b] -> [(a, b)]
zip [Interval]
interval1'' [Interval]
interval2'')
)
(Maybe [Interval], Maybe [Interval])
_ ->
Bool
False
disjoint3 :: M.Map VName Type -> [PrimExp VName] -> [(VName, PrimExp VName)] -> [PrimExp VName] -> LMAD (TPrimExp Int64 VName) -> LMAD (TPrimExp Int64 VName) -> Bool
disjoint3 :: Map VName Type
-> [PrimExp VName]
-> [(VName, PrimExp VName)]
-> [PrimExp VName]
-> LMAD (TPrimExp Int64 VName)
-> LMAD (TPrimExp Int64 VName)
-> Bool
disjoint3 Map VName Type
scope [PrimExp VName]
asserts [(VName, PrimExp VName)]
less_thans [PrimExp VName]
non_negatives LMAD (TPrimExp Int64 VName)
lmad1 LMAD (TPrimExp Int64 VName)
lmad2 =
let (SofP
offset1, [Interval]
interval1) = LMAD (TPrimExp Int64 VName) -> (SofP, [Interval])
lmadToIntervals LMAD (TPrimExp Int64 VName)
lmad1
(SofP
offset2, [Interval]
interval2) = LMAD (TPrimExp Int64 VName) -> (SofP, [Interval])
lmadToIntervals LMAD (TPrimExp Int64 VName)
lmad2
interval1' :: [Interval]
interval1' = forall a. Eq a => (a -> a) -> a -> a
fixPoint ([Interval] -> [Interval]
mergeDims forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [Interval] -> [Interval]
joinDims) forall a b. (a -> b) -> a -> b
$ forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (forall a b c. (a -> b -> c) -> b -> a -> c
flip SofP -> SofP -> Ordering
AlgSimplify.compareComplexity forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (PrimExp VName -> SofP
AlgSimplify.simplify0 forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Interval -> TPrimExp Int64 VName
stride)) [Interval]
interval1
interval2' :: [Interval]
interval2' = forall a. Eq a => (a -> a) -> a -> a
fixPoint ([Interval] -> [Interval]
mergeDims forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [Interval] -> [Interval]
joinDims) forall a b. (a -> b) -> a -> b
$ forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (forall a b c. (a -> b -> c) -> b -> a -> c
flip SofP -> SofP -> Ordering
AlgSimplify.compareComplexity forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (PrimExp VName -> SofP
AlgSimplify.simplify0 forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Interval -> TPrimExp Int64 VName
stride)) [Interval]
interval2
([Interval]
interval1'', [Interval]
interval2'') =
forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (forall a b c. (a -> b -> c) -> b -> a -> c
flip SofP -> SofP -> Ordering
AlgSimplify.compareComplexity forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (PrimExp VName -> SofP
AlgSimplify.simplify0 forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Interval -> TPrimExp Int64 VName
stride forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b. (a, b) -> a
fst)) forall a b. (a -> b) -> a -> b
$
[Interval] -> [Interval] -> [(Interval, Interval)]
intervalPairs [Interval]
interval1' [Interval]
interval2'
in Int -> [Interval] -> [Interval] -> SofP -> Bool
disjointHelper Int
4 [Interval]
interval1'' [Interval]
interval2'' forall a b. (a -> b) -> a -> b
$ SofP
offset1 SofP -> SofP -> SofP
`AlgSimplify.sub` SofP
offset2
where
disjointHelper :: Int -> [Interval] -> [Interval] -> AlgSimplify.SofP -> Bool
disjointHelper :: Int -> [Interval] -> [Interval] -> SofP -> Bool
disjointHelper Int
0 [Interval]
_ [Interval]
_ SofP
_ = Bool
False
disjointHelper Int
i [Interval]
is10 [Interval]
is20 SofP
offset =
let ([Interval]
is1, [Interval]
is2) =
forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (forall a b c. (a -> b -> c) -> b -> a -> c
flip SofP -> SofP -> Ordering
AlgSimplify.compareComplexity forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (PrimExp VName -> SofP
AlgSimplify.simplify0 forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Interval -> TPrimExp Int64 VName
stride forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b. (a, b) -> a
fst)) forall a b. (a -> b) -> a -> b
$
[Interval] -> [Interval] -> [(Interval, Interval)]
intervalPairs [Interval]
is10 [Interval]
is20
(SofP
neg_offset, SofP
pos_offset) = forall a. (a -> Bool) -> [a] -> ([a], [a])
partition Prod -> Bool
AlgSimplify.negated SofP
offset
in case ( forall (m :: * -> *).
MonadFail m =>
SofP -> [Interval] -> m [Interval]
distributeOffset SofP
pos_offset [Interval]
is1,
forall (m :: * -> *).
MonadFail m =>
SofP -> [Interval] -> m [Interval]
distributeOffset (forall a b. (a -> b) -> [a] -> [b]
map Prod -> Prod
AlgSimplify.negate SofP
neg_offset) [Interval]
is2
) of
(Just [Interval]
is1', Just [Interval]
is2') -> do
let overlap1 :: Maybe Interval
overlap1 = forall scope asserts.
scope
-> asserts
-> [(VName, PrimExp VName)]
-> [PrimExp VName]
-> [Interval]
-> Maybe Interval
selfOverlap Map VName Type
scope [PrimExp VName]
asserts [(VName, PrimExp VName)]
less_thans [PrimExp VName]
non_negatives [Interval]
is1'
let overlap2 :: Maybe Interval
overlap2 = forall scope asserts.
scope
-> asserts
-> [(VName, PrimExp VName)]
-> [PrimExp VName]
-> [Interval]
-> Maybe Interval
selfOverlap Map VName Type
scope [PrimExp VName]
asserts [(VName, PrimExp VName)]
less_thans [PrimExp VName]
non_negatives [Interval]
is2'
case (Maybe Interval
overlap1, Maybe Interval
overlap2) of
(Maybe Interval
Nothing, Maybe Interval
Nothing) ->
case [VName] -> Names
namesFromList forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM PrimExp VName -> Maybe VName
justLeafExp [PrimExp VName]
non_negatives of
Just Names
non_negatives' ->
Bool -> Bool
not forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all
(forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ([(VName, PrimExp VName)] -> Names -> Interval -> Interval -> Bool
intervalOverlap [(VName, PrimExp VName)]
less_thans Names
non_negatives'))
(forall a b. [a] -> [b] -> [(a, b)]
zip [Interval]
is1 [Interval]
is2)
Maybe Names
_ -> Bool
False
(Just Interval
overlapping_dim, Maybe Interval
_) ->
let expanded_offset :: Maybe SofP
expanded_offset = SofP -> SofP
AlgSimplify.simplifySofP' forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SofP -> [Interval] -> Maybe SofP
expandOffset SofP
offset [Interval]
is1
splits :: [(SofP, [Interval])]
splits = Interval -> [Interval] -> [(SofP, [Interval])]
splitDim Interval
overlapping_dim [Interval]
is1'
in forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\(SofP
new_offset, [Interval]
new_is1) -> Int -> [Interval] -> [Interval] -> SofP -> Bool
disjointHelper (Int
i forall a. Num a => a -> a -> a
- Int
1) ([Interval] -> [Interval]
joinDims [Interval]
new_is1) ([Interval] -> [Interval]
joinDims [Interval]
is2') SofP
new_offset) [(SofP, [Interval])]
splits
Bool -> Bool -> Bool
|| forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (Int -> [Interval] -> [Interval] -> SofP -> Bool
disjointHelper (Int
i forall a. Num a => a -> a -> a
- Int
1) [Interval]
is1 [Interval]
is2) Maybe SofP
expanded_offset
(Maybe Interval
_, Just Interval
overlapping_dim) ->
let expanded_offset :: Maybe SofP
expanded_offset = SofP -> SofP
AlgSimplify.simplifySofP' forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SofP -> [Interval] -> Maybe SofP
expandOffset SofP
offset [Interval]
is2
splits :: [(SofP, [Interval])]
splits = Interval -> [Interval] -> [(SofP, [Interval])]
splitDim Interval
overlapping_dim [Interval]
is2'
in forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all
( \(SofP
new_offset, [Interval]
new_is2) ->
Int -> [Interval] -> [Interval] -> SofP -> Bool
disjointHelper (Int
i forall a. Num a => a -> a -> a
- Int
1) ([Interval] -> [Interval]
joinDims [Interval]
is1') ([Interval] -> [Interval]
joinDims [Interval]
new_is2) forall a b. (a -> b) -> a -> b
$
forall a b. (a -> b) -> [a] -> [b]
map Prod -> Prod
AlgSimplify.negate SofP
new_offset
)
[(SofP, [Interval])]
splits
Bool -> Bool -> Bool
|| forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (Int -> [Interval] -> [Interval] -> SofP -> Bool
disjointHelper (Int
i forall a. Num a => a -> a -> a
- Int
1) [Interval]
is1 [Interval]
is2) Maybe SofP
expanded_offset
(Maybe [Interval], Maybe [Interval])
_ -> Bool
False
joinDims :: [Interval] -> [Interval]
joinDims :: [Interval] -> [Interval]
joinDims = [Interval] -> [Interval] -> [Interval]
helper []
where
helper :: [Interval] -> [Interval] -> [Interval]
helper [Interval]
acc [] = forall a. [a] -> [a]
reverse [Interval]
acc
helper [Interval]
acc [Interval
x] = forall a. [a] -> [a]
reverse forall a b. (a -> b) -> a -> b
$ Interval
x forall a. a -> [a] -> [a]
: [Interval]
acc
helper [Interval]
acc (Interval
x : Interval
y : [Interval]
rest) =
if Interval -> TPrimExp Int64 VName
stride Interval
x forall a. Eq a => a -> a -> Bool
== Interval -> TPrimExp Int64 VName
stride Interval
y Bool -> Bool -> Bool
&& Interval -> TPrimExp Int64 VName
lowerBound Interval
x forall a. Eq a => a -> a -> Bool
== TPrimExp Int64 VName
0 Bool -> Bool -> Bool
&& Interval -> TPrimExp Int64 VName
lowerBound Interval
y forall a. Eq a => a -> a -> Bool
== TPrimExp Int64 VName
0
then [Interval] -> [Interval] -> [Interval]
helper [Interval]
acc forall a b. (a -> b) -> a -> b
$ Interval
x {numElements :: TPrimExp Int64 VName
numElements = Interval -> TPrimExp Int64 VName
numElements Interval
x forall a. Num a => a -> a -> a
* Interval -> TPrimExp Int64 VName
numElements Interval
y} forall a. a -> [a] -> [a]
: [Interval]
rest
else [Interval] -> [Interval] -> [Interval]
helper (Interval
x forall a. a -> [a] -> [a]
: [Interval]
acc) (Interval
y forall a. a -> [a] -> [a]
: [Interval]
rest)
mergeDims :: [Interval] -> [Interval]
mergeDims :: [Interval] -> [Interval]
mergeDims = [Interval] -> [Interval] -> [Interval]
helper [] forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a. [a] -> [a]
reverse
where
helper :: [Interval] -> [Interval] -> [Interval]
helper [Interval]
acc [] = [Interval]
acc
helper [Interval]
acc [Interval
x] = Interval
x forall a. a -> [a] -> [a]
: [Interval]
acc
helper [Interval]
acc (Interval
x : Interval
y : [Interval]
rest) =
if Interval -> TPrimExp Int64 VName
stride Interval
x forall a. Num a => a -> a -> a
* Interval -> TPrimExp Int64 VName
numElements Interval
x forall a. Eq a => a -> a -> Bool
== Interval -> TPrimExp Int64 VName
stride Interval
y Bool -> Bool -> Bool
&& Interval -> TPrimExp Int64 VName
lowerBound Interval
x forall a. Eq a => a -> a -> Bool
== TPrimExp Int64 VName
0 Bool -> Bool -> Bool
&& Interval -> TPrimExp Int64 VName
lowerBound Interval
y forall a. Eq a => a -> a -> Bool
== TPrimExp Int64 VName
0
then [Interval] -> [Interval] -> [Interval]
helper [Interval]
acc forall a b. (a -> b) -> a -> b
$ Interval
x {numElements :: TPrimExp Int64 VName
numElements = Interval -> TPrimExp Int64 VName
numElements Interval
x forall a. Num a => a -> a -> a
* Interval -> TPrimExp Int64 VName
numElements Interval
y} forall a. a -> [a] -> [a]
: [Interval]
rest
else [Interval] -> [Interval] -> [Interval]
helper (Interval
x forall a. a -> [a] -> [a]
: [Interval]
acc) (Interval
y forall a. a -> [a] -> [a]
: [Interval]
rest)
splitDim :: Interval -> [Interval] -> [(AlgSimplify.SofP, [Interval])]
splitDim :: Interval -> [Interval] -> [(SofP, [Interval])]
splitDim Interval
overlapping_dim0 [Interval]
is
| [Prod
st] <- PrimExp VName -> SofP
AlgSimplify.simplify0 forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ Interval -> TPrimExp Int64 VName
stride Interval
overlapping_dim0,
[Prod
st1] <- PrimExp VName -> SofP
AlgSimplify.simplify0 forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ Interval -> TPrimExp Int64 VName
stride Interval
overlapping_dim,
[Prod
spn] <- PrimExp VName -> SofP
AlgSimplify.simplify0 forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ Interval -> TPrimExp Int64 VName
stride Interval
overlapping_dim forall a. Num a => a -> a -> a
* Interval -> TPrimExp Int64 VName
numElements Interval
overlapping_dim,
Interval -> TPrimExp Int64 VName
lowerBound Interval
overlapping_dim forall a. Eq a => a -> a -> Bool
== TPrimExp Int64 VName
0,
Just Prod
big_dim_elems <- Prod -> Prod -> Maybe Prod
AlgSimplify.maybeDivide Prod
spn Prod
st,
Just Prod
small_dim_elems <- Prod -> Prod -> Maybe Prod
AlgSimplify.maybeDivide Prod
st Prod
st1 =
[ ( [],
forall a. [a] -> [a]
init [Interval]
before
forall a. Semigroup a => a -> a -> a
<> [ TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval TPrimExp Int64 VName
0 (forall v. PrimExp v -> TPrimExp Int64 v
isInt64 forall a b. (a -> b) -> a -> b
$ Prod -> PrimExp VName
AlgSimplify.prodToExp Prod
big_dim_elems) (Interval -> TPrimExp Int64 VName
stride Interval
overlapping_dim0),
TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval TPrimExp Int64 VName
0 (forall v. PrimExp v -> TPrimExp Int64 v
isInt64 forall a b. (a -> b) -> a -> b
$ Prod -> PrimExp VName
AlgSimplify.prodToExp Prod
small_dim_elems) (Interval -> TPrimExp Int64 VName
stride Interval
overlapping_dim)
]
forall a. Semigroup a => a -> a -> a
<> [Interval]
after
)
]
| Bool
otherwise =
let shrunk_dim :: Interval
shrunk_dim = Interval
overlapping_dim {numElements :: TPrimExp Int64 VName
numElements = Interval -> TPrimExp Int64 VName
numElements Interval
overlapping_dim forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1}
point_offset :: SofP
point_offset = PrimExp VName -> SofP
AlgSimplify.simplify0 forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ (Interval -> TPrimExp Int64 VName
numElements Interval
overlapping_dim forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1 forall a. Num a => a -> a -> a
+ Interval -> TPrimExp Int64 VName
lowerBound Interval
overlapping_dim) forall a. Num a => a -> a -> a
* Interval -> TPrimExp Int64 VName
stride Interval
overlapping_dim
in [ (SofP
point_offset, [Interval]
before forall a. Semigroup a => a -> a -> a
<> [Interval]
after),
([], [Interval]
before forall a. Semigroup a => a -> a -> a
<> [Interval
shrunk_dim] forall a. Semigroup a => a -> a -> a
<> [Interval]
after)
]
where
([Interval]
before, Interval
overlapping_dim, [Interval]
after) =
forall a. HasCallStack => Maybe a -> a
fromJust forall a b. (a -> b) -> a -> b
$
forall a. Eq a => a -> [a] -> Maybe Int
elemIndex Interval
overlapping_dim0 [Interval]
is
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall int a. Integral int => int -> [a] -> Maybe ([a], a, [a])
focusNth [Interval]
is forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (forall a. Num a => a -> a -> a
+ Int
1))
lmadToIntervals :: LMAD (TPrimExp Int64 VName) -> (AlgSimplify.SofP, [Interval])
lmadToIntervals :: LMAD (TPrimExp Int64 VName) -> (SofP, [Interval])
lmadToIntervals (LMAD TPrimExp Int64 VName
offset []) = (PrimExp VName -> SofP
AlgSimplify.simplify0 forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
offset, [TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval TPrimExp Int64 VName
0 TPrimExp Int64 VName
1 TPrimExp Int64 VName
1])
lmadToIntervals lmad :: LMAD (TPrimExp Int64 VName)
lmad@(LMAD TPrimExp Int64 VName
offset [LMADDim (TPrimExp Int64 VName)]
dims0) =
(SofP
offset', forall a b. (a -> b) -> [a] -> [b]
map LMADDim (TPrimExp Int64 VName) -> Interval
helper forall a b. (a -> b) -> a -> b
$ forall a. Permutation -> [a] -> [a]
permuteInv (forall num. LMAD num -> Permutation
lmadPermutation LMAD (TPrimExp Int64 VName)
lmad) [LMADDim (TPrimExp Int64 VName)]
dims0)
where
offset' :: SofP
offset' = PrimExp VName -> SofP
AlgSimplify.simplify0 forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
offset
helper :: LMADDim (TPrimExp Int64 VName) -> Interval
helper :: LMADDim (TPrimExp Int64 VName) -> Interval
helper (LMADDim TPrimExp Int64 VName
strd TPrimExp Int64 VName
shp Int
_ Monotonicity
_) = do
TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval TPrimExp Int64 VName
0 (TPrimExp Int64 VName -> TPrimExp Int64 VName
AlgSimplify.simplify' TPrimExp Int64 VName
shp) (TPrimExp Int64 VName -> TPrimExp Int64 VName
AlgSimplify.simplify' TPrimExp Int64 VName
strd)
dynamicEqualsLMADDim :: Eq num => LMADDim (TPrimExp t num) -> LMADDim (TPrimExp t num) -> TPrimExp Bool num
dynamicEqualsLMADDim :: forall {k} num (t :: k).
Eq num =>
LMADDim (TPrimExp t num)
-> LMADDim (TPrimExp t num) -> TPrimExp Bool num
dynamicEqualsLMADDim LMADDim (TPrimExp t num)
dim1 LMADDim (TPrimExp t num)
dim2 =
forall num. LMADDim num -> num
ldStride LMADDim (TPrimExp t num)
dim1 forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. forall num. LMADDim num -> num
ldStride LMADDim (TPrimExp t num)
dim2
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. forall num. LMADDim num -> num
ldShape LMADDim (TPrimExp t num)
dim1 forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. forall num. LMADDim num -> num
ldShape LMADDim (TPrimExp t num)
dim2
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. forall v. Bool -> TPrimExp Bool v
fromBool (forall num. LMADDim num -> Int
ldPerm LMADDim (TPrimExp t num)
dim1 forall a. Eq a => a -> a -> Bool
== forall num. LMADDim num -> Int
ldPerm LMADDim (TPrimExp t num)
dim2)
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. forall v. Bool -> TPrimExp Bool v
fromBool (forall num. LMADDim num -> Monotonicity
ldMon LMADDim (TPrimExp t num)
dim1 forall a. Eq a => a -> a -> Bool
== forall num. LMADDim num -> Monotonicity
ldMon LMADDim (TPrimExp t num)
dim2)
dynamicEqualsLMAD :: Eq num => LMAD (TPrimExp t num) -> LMAD (TPrimExp t num) -> TPrimExp Bool num
dynamicEqualsLMAD :: forall {k} num (t :: k).
Eq num =>
LMAD (TPrimExp t num) -> LMAD (TPrimExp t num) -> TPrimExp Bool num
dynamicEqualsLMAD LMAD (TPrimExp t num)
lmad1 LMAD (TPrimExp t num)
lmad2 =
forall num. LMAD num -> num
lmadOffset LMAD (TPrimExp t num)
lmad1 forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. forall num. LMAD num -> num
lmadOffset LMAD (TPrimExp t num)
lmad2
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr
(forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
(.&&.) forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry forall {k} num (t :: k).
Eq num =>
LMADDim (TPrimExp t num)
-> LMADDim (TPrimExp t num) -> TPrimExp Bool num
dynamicEqualsLMADDim)
forall v. TPrimExp Bool v
true
(forall a b. [a] -> [b] -> [(a, b)]
zip (forall num. LMAD num -> [LMADDim num]
lmadDims LMAD (TPrimExp t num)
lmad1) (forall num. LMAD num -> [LMADDim num]
lmadDims LMAD (TPrimExp t num)
lmad2))