{-# LANGUAGE TypeFamilies #-}
module Futhark.Construct
(
module Futhark.Builder,
letSubExp,
letExp,
letTupExp,
letTupExp',
letInPlace,
eSubExp,
eParam,
eMatch',
eMatch,
eIf,
eIf',
eBinOp,
eUnOp,
eCmpOp,
eConvOp,
eSignum,
eCopy,
eBody,
eLambda,
eBlank,
eAll,
eAny,
eDimInBounds,
eOutOfBounds,
eIndex,
eLast,
asIntZ,
asIntS,
resultBody,
resultBodyM,
insertStmsM,
buildBody,
buildBody_,
mapResult,
foldBinOp,
binOpLambda,
cmpOpLambda,
mkLambda,
sliceDim,
fullSlice,
fullSliceNum,
isFullSlice,
sliceAt,
instantiateShapes,
instantiateShapes',
removeExistentials,
simpleMkLetNames,
ToExp (..),
toSubExp,
)
where
import Control.Monad
import Control.Monad.Identity
import Control.Monad.State
import Data.List (foldl', sortOn, transpose)
import Data.Map.Strict qualified as M
import Futhark.Builder
import Futhark.IR
import Futhark.Util (maybeNth)
letSubExp ::
(MonadBuilder m) =>
String ->
Exp (Rep m) ->
m SubExp
letSubExp :: forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
_ (BasicOp (SubExp SubExp
se)) = forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
se
letSubExp String
desc Exp (Rep m)
e = VName -> SubExp
Var forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
desc Exp (Rep m)
e
letExp ::
(MonadBuilder m) =>
String ->
Exp (Rep m) ->
m VName
letExp :: forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
_ (BasicOp (SubExp (Var VName
v))) =
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v
letExp String
desc Exp (Rep m)
e = do
Int
n <- forall (t :: * -> *) a. Foldable t => t a -> Int
length forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep (m :: * -> *).
(HasScope rep m, TypedOp (Op rep)) =>
Exp rep -> m [ExtType]
expExtType Exp (Rep m)
e
[VName]
vs <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
desc
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName]
vs Exp (Rep m)
e
case [VName]
vs of
[VName
v] -> forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v
[VName]
_ -> forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"letExp: tuple-typed expression given:\n" forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> String
prettyString Exp (Rep m)
e
letInPlace ::
(MonadBuilder m) =>
String ->
VName ->
Slice SubExp ->
Exp (Rep m) ->
m VName
letInPlace :: forall (m :: * -> *).
MonadBuilder m =>
String -> VName -> Slice SubExp -> Exp (Rep m) -> m VName
letInPlace String
desc VName
src Slice SubExp
slice Exp (Rep m)
e = do
SubExp
tmp <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp (String
desc forall a. [a] -> [a] -> [a]
++ String
"_tmp") Exp (Rep m)
e
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
desc forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update Safety
Unsafe VName
src Slice SubExp
slice SubExp
tmp
letTupExp ::
(MonadBuilder m) =>
String ->
Exp (Rep m) ->
m [VName]
letTupExp :: forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
_ (BasicOp (SubExp (Var VName
v))) =
forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName
v]
letTupExp String
name Exp (Rep m)
e = do
[ExtType]
e_t <- forall rep (m :: * -> *).
(HasScope rep m, TypedOp (Op rep)) =>
Exp rep -> m [ExtType]
expExtType Exp (Rep m)
e
[VName]
names <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (forall (t :: * -> *) a. Foldable t => t a -> Int
length [ExtType]
e_t) forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
name
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName]
names Exp (Rep m)
e
forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName]
names
letTupExp' ::
(MonadBuilder m) =>
String ->
Exp (Rep m) ->
m [SubExp]
letTupExp' :: forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [SubExp]
letTupExp' String
_ (BasicOp (SubExp SubExp
se)) = forall (f :: * -> *) a. Applicative f => a -> f a
pure [SubExp
se]
letTupExp' String
name Exp (Rep m)
ses = forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
name Exp (Rep m)
ses
eSubExp ::
(MonadBuilder m) =>
SubExp ->
m (Exp (Rep m))
eSubExp :: forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp
eParam ::
(MonadBuilder m) =>
Param t ->
m (Exp (Rep m))
eParam :: forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam = forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> VName
paramName
removeRedundantScrutinees :: [SubExp] -> [Case b] -> ([SubExp], [Case b])
removeRedundantScrutinees :: forall b. [SubExp] -> [Case b] -> ([SubExp], [Case b])
removeRedundantScrutinees [SubExp]
ses [Case b]
cases =
let ([SubExp]
ses', [[Maybe PrimValue]]
vs) =
forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter forall {a}. (a, [Maybe PrimValue]) -> Bool
interesting forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
ses forall a b. (a -> b) -> a -> b
$ forall a. [[a]] -> [[a]]
transpose (forall a b. (a -> b) -> [a] -> [b]
map forall body. Case body -> [Maybe PrimValue]
casePat [Case b]
cases)
in ([SubExp]
ses', forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall body. [Maybe PrimValue] -> body -> Case body
Case (forall a. [[a]] -> [[a]]
transpose [[Maybe PrimValue]]
vs) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall body. Case body -> body
caseBody [Case b]
cases)
where
interesting :: (a, [Maybe PrimValue]) -> Bool
interesting = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (forall a. Eq a => a -> a -> Bool
/= forall a. Maybe a
Nothing) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd
eMatch' ::
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
[SubExp] ->
[Case (m (Body (Rep m)))] ->
m (Body (Rep m)) ->
MatchSort ->
m (Exp (Rep m))
eMatch' :: forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
[SubExp]
-> [Case (m (Body (Rep m)))]
-> m (Body (Rep m))
-> MatchSort
-> m (Exp (Rep m))
eMatch' [SubExp]
ses [Case (m (Body (Rep m)))]
cases_m m (Body (Rep m))
defbody_m MatchSort
sort = do
[Case (Body (Rep m))]
cases <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall (m :: * -> *).
MonadBuilder m =>
m (Body (Rep m)) -> m (Body (Rep m))
insertStmsM) [Case (m (Body (Rep m)))]
cases_m
Body (Rep m)
defbody <- forall (m :: * -> *).
MonadBuilder m =>
m (Body (Rep m)) -> m (Body (Rep m))
insertStmsM m (Body (Rep m))
defbody_m
[ExtType]
ts <-
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' forall u.
[TypeBase (ShapeBase (Ext SubExp)) u]
-> [TypeBase (ShapeBase (Ext SubExp)) u]
-> [TypeBase (ShapeBase (Ext SubExp)) u]
generaliseExtTypes
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep (m :: * -> *).
(HasScope rep m, Monad m) =>
Body rep -> m [ExtType]
bodyExtType Body (Rep m)
defbody
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall rep (m :: * -> *).
(HasScope rep m, Monad m) =>
Body rep -> m [ExtType]
bodyExtType forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall body. Case body -> body
caseBody) [Case (Body (Rep m))]
cases
[Case (Body (Rep m))]
cases' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall a b. (a -> b) -> a -> b
$ forall {m :: * -> *} {u}.
MonadBuilder m =>
[TypeBase (ShapeBase (Ext SubExp)) u]
-> Body (Rep m) -> m (Body (Rep m))
addContextForBranch [ExtType]
ts) [Case (Body (Rep m))]
cases
Body (Rep m)
defbody' <- forall {m :: * -> *} {u}.
MonadBuilder m =>
[TypeBase (ShapeBase (Ext SubExp)) u]
-> Body (Rep m) -> m (Body (Rep m))
addContextForBranch [ExtType]
ts Body (Rep m)
defbody
let ts' :: [ExtType]
ts' = forall a. Int -> a -> [a]
replicate (forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall u. [TypeBase (ShapeBase (Ext SubExp)) u] -> Set Int
shapeContext [ExtType]
ts)) (forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64) forall a. [a] -> [a] -> [a]
++ [ExtType]
ts
([SubExp]
ses', [Case (Body (Rep m))]
cases'') = forall b. [SubExp] -> [Case b] -> ([SubExp], [Case b])
removeRedundantScrutinees [SubExp]
ses [Case (Body (Rep m))]
cases'
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
ses' [Case (Body (Rep m))]
cases'' Body (Rep m)
defbody' forall a b. (a -> b) -> a -> b
$ forall rt. [rt] -> MatchSort -> MatchDec rt
MatchDec [ExtType]
ts' MatchSort
sort
where
addContextForBranch :: [TypeBase (ShapeBase (Ext SubExp)) u]
-> Body (Rep m) -> m (Body (Rep m))
addContextForBranch [TypeBase (ShapeBase (Ext SubExp)) u]
ts (Body BodyDec (Rep m)
_ Stms (Rep m)
stms Result
val_res) = do
[Type]
body_ts <- forall rep (m :: * -> *) a.
ExtendedScope rep m a -> Scope rep -> m a
extendedScope (forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall t (m :: * -> *). HasScope t m => SubExpRes -> m Type
subExpResType Result
val_res) Scope (Rep m)
stmsscope
let ctx_res :: [SubExp]
ctx_res =
forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [(k, a)]
M.toList forall a b. (a -> b) -> a -> b
$ forall u u1.
[TypeBase (ShapeBase (Ext SubExp)) u]
-> [TypeBase Shape u1] -> Map Int SubExp
shapeExtMapping [TypeBase (ShapeBase (Ext SubExp)) u]
ts [Type]
body_ts
forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> Result -> m (Body (Rep m))
mkBodyM Stms (Rep m)
stms forall a b. (a -> b) -> a -> b
$ [SubExp] -> Result
subExpsRes [SubExp]
ctx_res forall a. [a] -> [a] -> [a]
++ Result
val_res
where
stmsscope :: Scope (Rep m)
stmsscope = forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms (Rep m)
stms
eMatch ::
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
[SubExp] ->
[Case (m (Body (Rep m)))] ->
m (Body (Rep m)) ->
m (Exp (Rep m))
eMatch :: forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
[SubExp]
-> [Case (m (Body (Rep m)))] -> m (Body (Rep m)) -> m (Exp (Rep m))
eMatch [SubExp]
ses [Case (m (Body (Rep m)))]
cases_m m (Body (Rep m))
defbody_m = forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
[SubExp]
-> [Case (m (Body (Rep m)))]
-> m (Body (Rep m))
-> MatchSort
-> m (Exp (Rep m))
eMatch' [SubExp]
ses [Case (m (Body (Rep m)))]
cases_m m (Body (Rep m))
defbody_m MatchSort
MatchNormal
eIf ::
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m)) ->
m (Body (Rep m)) ->
m (Body (Rep m)) ->
m (Exp (Rep m))
eIf :: forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf m (Exp (Rep m))
ce m (Body (Rep m))
te m (Body (Rep m))
fe = forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m))
-> m (Body (Rep m))
-> MatchSort
-> m (Exp (Rep m))
eIf' m (Exp (Rep m))
ce m (Body (Rep m))
te m (Body (Rep m))
fe MatchSort
MatchNormal
eIf' ::
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m)) ->
m (Body (Rep m)) ->
m (Body (Rep m)) ->
MatchSort ->
m (Exp (Rep m))
eIf' :: forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m))
-> m (Body (Rep m))
-> MatchSort
-> m (Exp (Rep m))
eIf' m (Exp (Rep m))
ce m (Body (Rep m))
te m (Body (Rep m))
fe MatchSort
if_sort = do
SubExp
ce' <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"cond" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Exp (Rep m))
ce
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
[SubExp]
-> [Case (m (Body (Rep m)))]
-> m (Body (Rep m))
-> MatchSort
-> m (Exp (Rep m))
eMatch' [SubExp
ce'] [forall body. [Maybe PrimValue] -> body -> Case body
Case [forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
True] m (Body (Rep m))
te] m (Body (Rep m))
fe MatchSort
if_sort
bodyExtType :: (HasScope rep m, Monad m) => Body rep -> m [ExtType]
bodyExtType :: forall rep (m :: * -> *).
(HasScope rep m, Monad m) =>
Body rep -> m [ExtType]
bodyExtType (Body BodyDec rep
_ Stms rep
stms Result
res) =
[VName] -> [ExtType] -> [ExtType]
existentialiseExtTypes (forall k a. Map k a -> [k]
M.keys Scope rep
stmsscope) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall u.
[TypeBase Shape u] -> [TypeBase (ShapeBase (Ext SubExp)) u]
staticShapes
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep (m :: * -> *) a.
ExtendedScope rep m a -> Scope rep -> m a
extendedScope (forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall t (m :: * -> *). HasScope t m => SubExpRes -> m Type
subExpResType Result
res) Scope rep
stmsscope
where
stmsscope :: Scope rep
stmsscope = forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms rep
stms
eBinOp ::
(MonadBuilder m) =>
BinOp ->
m (Exp (Rep m)) ->
m (Exp (Rep m)) ->
m (Exp (Rep m))
eBinOp :: forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp BinOp
op m (Exp (Rep m))
x m (Exp (Rep m))
y = do
SubExp
x' <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"x" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Exp (Rep m))
x
SubExp
y' <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"y" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Exp (Rep m))
y
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
op SubExp
x' SubExp
y'
eUnOp ::
(MonadBuilder m) =>
UnOp ->
m (Exp (Rep m)) ->
m (Exp (Rep m))
eUnOp :: forall (m :: * -> *).
MonadBuilder m =>
UnOp -> m (Exp (Rep m)) -> m (Exp (Rep m))
eUnOp UnOp
op m (Exp (Rep m))
x = forall rep. BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. UnOp -> SubExp -> BasicOp
UnOp UnOp
op forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"x" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Exp (Rep m))
x)
eCmpOp ::
(MonadBuilder m) =>
CmpOp ->
m (Exp (Rep m)) ->
m (Exp (Rep m)) ->
m (Exp (Rep m))
eCmpOp :: forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp CmpOp
op m (Exp (Rep m))
x m (Exp (Rep m))
y = do
SubExp
x' <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"x" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Exp (Rep m))
x
SubExp
y' <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"y" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Exp (Rep m))
y
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp CmpOp
op SubExp
x' SubExp
y'
eConvOp ::
(MonadBuilder m) =>
ConvOp ->
m (Exp (Rep m)) ->
m (Exp (Rep m))
eConvOp :: forall (m :: * -> *).
MonadBuilder m =>
ConvOp -> m (Exp (Rep m)) -> m (Exp (Rep m))
eConvOp ConvOp
op m (Exp (Rep m))
x = do
SubExp
x' <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"x" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Exp (Rep m))
x
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp ConvOp
op SubExp
x'
eSignum ::
(MonadBuilder m) =>
m (Exp (Rep m)) ->
m (Exp (Rep m))
eSignum :: forall (m :: * -> *).
MonadBuilder m =>
m (Exp (Rep m)) -> m (Exp (Rep m))
eSignum m (Exp (Rep m))
em = do
Exp (Rep m)
e <- m (Exp (Rep m))
em
SubExp
e' <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"signum_arg" Exp (Rep m)
e
Type
t <- forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
e'
case Type
t of
Prim (IntType IntType
int_t) ->
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ UnOp -> SubExp -> BasicOp
UnOp (IntType -> UnOp
SSignum IntType
int_t) SubExp
e'
Type
_ ->
forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"eSignum: operand " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> String
prettyString Exp (Rep m)
e forall a. [a] -> [a] -> [a]
++ String
" has invalid type."
eCopy ::
(MonadBuilder m) =>
m (Exp (Rep m)) ->
m (Exp (Rep m))
eCopy :: forall (m :: * -> *).
MonadBuilder m =>
m (Exp (Rep m)) -> m (Exp (Rep m))
eCopy m (Exp (Rep m))
e = forall rep. BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. Shape -> SubExp -> BasicOp
Replicate forall a. Monoid a => a
mempty forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"copy_arg" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Exp (Rep m))
e)
eBody ::
(MonadBuilder m) =>
[m (Exp (Rep m))] ->
m (Body (Rep m))
eBody :: forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody [m (Exp (Rep m))]
es = forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ forall a b. (a -> b) -> a -> b
$ do
[Exp (Rep m)]
es' <- forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence [m (Exp (Rep m))]
es
[[VName]]
xs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"x") [Exp (Rep m)]
es'
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ [VName] -> Result
varsRes forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
xs
eLambda ::
(MonadBuilder m) =>
Lambda (Rep m) ->
[m (Exp (Rep m))] ->
m [SubExpRes]
eLambda :: forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m Result
eLambda Lambda (Rep m)
lam [m (Exp (Rep m))]
args = do
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ forall {m :: * -> *} {dec}.
MonadBuilder m =>
Param dec -> m (Exp (Rep m)) -> m ()
bindParam (forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda (Rep m)
lam) [m (Exp (Rep m))]
args
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Rep m)
lam
where
bindParam :: Param dec -> m (Exp (Rep m)) -> m ()
bindParam Param dec
param m (Exp (Rep m))
arg = forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. Param dec -> VName
paramName Param dec
param] forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Exp (Rep m))
arg
eDimInBounds :: (MonadBuilder m) => m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eDimInBounds :: forall (m :: * -> *).
MonadBuilder m =>
m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eDimInBounds m (Exp (Rep m))
w m (Exp (Rep m))
i =
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
BinOp
LogAnd
(forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp (IntType -> CmpOp
CmpSle IntType
Int64) (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0)) m (Exp (Rep m))
i)
(forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp (IntType -> CmpOp
CmpSlt IntType
Int64) m (Exp (Rep m))
i m (Exp (Rep m))
w)
eOutOfBounds ::
(MonadBuilder m) =>
VName ->
[m (Exp (Rep m))] ->
m (Exp (Rep m))
eOutOfBounds :: forall (m :: * -> *).
MonadBuilder m =>
VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
eOutOfBounds VName
arr [m (Exp (Rep m))]
is = do
Type
arr_t <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
let ws :: [SubExp]
ws = forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
arr_t
[SubExp]
is' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"write_i") forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence [m (Exp (Rep m))]
is
let checkDim :: SubExp -> SubExp -> m SubExp
checkDim SubExp
w SubExp
i = do
SubExp
less_than_zero <-
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"less_than_zero" forall a b. (a -> b) -> a -> b
$
forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (IntType -> CmpOp
CmpSlt IntType
Int64) SubExp
i (forall v. IsValue v => v -> SubExp
constant (Int64
0 :: Int64))
SubExp
greater_than_size <-
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"greater_than_size" forall a b. (a -> b) -> a -> b
$
forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (IntType -> CmpOp
CmpSle IntType
Int64) SubExp
w SubExp
i
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"outside_bounds_dim" forall a b. (a -> b) -> a -> b
$
forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogOr SubExp
less_than_zero SubExp
greater_than_size
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp BinOp
LogOr (forall v. IsValue v => v -> SubExp
constant Bool
False) forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM forall {m :: * -> *}.
MonadBuilder m =>
SubExp -> SubExp -> m SubExp
checkDim [SubExp]
ws [SubExp]
is'
eIndex :: (MonadBuilder m) => VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
eIndex :: forall (m :: * -> *).
MonadBuilder m =>
VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
eIndex VName
arr [] = forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arr
eIndex VName
arr [m (Exp (Rep m))]
is = do
[SubExp]
is' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"i" =<<) [m (Exp (Rep m))]
is
Type
arr_t <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
arr_t forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall d. d -> DimIndex d
DimFix [SubExp]
is'
eLast :: (MonadBuilder m) => VName -> m (Exp (Rep m))
eLast :: forall (m :: * -> *). MonadBuilder m => VName -> m (Exp (Rep m))
eLast VName
arr = do
SubExp
n <- forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
SubExp
nm1 <-
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"nm1" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Sub IntType
Int64 Overflow
OverflowUndef) SubExp
n (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)
forall (m :: * -> *).
MonadBuilder m =>
VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
eIndex VName
arr [forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
nm1]
eBlank :: (MonadBuilder m) => Type -> m (Exp (Rep m))
eBlank :: forall (m :: * -> *). MonadBuilder m => Type -> m (Exp (Rep m))
eBlank (Prim PrimType
t) = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
t
eBlank (Array PrimType
t Shape
shape NoUniqueness
_) = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ PrimType -> [SubExp] -> BasicOp
Scratch PrimType
t forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims Shape
shape
eBlank Acc {} = forall a. HasCallStack => String -> a
error String
"eBlank: cannot create blank accumulator"
eBlank Mem {} = forall a. HasCallStack => String -> a
error String
"eBlank: cannot create blank memory"
asIntS :: (MonadBuilder m) => IntType -> SubExp -> m SubExp
asIntS :: forall (m :: * -> *).
MonadBuilder m =>
IntType -> SubExp -> m SubExp
asIntS = forall (m :: * -> *).
MonadBuilder m =>
(IntType -> IntType -> ConvOp) -> IntType -> SubExp -> m SubExp
asInt IntType -> IntType -> ConvOp
SExt
asIntZ :: (MonadBuilder m) => IntType -> SubExp -> m SubExp
asIntZ :: forall (m :: * -> *).
MonadBuilder m =>
IntType -> SubExp -> m SubExp
asIntZ = forall (m :: * -> *).
MonadBuilder m =>
(IntType -> IntType -> ConvOp) -> IntType -> SubExp -> m SubExp
asInt IntType -> IntType -> ConvOp
ZExt
asInt ::
(MonadBuilder m) =>
(IntType -> IntType -> ConvOp) ->
IntType ->
SubExp ->
m SubExp
asInt :: forall (m :: * -> *).
MonadBuilder m =>
(IntType -> IntType -> ConvOp) -> IntType -> SubExp -> m SubExp
asInt IntType -> IntType -> ConvOp
ext IntType
to_it SubExp
e = do
Type
e_t <- forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
e
case Type
e_t of
Prim (IntType IntType
from_it)
| IntType
to_it forall a. Eq a => a -> a -> Bool
== IntType
from_it -> forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
e
| Bool
otherwise -> forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
s forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> IntType -> ConvOp
ext IntType
from_it IntType
to_it) SubExp
e
Type
_ -> forall a. HasCallStack => String -> a
error String
"asInt: wrong type"
where
s :: String
s = case SubExp
e of
Var VName
v -> VName -> String
baseString VName
v
SubExp
_ -> String
"to_" forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> String
prettyString IntType
to_it
foldBinOp ::
(MonadBuilder m) =>
BinOp ->
SubExp ->
[SubExp] ->
m (Exp (Rep m))
foldBinOp :: forall (m :: * -> *).
MonadBuilder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp BinOp
_ SubExp
ne [] =
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
ne
foldBinOp BinOp
bop SubExp
ne (SubExp
e : [SubExp]
es) =
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp BinOp
bop (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
e) (forall (m :: * -> *).
MonadBuilder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp BinOp
bop SubExp
ne [SubExp]
es)
eAll :: (MonadBuilder m) => [SubExp] -> m (Exp (Rep m))
eAll :: forall (m :: * -> *). MonadBuilder m => [SubExp] -> m (Exp (Rep m))
eAll [] = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ forall v. IsValue v => v -> SubExp
constant Bool
True
eAll [SubExp
x] = forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
x
eAll (SubExp
x : [SubExp]
xs) = forall (m :: * -> *).
MonadBuilder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp BinOp
LogAnd SubExp
x [SubExp]
xs
eAny :: (MonadBuilder m) => [SubExp] -> m (Exp (Rep m))
eAny :: forall (m :: * -> *). MonadBuilder m => [SubExp] -> m (Exp (Rep m))
eAny [] = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ forall v. IsValue v => v -> SubExp
constant Bool
False
eAny [SubExp
x] = forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
x
eAny (SubExp
x : [SubExp]
xs) = forall (m :: * -> *).
MonadBuilder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp BinOp
LogOr SubExp
x [SubExp]
xs
binOpLambda ::
(MonadBuilder m, Buildable (Rep m)) =>
BinOp ->
PrimType ->
m (Lambda (Rep m))
binOpLambda :: forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m)) =>
BinOp -> PrimType -> m (Lambda (Rep m))
binOpLambda BinOp
bop PrimType
t = forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m)) =>
(SubExp -> SubExp -> BasicOp)
-> PrimType -> PrimType -> m (Lambda (Rep m))
binLambda (BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
bop) PrimType
t PrimType
t
cmpOpLambda ::
(MonadBuilder m, Buildable (Rep m)) =>
CmpOp ->
m (Lambda (Rep m))
cmpOpLambda :: forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m)) =>
CmpOp -> m (Lambda (Rep m))
cmpOpLambda CmpOp
cop = forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m)) =>
(SubExp -> SubExp -> BasicOp)
-> PrimType -> PrimType -> m (Lambda (Rep m))
binLambda (CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp CmpOp
cop) (CmpOp -> PrimType
cmpOpType CmpOp
cop) PrimType
Bool
binLambda ::
(MonadBuilder m, Buildable (Rep m)) =>
(SubExp -> SubExp -> BasicOp) ->
PrimType ->
PrimType ->
m (Lambda (Rep m))
binLambda :: forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m)) =>
(SubExp -> SubExp -> BasicOp)
-> PrimType -> PrimType -> m (Lambda (Rep m))
binLambda SubExp -> SubExp -> BasicOp
bop PrimType
arg_t PrimType
ret_t = do
VName
x <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"x"
VName
y <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"y"
Body (Rep m)
body <-
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> SubExpRes
subExpRes) forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"binlam_res" forall a b. (a -> b) -> a -> b
$
forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
SubExp -> SubExp -> BasicOp
bop (VName -> SubExp
Var VName
x) (VName -> SubExp
Var VName
y)
forall (f :: * -> *) a. Applicative f => a -> f a
pure
Lambda
{ lambdaParams :: [LParam (Rep m)]
lambdaParams =
[ forall dec. Attrs -> VName -> dec -> Param dec
Param forall a. Monoid a => a
mempty VName
x (forall shape u. PrimType -> TypeBase shape u
Prim PrimType
arg_t),
forall dec. Attrs -> VName -> dec -> Param dec
Param forall a. Monoid a => a
mempty VName
y (forall shape u. PrimType -> TypeBase shape u
Prim PrimType
arg_t)
],
lambdaReturnType :: [Type]
lambdaReturnType = [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
ret_t],
lambdaBody :: Body (Rep m)
lambdaBody = Body (Rep m)
body
}
mkLambda ::
(MonadBuilder m) =>
[LParam (Rep m)] ->
m Result ->
m (Lambda (Rep m))
mkLambda :: forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda [LParam (Rep m)]
params m Result
m = do
(Body (Rep m)
body, [Type]
ret) <- forall (m :: * -> *) a.
MonadBuilder m =>
m (Result, a) -> m (Body (Rep m), a)
buildBody forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams [LParam (Rep m)]
params) forall a b. (a -> b) -> a -> b
$ do
Result
res <- m Result
m
[Type]
ret <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall t (m :: * -> *). HasScope t m => SubExpRes -> m Type
subExpResType Result
res
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result
res, [Type]
ret)
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. [LParam rep] -> [Type] -> Body rep -> Lambda rep
Lambda [LParam (Rep m)]
params [Type]
ret Body (Rep m)
body
sliceDim :: SubExp -> DimIndex SubExp
sliceDim :: SubExp -> DimIndex SubExp
sliceDim SubExp
d = forall d. d -> d -> d -> DimIndex d
DimSlice (forall v. IsValue v => v -> SubExp
constant (Int64
0 :: Int64)) SubExp
d (forall v. IsValue v => v -> SubExp
constant (Int64
1 :: Int64))
fullSlice :: Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice :: Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
t [DimIndex SubExp]
slice =
forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp]
slice forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
sliceDim (forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex SubExp]
slice) forall a b. (a -> b) -> a -> b
$ forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t)
sliceAt :: Type -> Int -> [DimIndex SubExp] -> Slice SubExp
sliceAt :: Type -> Int -> [DimIndex SubExp] -> Slice SubExp
sliceAt Type
t Int
n [DimIndex SubExp]
slice =
Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
t forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
sliceDim (forall a. Int -> [a] -> [a]
take Int
n forall a b. (a -> b) -> a -> b
$ forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t) forall a. [a] -> [a] -> [a]
++ [DimIndex SubExp]
slice
fullSliceNum :: (Num d) => [d] -> [DimIndex d] -> Slice d
fullSliceNum :: forall d. Num d => [d] -> [DimIndex d] -> Slice d
fullSliceNum [d]
dims [DimIndex d]
slice =
forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$ [DimIndex d]
slice forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map (\d
d -> forall d. d -> d -> d -> DimIndex d
DimSlice d
0 d
d d
1) (forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex d]
slice) [d]
dims)
isFullSlice :: Shape -> Slice SubExp -> Bool
isFullSlice :: Shape -> Slice SubExp -> Bool
isFullSlice Shape
shape Slice SubExp
slice = forall (t :: * -> *). Foldable t => t Bool -> Bool
and forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SubExp -> DimIndex SubExp -> Bool
allOfIt (forall d. ShapeBase d -> [d]
shapeDims Shape
shape) (forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slice)
where
allOfIt :: SubExp -> DimIndex SubExp -> Bool
allOfIt (Constant PrimValue
v) DimFix {} = PrimValue -> Bool
oneIsh PrimValue
v
allOfIt SubExp
d (DimSlice SubExp
_ SubExp
n SubExp
_) = SubExp
d forall a. Eq a => a -> a -> Bool
== SubExp
n
allOfIt SubExp
_ DimIndex SubExp
_ = Bool
False
resultBody :: (Buildable rep) => [SubExp] -> Body rep
resultBody :: forall rep. Buildable rep => [SubExp] -> Body rep
resultBody = forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody forall a. Monoid a => a
mempty forall b c a. (b -> c) -> (a -> b) -> a -> c
. [SubExp] -> Result
subExpsRes
resultBodyM :: (MonadBuilder m) => [SubExp] -> m (Body (Rep m))
resultBodyM :: forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM = forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> Result -> m (Body (Rep m))
mkBodyM forall a. Monoid a => a
mempty forall b c a. (b -> c) -> (a -> b) -> a -> c
. [SubExp] -> Result
subExpsRes
insertStmsM ::
(MonadBuilder m) =>
m (Body (Rep m)) ->
m (Body (Rep m))
insertStmsM :: forall (m :: * -> *).
MonadBuilder m =>
m (Body (Rep m)) -> m (Body (Rep m))
insertStmsM m (Body (Rep m))
m = do
(Body BodyDec (Rep m)
_ Stms (Rep m)
stms Result
res, Stms (Rep m)
otherstms) <- forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms m (Body (Rep m))
m
forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> Result -> m (Body (Rep m))
mkBodyM (Stms (Rep m)
otherstms forall a. Semigroup a => a -> a -> a
<> Stms (Rep m)
stms) Result
res
buildBody ::
(MonadBuilder m) =>
m (Result, a) ->
m (Body (Rep m), a)
buildBody :: forall (m :: * -> *) a.
MonadBuilder m =>
m (Result, a) -> m (Body (Rep m), a)
buildBody m (Result, a)
m = do
((Result
res, a
v), Stms (Rep m)
stms) <- forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms m (Result, a)
m
Body (Rep m)
body <- forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> Result -> m (Body (Rep m))
mkBodyM Stms (Rep m)
stms Result
res
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body (Rep m)
body, a
v)
buildBody_ ::
(MonadBuilder m) =>
m Result ->
m (Body (Rep m))
buildBody_ :: forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ m Result
m = forall a b. (a, b) -> a
fst forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a.
MonadBuilder m =>
m (Result, a) -> m (Body (Rep m), a)
buildBody ((,()) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m Result
m)
mapResult ::
(Buildable rep) =>
(Result -> Body rep) ->
Body rep ->
Body rep
mapResult :: forall rep.
Buildable rep =>
(Result -> Body rep) -> Body rep -> Body rep
mapResult Result -> Body rep
f (Body BodyDec rep
_ Stms rep
stms Result
res) =
let Body BodyDec rep
_ Stms rep
stms2 Result
newres = Result -> Body rep
f Result
res
in forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody (Stms rep
stms forall a. Semigroup a => a -> a -> a
<> Stms rep
stms2) Result
newres
instantiateShapes ::
(Monad m) =>
(Int -> m SubExp) ->
[TypeBase ExtShape u] ->
m [TypeBase Shape u]
instantiateShapes :: forall (m :: * -> *) u.
Monad m =>
(Int -> m SubExp)
-> [TypeBase (ShapeBase (Ext SubExp)) u] -> m [TypeBase Shape u]
instantiateShapes Int -> m SubExp
f [TypeBase (ShapeBase (Ext SubExp)) u]
ts = forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT (forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM TypeBase (ShapeBase (Ext SubExp)) u
-> StateT (Map Int SubExp) m (TypeBase Shape u)
instantiate [TypeBase (ShapeBase (Ext SubExp)) u]
ts) forall k a. Map k a
M.empty
where
instantiate :: TypeBase (ShapeBase (Ext SubExp)) u
-> StateT (Map Int SubExp) m (TypeBase Shape u)
instantiate TypeBase (ShapeBase (Ext SubExp)) u
t = do
[SubExp]
shape <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Ext SubExp -> StateT (Map Int SubExp) m SubExp
instantiate' forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims forall a b. (a -> b) -> a -> b
$ forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase (ShapeBase (Ext SubExp)) u
t
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ TypeBase (ShapeBase (Ext SubExp)) u
t forall newshape oldshape u.
ArrayShape newshape =>
TypeBase oldshape u -> newshape -> TypeBase newshape u
`setArrayShape` forall d. [d] -> ShapeBase d
Shape [SubExp]
shape
instantiate' :: Ext SubExp -> StateT (Map Int SubExp) m SubExp
instantiate' (Ext Int
x) = do
Map Int SubExp
m <- forall s (m :: * -> *). MonadState s m => m s
get
case forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Int
x Map Int SubExp
m of
Just SubExp
se -> forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
se
Maybe SubExp
Nothing -> do
SubExp
se <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ Int -> m SubExp
f Int
x
forall s (m :: * -> *). MonadState s m => s -> m ()
put forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Int
x SubExp
se Map Int SubExp
m
forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
se
instantiate' (Free SubExp
se) = forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
se
instantiateShapes' :: [VName] -> [TypeBase ExtShape u] -> [TypeBase Shape u]
instantiateShapes' :: forall u.
[VName]
-> [TypeBase (ShapeBase (Ext SubExp)) u] -> [TypeBase Shape u]
instantiateShapes' [VName]
names [TypeBase (ShapeBase (Ext SubExp)) u]
ts =
forall a. Identity a -> a
runIdentity forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) u.
Monad m =>
(Int -> m SubExp)
-> [TypeBase (ShapeBase (Ext SubExp)) u] -> m [TypeBase Shape u]
instantiateShapes Int -> Identity SubExp
instantiate [TypeBase (ShapeBase (Ext SubExp)) u]
ts
where
instantiate :: Int -> Identity SubExp
instantiate Int
x =
case forall int a. Integral int => int -> [a] -> Maybe a
maybeNth Int
x [VName]
names of
Maybe VName
Nothing -> forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"instantiateShapes': " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> String
prettyString [VName]
names forall a. [a] -> [a] -> [a]
++ String
", " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Int
x
Just VName
name -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
name
removeExistentials :: ExtType -> Type -> Type
removeExistentials :: ExtType -> Type -> Type
removeExistentials ExtType
t1 Type
t2 =
ExtType
t1
forall oldshape u.
TypeBase oldshape u -> [SubExp] -> TypeBase Shape u
`setArrayDims` forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
forall {p}. Ext p -> p -> p
nonExistential
(forall d. ShapeBase d -> [d]
shapeDims forall a b. (a -> b) -> a -> b
$ forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape ExtType
t1)
(forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t2)
where
nonExistential :: Ext p -> p -> p
nonExistential (Ext Int
_) p
dim = p
dim
nonExistential (Free p
dim) p
_ = p
dim
simpleMkLetNames ::
( ExpDec rep ~ (),
LetDec rep ~ Type,
MonadFreshNames m,
TypedOp (Op rep),
HasScope rep m
) =>
[VName] ->
Exp rep ->
m (Stm rep)
simpleMkLetNames :: forall rep (m :: * -> *).
(ExpDec rep ~ (), LetDec rep ~ Type, MonadFreshNames m,
TypedOp (Op rep), HasScope rep m) =>
[VName] -> Exp rep -> m (Stm rep)
simpleMkLetNames [VName]
names Exp rep
e = do
[ExtType]
et <- forall rep (m :: * -> *).
(HasScope rep m, TypedOp (Op rep)) =>
Exp rep -> m [ExtType]
expExtType Exp rep
e
let ts :: [Type]
ts = forall u.
[VName]
-> [TypeBase (ShapeBase (Ext SubExp)) u] -> [TypeBase Shape u]
instantiateShapes' [VName]
names [ExtType]
et
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (forall dec. [PatElem dec] -> Pat dec
Pat forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall dec. VName -> dec -> PatElem dec
PatElem [VName]
names [Type]
ts) (forall dec. dec -> StmAux dec
defAux ()) Exp rep
e
class ToExp a where
toExp :: (MonadBuilder m) => a -> m (Exp (Rep m))
instance ToExp SubExp where
toExp :: forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
toExp = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp
instance ToExp VName where
toExp :: forall (m :: * -> *). MonadBuilder m => VName -> m (Exp (Rep m))
toExp = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var
toSubExp :: (MonadBuilder m, ToExp a) => String -> a -> m SubExp
toSubExp :: forall (m :: * -> *) a.
(MonadBuilder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
s a
e = forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
s forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp a
e