module Futhark.Internalise.Lambdas
( InternaliseLambda,
internaliseFoldLambda,
internalisePartitionLambda,
)
where
import Futhark.IR.SOACS as I
import Futhark.Internalise.AccurateSizes
import Futhark.Internalise.Monad
import Language.Futhark as E
type InternaliseLambda =
E.Exp -> [I.Type] -> InternaliseM ([I.LParam SOACS], I.Body SOACS, [I.Type])
internaliseFoldLambda ::
InternaliseLambda ->
E.Exp ->
[I.Type] ->
[I.Type] ->
InternaliseM (I.Lambda SOACS)
internaliseFoldLambda :: InternaliseLambda
-> Exp -> [Type] -> [Type] -> InternaliseM (Lambda SOACS)
internaliseFoldLambda InternaliseLambda
internaliseLambda Exp
lam [Type]
acctypes [Type]
arrtypes = do
let rowtypes :: [Type]
rowtypes = forall a b. (a -> b) -> [a] -> [b]
map forall u. TypeBase Shape u -> TypeBase Shape u
I.rowType [Type]
arrtypes
([Param Type]
params, Body SOACS
body, [Type]
rettype) <- InternaliseLambda
internaliseLambda Exp
lam forall a b. (a -> b) -> a -> b
$ [Type]
acctypes forall a. [a] -> [a] -> [a]
++ [Type]
rowtypes
let rettype' :: [Type]
rettype' =
[ Type
t forall newshape oldshape u.
ArrayShape newshape =>
TypeBase oldshape u -> newshape -> TypeBase newshape u
`I.setArrayShape` forall shape u. ArrayShape shape => TypeBase shape u -> shape
I.arrayShape Type
shape
| (Type
t, Type
shape) <- forall a b. [a] -> [b] -> [(a, b)]
zip [Type]
rettype [Type]
acctypes
]
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda [Param Type]
params forall a b. (a -> b) -> a -> b
$
ErrorMsg SubExp
-> SrcLoc -> [Type] -> Result -> InternaliseM Result
ensureResultShape
(forall a. [ErrorMsgPart a] -> ErrorMsg a
ErrorMsg [forall a. Text -> ErrorMsgPart a
ErrorString Text
"shape of result does not match shape of initial value"])
(forall a. Located a => a -> SrcLoc
srclocOf Exp
lam)
[Type]
rettype'
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind Body SOACS
body
internalisePartitionLambda ::
InternaliseLambda ->
Int ->
E.Exp ->
[I.SubExp] ->
InternaliseM (I.Lambda SOACS)
internalisePartitionLambda :: InternaliseLambda
-> Int -> Exp -> [SubExp] -> InternaliseM (Lambda SOACS)
internalisePartitionLambda InternaliseLambda
internaliseLambda Int
k Exp
lam [SubExp]
args = do
[Type]
argtypes <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall t (m :: * -> *). HasScope t m => SubExp -> m Type
I.subExpType [SubExp]
args
let rowtypes :: [Type]
rowtypes = forall a b. (a -> b) -> [a] -> [b]
map forall u. TypeBase Shape u -> TypeBase Shape u
I.rowType [Type]
argtypes
([Param Type]
params, Body SOACS
body, [Type]
_) <- InternaliseLambda
internaliseLambda Exp
lam [Type]
rowtypes
Body SOACS
body' <-
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams [Param Type]
params) forall a b. (a -> b) -> a -> b
$
Body SOACS -> InternaliseM (Body SOACS)
lambdaWithIncrement Body SOACS
body
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. [LParam rep] -> Body rep -> [Type] -> Lambda rep
I.Lambda [Param Type]
params Body SOACS
body' forall {shape} {u}. [TypeBase shape u]
rettype
where
rettype :: [TypeBase shape u]
rettype = forall a. Int -> a -> [a]
replicate (Int
k forall a. Num a => a -> a -> a
+ Int
2) forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
I.Prim PrimType
int64
result :: Int -> [SubExp]
result Int
i =
forall a b. (a -> b) -> [a] -> [b]
map forall v. IsValue v => v -> SubExp
constant forall a b. (a -> b) -> a -> b
$
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
i
forall a. a -> [a] -> [a]
: (forall a. Int -> a -> [a]
replicate Int
i Int64
0 forall a. [a] -> [a] -> [a]
++ [Int64
1 :: Int64] forall a. [a] -> [a] -> [a]
++ forall a. Int -> a -> [a]
replicate (Int
k forall a. Num a => a -> a -> a
- Int
i) Int64
0)
mkResult :: SubExp -> Int -> f [SubExp]
mkResult SubExp
_ Int
i | Int
i forall a. Ord a => a -> a -> Bool
>= Int
k = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Int -> [SubExp]
result Int
i
mkResult SubExp
eq_class Int
i = do
SubExp
is_i <-
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"is_i" forall a b. (a -> b) -> a -> b
$
forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (PrimType -> CmpOp
CmpEq PrimType
int64) SubExp
eq_class forall a b. (a -> b) -> a -> b
$
IntType -> Integer -> SubExp
intConst IntType
Int64 forall a b. (a -> b) -> a -> b
$
forall a. Integral a => a -> Integer
toInteger Int
i
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [SubExp]
letTupExp' String
"part_res"
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
(forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
is_i)
(forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. Buildable rep => [SubExp] -> Body rep
resultBody forall a b. (a -> b) -> a -> b
$ Int -> [SubExp]
result Int
i)
(forall rep. Buildable rep => [SubExp] -> Body rep
resultBody forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> Int -> f [SubExp]
mkResult SubExp
eq_class (Int
i forall a. Num a => a -> a -> a
+ Int
1))
lambdaWithIncrement :: I.Body SOACS -> InternaliseM (I.Body SOACS)
lambdaWithIncrement :: Body SOACS -> InternaliseM (Body SOACS)
lambdaWithIncrement Body SOACS
lam_body = forall rep (m :: * -> *) somerep.
(Buildable rep, MonadFreshNames m, HasScope somerep m,
SameScope somerep rep) =>
Builder rep (Body rep) -> m (Body rep)
runBodyBuilder forall a b. (a -> b) -> a -> b
$ do
SubExp
eq_class <- SubExpRes -> SubExp
resSubExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [a] -> a
head forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind Body SOACS
lam_body
forall rep. Buildable rep => [SubExp] -> Body rep
resultBody forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {f :: * -> *}.
(MonadBuilder f, Buildable (Rep f)) =>
SubExp -> Int -> f [SubExp]
mkResult SubExp
eq_class Int
0