{-# LANGUAGE TypeFamilies #-}
module Futhark.Analysis.HORep.SOAC
(
SOAC (..),
Futhark.ScremaForm (..),
inputs,
setInputs,
lambda,
setLambda,
typeOf,
width,
NotSOAC (..),
fromExp,
toExp,
toSOAC,
Input (..),
varInput,
inputTransforms,
identInput,
isVarInput,
isVarishInput,
addTransform,
addInitialTransforms,
inputArray,
inputRank,
inputType,
inputRowType,
transformRows,
transposeInput,
applyTransforms,
ArrayTransforms,
noTransforms,
nullTransforms,
(|>),
(<|),
viewf,
ViewF (..),
viewl,
ViewL (..),
ArrayTransform (..),
transformFromExp,
transformToExp,
soacToStream,
)
where
import Data.Foldable as Foldable
import Data.Maybe
import Data.Sequence qualified as Seq
import Futhark.Construct hiding (toExp)
import Futhark.IR hiding
( Iota,
Rearrange,
Replicate,
Reshape,
typeOf,
)
import Futhark.IR qualified as Futhark
import Futhark.IR.SOACS.SOAC
( HistOp (..),
ScremaForm (..),
scremaType,
)
import Futhark.IR.SOACS.SOAC qualified as Futhark
import Futhark.Transform.Rename (renameLambda)
import Futhark.Transform.Substitute
import Futhark.Util.Pretty (pretty)
import Futhark.Util.Pretty qualified as PP
data ArrayTransform
=
Rearrange Certs [Int]
|
Reshape Certs ReshapeKind Shape
|
ReshapeOuter Certs ReshapeKind Shape
|
ReshapeInner Certs ReshapeKind Shape
|
Replicate Certs Shape
deriving (Int -> ArrayTransform -> ShowS
[ArrayTransform] -> ShowS
ArrayTransform -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ArrayTransform] -> ShowS
$cshowList :: [ArrayTransform] -> ShowS
show :: ArrayTransform -> String
$cshow :: ArrayTransform -> String
showsPrec :: Int -> ArrayTransform -> ShowS
$cshowsPrec :: Int -> ArrayTransform -> ShowS
Show, ArrayTransform -> ArrayTransform -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ArrayTransform -> ArrayTransform -> Bool
$c/= :: ArrayTransform -> ArrayTransform -> Bool
== :: ArrayTransform -> ArrayTransform -> Bool
$c== :: ArrayTransform -> ArrayTransform -> Bool
Eq, Eq ArrayTransform
ArrayTransform -> ArrayTransform -> Bool
ArrayTransform -> ArrayTransform -> Ordering
ArrayTransform -> ArrayTransform -> ArrayTransform
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: ArrayTransform -> ArrayTransform -> ArrayTransform
$cmin :: ArrayTransform -> ArrayTransform -> ArrayTransform
max :: ArrayTransform -> ArrayTransform -> ArrayTransform
$cmax :: ArrayTransform -> ArrayTransform -> ArrayTransform
>= :: ArrayTransform -> ArrayTransform -> Bool
$c>= :: ArrayTransform -> ArrayTransform -> Bool
> :: ArrayTransform -> ArrayTransform -> Bool
$c> :: ArrayTransform -> ArrayTransform -> Bool
<= :: ArrayTransform -> ArrayTransform -> Bool
$c<= :: ArrayTransform -> ArrayTransform -> Bool
< :: ArrayTransform -> ArrayTransform -> Bool
$c< :: ArrayTransform -> ArrayTransform -> Bool
compare :: ArrayTransform -> ArrayTransform -> Ordering
$ccompare :: ArrayTransform -> ArrayTransform -> Ordering
Ord)
instance Substitute ArrayTransform where
substituteNames :: Map VName VName -> ArrayTransform -> ArrayTransform
substituteNames Map VName VName
substs (Rearrange Certs
cs [Int]
xs) =
Certs -> [Int] -> ArrayTransform
Rearrange (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Certs
cs) [Int]
xs
substituteNames Map VName VName
substs (Reshape Certs
cs ReshapeKind
k Shape
ses) =
Certs -> ReshapeKind -> Shape -> ArrayTransform
Reshape (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Certs
cs) ReshapeKind
k (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Shape
ses)
substituteNames Map VName VName
substs (ReshapeOuter Certs
cs ReshapeKind
k Shape
ses) =
Certs -> ReshapeKind -> Shape -> ArrayTransform
ReshapeOuter (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Certs
cs) ReshapeKind
k (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Shape
ses)
substituteNames Map VName VName
substs (ReshapeInner Certs
cs ReshapeKind
k Shape
ses) =
Certs -> ReshapeKind -> Shape -> ArrayTransform
ReshapeInner (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Certs
cs) ReshapeKind
k (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Shape
ses)
substituteNames Map VName VName
substs (Replicate Certs
cs Shape
se) =
Certs -> Shape -> ArrayTransform
Replicate (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Certs
cs) (forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Shape
se)
newtype ArrayTransforms = ArrayTransforms (Seq.Seq ArrayTransform)
deriving (ArrayTransforms -> ArrayTransforms -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ArrayTransforms -> ArrayTransforms -> Bool
$c/= :: ArrayTransforms -> ArrayTransforms -> Bool
== :: ArrayTransforms -> ArrayTransforms -> Bool
$c== :: ArrayTransforms -> ArrayTransforms -> Bool
Eq, Eq ArrayTransforms
ArrayTransforms -> ArrayTransforms -> Bool
ArrayTransforms -> ArrayTransforms -> Ordering
ArrayTransforms -> ArrayTransforms -> ArrayTransforms
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: ArrayTransforms -> ArrayTransforms -> ArrayTransforms
$cmin :: ArrayTransforms -> ArrayTransforms -> ArrayTransforms
max :: ArrayTransforms -> ArrayTransforms -> ArrayTransforms
$cmax :: ArrayTransforms -> ArrayTransforms -> ArrayTransforms
>= :: ArrayTransforms -> ArrayTransforms -> Bool
$c>= :: ArrayTransforms -> ArrayTransforms -> Bool
> :: ArrayTransforms -> ArrayTransforms -> Bool
$c> :: ArrayTransforms -> ArrayTransforms -> Bool
<= :: ArrayTransforms -> ArrayTransforms -> Bool
$c<= :: ArrayTransforms -> ArrayTransforms -> Bool
< :: ArrayTransforms -> ArrayTransforms -> Bool
$c< :: ArrayTransforms -> ArrayTransforms -> Bool
compare :: ArrayTransforms -> ArrayTransforms -> Ordering
$ccompare :: ArrayTransforms -> ArrayTransforms -> Ordering
Ord, Int -> ArrayTransforms -> ShowS
[ArrayTransforms] -> ShowS
ArrayTransforms -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ArrayTransforms] -> ShowS
$cshowList :: [ArrayTransforms] -> ShowS
show :: ArrayTransforms -> String
$cshow :: ArrayTransforms -> String
showsPrec :: Int -> ArrayTransforms -> ShowS
$cshowsPrec :: Int -> ArrayTransforms -> ShowS
Show)
instance Semigroup ArrayTransforms where
ArrayTransforms
ts1 <> :: ArrayTransforms -> ArrayTransforms -> ArrayTransforms
<> ArrayTransforms
ts2 = case ArrayTransforms -> ViewF
viewf ArrayTransforms
ts2 of
ArrayTransform
t :< ArrayTransforms
ts2' -> (ArrayTransforms
ts1 ArrayTransforms -> ArrayTransform -> ArrayTransforms
|> ArrayTransform
t) forall a. Semigroup a => a -> a -> a
<> ArrayTransforms
ts2'
ViewF
EmptyF -> ArrayTransforms
ts1
instance Monoid ArrayTransforms where
mempty :: ArrayTransforms
mempty = ArrayTransforms
noTransforms
instance Substitute ArrayTransforms where
substituteNames :: Map VName VName -> ArrayTransforms -> ArrayTransforms
substituteNames Map VName VName
substs (ArrayTransforms Seq ArrayTransform
ts) =
Seq ArrayTransform -> ArrayTransforms
ArrayTransforms forall a b. (a -> b) -> a -> b
$ forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Seq ArrayTransform
ts
noTransforms :: ArrayTransforms
noTransforms :: ArrayTransforms
noTransforms = Seq ArrayTransform -> ArrayTransforms
ArrayTransforms forall a. Seq a
Seq.empty
nullTransforms :: ArrayTransforms -> Bool
nullTransforms :: ArrayTransforms -> Bool
nullTransforms (ArrayTransforms Seq ArrayTransform
s) = forall a. Seq a -> Bool
Seq.null Seq ArrayTransform
s
viewf :: ArrayTransforms -> ViewF
viewf :: ArrayTransforms -> ViewF
viewf (ArrayTransforms Seq ArrayTransform
s) = case forall a. Seq a -> ViewL a
Seq.viewl Seq ArrayTransform
s of
ArrayTransform
t Seq.:< Seq ArrayTransform
s' -> ArrayTransform
t ArrayTransform -> ArrayTransforms -> ViewF
:< Seq ArrayTransform -> ArrayTransforms
ArrayTransforms Seq ArrayTransform
s'
ViewL ArrayTransform
Seq.EmptyL -> ViewF
EmptyF
data ViewF
= EmptyF
| ArrayTransform :< ArrayTransforms
viewl :: ArrayTransforms -> ViewL
viewl :: ArrayTransforms -> ViewL
viewl (ArrayTransforms Seq ArrayTransform
s) = case forall a. Seq a -> ViewR a
Seq.viewr Seq ArrayTransform
s of
Seq ArrayTransform
s' Seq.:> ArrayTransform
t -> Seq ArrayTransform -> ArrayTransforms
ArrayTransforms Seq ArrayTransform
s' ArrayTransforms -> ArrayTransform -> ViewL
:> ArrayTransform
t
ViewR ArrayTransform
Seq.EmptyR -> ViewL
EmptyL
data ViewL
= EmptyL
| ArrayTransforms :> ArrayTransform
(|>) :: ArrayTransforms -> ArrayTransform -> ArrayTransforms
|> :: ArrayTransforms -> ArrayTransform -> ArrayTransforms
(|>) = forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a b. (a -> b) -> a -> b
$ (ArrayTransforms -> Maybe (ArrayTransform, ArrayTransforms))
-> (ArrayTransform -> ArrayTransforms -> ArrayTransforms)
-> ((ArrayTransform, ArrayTransform)
-> (ArrayTransform, ArrayTransform))
-> ArrayTransform
-> ArrayTransforms
-> ArrayTransforms
addTransform' ArrayTransforms -> Maybe (ArrayTransform, ArrayTransforms)
extract ArrayTransform -> ArrayTransforms -> ArrayTransforms
add forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (forall a b c. (a -> b -> c) -> b -> a -> c
flip (,))
where
extract :: ArrayTransforms -> Maybe (ArrayTransform, ArrayTransforms)
extract ArrayTransforms
ts' = case ArrayTransforms -> ViewL
viewl ArrayTransforms
ts' of
ViewL
EmptyL -> forall a. Maybe a
Nothing
ArrayTransforms
ts'' :> ArrayTransform
t' -> forall a. a -> Maybe a
Just (ArrayTransform
t', ArrayTransforms
ts'')
add :: ArrayTransform -> ArrayTransforms -> ArrayTransforms
add ArrayTransform
t' (ArrayTransforms Seq ArrayTransform
ts') = Seq ArrayTransform -> ArrayTransforms
ArrayTransforms forall a b. (a -> b) -> a -> b
$ Seq ArrayTransform
ts' forall a. Seq a -> a -> Seq a
Seq.|> ArrayTransform
t'
(<|) :: ArrayTransform -> ArrayTransforms -> ArrayTransforms
<| :: ArrayTransform -> ArrayTransforms -> ArrayTransforms
(<|) = (ArrayTransforms -> Maybe (ArrayTransform, ArrayTransforms))
-> (ArrayTransform -> ArrayTransforms -> ArrayTransforms)
-> ((ArrayTransform, ArrayTransform)
-> (ArrayTransform, ArrayTransform))
-> ArrayTransform
-> ArrayTransforms
-> ArrayTransforms
addTransform' ArrayTransforms -> Maybe (ArrayTransform, ArrayTransforms)
extract ArrayTransform -> ArrayTransforms -> ArrayTransforms
add forall a. a -> a
id
where
extract :: ArrayTransforms -> Maybe (ArrayTransform, ArrayTransforms)
extract ArrayTransforms
ts' = case ArrayTransforms -> ViewF
viewf ArrayTransforms
ts' of
ViewF
EmptyF -> forall a. Maybe a
Nothing
ArrayTransform
t' :< ArrayTransforms
ts'' -> forall a. a -> Maybe a
Just (ArrayTransform
t', ArrayTransforms
ts'')
add :: ArrayTransform -> ArrayTransforms -> ArrayTransforms
add ArrayTransform
t' (ArrayTransforms Seq ArrayTransform
ts') = Seq ArrayTransform -> ArrayTransforms
ArrayTransforms forall a b. (a -> b) -> a -> b
$ ArrayTransform
t' forall a. a -> Seq a -> Seq a
Seq.<| Seq ArrayTransform
ts'
addTransform' ::
(ArrayTransforms -> Maybe (ArrayTransform, ArrayTransforms)) ->
(ArrayTransform -> ArrayTransforms -> ArrayTransforms) ->
((ArrayTransform, ArrayTransform) -> (ArrayTransform, ArrayTransform)) ->
ArrayTransform ->
ArrayTransforms ->
ArrayTransforms
addTransform' :: (ArrayTransforms -> Maybe (ArrayTransform, ArrayTransforms))
-> (ArrayTransform -> ArrayTransforms -> ArrayTransforms)
-> ((ArrayTransform, ArrayTransform)
-> (ArrayTransform, ArrayTransform))
-> ArrayTransform
-> ArrayTransforms
-> ArrayTransforms
addTransform' ArrayTransforms -> Maybe (ArrayTransform, ArrayTransforms)
extract ArrayTransform -> ArrayTransforms -> ArrayTransforms
add (ArrayTransform, ArrayTransform)
-> (ArrayTransform, ArrayTransform)
swap ArrayTransform
t ArrayTransforms
ts =
forall a. a -> Maybe a -> a
fromMaybe (ArrayTransform
t ArrayTransform -> ArrayTransforms -> ArrayTransforms
`add` ArrayTransforms
ts) forall a b. (a -> b) -> a -> b
$ do
(ArrayTransform
t', ArrayTransforms
ts') <- ArrayTransforms -> Maybe (ArrayTransform, ArrayTransforms)
extract ArrayTransforms
ts
ArrayTransform
combined <- forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ArrayTransform -> ArrayTransform -> Maybe ArrayTransform
combineTransforms forall a b. (a -> b) -> a -> b
$ (ArrayTransform, ArrayTransform)
-> (ArrayTransform, ArrayTransform)
swap (ArrayTransform
t', ArrayTransform
t)
forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$
if ArrayTransform -> Bool
identityTransform ArrayTransform
combined
then ArrayTransforms
ts'
else (ArrayTransforms -> Maybe (ArrayTransform, ArrayTransforms))
-> (ArrayTransform -> ArrayTransforms -> ArrayTransforms)
-> ((ArrayTransform, ArrayTransform)
-> (ArrayTransform, ArrayTransform))
-> ArrayTransform
-> ArrayTransforms
-> ArrayTransforms
addTransform' ArrayTransforms -> Maybe (ArrayTransform, ArrayTransforms)
extract ArrayTransform -> ArrayTransforms -> ArrayTransforms
add (ArrayTransform, ArrayTransform)
-> (ArrayTransform, ArrayTransform)
swap ArrayTransform
combined ArrayTransforms
ts'
identityTransform :: ArrayTransform -> Bool
identityTransform :: ArrayTransform -> Bool
identityTransform (Rearrange Certs
_ [Int]
perm) =
forall (t :: * -> *). Foldable t => t Bool -> Bool
Foldable.and forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. Eq a => a -> a -> Bool
(==) [Int]
perm [Int
0 ..]
identityTransform ArrayTransform
_ = Bool
False
combineTransforms :: ArrayTransform -> ArrayTransform -> Maybe ArrayTransform
combineTransforms :: ArrayTransform -> ArrayTransform -> Maybe ArrayTransform
combineTransforms (Rearrange Certs
cs2 [Int]
perm2) (Rearrange Certs
cs1 [Int]
perm1) =
forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Certs -> [Int] -> ArrayTransform
Rearrange (Certs
cs1 forall a. Semigroup a => a -> a -> a
<> Certs
cs2) forall a b. (a -> b) -> a -> b
$ [Int]
perm2 [Int] -> [Int] -> [Int]
`rearrangeCompose` [Int]
perm1
combineTransforms ArrayTransform
_ ArrayTransform
_ = forall a. Maybe a
Nothing
transformFromExp :: Certs -> Exp rep -> Maybe (VName, ArrayTransform)
transformFromExp :: forall rep. Certs -> Exp rep -> Maybe (VName, ArrayTransform)
transformFromExp Certs
cs (BasicOp (Futhark.Rearrange [Int]
perm VName
v)) =
forall a. a -> Maybe a
Just (VName
v, Certs -> [Int] -> ArrayTransform
Rearrange Certs
cs [Int]
perm)
transformFromExp Certs
cs (BasicOp (Futhark.Reshape ReshapeKind
k Shape
shape VName
v)) =
forall a. a -> Maybe a
Just (VName
v, Certs -> ReshapeKind -> Shape -> ArrayTransform
Reshape Certs
cs ReshapeKind
k Shape
shape)
transformFromExp Certs
cs (BasicOp (Futhark.Replicate Shape
shape (Var VName
v))) =
forall a. a -> Maybe a
Just (VName
v, Certs -> Shape -> ArrayTransform
Replicate Certs
cs Shape
shape)
transformFromExp Certs
_ Exp rep
_ = forall a. Maybe a
Nothing
transformToExp :: (Monad m, HasScope rep m) => ArrayTransform -> VName -> m (Certs, Exp rep)
transformToExp :: forall (m :: * -> *) rep.
(Monad m, HasScope rep m) =>
ArrayTransform -> VName -> m (Certs, Exp rep)
transformToExp (Replicate Certs
cs Shape
n) VName
ia =
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Certs
cs, forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Futhark.Replicate Shape
n (VName -> SubExp
Var VName
ia))
transformToExp (Rearrange Certs
cs [Int]
perm) VName
ia = do
Int
r <- forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
ia
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Certs
cs, forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Futhark.Rearrange ([Int]
perm forall a. [a] -> [a] -> [a]
++ [forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
perm .. Int
r forall a. Num a => a -> a -> a
- Int
1]) VName
ia)
transformToExp (Reshape Certs
cs ReshapeKind
k Shape
shape) VName
ia = do
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Certs
cs, forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ ReshapeKind -> Shape -> VName -> BasicOp
Futhark.Reshape ReshapeKind
k Shape
shape VName
ia)
transformToExp (ReshapeOuter Certs
cs ReshapeKind
k Shape
shape) VName
ia = do
Shape
shape' <- Shape -> Int -> Shape -> Shape
reshapeOuter Shape
shape Int
1 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
ia
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Certs
cs, forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ ReshapeKind -> Shape -> VName -> BasicOp
Futhark.Reshape ReshapeKind
k Shape
shape' VName
ia)
transformToExp (ReshapeInner Certs
cs ReshapeKind
k Shape
shape) VName
ia = do
Shape
shape' <- Shape -> Int -> Shape -> Shape
reshapeInner Shape
shape Int
1 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
ia
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Certs
cs, forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ ReshapeKind -> Shape -> VName -> BasicOp
Futhark.Reshape ReshapeKind
k Shape
shape' VName
ia)
data Input = Input ArrayTransforms VName Type
deriving (Int -> Input -> ShowS
[Input] -> ShowS
Input -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Input] -> ShowS
$cshowList :: [Input] -> ShowS
show :: Input -> String
$cshow :: Input -> String
showsPrec :: Int -> Input -> ShowS
$cshowsPrec :: Int -> Input -> ShowS
Show, Input -> Input -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Input -> Input -> Bool
$c/= :: Input -> Input -> Bool
== :: Input -> Input -> Bool
$c== :: Input -> Input -> Bool
Eq, Eq Input
Input -> Input -> Bool
Input -> Input -> Ordering
Input -> Input -> Input
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Input -> Input -> Input
$cmin :: Input -> Input -> Input
max :: Input -> Input -> Input
$cmax :: Input -> Input -> Input
>= :: Input -> Input -> Bool
$c>= :: Input -> Input -> Bool
> :: Input -> Input -> Bool
$c> :: Input -> Input -> Bool
<= :: Input -> Input -> Bool
$c<= :: Input -> Input -> Bool
< :: Input -> Input -> Bool
$c< :: Input -> Input -> Bool
compare :: Input -> Input -> Ordering
$ccompare :: Input -> Input -> Ordering
Ord)
instance Substitute Input where
substituteNames :: Map VName VName -> Input -> Input
substituteNames Map VName VName
substs (Input ArrayTransforms
ts VName
v Type
t) =
ArrayTransforms -> VName -> Type -> Input
Input
(forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs ArrayTransforms
ts)
(forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs VName
v)
(forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Type
t)
varInput :: HasScope t f => VName -> f Input
varInput :: forall t (f :: * -> *). HasScope t f => VName -> f Input
varInput VName
v = Type -> Input
withType forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
where
withType :: Type -> Input
withType = ArrayTransforms -> VName -> Type -> Input
Input (Seq ArrayTransform -> ArrayTransforms
ArrayTransforms forall a. Seq a
Seq.empty) VName
v
identInput :: Ident -> Input
identInput :: Ident -> Input
identInput Ident
v = ArrayTransforms -> VName -> Type -> Input
Input (Seq ArrayTransform -> ArrayTransforms
ArrayTransforms forall a. Seq a
Seq.empty) (Ident -> VName
identName Ident
v) (Ident -> Type
identType Ident
v)
isVarInput :: Input -> Maybe VName
isVarInput :: Input -> Maybe VName
isVarInput (Input ArrayTransforms
ts VName
v Type
_) | ArrayTransforms -> Bool
nullTransforms ArrayTransforms
ts = forall a. a -> Maybe a
Just VName
v
isVarInput Input
_ = forall a. Maybe a
Nothing
isVarishInput :: Input -> Maybe VName
isVarishInput :: Input -> Maybe VName
isVarishInput (Input ArrayTransforms
ts VName
v Type
t)
| ArrayTransforms -> Bool
nullTransforms ArrayTransforms
ts = forall a. a -> Maybe a
Just VName
v
| Reshape Certs
cs ReshapeKind
ReshapeCoerce (Shape [SubExp
_]) :< ArrayTransforms
ts' <- ArrayTransforms -> ViewF
viewf ArrayTransforms
ts,
Certs
cs forall a. Eq a => a -> a -> Bool
== forall a. Monoid a => a
mempty =
Input -> Maybe VName
isVarishInput forall a b. (a -> b) -> a -> b
$ ArrayTransforms -> VName -> Type -> Input
Input ArrayTransforms
ts' VName
v Type
t
isVarishInput Input
_ = forall a. Maybe a
Nothing
addTransform :: ArrayTransform -> Input -> Input
addTransform :: ArrayTransform -> Input -> Input
addTransform ArrayTransform
tr (Input ArrayTransforms
trs VName
a Type
t) =
ArrayTransforms -> VName -> Type -> Input
Input (ArrayTransforms
trs ArrayTransforms -> ArrayTransform -> ArrayTransforms
|> ArrayTransform
tr) VName
a Type
t
addInitialTransforms :: ArrayTransforms -> Input -> Input
addInitialTransforms :: ArrayTransforms -> Input -> Input
addInitialTransforms ArrayTransforms
ts (Input ArrayTransforms
ots VName
a Type
t) = ArrayTransforms -> VName -> Type -> Input
Input (ArrayTransforms
ts forall a. Semigroup a => a -> a -> a
<> ArrayTransforms
ots) VName
a Type
t
applyTransform :: MonadBuilder m => ArrayTransform -> VName -> m VName
applyTransform :: forall (m :: * -> *).
MonadBuilder m =>
ArrayTransform -> VName -> m VName
applyTransform ArrayTransform
tr VName
ia = do
(Certs
cs, Exp (Rep m)
e) <- forall (m :: * -> *) rep.
(Monad m, HasScope rep m) =>
ArrayTransform -> VName -> m (Certs, Exp rep)
transformToExp ArrayTransform
tr VName
ia
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
s Exp (Rep m)
e
where
s :: String
s = case ArrayTransform
tr of
Replicate {} -> String
"replicate"
Rearrange {} -> String
"rearrange"
Reshape {} -> String
"reshape"
ReshapeOuter {} -> String
"reshape_outer"
ReshapeInner {} -> String
"reshape_inner"
applyTransforms :: MonadBuilder m => ArrayTransforms -> VName -> m VName
applyTransforms :: forall (m :: * -> *).
MonadBuilder m =>
ArrayTransforms -> VName -> m VName
applyTransforms (ArrayTransforms Seq ArrayTransform
ts) VName
a = forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldlM (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall (m :: * -> *).
MonadBuilder m =>
ArrayTransform -> VName -> m VName
applyTransform) VName
a Seq ArrayTransform
ts
inputsToSubExps ::
(MonadBuilder m) =>
[Input] ->
m [VName]
inputsToSubExps :: forall (m :: * -> *). MonadBuilder m => [Input] -> m [VName]
inputsToSubExps = forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {m :: * -> *}. MonadBuilder m => Input -> m VName
f
where
f :: Input -> m VName
f (Input ArrayTransforms
ts VName
a Type
_) = forall (m :: * -> *).
MonadBuilder m =>
ArrayTransforms -> VName -> m VName
applyTransforms ArrayTransforms
ts VName
a
inputArray :: Input -> VName
inputArray :: Input -> VName
inputArray (Input ArrayTransforms
_ VName
v Type
_) = VName
v
inputTransforms :: Input -> ArrayTransforms
inputTransforms :: Input -> ArrayTransforms
inputTransforms (Input ArrayTransforms
ts VName
_ Type
_) = ArrayTransforms
ts
inputType :: Input -> Type
inputType :: Input -> Type
inputType (Input (ArrayTransforms Seq ArrayTransform
ts) VName
_ Type
at) =
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
Foldable.foldl Type -> ArrayTransform -> Type
transformType Type
at Seq ArrayTransform
ts
where
transformType :: Type -> ArrayTransform -> Type
transformType Type
t (Replicate Certs
_ Shape
shape) =
Type -> Shape -> Type
arrayOfShape Type
t Shape
shape
transformType Type
t (Rearrange Certs
_ [Int]
perm) =
[Int] -> Type -> Type
rearrangeType [Int]
perm Type
t
transformType Type
t (Reshape Certs
_ ReshapeKind
_ Shape
shape) =
Type
t forall newshape oldshape u.
ArrayShape newshape =>
TypeBase oldshape u -> newshape -> TypeBase newshape u
`setArrayShape` Shape
shape
transformType Type
t (ReshapeOuter Certs
_ ReshapeKind
_ Shape
shape) =
let Shape [SubExp]
oldshape = forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t
in Type
t forall newshape oldshape u.
ArrayShape newshape =>
TypeBase oldshape u -> newshape -> TypeBase newshape u
`setArrayShape` forall d. [d] -> ShapeBase d
Shape (forall d. ShapeBase d -> [d]
shapeDims Shape
shape forall a. [a] -> [a] -> [a]
++ forall a. Int -> [a] -> [a]
drop Int
1 [SubExp]
oldshape)
transformType Type
t (ReshapeInner Certs
_ ReshapeKind
_ Shape
shape) =
let Shape [SubExp]
oldshape = forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t
in Type
t forall newshape oldshape u.
ArrayShape newshape =>
TypeBase oldshape u -> newshape -> TypeBase newshape u
`setArrayShape` forall d. [d] -> ShapeBase d
Shape (forall a. Int -> [a] -> [a]
take Int
1 [SubExp]
oldshape forall a. [a] -> [a] -> [a]
++ forall d. ShapeBase d -> [d]
shapeDims Shape
shape)
inputRowType :: Input -> Type
inputRowType :: Input -> Type
inputRowType = forall u. TypeBase Shape u -> TypeBase Shape u
rowType forall b c a. (b -> c) -> (a -> b) -> a -> c
. Input -> Type
inputType
inputRank :: Input -> Int
inputRank :: Input -> Int
inputRank = forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank forall b c a. (b -> c) -> (a -> b) -> a -> c
. Input -> Type
inputType
transformRows :: ArrayTransforms -> Input -> Input
transformRows :: ArrayTransforms -> Input -> Input
transformRows (ArrayTransforms Seq ArrayTransform
ts) =
forall a b c. (a -> b -> c) -> b -> a -> c
flip (forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
Foldable.foldl Input -> ArrayTransform -> Input
transformRows') Seq ArrayTransform
ts
where
transformRows' :: Input -> ArrayTransform -> Input
transformRows' Input
inp (Rearrange Certs
cs [Int]
perm) =
ArrayTransform -> Input -> Input
addTransform (Certs -> [Int] -> ArrayTransform
Rearrange Certs
cs (Int
0 forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map (forall a. Num a => a -> a -> a
+ Int
1) [Int]
perm)) Input
inp
transformRows' Input
inp (Reshape Certs
cs ReshapeKind
k Shape
shape) =
ArrayTransform -> Input -> Input
addTransform (Certs -> ReshapeKind -> Shape -> ArrayTransform
ReshapeInner Certs
cs ReshapeKind
k Shape
shape) Input
inp
transformRows' Input
inp (Replicate Certs
cs Shape
n)
| Input -> Int
inputRank Input
inp forall a. Eq a => a -> a -> Bool
== Int
1 =
Certs -> [Int] -> ArrayTransform
Rearrange forall a. Monoid a => a
mempty [Int
1, Int
0]
ArrayTransform -> Input -> Input
`addTransform` (Certs -> Shape -> ArrayTransform
Replicate Certs
cs Shape
n ArrayTransform -> Input -> Input
`addTransform` Input
inp)
| Bool
otherwise =
Certs -> [Int] -> ArrayTransform
Rearrange forall a. Monoid a => a
mempty (Int
2 forall a. a -> [a] -> [a]
: Int
0 forall a. a -> [a] -> [a]
: Int
1 forall a. a -> [a] -> [a]
: [Int
3 .. Input -> Int
inputRank Input
inp])
ArrayTransform -> Input -> Input
`addTransform` ( Certs -> Shape -> ArrayTransform
Replicate Certs
cs Shape
n
ArrayTransform -> Input -> Input
`addTransform` (Certs -> [Int] -> ArrayTransform
Rearrange forall a. Monoid a => a
mempty (Int
1 forall a. a -> [a] -> [a]
: Int
0 forall a. a -> [a] -> [a]
: [Int
2 .. Input -> Int
inputRank Input
inp forall a. Num a => a -> a -> a
- Int
1]) ArrayTransform -> Input -> Input
`addTransform` Input
inp)
)
transformRows' Input
inp ArrayTransform
nts =
forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"transformRows: Cannot transform this yet:\n" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show ArrayTransform
nts forall a. [a] -> [a] -> [a]
++ String
"\n" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Input
inp
transposeInput :: Int -> Int -> Input -> Input
transposeInput :: Int -> Int -> Input -> Input
transposeInput Int
k Int
n Input
inp =
ArrayTransform -> Input -> Input
addTransform (Certs -> [Int] -> ArrayTransform
Rearrange forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall a. Int -> Int -> [a] -> [a]
transposeIndex Int
k Int
n [Int
0 .. Input -> Int
inputRank Input
inp forall a. Num a => a -> a -> a
- Int
1]) Input
inp
data SOAC rep
= Stream SubExp (Lambda rep) [SubExp] [Input]
| Scatter SubExp (Lambda rep) [Input] [(Shape, Int, VName)]
| Screma SubExp (ScremaForm rep) [Input]
| Hist SubExp [HistOp rep] (Lambda rep) [Input]
deriving (SOAC rep -> SOAC rep -> Bool
forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SOAC rep -> SOAC rep -> Bool
$c/= :: forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Bool
== :: SOAC rep -> SOAC rep -> Bool
$c== :: forall rep. RepTypes rep => SOAC rep -> SOAC rep -> Bool
Eq, Int -> SOAC rep -> ShowS
forall rep. RepTypes rep => Int -> SOAC rep -> ShowS
forall rep. RepTypes rep => [SOAC rep] -> ShowS
forall rep. RepTypes rep => SOAC rep -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SOAC rep] -> ShowS
$cshowList :: forall rep. RepTypes rep => [SOAC rep] -> ShowS
show :: SOAC rep -> String
$cshow :: forall rep. RepTypes rep => SOAC rep -> String
showsPrec :: Int -> SOAC rep -> ShowS
$cshowsPrec :: forall rep. RepTypes rep => Int -> SOAC rep -> ShowS
Show)
instance PP.Pretty Input where
pretty :: forall ann. Input -> Doc ann
pretty (Input (ArrayTransforms Seq ArrayTransform
ts) VName
arr Type
_) = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl forall {ann}. Doc ann -> ArrayTransform -> Doc ann
f (forall a ann. Pretty a => a -> Doc ann
pretty VName
arr) Seq ArrayTransform
ts
where
f :: Doc ann -> ArrayTransform -> Doc ann
f Doc ann
e (Rearrange Certs
cs [Int]
perm) =
Doc ann
"rearrange" forall a. Semigroup a => a -> a -> a
<> forall a ann. Pretty a => a -> Doc ann
pretty Certs
cs forall a. Semigroup a => a -> a -> a
<> forall a. [Doc a] -> Doc a
PP.apply [forall a. [Doc a] -> Doc a
PP.apply (forall a b. (a -> b) -> [a] -> [b]
map forall a ann. Pretty a => a -> Doc ann
pretty [Int]
perm), Doc ann
e]
f Doc ann
e (Reshape Certs
cs ReshapeKind
ReshapeArbitrary Shape
shape) =
Doc ann
"reshape" forall a. Semigroup a => a -> a -> a
<> forall a ann. Pretty a => a -> Doc ann
pretty Certs
cs forall a. Semigroup a => a -> a -> a
<> forall a. [Doc a] -> Doc a
PP.apply [forall a ann. Pretty a => a -> Doc ann
pretty Shape
shape, Doc ann
e]
f Doc ann
e (ReshapeOuter Certs
cs ReshapeKind
ReshapeArbitrary Shape
shape) =
Doc ann
"reshape_outer" forall a. Semigroup a => a -> a -> a
<> forall a ann. Pretty a => a -> Doc ann
pretty Certs
cs forall a. Semigroup a => a -> a -> a
<> forall a. [Doc a] -> Doc a
PP.apply [forall a ann. Pretty a => a -> Doc ann
pretty Shape
shape, Doc ann
e]
f Doc ann
e (ReshapeInner Certs
cs ReshapeKind
ReshapeArbitrary Shape
shape) =
Doc ann
"reshape_inner" forall a. Semigroup a => a -> a -> a
<> forall a ann. Pretty a => a -> Doc ann
pretty Certs
cs forall a. Semigroup a => a -> a -> a
<> forall a. [Doc a] -> Doc a
PP.apply [forall a ann. Pretty a => a -> Doc ann
pretty Shape
shape, Doc ann
e]
f Doc ann
e (Reshape Certs
cs ReshapeKind
ReshapeCoerce Shape
shape) =
Doc ann
"coerce" forall a. Semigroup a => a -> a -> a
<> forall a ann. Pretty a => a -> Doc ann
pretty Certs
cs forall a. Semigroup a => a -> a -> a
<> forall a. [Doc a] -> Doc a
PP.apply [forall a ann. Pretty a => a -> Doc ann
pretty Shape
shape, Doc ann
e]
f Doc ann
e (ReshapeOuter Certs
cs ReshapeKind
ReshapeCoerce Shape
shape) =
Doc ann
"coerce_outer" forall a. Semigroup a => a -> a -> a
<> forall a ann. Pretty a => a -> Doc ann
pretty Certs
cs forall a. Semigroup a => a -> a -> a
<> forall a. [Doc a] -> Doc a
PP.apply [forall a ann. Pretty a => a -> Doc ann
pretty Shape
shape, Doc ann
e]
f Doc ann
e (ReshapeInner Certs
cs ReshapeKind
ReshapeCoerce Shape
shape) =
Doc ann
"coerce_inner" forall a. Semigroup a => a -> a -> a
<> forall a ann. Pretty a => a -> Doc ann
pretty Certs
cs forall a. Semigroup a => a -> a -> a
<> forall a. [Doc a] -> Doc a
PP.apply [forall a ann. Pretty a => a -> Doc ann
pretty Shape
shape, Doc ann
e]
f Doc ann
e (Replicate Certs
cs Shape
ne) =
Doc ann
"replicate" forall a. Semigroup a => a -> a -> a
<> forall a ann. Pretty a => a -> Doc ann
pretty Certs
cs forall a. Semigroup a => a -> a -> a
<> forall a. [Doc a] -> Doc a
PP.apply [forall a ann. Pretty a => a -> Doc ann
pretty Shape
ne, Doc ann
e]
instance PrettyRep rep => PP.Pretty (SOAC rep) where
pretty :: forall ann. SOAC rep -> Doc ann
pretty (Screma SubExp
w ScremaForm rep
form [Input]
arrs) = forall rep inp ann.
(PrettyRep rep, Pretty inp) =>
SubExp -> [inp] -> ScremaForm rep -> Doc ann
Futhark.ppScrema SubExp
w [Input]
arrs ScremaForm rep
form
pretty (Hist SubExp
len [HistOp rep]
ops Lambda rep
bucket_fun [Input]
imgs) = forall rep inp ann.
(PrettyRep rep, Pretty inp) =>
SubExp -> [inp] -> [HistOp rep] -> Lambda rep -> Doc ann
Futhark.ppHist SubExp
len [Input]
imgs [HistOp rep]
ops Lambda rep
bucket_fun
pretty (Stream SubExp
w Lambda rep
lam [SubExp]
nes [Input]
arrs) = forall rep inp ann.
(PrettyRep rep, Pretty inp) =>
SubExp -> [inp] -> [SubExp] -> Lambda rep -> Doc ann
Futhark.ppStream SubExp
w [Input]
arrs [SubExp]
nes Lambda rep
lam
pretty (Scatter SubExp
w Lambda rep
lam [Input]
arrs [(Shape, Int, VName)]
dests) = forall rep inp ann.
(PrettyRep rep, Pretty inp) =>
SubExp -> [inp] -> Lambda rep -> [(Shape, Int, VName)] -> Doc ann
Futhark.ppScatter SubExp
w [Input]
arrs Lambda rep
lam [(Shape, Int, VName)]
dests
inputs :: SOAC rep -> [Input]
inputs :: forall rep. SOAC rep -> [Input]
inputs (Stream SubExp
_ Lambda rep
_ [SubExp]
_ [Input]
arrs) = [Input]
arrs
inputs (Scatter SubExp
_len Lambda rep
_lam [Input]
ivs [(Shape, Int, VName)]
_as) = [Input]
ivs
inputs (Screma SubExp
_ ScremaForm rep
_ [Input]
arrs) = [Input]
arrs
inputs (Hist SubExp
_ [HistOp rep]
_ Lambda rep
_ [Input]
inps) = [Input]
inps
setInputs :: [Input] -> SOAC rep -> SOAC rep
setInputs :: forall rep. [Input] -> SOAC rep -> SOAC rep
setInputs [Input]
arrs (Stream SubExp
w Lambda rep
lam [SubExp]
nes [Input]
_) =
forall rep. SubExp -> Lambda rep -> [SubExp] -> [Input] -> SOAC rep
Stream ([Input] -> SubExp -> SubExp
newWidth [Input]
arrs SubExp
w) Lambda rep
lam [SubExp]
nes [Input]
arrs
setInputs [Input]
arrs (Scatter SubExp
w Lambda rep
lam [Input]
_ivs [(Shape, Int, VName)]
as) =
forall rep.
SubExp
-> Lambda rep -> [Input] -> [(Shape, Int, VName)] -> SOAC rep
Scatter ([Input] -> SubExp -> SubExp
newWidth [Input]
arrs SubExp
w) Lambda rep
lam [Input]
arrs [(Shape, Int, VName)]
as
setInputs [Input]
arrs (Screma SubExp
w ScremaForm rep
form [Input]
_) =
forall rep. SubExp -> ScremaForm rep -> [Input] -> SOAC rep
Screma SubExp
w ScremaForm rep
form [Input]
arrs
setInputs [Input]
inps (Hist SubExp
w [HistOp rep]
ops Lambda rep
lam [Input]
_) =
forall rep.
SubExp -> [HistOp rep] -> Lambda rep -> [Input] -> SOAC rep
Hist SubExp
w [HistOp rep]
ops Lambda rep
lam [Input]
inps
newWidth :: [Input] -> SubExp -> SubExp
newWidth :: [Input] -> SubExp -> SubExp
newWidth [] SubExp
w = SubExp
w
newWidth (Input
inp : [Input]
_) SubExp
_ = forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 forall a b. (a -> b) -> a -> b
$ Input -> Type
inputType Input
inp
lambda :: SOAC rep -> Lambda rep
lambda :: forall rep. SOAC rep -> Lambda rep
lambda (Stream SubExp
_ Lambda rep
lam [SubExp]
_ [Input]
_) = Lambda rep
lam
lambda (Scatter SubExp
_len Lambda rep
lam [Input]
_ivs [(Shape, Int, VName)]
_as) = Lambda rep
lam
lambda (Screma SubExp
_ (ScremaForm [Scan rep]
_ [Reduce rep]
_ Lambda rep
lam) [Input]
_) = Lambda rep
lam
lambda (Hist SubExp
_ [HistOp rep]
_ Lambda rep
lam [Input]
_) = Lambda rep
lam
setLambda :: Lambda rep -> SOAC rep -> SOAC rep
setLambda :: forall rep. Lambda rep -> SOAC rep -> SOAC rep
setLambda Lambda rep
lam (Stream SubExp
w Lambda rep
_ [SubExp]
nes [Input]
arrs) =
forall rep. SubExp -> Lambda rep -> [SubExp] -> [Input] -> SOAC rep
Stream SubExp
w Lambda rep
lam [SubExp]
nes [Input]
arrs
setLambda Lambda rep
lam (Scatter SubExp
len Lambda rep
_lam [Input]
ivs [(Shape, Int, VName)]
as) =
forall rep.
SubExp
-> Lambda rep -> [Input] -> [(Shape, Int, VName)] -> SOAC rep
Scatter SubExp
len Lambda rep
lam [Input]
ivs [(Shape, Int, VName)]
as
setLambda Lambda rep
lam (Screma SubExp
w (ScremaForm [Scan rep]
scan [Reduce rep]
red Lambda rep
_) [Input]
arrs) =
forall rep. SubExp -> ScremaForm rep -> [Input] -> SOAC rep
Screma SubExp
w (forall rep.
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm [Scan rep]
scan [Reduce rep]
red Lambda rep
lam) [Input]
arrs
setLambda Lambda rep
lam (Hist SubExp
w [HistOp rep]
ops Lambda rep
_ [Input]
inps) =
forall rep.
SubExp -> [HistOp rep] -> Lambda rep -> [Input] -> SOAC rep
Hist SubExp
w [HistOp rep]
ops Lambda rep
lam [Input]
inps
typeOf :: SOAC rep -> [Type]
typeOf :: forall rep. SOAC rep -> [Type]
typeOf (Stream SubExp
w Lambda rep
lam [SubExp]
nes [Input]
_) =
let accrtps :: [Type]
accrtps = forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam
arrtps :: [Type]
arrtps =
[ forall shape u_unused u.
ArrayShape shape =>
TypeBase shape u_unused -> shape -> u -> TypeBase shape u
arrayOf (forall u. Int -> TypeBase Shape u -> TypeBase Shape u
stripArray Int
1 Type
t) (forall d. [d] -> ShapeBase d
Shape [SubExp
w]) NoUniqueness
NoUniqueness
| Type
t <- forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) (forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam)
]
in [Type]
accrtps forall a. [a] -> [a] -> [a]
++ [Type]
arrtps
typeOf (Scatter SubExp
_w Lambda rep
lam [Input]
_ivs [(Shape, Int, VName)]
dests) =
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Type -> Shape -> Type
arrayOfShape [Type]
val_ts [Shape]
ws
where
indexes :: Int
indexes = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. Num a => a -> a -> a
(*) [Int]
ns forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall (t :: * -> *) a. Foldable t => t a -> Int
length [Shape]
ws
val_ts :: [Type]
val_ts = forall a. Int -> [a] -> [a]
drop Int
indexes forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam
([Shape]
ws, [Int]
ns, [VName]
_) = forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Shape, Int, VName)]
dests
typeOf (Screma SubExp
w ScremaForm rep
form [Input]
_) =
forall rep. SubExp -> ScremaForm rep -> [Type]
scremaType SubExp
w ScremaForm rep
form
typeOf (Hist SubExp
_ [HistOp rep]
ops Lambda rep
_ [Input]
_) = do
HistOp rep
op <- [HistOp rep]
ops
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Shape -> Type
`arrayOfShape` forall rep. HistOp rep -> Shape
histShape HistOp rep
op) (forall rep. Lambda rep -> [Type]
lambdaReturnType forall a b. (a -> b) -> a -> b
$ forall rep. HistOp rep -> Lambda rep
histOp HistOp rep
op)
width :: SOAC rep -> SubExp
width :: forall rep. SOAC rep -> SubExp
width (Stream SubExp
w Lambda rep
_ [SubExp]
_ [Input]
_) = SubExp
w
width (Scatter SubExp
len Lambda rep
_lam [Input]
_ivs [(Shape, Int, VName)]
_as) = SubExp
len
width (Screma SubExp
w ScremaForm rep
_ [Input]
_) = SubExp
w
width (Hist SubExp
w [HistOp rep]
_ Lambda rep
_ [Input]
_) = SubExp
w
toExp ::
(MonadBuilder m, Op (Rep m) ~ Futhark.SOAC (Rep m)) =>
SOAC (Rep m) ->
m (Exp (Rep m))
toExp :: forall (m :: * -> *).
(MonadBuilder m, Op (Rep m) ~ SOAC (Rep m)) =>
SOAC (Rep m) -> m (Exp (Rep m))
toExp SOAC (Rep m)
soac = forall rep. Op rep -> Exp rep
Op forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *).
MonadBuilder m =>
SOAC (Rep m) -> m (SOAC (Rep m))
toSOAC SOAC (Rep m)
soac
toSOAC :: MonadBuilder m => SOAC (Rep m) -> m (Futhark.SOAC (Rep m))
toSOAC :: forall (m :: * -> *).
MonadBuilder m =>
SOAC (Rep m) -> m (SOAC (Rep m))
toSOAC (Stream SubExp
w Lambda (Rep m)
lam [SubExp]
nes [Input]
inps) =
forall rep. SubExp -> [VName] -> [SubExp] -> Lambda rep -> SOAC rep
Futhark.Stream SubExp
w forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *). MonadBuilder m => [Input] -> m [VName]
inputsToSubExps [Input]
inps forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure [SubExp]
nes forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda (Rep m)
lam
toSOAC (Scatter SubExp
w Lambda (Rep m)
lam [Input]
ivs [(Shape, Int, VName)]
dests) =
forall rep.
SubExp
-> [VName] -> Lambda rep -> [(Shape, Int, VName)] -> SOAC rep
Futhark.Scatter SubExp
w forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *). MonadBuilder m => [Input] -> m [VName]
inputsToSubExps [Input]
ivs forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda (Rep m)
lam forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure [(Shape, Int, VName)]
dests
toSOAC (Screma SubExp
w ScremaForm (Rep m)
form [Input]
arrs) =
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Futhark.Screma SubExp
w forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *). MonadBuilder m => [Input] -> m [VName]
inputsToSubExps [Input]
arrs forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure ScremaForm (Rep m)
form
toSOAC (Hist SubExp
w [HistOp (Rep m)]
ops Lambda (Rep m)
lam [Input]
arrs) =
forall rep.
SubExp -> [VName] -> [HistOp rep] -> Lambda rep -> SOAC rep
Futhark.Hist SubExp
w forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *). MonadBuilder m => [Input] -> m [VName]
inputsToSubExps [Input]
arrs forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure [HistOp (Rep m)]
ops forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda (Rep m)
lam
data NotSOAC
=
NotSOAC
deriving (Int -> NotSOAC -> ShowS
[NotSOAC] -> ShowS
NotSOAC -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [NotSOAC] -> ShowS
$cshowList :: [NotSOAC] -> ShowS
show :: NotSOAC -> String
$cshow :: NotSOAC -> String
showsPrec :: Int -> NotSOAC -> ShowS
$cshowsPrec :: Int -> NotSOAC -> ShowS
Show)
fromExp ::
(Op rep ~ Futhark.SOAC rep, HasScope rep m) =>
Exp rep ->
m (Either NotSOAC (SOAC rep))
fromExp :: forall rep (m :: * -> *).
(Op rep ~ SOAC rep, HasScope rep m) =>
Exp rep -> m (Either NotSOAC (SOAC rep))
fromExp (Op (Futhark.Stream SubExp
w [VName]
as [SubExp]
nes Lambda rep
lam)) =
forall a b. b -> Either a b
Right forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. SubExp -> Lambda rep -> [SubExp] -> [Input] -> SOAC rep
Stream SubExp
w Lambda rep
lam [SubExp]
nes forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall t (f :: * -> *). HasScope t f => VName -> f Input
varInput [VName]
as
fromExp (Op (Futhark.Scatter SubExp
w [VName]
ivs Lambda rep
lam [(Shape, Int, VName)]
as)) =
forall a b. b -> Either a b
Right forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall rep.
SubExp
-> Lambda rep -> [Input] -> [(Shape, Int, VName)] -> SOAC rep
Scatter SubExp
w Lambda rep
lam forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall t (f :: * -> *). HasScope t f => VName -> f Input
varInput [VName]
ivs forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure [(Shape, Int, VName)]
as)
fromExp (Op (Futhark.Screma SubExp
w [VName]
arrs ScremaForm rep
form)) =
forall a b. b -> Either a b
Right forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. SubExp -> ScremaForm rep -> [Input] -> SOAC rep
Screma SubExp
w ScremaForm rep
form forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall t (f :: * -> *). HasScope t f => VName -> f Input
varInput [VName]
arrs
fromExp (Op (Futhark.Hist SubExp
w [VName]
arrs [HistOp rep]
ops Lambda rep
lam)) =
forall a b. b -> Either a b
Right forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep.
SubExp -> [HistOp rep] -> Lambda rep -> [Input] -> SOAC rep
Hist SubExp
w [HistOp rep]
ops Lambda rep
lam forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall t (f :: * -> *). HasScope t f => VName -> f Input
varInput [VName]
arrs
fromExp Exp rep
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a b. a -> Either a b
Left NotSOAC
NotSOAC
soacToStream ::
( HasScope rep m,
MonadFreshNames m,
Buildable rep,
BuilderOps rep,
Op rep ~ Futhark.SOAC rep
) =>
SOAC rep ->
m (SOAC rep, [Ident])
soacToStream :: forall rep (m :: * -> *).
(HasScope rep m, MonadFreshNames m, Buildable rep, BuilderOps rep,
Op rep ~ SOAC rep) =>
SOAC rep -> m (SOAC rep, [Ident])
soacToStream SOAC rep
soac = do
Param Type
chunk_param <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"chunk" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
let chvar :: SubExp
chvar = VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param Type
chunk_param
(Lambda rep
lam, [Input]
inps) = (forall rep. SOAC rep -> Lambda rep
lambda SOAC rep
soac, forall rep. SOAC rep -> [Input]
inputs SOAC rep
soac)
w :: SubExp
w = forall rep. SOAC rep -> SubExp
width SOAC rep
soac
Lambda rep
lam' <- forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda rep
lam
let arrrtps :: [Type]
arrrtps = forall rep. SubExp -> Lambda rep -> [Type]
mapType SubExp
w Lambda rep
lam
loutps :: [Type]
loutps = [forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
arrayOfRow Type
t SubExp
chvar | Type
t <- forall a b. (a -> b) -> [a] -> [b]
map forall u. TypeBase Shape u -> TypeBase Shape u
rowType [Type]
arrrtps]
lintps :: [Type]
lintps = [forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
arrayOfRow Type
t SubExp
chvar | Type
t <- forall a b. (a -> b) -> [a] -> [b]
map Input -> Type
inputRowType [Input]
inps]
[Param Type]
strm_inpids <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"inp") [Type]
lintps
case SOAC rep
soac of
Screma SubExp
_ ScremaForm rep
form [Input]
_
| Just Lambda rep
_ <- forall rep. ScremaForm rep -> Maybe (Lambda rep)
Futhark.isMapSOAC ScremaForm rep
form -> do
[Ident]
strm_resids <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
MonadFreshNames m =>
String -> Type -> m Ident
newIdent String
"res") [Type]
loutps
let insoac :: SOAC rep
insoac =
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Futhark.Screma SubExp
chvar (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param Type]
strm_inpids) forall a b. (a -> b) -> a -> b
$
forall rep. Lambda rep -> ScremaForm rep
Futhark.mapSOAC Lambda rep
lam'
insstm :: Stm rep
insstm = forall rep. Buildable rep => [Ident] -> Exp rep -> Stm rep
mkLet [Ident]
strm_resids forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op SOAC rep
insoac
strmbdy :: Body rep
strmbdy = forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody (forall rep. Stm rep -> Stms rep
oneStm Stm rep
insstm) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> SubExpRes
subExpRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ident -> VName
identName) [Ident]
strm_resids
strmpar :: [Param Type]
strmpar = Param Type
chunk_param forall a. a -> [a] -> [a]
: [Param Type]
strm_inpids
strmlam :: Lambda rep
strmlam = forall rep. [LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda [Param Type]
strmpar Body rep
strmbdy [Type]
loutps
forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall rep. SubExp -> Lambda rep -> [SubExp] -> [Input] -> SOAC rep
Stream SubExp
w Lambda rep
strmlam [] [Input]
inps, [])
| Just ([Scan rep]
scans, Lambda rep
_) <- forall rep. ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
Futhark.isScanomapSOAC ScremaForm rep
form,
Futhark.Scan Lambda rep
scan_lam [SubExp]
nes <- forall rep. Buildable rep => [Scan rep] -> Scan rep
Futhark.singleScan [Scan rep]
scans -> do
let scan_arr_ts :: [Type]
scan_arr_ts = forall a b. (a -> b) -> [a] -> [b]
map (forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp
chvar) forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
scan_lam
accrtps :: [Type]
accrtps = forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
scan_lam
[Param Type]
inpacc_ids <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"inpacc") [Type]
accrtps
Lambda rep
maplam <- forall (m :: * -> *) rep.
(MonadFreshNames m, Buildable rep) =>
[SubExp] -> Lambda rep -> m (Lambda rep)
mkMapPlusAccLam (forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> VName
paramName) [Param Type]
inpacc_ids) Lambda rep
scan_lam
let strmpar :: [Param Type]
strmpar = Param Type
chunk_param forall a. a -> [a] -> [a]
: [Param Type]
inpacc_ids forall a. [a] -> [a] -> [a]
++ [Param Type]
strm_inpids
Lambda rep
strmlam <- forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a, b) -> a
fst forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda [Param Type]
strmpar forall a b. (a -> b) -> a -> b
$ do
([VName]
scan0_ids, [VName]
map_resids) <-
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
scan_arr_ts)) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"scan" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Futhark.Screma SubExp
chvar (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param Type]
strm_inpids) forall a b. (a -> b) -> a -> b
$
forall rep. [Scan rep] -> Lambda rep -> ScremaForm rep
Futhark.scanomapSOAC [forall rep. Lambda rep -> [SubExp] -> Scan rep
Futhark.Scan Lambda rep
scan_lam [SubExp]
nes] Lambda rep
lam'
SubExp
outszm1id <-
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"outszm1" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
BinOp -> SubExp -> SubExp -> BasicOp
BinOp
(IntType -> Overflow -> BinOp
Sub IntType
Int64 Overflow
OverflowUndef)
(VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param Type
chunk_param)
(forall v. IsValue v => v -> SubExp
constant (Int64
1 :: Int64))
VName
empty_arr <-
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"empty_arr" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp
(IntType -> CmpOp
CmpSlt IntType
Int64)
SubExp
outszm1id
(forall v. IsValue v => v -> SubExp
constant (Int64
0 :: Int64))
let indexLast :: VName -> BuilderT rep (State VNameSource) (Exp rep)
indexLast VName
arr = do
Type
arr_t <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Slice SubExp -> BasicOp
Index VName
arr forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
arr_t [forall d. d -> DimIndex d
DimFix SubExp
outszm1id]
[VName]
lastel_ids <-
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"lastel"
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
(forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
empty_arr)
(forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [SubExp]
nes)
(forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map VName -> BuilderT rep (State VNameSource) (Exp rep)
indexLast [VName]
scan0_ids)
Body rep
addlelbdy <-
forall (m :: * -> *) rep.
(MonadFreshNames m, Buildable rep) =>
Lambda rep -> [SubExp] -> m (Body rep)
mkPlusBnds Lambda rep
scan_lam forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param Type]
inpacc_ids forall a. [a] -> [a] -> [a]
++ [VName]
lastel_ids
let (Stms rep
addlelstm, Result
addlelres) = (forall rep. Body rep -> Stms rep
bodyStms Body rep
addlelbdy, forall rep. Body rep -> Result
bodyResult Body rep
addlelbdy)
[VName]
strm_resids <-
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"strm_res" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Futhark.Screma SubExp
chvar [VName]
scan0_ids (forall rep. Lambda rep -> ScremaForm rep
Futhark.mapSOAC Lambda rep
maplam)
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms rep
addlelstm
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Result
addlelres forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> SubExpRes
subExpRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) ([VName]
strm_resids forall a. [a] -> [a] -> [a]
++ [VName]
map_resids)
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( forall rep. SubExp -> Lambda rep -> [SubExp] -> [Input] -> SOAC rep
Stream SubExp
w Lambda rep
strmlam [SubExp]
nes [Input]
inps,
forall a b. (a -> b) -> [a] -> [b]
map forall dec. Typed dec => Param dec -> Ident
paramIdent [Param Type]
inpacc_ids
)
| Just ([Reduce rep]
reds, Lambda rep
_) <- forall rep. ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
Futhark.isRedomapSOAC ScremaForm rep
form,
Futhark.Reduce Commutativity
comm Lambda rep
lamin [SubExp]
nes <- forall rep. Buildable rep => [Reduce rep] -> Reduce rep
Futhark.singleReduce [Reduce rep]
reds -> do
let accrtps :: [Type]
accrtps = forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam
loutps' :: [Type]
loutps' = forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) [Type]
loutps
foldlam :: Lambda rep
foldlam = Lambda rep
lam'
[Ident]
strm_resids <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
MonadFreshNames m =>
String -> Type -> m Ident
newIdent String
"res") [Type]
loutps'
[Param Type]
inpacc_ids <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"inpacc") [Type]
accrtps
[Ident]
acc0_ids <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
MonadFreshNames m =>
String -> Type -> m Ident
newIdent String
"acc0") [Type]
accrtps
let insoac :: SOAC rep
insoac =
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Futhark.Screma
SubExp
chvar
(forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param Type]
strm_inpids)
forall a b. (a -> b) -> a -> b
$ forall rep. [Reduce rep] -> Lambda rep -> ScremaForm rep
Futhark.redomapSOAC [forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Futhark.Reduce Commutativity
comm Lambda rep
lamin [SubExp]
nes] Lambda rep
foldlam
insstm :: Stm rep
insstm = forall rep. Buildable rep => [Ident] -> Exp rep -> Stm rep
mkLet ([Ident]
acc0_ids forall a. [a] -> [a] -> [a]
++ [Ident]
strm_resids) forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op SOAC rep
insoac
Body rep
addaccbdy <-
forall (m :: * -> *) rep.
(MonadFreshNames m, Buildable rep) =>
Lambda rep -> [SubExp] -> m (Body rep)
mkPlusBnds Lambda rep
lamin forall a b. (a -> b) -> a -> b
$
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$
forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param Type]
inpacc_ids forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
acc0_ids
let (Stms rep
addaccstm, Result
addaccres) = (forall rep. Body rep -> Stms rep
bodyStms Body rep
addaccbdy, forall rep. Body rep -> Result
bodyResult Body rep
addaccbdy)
strmbdy :: Body rep
strmbdy =
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody (forall rep. Stm rep -> Stms rep
oneStm Stm rep
insstm forall a. Semigroup a => a -> a -> a
<> Stms rep
addaccstm) forall a b. (a -> b) -> a -> b
$
Result
addaccres forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> SubExpRes
subExpRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ident -> VName
identName) [Ident]
strm_resids
strmpar :: [Param Type]
strmpar = Param Type
chunk_param forall a. a -> [a] -> [a]
: [Param Type]
inpacc_ids forall a. [a] -> [a] -> [a]
++ [Param Type]
strm_inpids
strmlam :: Lambda rep
strmlam = forall rep. [LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda [Param Type]
strmpar Body rep
strmbdy ([Type]
accrtps forall a. [a] -> [a] -> [a]
++ [Type]
loutps')
forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall rep. SubExp -> Lambda rep -> [SubExp] -> [Input] -> SOAC rep
Stream SubExp
w Lambda rep
strmlam [SubExp]
nes [Input]
inps, [])
SOAC rep
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (SOAC rep
soac, [])
where
mkMapPlusAccLam ::
(MonadFreshNames m, Buildable rep) =>
[SubExp] ->
Lambda rep ->
m (Lambda rep)
mkMapPlusAccLam :: forall (m :: * -> *) rep.
(MonadFreshNames m, Buildable rep) =>
[SubExp] -> Lambda rep -> m (Lambda rep)
mkMapPlusAccLam [SubExp]
accs Lambda rep
plus = do
let ([Param Type]
accpars, [Param Type]
rempars) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
accs) forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
plus
parstms :: [Stm rep]
parstms =
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
(\Param Type
par SubExp
se -> forall rep. Buildable rep => [Ident] -> Exp rep -> Stm rep
mkLet [forall dec. Typed dec => Param dec -> Ident
paramIdent Param Type
par] (forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se))
[Param Type]
accpars
[SubExp]
accs
plus_bdy :: Body rep
plus_bdy = forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
plus
newlambdy :: Body rep
newlambdy =
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body
(forall rep. Body rep -> BodyDec rep
bodyDec Body rep
plus_bdy)
(forall rep. [Stm rep] -> Stms rep
stmsFromList [Stm rep]
parstms forall a. Semigroup a => a -> a -> a
<> forall rep. Body rep -> Stms rep
bodyStms Body rep
plus_bdy)
(forall rep. Body rep -> Result
bodyResult Body rep
plus_bdy)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda forall a b. (a -> b) -> a -> b
$ forall rep. [LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda [Param Type]
rempars Body rep
newlambdy forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
plus
mkPlusBnds ::
(MonadFreshNames m, Buildable rep) =>
Lambda rep ->
[SubExp] ->
m (Body rep)
mkPlusBnds :: forall (m :: * -> *) rep.
(MonadFreshNames m, Buildable rep) =>
Lambda rep -> [SubExp] -> m (Body rep)
mkPlusBnds Lambda rep
plus [SubExp]
accels = do
Lambda rep
plus' <- forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda rep
plus
let parstms :: [Stm rep]
parstms =
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
(\Param Type
par SubExp
se -> forall rep. Buildable rep => [Ident] -> Exp rep -> Stm rep
mkLet [forall dec. Typed dec => Param dec -> Ident
paramIdent Param Type
par] (forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se))
(forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
plus')
[SubExp]
accels
body :: Body rep
body = forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
plus'
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Body rep
body {bodyStms :: Stms rep
bodyStms = forall rep. [Stm rep] -> Stms rep
stmsFromList [Stm rep]
parstms forall a. Semigroup a => a -> a -> a
<> forall rep. Body rep -> Stms rep
bodyStms Body rep
body}