{-# LANGUAGE TypeFamilies #-}

-- | = Constructing Futhark ASTs
--
-- This module re-exports and defines a bunch of building blocks for
-- constructing fragments of Futhark ASTs.  More importantly, it also
-- contains a basic introduction on how to use them.
--
-- The "Futhark.IR.Syntax" module contains the core
-- AST definition.  One important invariant is that all bound names in
-- a Futhark program must be /globally/ unique.  In principle, you
-- could use the facilities from "Futhark.MonadFreshNames" (or your
-- own bespoke source of unique names) to manually construct
-- expressions, statements, and entire ASTs.  In practice, this would
-- be very tedious.  Instead, we have defined a collection of building
-- blocks (centered around the 'MonadBuilder' type class) that permits
-- a more abstract way of generating code.
--
-- Constructing ASTs with these building blocks requires you to ensure
-- that all free variables are in scope.  See
-- "Futhark.IR.Prop.Scope".
--
-- == 'MonadBuilder'
--
-- A monad that implements 'MonadBuilder' tracks the statements added
-- so far, the current names in scope, and allows you to add
-- additional statements with 'addStm'.  Any monad that implements
-- 'MonadBuilder' also implements the t'Rep' type family, which
-- indicates which rep it works with.  Inside a 'MonadBuilder' we can
-- use 'collectStms' to gather up the 'Stms' added with 'addStm' in
-- some nested computation.
--
-- The 'BuilderT' monad (and its convenient 'Builder' version) provides
-- the simplest implementation of 'MonadBuilder'.
--
-- == Higher-level building blocks
--
-- On top of the raw facilities provided by 'MonadBuilder', we have
-- more convenient facilities.  For example, 'letSubExp' lets us
-- conveniently create a 'Stm' for an 'Exp' that produces a /single/
-- value, and returns the (fresh) name for the resulting variable:
--
-- @
-- z <- letExp "z" $ BasicOp $ BinOp (Add Int32) (Var x) (Var y)
-- @
--
-- == Monadic expression builders
--
-- This module also contains "monadic expression" functions that let
-- us build nested expressions in a "direct" style, rather than using
-- 'letExp' and friends to bind every sub-part first.  See functions
-- such as 'eIf' and 'eBody' for example.  See also
-- "Futhark.Analysis.PrimExp" and the 'ToExp' type class.
--
-- == Examples
--
-- The "Futhark.Transform.FirstOrderTransform" module is a
-- (relatively) simple example of how to use these components.  As are
-- some of the high-level building blocks in this very module.
module Futhark.Construct
  ( -- * Basic building blocks
    module Futhark.Builder,
    letSubExp,
    letExp,
    letTupExp,
    letTupExp',
    letInPlace,

    -- * Monadic expression builders
    eSubExp,
    eParam,
    eMatch',
    eMatch,
    eIf,
    eIf',
    eBinOp,
    eUnOp,
    eCmpOp,
    eConvOp,
    eSignum,
    eCopy,
    eBody,
    eLambda,
    eBlank,
    eAll,
    eAny,
    eDimInBounds,
    eOutOfBounds,

    -- * Other building blocks
    asIntZ,
    asIntS,
    resultBody,
    resultBodyM,
    insertStmsM,
    buildBody,
    buildBody_,
    mapResult,
    foldBinOp,
    binOpLambda,
    cmpOpLambda,
    mkLambda,
    sliceDim,
    fullSlice,
    fullSliceNum,
    isFullSlice,
    sliceAt,

    -- * Result types
    instantiateShapes,
    instantiateShapes',
    removeExistentials,

    -- * Convenience
    simpleMkLetNames,
    ToExp (..),
    toSubExp,
  )
where

import Control.Monad.Identity
import Control.Monad.State
import Data.List (foldl', sortOn, transpose)
import Data.Map.Strict qualified as M
import Futhark.Builder
import Futhark.IR
import Futhark.Util (maybeNth)

-- | @letSubExp desc e@ binds the expression @e@, which must produce a
-- single value.  Returns a t'SubExp' corresponding to the resulting
-- value.  For expressions that produce multiple values, see
-- 'letTupExp'.
letSubExp ::
  MonadBuilder m =>
  String ->
  Exp (Rep m) ->
  m SubExp
letSubExp :: forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
_ (BasicOp (SubExp SubExp
se)) = forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
se
letSubExp String
desc Exp (Rep m)
e = VName -> SubExp
Var forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
desc Exp (Rep m)
e

-- | Like 'letSubExp', but returns a name rather than a t'SubExp'.
letExp ::
  MonadBuilder m =>
  String ->
  Exp (Rep m) ->
  m VName
letExp :: forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
_ (BasicOp (SubExp (Var VName
v))) =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v
letExp String
desc Exp (Rep m)
e = do
  Int
n <- forall (t :: * -> *) a. Foldable t => t a -> Int
length forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) (m :: * -> *).
(HasScope rep m, TypedOp (Op rep)) =>
Exp rep -> m [ExtType]
expExtType Exp (Rep m)
e
  [VName]
vs <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
desc
  forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName]
vs Exp (Rep m)
e
  case [VName]
vs of
    [VName
v] -> forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v
    [VName]
_ -> forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"letExp: tuple-typed expression given:\n" forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> String
prettyString Exp (Rep m)
e

-- | Like 'letExp', but the 'VName' and 'Slice' denote an array that
-- is 'Update'd with the result of the expression.  The name of the
-- updated array is returned.
letInPlace ::
  MonadBuilder m =>
  String ->
  VName ->
  Slice SubExp ->
  Exp (Rep m) ->
  m VName
letInPlace :: forall (m :: * -> *).
MonadBuilder m =>
String -> VName -> Slice SubExp -> Exp (Rep m) -> m VName
letInPlace String
desc VName
src Slice SubExp
slice Exp (Rep m)
e = do
  SubExp
tmp <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp (String
desc forall a. [a] -> [a] -> [a]
++ String
"_tmp") Exp (Rep m)
e
  forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
desc forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update Safety
Unsafe VName
src Slice SubExp
slice SubExp
tmp

-- | Like 'letExp', but the expression may return multiple values.
letTupExp ::
  (MonadBuilder m) =>
  String ->
  Exp (Rep m) ->
  m [VName]
letTupExp :: forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
_ (BasicOp (SubExp (Var VName
v))) =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName
v]
letTupExp String
name Exp (Rep m)
e = do
  [ExtType]
e_t <- forall {k} (rep :: k) (m :: * -> *).
(HasScope rep m, TypedOp (Op rep)) =>
Exp rep -> m [ExtType]
expExtType Exp (Rep m)
e
  [VName]
names <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (forall (t :: * -> *) a. Foldable t => t a -> Int
length [ExtType]
e_t) forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
name
  forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName]
names Exp (Rep m)
e
  forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName]
names

-- | Like 'letTupExp', but returns t'SubExp's instead of 'VName's.
letTupExp' ::
  (MonadBuilder m) =>
  String ->
  Exp (Rep m) ->
  m [SubExp]
letTupExp' :: forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [SubExp]
letTupExp' String
_ (BasicOp (SubExp SubExp
se)) = forall (f :: * -> *) a. Applicative f => a -> f a
pure [SubExp
se]
letTupExp' String
name Exp (Rep m)
ses = forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
name Exp (Rep m)
ses

-- | Turn a subexpression into a monad expression.  Does not actually
-- lead to any code generation.  This is supposed to be used alongside
-- the other monadic expression functions, such as 'eIf'.
eSubExp ::
  MonadBuilder m =>
  SubExp ->
  m (Exp (Rep m))
eSubExp :: forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp

-- | Treat a parameter as a monadic expression.
eParam ::
  MonadBuilder m =>
  Param t ->
  m (Exp (Rep m))
eParam :: forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam = forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> VName
paramName

removeRedundantScrutinees :: [SubExp] -> [Case b] -> ([SubExp], [Case b])
removeRedundantScrutinees :: forall b. [SubExp] -> [Case b] -> ([SubExp], [Case b])
removeRedundantScrutinees [SubExp]
ses [Case b]
cases =
  let ([SubExp]
ses', [[Maybe PrimValue]]
vs) =
        forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter forall {a}. (a, [Maybe PrimValue]) -> Bool
interesting forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
ses forall a b. (a -> b) -> a -> b
$ forall a. [[a]] -> [[a]]
transpose (forall a b. (a -> b) -> [a] -> [b]
map forall body. Case body -> [Maybe PrimValue]
casePat [Case b]
cases)
   in ([SubExp]
ses', forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall body. [Maybe PrimValue] -> body -> Case body
Case (forall a. [[a]] -> [[a]]
transpose [[Maybe PrimValue]]
vs) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall body. Case body -> body
caseBody [Case b]
cases)
  where
    interesting :: (a, [Maybe PrimValue]) -> Bool
interesting = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (forall a. Eq a => a -> a -> Bool
/= forall a. Maybe a
Nothing) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd

-- | As 'eMatch', but an 'MatchSort' can be given.
eMatch' ::
  (MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
  [SubExp] ->
  [Case (m (Body (Rep m)))] ->
  m (Body (Rep m)) ->
  MatchSort ->
  m (Exp (Rep m))
eMatch' :: forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
[SubExp]
-> [Case (m (Body (Rep m)))]
-> m (Body (Rep m))
-> MatchSort
-> m (Exp (Rep m))
eMatch' [SubExp]
ses [Case (m (Body (Rep m)))]
cases_m m (Body (Rep m))
defbody_m MatchSort
sort = do
  [Case (Body (Rep m))]
cases <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall (m :: * -> *).
MonadBuilder m =>
m (Body (Rep m)) -> m (Body (Rep m))
insertStmsM) [Case (m (Body (Rep m)))]
cases_m
  Body (Rep m)
defbody <- forall (m :: * -> *).
MonadBuilder m =>
m (Body (Rep m)) -> m (Body (Rep m))
insertStmsM m (Body (Rep m))
defbody_m
  [ExtType]
ts <-
    forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' forall u.
[TypeBase (ShapeBase (Ext SubExp)) u]
-> [TypeBase (ShapeBase (Ext SubExp)) u]
-> [TypeBase (ShapeBase (Ext SubExp)) u]
generaliseExtTypes
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) (m :: * -> *).
(HasScope rep m, Monad m) =>
Body rep -> m [ExtType]
bodyExtType Body (Rep m)
defbody
      forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {k} (rep :: k) (m :: * -> *).
(HasScope rep m, Monad m) =>
Body rep -> m [ExtType]
bodyExtType forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall body. Case body -> body
caseBody) [Case (Body (Rep m))]
cases
  [Case (Body (Rep m))]
cases' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall a b. (a -> b) -> a -> b
$ forall {m :: * -> *} {u}.
MonadBuilder m =>
[TypeBase (ShapeBase (Ext SubExp)) u]
-> Body (Rep m) -> m (Body (Rep m))
addContextForBranch [ExtType]
ts) [Case (Body (Rep m))]
cases
  Body (Rep m)
defbody' <- forall {m :: * -> *} {u}.
MonadBuilder m =>
[TypeBase (ShapeBase (Ext SubExp)) u]
-> Body (Rep m) -> m (Body (Rep m))
addContextForBranch [ExtType]
ts Body (Rep m)
defbody
  let ts' :: [ExtType]
ts' = forall a. Int -> a -> [a]
replicate (forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall u. [TypeBase (ShapeBase (Ext SubExp)) u] -> Set Int
shapeContext [ExtType]
ts)) (forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64) forall a. [a] -> [a] -> [a]
++ [ExtType]
ts
      ([SubExp]
ses', [Case (Body (Rep m))]
cases'') = forall b. [SubExp] -> [Case b] -> ([SubExp], [Case b])
removeRedundantScrutinees [SubExp]
ses [Case (Body (Rep m))]
cases'
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
ses' [Case (Body (Rep m))]
cases'' Body (Rep m)
defbody' forall a b. (a -> b) -> a -> b
$ forall rt. [rt] -> MatchSort -> MatchDec rt
MatchDec [ExtType]
ts' MatchSort
sort
  where
    addContextForBranch :: [TypeBase (ShapeBase (Ext SubExp)) u]
-> Body (Rep m) -> m (Body (Rep m))
addContextForBranch [TypeBase (ShapeBase (Ext SubExp)) u]
ts (Body BodyDec (Rep m)
_ Stms (Rep m)
stms Result
val_res) = do
      [Type]
body_ts <- forall {k} (rep :: k) (m :: * -> *) a.
ExtendedScope rep m a -> Scope rep -> m a
extendedScope (forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall {k} (t :: k) (m :: * -> *).
HasScope t m =>
SubExpRes -> m Type
subExpResType Result
val_res) Scope (Rep m)
stmsscope
      let ctx_res :: [SubExp]
ctx_res =
            forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [(k, a)]
M.toList forall a b. (a -> b) -> a -> b
$ forall u u1.
[TypeBase (ShapeBase (Ext SubExp)) u]
-> [TypeBase Shape u1] -> Map Int SubExp
shapeExtMapping [TypeBase (ShapeBase (Ext SubExp)) u]
ts [Type]
body_ts
      forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> Result -> m (Body (Rep m))
mkBodyM Stms (Rep m)
stms forall a b. (a -> b) -> a -> b
$ [SubExp] -> Result
subExpsRes [SubExp]
ctx_res forall a. [a] -> [a] -> [a]
++ Result
val_res
      where
        stmsscope :: Scope (Rep m)
stmsscope = forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf Stms (Rep m)
stms

-- | Construct a 'Match' expression.  The main convenience here is
-- that the existential context of the return type is automatically
-- deduced, and the necessary elements added to the branches.
eMatch ::
  (MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
  [SubExp] ->
  [Case (m (Body (Rep m)))] ->
  m (Body (Rep m)) ->
  m (Exp (Rep m))
eMatch :: forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
[SubExp]
-> [Case (m (Body (Rep m)))] -> m (Body (Rep m)) -> m (Exp (Rep m))
eMatch [SubExp]
ses [Case (m (Body (Rep m)))]
cases_m m (Body (Rep m))
defbody_m = forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
[SubExp]
-> [Case (m (Body (Rep m)))]
-> m (Body (Rep m))
-> MatchSort
-> m (Exp (Rep m))
eMatch' [SubExp]
ses [Case (m (Body (Rep m)))]
cases_m m (Body (Rep m))
defbody_m MatchSort
MatchNormal

-- | Construct a 'Match' modelling an if-expression from a monadic
-- condition and monadic branches.  'eBody' might be convenient for
-- constructing the branches.
eIf ::
  (MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
  m (Exp (Rep m)) ->
  m (Body (Rep m)) ->
  m (Body (Rep m)) ->
  m (Exp (Rep m))
eIf :: forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf m (Exp (Rep m))
ce m (Body (Rep m))
te m (Body (Rep m))
fe = forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m))
-> m (Body (Rep m))
-> MatchSort
-> m (Exp (Rep m))
eIf' m (Exp (Rep m))
ce m (Body (Rep m))
te m (Body (Rep m))
fe MatchSort
MatchNormal

-- | As 'eIf', but an 'MatchSort' can be given.
eIf' ::
  (MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
  m (Exp (Rep m)) ->
  m (Body (Rep m)) ->
  m (Body (Rep m)) ->
  MatchSort ->
  m (Exp (Rep m))
eIf' :: forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m))
-> m (Body (Rep m))
-> MatchSort
-> m (Exp (Rep m))
eIf' m (Exp (Rep m))
ce m (Body (Rep m))
te m (Body (Rep m))
fe MatchSort
if_sort = do
  SubExp
ce' <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"cond" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Exp (Rep m))
ce
  forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
[SubExp]
-> [Case (m (Body (Rep m)))]
-> m (Body (Rep m))
-> MatchSort
-> m (Exp (Rep m))
eMatch' [SubExp
ce'] [forall body. [Maybe PrimValue] -> body -> Case body
Case [forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
True] m (Body (Rep m))
te] m (Body (Rep m))
fe MatchSort
if_sort

-- The type of a body.  Watch out: this only works for the degenerate
-- case where the body does not already return its context.
bodyExtType :: (HasScope rep m, Monad m) => Body rep -> m [ExtType]
bodyExtType :: forall {k} (rep :: k) (m :: * -> *).
(HasScope rep m, Monad m) =>
Body rep -> m [ExtType]
bodyExtType (Body BodyDec rep
_ Stms rep
stms Result
res) =
  [VName] -> [ExtType] -> [ExtType]
existentialiseExtTypes (forall k a. Map k a -> [k]
M.keys Scope rep
stmsscope) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall u.
[TypeBase Shape u] -> [TypeBase (ShapeBase (Ext SubExp)) u]
staticShapes
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) (m :: * -> *) a.
ExtendedScope rep m a -> Scope rep -> m a
extendedScope (forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall {k} (t :: k) (m :: * -> *).
HasScope t m =>
SubExpRes -> m Type
subExpResType Result
res) Scope rep
stmsscope
  where
    stmsscope :: Scope rep
stmsscope = forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf Stms rep
stms

-- | Construct a v'BinOp' expression with the given operator.
eBinOp ::
  MonadBuilder m =>
  BinOp ->
  m (Exp (Rep m)) ->
  m (Exp (Rep m)) ->
  m (Exp (Rep m))
eBinOp :: forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp BinOp
op m (Exp (Rep m))
x m (Exp (Rep m))
y = do
  SubExp
x' <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"x" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Exp (Rep m))
x
  SubExp
y' <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"y" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Exp (Rep m))
y
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
op SubExp
x' SubExp
y'

-- | Construct a v'UnOp' expression with the given operator.
eUnOp ::
  MonadBuilder m =>
  UnOp ->
  m (Exp (Rep m)) ->
  m (Exp (Rep m))
eUnOp :: forall (m :: * -> *).
MonadBuilder m =>
UnOp -> m (Exp (Rep m)) -> m (Exp (Rep m))
eUnOp UnOp
op m (Exp (Rep m))
x = forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. UnOp -> SubExp -> BasicOp
UnOp UnOp
op forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"x" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Exp (Rep m))
x)

-- | Construct a v'CmpOp' expression with the given comparison.
eCmpOp ::
  MonadBuilder m =>
  CmpOp ->
  m (Exp (Rep m)) ->
  m (Exp (Rep m)) ->
  m (Exp (Rep m))
eCmpOp :: forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp CmpOp
op m (Exp (Rep m))
x m (Exp (Rep m))
y = do
  SubExp
x' <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"x" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Exp (Rep m))
x
  SubExp
y' <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"y" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Exp (Rep m))
y
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp CmpOp
op SubExp
x' SubExp
y'

-- | Construct a v'ConvOp' expression with the given conversion.
eConvOp ::
  MonadBuilder m =>
  ConvOp ->
  m (Exp (Rep m)) ->
  m (Exp (Rep m))
eConvOp :: forall (m :: * -> *).
MonadBuilder m =>
ConvOp -> m (Exp (Rep m)) -> m (Exp (Rep m))
eConvOp ConvOp
op m (Exp (Rep m))
x = do
  SubExp
x' <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"x" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Exp (Rep m))
x
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp ConvOp
op SubExp
x'

-- | Construct a 'SSignum' expression.  Fails if the provided
-- expression is not of integer type.
eSignum ::
  MonadBuilder m =>
  m (Exp (Rep m)) ->
  m (Exp (Rep m))
eSignum :: forall (m :: * -> *).
MonadBuilder m =>
m (Exp (Rep m)) -> m (Exp (Rep m))
eSignum m (Exp (Rep m))
em = do
  Exp (Rep m)
e <- m (Exp (Rep m))
em
  SubExp
e' <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"signum_arg" Exp (Rep m)
e
  Type
t <- forall {k} (t :: k) (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
e'
  case Type
t of
    Prim (IntType IntType
int_t) ->
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ UnOp -> SubExp -> BasicOp
UnOp (IntType -> UnOp
SSignum IntType
int_t) SubExp
e'
    Type
_ ->
      forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"eSignum: operand " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> String
prettyString Exp (Rep m)
e forall a. [a] -> [a] -> [a]
++ String
" has invalid type."

-- | Construct a 'Copy' expression.
eCopy ::
  MonadBuilder m =>
  m (Exp (Rep m)) ->
  m (Exp (Rep m))
eCopy :: forall (m :: * -> *).
MonadBuilder m =>
m (Exp (Rep m)) -> m (Exp (Rep m))
eCopy m (Exp (Rep m))
e = forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> BasicOp
Copy forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"copy_arg" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Exp (Rep m))
e)

-- | Construct a body from expressions.  If multiple expressions are
-- provided, their results will be concatenated in order and returned
-- as the result.
--
-- /Beware/: this will not produce correct code if the type of the
-- body would be existential.  That is, the type of the results being
-- returned should be invariant to the body.
eBody ::
  (MonadBuilder m) =>
  [m (Exp (Rep m))] ->
  m (Body (Rep m))
eBody :: forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody [m (Exp (Rep m))]
es = forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ forall a b. (a -> b) -> a -> b
$ do
  [Exp (Rep m)]
es' <- forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence [m (Exp (Rep m))]
es
  [[VName]]
xs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"x") [Exp (Rep m)]
es'
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ [VName] -> Result
varsRes forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
xs

-- | Bind each lambda parameter to the result of an expression, then
-- bind the body of the lambda.  The expressions must produce only a
-- single value each.
eLambda ::
  MonadBuilder m =>
  Lambda (Rep m) ->
  [m (Exp (Rep m))] ->
  m [SubExpRes]
eLambda :: forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m Result
eLambda Lambda (Rep m)
lam [m (Exp (Rep m))]
args = do
  forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ forall {m :: * -> *} {dec}.
MonadBuilder m =>
Param dec -> m (Exp (Rep m)) -> m ()
bindParam (forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda (Rep m)
lam) [m (Exp (Rep m))]
args
  forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda (Rep m)
lam
  where
    bindParam :: Param dec -> m (Exp (Rep m)) -> m ()
bindParam Param dec
param m (Exp (Rep m))
arg = forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. Param dec -> VName
paramName Param dec
param] forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Exp (Rep m))
arg

-- | @eInBoundsForDim w i@ produces @0 <= i < w@.
eDimInBounds :: MonadBuilder m => m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eDimInBounds :: forall (m :: * -> *).
MonadBuilder m =>
m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eDimInBounds m (Exp (Rep m))
w m (Exp (Rep m))
i =
  forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
    BinOp
LogAnd
    (forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp (IntType -> CmpOp
CmpSle IntType
Int64) (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0)) m (Exp (Rep m))
i)
    (forall (m :: * -> *).
MonadBuilder m =>
CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp (IntType -> CmpOp
CmpSlt IntType
Int64) m (Exp (Rep m))
i m (Exp (Rep m))
w)

-- | Are these indexes out-of-bounds for the array?
eOutOfBounds ::
  MonadBuilder m =>
  VName ->
  [m (Exp (Rep m))] ->
  m (Exp (Rep m))
eOutOfBounds :: forall (m :: * -> *).
MonadBuilder m =>
VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
eOutOfBounds VName
arr [m (Exp (Rep m))]
is = do
  Type
arr_t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
arr
  let ws :: [SubExp]
ws = forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
arr_t
  [SubExp]
is' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"write_i") forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence [m (Exp (Rep m))]
is
  let checkDim :: SubExp -> SubExp -> m SubExp
checkDim SubExp
w SubExp
i = do
        SubExp
less_than_zero <-
          forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"less_than_zero" forall a b. (a -> b) -> a -> b
$
            forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
              CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (IntType -> CmpOp
CmpSlt IntType
Int64) SubExp
i (forall v. IsValue v => v -> SubExp
constant (Int64
0 :: Int64))
        SubExp
greater_than_size <-
          forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"greater_than_size" forall a b. (a -> b) -> a -> b
$
            forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
              CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (IntType -> CmpOp
CmpSle IntType
Int64) SubExp
w SubExp
i
        forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"outside_bounds_dim" forall a b. (a -> b) -> a -> b
$
          forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
            BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogOr SubExp
less_than_zero SubExp
greater_than_size
  forall (m :: * -> *).
MonadBuilder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp BinOp
LogOr (forall v. IsValue v => v -> SubExp
constant Bool
False) forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM forall {m :: * -> *}.
MonadBuilder m =>
SubExp -> SubExp -> m SubExp
checkDim [SubExp]
ws [SubExp]
is'

-- | Construct an unspecified value of the given type.
eBlank :: MonadBuilder m => Type -> m (Exp (Rep m))
eBlank :: forall (m :: * -> *). MonadBuilder m => Type -> m (Exp (Rep m))
eBlank (Prim PrimType
t) = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
t
eBlank (Array PrimType
t Shape
shape NoUniqueness
_) = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ PrimType -> [SubExp] -> BasicOp
Scratch PrimType
t forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims Shape
shape
eBlank Acc {} = forall a. HasCallStack => String -> a
error String
"eBlank: cannot create blank accumulator"
eBlank Mem {} = forall a. HasCallStack => String -> a
error String
"eBlank: cannot create blank memory"

-- | Sign-extend to the given integer type.
asIntS :: MonadBuilder m => IntType -> SubExp -> m SubExp
asIntS :: forall (m :: * -> *).
MonadBuilder m =>
IntType -> SubExp -> m SubExp
asIntS = forall (m :: * -> *).
MonadBuilder m =>
(IntType -> IntType -> ConvOp) -> IntType -> SubExp -> m SubExp
asInt IntType -> IntType -> ConvOp
SExt

-- | Zero-extend to the given integer type.
asIntZ :: MonadBuilder m => IntType -> SubExp -> m SubExp
asIntZ :: forall (m :: * -> *).
MonadBuilder m =>
IntType -> SubExp -> m SubExp
asIntZ = forall (m :: * -> *).
MonadBuilder m =>
(IntType -> IntType -> ConvOp) -> IntType -> SubExp -> m SubExp
asInt IntType -> IntType -> ConvOp
ZExt

asInt ::
  MonadBuilder m =>
  (IntType -> IntType -> ConvOp) ->
  IntType ->
  SubExp ->
  m SubExp
asInt :: forall (m :: * -> *).
MonadBuilder m =>
(IntType -> IntType -> ConvOp) -> IntType -> SubExp -> m SubExp
asInt IntType -> IntType -> ConvOp
ext IntType
to_it SubExp
e = do
  Type
e_t <- forall {k} (t :: k) (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
e
  case Type
e_t of
    Prim (IntType IntType
from_it)
      | IntType
to_it forall a. Eq a => a -> a -> Bool
== IntType
from_it -> forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
e
      | Bool
otherwise -> forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
s forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> IntType -> ConvOp
ext IntType
from_it IntType
to_it) SubExp
e
    Type
_ -> forall a. HasCallStack => String -> a
error String
"asInt: wrong type"
  where
    s :: String
s = case SubExp
e of
      Var VName
v -> VName -> String
baseString VName
v
      SubExp
_ -> String
"to_" forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> String
prettyString IntType
to_it

-- | Apply a binary operator to several subexpressions.  A left-fold.
foldBinOp ::
  MonadBuilder m =>
  BinOp ->
  SubExp ->
  [SubExp] ->
  m (Exp (Rep m))
foldBinOp :: forall (m :: * -> *).
MonadBuilder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp BinOp
_ SubExp
ne [] =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
ne
foldBinOp BinOp
bop SubExp
ne (SubExp
e : [SubExp]
es) =
  forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp BinOp
bop (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
e) (forall (m :: * -> *).
MonadBuilder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp BinOp
bop SubExp
ne [SubExp]
es)

-- | True if all operands are true.
eAll :: MonadBuilder m => [SubExp] -> m (Exp (Rep m))
eAll :: forall (m :: * -> *). MonadBuilder m => [SubExp] -> m (Exp (Rep m))
eAll [] = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ forall v. IsValue v => v -> SubExp
constant Bool
True
eAll [SubExp
x] = forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
x
eAll (SubExp
x : [SubExp]
xs) = forall (m :: * -> *).
MonadBuilder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp BinOp
LogAnd SubExp
x [SubExp]
xs

-- | True if any operand is true.
eAny :: MonadBuilder m => [SubExp] -> m (Exp (Rep m))
eAny :: forall (m :: * -> *). MonadBuilder m => [SubExp] -> m (Exp (Rep m))
eAny [] = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ forall v. IsValue v => v -> SubExp
constant Bool
False
eAny [SubExp
x] = forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
x
eAny (SubExp
x : [SubExp]
xs) = forall (m :: * -> *).
MonadBuilder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp BinOp
LogOr SubExp
x [SubExp]
xs

-- | Create a two-parameter lambda whose body applies the given binary
-- operation to its arguments.  It is assumed that both argument and
-- result types are the same.  (This assumption should be fixed at
-- some point.)
binOpLambda ::
  (MonadBuilder m, Buildable (Rep m)) =>
  BinOp ->
  PrimType ->
  m (Lambda (Rep m))
binOpLambda :: forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m)) =>
BinOp -> PrimType -> m (Lambda (Rep m))
binOpLambda BinOp
bop PrimType
t = forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m)) =>
(SubExp -> SubExp -> BasicOp)
-> PrimType -> PrimType -> m (Lambda (Rep m))
binLambda (BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
bop) PrimType
t PrimType
t

-- | As 'binOpLambda', but for t'CmpOp's.
cmpOpLambda ::
  (MonadBuilder m, Buildable (Rep m)) =>
  CmpOp ->
  m (Lambda (Rep m))
cmpOpLambda :: forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m)) =>
CmpOp -> m (Lambda (Rep m))
cmpOpLambda CmpOp
cop = forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m)) =>
(SubExp -> SubExp -> BasicOp)
-> PrimType -> PrimType -> m (Lambda (Rep m))
binLambda (CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp CmpOp
cop) (CmpOp -> PrimType
cmpOpType CmpOp
cop) PrimType
Bool

binLambda ::
  (MonadBuilder m, Buildable (Rep m)) =>
  (SubExp -> SubExp -> BasicOp) ->
  PrimType ->
  PrimType ->
  m (Lambda (Rep m))
binLambda :: forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m)) =>
(SubExp -> SubExp -> BasicOp)
-> PrimType -> PrimType -> m (Lambda (Rep m))
binLambda SubExp -> SubExp -> BasicOp
bop PrimType
arg_t PrimType
ret_t = do
  VName
x <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"x"
  VName
y <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"y"
  Body (Rep m)
body <-
    forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> SubExpRes
subExpRes) forall a b. (a -> b) -> a -> b
$
      forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"binlam_res" forall a b. (a -> b) -> a -> b
$
        forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
          SubExp -> SubExp -> BasicOp
bop (VName -> SubExp
Var VName
x) (VName -> SubExp
Var VName
y)
  forall (f :: * -> *) a. Applicative f => a -> f a
pure
    Lambda
      { lambdaParams :: [LParam (Rep m)]
lambdaParams =
          [ forall dec. Attrs -> VName -> dec -> Param dec
Param forall a. Monoid a => a
mempty VName
x (forall shape u. PrimType -> TypeBase shape u
Prim PrimType
arg_t),
            forall dec. Attrs -> VName -> dec -> Param dec
Param forall a. Monoid a => a
mempty VName
y (forall shape u. PrimType -> TypeBase shape u
Prim PrimType
arg_t)
          ],
        lambdaReturnType :: [Type]
lambdaReturnType = [forall shape u. PrimType -> TypeBase shape u
Prim PrimType
ret_t],
        lambdaBody :: Body (Rep m)
lambdaBody = Body (Rep m)
body
      }

-- | Easily construct a t'Lambda' within a 'MonadBuilder'.  See also
-- 'runLambdaBuilder'.
mkLambda ::
  MonadBuilder m =>
  [LParam (Rep m)] ->
  m Result ->
  m (Lambda (Rep m))
mkLambda :: forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda [LParam (Rep m)]
params m Result
m = do
  (Body (Rep m)
body, [Type]
ret) <- forall (m :: * -> *) a.
MonadBuilder m =>
m (Result, a) -> m (Body (Rep m), a)
buildBody forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k) dec.
(LParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfLParams [LParam (Rep m)]
params) forall a b. (a -> b) -> a -> b
$ do
    Result
res <- m Result
m
    [Type]
ret <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {k} (t :: k) (m :: * -> *).
HasScope t m =>
SubExpRes -> m Type
subExpResType Result
res
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result
res, [Type]
ret)
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
[LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda [LParam (Rep m)]
params Body (Rep m)
body [Type]
ret

-- | Slice a full dimension of the given size.
sliceDim :: SubExp -> DimIndex SubExp
sliceDim :: SubExp -> DimIndex SubExp
sliceDim SubExp
d = forall d. d -> d -> d -> DimIndex d
DimSlice (forall v. IsValue v => v -> SubExp
constant (Int64
0 :: Int64)) SubExp
d (forall v. IsValue v => v -> SubExp
constant (Int64
1 :: Int64))

-- | @fullSlice t slice@ returns @slice@, but with 'DimSlice's of
-- entire dimensions appended to the full dimensionality of @t@.  This
-- function is used to turn incomplete indexing complete, as required
-- by 'Index'.
fullSlice :: Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice :: Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
t [DimIndex SubExp]
slice =
  forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp]
slice forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
sliceDim (forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex SubExp]
slice) forall a b. (a -> b) -> a -> b
$ forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t)

-- | @ sliceAt t n slice@ returns @slice@ but with 'DimSlice's of the
-- outer @n@ dimensions prepended, and as many appended as to make it
-- a full slice.  This is a generalisation of 'fullSlice'.
sliceAt :: Type -> Int -> [DimIndex SubExp] -> Slice SubExp
sliceAt :: Type -> Int -> [DimIndex SubExp] -> Slice SubExp
sliceAt Type
t Int
n [DimIndex SubExp]
slice =
  Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
t forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
sliceDim (forall a. Int -> [a] -> [a]
take Int
n forall a b. (a -> b) -> a -> b
$ forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t) forall a. [a] -> [a] -> [a]
++ [DimIndex SubExp]
slice

-- | Like 'fullSlice', but the dimensions are simply numeric.
fullSliceNum :: Num d => [d] -> [DimIndex d] -> Slice d
fullSliceNum :: forall d. Num d => [d] -> [DimIndex d] -> Slice d
fullSliceNum [d]
dims [DimIndex d]
slice =
  forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$ [DimIndex d]
slice forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map (\d
d -> forall d. d -> d -> d -> DimIndex d
DimSlice d
0 d
d d
1) (forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex d]
slice) [d]
dims)

-- | Does the slice describe the full size of the array?  The most
-- obvious such slice is one that 'DimSlice's the full span of every
-- dimension, but also one that fixes all unit dimensions.
isFullSlice :: Shape -> Slice SubExp -> Bool
isFullSlice :: Shape -> Slice SubExp -> Bool
isFullSlice Shape
shape Slice SubExp
slice = forall (t :: * -> *). Foldable t => t Bool -> Bool
and forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SubExp -> DimIndex SubExp -> Bool
allOfIt (forall d. ShapeBase d -> [d]
shapeDims Shape
shape) (forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slice)
  where
    allOfIt :: SubExp -> DimIndex SubExp -> Bool
allOfIt (Constant PrimValue
v) DimFix {} = PrimValue -> Bool
oneIsh PrimValue
v
    allOfIt SubExp
d (DimSlice SubExp
_ SubExp
n SubExp
_) = SubExp
d forall a. Eq a => a -> a -> Bool
== SubExp
n
    allOfIt SubExp
_ DimIndex SubExp
_ = Bool
False

-- | Conveniently construct a body that contains no bindings.
resultBody :: Buildable rep => [SubExp] -> Body rep
resultBody :: forall {k} (rep :: k). Buildable rep => [SubExp] -> Body rep
resultBody = forall {k} (rep :: k).
Buildable rep =>
Stms rep -> Result -> Body rep
mkBody forall a. Monoid a => a
mempty forall b c a. (b -> c) -> (a -> b) -> a -> c
. [SubExp] -> Result
subExpsRes

-- | Conveniently construct a body that contains no bindings - but
-- this time, monadically!
resultBodyM :: MonadBuilder m => [SubExp] -> m (Body (Rep m))
resultBodyM :: forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM = forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> Result -> m (Body (Rep m))
mkBodyM forall a. Monoid a => a
mempty forall b c a. (b -> c) -> (a -> b) -> a -> c
. [SubExp] -> Result
subExpsRes

-- | Evaluate the action, producing a body, then wrap it in all the
-- bindings it created using 'addStm'.
insertStmsM ::
  (MonadBuilder m) =>
  m (Body (Rep m)) ->
  m (Body (Rep m))
insertStmsM :: forall (m :: * -> *).
MonadBuilder m =>
m (Body (Rep m)) -> m (Body (Rep m))
insertStmsM m (Body (Rep m))
m = do
  (Body BodyDec (Rep m)
_ Stms (Rep m)
stms Result
res, Stms (Rep m)
otherstms) <- forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms m (Body (Rep m))
m
  forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> Result -> m (Body (Rep m))
mkBodyM (Stms (Rep m)
otherstms forall a. Semigroup a => a -> a -> a
<> Stms (Rep m)
stms) Result
res

-- | Evaluate an action that produces a 'Result' and an auxiliary
-- value, then return the body constructed from the 'Result' and any
-- statements added during the action, along the auxiliary value.
buildBody ::
  MonadBuilder m =>
  m (Result, a) ->
  m (Body (Rep m), a)
buildBody :: forall (m :: * -> *) a.
MonadBuilder m =>
m (Result, a) -> m (Body (Rep m), a)
buildBody m (Result, a)
m = do
  ((Result
res, a
v), Stms (Rep m)
stms) <- forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms m (Result, a)
m
  Body (Rep m)
body <- forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> Result -> m (Body (Rep m))
mkBodyM Stms (Rep m)
stms Result
res
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body (Rep m)
body, a
v)

-- | As 'buildBody', but there is no auxiliary value.
buildBody_ ::
  MonadBuilder m =>
  m Result ->
  m (Body (Rep m))
buildBody_ :: forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ m Result
m = forall a b. (a, b) -> a
fst forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a.
MonadBuilder m =>
m (Result, a) -> m (Body (Rep m), a)
buildBody ((,()) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m Result
m)

-- | Change that result where evaluation of the body would stop.  Also
-- change type annotations at branches.
mapResult ::
  Buildable rep =>
  (Result -> Body rep) ->
  Body rep ->
  Body rep
mapResult :: forall {k} (rep :: k).
Buildable rep =>
(Result -> Body rep) -> Body rep -> Body rep
mapResult Result -> Body rep
f (Body BodyDec rep
_ Stms rep
stms Result
res) =
  let Body BodyDec rep
_ Stms rep
stms2 Result
newres = Result -> Body rep
f Result
res
   in forall {k} (rep :: k).
Buildable rep =>
Stms rep -> Result -> Body rep
mkBody (Stms rep
stms forall a. Semigroup a => a -> a -> a
<> Stms rep
stms2) Result
newres

-- | Instantiate all existential parts dimensions of the given
-- type, using a monadic action to create the necessary t'SubExp's.
-- You should call this function within some monad that allows you to
-- collect the actions performed (say, 'State').
instantiateShapes ::
  Monad m =>
  (Int -> m SubExp) ->
  [TypeBase ExtShape u] ->
  m [TypeBase Shape u]
instantiateShapes :: forall (m :: * -> *) u.
Monad m =>
(Int -> m SubExp)
-> [TypeBase (ShapeBase (Ext SubExp)) u] -> m [TypeBase Shape u]
instantiateShapes Int -> m SubExp
f [TypeBase (ShapeBase (Ext SubExp)) u]
ts = forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT (forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM TypeBase (ShapeBase (Ext SubExp)) u
-> StateT (Map Int SubExp) m (TypeBase Shape u)
instantiate [TypeBase (ShapeBase (Ext SubExp)) u]
ts) forall k a. Map k a
M.empty
  where
    instantiate :: TypeBase (ShapeBase (Ext SubExp)) u
-> StateT (Map Int SubExp) m (TypeBase Shape u)
instantiate TypeBase (ShapeBase (Ext SubExp)) u
t = do
      [SubExp]
shape <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Ext SubExp -> StateT (Map Int SubExp) m SubExp
instantiate' forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims forall a b. (a -> b) -> a -> b
$ forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase (ShapeBase (Ext SubExp)) u
t
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ TypeBase (ShapeBase (Ext SubExp)) u
t forall newshape oldshape u.
ArrayShape newshape =>
TypeBase oldshape u -> newshape -> TypeBase newshape u
`setArrayShape` forall d. [d] -> ShapeBase d
Shape [SubExp]
shape
    instantiate' :: Ext SubExp -> StateT (Map Int SubExp) m SubExp
instantiate' (Ext Int
x) = do
      Map Int SubExp
m <- forall s (m :: * -> *). MonadState s m => m s
get
      case forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Int
x Map Int SubExp
m of
        Just SubExp
se -> forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
se
        Maybe SubExp
Nothing -> do
          SubExp
se <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ Int -> m SubExp
f Int
x
          forall s (m :: * -> *). MonadState s m => s -> m ()
put forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Int
x SubExp
se Map Int SubExp
m
          forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
se
    instantiate' (Free SubExp
se) = forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
se

-- | Like 'instantiateShapes', but obtains names from the provided
-- list.  If an 'Ext' is out of bounds of this list, the function
-- fails with 'error'.
instantiateShapes' :: [VName] -> [TypeBase ExtShape u] -> [TypeBase Shape u]
instantiateShapes' :: forall u.
[VName]
-> [TypeBase (ShapeBase (Ext SubExp)) u] -> [TypeBase Shape u]
instantiateShapes' [VName]
names [TypeBase (ShapeBase (Ext SubExp)) u]
ts =
  -- Carefully ensure that the order of idents we produce corresponds
  -- to their existential index.
  forall a. Identity a -> a
runIdentity forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) u.
Monad m =>
(Int -> m SubExp)
-> [TypeBase (ShapeBase (Ext SubExp)) u] -> m [TypeBase Shape u]
instantiateShapes Int -> Identity SubExp
instantiate [TypeBase (ShapeBase (Ext SubExp)) u]
ts
  where
    instantiate :: Int -> Identity SubExp
instantiate Int
x =
      case forall int a. Integral int => int -> [a] -> Maybe a
maybeNth Int
x [VName]
names of
        Maybe VName
Nothing -> forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"instantiateShapes': " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> String
prettyString [VName]
names forall a. [a] -> [a] -> [a]
++ String
", " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Int
x
        Just VName
name -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
name

-- | Remove existentials by imposing sizes from another type where
-- needed.
removeExistentials :: ExtType -> Type -> Type
removeExistentials :: ExtType -> Type -> Type
removeExistentials ExtType
t1 Type
t2 =
  ExtType
t1
    forall oldshape u.
TypeBase oldshape u -> [SubExp] -> TypeBase Shape u
`setArrayDims` forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
      forall {p}. Ext p -> p -> p
nonExistential
      (forall d. ShapeBase d -> [d]
shapeDims forall a b. (a -> b) -> a -> b
$ forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape ExtType
t1)
      (forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t2)
  where
    nonExistential :: Ext p -> p -> p
nonExistential (Ext Int
_) p
dim = p
dim
    nonExistential (Free p
dim) p
_ = p
dim

-- | Can be used as the definition of 'mkLetNames' for a 'Buildable'
-- instance for simple representations.
simpleMkLetNames ::
  ( ExpDec rep ~ (),
    LetDec rep ~ Type,
    MonadFreshNames m,
    TypedOp (Op rep),
    HasScope rep m
  ) =>
  [VName] ->
  Exp rep ->
  m (Stm rep)
simpleMkLetNames :: forall {k} (rep :: k) (m :: * -> *).
(ExpDec rep ~ (), LetDec rep ~ Type, MonadFreshNames m,
 TypedOp (Op rep), HasScope rep m) =>
[VName] -> Exp rep -> m (Stm rep)
simpleMkLetNames [VName]
names Exp rep
e = do
  [ExtType]
et <- forall {k} (rep :: k) (m :: * -> *).
(HasScope rep m, TypedOp (Op rep)) =>
Exp rep -> m [ExtType]
expExtType Exp rep
e
  let ts :: [Type]
ts = forall u.
[VName]
-> [TypeBase (ShapeBase (Ext SubExp)) u] -> [TypeBase Shape u]
instantiateShapes' [VName]
names [ExtType]
et
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (forall dec. [PatElem dec] -> Pat dec
Pat forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall dec. VName -> dec -> PatElem dec
PatElem [VName]
names [Type]
ts) (forall dec. dec -> StmAux dec
defAux ()) Exp rep
e

-- | Instances of this class can be converted to Futhark expressions
-- within a 'MonadBuilder'.
class ToExp a where
  toExp :: MonadBuilder m => a -> m (Exp (Rep m))

instance ToExp SubExp where
  toExp :: forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
toExp = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp

instance ToExp VName where
  toExp :: forall (m :: * -> *). MonadBuilder m => VName -> m (Exp (Rep m))
toExp = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var

-- | A convenient composition of 'letSubExp' and 'toExp'.
toSubExp :: (MonadBuilder m, ToExp a) => String -> a -> m SubExp
toSubExp :: forall (m :: * -> *) a.
(MonadBuilder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
s a
e = forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
s forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp a
e