{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE Safe #-}
-- | = Constructing Futhark ASTs
--
-- This module re-exports and defines a bunch of building blocks for
-- constructing fragments of Futhark ASTs.  More importantly, it also
-- contains a basic introduction on how to use them.
--
-- The "Futhark.IR.Syntax" module contains the core
-- AST definition.  One important invariant is that all bound names in
-- a Futhark program must be /globally/ unique.  In principle, you
-- could use the facilities from "Futhark.MonadFreshNames" (or your
-- own bespoke source of unique names) to manually construct
-- expressions, statements, and entire ASTs.  In practice, this would
-- be very tedious.  Instead, we have defined a collection of building
-- blocks (centered around the 'MonadBinder' type class) that permits
-- a more abstract way of generating code.
--
-- Constructing ASTs with these building blocks requires you to ensure
-- that all free variables are in scope.  See
-- "Futhark.IR.Prop.Scope".
--
-- == 'MonadBinder'
--
-- A monad that implements 'MonadBinder' tracks the statements added
-- so far, the current names in scope, and allows you to add
-- additional statements with 'addStm'.  Any monad that implements
-- 'MonadBinder' also implements the t'Lore' type family, which
-- indicates which lore it works with.  Inside a 'MonadBinder' we can
-- use 'collectStms' to gather up the 'Stms' added with 'addStm' in
-- some nested computation.
--
-- The 'BinderT' monad (and its convenient 'Binder' version) provides
-- the simplest implementation of 'MonadBinder'.
--
-- == Higher-level building blocks
--
-- On top of the raw facilities provided by 'MonadBinder', we have
-- more convenient facilities.  For example, 'letSubExp' lets us
-- conveniently create a 'Stm' for an 'Exp' that produces a /single/
-- value, and returns the (fresh) name for the resulting variable:
--
-- @
-- z <- letExp "z" $ BasicOp $ BinOp (Add Int32) (Var x) (Var y)
-- @
--
-- == Examples
--
-- The "Futhark.Transform.FirstOrderTransform" module is a
-- (relatively) simple example of how to use these components.  As are
-- some of the high-level building blocks in this very module.
module Futhark.Construct
  ( letSubExp
  , letSubExps
  , letExp
  , letTupExp
  , letTupExp'
  , letInPlace

  , eSubExp
  , eIf
  , eIf'
  , eBinOp
  , eCmpOp
  , eConvOp
  , eNot
  , eSignum
  , eCopy
  , eAssert
  , eBody
  , eLambda
  , eRoundToMultipleOf
  , eSliceArray
  , eBlank
  , eAll

  , eOutOfBounds
  , eWriteArray

  , asIntZ, asIntS

  , resultBody
  , resultBodyM
  , insertStmsM
  , mapResult

  , foldBinOp
  , binOpLambda
  , cmpOpLambda
  , sliceDim
  , fullSlice
  , fullSliceNum
  , isFullSlice
  , sliceAt
  , ifCommon

  , module Futhark.Binder

  -- * Result types
  , instantiateShapes
  , instantiateShapes'
  , removeExistentials

  -- * Convenience
  , simpleMkLetNames

  , ToExp(..)
  , toSubExp
  )
where

import qualified Data.Map.Strict as M
import Data.List (sortOn)
import Control.Monad.Identity
import Control.Monad.State
import Control.Monad.Writer

import Futhark.IR
import Futhark.MonadFreshNames
import Futhark.Binder

letSubExp :: MonadBinder m =>
             String -> Exp (Lore m) -> m SubExp
letSubExp :: String -> Exp (Lore m) -> m SubExp
letSubExp String
_ (BasicOp (SubExp SubExp
se)) = SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
se
letSubExp String
desc Exp (Lore m)
e = VName -> SubExp
Var (VName -> SubExp) -> m VName -> m SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
desc Exp (Lore m)
e

letExp :: MonadBinder m =>
          String -> Exp (Lore m) -> m VName
letExp :: String -> Exp (Lore m) -> m VName
letExp String
_ (BasicOp (SubExp (Var VName
v))) =
  VName -> m VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
v
letExp String
desc Exp (Lore m)
e = do
  Int
n <- [ExtType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([ExtType] -> Int) -> m [ExtType] -> m Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp (Lore m) -> m [ExtType]
forall lore (m :: * -> *).
(HasScope lore m, TypedOp (Op lore)) =>
Exp lore -> m [ExtType]
expExtType Exp (Lore m)
e
  [VName]
vs <- Int -> m VName -> m [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n (m VName -> m [VName]) -> m VName -> m [VName]
forall a b. (a -> b) -> a -> b
$ String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
desc
  [VName] -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName]
vs Exp (Lore m)
e
  case [VName]
vs of
    [VName
v] -> VName -> m VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
v
    [VName]
_   -> String -> m VName
forall a. HasCallStack => String -> a
error (String -> m VName) -> String -> m VName
forall a b. (a -> b) -> a -> b
$ String
"letExp: tuple-typed expression given:\n" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Exp (Lore m) -> String
forall a. Pretty a => a -> String
pretty Exp (Lore m)
e

letInPlace :: MonadBinder m =>
              String -> VName -> Slice SubExp -> Exp (Lore m)
           -> m VName
letInPlace :: String -> VName -> Slice SubExp -> Exp (Lore m) -> m VName
letInPlace String
desc VName
src Slice SubExp
slice Exp (Lore m)
e = do
  SubExp
tmp <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp (String
desc String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_tmp") Exp (Lore m)
e
  String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
desc (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> SubExp -> BasicOp
Update VName
src Slice SubExp
slice SubExp
tmp

letSubExps :: MonadBinder m =>
              String -> [Exp (Lore m)] -> m [SubExp]
letSubExps :: String -> [Exp (Lore m)] -> m [SubExp]
letSubExps String
desc = (Exp (Lore m) -> m SubExp) -> [Exp (Lore m)] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((Exp (Lore m) -> m SubExp) -> [Exp (Lore m)] -> m [SubExp])
-> (Exp (Lore m) -> m SubExp) -> [Exp (Lore m)] -> m [SubExp]
forall a b. (a -> b) -> a -> b
$ String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
desc

letTupExp :: (MonadBinder m) =>
             String -> Exp (Lore m)
          -> m [VName]
letTupExp :: String -> Exp (Lore m) -> m [VName]
letTupExp String
_ (BasicOp (SubExp (Var VName
v))) =
  [VName] -> m [VName]
forall (m :: * -> *) a. Monad m => a -> m a
return [VName
v]
letTupExp String
name Exp (Lore m)
e = do
  Int
numValues <- [ExtType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([ExtType] -> Int) -> m [ExtType] -> m Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp (Lore m) -> m [ExtType]
forall lore (m :: * -> *).
(HasScope lore m, TypedOp (Op lore)) =>
Exp lore -> m [ExtType]
expExtType Exp (Lore m)
e
  [VName]
names <- Int -> m VName -> m [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
numValues (m VName -> m [VName]) -> m VName -> m [VName]
forall a b. (a -> b) -> a -> b
$ String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
name
  [VName] -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName]
names Exp (Lore m)
e
  [VName] -> m [VName]
forall (m :: * -> *) a. Monad m => a -> m a
return [VName]
names

letTupExp' :: (MonadBinder m) =>
              String -> Exp (Lore m)
           -> m [SubExp]
letTupExp' :: String -> Exp (Lore m) -> m [SubExp]
letTupExp' String
_ (BasicOp (SubExp SubExp
se)) = [SubExp] -> m [SubExp]
forall (m :: * -> *) a. Monad m => a -> m a
return [SubExp
se]
letTupExp' String
name Exp (Lore m)
ses = (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var ([VName] -> [SubExp]) -> m [VName] -> m [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> Exp (Lore m) -> m [VName]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [VName]
letTupExp String
name Exp (Lore m)
ses

eSubExp :: MonadBinder m =>
           SubExp -> m (Exp (Lore m))
eSubExp :: SubExp -> m (Exp (Lore m))
eSubExp = Exp (Lore m) -> m (Exp (Lore m))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp (Lore m) -> m (Exp (Lore m)))
-> (SubExp -> Exp (Lore m)) -> SubExp -> m (Exp (Lore m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m))
-> (SubExp -> BasicOp) -> SubExp -> Exp (Lore m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp

eIf :: (MonadBinder m, BranchType (Lore m) ~ ExtType) =>
       m (Exp (Lore m)) -> m (Body (Lore m)) -> m (Body (Lore m))
    -> m (Exp (Lore m))
eIf :: m (Exp (Lore m))
-> m (Body (Lore m)) -> m (Body (Lore m)) -> m (Exp (Lore m))
eIf m (Exp (Lore m))
ce m (Body (Lore m))
te m (Body (Lore m))
fe = m (Exp (Lore m))
-> m (Body (Lore m))
-> m (Body (Lore m))
-> IfSort
-> m (Exp (Lore m))
forall (m :: * -> *).
(MonadBinder m, BranchType (Lore m) ~ ExtType) =>
m (Exp (Lore m))
-> m (Body (Lore m))
-> m (Body (Lore m))
-> IfSort
-> m (Exp (Lore m))
eIf' m (Exp (Lore m))
ce m (Body (Lore m))
te m (Body (Lore m))
fe IfSort
IfNormal

-- | As 'eIf', but an 'IfSort' can be given.
eIf' :: (MonadBinder m, BranchType (Lore m) ~ ExtType) =>
        m (Exp (Lore m)) -> m (Body (Lore m)) -> m (Body (Lore m))
     -> IfSort
     -> m (Exp (Lore m))
eIf' :: m (Exp (Lore m))
-> m (Body (Lore m))
-> m (Body (Lore m))
-> IfSort
-> m (Exp (Lore m))
eIf' m (Exp (Lore m))
ce m (Body (Lore m))
te m (Body (Lore m))
fe IfSort
if_sort = do
  SubExp
ce' <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"cond" (Exp (Lore m) -> m SubExp) -> m (Exp (Lore m)) -> m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Exp (Lore m))
ce
  Body (Lore m)
te' <- m (Body (Lore m)) -> m (Body (Lore m))
forall (m :: * -> *).
MonadBinder m =>
m (Body (Lore m)) -> m (Body (Lore m))
insertStmsM m (Body (Lore m))
te
  Body (Lore m)
fe' <- m (Body (Lore m)) -> m (Body (Lore m))
forall (m :: * -> *).
MonadBinder m =>
m (Body (Lore m)) -> m (Body (Lore m))
insertStmsM m (Body (Lore m))
fe
  -- We need to construct the context.
  [ExtType]
ts <- [ExtType] -> [ExtType] -> [ExtType]
forall u.
[TypeBase ExtShape u]
-> [TypeBase ExtShape u] -> [TypeBase ExtShape u]
generaliseExtTypes ([ExtType] -> [ExtType] -> [ExtType])
-> m [ExtType] -> m ([ExtType] -> [ExtType])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body (Lore m) -> m [ExtType]
forall lore (m :: * -> *).
(HasScope lore m, Monad m) =>
Body lore -> m [ExtType]
bodyExtType Body (Lore m)
te' m ([ExtType] -> [ExtType]) -> m [ExtType] -> m [ExtType]
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Body (Lore m) -> m [ExtType]
forall lore (m :: * -> *).
(HasScope lore m, Monad m) =>
Body lore -> m [ExtType]
bodyExtType Body (Lore m)
fe'
  Body (Lore m)
te'' <- [ExtType] -> Body (Lore m) -> m (Body (Lore m))
forall (m :: * -> *) u.
MonadBinder m =>
[TypeBase ExtShape u] -> BodyT (Lore m) -> m (BodyT (Lore m))
addContextForBranch [ExtType]
ts Body (Lore m)
te'
  Body (Lore m)
fe'' <- [ExtType] -> Body (Lore m) -> m (Body (Lore m))
forall (m :: * -> *) u.
MonadBinder m =>
[TypeBase ExtShape u] -> BodyT (Lore m) -> m (BodyT (Lore m))
addContextForBranch [ExtType]
ts Body (Lore m)
fe'
  Exp (Lore m) -> m (Exp (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp (Lore m) -> m (Exp (Lore m)))
-> Exp (Lore m) -> m (Exp (Lore m))
forall a b. (a -> b) -> a -> b
$ SubExp
-> Body (Lore m)
-> Body (Lore m)
-> IfDec (BranchType (Lore m))
-> Exp (Lore m)
forall lore.
SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
If SubExp
ce' Body (Lore m)
te'' Body (Lore m)
fe'' (IfDec (BranchType (Lore m)) -> Exp (Lore m))
-> IfDec (BranchType (Lore m)) -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ [ExtType] -> IfSort -> IfDec ExtType
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [ExtType]
ts IfSort
if_sort
  where addContextForBranch :: [TypeBase ExtShape u] -> BodyT (Lore m) -> m (BodyT (Lore m))
addContextForBranch [TypeBase ExtShape u]
ts (Body BodyDec (Lore m)
_ Stms (Lore m)
stms [SubExp]
val_res) = do
          [Type]
body_ts <- ExtendedScope (Lore m) m [Type] -> Scope (Lore m) -> m [Type]
forall lore (m :: * -> *) a.
ExtendedScope lore m a -> Scope lore -> m a
extendedScope ((SubExp -> ExtendedScope (Lore m) m Type)
-> [SubExp] -> ExtendedScope (Lore m) m [Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> ExtendedScope (Lore m) m Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType [SubExp]
val_res) Scope (Lore m)
stmsscope
          let ctx_res :: [SubExp]
ctx_res = ((Int, SubExp) -> SubExp) -> [(Int, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (Int, SubExp) -> SubExp
forall a b. (a, b) -> b
snd ([(Int, SubExp)] -> [SubExp]) -> [(Int, SubExp)] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ ((Int, SubExp) -> Int) -> [(Int, SubExp)] -> [(Int, SubExp)]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn (Int, SubExp) -> Int
forall a b. (a, b) -> a
fst ([(Int, SubExp)] -> [(Int, SubExp)])
-> [(Int, SubExp)] -> [(Int, SubExp)]
forall a b. (a -> b) -> a -> b
$
                        Map Int SubExp -> [(Int, SubExp)]
forall k a. Map k a -> [(k, a)]
M.toList (Map Int SubExp -> [(Int, SubExp)])
-> Map Int SubExp -> [(Int, SubExp)]
forall a b. (a -> b) -> a -> b
$ [TypeBase ExtShape u] -> [Type] -> Map Int SubExp
forall u u1.
[TypeBase ExtShape u] -> [TypeBase Shape u1] -> Map Int SubExp
shapeExtMapping [TypeBase ExtShape u]
ts [Type]
body_ts
          Stms (Lore m) -> [SubExp] -> m (BodyT (Lore m))
forall (m :: * -> *).
MonadBinder m =>
Stms (Lore m) -> [SubExp] -> m (Body (Lore m))
mkBodyM Stms (Lore m)
stms ([SubExp] -> m (BodyT (Lore m))) -> [SubExp] -> m (BodyT (Lore m))
forall a b. (a -> b) -> a -> b
$ [SubExp]
ctx_res[SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++[SubExp]
val_res
            where stmsscope :: Scope (Lore m)
stmsscope = Stms (Lore m) -> Scope (Lore m)
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stms (Lore m)
stms

-- The type of a body.  Watch out: this only works for the degenerate
-- case where the body does not already return its context.
bodyExtType :: (HasScope lore m, Monad m) => Body lore -> m [ExtType]
bodyExtType :: Body lore -> m [ExtType]
bodyExtType (Body BodyDec lore
_ Stms lore
stms [SubExp]
res) =
  [VName] -> [ExtType] -> [ExtType]
existentialiseExtTypes (Map VName (NameInfo lore) -> [VName]
forall k a. Map k a -> [k]
M.keys Map VName (NameInfo lore)
stmsscope) ([ExtType] -> [ExtType])
-> ([Type] -> [ExtType]) -> [Type] -> [ExtType]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Type] -> [ExtType]
forall u. [TypeBase Shape u] -> [TypeBase ExtShape u]
staticShapes ([Type] -> [ExtType]) -> m [Type] -> m [ExtType]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
  ExtendedScope lore m [Type]
-> Map VName (NameInfo lore) -> m [Type]
forall lore (m :: * -> *) a.
ExtendedScope lore m a -> Scope lore -> m a
extendedScope ((SubExp -> ExtendedScope lore m Type)
-> [SubExp] -> ExtendedScope lore m [Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> ExtendedScope lore m Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType [SubExp]
res) Map VName (NameInfo lore)
stmsscope
  where stmsscope :: Map VName (NameInfo lore)
stmsscope = Stms lore -> Map VName (NameInfo lore)
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stms lore
stms

eBinOp :: MonadBinder m =>
          BinOp -> m (Exp (Lore m)) -> m (Exp (Lore m))
       -> m (Exp (Lore m))
eBinOp :: BinOp -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
eBinOp BinOp
op m (Exp (Lore m))
x m (Exp (Lore m))
y = do
  SubExp
x' <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"x" (Exp (Lore m) -> m SubExp) -> m (Exp (Lore m)) -> m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Exp (Lore m))
x
  SubExp
y' <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"y" (Exp (Lore m) -> m SubExp) -> m (Exp (Lore m)) -> m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Exp (Lore m))
y
  Exp (Lore m) -> m (Exp (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp (Lore m) -> m (Exp (Lore m)))
-> Exp (Lore m) -> m (Exp (Lore m))
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
op SubExp
x' SubExp
y'

eCmpOp :: MonadBinder m =>
          CmpOp -> m (Exp (Lore m)) -> m (Exp (Lore m))
       -> m (Exp (Lore m))
eCmpOp :: CmpOp -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
eCmpOp CmpOp
op m (Exp (Lore m))
x m (Exp (Lore m))
y = do
  SubExp
x' <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"x" (Exp (Lore m) -> m SubExp) -> m (Exp (Lore m)) -> m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Exp (Lore m))
x
  SubExp
y' <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"y" (Exp (Lore m) -> m SubExp) -> m (Exp (Lore m)) -> m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Exp (Lore m))
y
  Exp (Lore m) -> m (Exp (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp (Lore m) -> m (Exp (Lore m)))
-> Exp (Lore m) -> m (Exp (Lore m))
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp CmpOp
op SubExp
x' SubExp
y'

eConvOp :: MonadBinder m =>
           ConvOp -> m (Exp (Lore m))
        -> m (Exp (Lore m))
eConvOp :: ConvOp -> m (Exp (Lore m)) -> m (Exp (Lore m))
eConvOp ConvOp
op m (Exp (Lore m))
x = do
  SubExp
x' <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"x" (Exp (Lore m) -> m SubExp) -> m (Exp (Lore m)) -> m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Exp (Lore m))
x
  Exp (Lore m) -> m (Exp (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp (Lore m) -> m (Exp (Lore m)))
-> Exp (Lore m) -> m (Exp (Lore m))
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp ConvOp
op SubExp
x'

eNot :: MonadBinder m =>
        m (Exp (Lore m)) -> m (Exp (Lore m))
eNot :: m (Exp (Lore m)) -> m (Exp (Lore m))
eNot m (Exp (Lore m))
e = BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m))
-> (SubExp -> BasicOp) -> SubExp -> Exp (Lore m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. UnOp -> SubExp -> BasicOp
UnOp UnOp
Not (SubExp -> Exp (Lore m)) -> m SubExp -> m (Exp (Lore m))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"not_arg" (Exp (Lore m) -> m SubExp) -> m (Exp (Lore m)) -> m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Exp (Lore m))
e)

eSignum :: MonadBinder m =>
        m (Exp (Lore m)) -> m (Exp (Lore m))
eSignum :: m (Exp (Lore m)) -> m (Exp (Lore m))
eSignum m (Exp (Lore m))
em = do
  Exp (Lore m)
e <- m (Exp (Lore m))
em
  SubExp
e' <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"signum_arg" Exp (Lore m)
e
  Type
t <- SubExp -> m Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
e'
  case Type
t of
    Prim (IntType IntType
int_t) ->
      Exp (Lore m) -> m (Exp (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp (Lore m) -> m (Exp (Lore m)))
-> Exp (Lore m) -> m (Exp (Lore m))
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ UnOp -> SubExp -> BasicOp
UnOp (IntType -> UnOp
SSignum IntType
int_t) SubExp
e'
    Type
_ ->
      String -> m (Exp (Lore m))
forall a. HasCallStack => String -> a
error (String -> m (Exp (Lore m))) -> String -> m (Exp (Lore m))
forall a b. (a -> b) -> a -> b
$ String
"eSignum: operand " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Exp (Lore m) -> String
forall a. Pretty a => a -> String
pretty Exp (Lore m)
e String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" has invalid type."

eCopy :: MonadBinder m =>
         m (Exp (Lore m)) -> m (Exp (Lore m))
eCopy :: m (Exp (Lore m)) -> m (Exp (Lore m))
eCopy m (Exp (Lore m))
e = BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m))
-> (VName -> BasicOp) -> VName -> Exp (Lore m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> BasicOp
Copy (VName -> Exp (Lore m)) -> m VName -> m (Exp (Lore m))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"copy_arg" (Exp (Lore m) -> m VName) -> m (Exp (Lore m)) -> m VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Exp (Lore m))
e)

eAssert :: MonadBinder m =>
         m (Exp (Lore m)) -> ErrorMsg SubExp -> SrcLoc -> m (Exp (Lore m))
eAssert :: m (Exp (Lore m)) -> ErrorMsg SubExp -> SrcLoc -> m (Exp (Lore m))
eAssert m (Exp (Lore m))
e ErrorMsg SubExp
msg SrcLoc
loc = do SubExp
e' <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"assert_arg" (Exp (Lore m) -> m SubExp) -> m (Exp (Lore m)) -> m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Exp (Lore m))
e
                       Exp (Lore m) -> m (Exp (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp (Lore m) -> m (Exp (Lore m)))
-> Exp (Lore m) -> m (Exp (Lore m))
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ SubExp -> ErrorMsg SubExp -> (SrcLoc, [SrcLoc]) -> BasicOp
Assert SubExp
e' ErrorMsg SubExp
msg (SrcLoc
loc, [SrcLoc]
forall a. Monoid a => a
mempty)

eBody :: (MonadBinder m) =>
         [m (Exp (Lore m))]
      -> m (Body (Lore m))
eBody :: [m (Exp (Lore m))] -> m (Body (Lore m))
eBody [m (Exp (Lore m))]
es = m (Body (Lore m)) -> m (Body (Lore m))
forall (m :: * -> *).
MonadBinder m =>
m (Body (Lore m)) -> m (Body (Lore m))
insertStmsM (m (Body (Lore m)) -> m (Body (Lore m)))
-> m (Body (Lore m)) -> m (Body (Lore m))
forall a b. (a -> b) -> a -> b
$ do
             [Exp (Lore m)]
es' <- [m (Exp (Lore m))] -> m [Exp (Lore m)]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence [m (Exp (Lore m))]
es
             [[VName]]
xs <- (Exp (Lore m) -> m [VName]) -> [Exp (Lore m)] -> m [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> Exp (Lore m) -> m [VName]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [VName]
letTupExp String
"x") [Exp (Lore m)]
es'
             Stms (Lore m) -> [SubExp] -> m (Body (Lore m))
forall (m :: * -> *).
MonadBinder m =>
Stms (Lore m) -> [SubExp] -> m (Body (Lore m))
mkBodyM Stms (Lore m)
forall a. Monoid a => a
mempty ([SubExp] -> m (Body (Lore m))) -> [SubExp] -> m (Body (Lore m))
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var ([VName] -> [SubExp]) -> [VName] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ [[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
xs

eLambda :: MonadBinder m =>
           Lambda (Lore m) -> [m (Exp (Lore m))] -> m [SubExp]
eLambda :: Lambda (Lore m) -> [m (Exp (Lore m))] -> m [SubExp]
eLambda Lambda (Lore m)
lam [m (Exp (Lore m))]
args = do (Param (LParamInfo (Lore m)) -> m (Exp (Lore m)) -> m ())
-> [Param (LParamInfo (Lore m))] -> [m (Exp (Lore m))] -> m ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Param (LParamInfo (Lore m)) -> m (Exp (Lore m)) -> m ()
forall (m :: * -> *) dec.
MonadBinder m =>
Param dec -> m (Exp (Lore m)) -> m ()
bindParam (Lambda (Lore m) -> [Param (LParamInfo (Lore m))]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda (Lore m)
lam) [m (Exp (Lore m))]
args
                      Body (Lore m) -> m [SubExp]
forall (m :: * -> *). MonadBinder m => Body (Lore m) -> m [SubExp]
bodyBind (Body (Lore m) -> m [SubExp]) -> Body (Lore m) -> m [SubExp]
forall a b. (a -> b) -> a -> b
$ Lambda (Lore m) -> Body (Lore m)
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda (Lore m)
lam
  where bindParam :: Param dec -> m (Exp (Lore m)) -> m ()
bindParam Param dec
param m (Exp (Lore m))
arg = [VName] -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
param] (Exp (Lore m) -> m ()) -> m (Exp (Lore m)) -> m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Exp (Lore m))
arg

eRoundToMultipleOf :: MonadBinder m =>
                      IntType -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
eRoundToMultipleOf :: IntType -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
eRoundToMultipleOf IntType
t m (Exp (Lore m))
x m (Exp (Lore m))
d =
  m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
ePlus m (Exp (Lore m))
x (m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
eMod (m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
eMinus m (Exp (Lore m))
d (m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
eMod m (Exp (Lore m))
x m (Exp (Lore m))
d)) m (Exp (Lore m))
d)
  where eMod :: m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
eMod = BinOp -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
eBinOp (IntType -> Safety -> BinOp
SMod IntType
t Safety
Unsafe)
        eMinus :: m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
eMinus = BinOp -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
eBinOp (IntType -> Overflow -> BinOp
Sub IntType
t Overflow
OverflowWrap)
        ePlus :: m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
ePlus = BinOp -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
eBinOp (IntType -> Overflow -> BinOp
Add IntType
t Overflow
OverflowWrap)

-- | Construct an 'Index' expressions that slices an array with unit stride.
eSliceArray :: MonadBinder m =>
               Int -> VName -> m (Exp (Lore m)) -> m (Exp (Lore m))
            -> m (Exp (Lore m))
eSliceArray :: Int
-> VName
-> m (Exp (Lore m))
-> m (Exp (Lore m))
-> m (Exp (Lore m))
eSliceArray Int
d VName
arr m (Exp (Lore m))
i m (Exp (Lore m))
n = do
  Type
arr_t <- VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
  let skips :: Slice SubExp
skips = (SubExp -> DimIndex SubExp) -> [SubExp] -> Slice SubExp
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> SubExp -> DimIndex SubExp
slice (Int32 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int32
0::Int32))) ([SubExp] -> Slice SubExp) -> [SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
take Int
d ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
arr_t
  SubExp
i' <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"slice_i" (Exp (Lore m) -> m SubExp) -> m (Exp (Lore m)) -> m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Exp (Lore m))
i
  SubExp
n' <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"slice_n" (Exp (Lore m) -> m SubExp) -> m (Exp (Lore m)) -> m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Exp (Lore m))
n
  Exp (Lore m) -> m (Exp (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp (Lore m) -> m (Exp (Lore m)))
-> Exp (Lore m) -> m (Exp (Lore m))
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> Slice SubExp -> Slice SubExp
fullSlice Type
arr_t (Slice SubExp -> Slice SubExp) -> Slice SubExp -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ Slice SubExp
skips Slice SubExp -> Slice SubExp -> Slice SubExp
forall a. [a] -> [a] -> [a]
++ [SubExp -> SubExp -> DimIndex SubExp
slice SubExp
i' SubExp
n']
  where slice :: SubExp -> SubExp -> DimIndex SubExp
slice SubExp
j SubExp
m = SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice SubExp
j SubExp
m (Int32 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int32
1::Int32))

-- | Are these indexes out-of-bounds for the array?
eOutOfBounds :: MonadBinder m =>
                VName -> [m (Exp (Lore m))] -> m (Exp (Lore m))
eOutOfBounds :: VName -> [m (Exp (Lore m))] -> m (Exp (Lore m))
eOutOfBounds VName
arr [m (Exp (Lore m))]
is = do
  Type
arr_t <- VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
  let ws :: [SubExp]
ws = Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
arr_t
  [SubExp]
is' <- (Exp (Lore m) -> m SubExp) -> [Exp (Lore m)] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"write_i") ([Exp (Lore m)] -> m [SubExp]) -> m [Exp (Lore m)] -> m [SubExp]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [m (Exp (Lore m))] -> m [Exp (Lore m)]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence [m (Exp (Lore m))]
is
  let checkDim :: SubExp -> SubExp -> m SubExp
checkDim SubExp
w SubExp
i = do
        SubExp
less_than_zero <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"less_than_zero" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
          BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (IntType -> CmpOp
CmpSlt IntType
Int32) SubExp
i (Int32 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int32
0::Int32))
        SubExp
greater_than_size <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"greater_than_size" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
          BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (IntType -> CmpOp
CmpSle IntType
Int32) SubExp
w SubExp
i
        String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"outside_bounds_dim" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
          BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogOr SubExp
less_than_zero SubExp
greater_than_size
  BinOp -> SubExp -> [SubExp] -> m (Exp (Lore m))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Lore m))
foldBinOp BinOp
LogOr (Bool -> SubExp
forall v. IsValue v => v -> SubExp
constant Bool
False) ([SubExp] -> m (Exp (Lore m))) -> m [SubExp] -> m (Exp (Lore m))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (SubExp -> SubExp -> m SubExp)
-> [SubExp] -> [SubExp] -> m [SubExp]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM SubExp -> SubExp -> m SubExp
forall (m :: * -> *). MonadBinder m => SubExp -> SubExp -> m SubExp
checkDim [SubExp]
ws [SubExp]
is'

-- | Write to an index of the array, if within bounds.  Otherwise,
-- nothing.  Produces the updated array.
eWriteArray :: (MonadBinder m, BranchType (Lore m) ~ ExtType) =>
               VName -> [m (Exp (Lore m))] -> m (Exp (Lore m))
            -> m (Exp (Lore m))
eWriteArray :: VName -> [m (Exp (Lore m))] -> m (Exp (Lore m)) -> m (Exp (Lore m))
eWriteArray VName
arr [m (Exp (Lore m))]
is m (Exp (Lore m))
v = do
  Type
arr_t <- VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
  [SubExp]
is' <- (Exp (Lore m) -> m SubExp) -> [Exp (Lore m)] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"write_i") ([Exp (Lore m)] -> m [SubExp]) -> m [Exp (Lore m)] -> m [SubExp]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [m (Exp (Lore m))] -> m [Exp (Lore m)]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence [m (Exp (Lore m))]
is
  SubExp
v' <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"write_v" (Exp (Lore m) -> m SubExp) -> m (Exp (Lore m)) -> m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Exp (Lore m))
v

  SubExp
outside_bounds <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"outside_bounds" (Exp (Lore m) -> m SubExp) -> m (Exp (Lore m)) -> m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> [m (Exp (Lore m))] -> m (Exp (Lore m))
forall (m :: * -> *).
MonadBinder m =>
VName -> [m (Exp (Lore m))] -> m (Exp (Lore m))
eOutOfBounds VName
arr [m (Exp (Lore m))]
is

  BodyT (Lore m)
outside_bounds_branch <- m (BodyT (Lore m)) -> m (BodyT (Lore m))
forall (m :: * -> *).
MonadBinder m =>
m (Body (Lore m)) -> m (Body (Lore m))
insertStmsM (m (BodyT (Lore m)) -> m (BodyT (Lore m)))
-> m (BodyT (Lore m)) -> m (BodyT (Lore m))
forall a b. (a -> b) -> a -> b
$ [SubExp] -> m (BodyT (Lore m))
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM [VName -> SubExp
Var VName
arr]

  BodyT (Lore m)
in_bounds_branch <- m (BodyT (Lore m)) -> m (BodyT (Lore m))
forall (m :: * -> *).
MonadBinder m =>
m (Body (Lore m)) -> m (Body (Lore m))
insertStmsM (m (BodyT (Lore m)) -> m (BodyT (Lore m)))
-> m (BodyT (Lore m)) -> m (BodyT (Lore m))
forall a b. (a -> b) -> a -> b
$ do
    VName
res <- String -> VName -> Slice SubExp -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> VName -> Slice SubExp -> Exp (Lore m) -> m VName
letInPlace String
"write_out_inside_bounds" VName
arr
           (Type -> Slice SubExp -> Slice SubExp
fullSlice Type
arr_t ((SubExp -> DimIndex SubExp) -> [SubExp] -> Slice SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix [SubExp]
is')) (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
v'
    [SubExp] -> m (BodyT (Lore m))
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM [VName -> SubExp
Var VName
res]

  Exp (Lore m) -> m (Exp (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp (Lore m) -> m (Exp (Lore m)))
-> Exp (Lore m) -> m (Exp (Lore m))
forall a b. (a -> b) -> a -> b
$
    SubExp
-> BodyT (Lore m)
-> BodyT (Lore m)
-> IfDec (BranchType (Lore m))
-> Exp (Lore m)
forall lore.
SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
If SubExp
outside_bounds BodyT (Lore m)
outside_bounds_branch BodyT (Lore m)
in_bounds_branch (IfDec (BranchType (Lore m)) -> Exp (Lore m))
-> IfDec (BranchType (Lore m)) -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$
    [Type] -> IfDec ExtType
ifCommon [Type
arr_t]

-- | Construct an unspecified value of the given type.
eBlank :: MonadBinder m => Type -> m (Exp (Lore m))
eBlank :: Type -> m (Exp (Lore m))
eBlank (Prim PrimType
t) = Exp (Lore m) -> m (Exp (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp (Lore m) -> m (Exp (Lore m)))
-> Exp (Lore m) -> m (Exp (Lore m))
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
t
eBlank (Array PrimType
t Shape
shape NoUniqueness
_) = Exp (Lore m) -> m (Exp (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp (Lore m) -> m (Exp (Lore m)))
-> Exp (Lore m) -> m (Exp (Lore m))
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ PrimType -> [SubExp] -> BasicOp
Scratch PrimType
t ([SubExp] -> BasicOp) -> [SubExp] -> BasicOp
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape
eBlank Mem{} = String -> m (Exp (Lore m))
forall a. HasCallStack => String -> a
error String
"eBlank: cannot create blank memory"

-- | Sign-extend to the given integer type.
asIntS :: MonadBinder m => IntType -> SubExp -> m SubExp
asIntS :: IntType -> SubExp -> m SubExp
asIntS = (IntType -> IntType -> ConvOp) -> IntType -> SubExp -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
(IntType -> IntType -> ConvOp) -> IntType -> SubExp -> m SubExp
asInt IntType -> IntType -> ConvOp
SExt

-- | Zero-extend to the given integer type.
asIntZ :: MonadBinder m => IntType -> SubExp -> m SubExp
asIntZ :: IntType -> SubExp -> m SubExp
asIntZ = (IntType -> IntType -> ConvOp) -> IntType -> SubExp -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
(IntType -> IntType -> ConvOp) -> IntType -> SubExp -> m SubExp
asInt IntType -> IntType -> ConvOp
ZExt

asInt :: MonadBinder m =>
         (IntType -> IntType -> ConvOp) -> IntType -> SubExp -> m SubExp
asInt :: (IntType -> IntType -> ConvOp) -> IntType -> SubExp -> m SubExp
asInt IntType -> IntType -> ConvOp
ext IntType
to_it SubExp
e = do
  Type
e_t <- SubExp -> m Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
e
  case Type
e_t of
    Prim (IntType IntType
from_it)
      | IntType
to_it IntType -> IntType -> Bool
forall a. Eq a => a -> a -> Bool
== IntType
from_it -> SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
e
      | Bool
otherwise -> String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
s (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> IntType -> ConvOp
ext IntType
from_it IntType
to_it) SubExp
e
    Type
_ -> String -> m SubExp
forall a. HasCallStack => String -> a
error String
"asInt: wrong type"
  where s :: String
s = case SubExp
e of Var VName
v -> VName -> String
baseString VName
v
                      SubExp
_     -> String
"to_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ IntType -> String
forall a. Pretty a => a -> String
pretty IntType
to_it


-- | Apply a binary operator to several subexpressions.  A left-fold.
foldBinOp :: MonadBinder m =>
             BinOp -> SubExp -> [SubExp] -> m (Exp (Lore m))
foldBinOp :: BinOp -> SubExp -> [SubExp] -> m (Exp (Lore m))
foldBinOp BinOp
_ SubExp
ne [] =
  Exp (Lore m) -> m (Exp (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp (Lore m) -> m (Exp (Lore m)))
-> Exp (Lore m) -> m (Exp (Lore m))
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
ne
foldBinOp BinOp
bop SubExp
ne (SubExp
e:[SubExp]
es) =
  BinOp -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
eBinOp BinOp
bop (Exp (Lore m) -> m (Exp (Lore m))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp (Lore m) -> m (Exp (Lore m)))
-> Exp (Lore m) -> m (Exp (Lore m))
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
e) (BinOp -> SubExp -> [SubExp] -> m (Exp (Lore m))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Lore m))
foldBinOp BinOp
bop SubExp
ne [SubExp]
es)

-- | True if all operands are true.
eAll :: MonadBinder m => [SubExp] -> m (Exp (Lore m))
eAll :: [SubExp] -> m (Exp (Lore m))
eAll [] = Exp (Lore m) -> m (Exp (Lore m))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp (Lore m) -> m (Exp (Lore m)))
-> Exp (Lore m) -> m (Exp (Lore m))
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Bool -> SubExp
forall v. IsValue v => v -> SubExp
constant Bool
True
eAll (SubExp
x:[SubExp]
xs) = BinOp -> SubExp -> [SubExp] -> m (Exp (Lore m))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Lore m))
foldBinOp BinOp
LogAnd SubExp
x [SubExp]
xs

-- | Create a two-parameter lambda whose body applies the given binary
-- operation to its arguments.  It is assumed that both argument and
-- result types are the same.  (This assumption should be fixed at
-- some point.)
binOpLambda :: (MonadBinder m, Bindable (Lore m)) =>
               BinOp -> PrimType -> m (Lambda (Lore m))
binOpLambda :: BinOp -> PrimType -> m (Lambda (Lore m))
binOpLambda BinOp
bop PrimType
t = (SubExp -> SubExp -> BasicOp)
-> PrimType -> PrimType -> m (Lambda (Lore m))
forall (m :: * -> *).
(MonadBinder m, Bindable (Lore m)) =>
(SubExp -> SubExp -> BasicOp)
-> PrimType -> PrimType -> m (Lambda (Lore m))
binLambda (BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
bop) PrimType
t PrimType
t

-- | As 'binOpLambda', but for t'CmpOp's.
cmpOpLambda :: (MonadBinder m, Bindable (Lore m)) =>
               CmpOp -> m (Lambda (Lore m))
cmpOpLambda :: CmpOp -> m (Lambda (Lore m))
cmpOpLambda CmpOp
cop = (SubExp -> SubExp -> BasicOp)
-> PrimType -> PrimType -> m (Lambda (Lore m))
forall (m :: * -> *).
(MonadBinder m, Bindable (Lore m)) =>
(SubExp -> SubExp -> BasicOp)
-> PrimType -> PrimType -> m (Lambda (Lore m))
binLambda (CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp CmpOp
cop) (CmpOp -> PrimType
cmpOpType CmpOp
cop) PrimType
Bool

binLambda :: (MonadBinder m, Bindable (Lore m)) =>
             (SubExp -> SubExp -> BasicOp) -> PrimType -> PrimType
          -> m (Lambda (Lore m))
binLambda :: (SubExp -> SubExp -> BasicOp)
-> PrimType -> PrimType -> m (Lambda (Lore m))
binLambda SubExp -> SubExp -> BasicOp
bop PrimType
arg_t PrimType
ret_t = do
  VName
x   <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"x"
  VName
y   <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"y"
  BodyT (Lore m)
body <- m (BodyT (Lore m)) -> m (BodyT (Lore m))
forall (m :: * -> *).
MonadBinder m =>
m (Body (Lore m)) -> m (Body (Lore m))
insertStmsM (m (BodyT (Lore m)) -> m (BodyT (Lore m)))
-> m (BodyT (Lore m)) -> m (BodyT (Lore m))
forall a b. (a -> b) -> a -> b
$ do
    SubExp
res <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"binlam_res" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp -> BasicOp
bop (VName -> SubExp
Var VName
x) (VName -> SubExp
Var VName
y)
    BodyT (Lore m) -> m (BodyT (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (BodyT (Lore m) -> m (BodyT (Lore m)))
-> BodyT (Lore m) -> m (BodyT (Lore m))
forall a b. (a -> b) -> a -> b
$ [SubExp] -> BodyT (Lore m)
forall lore. Bindable lore => [SubExp] -> Body lore
resultBody [SubExp
res]
  Lambda (Lore m) -> m (Lambda (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return Lambda :: forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda {
             lambdaParams :: [LParam (Lore m)]
lambdaParams     = [VName -> Type -> Param Type
forall dec. VName -> dec -> Param dec
Param VName
x (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
arg_t),
                                 VName -> Type -> Param Type
forall dec. VName -> dec -> Param dec
Param VName
y (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
arg_t)]
           , lambdaReturnType :: [Type]
lambdaReturnType = [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
ret_t]
           , lambdaBody :: BodyT (Lore m)
lambdaBody       = BodyT (Lore m)
body
           }

-- | Slice a full dimension of the given size.
sliceDim :: SubExp -> DimIndex SubExp
sliceDim :: SubExp -> DimIndex SubExp
sliceDim SubExp
d = SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice (Int32 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int32
0::Int32)) SubExp
d (Int32 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int32
1::Int32))

-- | @fullSlice t slice@ returns @slice@, but with 'DimSlice's of
-- entire dimensions appended to the full dimensionality of @t@.  This
-- function is used to turn incomplete indexing complete, as required
-- by 'Index'.
fullSlice :: Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice :: Type -> Slice SubExp -> Slice SubExp
fullSlice Type
t Slice SubExp
slice =
  Slice SubExp
slice Slice SubExp -> Slice SubExp -> Slice SubExp
forall a. [a] -> [a] -> [a]
++ (SubExp -> DimIndex SubExp) -> [SubExp] -> Slice SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
sliceDim (Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
drop (Slice SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Slice SubExp
slice) ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t)

-- | @ sliceAt t n slice@ returns @slice@ but with 'DimSlice's of the
-- outer @n@ dimensions prepended, and as many appended as to make it
-- a full slice.  This is a generalisation of 'fullSlice'.
sliceAt :: Type -> Int -> [DimIndex SubExp] -> Slice SubExp
sliceAt :: Type -> Int -> Slice SubExp -> Slice SubExp
sliceAt Type
t Int
n Slice SubExp
slice =
  Type -> Slice SubExp -> Slice SubExp
fullSlice Type
t (Slice SubExp -> Slice SubExp) -> Slice SubExp -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ (SubExp -> DimIndex SubExp) -> [SubExp] -> Slice SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
sliceDim (Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
take Int
n ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t) Slice SubExp -> Slice SubExp -> Slice SubExp
forall a. [a] -> [a] -> [a]
++ Slice SubExp
slice

-- | Like 'fullSlice', but the dimensions are simply numeric.
fullSliceNum :: Num d => [d] -> [DimIndex d] -> Slice d
fullSliceNum :: [d] -> [DimIndex d] -> [DimIndex d]
fullSliceNum [d]
dims [DimIndex d]
slice =
  [DimIndex d]
slice [DimIndex d] -> [DimIndex d] -> [DimIndex d]
forall a. [a] -> [a] -> [a]
++ (d -> DimIndex d) -> [d] -> [DimIndex d]
forall a b. (a -> b) -> [a] -> [b]
map (\d
d -> d -> d -> d -> DimIndex d
forall d. d -> d -> d -> DimIndex d
DimSlice d
0 d
d d
1) (Int -> [d] -> [d]
forall a. Int -> [a] -> [a]
drop ([DimIndex d] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex d]
slice) [d]
dims)

-- | Does the slice describe the full size of the array?  The most
-- obvious such slice is one that 'DimSlice's the full span of every
-- dimension, but also one that fixes all unit dimensions.
isFullSlice :: Shape -> Slice SubExp -> Bool
isFullSlice :: Shape -> Slice SubExp -> Bool
isFullSlice Shape
shape Slice SubExp
slice = [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ([Bool] -> Bool) -> [Bool] -> Bool
forall a b. (a -> b) -> a -> b
$ (SubExp -> DimIndex SubExp -> Bool)
-> [SubExp] -> Slice SubExp -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SubExp -> DimIndex SubExp -> Bool
allOfIt (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape) Slice SubExp
slice
  where allOfIt :: SubExp -> DimIndex SubExp -> Bool
allOfIt (Constant PrimValue
v) DimFix{} = PrimValue -> Bool
oneIsh PrimValue
v
        allOfIt SubExp
d (DimSlice SubExp
_ SubExp
n SubExp
_) = SubExp
d SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
n
        allOfIt SubExp
_ DimIndex SubExp
_ = Bool
False

ifCommon :: [Type] -> IfDec ExtType
ifCommon :: [Type] -> IfDec ExtType
ifCommon [Type]
ts = [ExtType] -> IfSort -> IfDec ExtType
forall rt. [rt] -> IfSort -> IfDec rt
IfDec ([Type] -> [ExtType]
forall u. [TypeBase Shape u] -> [TypeBase ExtShape u]
staticShapes [Type]
ts) IfSort
IfNormal

-- | Conveniently construct a body that contains no bindings.
resultBody :: Bindable lore => [SubExp] -> Body lore
resultBody :: [SubExp] -> Body lore
resultBody = Stms lore -> [SubExp] -> Body lore
forall lore. Bindable lore => Stms lore -> [SubExp] -> Body lore
mkBody Stms lore
forall a. Monoid a => a
mempty

-- | Conveniently construct a body that contains no bindings - but
-- this time, monadically!
resultBodyM :: MonadBinder m =>
               [SubExp]
            -> m (Body (Lore m))
resultBodyM :: [SubExp] -> m (Body (Lore m))
resultBodyM = Stms (Lore m) -> [SubExp] -> m (Body (Lore m))
forall (m :: * -> *).
MonadBinder m =>
Stms (Lore m) -> [SubExp] -> m (Body (Lore m))
mkBodyM Stms (Lore m)
forall a. Monoid a => a
mempty

-- | Evaluate the action, producing a body, then wrap it in all the
-- bindings it created using 'addStm'.
insertStmsM :: (MonadBinder m) =>
               m (Body (Lore m)) -> m (Body (Lore m))
insertStmsM :: m (Body (Lore m)) -> m (Body (Lore m))
insertStmsM m (Body (Lore m))
m = do
  (Body BodyDec (Lore m)
_ Stms (Lore m)
bnds [SubExp]
res, Stms (Lore m)
otherbnds) <- m (Body (Lore m)) -> m (Body (Lore m), Stms (Lore m))
forall (m :: * -> *) a.
MonadBinder m =>
m a -> m (a, Stms (Lore m))
collectStms m (Body (Lore m))
m
  Stms (Lore m) -> [SubExp] -> m (Body (Lore m))
forall (m :: * -> *).
MonadBinder m =>
Stms (Lore m) -> [SubExp] -> m (Body (Lore m))
mkBodyM (Stms (Lore m)
otherbnds Stms (Lore m) -> Stms (Lore m) -> Stms (Lore m)
forall a. Semigroup a => a -> a -> a
<> Stms (Lore m)
bnds) [SubExp]
res

-- | Change that result where evaluation of the body would stop.  Also
-- change type annotations at branches.
mapResult :: Bindable lore =>
             (Result -> Body lore) -> Body lore -> Body lore
mapResult :: ([SubExp] -> Body lore) -> Body lore -> Body lore
mapResult [SubExp] -> Body lore
f (Body BodyDec lore
_ Stms lore
bnds [SubExp]
res) =
  let Body BodyDec lore
_ Stms lore
bnds2 [SubExp]
newres = [SubExp] -> Body lore
f [SubExp]
res
  in Stms lore -> [SubExp] -> Body lore
forall lore. Bindable lore => Stms lore -> [SubExp] -> Body lore
mkBody (Stms lore
bndsStms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<>Stms lore
bnds2) [SubExp]
newres

-- | Instantiate all existential parts dimensions of the given
-- type, using a monadic action to create the necessary t'SubExp's.
-- You should call this function within some monad that allows you to
-- collect the actions performed (say, 'Writer').
instantiateShapes :: Monad m =>
                     (Int -> m SubExp)
                  -> [TypeBase ExtShape u]
                  -> m [TypeBase Shape u]
instantiateShapes :: (Int -> m SubExp) -> [TypeBase ExtShape u] -> m [TypeBase Shape u]
instantiateShapes Int -> m SubExp
f [TypeBase ExtShape u]
ts = StateT (Map Int SubExp) m [TypeBase Shape u]
-> Map Int SubExp -> m [TypeBase Shape u]
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT ((TypeBase ExtShape u
 -> StateT (Map Int SubExp) m (TypeBase Shape u))
-> [TypeBase ExtShape u]
-> StateT (Map Int SubExp) m [TypeBase Shape u]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM TypeBase ExtShape u -> StateT (Map Int SubExp) m (TypeBase Shape u)
instantiate [TypeBase ExtShape u]
ts) Map Int SubExp
forall k a. Map k a
M.empty
  where instantiate :: TypeBase ExtShape u -> StateT (Map Int SubExp) m (TypeBase Shape u)
instantiate TypeBase ExtShape u
t = do
          [SubExp]
shape <- (Ext SubExp -> StateT (Map Int SubExp) m SubExp)
-> [Ext SubExp] -> StateT (Map Int SubExp) m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Ext SubExp -> StateT (Map Int SubExp) m SubExp
instantiate' ([Ext SubExp] -> StateT (Map Int SubExp) m [SubExp])
-> [Ext SubExp] -> StateT (Map Int SubExp) m [SubExp]
forall a b. (a -> b) -> a -> b
$ ExtShape -> [Ext SubExp]
forall d. ShapeBase d -> [d]
shapeDims (ExtShape -> [Ext SubExp]) -> ExtShape -> [Ext SubExp]
forall a b. (a -> b) -> a -> b
$ TypeBase ExtShape u -> ExtShape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase ExtShape u
t
          TypeBase Shape u -> StateT (Map Int SubExp) m (TypeBase Shape u)
forall (m :: * -> *) a. Monad m => a -> m a
return (TypeBase Shape u -> StateT (Map Int SubExp) m (TypeBase Shape u))
-> TypeBase Shape u -> StateT (Map Int SubExp) m (TypeBase Shape u)
forall a b. (a -> b) -> a -> b
$ TypeBase ExtShape u
t TypeBase ExtShape u -> Shape -> TypeBase Shape u
forall newshape oldshape u.
ArrayShape newshape =>
TypeBase oldshape u -> newshape -> TypeBase newshape u
`setArrayShape` [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
shape
        instantiate' :: Ext SubExp -> StateT (Map Int SubExp) m SubExp
instantiate' (Ext Int
x) = do
          Map Int SubExp
m <- StateT (Map Int SubExp) m (Map Int SubExp)
forall s (m :: * -> *). MonadState s m => m s
get
          case Int -> Map Int SubExp -> Maybe SubExp
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Int
x Map Int SubExp
m of
            Just SubExp
se -> SubExp -> StateT (Map Int SubExp) m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
se
            Maybe SubExp
Nothing -> do SubExp
se <- m SubExp -> StateT (Map Int SubExp) m SubExp
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m SubExp -> StateT (Map Int SubExp) m SubExp)
-> m SubExp -> StateT (Map Int SubExp) m SubExp
forall a b. (a -> b) -> a -> b
$ Int -> m SubExp
f Int
x
                          Map Int SubExp -> StateT (Map Int SubExp) m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (Map Int SubExp -> StateT (Map Int SubExp) m ())
-> Map Int SubExp -> StateT (Map Int SubExp) m ()
forall a b. (a -> b) -> a -> b
$ Int -> SubExp -> Map Int SubExp -> Map Int SubExp
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Int
x SubExp
se Map Int SubExp
m
                          SubExp -> StateT (Map Int SubExp) m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
se
        instantiate' (Free SubExp
se) = SubExp -> StateT (Map Int SubExp) m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
se

instantiateShapes' :: MonadFreshNames m =>
                      [TypeBase ExtShape u]
                   -> m ([TypeBase Shape u], [Ident])
instantiateShapes' :: [TypeBase ExtShape u] -> m ([TypeBase Shape u], [Ident])
instantiateShapes' [TypeBase ExtShape u]
ts =
  WriterT [Ident] m [TypeBase Shape u]
-> m ([TypeBase Shape u], [Ident])
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT [Ident] m [TypeBase Shape u]
 -> m ([TypeBase Shape u], [Ident]))
-> WriterT [Ident] m [TypeBase Shape u]
-> m ([TypeBase Shape u], [Ident])
forall a b. (a -> b) -> a -> b
$ (Int -> WriterT [Ident] m SubExp)
-> [TypeBase ExtShape u] -> WriterT [Ident] m [TypeBase Shape u]
forall (m :: * -> *) u.
Monad m =>
(Int -> m SubExp) -> [TypeBase ExtShape u] -> m [TypeBase Shape u]
instantiateShapes Int -> WriterT [Ident] m SubExp
forall (t :: (* -> *) -> * -> *) (m :: * -> *) p.
(MonadTrans t, MonadFreshNames m, MonadWriter [Ident] (t m)) =>
p -> t m SubExp
instantiate [TypeBase ExtShape u]
ts
  where instantiate :: p -> t m SubExp
instantiate p
_ = do Ident
v <- m Ident -> t m Ident
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Ident -> t m Ident) -> m Ident -> t m Ident
forall a b. (a -> b) -> a -> b
$ String -> Type -> m Ident
forall (m :: * -> *).
MonadFreshNames m =>
String -> Type -> m Ident
newIdent String
"size" (Type -> m Ident) -> Type -> m Ident
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int32
                           [Ident] -> t m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell [Ident
v]
                           SubExp -> t m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> t m SubExp) -> SubExp -> t m SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Ident -> VName
identName Ident
v

removeExistentials :: ExtType -> Type -> Type
removeExistentials :: ExtType -> Type -> Type
removeExistentials ExtType
t1 Type
t2 =
  ExtType
t1 ExtType -> [SubExp] -> Type
forall oldshape u.
TypeBase oldshape u -> [SubExp] -> TypeBase Shape u
`setArrayDims`
  (Ext SubExp -> SubExp -> SubExp)
-> [Ext SubExp] -> [SubExp] -> [SubExp]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Ext SubExp -> SubExp -> SubExp
forall p. Ext p -> p -> p
nonExistential
  (ExtShape -> [Ext SubExp]
forall d. ShapeBase d -> [d]
shapeDims (ExtShape -> [Ext SubExp]) -> ExtShape -> [Ext SubExp]
forall a b. (a -> b) -> a -> b
$ ExtType -> ExtShape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape ExtType
t1)
  (Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t2)
  where nonExistential :: Ext p -> p -> p
nonExistential (Ext Int
_)    p
dim = p
dim
        nonExistential (Free p
dim) p
_   = p
dim

-- | Can be used as the definition of 'mkLetNames' for a 'Bindable'
-- instance for simple representations.
simpleMkLetNames :: (ExpDec lore ~ (), LetDec lore ~ Type,
                     MonadFreshNames m, TypedOp (Op lore), HasScope lore m) =>
                    [VName] -> Exp lore -> m (Stm lore)
simpleMkLetNames :: [VName] -> Exp lore -> m (Stm lore)
simpleMkLetNames [VName]
names Exp lore
e = do
  [ExtType]
et <- Exp lore -> m [ExtType]
forall lore (m :: * -> *).
(HasScope lore m, TypedOp (Op lore)) =>
Exp lore -> m [ExtType]
expExtType Exp lore
e
  ([Type]
ts, [Ident]
shapes) <- [ExtType] -> m ([Type], [Ident])
forall (m :: * -> *) u.
MonadFreshNames m =>
[TypeBase ExtShape u] -> m ([TypeBase Shape u], [Ident])
instantiateShapes' [ExtType]
et
  let shapeElems :: [PatElemT Type]
shapeElems = [ VName -> Type -> PatElemT Type
forall dec. VName -> dec -> PatElemT dec
PatElem VName
shape Type
shapet | Ident VName
shape Type
shapet <- [Ident]
shapes ]
  let valElems :: [PatElemT Type]
valElems = (VName -> Type -> PatElemT Type)
-> [VName] -> [Type] -> [PatElemT Type]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> Type -> PatElemT Type
forall dec. VName -> dec -> PatElemT dec
PatElem [VName]
names [Type]
ts
  Stm lore -> m (Stm lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stm lore -> m (Stm lore)) -> Stm lore -> m (Stm lore)
forall a b. (a -> b) -> a -> b
$ Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let ([PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [PatElemT Type]
shapeElems [PatElemT Type]
valElems) (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) Exp lore
e

-- | Instances of this class can be converted to Futhark expressions
-- within a 'MonadBinder'.
class ToExp a where
  toExp :: MonadBinder m => a -> m (Exp (Lore m))

instance ToExp SubExp where
  toExp :: SubExp -> m (Exp (Lore m))
toExp = Exp (Lore m) -> m (Exp (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp (Lore m) -> m (Exp (Lore m)))
-> (SubExp -> Exp (Lore m)) -> SubExp -> m (Exp (Lore m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m))
-> (SubExp -> BasicOp) -> SubExp -> Exp (Lore m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp

instance ToExp VName where
  toExp :: VName -> m (Exp (Lore m))
toExp = Exp (Lore m) -> m (Exp (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp (Lore m) -> m (Exp (Lore m)))
-> (VName -> Exp (Lore m)) -> VName -> m (Exp (Lore m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m))
-> (VName -> BasicOp) -> VName -> Exp (Lore m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> (VName -> SubExp) -> VName -> BasicOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var

-- | A convenient composition of 'letSubExp' and 'toExp'.
toSubExp :: (MonadBinder m, ToExp a) => String -> a -> m SubExp
toSubExp :: String -> a -> m SubExp
toSubExp String
s a
e = String -> ExpT (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
s (ExpT (Lore m) -> m SubExp) -> m (ExpT (Lore m)) -> m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< a -> m (ExpT (Lore m))
forall a (m :: * -> *).
(ToExp a, MonadBinder m) =>
a -> m (Exp (Lore m))
toExp a
e