{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE Strict #-}
{-# LANGUAGE TypeFamilies #-}

module Futhark.CodeGen.ImpGen
  ( -- * Entry Points
    compileProg,

    -- * Pluggable Compiler
    OpCompiler,
    ExpCompiler,
    CopyCompiler,
    StmsCompiler,
    AllocCompiler,
    Operations (..),
    defaultOperations,
    MemLoc (..),
    sliceMemLoc,
    MemEntry (..),
    ScalarEntry (..),

    -- * Monadic Compiler Interface
    ImpM,
    localDefaultSpace,
    askFunction,
    newVNameForFun,
    nameForFun,
    askEnv,
    localEnv,
    localOps,
    VTable,
    getVTable,
    localVTable,
    subImpM,
    subImpM_,
    emit,
    emitFunction,
    hasFunction,
    collect,
    collect',
    comment,
    VarEntry (..),
    ArrayEntry (..),

    -- * Lookups
    lookupVar,
    lookupArray,
    lookupMemory,
    lookupAcc,
    askAttrs,

    -- * Building Blocks
    TV,
    mkTV,
    tvSize,
    tvExp,
    tvVar,
    ToExp (..),
    compileAlloc,
    everythingVolatile,
    compileBody,
    compileBody',
    compileLoopBody,
    defCompileStms,
    compileStms,
    compileExp,
    defCompileExp,
    fullyIndexArray,
    fullyIndexArray',
    copy,
    copyDWIM,
    copyDWIMFix,
    lmadCopy,
    typeSize,
    inBounds,
    caseMatch,

    -- * Constructing code.
    dLParams,
    dFParams,
    addLoopVar,
    dScope,
    dArray,
    dPrim,
    dPrimVol,
    dPrim_,
    dPrimV_,
    dPrimV,
    dPrimVE,
    dIndexSpace,
    dIndexSpace',
    sFor,
    sWhile,
    sComment,
    sIf,
    sWhen,
    sUnless,
    sOp,
    sDeclareMem,
    sAlloc,
    sAlloc_,
    sArray,
    sArrayInMem,
    sAllocArray,
    sAllocArrayPerm,
    sStaticArray,
    sWrite,
    sUpdate,
    sLoopNest,
    sLoopSpace,
    (<--),
    (<~~),
    function,
    genConstants,
    warn,
    module Language.Futhark.Warnings,
  )
where

import Control.Monad
import Control.Monad.Reader
import Control.Monad.State
import Control.Parallel.Strategies
import Data.Bifunctor (first)
import Data.DList qualified as DL
import Data.Either
import Data.List (find)
import Data.List.NonEmpty (NonEmpty (..))
import Data.Map.Strict qualified as M
import Data.Maybe
import Data.Set qualified as S
import Data.String
import Data.Text qualified as T
import Futhark.CodeGen.ImpCode
  ( Bytes,
    Count,
    Elements,
    elements,
  )
import Futhark.CodeGen.ImpCode qualified as Imp
import Futhark.Construct hiding (ToExp (..))
import Futhark.IR.Mem
import Futhark.IR.Mem.IxFun qualified as IxFun
import Futhark.IR.Mem.LMAD qualified as LMAD
import Futhark.IR.SOACS (SOACS)
import Futhark.Util
import Futhark.Util.IntegralExp
import Futhark.Util.Loc (noLoc)
import Futhark.Util.Pretty hiding (nest, space)
import Language.Futhark.Warnings
import Prelude hiding (mod, quot)

-- | How to compile an t'Op'.
type OpCompiler rep r op = Pat (LetDec rep) -> Op rep -> ImpM rep r op ()

-- | How to compile some 'Stms'.
type StmsCompiler rep r op = Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()

-- | How to compile an 'Exp'.
type ExpCompiler rep r op = Pat (LetDec rep) -> Exp rep -> ImpM rep r op ()

type CopyCompiler rep r op =
  PrimType ->
  MemLoc ->
  MemLoc ->
  ImpM rep r op ()

-- | An alternate way of compiling an allocation.
type AllocCompiler rep r op = VName -> Count Bytes (Imp.TExp Int64) -> ImpM rep r op ()

data Operations rep r op = Operations
  { forall rep r op. Operations rep r op -> ExpCompiler rep r op
opsExpCompiler :: ExpCompiler rep r op,
    forall rep r op. Operations rep r op -> OpCompiler rep r op
opsOpCompiler :: OpCompiler rep r op,
    forall rep r op. Operations rep r op -> StmsCompiler rep r op
opsStmsCompiler :: StmsCompiler rep r op,
    forall rep r op. Operations rep r op -> CopyCompiler rep r op
opsCopyCompiler :: CopyCompiler rep r op,
    forall rep r op.
Operations rep r op -> Map Space (AllocCompiler rep r op)
opsAllocCompilers :: M.Map Space (AllocCompiler rep r op)
  }

-- | An operations set for which the expression compiler always
-- returns 'defCompileExp'.
defaultOperations ::
  (Mem rep inner, FreeIn op) =>
  OpCompiler rep r op ->
  Operations rep r op
defaultOperations :: forall rep (inner :: * -> *) op r.
(Mem rep inner, FreeIn op) =>
OpCompiler rep r op -> Operations rep r op
defaultOperations OpCompiler rep r op
opc =
  Operations
    { opsExpCompiler :: ExpCompiler rep r op
opsExpCompiler = ExpCompiler rep r op
forall rep (inner :: * -> *) r op.
Mem rep inner =>
Pat (LetDec rep) -> Exp rep -> ImpM rep r op ()
defCompileExp,
      opsOpCompiler :: OpCompiler rep r op
opsOpCompiler = OpCompiler rep r op
opc,
      opsStmsCompiler :: StmsCompiler rep r op
opsStmsCompiler = StmsCompiler rep r op
forall rep (inner :: * -> *) op r.
(Mem rep inner, FreeIn op) =>
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
defCompileStms,
      opsCopyCompiler :: CopyCompiler rep r op
opsCopyCompiler = CopyCompiler rep r op
forall rep r op. CopyCompiler rep r op
lmadCopy,
      opsAllocCompilers :: Map Space (AllocCompiler rep r op)
opsAllocCompilers = Map Space (AllocCompiler rep r op)
forall a. Monoid a => a
mempty
    }

-- | When an array is declared, this is where it is stored.
data MemLoc = MemLoc
  { MemLoc -> VName
memLocName :: VName,
    MemLoc -> [SubExp]
memLocShape :: [Imp.DimSize],
    MemLoc -> LMAD (TExp Int64)
memLocLMAD :: LMAD.LMAD (Imp.TExp Int64)
  }
  deriving (MemLoc -> MemLoc -> Bool
(MemLoc -> MemLoc -> Bool)
-> (MemLoc -> MemLoc -> Bool) -> Eq MemLoc
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: MemLoc -> MemLoc -> Bool
== :: MemLoc -> MemLoc -> Bool
$c/= :: MemLoc -> MemLoc -> Bool
/= :: MemLoc -> MemLoc -> Bool
Eq, Int -> MemLoc -> ShowS
[MemLoc] -> ShowS
MemLoc -> [Char]
(Int -> MemLoc -> ShowS)
-> (MemLoc -> [Char]) -> ([MemLoc] -> ShowS) -> Show MemLoc
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> MemLoc -> ShowS
showsPrec :: Int -> MemLoc -> ShowS
$cshow :: MemLoc -> [Char]
show :: MemLoc -> [Char]
$cshowList :: [MemLoc] -> ShowS
showList :: [MemLoc] -> ShowS
Show)

sliceMemLoc :: MemLoc -> Slice (Imp.TExp Int64) -> MemLoc
sliceMemLoc :: MemLoc -> Slice (TExp Int64) -> MemLoc
sliceMemLoc (MemLoc VName
mem [SubExp]
shape LMAD (TExp Int64)
lmad) Slice (TExp Int64)
slice =
  VName -> [SubExp] -> LMAD (TExp Int64) -> MemLoc
MemLoc VName
mem [SubExp]
shape (LMAD (TExp Int64) -> MemLoc) -> LMAD (TExp Int64) -> MemLoc
forall a b. (a -> b) -> a -> b
$ LMAD (TExp Int64) -> Slice (TExp Int64) -> LMAD (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
LMAD num -> Slice num -> LMAD num
LMAD.slice LMAD (TExp Int64)
lmad Slice (TExp Int64)
slice

flatSliceMemLoc :: MemLoc -> FlatSlice (Imp.TExp Int64) -> MemLoc
flatSliceMemLoc :: MemLoc -> FlatSlice (TExp Int64) -> MemLoc
flatSliceMemLoc (MemLoc VName
mem [SubExp]
shape LMAD (TExp Int64)
lmad) FlatSlice (TExp Int64)
slice =
  VName -> [SubExp] -> LMAD (TExp Int64) -> MemLoc
MemLoc VName
mem [SubExp]
shape (LMAD (TExp Int64) -> MemLoc) -> LMAD (TExp Int64) -> MemLoc
forall a b. (a -> b) -> a -> b
$ LMAD (TExp Int64) -> FlatSlice (TExp Int64) -> LMAD (TExp Int64)
forall num.
IntegralExp num =>
LMAD num -> FlatSlice num -> LMAD num
LMAD.flatSlice LMAD (TExp Int64)
lmad FlatSlice (TExp Int64)
slice

data ArrayEntry = ArrayEntry
  { ArrayEntry -> MemLoc
entryArrayLoc :: MemLoc,
    ArrayEntry -> PrimType
entryArrayElemType :: PrimType
  }
  deriving (Int -> ArrayEntry -> ShowS
[ArrayEntry] -> ShowS
ArrayEntry -> [Char]
(Int -> ArrayEntry -> ShowS)
-> (ArrayEntry -> [Char])
-> ([ArrayEntry] -> ShowS)
-> Show ArrayEntry
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ArrayEntry -> ShowS
showsPrec :: Int -> ArrayEntry -> ShowS
$cshow :: ArrayEntry -> [Char]
show :: ArrayEntry -> [Char]
$cshowList :: [ArrayEntry] -> ShowS
showList :: [ArrayEntry] -> ShowS
Show)

entryArrayShape :: ArrayEntry -> [Imp.DimSize]
entryArrayShape :: ArrayEntry -> [SubExp]
entryArrayShape = MemLoc -> [SubExp]
memLocShape (MemLoc -> [SubExp])
-> (ArrayEntry -> MemLoc) -> ArrayEntry -> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArrayEntry -> MemLoc
entryArrayLoc

newtype MemEntry = MemEntry {MemEntry -> Space
entryMemSpace :: Imp.Space}
  deriving (Int -> MemEntry -> ShowS
[MemEntry] -> ShowS
MemEntry -> [Char]
(Int -> MemEntry -> ShowS)
-> (MemEntry -> [Char]) -> ([MemEntry] -> ShowS) -> Show MemEntry
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> MemEntry -> ShowS
showsPrec :: Int -> MemEntry -> ShowS
$cshow :: MemEntry -> [Char]
show :: MemEntry -> [Char]
$cshowList :: [MemEntry] -> ShowS
showList :: [MemEntry] -> ShowS
Show)

newtype ScalarEntry = ScalarEntry
  { ScalarEntry -> PrimType
entryScalarType :: PrimType
  }
  deriving (Int -> ScalarEntry -> ShowS
[ScalarEntry] -> ShowS
ScalarEntry -> [Char]
(Int -> ScalarEntry -> ShowS)
-> (ScalarEntry -> [Char])
-> ([ScalarEntry] -> ShowS)
-> Show ScalarEntry
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ScalarEntry -> ShowS
showsPrec :: Int -> ScalarEntry -> ShowS
$cshow :: ScalarEntry -> [Char]
show :: ScalarEntry -> [Char]
$cshowList :: [ScalarEntry] -> ShowS
showList :: [ScalarEntry] -> ShowS
Show)

-- | Every non-scalar variable must be associated with an entry.
data VarEntry rep
  = ArrayVar (Maybe (Exp rep)) ArrayEntry
  | ScalarVar (Maybe (Exp rep)) ScalarEntry
  | MemVar (Maybe (Exp rep)) MemEntry
  | AccVar (Maybe (Exp rep)) (VName, Shape, [Type])
  deriving (Int -> VarEntry rep -> ShowS
[VarEntry rep] -> ShowS
VarEntry rep -> [Char]
(Int -> VarEntry rep -> ShowS)
-> (VarEntry rep -> [Char])
-> ([VarEntry rep] -> ShowS)
-> Show (VarEntry rep)
forall rep. RepTypes rep => Int -> VarEntry rep -> ShowS
forall rep. RepTypes rep => [VarEntry rep] -> ShowS
forall rep. RepTypes rep => VarEntry rep -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall rep. RepTypes rep => Int -> VarEntry rep -> ShowS
showsPrec :: Int -> VarEntry rep -> ShowS
$cshow :: forall rep. RepTypes rep => VarEntry rep -> [Char]
show :: VarEntry rep -> [Char]
$cshowList :: forall rep. RepTypes rep => [VarEntry rep] -> ShowS
showList :: [VarEntry rep] -> ShowS
Show)

data ValueDestination
  = ScalarDestination VName
  | MemoryDestination VName
  | -- | The 'MemLoc' is 'Just' if a copy if
    -- required.  If it is 'Nothing', then a
    -- copy/assignment of a memory block somewhere
    -- takes care of this array.
    ArrayDestination (Maybe MemLoc)
  deriving (Int -> ValueDestination -> ShowS
[ValueDestination] -> ShowS
ValueDestination -> [Char]
(Int -> ValueDestination -> ShowS)
-> (ValueDestination -> [Char])
-> ([ValueDestination] -> ShowS)
-> Show ValueDestination
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ValueDestination -> ShowS
showsPrec :: Int -> ValueDestination -> ShowS
$cshow :: ValueDestination -> [Char]
show :: ValueDestination -> [Char]
$cshowList :: [ValueDestination] -> ShowS
showList :: [ValueDestination] -> ShowS
Show)

data Env rep r op = Env
  { forall rep r op. Env rep r op -> ExpCompiler rep r op
envExpCompiler :: ExpCompiler rep r op,
    forall rep r op. Env rep r op -> StmsCompiler rep r op
envStmsCompiler :: StmsCompiler rep r op,
    forall rep r op. Env rep r op -> OpCompiler rep r op
envOpCompiler :: OpCompiler rep r op,
    forall rep r op. Env rep r op -> CopyCompiler rep r op
envCopyCompiler :: CopyCompiler rep r op,
    forall rep r op. Env rep r op -> Map Space (AllocCompiler rep r op)
envAllocCompilers :: M.Map Space (AllocCompiler rep r op),
    forall rep r op. Env rep r op -> Space
envDefaultSpace :: Imp.Space,
    forall rep r op. Env rep r op -> Volatility
envVolatility :: Imp.Volatility,
    -- | User-extensible environment.
    forall rep r op. Env rep r op -> r
envEnv :: r,
    -- | Name of the function we are compiling, if any.
    forall rep r op. Env rep r op -> Maybe Name
envFunction :: Maybe Name,
    -- | The set of attributes that are active on the enclosing
    -- statements (including the one we are currently compiling).
    forall rep r op. Env rep r op -> Attrs
envAttrs :: Attrs
  }

newEnv :: r -> Operations rep r op -> Imp.Space -> Env rep r op
newEnv :: forall r rep op. r -> Operations rep r op -> Space -> Env rep r op
newEnv r
r Operations rep r op
ops Space
ds =
  Env
    { envExpCompiler :: ExpCompiler rep r op
envExpCompiler = Operations rep r op -> ExpCompiler rep r op
forall rep r op. Operations rep r op -> ExpCompiler rep r op
opsExpCompiler Operations rep r op
ops,
      envStmsCompiler :: StmsCompiler rep r op
envStmsCompiler = Operations rep r op -> StmsCompiler rep r op
forall rep r op. Operations rep r op -> StmsCompiler rep r op
opsStmsCompiler Operations rep r op
ops,
      envOpCompiler :: OpCompiler rep r op
envOpCompiler = Operations rep r op -> OpCompiler rep r op
forall rep r op. Operations rep r op -> OpCompiler rep r op
opsOpCompiler Operations rep r op
ops,
      envCopyCompiler :: CopyCompiler rep r op
envCopyCompiler = Operations rep r op -> CopyCompiler rep r op
forall rep r op. Operations rep r op -> CopyCompiler rep r op
opsCopyCompiler Operations rep r op
ops,
      envAllocCompilers :: Map Space (AllocCompiler rep r op)
envAllocCompilers = Map Space (AllocCompiler rep r op)
forall a. Monoid a => a
mempty,
      envDefaultSpace :: Space
envDefaultSpace = Space
ds,
      envVolatility :: Volatility
envVolatility = Volatility
Imp.Nonvolatile,
      envEnv :: r
envEnv = r
r,
      envFunction :: Maybe Name
envFunction = Maybe Name
forall a. Maybe a
Nothing,
      envAttrs :: Attrs
envAttrs = Attrs
forall a. Monoid a => a
mempty
    }

-- | The symbol table used during compilation.
type VTable rep = M.Map VName (VarEntry rep)

data ImpState rep r op = ImpState
  { forall {k} rep (r :: k) op. ImpState rep r op -> VTable rep
stateVTable :: VTable rep,
    forall {k} rep (r :: k) op. ImpState rep r op -> Functions op
stateFunctions :: Imp.Functions op,
    forall {k} rep (r :: k) op. ImpState rep r op -> Code op
stateCode :: Imp.Code op,
    forall {k} rep (r :: k) op. ImpState rep r op -> Constants op
stateConstants :: Imp.Constants op,
    forall {k} rep (r :: k) op. ImpState rep r op -> Warnings
stateWarnings :: Warnings,
    -- | Maps the arrays backing each accumulator to their
    -- update function and neutral elements.  This works
    -- because an array name can only become part of a single
    -- accumulator throughout its lifetime.  If the arrays
    -- backing an accumulator is not in this mapping, the
    -- accumulator is scatter-like.
    forall {k} rep (r :: k) op.
ImpState rep r op
-> Map VName ([VName], Maybe (Lambda rep, [SubExp]))
stateAccs :: M.Map VName ([VName], Maybe (Lambda rep, [SubExp])),
    forall {k} rep (r :: k) op. ImpState rep r op -> VNameSource
stateNameSource :: VNameSource
  }

newState :: VNameSource -> ImpState rep r op
newState :: forall {k} rep (r :: k) op. VNameSource -> ImpState rep r op
newState = VTable rep
-> Functions op
-> Code op
-> Constants op
-> Warnings
-> Map VName ([VName], Maybe (Lambda rep, [SubExp]))
-> VNameSource
-> ImpState rep r op
forall {k} rep (r :: k) op.
VTable rep
-> Functions op
-> Code op
-> Constants op
-> Warnings
-> Map VName ([VName], Maybe (Lambda rep, [SubExp]))
-> VNameSource
-> ImpState rep r op
ImpState VTable rep
forall a. Monoid a => a
mempty Functions op
forall a. Monoid a => a
mempty Code op
forall a. Monoid a => a
mempty Constants op
forall a. Monoid a => a
mempty Warnings
forall a. Monoid a => a
mempty Map VName ([VName], Maybe (Lambda rep, [SubExp]))
forall a. Monoid a => a
mempty

newtype ImpM rep r op a
  = ImpM (ReaderT (Env rep r op) (State (ImpState rep r op)) a)
  deriving
    ( (forall a b. (a -> b) -> ImpM rep r op a -> ImpM rep r op b)
-> (forall a b. a -> ImpM rep r op b -> ImpM rep r op a)
-> Functor (ImpM rep r op)
forall a b. a -> ImpM rep r op b -> ImpM rep r op a
forall a b. (a -> b) -> ImpM rep r op a -> ImpM rep r op b
forall rep r op a b. a -> ImpM rep r op b -> ImpM rep r op a
forall rep r op a b. (a -> b) -> ImpM rep r op a -> ImpM rep r op b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall rep r op a b. (a -> b) -> ImpM rep r op a -> ImpM rep r op b
fmap :: forall a b. (a -> b) -> ImpM rep r op a -> ImpM rep r op b
$c<$ :: forall rep r op a b. a -> ImpM rep r op b -> ImpM rep r op a
<$ :: forall a b. a -> ImpM rep r op b -> ImpM rep r op a
Functor,
      Functor (ImpM rep r op)
Functor (ImpM rep r op)
-> (forall a. a -> ImpM rep r op a)
-> (forall a b.
    ImpM rep r op (a -> b) -> ImpM rep r op a -> ImpM rep r op b)
-> (forall a b c.
    (a -> b -> c)
    -> ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op c)
-> (forall a b.
    ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op b)
-> (forall a b.
    ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op a)
-> Applicative (ImpM rep r op)
forall a. a -> ImpM rep r op a
forall a b. ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op a
forall a b. ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op b
forall a b.
ImpM rep r op (a -> b) -> ImpM rep r op a -> ImpM rep r op b
forall rep r op. Functor (ImpM rep r op)
forall a b c.
(a -> b -> c)
-> ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op c
forall rep r op a. a -> ImpM rep r op a
forall rep r op a b.
ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op a
forall rep r op a b.
ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op b
forall rep r op a b.
ImpM rep r op (a -> b) -> ImpM rep r op a -> ImpM rep r op b
forall rep r op a b c.
(a -> b -> c)
-> ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall rep r op a. a -> ImpM rep r op a
pure :: forall a. a -> ImpM rep r op a
$c<*> :: forall rep r op a b.
ImpM rep r op (a -> b) -> ImpM rep r op a -> ImpM rep r op b
<*> :: forall a b.
ImpM rep r op (a -> b) -> ImpM rep r op a -> ImpM rep r op b
$cliftA2 :: forall rep r op a b c.
(a -> b -> c)
-> ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op c
liftA2 :: forall a b c.
(a -> b -> c)
-> ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op c
$c*> :: forall rep r op a b.
ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op b
*> :: forall a b. ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op b
$c<* :: forall rep r op a b.
ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op a
<* :: forall a b. ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op a
Applicative,
      Applicative (ImpM rep r op)
Applicative (ImpM rep r op)
-> (forall a b.
    ImpM rep r op a -> (a -> ImpM rep r op b) -> ImpM rep r op b)
-> (forall a b.
    ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op b)
-> (forall a. a -> ImpM rep r op a)
-> Monad (ImpM rep r op)
forall a. a -> ImpM rep r op a
forall a b. ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op b
forall a b.
ImpM rep r op a -> (a -> ImpM rep r op b) -> ImpM rep r op b
forall rep r op. Applicative (ImpM rep r op)
forall rep r op a. a -> ImpM rep r op a
forall rep r op a b.
ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op b
forall rep r op a b.
ImpM rep r op a -> (a -> ImpM rep r op b) -> ImpM rep r op b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall rep r op a b.
ImpM rep r op a -> (a -> ImpM rep r op b) -> ImpM rep r op b
>>= :: forall a b.
ImpM rep r op a -> (a -> ImpM rep r op b) -> ImpM rep r op b
$c>> :: forall rep r op a b.
ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op b
>> :: forall a b. ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op b
$creturn :: forall rep r op a. a -> ImpM rep r op a
return :: forall a. a -> ImpM rep r op a
Monad,
      MonadState (ImpState rep r op),
      MonadReader (Env rep r op)
    )

instance MonadFreshNames (ImpM rep r op) where
  getNameSource :: ImpM rep r op VNameSource
getNameSource = (ImpState rep r op -> VNameSource) -> ImpM rep r op VNameSource
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ImpState rep r op -> VNameSource
forall {k} rep (r :: k) op. ImpState rep r op -> VNameSource
stateNameSource
  putNameSource :: VNameSource -> ImpM rep r op ()
putNameSource VNameSource
src = (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ())
-> (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s -> ImpState rep r op
s {stateNameSource :: VNameSource
stateNameSource = VNameSource
src}

-- Cannot be an KernelsMem scope because the index functions have
-- the wrong leaves (VName instead of Imp.Exp).
instance HasScope SOACS (ImpM rep r op) where
  askScope :: ImpM rep r op (Scope SOACS)
askScope = (ImpState rep r op -> Scope SOACS) -> ImpM rep r op (Scope SOACS)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ImpState rep r op -> Scope SOACS) -> ImpM rep r op (Scope SOACS))
-> (ImpState rep r op -> Scope SOACS)
-> ImpM rep r op (Scope SOACS)
forall a b. (a -> b) -> a -> b
$ (VarEntry rep -> NameInfo SOACS)
-> Map VName (VarEntry rep) -> Scope SOACS
forall a b k. (a -> b) -> Map k a -> Map k b
M.map (Type -> NameInfo SOACS
LetDec SOACS -> NameInfo SOACS
forall rep. LetDec rep -> NameInfo rep
LetName (Type -> NameInfo SOACS)
-> (VarEntry rep -> Type) -> VarEntry rep -> NameInfo SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VarEntry rep -> Type
forall {rep}. VarEntry rep -> Type
entryType) (Map VName (VarEntry rep) -> Scope SOACS)
-> (ImpState rep r op -> Map VName (VarEntry rep))
-> ImpState rep r op
-> Scope SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ImpState rep r op -> Map VName (VarEntry rep)
forall {k} rep (r :: k) op. ImpState rep r op -> VTable rep
stateVTable
    where
      entryType :: VarEntry rep -> Type
entryType (MemVar Maybe (Exp rep)
_ MemEntry
memEntry) =
        Space -> Type
forall shape u. Space -> TypeBase shape u
Mem (MemEntry -> Space
entryMemSpace MemEntry
memEntry)
      entryType (ArrayVar Maybe (Exp rep)
_ ArrayEntry
arrayEntry) =
        PrimType -> Shape -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array
          (ArrayEntry -> PrimType
entryArrayElemType ArrayEntry
arrayEntry)
          ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> [SubExp] -> Shape
forall a b. (a -> b) -> a -> b
$ ArrayEntry -> [SubExp]
entryArrayShape ArrayEntry
arrayEntry)
          NoUniqueness
NoUniqueness
      entryType (ScalarVar Maybe (Exp rep)
_ ScalarEntry
scalarEntry) =
        PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> Type) -> PrimType -> Type
forall a b. (a -> b) -> a -> b
$ ScalarEntry -> PrimType
entryScalarType ScalarEntry
scalarEntry
      entryType (AccVar Maybe (Exp rep)
_ (VName
acc, Shape
ispace, [Type]
ts)) =
        VName -> Shape -> [Type] -> NoUniqueness -> Type
forall shape u. VName -> Shape -> [Type] -> u -> TypeBase shape u
Acc VName
acc Shape
ispace [Type]
ts NoUniqueness
NoUniqueness

runImpM ::
  ImpM rep r op a ->
  r ->
  Operations rep r op ->
  Imp.Space ->
  ImpState rep r op ->
  (a, ImpState rep r op)
runImpM :: forall rep r op a.
ImpM rep r op a
-> r
-> Operations rep r op
-> Space
-> ImpState rep r op
-> (a, ImpState rep r op)
runImpM (ImpM ReaderT (Env rep r op) (State (ImpState rep r op)) a
m) r
r Operations rep r op
ops Space
space = State (ImpState rep r op) a
-> ImpState rep r op -> (a, ImpState rep r op)
forall s a. State s a -> s -> (a, s)
runState (ReaderT (Env rep r op) (State (ImpState rep r op)) a
-> Env rep r op -> State (ImpState rep r op) a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Env rep r op) (State (ImpState rep r op)) a
m (Env rep r op -> State (ImpState rep r op) a)
-> Env rep r op -> State (ImpState rep r op) a
forall a b. (a -> b) -> a -> b
$ r -> Operations rep r op -> Space -> Env rep r op
forall r rep op. r -> Operations rep r op -> Space -> Env rep r op
newEnv r
r Operations rep r op
ops Space
space)

subImpM_ ::
  r' ->
  Operations rep r' op' ->
  ImpM rep r' op' a ->
  ImpM rep r op (Imp.Code op')
subImpM_ :: forall r' rep op' a r op.
r'
-> Operations rep r' op'
-> ImpM rep r' op' a
-> ImpM rep r op (Code op')
subImpM_ r'
r Operations rep r' op'
ops ImpM rep r' op' a
m = (a, Code op') -> Code op'
forall a b. (a, b) -> b
snd ((a, Code op') -> Code op')
-> ImpM rep r op (a, Code op') -> ImpM rep r op (Code op')
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> r'
-> Operations rep r' op'
-> ImpM rep r' op' a
-> ImpM rep r op (a, Code op')
forall r' rep op' a r op.
r'
-> Operations rep r' op'
-> ImpM rep r' op' a
-> ImpM rep r op (a, Code op')
subImpM r'
r Operations rep r' op'
ops ImpM rep r' op' a
m

subImpM ::
  r' ->
  Operations rep r' op' ->
  ImpM rep r' op' a ->
  ImpM rep r op (a, Imp.Code op')
subImpM :: forall r' rep op' a r op.
r'
-> Operations rep r' op'
-> ImpM rep r' op' a
-> ImpM rep r op (a, Code op')
subImpM r'
r Operations rep r' op'
ops (ImpM ReaderT (Env rep r' op') (State (ImpState rep r' op')) a
m) = do
  Env rep r op
env <- ImpM rep r op (Env rep r op)
forall r (m :: * -> *). MonadReader r m => m r
ask
  ImpState rep r op
s <- ImpM rep r op (ImpState rep r op)
forall s (m :: * -> *). MonadState s m => m s
get

  let env' :: Env rep r' op'
env' =
        Env rep r op
env
          { envExpCompiler :: ExpCompiler rep r' op'
envExpCompiler = Operations rep r' op' -> ExpCompiler rep r' op'
forall rep r op. Operations rep r op -> ExpCompiler rep r op
opsExpCompiler Operations rep r' op'
ops,
            envStmsCompiler :: StmsCompiler rep r' op'
envStmsCompiler = Operations rep r' op' -> StmsCompiler rep r' op'
forall rep r op. Operations rep r op -> StmsCompiler rep r op
opsStmsCompiler Operations rep r' op'
ops,
            envCopyCompiler :: CopyCompiler rep r' op'
envCopyCompiler = Operations rep r' op' -> CopyCompiler rep r' op'
forall rep r op. Operations rep r op -> CopyCompiler rep r op
opsCopyCompiler Operations rep r' op'
ops,
            envOpCompiler :: OpCompiler rep r' op'
envOpCompiler = Operations rep r' op' -> OpCompiler rep r' op'
forall rep r op. Operations rep r op -> OpCompiler rep r op
opsOpCompiler Operations rep r' op'
ops,
            envAllocCompilers :: Map Space (AllocCompiler rep r' op')
envAllocCompilers = Operations rep r' op' -> Map Space (AllocCompiler rep r' op')
forall rep r op.
Operations rep r op -> Map Space (AllocCompiler rep r op)
opsAllocCompilers Operations rep r' op'
ops,
            envEnv :: r'
envEnv = r'
r
          }
      s' :: ImpState rep r' op'
s' =
        ImpState
          { stateVTable :: VTable rep
stateVTable = ImpState rep r op -> VTable rep
forall {k} rep (r :: k) op. ImpState rep r op -> VTable rep
stateVTable ImpState rep r op
s,
            stateFunctions :: Functions op'
stateFunctions = Functions op'
forall a. Monoid a => a
mempty,
            stateCode :: Code op'
stateCode = Code op'
forall a. Monoid a => a
mempty,
            stateNameSource :: VNameSource
stateNameSource = ImpState rep r op -> VNameSource
forall {k} rep (r :: k) op. ImpState rep r op -> VNameSource
stateNameSource ImpState rep r op
s,
            stateConstants :: Constants op'
stateConstants = Constants op'
forall a. Monoid a => a
mempty,
            stateWarnings :: Warnings
stateWarnings = Warnings
forall a. Monoid a => a
mempty,
            stateAccs :: Map VName ([VName], Maybe (Lambda rep, [SubExp]))
stateAccs = ImpState rep r op
-> Map VName ([VName], Maybe (Lambda rep, [SubExp]))
forall {k} rep (r :: k) op.
ImpState rep r op
-> Map VName ([VName], Maybe (Lambda rep, [SubExp]))
stateAccs ImpState rep r op
s
          }
      (a
x, ImpState rep r' op'
s'') = State (ImpState rep r' op') a
-> ImpState rep r' op' -> (a, ImpState rep r' op')
forall s a. State s a -> s -> (a, s)
runState (ReaderT (Env rep r' op') (State (ImpState rep r' op')) a
-> Env rep r' op' -> State (ImpState rep r' op') a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Env rep r' op') (State (ImpState rep r' op')) a
m Env rep r' op'
env') ImpState rep r' op'
s'

  VNameSource -> ImpM rep r op ()
forall (m :: * -> *). MonadFreshNames m => VNameSource -> m ()
putNameSource (VNameSource -> ImpM rep r op ())
-> VNameSource -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ ImpState rep r' op' -> VNameSource
forall {k} rep (r :: k) op. ImpState rep r op -> VNameSource
stateNameSource ImpState rep r' op'
s''
  Warnings -> ImpM rep r op ()
forall rep r op. Warnings -> ImpM rep r op ()
warnings (Warnings -> ImpM rep r op ()) -> Warnings -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ ImpState rep r' op' -> Warnings
forall {k} rep (r :: k) op. ImpState rep r op -> Warnings
stateWarnings ImpState rep r' op'
s''
  (a, Code op') -> ImpM rep r op (a, Code op')
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
x, ImpState rep r' op' -> Code op'
forall {k} rep (r :: k) op. ImpState rep r op -> Code op
stateCode ImpState rep r' op'
s'')

-- | Execute a code generation action, returning the code that was
-- emitted.
collect :: ImpM rep r op () -> ImpM rep r op (Imp.Code op)
collect :: forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect = (((), Code op) -> Code op)
-> ImpM rep r op ((), Code op) -> ImpM rep r op (Code op)
forall a b. (a -> b) -> ImpM rep r op a -> ImpM rep r op b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((), Code op) -> Code op
forall a b. (a, b) -> b
snd (ImpM rep r op ((), Code op) -> ImpM rep r op (Code op))
-> (ImpM rep r op () -> ImpM rep r op ((), Code op))
-> ImpM rep r op ()
-> ImpM rep r op (Code op)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ImpM rep r op () -> ImpM rep r op ((), Code op)
forall rep r op a. ImpM rep r op a -> ImpM rep r op (a, Code op)
collect'

collect' :: ImpM rep r op a -> ImpM rep r op (a, Imp.Code op)
collect' :: forall rep r op a. ImpM rep r op a -> ImpM rep r op (a, Code op)
collect' ImpM rep r op a
m = do
  Code op
prev_code <- (ImpState rep r op -> Code op) -> ImpM rep r op (Code op)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ImpState rep r op -> Code op
forall {k} rep (r :: k) op. ImpState rep r op -> Code op
stateCode
  (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ())
-> (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s -> ImpState rep r op
s {stateCode :: Code op
stateCode = Code op
forall a. Monoid a => a
mempty}
  a
x <- ImpM rep r op a
m
  Code op
new_code <- (ImpState rep r op -> Code op) -> ImpM rep r op (Code op)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ImpState rep r op -> Code op
forall {k} rep (r :: k) op. ImpState rep r op -> Code op
stateCode
  (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ())
-> (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s -> ImpState rep r op
s {stateCode :: Code op
stateCode = Code op
prev_code}
  (a, Code op) -> ImpM rep r op (a, Code op)
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
x, Code op
new_code)

-- | Execute a code generation action, wrapping the generated code
-- within a 'Imp.Comment' with the given description.
comment :: T.Text -> ImpM rep r op () -> ImpM rep r op ()
comment :: forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
comment Text
desc ImpM rep r op ()
m = do
  Code op
code <- ImpM rep r op () -> ImpM rep r op (Code op)
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect ImpM rep r op ()
m
  Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Text -> Code op -> Code op
forall a. Text -> Code a -> Code a
Imp.Comment Text
desc Code op
code

-- | Emit some generated imperative code.
emit :: Imp.Code op -> ImpM rep r op ()
emit :: forall op rep r. Code op -> ImpM rep r op ()
emit Code op
code = (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ())
-> (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s -> ImpState rep r op
s {stateCode :: Code op
stateCode = ImpState rep r op -> Code op
forall {k} rep (r :: k) op. ImpState rep r op -> Code op
stateCode ImpState rep r op
s Code op -> Code op -> Code op
forall a. Semigroup a => a -> a -> a
<> Code op
code}

warnings :: Warnings -> ImpM rep r op ()
warnings :: forall rep r op. Warnings -> ImpM rep r op ()
warnings Warnings
ws = (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ())
-> (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s -> ImpState rep r op
s {stateWarnings :: Warnings
stateWarnings = Warnings
ws Warnings -> Warnings -> Warnings
forall a. Semigroup a => a -> a -> a
<> ImpState rep r op -> Warnings
forall {k} rep (r :: k) op. ImpState rep r op -> Warnings
stateWarnings ImpState rep r op
s}

-- | Emit a warning about something the user should be aware of.
warn :: (Located loc) => loc -> [loc] -> T.Text -> ImpM rep r op ()
warn :: forall loc rep r op.
Located loc =>
loc -> [loc] -> Text -> ImpM rep r op ()
warn loc
loc [loc]
locs Text
problem =
  Warnings -> ImpM rep r op ()
forall rep r op. Warnings -> ImpM rep r op ()
warnings (Warnings -> ImpM rep r op ()) -> Warnings -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ SrcLoc -> [SrcLoc] -> Doc () -> Warnings
singleWarning' (loc -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf loc
loc) ((loc -> SrcLoc) -> [loc] -> [SrcLoc]
forall a b. (a -> b) -> [a] -> [b]
map loc -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf [loc]
locs) (Text -> Doc ()
forall ann. Text -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty Text
problem)

-- | Emit a function in the generated code.
emitFunction :: Name -> Imp.Function op -> ImpM rep r op ()
emitFunction :: forall op rep r. Name -> Function op -> ImpM rep r op ()
emitFunction Name
fname Function op
fun = do
  Imp.Functions [(Name, Function op)]
fs <- (ImpState rep r op -> Functions op) -> ImpM rep r op (Functions op)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ImpState rep r op -> Functions op
forall {k} rep (r :: k) op. ImpState rep r op -> Functions op
stateFunctions
  (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ())
-> (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s -> ImpState rep r op
s {stateFunctions :: Functions op
stateFunctions = [(Name, Function op)] -> Functions op
forall a. [(Name, Function a)] -> Functions a
Imp.Functions ([(Name, Function op)] -> Functions op)
-> [(Name, Function op)] -> Functions op
forall a b. (a -> b) -> a -> b
$ (Name
fname, Function op
fun) (Name, Function op)
-> [(Name, Function op)] -> [(Name, Function op)]
forall a. a -> [a] -> [a]
: [(Name, Function op)]
fs}

-- | Check if a function of a given name exists.
hasFunction :: Name -> ImpM rep r op Bool
hasFunction :: forall rep r op. Name -> ImpM rep r op Bool
hasFunction Name
fname = (ImpState rep r op -> Bool) -> ImpM rep r op Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ImpState rep r op -> Bool) -> ImpM rep r op Bool)
-> (ImpState rep r op -> Bool) -> ImpM rep r op Bool
forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s ->
  let Imp.Functions [(Name, Function op)]
fs = ImpState rep r op -> Functions op
forall {k} rep (r :: k) op. ImpState rep r op -> Functions op
stateFunctions ImpState rep r op
s
   in Maybe (Function op) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe (Function op) -> Bool) -> Maybe (Function op) -> Bool
forall a b. (a -> b) -> a -> b
$ Name -> [(Name, Function op)] -> Maybe (Function op)
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Name
fname [(Name, Function op)]
fs

constsVTable :: (Mem rep inner) => Stms rep -> VTable rep
constsVTable :: forall rep (inner :: * -> *).
Mem rep inner =>
Stms rep -> VTable rep
constsVTable = (Stm rep -> VTable rep) -> Seq (Stm rep) -> VTable rep
forall m a. Monoid m => (a -> m) -> Seq a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm rep -> VTable rep
forall {rep}.
HasLetDecMem (LetDec rep) =>
Stm rep -> Map VName (VarEntry rep)
stmVtable
  where
    stmVtable :: Stm rep -> Map VName (VarEntry rep)
stmVtable (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ Exp rep
e) =
      (PatElem (LetDec rep) -> Map VName (VarEntry rep))
-> [PatElem (LetDec rep)] -> Map VName (VarEntry rep)
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Exp rep -> PatElem (LetDec rep) -> Map VName (VarEntry rep)
forall {t} {rep}.
HasLetDecMem t =>
Exp rep -> PatElem t -> Map VName (VarEntry rep)
peVtable Exp rep
e) ([PatElem (LetDec rep)] -> Map VName (VarEntry rep))
-> [PatElem (LetDec rep)] -> Map VName (VarEntry rep)
forall a b. (a -> b) -> a -> b
$ Pat (LetDec rep) -> [PatElem (LetDec rep)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat
    peVtable :: Exp rep -> PatElem t -> Map VName (VarEntry rep)
peVtable Exp rep
e (PatElem VName
name t
dec) =
      VName -> VarEntry rep -> Map VName (VarEntry rep)
forall k a. k -> a -> Map k a
M.singleton VName
name (VarEntry rep -> Map VName (VarEntry rep))
-> VarEntry rep -> Map VName (VarEntry rep)
forall a b. (a -> b) -> a -> b
$ Maybe (Exp rep) -> LParamMem -> VarEntry rep
forall rep. Maybe (Exp rep) -> LParamMem -> VarEntry rep
memBoundToVarEntry (Exp rep -> Maybe (Exp rep)
forall a. a -> Maybe a
Just Exp rep
e) (LParamMem -> VarEntry rep) -> LParamMem -> VarEntry rep
forall a b. (a -> b) -> a -> b
$ t -> LParamMem
forall t. HasLetDecMem t => t -> LParamMem
letDecMem t
dec

compileProg ::
  (Mem rep inner, FreeIn op, MonadFreshNames m) =>
  r ->
  Operations rep r op ->
  Imp.Space ->
  Prog rep ->
  m (Warnings, Imp.Definitions op)
compileProg :: forall rep (inner :: * -> *) op (m :: * -> *) r.
(Mem rep inner, FreeIn op, MonadFreshNames m) =>
r
-> Operations rep r op
-> Space
-> Prog rep
-> m (Warnings, Definitions op)
compileProg r
r Operations rep r op
ops Space
space (Prog OpaqueTypes
types Stms rep
consts [FunDef rep]
funs) =
  (VNameSource -> ((Warnings, Definitions op), VNameSource))
-> m (Warnings, Definitions op)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> ((Warnings, Definitions op), VNameSource))
 -> m (Warnings, Definitions op))
-> (VNameSource -> ((Warnings, Definitions op), VNameSource))
-> m (Warnings, Definitions op)
forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
    let ([()]
_, [ImpState rep r op]
ss) =
          [((), ImpState rep r op)] -> ([()], [ImpState rep r op])
forall a b. [(a, b)] -> ([a], [b])
unzip ([((), ImpState rep r op)] -> ([()], [ImpState rep r op]))
-> [((), ImpState rep r op)] -> ([()], [ImpState rep r op])
forall a b. (a -> b) -> a -> b
$ Strategy ((), ImpState rep r op)
-> (FunDef rep -> ((), ImpState rep r op))
-> [FunDef rep]
-> [((), ImpState rep r op)]
forall b a. Strategy b -> (a -> b) -> [a] -> [b]
parMap Strategy ((), ImpState rep r op)
forall a. Strategy a
rpar (VNameSource -> FunDef rep -> ((), ImpState rep r op)
compileFunDef' VNameSource
src) [FunDef rep]
funs
        free_in_funs :: Names
free_in_funs =
          Functions op -> Names
forall a. FreeIn a => a -> Names
freeIn (Functions op -> Names) -> Functions op -> Names
forall a b. (a -> b) -> a -> b
$ [Functions op] -> Functions op
forall a. Monoid a => [a] -> a
mconcat ([Functions op] -> Functions op) -> [Functions op] -> Functions op
forall a b. (a -> b) -> a -> b
$ (ImpState rep r op -> Functions op)
-> [ImpState rep r op] -> [Functions op]
forall a b. (a -> b) -> [a] -> [b]
map ImpState rep r op -> Functions op
forall {k} rep (r :: k) op. ImpState rep r op -> Functions op
stateFunctions [ImpState rep r op]
ss
        ((), ImpState rep r op
s') =
          ImpM rep r op ()
-> r
-> Operations rep r op
-> Space
-> ImpState rep r op
-> ((), ImpState rep r op)
forall rep r op a.
ImpM rep r op a
-> r
-> Operations rep r op
-> Space
-> ImpState rep r op
-> (a, ImpState rep r op)
runImpM (Names -> Stms rep -> ImpM rep r op ()
forall rep r op. Names -> Stms rep -> ImpM rep r op ()
compileConsts Names
free_in_funs Stms rep
consts) r
r Operations rep r op
ops Space
space (ImpState rep r op -> ((), ImpState rep r op))
-> ImpState rep r op -> ((), ImpState rep r op)
forall a b. (a -> b) -> a -> b
$
            [ImpState rep r op] -> ImpState rep r op
forall {k} {k} {rep} {r :: k} {op} {rep} {r :: k}.
[ImpState rep r op] -> ImpState rep r op
combineStates [ImpState rep r op]
ss
     in ( ( ImpState rep r op -> Warnings
forall {k} rep (r :: k) op. ImpState rep r op -> Warnings
stateWarnings ImpState rep r op
s',
            OpaqueTypes -> Constants op -> Functions op -> Definitions op
forall a.
OpaqueTypes -> Constants a -> Functions a -> Definitions a
Imp.Definitions
              OpaqueTypes
types
              (ImpState rep r op -> Constants op
forall {k} rep (r :: k) op. ImpState rep r op -> Constants op
stateConstants ImpState rep r op
s' Constants op -> Constants op -> Constants op
forall a. Semigroup a => a -> a -> a
<> (ImpState rep r op -> Constants op)
-> [ImpState rep r op] -> Constants op
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap ImpState rep r op -> Constants op
forall {k} rep (r :: k) op. ImpState rep r op -> Constants op
stateConstants [ImpState rep r op]
ss)
              (ImpState rep r op -> Functions op
forall {k} rep (r :: k) op. ImpState rep r op -> Functions op
stateFunctions ImpState rep r op
s')
          ),
          ImpState rep r op -> VNameSource
forall {k} rep (r :: k) op. ImpState rep r op -> VNameSource
stateNameSource ImpState rep r op
s'
        )
  where
    compileFunDef' :: VNameSource -> FunDef rep -> ((), ImpState rep r op)
compileFunDef' VNameSource
src FunDef rep
fdef =
      ImpM rep r op ()
-> r
-> Operations rep r op
-> Space
-> ImpState rep r op
-> ((), ImpState rep r op)
forall rep r op a.
ImpM rep r op a
-> r
-> Operations rep r op
-> Space
-> ImpState rep r op
-> (a, ImpState rep r op)
runImpM
        (OpaqueTypes -> FunDef rep -> ImpM rep r op ()
forall rep (inner :: * -> *) r op.
Mem rep inner =>
OpaqueTypes -> FunDef rep -> ImpM rep r op ()
compileFunDef OpaqueTypes
types FunDef rep
fdef)
        r
r
        Operations rep r op
ops
        Space
space
        (VNameSource -> ImpState rep Any op
forall {k} rep (r :: k) op. VNameSource -> ImpState rep r op
newState VNameSource
src) {stateVTable :: VTable rep
stateVTable = Stms rep -> VTable rep
forall rep (inner :: * -> *).
Mem rep inner =>
Stms rep -> VTable rep
constsVTable Stms rep
consts}

    combineStates :: [ImpState rep r op] -> ImpState rep r op
combineStates [ImpState rep r op]
ss =
      let Imp.Functions [(Name, Function op)]
funs' = [Functions op] -> Functions op
forall a. Monoid a => [a] -> a
mconcat ([Functions op] -> Functions op) -> [Functions op] -> Functions op
forall a b. (a -> b) -> a -> b
$ (ImpState rep r op -> Functions op)
-> [ImpState rep r op] -> [Functions op]
forall a b. (a -> b) -> [a] -> [b]
map ImpState rep r op -> Functions op
forall {k} rep (r :: k) op. ImpState rep r op -> Functions op
stateFunctions [ImpState rep r op]
ss
          src :: VNameSource
src = [VNameSource] -> VNameSource
forall a. Monoid a => [a] -> a
mconcat ((ImpState rep r op -> VNameSource)
-> [ImpState rep r op] -> [VNameSource]
forall a b. (a -> b) -> [a] -> [b]
map ImpState rep r op -> VNameSource
forall {k} rep (r :: k) op. ImpState rep r op -> VNameSource
stateNameSource [ImpState rep r op]
ss)
       in (VNameSource -> ImpState rep Any op
forall {k} rep (r :: k) op. VNameSource -> ImpState rep r op
newState VNameSource
src)
            { stateFunctions :: Functions op
stateFunctions =
                [(Name, Function op)] -> Functions op
forall a. [(Name, Function a)] -> Functions a
Imp.Functions ([(Name, Function op)] -> Functions op)
-> [(Name, Function op)] -> Functions op
forall a b. (a -> b) -> a -> b
$ Map Name (Function op) -> [(Name, Function op)]
forall k a. Map k a -> [(k, a)]
M.toList (Map Name (Function op) -> [(Name, Function op)])
-> Map Name (Function op) -> [(Name, Function op)]
forall a b. (a -> b) -> a -> b
$ [(Name, Function op)] -> Map Name (Function op)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(Name, Function op)]
funs',
              stateWarnings :: Warnings
stateWarnings =
                [Warnings] -> Warnings
forall a. Monoid a => [a] -> a
mconcat ([Warnings] -> Warnings) -> [Warnings] -> Warnings
forall a b. (a -> b) -> a -> b
$ (ImpState rep r op -> Warnings)
-> [ImpState rep r op] -> [Warnings]
forall a b. (a -> b) -> [a] -> [b]
map ImpState rep r op -> Warnings
forall {k} rep (r :: k) op. ImpState rep r op -> Warnings
stateWarnings [ImpState rep r op]
ss
            }

compileConsts :: Names -> Stms rep -> ImpM rep r op ()
compileConsts :: forall rep r op. Names -> Stms rep -> ImpM rep r op ()
compileConsts Names
used_consts Stms rep
stms = ImpM rep r op (Names, ()) -> ImpM rep r op ()
forall rep r op a. ImpM rep r op (Names, a) -> ImpM rep r op a
genConstants (ImpM rep r op (Names, ()) -> ImpM rep r op ())
-> ImpM rep r op (Names, ()) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ do
  Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
used_consts Stms rep
stms (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ () -> ImpM rep r op ()
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  (Names, ()) -> ImpM rep r op (Names, ())
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Names
used_consts, ())

lookupOpaqueType :: Name -> OpaqueTypes -> OpaqueType
lookupOpaqueType :: Name -> OpaqueTypes -> OpaqueType
lookupOpaqueType Name
v (OpaqueTypes [(Name, OpaqueType)]
types) =
  case Name -> [(Name, OpaqueType)] -> Maybe OpaqueType
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Name
v [(Name, OpaqueType)]
types of
    Just OpaqueType
t -> OpaqueType
t
    Maybe OpaqueType
Nothing -> [Char] -> OpaqueType
forall a. HasCallStack => [Char] -> a
error ([Char] -> OpaqueType) -> [Char] -> OpaqueType
forall a b. (a -> b) -> a -> b
$ [Char]
"Unknown opaque type: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ Name -> [Char]
forall a. Show a => a -> [Char]
show Name
v

valueTypeSign :: ValueType -> Signedness
valueTypeSign :: ValueType -> Signedness
valueTypeSign (ValueType Signedness
sign Rank
_ PrimType
_) = Signedness
sign

entryPointSignedness :: OpaqueTypes -> EntryPointType -> [Signedness]
entryPointSignedness :: OpaqueTypes -> EntryPointType -> [Signedness]
entryPointSignedness OpaqueTypes
_ (TypeTransparent ValueType
vt) = [ValueType -> Signedness
valueTypeSign ValueType
vt]
entryPointSignedness OpaqueTypes
types (TypeOpaque Name
desc) =
  case Name -> OpaqueTypes -> OpaqueType
lookupOpaqueType Name
desc OpaqueTypes
types of
    OpaqueType [ValueType]
vts -> (ValueType -> Signedness) -> [ValueType] -> [Signedness]
forall a b. (a -> b) -> [a] -> [b]
map ValueType -> Signedness
valueTypeSign [ValueType]
vts
    OpaqueRecord [(Name, EntryPointType)]
fs -> ((Name, EntryPointType) -> [Signedness])
-> [(Name, EntryPointType)] -> [Signedness]
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (OpaqueTypes -> EntryPointType -> [Signedness]
entryPointSignedness OpaqueTypes
types (EntryPointType -> [Signedness])
-> ((Name, EntryPointType) -> EntryPointType)
-> (Name, EntryPointType)
-> [Signedness]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Name, EntryPointType) -> EntryPointType
forall a b. (a, b) -> b
snd) [(Name, EntryPointType)]
fs

-- | How many value parameters are accepted by this entry point?  This
-- is used to determine which of the function parameters correspond to
-- the parameters of the original function (they must all come at the
-- end).
entryPointSize :: OpaqueTypes -> EntryPointType -> Int
entryPointSize :: OpaqueTypes -> EntryPointType -> Int
entryPointSize OpaqueTypes
_ (TypeTransparent ValueType
_) = Int
1
entryPointSize OpaqueTypes
types (TypeOpaque Name
desc) =
  case Name -> OpaqueTypes -> OpaqueType
lookupOpaqueType Name
desc OpaqueTypes
types of
    OpaqueType [ValueType]
vts -> [ValueType] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [ValueType]
vts
    OpaqueRecord [(Name, EntryPointType)]
fs -> [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ ((Name, EntryPointType) -> Int)
-> [(Name, EntryPointType)] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (OpaqueTypes -> EntryPointType -> Int
entryPointSize OpaqueTypes
types (EntryPointType -> Int)
-> ((Name, EntryPointType) -> EntryPointType)
-> (Name, EntryPointType)
-> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Name, EntryPointType) -> EntryPointType
forall a b. (a, b) -> b
snd) [(Name, EntryPointType)]
fs

compileInParam ::
  (Mem rep inner) =>
  FParam rep ->
  ImpM rep r op (Either Imp.Param ArrayDecl)
compileInParam :: forall rep (inner :: * -> *) r op.
Mem rep inner =>
FParam rep -> ImpM rep r op (Either Param ArrayDecl)
compileInParam FParam rep
fparam = case Param FParamMem -> FParamMem
forall dec. Param dec -> dec
paramDec FParam rep
Param FParamMem
fparam of
  MemPrim PrimType
bt ->
    Either Param ArrayDecl -> ImpM rep r op (Either Param ArrayDecl)
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either Param ArrayDecl -> ImpM rep r op (Either Param ArrayDecl))
-> Either Param ArrayDecl -> ImpM rep r op (Either Param ArrayDecl)
forall a b. (a -> b) -> a -> b
$ Param -> Either Param ArrayDecl
forall a b. a -> Either a b
Left (Param -> Either Param ArrayDecl)
-> Param -> Either Param ArrayDecl
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Param
Imp.ScalarParam VName
name PrimType
bt
  MemMem Space
space ->
    Either Param ArrayDecl -> ImpM rep r op (Either Param ArrayDecl)
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either Param ArrayDecl -> ImpM rep r op (Either Param ArrayDecl))
-> Either Param ArrayDecl -> ImpM rep r op (Either Param ArrayDecl)
forall a b. (a -> b) -> a -> b
$ Param -> Either Param ArrayDecl
forall a b. a -> Either a b
Left (Param -> Either Param ArrayDecl)
-> Param -> Either Param ArrayDecl
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Param
Imp.MemParam VName
name Space
space
  MemArray PrimType
bt Shape
shape Uniqueness
_ (ArrayIn VName
mem IxFun
lmad) ->
    Either Param ArrayDecl -> ImpM rep r op (Either Param ArrayDecl)
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Either Param ArrayDecl -> ImpM rep r op (Either Param ArrayDecl))
-> Either Param ArrayDecl -> ImpM rep r op (Either Param ArrayDecl)
forall a b. (a -> b) -> a -> b
$ ArrayDecl -> Either Param ArrayDecl
forall a b. b -> Either a b
Right (ArrayDecl -> Either Param ArrayDecl)
-> ArrayDecl -> Either Param ArrayDecl
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> MemLoc -> ArrayDecl
ArrayDecl VName
name PrimType
bt (MemLoc -> ArrayDecl) -> MemLoc -> ArrayDecl
forall a b. (a -> b) -> a -> b
$ VName -> [SubExp] -> LMAD (TExp Int64) -> MemLoc
MemLoc VName
mem (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape) (LMAD (TExp Int64) -> MemLoc) -> LMAD (TExp Int64) -> MemLoc
forall a b. (a -> b) -> a -> b
$ IxFun -> LMAD (TExp Int64)
forall num. IxFun num -> LMAD num
IxFun.ixfunLMAD IxFun
lmad
  MemAcc {} ->
    [Char] -> ImpM rep r op (Either Param ArrayDecl)
forall a. HasCallStack => [Char] -> a
error [Char]
"Functions may not have accumulator parameters."
  where
    name :: VName
name = Param FParamMem -> VName
forall dec. Param dec -> VName
paramName FParam rep
Param FParamMem
fparam

data ArrayDecl = ArrayDecl VName PrimType MemLoc

compileInParams ::
  (Mem rep inner) =>
  OpaqueTypes ->
  [FParam rep] ->
  Maybe [EntryParam] ->
  ImpM rep r op ([Imp.Param], [ArrayDecl], Maybe [((Name, Uniqueness), Imp.ExternalValue)])
compileInParams :: forall rep (inner :: * -> *) r op.
Mem rep inner =>
OpaqueTypes
-> [FParam rep]
-> Maybe [EntryParam]
-> ImpM
     rep
     r
     op
     ([Param], [ArrayDecl], Maybe [((Name, Uniqueness), ExternalValue)])
compileInParams OpaqueTypes
types [FParam rep]
params Maybe [EntryParam]
eparams = do
  ([Param]
inparams, [ArrayDecl]
arrayds) <- [Either Param ArrayDecl] -> ([Param], [ArrayDecl])
forall a b. [Either a b] -> ([a], [b])
partitionEithers ([Either Param ArrayDecl] -> ([Param], [ArrayDecl]))
-> ImpM rep r op [Either Param ArrayDecl]
-> ImpM rep r op ([Param], [ArrayDecl])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Param FParamMem -> ImpM rep r op (Either Param ArrayDecl))
-> [Param FParamMem] -> ImpM rep r op [Either Param ArrayDecl]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM FParam rep -> ImpM rep r op (Either Param ArrayDecl)
Param FParamMem -> ImpM rep r op (Either Param ArrayDecl)
forall rep (inner :: * -> *) r op.
Mem rep inner =>
FParam rep -> ImpM rep r op (Either Param ArrayDecl)
compileInParam [FParam rep]
[Param FParamMem]
params
  let findArray :: VName -> Maybe ArrayDecl
findArray VName
x = (ArrayDecl -> Bool) -> [ArrayDecl] -> Maybe ArrayDecl
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (VName -> ArrayDecl -> Bool
isArrayDecl VName
x) [ArrayDecl]
arrayds

      summaries :: Map VName Space
summaries = [(VName, Space)] -> Map VName Space
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Space)] -> Map VName Space)
-> [(VName, Space)] -> Map VName Space
forall a b. (a -> b) -> a -> b
$ (Param FParamMem -> Maybe (VName, Space))
-> [Param FParamMem] -> [(VName, Space)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Param FParamMem -> Maybe (VName, Space)
forall {d} {u} {ret}.
Param (MemInfo d u ret) -> Maybe (VName, Space)
memSummary [FParam rep]
[Param FParamMem]
params
        where
          memSummary :: Param (MemInfo d u ret) -> Maybe (VName, Space)
memSummary Param (MemInfo d u ret)
param
            | MemMem Space
space <- Param (MemInfo d u ret) -> MemInfo d u ret
forall dec. Param dec -> dec
paramDec Param (MemInfo d u ret)
param =
                (VName, Space) -> Maybe (VName, Space)
forall a. a -> Maybe a
Just (Param (MemInfo d u ret) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo d u ret)
param, Space
space)
            | Bool
otherwise =
                Maybe (VName, Space)
forall a. Maybe a
Nothing

      findMemInfo :: VName -> Maybe Space
      findMemInfo :: VName -> Maybe Space
findMemInfo = (VName -> Map VName Space -> Maybe Space)
-> Map VName Space -> VName -> Maybe Space
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> Map VName Space -> Maybe Space
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Map VName Space
summaries

      mkValueDesc :: Param FParamMem -> Signedness -> Maybe ValueDesc
mkValueDesc Param FParamMem
fparam Signedness
signedness =
        case (VName -> Maybe ArrayDecl
findArray (VName -> Maybe ArrayDecl) -> VName -> Maybe ArrayDecl
forall a b. (a -> b) -> a -> b
$ Param FParamMem -> VName
forall dec. Param dec -> VName
paramName Param FParamMem
fparam, Param FParamMem -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param FParamMem
fparam) of
          (Just (ArrayDecl VName
_ PrimType
bt (MemLoc VName
mem [SubExp]
shape LMAD (TExp Int64)
_)), Type
_) -> do
            Space
memspace <- VName -> Maybe Space
findMemInfo VName
mem
            ValueDesc -> Maybe ValueDesc
forall a. a -> Maybe a
Just (ValueDesc -> Maybe ValueDesc) -> ValueDesc -> Maybe ValueDesc
forall a b. (a -> b) -> a -> b
$ VName -> Space -> PrimType -> Signedness -> [SubExp] -> ValueDesc
Imp.ArrayValue VName
mem Space
memspace PrimType
bt Signedness
signedness [SubExp]
shape
          (Maybe ArrayDecl
_, Prim PrimType
bt) ->
            ValueDesc -> Maybe ValueDesc
forall a. a -> Maybe a
Just (ValueDesc -> Maybe ValueDesc) -> ValueDesc -> Maybe ValueDesc
forall a b. (a -> b) -> a -> b
$ PrimType -> Signedness -> VName -> ValueDesc
Imp.ScalarValue PrimType
bt Signedness
signedness (VName -> ValueDesc) -> VName -> ValueDesc
forall a b. (a -> b) -> a -> b
$ Param FParamMem -> VName
forall dec. Param dec -> VName
paramName Param FParamMem
fparam
          (Maybe ArrayDecl, Type)
_ ->
            Maybe ValueDesc
forall a. Maybe a
Nothing

      mkExts :: [EntryParam]
-> [Param FParamMem] -> [((Name, Uniqueness), ExternalValue)]
mkExts (EntryParam Name
v Uniqueness
u et :: EntryPointType
et@(TypeOpaque Name
desc) : [EntryParam]
epts) [Param FParamMem]
fparams =
        let signs :: [Signedness]
signs = OpaqueTypes -> EntryPointType -> [Signedness]
entryPointSignedness OpaqueTypes
types EntryPointType
et
            n :: Int
n = OpaqueTypes -> EntryPointType -> Int
entryPointSize OpaqueTypes
types EntryPointType
et
            ([Param FParamMem]
fparams', [Param FParamMem]
rest) = Int -> [Param FParamMem] -> ([Param FParamMem], [Param FParamMem])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
n [Param FParamMem]
fparams
         in ( (Name
v, Uniqueness
u),
              Name -> [ValueDesc] -> ExternalValue
Imp.OpaqueValue
                Name
desc
                ([Maybe ValueDesc] -> [ValueDesc]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe ValueDesc] -> [ValueDesc])
-> [Maybe ValueDesc] -> [ValueDesc]
forall a b. (a -> b) -> a -> b
$ (Param FParamMem -> Signedness -> Maybe ValueDesc)
-> [Param FParamMem] -> [Signedness] -> [Maybe ValueDesc]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Param FParamMem -> Signedness -> Maybe ValueDesc
mkValueDesc [Param FParamMem]
fparams' [Signedness]
signs)
            )
              ((Name, Uniqueness), ExternalValue)
-> [((Name, Uniqueness), ExternalValue)]
-> [((Name, Uniqueness), ExternalValue)]
forall a. a -> [a] -> [a]
: [EntryParam]
-> [Param FParamMem] -> [((Name, Uniqueness), ExternalValue)]
mkExts [EntryParam]
epts [Param FParamMem]
rest
      mkExts (EntryParam Name
v Uniqueness
u (TypeTransparent (ValueType Signedness
s Rank
_ PrimType
_)) : [EntryParam]
epts) (Param FParamMem
fparam : [Param FParamMem]
fparams) =
        Maybe ((Name, Uniqueness), ExternalValue)
-> [((Name, Uniqueness), ExternalValue)]
forall a. Maybe a -> [a]
maybeToList (((Name
v, Uniqueness
u),) (ExternalValue -> ((Name, Uniqueness), ExternalValue))
-> (ValueDesc -> ExternalValue)
-> ValueDesc
-> ((Name, Uniqueness), ExternalValue)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ValueDesc -> ExternalValue
Imp.TransparentValue (ValueDesc -> ((Name, Uniqueness), ExternalValue))
-> Maybe ValueDesc -> Maybe ((Name, Uniqueness), ExternalValue)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Param FParamMem -> Signedness -> Maybe ValueDesc
mkValueDesc Param FParamMem
fparam Signedness
s)
          [((Name, Uniqueness), ExternalValue)]
-> [((Name, Uniqueness), ExternalValue)]
-> [((Name, Uniqueness), ExternalValue)]
forall a. [a] -> [a] -> [a]
++ [EntryParam]
-> [Param FParamMem] -> [((Name, Uniqueness), ExternalValue)]
mkExts [EntryParam]
epts [Param FParamMem]
fparams
      mkExts [EntryParam]
_ [Param FParamMem]
_ = []

  ([Param], [ArrayDecl], Maybe [((Name, Uniqueness), ExternalValue)])
-> ImpM
     rep
     r
     op
     ([Param], [ArrayDecl], Maybe [((Name, Uniqueness), ExternalValue)])
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( [Param]
inparams,
      [ArrayDecl]
arrayds,
      case Maybe [EntryParam]
eparams of
        Just [EntryParam]
eparams' ->
          let num_val_params :: Int
num_val_params = [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((EntryParam -> Int) -> [EntryParam] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (OpaqueTypes -> EntryPointType -> Int
entryPointSize OpaqueTypes
types (EntryPointType -> Int)
-> (EntryParam -> EntryPointType) -> EntryParam -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EntryParam -> EntryPointType
entryParamType) [EntryParam]
eparams')
              ([Param FParamMem]
_ctx_params, [Param FParamMem]
val_params) = Int -> [Param FParamMem] -> ([Param FParamMem], [Param FParamMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([Param FParamMem] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [FParam rep]
[Param FParamMem]
params Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
num_val_params) [FParam rep]
[Param FParamMem]
params
           in [((Name, Uniqueness), ExternalValue)]
-> Maybe [((Name, Uniqueness), ExternalValue)]
forall a. a -> Maybe a
Just ([((Name, Uniqueness), ExternalValue)]
 -> Maybe [((Name, Uniqueness), ExternalValue)])
-> [((Name, Uniqueness), ExternalValue)]
-> Maybe [((Name, Uniqueness), ExternalValue)]
forall a b. (a -> b) -> a -> b
$ [EntryParam]
-> [Param FParamMem] -> [((Name, Uniqueness), ExternalValue)]
mkExts [EntryParam]
eparams' [Param FParamMem]
val_params
        Maybe [EntryParam]
Nothing -> Maybe [((Name, Uniqueness), ExternalValue)]
forall a. Maybe a
Nothing
    )
  where
    isArrayDecl :: VName -> ArrayDecl -> Bool
isArrayDecl VName
x (ArrayDecl VName
y PrimType
_ MemLoc
_) = VName
x VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
y

compileOutParam ::
  FunReturns -> ImpM rep r op (Maybe Imp.Param, ValueDestination)
compileOutParam :: forall rep r op.
RetTypeMem -> ImpM rep r op (Maybe Param, ValueDestination)
compileOutParam (MemPrim PrimType
t) = do
  VName
name <- [Char] -> ImpM rep r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"prim_out"
  (Maybe Param, ValueDestination)
-> ImpM rep r op (Maybe Param, ValueDestination)
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param -> Maybe Param
forall a. a -> Maybe a
Just (Param -> Maybe Param) -> Param -> Maybe Param
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Param
Imp.ScalarParam VName
name PrimType
t, VName -> ValueDestination
ScalarDestination VName
name)
compileOutParam (MemMem Space
space) = do
  VName
name <- [Char] -> ImpM rep r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"mem_out"
  (Maybe Param, ValueDestination)
-> ImpM rep r op (Maybe Param, ValueDestination)
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param -> Maybe Param
forall a. a -> Maybe a
Just (Param -> Maybe Param) -> Param -> Maybe Param
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Param
Imp.MemParam VName
name Space
space, VName -> ValueDestination
MemoryDestination VName
name)
compileOutParam MemArray {} =
  (Maybe Param, ValueDestination)
-> ImpM rep r op (Maybe Param, ValueDestination)
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Param
forall a. Maybe a
Nothing, Maybe MemLoc -> ValueDestination
ArrayDestination Maybe MemLoc
forall a. Maybe a
Nothing)
compileOutParam MemAcc {} =
  [Char] -> ImpM rep r op (Maybe Param, ValueDestination)
forall a. HasCallStack => [Char] -> a
error [Char]
"Functions may not return accumulators."

compileExternalValues ::
  (Mem rep inner) =>
  OpaqueTypes ->
  [RetType rep] ->
  [EntryResult] ->
  [Maybe Imp.Param] ->
  ImpM rep r op [(Uniqueness, Imp.ExternalValue)]
compileExternalValues :: forall rep (inner :: * -> *) r op.
Mem rep inner =>
OpaqueTypes
-> [RetType rep]
-> [EntryResult]
-> [Maybe Param]
-> ImpM rep r op [(Uniqueness, ExternalValue)]
compileExternalValues OpaqueTypes
types [RetType rep]
orig_rts [EntryResult]
orig_epts [Maybe Param]
maybe_params = do
  let ([RetTypeMem]
ctx_rts, [RetTypeMem]
val_rts) =
        Int -> [RetTypeMem] -> ([RetTypeMem], [RetTypeMem])
forall a. Int -> [a] -> ([a], [a])
splitAt
          ([RetTypeMem] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [RetType rep]
[RetTypeMem]
orig_rts Int -> Int -> Int
forall a. Num a => a -> a -> a
- [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((EntryResult -> Int) -> [EntryResult] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (OpaqueTypes -> EntryPointType -> Int
entryPointSize OpaqueTypes
types (EntryPointType -> Int)
-> (EntryResult -> EntryPointType) -> EntryResult -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. EntryResult -> EntryPointType
entryResultType) [EntryResult]
orig_epts))
          [RetType rep]
[RetTypeMem]
orig_rts

  let nthOut :: Int -> VName
nthOut Int
i = case Int -> [Maybe Param] -> Maybe (Maybe Param)
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth Int
i [Maybe Param]
maybe_params of
        Just (Just Param
p) -> Param -> VName
Imp.paramName Param
p
        Just Maybe Param
Nothing -> [Char] -> VName
forall a. HasCallStack => [Char] -> a
error ([Char] -> VName) -> [Char] -> VName
forall a b. (a -> b) -> a -> b
$ [Char]
"Output " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show Int
i [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
" not a param."
        Maybe (Maybe Param)
Nothing -> [Char] -> VName
forall a. HasCallStack => [Char] -> a
error ([Char] -> VName) -> [Char] -> VName
forall a b. (a -> b) -> a -> b
$ [Char]
"Param " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show Int
i [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
" does not exist."

      mkValueDesc :: Int -> Signedness -> RetTypeMem -> ImpM rep r op ValueDesc
mkValueDesc Int
_ Signedness
signedness (MemArray PrimType
t ShapeBase (Ext SubExp)
shape Uniqueness
_ MemReturn
ret) = do
        (VName
mem, Space
space) <-
          case MemReturn
ret of
            ReturnsNewBlock Space
space Int
j ExtIxFun
_lmad ->
              (VName, Space) -> ImpM rep r op (VName, Space)
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> VName
nthOut Int
j, Space
space)
            ReturnsInBlock VName
mem ExtIxFun
_lmad -> do
              Space
space <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM rep r op MemEntry -> ImpM rep r op Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM rep r op MemEntry
forall rep r op. VName -> ImpM rep r op MemEntry
lookupMemory VName
mem
              (VName, Space) -> ImpM rep r op (VName, Space)
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
mem, Space
space)
        ValueDesc -> ImpM rep r op ValueDesc
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ValueDesc -> ImpM rep r op ValueDesc)
-> ValueDesc -> ImpM rep r op ValueDesc
forall a b. (a -> b) -> a -> b
$ VName -> Space -> PrimType -> Signedness -> [SubExp] -> ValueDesc
Imp.ArrayValue VName
mem Space
space PrimType
t Signedness
signedness ([SubExp] -> ValueDesc) -> [SubExp] -> ValueDesc
forall a b. (a -> b) -> a -> b
$ (Ext SubExp -> SubExp) -> [Ext SubExp] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map Ext SubExp -> SubExp
f ([Ext SubExp] -> [SubExp]) -> [Ext SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ ShapeBase (Ext SubExp) -> [Ext SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase (Ext SubExp)
shape
        where
          f :: Ext SubExp -> SubExp
f (Free SubExp
v) = SubExp
v
          f (Ext Int
i) = VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Int -> VName
nthOut Int
i
      mkValueDesc Int
i Signedness
signedness (MemPrim PrimType
bt) =
        ValueDesc -> ImpM rep r op ValueDesc
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ValueDesc -> ImpM rep r op ValueDesc)
-> ValueDesc -> ImpM rep r op ValueDesc
forall a b. (a -> b) -> a -> b
$ PrimType -> Signedness -> VName -> ValueDesc
Imp.ScalarValue PrimType
bt Signedness
signedness (VName -> ValueDesc) -> VName -> ValueDesc
forall a b. (a -> b) -> a -> b
$ Int -> VName
nthOut Int
i
      mkValueDesc Int
_ Signedness
_ MemAcc {} =
        [Char] -> ImpM rep r op ValueDesc
forall a. HasCallStack => [Char] -> a
error [Char]
"mkValueDesc: unexpected MemAcc output."
      mkValueDesc Int
_ Signedness
_ MemMem {} =
        [Char] -> ImpM rep r op ValueDesc
forall a. HasCallStack => [Char] -> a
error [Char]
"mkValueDesc: unexpected MemMem output."

      mkExts :: Int
-> [EntryResult]
-> [RetTypeMem]
-> ImpM rep r op [(Uniqueness, ExternalValue)]
mkExts Int
i (EntryResult Uniqueness
u et :: EntryPointType
et@(TypeOpaque Name
desc) : [EntryResult]
epts) [RetTypeMem]
rets = do
        let signs :: [Signedness]
signs = OpaqueTypes -> EntryPointType -> [Signedness]
entryPointSignedness OpaqueTypes
types EntryPointType
et
            n :: Int
n = OpaqueTypes -> EntryPointType -> Int
entryPointSize OpaqueTypes
types EntryPointType
et
            ([RetTypeMem]
rets', [RetTypeMem]
rest) = Int -> [RetTypeMem] -> ([RetTypeMem], [RetTypeMem])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
n [RetTypeMem]
rets
        [ValueDesc]
vds <- [(Int, Signedness, RetTypeMem)]
-> ((Int, Signedness, RetTypeMem) -> ImpM rep r op ValueDesc)
-> ImpM rep r op [ValueDesc]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Int]
-> [Signedness] -> [RetTypeMem] -> [(Int, Signedness, RetTypeMem)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Int
i ..] [Signedness]
signs [RetTypeMem]
rets') (((Int, Signedness, RetTypeMem) -> ImpM rep r op ValueDesc)
 -> ImpM rep r op [ValueDesc])
-> ((Int, Signedness, RetTypeMem) -> ImpM rep r op ValueDesc)
-> ImpM rep r op [ValueDesc]
forall a b. (a -> b) -> a -> b
$ \(Int
j, Signedness
s, RetTypeMem
r) -> Int -> Signedness -> RetTypeMem -> ImpM rep r op ValueDesc
mkValueDesc Int
j Signedness
s RetTypeMem
r
        ((Uniqueness
u, Name -> [ValueDesc] -> ExternalValue
Imp.OpaqueValue Name
desc [ValueDesc]
vds) :) ([(Uniqueness, ExternalValue)] -> [(Uniqueness, ExternalValue)])
-> ImpM rep r op [(Uniqueness, ExternalValue)]
-> ImpM rep r op [(Uniqueness, ExternalValue)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int
-> [EntryResult]
-> [RetTypeMem]
-> ImpM rep r op [(Uniqueness, ExternalValue)]
mkExts (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
n) [EntryResult]
epts [RetTypeMem]
rest
      mkExts Int
i (EntryResult Uniqueness
u (TypeTransparent (ValueType Signedness
s Rank
_ PrimType
_)) : [EntryResult]
epts) (RetTypeMem
ret : [RetTypeMem]
rets) = do
        ValueDesc
vd <- Int -> Signedness -> RetTypeMem -> ImpM rep r op ValueDesc
mkValueDesc Int
i Signedness
s RetTypeMem
ret
        ((Uniqueness
u, ValueDesc -> ExternalValue
Imp.TransparentValue ValueDesc
vd) :) ([(Uniqueness, ExternalValue)] -> [(Uniqueness, ExternalValue)])
-> ImpM rep r op [(Uniqueness, ExternalValue)]
-> ImpM rep r op [(Uniqueness, ExternalValue)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int
-> [EntryResult]
-> [RetTypeMem]
-> ImpM rep r op [(Uniqueness, ExternalValue)]
mkExts (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [EntryResult]
epts [RetTypeMem]
rets
      mkExts Int
_ [EntryResult]
_ [RetTypeMem]
_ = [(Uniqueness, ExternalValue)]
-> ImpM rep r op [(Uniqueness, ExternalValue)]
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []

  Int
-> [EntryResult]
-> [RetTypeMem]
-> ImpM rep r op [(Uniqueness, ExternalValue)]
mkExts ([RetTypeMem] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [RetTypeMem]
ctx_rts) [EntryResult]
orig_epts [RetTypeMem]
val_rts

compileOutParams ::
  (Mem rep inner) =>
  OpaqueTypes ->
  [RetType rep] ->
  Maybe [EntryResult] ->
  ImpM rep r op (Maybe [(Uniqueness, Imp.ExternalValue)], [Imp.Param], [ValueDestination])
compileOutParams :: forall rep (inner :: * -> *) r op.
Mem rep inner =>
OpaqueTypes
-> [RetType rep]
-> Maybe [EntryResult]
-> ImpM
     rep
     r
     op
     (Maybe [(Uniqueness, ExternalValue)], [Param], [ValueDestination])
compileOutParams OpaqueTypes
types [RetType rep]
orig_rts Maybe [EntryResult]
maybe_orig_epts = do
  ([Maybe Param]
maybe_params, [ValueDestination]
dests) <- (RetTypeMem -> ImpM rep r op (Maybe Param, ValueDestination))
-> [RetTypeMem]
-> ImpM rep r op ([Maybe Param], [ValueDestination])
forall (m :: * -> *) a b c.
Applicative m =>
(a -> m (b, c)) -> [a] -> m ([b], [c])
mapAndUnzipM RetTypeMem -> ImpM rep r op (Maybe Param, ValueDestination)
forall rep r op.
RetTypeMem -> ImpM rep r op (Maybe Param, ValueDestination)
compileOutParam [RetType rep]
[RetTypeMem]
orig_rts
  Maybe [(Uniqueness, ExternalValue)]
evs <- case Maybe [EntryResult]
maybe_orig_epts of
    Just [EntryResult]
orig_epts ->
      [(Uniqueness, ExternalValue)]
-> Maybe [(Uniqueness, ExternalValue)]
forall a. a -> Maybe a
Just ([(Uniqueness, ExternalValue)]
 -> Maybe [(Uniqueness, ExternalValue)])
-> ImpM rep r op [(Uniqueness, ExternalValue)]
-> ImpM rep r op (Maybe [(Uniqueness, ExternalValue)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> OpaqueTypes
-> [RetType rep]
-> [EntryResult]
-> [Maybe Param]
-> ImpM rep r op [(Uniqueness, ExternalValue)]
forall rep (inner :: * -> *) r op.
Mem rep inner =>
OpaqueTypes
-> [RetType rep]
-> [EntryResult]
-> [Maybe Param]
-> ImpM rep r op [(Uniqueness, ExternalValue)]
compileExternalValues OpaqueTypes
types [RetType rep]
orig_rts [EntryResult]
orig_epts [Maybe Param]
maybe_params
    Maybe [EntryResult]
Nothing -> Maybe [(Uniqueness, ExternalValue)]
-> ImpM rep r op (Maybe [(Uniqueness, ExternalValue)])
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe [(Uniqueness, ExternalValue)]
forall a. Maybe a
Nothing
  (Maybe [(Uniqueness, ExternalValue)], [Param], [ValueDestination])
-> ImpM
     rep
     r
     op
     (Maybe [(Uniqueness, ExternalValue)], [Param], [ValueDestination])
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe [(Uniqueness, ExternalValue)]
evs, [Maybe Param] -> [Param]
forall a. [Maybe a] -> [a]
catMaybes [Maybe Param]
maybe_params, [ValueDestination]
dests)

compileFunDef ::
  (Mem rep inner) =>
  OpaqueTypes ->
  FunDef rep ->
  ImpM rep r op ()
compileFunDef :: forall rep (inner :: * -> *) r op.
Mem rep inner =>
OpaqueTypes -> FunDef rep -> ImpM rep r op ()
compileFunDef OpaqueTypes
types (FunDef Maybe EntryPoint
entry Attrs
_ Name
fname [(RetType rep, RetAls)]
rettype [FParam rep]
params Body rep
body) =
  (Env rep r op -> Env rep r op)
-> ImpM rep r op () -> ImpM rep r op ()
forall a.
(Env rep r op -> Env rep r op)
-> ImpM rep r op a -> ImpM rep r op a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (\Env rep r op
env -> Env rep r op
env {envFunction :: Maybe Name
envFunction = Maybe Name
name_entry Maybe Name -> Maybe Name -> Maybe Name
forall a. Maybe a -> Maybe a -> Maybe a
forall (m :: * -> *) a. MonadPlus m => m a -> m a -> m a
`mplus` Name -> Maybe Name
forall a. a -> Maybe a
Just Name
fname}) (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ do
    (([Param]
outparams, [Param]
inparams, Maybe [(Uniqueness, ExternalValue)]
results, Maybe [((Name, Uniqueness), ExternalValue)]
args), Code op
body') <- ImpM
  rep
  r
  op
  ([Param], [Param], Maybe [(Uniqueness, ExternalValue)],
   Maybe [((Name, Uniqueness), ExternalValue)])
-> ImpM
     rep
     r
     op
     (([Param], [Param], Maybe [(Uniqueness, ExternalValue)],
       Maybe [((Name, Uniqueness), ExternalValue)]),
      Code op)
forall rep r op a. ImpM rep r op a -> ImpM rep r op (a, Code op)
collect' ImpM
  rep
  r
  op
  ([Param], [Param], Maybe [(Uniqueness, ExternalValue)],
   Maybe [((Name, Uniqueness), ExternalValue)])
compile
    let entry' :: Maybe EntryPoint
entry' = case (Maybe Name
name_entry, Maybe [(Uniqueness, ExternalValue)]
results, Maybe [((Name, Uniqueness), ExternalValue)]
args) of
          (Just Name
name_entry', Just [(Uniqueness, ExternalValue)]
results', Just [((Name, Uniqueness), ExternalValue)]
args') ->
            EntryPoint -> Maybe EntryPoint
forall a. a -> Maybe a
Just (EntryPoint -> Maybe EntryPoint) -> EntryPoint -> Maybe EntryPoint
forall a b. (a -> b) -> a -> b
$ Name
-> [(Uniqueness, ExternalValue)]
-> [((Name, Uniqueness), ExternalValue)]
-> EntryPoint
Imp.EntryPoint Name
name_entry' [(Uniqueness, ExternalValue)]
results' [((Name, Uniqueness), ExternalValue)]
args'
          (Maybe Name, Maybe [(Uniqueness, ExternalValue)],
 Maybe [((Name, Uniqueness), ExternalValue)])
_ ->
            Maybe EntryPoint
forall a. Maybe a
Nothing
    Name -> Function op -> ImpM rep r op ()
forall op rep r. Name -> Function op -> ImpM rep r op ()
emitFunction Name
fname (Function op -> ImpM rep r op ())
-> Function op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Maybe EntryPoint -> [Param] -> [Param] -> Code op -> Function op
forall a.
Maybe EntryPoint -> [Param] -> [Param] -> Code a -> FunctionT a
Imp.Function Maybe EntryPoint
entry' [Param]
outparams [Param]
inparams Code op
body'
  where
    (Maybe Name
name_entry, Maybe [EntryParam]
params_entry, Maybe [EntryResult]
ret_entry) = case Maybe EntryPoint
entry of
      Maybe EntryPoint
Nothing -> (Maybe Name
forall a. Maybe a
Nothing, Maybe [EntryParam]
forall a. Maybe a
Nothing, Maybe [EntryResult]
forall a. Maybe a
Nothing)
      Just (Name
x, [EntryParam]
y, [EntryResult]
z) -> (Name -> Maybe Name
forall a. a -> Maybe a
Just Name
x, [EntryParam] -> Maybe [EntryParam]
forall a. a -> Maybe a
Just [EntryParam]
y, [EntryResult] -> Maybe [EntryResult]
forall a. a -> Maybe a
Just [EntryResult]
z)
    compile :: ImpM
  rep
  r
  op
  ([Param], [Param], Maybe [(Uniqueness, ExternalValue)],
   Maybe [((Name, Uniqueness), ExternalValue)])
compile = do
      ([Param]
inparams, [ArrayDecl]
arrayds, Maybe [((Name, Uniqueness), ExternalValue)]
args) <- OpaqueTypes
-> [FParam rep]
-> Maybe [EntryParam]
-> ImpM
     rep
     r
     op
     ([Param], [ArrayDecl], Maybe [((Name, Uniqueness), ExternalValue)])
forall rep (inner :: * -> *) r op.
Mem rep inner =>
OpaqueTypes
-> [FParam rep]
-> Maybe [EntryParam]
-> ImpM
     rep
     r
     op
     ([Param], [ArrayDecl], Maybe [((Name, Uniqueness), ExternalValue)])
compileInParams OpaqueTypes
types [FParam rep]
params Maybe [EntryParam]
params_entry
      (Maybe [(Uniqueness, ExternalValue)]
results, [Param]
outparams, [ValueDestination]
dests) <- OpaqueTypes
-> [RetType rep]
-> Maybe [EntryResult]
-> ImpM
     rep
     r
     op
     (Maybe [(Uniqueness, ExternalValue)], [Param], [ValueDestination])
forall rep (inner :: * -> *) r op.
Mem rep inner =>
OpaqueTypes
-> [RetType rep]
-> Maybe [EntryResult]
-> ImpM
     rep
     r
     op
     (Maybe [(Uniqueness, ExternalValue)], [Param], [ValueDestination])
compileOutParams OpaqueTypes
types (((RetTypeMem, RetAls) -> RetTypeMem)
-> [(RetTypeMem, RetAls)] -> [RetTypeMem]
forall a b. (a -> b) -> [a] -> [b]
map (RetTypeMem, RetAls) -> RetTypeMem
forall a b. (a, b) -> a
fst [(RetType rep, RetAls)]
[(RetTypeMem, RetAls)]
rettype) Maybe [EntryResult]
ret_entry
      [FParam rep] -> ImpM rep r op ()
forall rep (inner :: * -> *) r op.
Mem rep inner =>
[FParam rep] -> ImpM rep r op ()
addFParams [FParam rep]
params
      [ArrayDecl] -> ImpM rep r op ()
forall rep r op. [ArrayDecl] -> ImpM rep r op ()
addArrays [ArrayDecl]
arrayds

      let Body BodyDec rep
_ Stms rep
stms Result
ses = Body rep
body
      Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms (Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
ses) Stms rep
stms (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
        [(ValueDestination, SubExpRes)]
-> ((ValueDestination, SubExpRes) -> ImpM rep r op ())
-> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([ValueDestination] -> Result -> [(ValueDestination, SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip [ValueDestination]
dests Result
ses) (((ValueDestination, SubExpRes) -> ImpM rep r op ())
 -> ImpM rep r op ())
-> ((ValueDestination, SubExpRes) -> ImpM rep r op ())
-> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
          \(ValueDestination
d, SubExpRes Certs
_ SubExp
se) -> ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIMDest ValueDestination
d [] SubExp
se []

      ([Param], [Param], Maybe [(Uniqueness, ExternalValue)],
 Maybe [((Name, Uniqueness), ExternalValue)])
-> ImpM
     rep
     r
     op
     ([Param], [Param], Maybe [(Uniqueness, ExternalValue)],
      Maybe [((Name, Uniqueness), ExternalValue)])
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Param]
outparams, [Param]
inparams, Maybe [(Uniqueness, ExternalValue)]
results, Maybe [((Name, Uniqueness), ExternalValue)]
args)

compileBody :: Pat (LetDec rep) -> Body rep -> ImpM rep r op ()
compileBody :: forall rep r op. Pat (LetDec rep) -> Body rep -> ImpM rep r op ()
compileBody Pat (LetDec rep)
pat (Body BodyDec rep
_ Stms rep
stms Result
ses) = do
  [ValueDestination]
dests <- Pat (LetDec rep) -> ImpM rep r op [ValueDestination]
forall rep r op.
Pat (LetDec rep) -> ImpM rep r op [ValueDestination]
destinationFromPat Pat (LetDec rep)
pat
  Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms (Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
ses) Stms rep
stms (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
    [(ValueDestination, SubExpRes)]
-> ((ValueDestination, SubExpRes) -> ImpM rep r op ())
-> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([ValueDestination] -> Result -> [(ValueDestination, SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip [ValueDestination]
dests Result
ses) (((ValueDestination, SubExpRes) -> ImpM rep r op ())
 -> ImpM rep r op ())
-> ((ValueDestination, SubExpRes) -> ImpM rep r op ())
-> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
      \(ValueDestination
d, SubExpRes Certs
_ SubExp
se) -> ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIMDest ValueDestination
d [] SubExp
se []

compileBody' :: [Param dec] -> Body rep -> ImpM rep r op ()
compileBody' :: forall dec rep r op. [Param dec] -> Body rep -> ImpM rep r op ()
compileBody' [Param dec]
params (Body BodyDec rep
_ Stms rep
stms Result
ses) =
  Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms (Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
ses) Stms rep
stms (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
    [(Param dec, SubExpRes)]
-> ((Param dec, SubExpRes) -> ImpM rep r op ()) -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param dec] -> Result -> [(Param dec, SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param dec]
params Result
ses) (((Param dec, SubExpRes) -> ImpM rep r op ()) -> ImpM rep r op ())
-> ((Param dec, SubExpRes) -> ImpM rep r op ()) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
      \(Param dec
param, SubExpRes Certs
_ SubExp
se) -> VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
param) [] SubExp
se []

compileLoopBody :: (Typed dec) => [Param dec] -> Body rep -> ImpM rep r op ()
compileLoopBody :: forall dec rep r op.
Typed dec =>
[Param dec] -> Body rep -> ImpM rep r op ()
compileLoopBody [Param dec]
mergeparams (Body BodyDec rep
_ Stms rep
stms Result
ses) = do
  -- We cannot write the results to the merge parameters immediately,
  -- as some of the results may actually *be* merge parameters, and
  -- would thus be clobbered.  Therefore, we first copy to new
  -- variables mirroring the merge parameters, and then copy this
  -- buffer to the merge parameters.  This is efficient, because the
  -- operations are all scalar operations.
  [VName]
tmpnames <- (Param dec -> ImpM rep r op VName)
-> [Param dec] -> ImpM rep r op [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ([Char] -> ImpM rep r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName ([Char] -> ImpM rep r op VName)
-> (Param dec -> [Char]) -> Param dec -> ImpM rep r op VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
"_tmp") ShowS -> (Param dec -> [Char]) -> Param dec -> [Char]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> [Char]
baseString (VName -> [Char]) -> (Param dec -> VName) -> Param dec -> [Char]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param dec -> VName
forall dec. Param dec -> VName
paramName) [Param dec]
mergeparams
  Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms (Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
ses) Stms rep
stms (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ do
    [ImpM rep r op ()]
copy_to_merge_params <- [(Param dec, VName, SubExpRes)]
-> ((Param dec, VName, SubExpRes)
    -> ImpM rep r op (ImpM rep r op ()))
-> ImpM rep r op [ImpM rep r op ()]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Param dec] -> [VName] -> Result -> [(Param dec, VName, SubExpRes)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Param dec]
mergeparams [VName]
tmpnames Result
ses) (((Param dec, VName, SubExpRes)
  -> ImpM rep r op (ImpM rep r op ()))
 -> ImpM rep r op [ImpM rep r op ()])
-> ((Param dec, VName, SubExpRes)
    -> ImpM rep r op (ImpM rep r op ()))
-> ImpM rep r op [ImpM rep r op ()]
forall a b. (a -> b) -> a -> b
$ \(Param dec
p, VName
tmp, SubExpRes Certs
_ SubExp
se) ->
      case Param dec -> Type
forall t. Typed t => t -> Type
typeOf Param dec
p of
        Prim PrimType
pt -> do
          Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Volatility -> PrimType -> Code op
forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
tmp Volatility
Imp.Nonvolatile PrimType
pt
          Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
tmp (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
pt SubExp
se
          ImpM rep r op () -> ImpM rep r op (ImpM rep r op ())
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ImpM rep r op () -> ImpM rep r op (ImpM rep r op ()))
-> ImpM rep r op () -> ImpM rep r op (ImpM rep r op ())
forall a b. (a -> b) -> a -> b
$ Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar (Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p) (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
tmp PrimType
pt
        Mem Space
space | Var VName
v <- SubExp
se -> do
          Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Code op
forall a. VName -> Space -> Code a
Imp.DeclareMem VName
tmp Space
space
          Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> Space -> Code op
forall a. VName -> VName -> Space -> Code a
Imp.SetMem VName
tmp VName
v Space
space
          ImpM rep r op () -> ImpM rep r op (ImpM rep r op ())
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ImpM rep r op () -> ImpM rep r op (ImpM rep r op ()))
-> ImpM rep r op () -> ImpM rep r op (ImpM rep r op ())
forall a b. (a -> b) -> a -> b
$ Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> Space -> Code op
forall a. VName -> VName -> Space -> Code a
Imp.SetMem (Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p) VName
tmp Space
space
        Type
_ -> ImpM rep r op () -> ImpM rep r op (ImpM rep r op ())
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ImpM rep r op () -> ImpM rep r op (ImpM rep r op ()))
-> ImpM rep r op () -> ImpM rep r op (ImpM rep r op ())
forall a b. (a -> b) -> a -> b
$ () -> ImpM rep r op ()
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    [ImpM rep r op ()] -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ [ImpM rep r op ()]
copy_to_merge_params

compileStms :: Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms :: forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
alive_after_stms Stms rep
all_stms ImpM rep r op ()
m = do
  StmsCompiler rep r op
cb <- (Env rep r op -> StmsCompiler rep r op)
-> ImpM rep r op (StmsCompiler rep r op)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env rep r op -> StmsCompiler rep r op
forall rep r op. Env rep r op -> StmsCompiler rep r op
envStmsCompiler
  StmsCompiler rep r op
cb Names
alive_after_stms Stms rep
all_stms ImpM rep r op ()
m

defCompileStms ::
  (Mem rep inner, FreeIn op) =>
  Names ->
  Stms rep ->
  ImpM rep r op () ->
  ImpM rep r op ()
defCompileStms :: forall rep (inner :: * -> *) op r.
(Mem rep inner, FreeIn op) =>
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
defCompileStms Names
alive_after_stms Stms rep
all_stms ImpM rep r op ()
m =
  -- We keep track of any memory blocks produced by the statements,
  -- and after the last time that memory block is used, we insert a
  -- Free.  This is very conservative, but can cut down on lifetimes
  -- in some cases.
  ImpM rep r op Names -> ImpM rep r op ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ImpM rep r op Names -> ImpM rep r op ())
-> ImpM rep r op Names -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Set (VName, Space) -> [Stm rep] -> ImpM rep r op Names
compileStms' Set (VName, Space)
forall a. Monoid a => a
mempty ([Stm rep] -> ImpM rep r op Names)
-> [Stm rep] -> ImpM rep r op Names
forall a b. (a -> b) -> a -> b
$ Stms rep -> [Stm rep]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms rep
all_stms
  where
    compileStms' :: Set (VName, Space) -> [Stm rep] -> ImpM rep r op Names
compileStms' Set (VName, Space)
allocs (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux Exp rep
e : [Stm rep]
bs) = do
      Maybe (Exp rep) -> [PatElem (LetDec rep)] -> ImpM rep r op ()
forall rep (inner :: * -> *) r op.
Mem rep inner =>
Maybe (Exp rep) -> [PatElem (LetDec rep)] -> ImpM rep r op ()
dVars (Exp rep -> Maybe (Exp rep)
forall a. a -> Maybe a
Just Exp rep
e) (Pat (LetDec rep) -> [PatElem (LetDec rep)]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat)

      Code op
e_code <-
        Attrs -> ImpM rep r op (Code op) -> ImpM rep r op (Code op)
forall rep r op a. Attrs -> ImpM rep r op a -> ImpM rep r op a
localAttrs (StmAux (ExpDec rep) -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux (ExpDec rep)
aux) (ImpM rep r op (Code op) -> ImpM rep r op (Code op))
-> ImpM rep r op (Code op) -> ImpM rep r op (Code op)
forall a b. (a -> b) -> a -> b
$
          ImpM rep r op () -> ImpM rep r op (Code op)
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect (ImpM rep r op () -> ImpM rep r op (Code op))
-> ImpM rep r op () -> ImpM rep r op (Code op)
forall a b. (a -> b) -> a -> b
$
            Pat (LetDec rep) -> Exp rep -> ImpM rep r op ()
forall rep r op. Pat (LetDec rep) -> Exp rep -> ImpM rep r op ()
compileExp Pat (LetDec rep)
pat Exp rep
e
      (Names
live_after, Code op
bs_code) <- ImpM rep r op Names -> ImpM rep r op (Names, Code op)
forall rep r op a. ImpM rep r op a -> ImpM rep r op (a, Code op)
collect' (ImpM rep r op Names -> ImpM rep r op (Names, Code op))
-> ImpM rep r op Names -> ImpM rep r op (Names, Code op)
forall a b. (a -> b) -> a -> b
$ Set (VName, Space) -> [Stm rep] -> ImpM rep r op Names
compileStms' (Pat (LetDec rep) -> Set (VName, Space)
patternAllocs Pat (LetDec rep)
pat Set (VName, Space) -> Set (VName, Space) -> Set (VName, Space)
forall a. Semigroup a => a -> a -> a
<> Set (VName, Space)
allocs) [Stm rep]
bs
      let dies_here :: VName -> Bool
dies_here VName
v =
            (VName
v VName -> Names -> Bool
`notNameIn` Names
live_after) Bool -> Bool -> Bool
&& (VName
v VName -> Names -> Bool
`nameIn` Code op -> Names
forall a. FreeIn a => a -> Names
freeIn Code op
e_code)
          to_free :: Set (VName, Space)
to_free = ((VName, Space) -> Bool)
-> Set (VName, Space) -> Set (VName, Space)
forall a. (a -> Bool) -> Set a -> Set a
S.filter (VName -> Bool
dies_here (VName -> Bool)
-> ((VName, Space) -> VName) -> (VName, Space) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, Space) -> VName
forall a b. (a, b) -> a
fst) Set (VName, Space)
allocs

      Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit Code op
e_code
      ((VName, Space) -> ImpM rep r op ())
-> Set (VName, Space) -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ())
-> ((VName, Space) -> Code op)
-> (VName, Space)
-> ImpM rep r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> Space -> Code op) -> (VName, Space) -> Code op
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry VName -> Space -> Code op
forall a. VName -> Space -> Code a
Imp.Free) Set (VName, Space)
to_free
      Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit Code op
bs_code

      Names -> ImpM rep r op Names
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Names -> ImpM rep r op Names) -> Names -> ImpM rep r op Names
forall a b. (a -> b) -> a -> b
$ Code op -> Names
forall a. FreeIn a => a -> Names
freeIn Code op
e_code Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
live_after
    compileStms' Set (VName, Space)
_ [] = do
      Code op
code <- ImpM rep r op () -> ImpM rep r op (Code op)
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect ImpM rep r op ()
m
      Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit Code op
code
      Names -> ImpM rep r op Names
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Names -> ImpM rep r op Names) -> Names -> ImpM rep r op Names
forall a b. (a -> b) -> a -> b
$ Code op -> Names
forall a. FreeIn a => a -> Names
freeIn Code op
code Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
alive_after_stms

    patternAllocs :: Pat (LetDec rep) -> Set (VName, Space)
patternAllocs = [(VName, Space)] -> Set (VName, Space)
forall a. Ord a => [a] -> Set a
S.fromList ([(VName, Space)] -> Set (VName, Space))
-> (Pat (LetDec rep) -> [(VName, Space)])
-> Pat (LetDec rep)
-> Set (VName, Space)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElem (LetDec rep) -> Maybe (VName, Space))
-> [PatElem (LetDec rep)] -> [(VName, Space)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe PatElem (LetDec rep) -> Maybe (VName, Space)
forall {dec}. Typed dec => PatElem dec -> Maybe (VName, Space)
isMemPatElem ([PatElem (LetDec rep)] -> [(VName, Space)])
-> (Pat (LetDec rep) -> [PatElem (LetDec rep)])
-> Pat (LetDec rep)
-> [(VName, Space)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec rep) -> [PatElem (LetDec rep)]
forall dec. Pat dec -> [PatElem dec]
patElems
    isMemPatElem :: PatElem dec -> Maybe (VName, Space)
isMemPatElem PatElem dec
pe = case PatElem dec -> Type
forall dec. Typed dec => PatElem dec -> Type
patElemType PatElem dec
pe of
      Mem Space
space -> (VName, Space) -> Maybe (VName, Space)
forall a. a -> Maybe a
Just (PatElem dec -> VName
forall dec. PatElem dec -> VName
patElemName PatElem dec
pe, Space
space)
      Type
_ -> Maybe (VName, Space)
forall a. Maybe a
Nothing

compileExp :: Pat (LetDec rep) -> Exp rep -> ImpM rep r op ()
compileExp :: forall rep r op. Pat (LetDec rep) -> Exp rep -> ImpM rep r op ()
compileExp Pat (LetDec rep)
pat Exp rep
e = do
  Pat (LetDec rep) -> Exp rep -> ImpM rep r op ()
ec <- (Env rep r op -> Pat (LetDec rep) -> Exp rep -> ImpM rep r op ())
-> ImpM rep r op (Pat (LetDec rep) -> Exp rep -> ImpM rep r op ())
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env rep r op -> Pat (LetDec rep) -> Exp rep -> ImpM rep r op ()
forall rep r op. Env rep r op -> ExpCompiler rep r op
envExpCompiler
  Pat (LetDec rep) -> Exp rep -> ImpM rep r op ()
ec Pat (LetDec rep)
pat Exp rep
e

-- | Generate an expression that is true if the subexpressions match
-- the case pasttern.
caseMatch :: [SubExp] -> [Maybe PrimValue] -> Imp.TExp Bool
caseMatch :: [SubExp] -> [Maybe PrimValue] -> TExp Bool
caseMatch [SubExp]
ses [Maybe PrimValue]
vs = (TExp Bool -> TExp Bool -> TExp Bool)
-> TExp Bool -> [TExp Bool] -> TExp Bool
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl TExp Bool -> TExp Bool -> TExp Bool
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
(.&&.) TExp Bool
forall v. TPrimExp Bool v
true ((SubExp -> Maybe PrimValue -> TExp Bool)
-> [SubExp] -> [Maybe PrimValue] -> [TExp Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SubExp -> Maybe PrimValue -> TExp Bool
forall {a}. ToExp a => a -> Maybe PrimValue -> TExp Bool
cmp [SubExp]
ses [Maybe PrimValue]
vs)
  where
    cmp :: a -> Maybe PrimValue -> TExp Bool
cmp a
se (Just PrimValue
v) = Exp -> TExp Bool
forall v. PrimExp v -> TPrimExp Bool v
isBool (Exp -> TExp Bool) -> Exp -> TExp Bool
forall a b. (a -> b) -> a -> b
$ PrimType -> a -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' (PrimValue -> PrimType
primValueType PrimValue
v) a
se Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
~==~ PrimValue -> Exp
forall v. PrimValue -> PrimExp v
ValueExp PrimValue
v
    cmp a
_ Maybe PrimValue
Nothing = TExp Bool
forall v. TPrimExp Bool v
true

defCompileExp ::
  (Mem rep inner) =>
  Pat (LetDec rep) ->
  Exp rep ->
  ImpM rep r op ()
defCompileExp :: forall rep (inner :: * -> *) r op.
Mem rep inner =>
Pat (LetDec rep) -> Exp rep -> ImpM rep r op ()
defCompileExp Pat (LetDec rep)
pat (Match [SubExp]
ses [Case (Body rep)]
cases Body rep
defbody MatchDec (BranchType rep)
_) =
  (Case (Body rep) -> ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> [Case (Body rep)] -> ImpM rep r op ()
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Case (Body rep) -> ImpM rep r op () -> ImpM rep r op ()
f (Pat (LetDec rep) -> Body rep -> ImpM rep r op ()
forall rep r op. Pat (LetDec rep) -> Body rep -> ImpM rep r op ()
compileBody Pat (LetDec rep)
pat Body rep
defbody) [Case (Body rep)]
cases
  where
    f :: Case (Body rep) -> ImpM rep r op () -> ImpM rep r op ()
f (Case [Maybe PrimValue]
vs Body rep
body) = TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf ([SubExp] -> [Maybe PrimValue] -> TExp Bool
caseMatch [SubExp]
ses [Maybe PrimValue]
vs) (Pat (LetDec rep) -> Body rep -> ImpM rep r op ()
forall rep r op. Pat (LetDec rep) -> Body rep -> ImpM rep r op ()
compileBody Pat (LetDec rep)
pat Body rep
body)
defCompileExp Pat (LetDec rep)
pat (Apply Name
fname [(SubExp, Diet)]
args [(RetType rep, RetAls)]
_ (Safety, SrcLoc, [SrcLoc])
_) = do
  [ValueDestination]
dest <- Pat (LetDec rep) -> ImpM rep r op [ValueDestination]
forall rep r op.
Pat (LetDec rep) -> ImpM rep r op [ValueDestination]
destinationFromPat Pat (LetDec rep)
pat
  [VName]
targets <- [ValueDestination] -> ImpM rep r op [VName]
forall rep r op. [ValueDestination] -> ImpM rep r op [VName]
funcallTargets [ValueDestination]
dest
  [Arg]
args' <- [Maybe Arg] -> [Arg]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe Arg] -> [Arg])
-> ImpM rep r op [Maybe Arg] -> ImpM rep r op [Arg]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((SubExp, Diet) -> ImpM rep r op (Maybe Arg))
-> [(SubExp, Diet)] -> ImpM rep r op [Maybe Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SubExp, Diet) -> ImpM rep r op (Maybe Arg)
forall {m :: * -> *} {t} {b}.
(Monad m, HasScope t m) =>
(SubExp, b) -> m (Maybe Arg)
compileArg [(SubExp, Diet)]
args
  Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Name -> [Arg] -> Code op
forall a. [VName] -> Name -> [Arg] -> Code a
Imp.Call [VName]
targets Name
fname [Arg]
args'
  where
    compileArg :: (SubExp, b) -> m (Maybe Arg)
compileArg (SubExp
se, b
_) = do
      Type
t <- SubExp -> m Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
se
      case (SubExp
se, Type
t) of
        (SubExp
_, Prim PrimType
pt) -> Maybe Arg -> m (Maybe Arg)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Arg -> m (Maybe Arg)) -> Maybe Arg -> m (Maybe Arg)
forall a b. (a -> b) -> a -> b
$ Arg -> Maybe Arg
forall a. a -> Maybe a
Just (Arg -> Maybe Arg) -> Arg -> Maybe Arg
forall a b. (a -> b) -> a -> b
$ Exp -> Arg
Imp.ExpArg (Exp -> Arg) -> Exp -> Arg
forall a b. (a -> b) -> a -> b
$ PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
pt SubExp
se
        (Var VName
v, Mem {}) -> Maybe Arg -> m (Maybe Arg)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Arg -> m (Maybe Arg)) -> Maybe Arg -> m (Maybe Arg)
forall a b. (a -> b) -> a -> b
$ Arg -> Maybe Arg
forall a. a -> Maybe a
Just (Arg -> Maybe Arg) -> Arg -> Maybe Arg
forall a b. (a -> b) -> a -> b
$ VName -> Arg
Imp.MemArg VName
v
        (SubExp, Type)
_ -> Maybe Arg -> m (Maybe Arg)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe Arg
forall a. Maybe a
Nothing
defCompileExp Pat (LetDec rep)
pat (BasicOp BasicOp
op) = Pat (LetDec rep) -> BasicOp -> ImpM rep r op ()
forall rep (inner :: * -> *) r op.
Mem rep inner =>
Pat (LetDec rep) -> BasicOp -> ImpM rep r op ()
defCompileBasicOp Pat (LetDec rep)
pat BasicOp
op
defCompileExp Pat (LetDec rep)
pat (Loop [(FParam rep, SubExp)]
merge LoopForm
form Body rep
body) = do
  Attrs
attrs <- ImpM rep r op Attrs
forall rep r op. ImpM rep r op Attrs
askAttrs
  Bool -> ImpM rep r op () -> ImpM rep r op ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Attr
"unroll" Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs) (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
    SrcLoc -> [SrcLoc] -> Text -> ImpM rep r op ()
forall loc rep r op.
Located loc =>
loc -> [loc] -> Text -> ImpM rep r op ()
warn (SrcLoc
forall a. IsLocation a => a
noLoc :: SrcLoc) [] Text
"#[unroll] on loop with unknown number of iterations." -- FIXME: no location.
  [FParam rep] -> ImpM rep r op ()
forall rep (inner :: * -> *) r op.
Mem rep inner =>
[FParam rep] -> ImpM rep r op ()
dFParams [FParam rep]
[Param FParamMem]
params
  [(Param FParamMem, SubExp)]
-> ((Param FParamMem, SubExp) -> ImpM rep r op ())
-> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(FParam rep, SubExp)]
[(Param FParamMem, SubExp)]
merge (((Param FParamMem, SubExp) -> ImpM rep r op ())
 -> ImpM rep r op ())
-> ((Param FParamMem, SubExp) -> ImpM rep r op ())
-> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \(Param FParamMem
p, SubExp
se) ->
    Bool -> ImpM rep r op () -> ImpM rep r op ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ((Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0) (Int -> Bool) -> Int -> Bool
forall a b. (a -> b) -> a -> b
$ Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (Type -> Int) -> Type -> Int
forall a b. (a -> b) -> a -> b
$ Param FParamMem -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param FParamMem
p) (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
      VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (Param FParamMem -> VName
forall dec. Param dec -> VName
paramName Param FParamMem
p) [] SubExp
se []

  let doBody :: ImpM rep r op ()
doBody = [Param FParamMem] -> Body rep -> ImpM rep r op ()
forall dec rep r op.
Typed dec =>
[Param dec] -> Body rep -> ImpM rep r op ()
compileLoopBody [Param FParamMem]
params Body rep
body

  case LoopForm
form of
    ForLoop VName
i IntType
_ SubExp
bound -> do
      Exp
bound' <- SubExp -> ImpM rep r op Exp
forall a rep r op. ToExp a => a -> ImpM rep r op Exp
forall rep r op. SubExp -> ImpM rep r op Exp
toExp SubExp
bound
      VName -> Exp -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op.
VName -> Exp -> ImpM rep r op () -> ImpM rep r op ()
sFor' VName
i Exp
bound' ImpM rep r op ()
doBody
    WhileLoop VName
cond ->
      TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhile (Exp -> TExp Bool
forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp (Exp -> TExp Bool) -> Exp -> TExp Bool
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
cond PrimType
Bool) ImpM rep r op ()
doBody

  [ValueDestination]
pat_dests <- Pat (LetDec rep) -> ImpM rep r op [ValueDestination]
forall rep r op.
Pat (LetDec rep) -> ImpM rep r op [ValueDestination]
destinationFromPat Pat (LetDec rep)
pat
  [(ValueDestination, SubExp)]
-> ((ValueDestination, SubExp) -> ImpM rep r op ())
-> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([ValueDestination] -> [SubExp] -> [(ValueDestination, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [ValueDestination]
pat_dests ([SubExp] -> [(ValueDestination, SubExp)])
-> [SubExp] -> [(ValueDestination, SubExp)]
forall a b. (a -> b) -> a -> b
$ ((Param FParamMem, SubExp) -> SubExp)
-> [(Param FParamMem, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> ((Param FParamMem, SubExp) -> VName)
-> (Param FParamMem, SubExp)
-> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param FParamMem -> VName
forall dec. Param dec -> VName
paramName (Param FParamMem -> VName)
-> ((Param FParamMem, SubExp) -> Param FParamMem)
-> (Param FParamMem, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param FParamMem, SubExp) -> Param FParamMem
forall a b. (a, b) -> a
fst) [(FParam rep, SubExp)]
[(Param FParamMem, SubExp)]
merge) (((ValueDestination, SubExp) -> ImpM rep r op ())
 -> ImpM rep r op ())
-> ((ValueDestination, SubExp) -> ImpM rep r op ())
-> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \(ValueDestination
d, SubExp
r) ->
    ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIMDest ValueDestination
d [] SubExp
r []
  where
    params :: [Param FParamMem]
params = ((Param FParamMem, SubExp) -> Param FParamMem)
-> [(Param FParamMem, SubExp)] -> [Param FParamMem]
forall a b. (a -> b) -> [a] -> [b]
map (Param FParamMem, SubExp) -> Param FParamMem
forall a b. (a, b) -> a
fst [(FParam rep, SubExp)]
[(Param FParamMem, SubExp)]
merge
defCompileExp Pat (LetDec rep)
pat (WithAcc [WithAccInput rep]
inputs Lambda rep
lam) = do
  [LParam rep] -> ImpM rep r op ()
forall rep (inner :: * -> *) r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams ([LParam rep] -> ImpM rep r op ())
-> [LParam rep] -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [LParam rep]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam
  [(WithAccInput rep, Param LParamMem)]
-> ((WithAccInput rep, Param LParamMem) -> ImpM rep r op ())
-> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([WithAccInput rep]
-> [Param LParamMem] -> [(WithAccInput rep, Param LParamMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [WithAccInput rep]
inputs ([Param LParamMem] -> [(WithAccInput rep, Param LParamMem)])
-> [Param LParamMem] -> [(WithAccInput rep, Param LParamMem)]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [LParam rep]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam) (((WithAccInput rep, Param LParamMem) -> ImpM rep r op ())
 -> ImpM rep r op ())
-> ((WithAccInput rep, Param LParamMem) -> ImpM rep r op ())
-> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \((Shape
_, [VName]
arrs, Maybe (Lambda rep, [SubExp])
op), Param LParamMem
p) ->
    (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ())
-> (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s ->
      ImpState rep r op
s {stateAccs :: Map VName ([VName], Maybe (Lambda rep, [SubExp]))
stateAccs = VName
-> ([VName], Maybe (Lambda rep, [SubExp]))
-> Map VName ([VName], Maybe (Lambda rep, [SubExp]))
-> Map VName ([VName], Maybe (Lambda rep, [SubExp]))
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) ([VName]
arrs, Maybe (Lambda rep, [SubExp])
op) (Map VName ([VName], Maybe (Lambda rep, [SubExp]))
 -> Map VName ([VName], Maybe (Lambda rep, [SubExp])))
-> Map VName ([VName], Maybe (Lambda rep, [SubExp]))
-> Map VName ([VName], Maybe (Lambda rep, [SubExp]))
forall a b. (a -> b) -> a -> b
$ ImpState rep r op
-> Map VName ([VName], Maybe (Lambda rep, [SubExp]))
forall {k} rep (r :: k) op.
ImpState rep r op
-> Map VName ([VName], Maybe (Lambda rep, [SubExp]))
stateAccs ImpState rep r op
s}
  Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (Body rep -> Stms rep
forall rep. Body rep -> Stms rep
bodyStms (Body rep -> Stms rep) -> Body rep -> Stms rep
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam) (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ do
    let nonacc_res :: Result
nonacc_res = Int -> Result -> Result
forall a. Int -> [a] -> [a]
drop Int
num_accs (Body rep -> Result
forall rep. Body rep -> Result
bodyResult (Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam))
        nonacc_pat_names :: [VName]
nonacc_pat_names = Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
takeLast (Result -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
nonacc_res) (Pat (LetDec rep) -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat)
    [(VName, SubExpRes)]
-> ((VName, SubExpRes) -> ImpM rep r op ()) -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> Result -> [(VName, SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
nonacc_pat_names Result
nonacc_res) (((VName, SubExpRes) -> ImpM rep r op ()) -> ImpM rep r op ())
-> ((VName, SubExpRes) -> ImpM rep r op ()) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \(VName
v, SubExpRes Certs
_ SubExp
se) ->
      VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM VName
v [] SubExp
se []
  where
    num_accs :: Int
num_accs = [WithAccInput rep] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [WithAccInput rep]
inputs
defCompileExp Pat (LetDec rep)
pat (Op Op rep
op) = do
  Pat (LetDec rep) -> MemOp inner rep -> ImpM rep r op ()
opc <- (Env rep r op
 -> Pat (LetDec rep) -> MemOp inner rep -> ImpM rep r op ())
-> ImpM
     rep r op (Pat (LetDec rep) -> MemOp inner rep -> ImpM rep r op ())
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env rep r op -> OpCompiler rep r op
Env rep r op
-> Pat (LetDec rep) -> MemOp inner rep -> ImpM rep r op ()
forall rep r op. Env rep r op -> OpCompiler rep r op
envOpCompiler
  Pat (LetDec rep) -> MemOp inner rep -> ImpM rep r op ()
opc Pat (LetDec rep)
pat Op rep
MemOp inner rep
op

tracePrim :: T.Text -> PrimType -> SubExp -> ImpM rep r op ()
tracePrim :: forall rep r op. Text -> PrimType -> SubExp -> ImpM rep r op ()
tracePrim Text
s PrimType
t SubExp
se =
  Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ())
-> (ErrorMsg Exp -> Code op) -> ErrorMsg Exp -> ImpM rep r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ErrorMsg Exp -> Code op
forall a. ErrorMsg Exp -> Code a
Imp.TracePrint (ErrorMsg Exp -> ImpM rep r op ())
-> ErrorMsg Exp -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
    [ErrorMsgPart Exp] -> ErrorMsg Exp
forall a. [ErrorMsgPart a] -> ErrorMsg a
ErrorMsg [Text -> ErrorMsgPart Exp
forall a. Text -> ErrorMsgPart a
ErrorString (Text
s Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
": "), PrimType -> Exp -> ErrorMsgPart Exp
forall a. PrimType -> a -> ErrorMsgPart a
ErrorVal PrimType
t (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
t SubExp
se), Text -> ErrorMsgPart Exp
forall a. Text -> ErrorMsgPart a
ErrorString Text
"\n"]

traceArray :: T.Text -> PrimType -> Shape -> SubExp -> ImpM rep r op ()
traceArray :: forall rep r op.
Text -> PrimType -> Shape -> SubExp -> ImpM rep r op ()
traceArray Text
s PrimType
t Shape
shape SubExp
se = do
  Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ())
-> (ErrorMsg Exp -> Code op) -> ErrorMsg Exp -> ImpM rep r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ErrorMsg Exp -> Code op
forall a. ErrorMsg Exp -> Code a
Imp.TracePrint (ErrorMsg Exp -> ImpM rep r op ())
-> ErrorMsg Exp -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ [ErrorMsgPart Exp] -> ErrorMsg Exp
forall a. [ErrorMsgPart a] -> ErrorMsg a
ErrorMsg [Text -> ErrorMsgPart Exp
forall a. Text -> ErrorMsgPart a
ErrorString (Text
s Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
": ")]
  Shape -> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
forall rep r op.
Shape -> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest Shape
shape (([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ())
-> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
is -> do
    TV Any
arr_elem <- [Char] -> PrimType -> ImpM rep r op (TV Any)
forall {k} rep r op (t :: k).
[Char] -> PrimType -> ImpM rep r op (TV t)
dPrim [Char]
"arr_elem" PrimType
t
    VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (TV Any -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV Any
arr_elem) [] SubExp
se [TExp Int64]
is
    Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ())
-> (ErrorMsg Exp -> Code op) -> ErrorMsg Exp -> ImpM rep r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ErrorMsg Exp -> Code op
forall a. ErrorMsg Exp -> Code a
Imp.TracePrint (ErrorMsg Exp -> ImpM rep r op ())
-> ErrorMsg Exp -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ [ErrorMsgPart Exp] -> ErrorMsg Exp
forall a. [ErrorMsgPart a] -> ErrorMsg a
ErrorMsg [PrimType -> Exp -> ErrorMsgPart Exp
forall a. PrimType -> a -> ErrorMsgPart a
ErrorVal PrimType
t (TPrimExp Any VName -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TV Any -> TPrimExp Any VName
forall {k} (t :: k). TV t -> TExp t
tvExp TV Any
arr_elem)), ErrorMsgPart Exp
" "]
  Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ())
-> (ErrorMsg Exp -> Code op) -> ErrorMsg Exp -> ImpM rep r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ErrorMsg Exp -> Code op
forall a. ErrorMsg Exp -> Code a
Imp.TracePrint (ErrorMsg Exp -> ImpM rep r op ())
-> ErrorMsg Exp -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ [ErrorMsgPart Exp] -> ErrorMsg Exp
forall a. [ErrorMsgPart a] -> ErrorMsg a
ErrorMsg [ErrorMsgPart Exp
"\n"]

defCompileBasicOp ::
  (Mem rep inner) =>
  Pat (LetDec rep) ->
  BasicOp ->
  ImpM rep r op ()
defCompileBasicOp :: forall rep (inner :: * -> *) r op.
Mem rep inner =>
Pat (LetDec rep) -> BasicOp -> ImpM rep r op ()
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (SubExp SubExp
se) =
  VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe) [] SubExp
se []
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (Opaque OpaqueOp
op SubExp
se) = do
  VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe) [] SubExp
se []
  case OpaqueOp
op of
    OpaqueOp
OpaqueNil -> () -> ImpM rep r op ()
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    OpaqueTrace Text
s -> Text -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
comment (Text
"Trace: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
s) (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ do
      Type
se_t <- SubExp -> ImpM rep r op Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
se
      case Type
se_t of
        Prim PrimType
t -> Text -> PrimType -> SubExp -> ImpM rep r op ()
forall rep r op. Text -> PrimType -> SubExp -> ImpM rep r op ()
tracePrim Text
s PrimType
t SubExp
se
        Array PrimType
t Shape
shape NoUniqueness
_ -> Text -> PrimType -> Shape -> SubExp -> ImpM rep r op ()
forall rep r op.
Text -> PrimType -> Shape -> SubExp -> ImpM rep r op ()
traceArray Text
s PrimType
t Shape
shape SubExp
se
        Type
_ ->
          [SrcLoc] -> [[SrcLoc]] -> Text -> ImpM rep r op ()
forall loc rep r op.
Located loc =>
loc -> [loc] -> Text -> ImpM rep r op ()
warn [SrcLoc
forall a. Monoid a => a
mempty :: SrcLoc] [[SrcLoc]]
forall a. Monoid a => a
mempty (Text -> ImpM rep r op ()) -> Text -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
            Text
s Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
": cannot trace value of this (core) type: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Type -> Text
forall a. Pretty a => a -> Text
prettyText Type
se_t
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (UnOp UnOp
op SubExp
e) = do
  Exp
e' <- SubExp -> ImpM rep r op Exp
forall a rep r op. ToExp a => a -> ImpM rep r op Exp
forall rep r op. SubExp -> ImpM rep r op Exp
toExp SubExp
e
  PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe VName -> Exp -> ImpM rep r op ()
forall rep r op. VName -> Exp -> ImpM rep r op ()
<~~ UnOp -> Exp -> Exp
forall v. UnOp -> PrimExp v -> PrimExp v
Imp.UnOpExp UnOp
op Exp
e'
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (ConvOp ConvOp
conv SubExp
e) = do
  Exp
e' <- SubExp -> ImpM rep r op Exp
forall a rep r op. ToExp a => a -> ImpM rep r op Exp
forall rep r op. SubExp -> ImpM rep r op Exp
toExp SubExp
e
  PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe VName -> Exp -> ImpM rep r op ()
forall rep r op. VName -> Exp -> ImpM rep r op ()
<~~ ConvOp -> Exp -> Exp
forall v. ConvOp -> PrimExp v -> PrimExp v
Imp.ConvOpExp ConvOp
conv Exp
e'
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (BinOp BinOp
bop SubExp
x SubExp
y) = do
  Exp
x' <- SubExp -> ImpM rep r op Exp
forall a rep r op. ToExp a => a -> ImpM rep r op Exp
forall rep r op. SubExp -> ImpM rep r op Exp
toExp SubExp
x
  Exp
y' <- SubExp -> ImpM rep r op Exp
forall a rep r op. ToExp a => a -> ImpM rep r op Exp
forall rep r op. SubExp -> ImpM rep r op Exp
toExp SubExp
y
  PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe VName -> Exp -> ImpM rep r op ()
forall rep r op. VName -> Exp -> ImpM rep r op ()
<~~ BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp BinOp
bop Exp
x' Exp
y'
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (CmpOp CmpOp
bop SubExp
x SubExp
y) = do
  Exp
x' <- SubExp -> ImpM rep r op Exp
forall a rep r op. ToExp a => a -> ImpM rep r op Exp
forall rep r op. SubExp -> ImpM rep r op Exp
toExp SubExp
x
  Exp
y' <- SubExp -> ImpM rep r op Exp
forall a rep r op. ToExp a => a -> ImpM rep r op Exp
forall rep r op. SubExp -> ImpM rep r op Exp
toExp SubExp
y
  PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe VName -> Exp -> ImpM rep r op ()
forall rep r op. VName -> Exp -> ImpM rep r op ()
<~~ CmpOp -> Exp -> Exp -> Exp
forall v. CmpOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.CmpOpExp CmpOp
bop Exp
x' Exp
y'
defCompileBasicOp Pat (LetDec rep)
_ (Assert SubExp
e ErrorMsg SubExp
msg (SrcLoc, [SrcLoc])
loc) = do
  Exp
e' <- SubExp -> ImpM rep r op Exp
forall a rep r op. ToExp a => a -> ImpM rep r op Exp
forall rep r op. SubExp -> ImpM rep r op Exp
toExp SubExp
e
  ErrorMsg Exp
msg' <- (SubExp -> ImpM rep r op Exp)
-> ErrorMsg SubExp -> ImpM rep r op (ErrorMsg Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> ErrorMsg a -> f (ErrorMsg b)
traverse SubExp -> ImpM rep r op Exp
forall a rep r op. ToExp a => a -> ImpM rep r op Exp
forall rep r op. SubExp -> ImpM rep r op Exp
toExp ErrorMsg SubExp
msg
  Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Exp -> ErrorMsg Exp -> (SrcLoc, [SrcLoc]) -> Code op
forall a. Exp -> ErrorMsg Exp -> (SrcLoc, [SrcLoc]) -> Code a
Imp.Assert Exp
e' ErrorMsg Exp
msg' (SrcLoc, [SrcLoc])
loc

  Attrs
attrs <- ImpM rep r op Attrs
forall rep r op. ImpM rep r op Attrs
askAttrs
  Bool -> ImpM rep r op () -> ImpM rep r op ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Name -> [Attr] -> Attr
AttrComp Name
"warn" [Attr
"safety_checks"] Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs) (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
    (SrcLoc -> [SrcLoc] -> Text -> ImpM rep r op ())
-> (SrcLoc, [SrcLoc]) -> Text -> ImpM rep r op ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry SrcLoc -> [SrcLoc] -> Text -> ImpM rep r op ()
forall loc rep r op.
Located loc =>
loc -> [loc] -> Text -> ImpM rep r op ()
warn (SrcLoc, [SrcLoc])
loc Text
"Safety check required at run-time."
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (Index VName
src Slice SubExp
slice)
  | Just [SubExp]
idxs <- Slice SubExp -> Maybe [SubExp]
forall d. Slice d -> Maybe [d]
sliceIndices Slice SubExp
slice =
      VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe) [] (VName -> SubExp
Var VName
src) ([DimIndex (TExp Int64)] -> ImpM rep r op ())
-> [DimIndex (TExp Int64)] -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ (SubExp -> DimIndex (TExp Int64))
-> [SubExp] -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map (TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix (TExp Int64 -> DimIndex (TExp Int64))
-> (SubExp -> TExp Int64) -> SubExp -> DimIndex (TExp Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> TExp Int64
pe64) [SubExp]
idxs
defCompileBasicOp Pat (LetDec rep)
_ Index {} =
  () -> ImpM rep r op ()
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (Update Safety
safety VName
_ Slice SubExp
slice SubExp
se) =
  case Safety
safety of
    Safety
Unsafe -> ImpM rep r op ()
write
    Safety
Safe -> TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (Slice (TExp Int64) -> [TExp Int64] -> TExp Bool
inBounds Slice (TExp Int64)
slice' [TExp Int64]
dims) ImpM rep r op ()
write
  where
    slice' :: Slice (TExp Int64)
slice' = (SubExp -> TExp Int64) -> Slice SubExp -> Slice (TExp Int64)
forall a b. (a -> b) -> Slice a -> Slice b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
pe64 Slice SubExp
slice
    dims :: [TExp Int64]
dims = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 ([SubExp] -> [TExp Int64]) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [SubExp]) -> Type -> [SubExp]
forall a b. (a -> b) -> a -> b
$ PatElem (LetDec rep) -> Type
forall dec. Typed dec => PatElem dec -> Type
patElemType PatElem (LetDec rep)
pe
    write :: ImpM rep r op ()
write = VName -> Slice (TExp Int64) -> SubExp -> ImpM rep r op ()
forall rep r op.
VName -> Slice (TExp Int64) -> SubExp -> ImpM rep r op ()
sUpdate (PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe) Slice (TExp Int64)
slice' SubExp
se
defCompileBasicOp Pat (LetDec rep)
_ FlatIndex {} =
  () -> ImpM rep r op ()
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (FlatUpdate VName
_ FlatSlice SubExp
slice VName
v) = do
  MemLoc
pe_loc <- ArrayEntry -> MemLoc
entryArrayLoc (ArrayEntry -> MemLoc)
-> ImpM rep r op ArrayEntry -> ImpM rep r op MemLoc
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM rep r op ArrayEntry
forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray (PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe)
  MemLoc
v_loc <- ArrayEntry -> MemLoc
entryArrayLoc (ArrayEntry -> MemLoc)
-> ImpM rep r op ArrayEntry -> ImpM rep r op MemLoc
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM rep r op ArrayEntry
forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray VName
v
  let pe_loc' :: MemLoc
pe_loc' = MemLoc -> FlatSlice (TExp Int64) -> MemLoc
flatSliceMemLoc MemLoc
pe_loc (FlatSlice (TExp Int64) -> MemLoc)
-> FlatSlice (TExp Int64) -> MemLoc
forall a b. (a -> b) -> a -> b
$ (SubExp -> TExp Int64)
-> FlatSlice SubExp -> FlatSlice (TExp Int64)
forall a b. (a -> b) -> FlatSlice a -> FlatSlice b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
pe64 FlatSlice SubExp
slice
  CopyCompiler rep r op
forall rep r op. CopyCompiler rep r op
copy (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType (PatElem (LetDec rep) -> Type
forall dec. Typed dec => PatElem dec -> Type
patElemType PatElem (LetDec rep)
pe)) MemLoc
pe_loc' MemLoc
v_loc
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (Replicate Shape
shape SubExp
se)
  | Acc {} <- PatElem (LetDec rep) -> Type
forall dec. Typed dec => PatElem dec -> Type
patElemType PatElem (LetDec rep)
pe = () -> ImpM rep r op ()
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  | Shape
shape Shape -> Shape -> Bool
forall a. Eq a => a -> a -> Bool
== Shape
forall a. Monoid a => a
mempty =
      VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe) [] SubExp
se []
  | Bool
otherwise =
      Shape -> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
forall rep r op.
Shape -> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest Shape
shape (([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ())
-> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
is -> VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe) [TExp Int64]
is SubExp
se []
defCompileBasicOp Pat (LetDec rep)
_ Scratch {} =
  () -> ImpM rep r op ()
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (Iota SubExp
n SubExp
e SubExp
s IntType
it) = do
  Exp
e' <- SubExp -> ImpM rep r op Exp
forall a rep r op. ToExp a => a -> ImpM rep r op Exp
forall rep r op. SubExp -> ImpM rep r op Exp
toExp SubExp
e
  Exp
s' <- SubExp -> ImpM rep r op Exp
forall a rep r op. ToExp a => a -> ImpM rep r op Exp
forall rep r op. SubExp -> ImpM rep r op Exp
toExp SubExp
s
  [Char]
-> TExp Int64
-> (TExp Int64 -> ImpM rep r op ())
-> ImpM rep r op ()
forall {k} (t :: k) rep r op.
[Char]
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor [Char]
"i" (SubExp -> TExp Int64
pe64 SubExp
n) ((TExp Int64 -> ImpM rep r op ()) -> ImpM rep r op ())
-> (TExp Int64 -> ImpM rep r op ()) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
    let i' :: Exp
i' = IntType -> Exp -> Exp
forall v. IntType -> PrimExp v -> PrimExp v
sExt IntType
it (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Int64
i
    TV Any
x <-
      [Char] -> TPrimExp Any VName -> ImpM rep r op (TV Any)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
"x" (TPrimExp Any VName -> ImpM rep r op (TV Any))
-> (Exp -> TPrimExp Any VName) -> Exp -> ImpM rep r op (TV Any)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp -> TPrimExp Any VName
forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp (Exp -> ImpM rep r op (TV Any)) -> Exp -> ImpM rep r op (TV Any)
forall a b. (a -> b) -> a -> b
$
        BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> Overflow -> BinOp
Add IntType
it Overflow
OverflowUndef) Exp
e' (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$
          BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> Overflow -> BinOp
Mul IntType
it Overflow
OverflowUndef) Exp
i' Exp
s'
    VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe) [TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix TExp Int64
i] (VName -> SubExp
Var (TV Any -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV Any
x)) []
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (Manifest [Int]
_ VName
src) =
  VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe) [] (VName -> SubExp
Var VName
src) []
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (Concat Int
i (VName
x :| [VName]
ys) SubExp
_) = do
  TV Int64
offs_glb <- [Char] -> TExp Int64 -> ImpM rep r op (TV Int64)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
"tmp_offs" TExp Int64
0

  [VName] -> (VName -> ImpM rep r op ()) -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (VName
x VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
ys) ((VName -> ImpM rep r op ()) -> ImpM rep r op ())
-> (VName -> ImpM rep r op ()) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \VName
y -> do
    [SubExp]
y_dims <- Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [SubExp]) -> ImpM rep r op Type -> ImpM rep r op [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM rep r op Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
y
    let rows :: TExp Int64
rows = case Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
drop Int
i [SubExp]
y_dims of
          [] -> [Char] -> TExp Int64
forall a. HasCallStack => [Char] -> a
error ([Char] -> TExp Int64) -> [Char] -> TExp Int64
forall a b. (a -> b) -> a -> b
$ [Char]
"defCompileBasicOp Concat: empty array shape for " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
prettyString VName
y
          SubExp
r : [SubExp]
_ -> SubExp -> TExp Int64
pe64 SubExp
r
        skip_dims :: [SubExp]
skip_dims = Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
take Int
i [SubExp]
y_dims
        sliceAllDim :: d -> DimIndex d
sliceAllDim d
d = d -> d -> d -> DimIndex d
forall d. d -> d -> d -> DimIndex d
DimSlice d
0 d
d d
1
        skip_slices :: [DimIndex (TExp Int64)]
skip_slices = (SubExp -> DimIndex (TExp Int64))
-> [SubExp] -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map (TExp Int64 -> DimIndex (TExp Int64)
forall {d}. Num d => d -> DimIndex d
sliceAllDim (TExp Int64 -> DimIndex (TExp Int64))
-> (SubExp -> TExp Int64) -> SubExp -> DimIndex (TExp Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> TExp Int64
pe64) [SubExp]
skip_dims
        destslice :: [DimIndex (TExp Int64)]
destslice = [DimIndex (TExp Int64)]
skip_slices [DimIndex (TExp Int64)]
-> [DimIndex (TExp Int64)] -> [DimIndex (TExp Int64)]
forall a. [a] -> [a] -> [a]
++ [TExp Int64 -> TExp Int64 -> TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> d -> d -> DimIndex d
DimSlice (TV Int64 -> TExp Int64
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
offs_glb) TExp Int64
rows TExp Int64
1]
    VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe) [DimIndex (TExp Int64)]
destslice (VName -> SubExp
Var VName
y) []
    TV Int64
offs_glb TV Int64 -> TExp Int64 -> ImpM rep r op ()
forall {k} (t :: k) rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TV Int64 -> TExp Int64
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
offs_glb TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
rows
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (ArrayLit [SubExp]
es Type
_)
  | Just vs :: [PrimValue]
vs@(PrimValue
v : [PrimValue]
_) <- (SubExp -> Maybe PrimValue) -> [SubExp] -> Maybe [PrimValue]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExp -> Maybe PrimValue
isLiteral [SubExp]
es = do
      MemLoc
dest_mem <- ArrayEntry -> MemLoc
entryArrayLoc (ArrayEntry -> MemLoc)
-> ImpM rep r op ArrayEntry -> ImpM rep r op MemLoc
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM rep r op ArrayEntry
forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray (PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe)
      let t :: PrimType
t = PrimValue -> PrimType
primValueType PrimValue
v
      VName
static_array <- [Char] -> ImpM rep r op VName
forall rep r op. [Char] -> ImpM rep r op VName
newVNameForFun [Char]
"static_array"
      Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> ArrayContents -> Code op
forall a. VName -> PrimType -> ArrayContents -> Code a
Imp.DeclareArray VName
static_array PrimType
t (ArrayContents -> Code op) -> ArrayContents -> Code op
forall a b. (a -> b) -> a -> b
$ [PrimValue] -> ArrayContents
Imp.ArrayValues [PrimValue]
vs
      let static_src :: MemLoc
static_src =
            VName -> [SubExp] -> LMAD (TExp Int64) -> MemLoc
MemLoc VName
static_array [IntType -> Integer -> SubExp
intConst IntType
Int64 (Integer -> SubExp) -> Integer -> SubExp
forall a b. (a -> b) -> a -> b
$ Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Integer) -> Int -> Integer
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
es] (LMAD (TExp Int64) -> MemLoc) -> LMAD (TExp Int64) -> MemLoc
forall a b. (a -> b) -> a -> b
$
              TExp Int64 -> [TExp Int64] -> LMAD (TExp Int64)
forall num. IntegralExp num => num -> [num] -> LMAD num
LMAD.iota TExp Int64
0 [Int -> TExp Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> TExp Int64) -> Int -> TExp Int64
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
es]
      VName -> VarEntry rep -> ImpM rep r op ()
forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
static_array (VarEntry rep -> ImpM rep r op ())
-> VarEntry rep -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp rep) -> MemEntry -> VarEntry rep
forall rep. Maybe (Exp rep) -> MemEntry -> VarEntry rep
MemVar Maybe (Exp rep)
forall a. Maybe a
Nothing (MemEntry -> VarEntry rep) -> MemEntry -> VarEntry rep
forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
DefaultSpace
      CopyCompiler rep r op
forall rep r op. CopyCompiler rep r op
copy PrimType
t MemLoc
dest_mem MemLoc
static_src
  | Bool
otherwise =
      [(Integer, SubExp)]
-> ((Integer, SubExp) -> ImpM rep r op ()) -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Integer] -> [SubExp] -> [(Integer, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Integer
0 ..] [SubExp]
es) (((Integer, SubExp) -> ImpM rep r op ()) -> ImpM rep r op ())
-> ((Integer, SubExp) -> ImpM rep r op ()) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \(Integer
i, SubExp
e) ->
        VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe) [TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix (TExp Int64 -> DimIndex (TExp Int64))
-> TExp Int64 -> DimIndex (TExp Int64)
forall a b. (a -> b) -> a -> b
$ Integer -> TExp Int64
forall a. Num a => Integer -> a
fromInteger Integer
i] SubExp
e []
  where
    isLiteral :: SubExp -> Maybe PrimValue
isLiteral (Constant PrimValue
v) = PrimValue -> Maybe PrimValue
forall a. a -> Maybe a
Just PrimValue
v
    isLiteral SubExp
_ = Maybe PrimValue
forall a. Maybe a
Nothing
defCompileBasicOp Pat (LetDec rep)
_ Rearrange {} =
  () -> ImpM rep r op ()
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
defCompileBasicOp Pat (LetDec rep)
_ Reshape {} =
  () -> ImpM rep r op ()
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
defCompileBasicOp Pat (LetDec rep)
_ (UpdateAcc VName
acc [SubExp]
is [SubExp]
vs) = Text -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"UpdateAcc" (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ do
  -- We are abusing the comment mechanism to wrap the operator in
  -- braces when we end up generating code.  This is necessary because
  -- we might otherwise end up declaring lambda parameters (if any)
  -- multiple times, as they are duplicated every time we do an
  -- UpdateAcc for the same accumulator.
  let is' :: [TExp Int64]
is' = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 [SubExp]
is

  -- We need to figure out whether we are updating a scatter-like
  -- accumulator or a generalised reduction.  This also binds the
  -- index parameters.
  (VName
_, Space
_, [VName]
arrs, [TExp Int64]
dims, Maybe (Lambda rep)
op) <- VName
-> [TExp Int64]
-> ImpM
     rep r op (VName, Space, [VName], [TExp Int64], Maybe (Lambda rep))
forall rep r op.
VName
-> [TExp Int64]
-> ImpM
     rep r op (VName, Space, [VName], [TExp Int64], Maybe (Lambda rep))
lookupAcc VName
acc [TExp Int64]
is'

  TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (Slice (TExp Int64) -> [TExp Int64] -> TExp Bool
inBounds ([DimIndex (TExp Int64)] -> Slice (TExp Int64)
forall d. [DimIndex d] -> Slice d
Slice ((TExp Int64 -> DimIndex (TExp Int64))
-> [TExp Int64] -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix [TExp Int64]
is')) [TExp Int64]
dims) (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
    case Maybe (Lambda rep)
op of
      Maybe (Lambda rep)
Nothing ->
        -- Scatter-like.
        [(VName, SubExp)]
-> ((VName, SubExp) -> ImpM rep r op ()) -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
arrs [SubExp]
vs) (((VName, SubExp) -> ImpM rep r op ()) -> ImpM rep r op ())
-> ((VName, SubExp) -> ImpM rep r op ()) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \(VName
arr, SubExp
v) -> VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
arr [TExp Int64]
is' SubExp
v []
      Just Lambda rep
lam -> do
        -- Generalised reduction.
        [LParam rep] -> ImpM rep r op ()
forall rep (inner :: * -> *) r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams ([LParam rep] -> ImpM rep r op ())
-> [LParam rep] -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [LParam rep]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam
        let ([VName]
x_params, [VName]
y_params) =
              Int -> [VName] -> ([VName], [VName])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
vs) ([VName] -> ([VName], [VName])) -> [VName] -> ([VName], [VName])
forall a b. (a -> b) -> a -> b
$ (Param LParamMem -> VName) -> [Param LParamMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param LParamMem -> VName
forall dec. Param dec -> VName
paramName ([Param LParamMem] -> [VName]) -> [Param LParamMem] -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [LParam rep]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam

        [(VName, VName)]
-> ((VName, VName) -> ImpM rep r op ()) -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
x_params [VName]
arrs) (((VName, VName) -> ImpM rep r op ()) -> ImpM rep r op ())
-> ((VName, VName) -> ImpM rep r op ()) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \(VName
xp, VName
arr) ->
          VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
xp [] (VName -> SubExp
Var VName
arr) [TExp Int64]
is'

        [(VName, SubExp)]
-> ((VName, SubExp) -> ImpM rep r op ()) -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
y_params [SubExp]
vs) (((VName, SubExp) -> ImpM rep r op ()) -> ImpM rep r op ())
-> ((VName, SubExp) -> ImpM rep r op ()) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \(VName
yp, SubExp
v) ->
          VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM VName
yp [] SubExp
v []

        Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (Body rep -> Stms rep
forall rep. Body rep -> Stms rep
bodyStms (Body rep -> Stms rep) -> Body rep -> Stms rep
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam) (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
          [(VName, SubExpRes)]
-> ((VName, SubExpRes) -> ImpM rep r op ()) -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> Result -> [(VName, SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
arrs (Body rep -> Result
forall rep. Body rep -> Result
bodyResult (Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam))) (((VName, SubExpRes) -> ImpM rep r op ()) -> ImpM rep r op ())
-> ((VName, SubExpRes) -> ImpM rep r op ()) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \(VName
arr, SubExpRes Certs
_ SubExp
se) ->
            VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
arr [TExp Int64]
is' SubExp
se []
defCompileBasicOp Pat (LetDec rep)
pat BasicOp
e =
  [Char] -> ImpM rep r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM rep r op ()) -> [Char] -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
    [Char]
"ImpGen.defCompileBasicOp: Invalid pattern\n  "
      [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ Pat (LetDec rep) -> [Char]
forall a. Pretty a => a -> [Char]
prettyString Pat (LetDec rep)
pat
      [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
"\nfor expression\n  "
      [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ BasicOp -> [Char]
forall a. Pretty a => a -> [Char]
prettyString BasicOp
e

-- | Note: a hack to be used only for functions.
addArrays :: [ArrayDecl] -> ImpM rep r op ()
addArrays :: forall rep r op. [ArrayDecl] -> ImpM rep r op ()
addArrays = (ArrayDecl -> ImpM rep r op ()) -> [ArrayDecl] -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ArrayDecl -> ImpM rep r op ()
forall {rep} {r} {op}. ArrayDecl -> ImpM rep r op ()
addArray
  where
    addArray :: ArrayDecl -> ImpM rep r op ()
addArray (ArrayDecl VName
name PrimType
bt MemLoc
location) =
      VName -> VarEntry rep -> ImpM rep r op ()
forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
name (VarEntry rep -> ImpM rep r op ())
-> VarEntry rep -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
        Maybe (Exp rep) -> ArrayEntry -> VarEntry rep
forall rep. Maybe (Exp rep) -> ArrayEntry -> VarEntry rep
ArrayVar
          Maybe (Exp rep)
forall a. Maybe a
Nothing
          ArrayEntry
            { entryArrayLoc :: MemLoc
entryArrayLoc = MemLoc
location,
              entryArrayElemType :: PrimType
entryArrayElemType = PrimType
bt
            }

-- | Like 'dFParams', but does not create new declarations.
-- Note: a hack to be used only for functions.
addFParams :: (Mem rep inner) => [FParam rep] -> ImpM rep r op ()
addFParams :: forall rep (inner :: * -> *) r op.
Mem rep inner =>
[FParam rep] -> ImpM rep r op ()
addFParams = (Param FParamMem -> ImpM rep r op ())
-> [Param FParamMem] -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Param FParamMem -> ImpM rep r op ()
forall {u} {rep} {r} {op}.
Param (MemInfo SubExp u MemBind) -> ImpM rep r op ()
addFParam
  where
    addFParam :: Param (MemInfo SubExp u MemBind) -> ImpM rep r op ()
addFParam Param (MemInfo SubExp u MemBind)
fparam =
      VName -> VarEntry rep -> ImpM rep r op ()
forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar (Param (MemInfo SubExp u MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo SubExp u MemBind)
fparam) (VarEntry rep -> ImpM rep r op ())
-> VarEntry rep -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
        Maybe (Exp rep) -> LParamMem -> VarEntry rep
forall rep. Maybe (Exp rep) -> LParamMem -> VarEntry rep
memBoundToVarEntry Maybe (Exp rep)
forall a. Maybe a
Nothing (LParamMem -> VarEntry rep) -> LParamMem -> VarEntry rep
forall a b. (a -> b) -> a -> b
$
          MemInfo SubExp u MemBind -> LParamMem
forall d u r. MemInfo d u r -> MemInfo d NoUniqueness r
noUniquenessReturns (MemInfo SubExp u MemBind -> LParamMem)
-> MemInfo SubExp u MemBind -> LParamMem
forall a b. (a -> b) -> a -> b
$
            Param (MemInfo SubExp u MemBind) -> MemInfo SubExp u MemBind
forall dec. Param dec -> dec
paramDec Param (MemInfo SubExp u MemBind)
fparam

-- | Another hack.
addLoopVar :: VName -> IntType -> ImpM rep r op ()
addLoopVar :: forall rep r op. VName -> IntType -> ImpM rep r op ()
addLoopVar VName
i IntType
it = VName -> VarEntry rep -> ImpM rep r op ()
forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
i (VarEntry rep -> ImpM rep r op ())
-> VarEntry rep -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp rep) -> ScalarEntry -> VarEntry rep
forall rep. Maybe (Exp rep) -> ScalarEntry -> VarEntry rep
ScalarVar Maybe (Exp rep)
forall a. Maybe a
Nothing (ScalarEntry -> VarEntry rep) -> ScalarEntry -> VarEntry rep
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarEntry
ScalarEntry (PrimType -> ScalarEntry) -> PrimType -> ScalarEntry
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
it

dVars ::
  (Mem rep inner) =>
  Maybe (Exp rep) ->
  [PatElem (LetDec rep)] ->
  ImpM rep r op ()
dVars :: forall rep (inner :: * -> *) r op.
Mem rep inner =>
Maybe (Exp rep) -> [PatElem (LetDec rep)] -> ImpM rep r op ()
dVars Maybe (Exp rep)
e = (PatElem (LetDec rep) -> ImpM rep r op ())
-> [PatElem (LetDec rep)] -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ PatElem (LetDec rep) -> ImpM rep r op ()
dVar
  where
    dVar :: PatElem (LetDec rep) -> ImpM rep r op ()
dVar = Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
forall rep (inner :: * -> *) r op.
Mem rep inner =>
Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
dScope Maybe (Exp rep)
e (Scope rep -> ImpM rep r op ())
-> (PatElem (LetDec rep) -> Scope rep)
-> PatElem (LetDec rep)
-> ImpM rep r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElem (LetDec rep) -> Scope rep
forall rep dec. (LetDec rep ~ dec) => PatElem dec -> Scope rep
scopeOfPatElem

dFParams :: (Mem rep inner) => [FParam rep] -> ImpM rep r op ()
dFParams :: forall rep (inner :: * -> *) r op.
Mem rep inner =>
[FParam rep] -> ImpM rep r op ()
dFParams = Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
forall rep (inner :: * -> *) r op.
Mem rep inner =>
Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
dScope Maybe (Exp rep)
forall a. Maybe a
Nothing (Scope rep -> ImpM rep r op ())
-> ([Param FParamMem] -> Scope rep)
-> [Param FParamMem]
-> ImpM rep r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Param FParamMem] -> Scope rep
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams

dLParams :: (Mem rep inner) => [LParam rep] -> ImpM rep r op ()
dLParams :: forall rep (inner :: * -> *) r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams = Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
forall rep (inner :: * -> *) r op.
Mem rep inner =>
Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
dScope Maybe (Exp rep)
forall a. Maybe a
Nothing (Scope rep -> ImpM rep r op ())
-> ([Param LParamMem] -> Scope rep)
-> [Param LParamMem]
-> ImpM rep r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Param LParamMem] -> Scope rep
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams

dPrimVol :: String -> PrimType -> Imp.TExp t -> ImpM rep r op (TV t)
dPrimVol :: forall {k} (t :: k) rep r op.
[Char] -> PrimType -> TExp t -> ImpM rep r op (TV t)
dPrimVol [Char]
name PrimType
t TExp t
e = do
  VName
name' <- [Char] -> ImpM rep r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
name
  Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Volatility -> PrimType -> Code op
forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
name' Volatility
Imp.Volatile PrimType
t
  VName -> VarEntry rep -> ImpM rep r op ()
forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
name' (VarEntry rep -> ImpM rep r op ())
-> VarEntry rep -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp rep) -> ScalarEntry -> VarEntry rep
forall rep. Maybe (Exp rep) -> ScalarEntry -> VarEntry rep
ScalarVar Maybe (Exp rep)
forall a. Maybe a
Nothing (ScalarEntry -> VarEntry rep) -> ScalarEntry -> VarEntry rep
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarEntry
ScalarEntry PrimType
t
  VName
name' VName -> Exp -> ImpM rep r op ()
forall rep r op. VName -> Exp -> ImpM rep r op ()
<~~ TExp t -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp t
e
  TV t -> ImpM rep r op (TV t)
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TV t -> ImpM rep r op (TV t)) -> TV t -> ImpM rep r op (TV t)
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> TV t
forall {k} (t :: k). VName -> PrimType -> TV t
TV VName
name' PrimType
t

dPrim_ :: VName -> PrimType -> ImpM rep r op ()
dPrim_ :: forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
name PrimType
t = do
  Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Volatility -> PrimType -> Code op
forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
name Volatility
Imp.Nonvolatile PrimType
t
  VName -> VarEntry rep -> ImpM rep r op ()
forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
name (VarEntry rep -> ImpM rep r op ())
-> VarEntry rep -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp rep) -> ScalarEntry -> VarEntry rep
forall rep. Maybe (Exp rep) -> ScalarEntry -> VarEntry rep
ScalarVar Maybe (Exp rep)
forall a. Maybe a
Nothing (ScalarEntry -> VarEntry rep) -> ScalarEntry -> VarEntry rep
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarEntry
ScalarEntry PrimType
t

-- | The return type is polymorphic, so there is no guarantee it
-- actually matches the 'PrimType', but at least we have to use it
-- consistently.
dPrim :: String -> PrimType -> ImpM rep r op (TV t)
dPrim :: forall {k} rep r op (t :: k).
[Char] -> PrimType -> ImpM rep r op (TV t)
dPrim [Char]
name PrimType
t = do
  VName
name' <- [Char] -> ImpM rep r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
name
  VName -> PrimType -> ImpM rep r op ()
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
name' PrimType
t
  TV t -> ImpM rep r op (TV t)
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TV t -> ImpM rep r op (TV t)) -> TV t -> ImpM rep r op (TV t)
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> TV t
forall {k} (t :: k). VName -> PrimType -> TV t
TV VName
name' PrimType
t

dPrimV_ :: VName -> Imp.TExp t -> ImpM rep r op ()
dPrimV_ :: forall {k} (t :: k) rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
name TExp t
e = do
  VName -> PrimType -> ImpM rep r op ()
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
name PrimType
t
  VName -> PrimType -> TV t
forall {k} (t :: k). VName -> PrimType -> TV t
TV VName
name PrimType
t TV t -> TExp t -> ImpM rep r op ()
forall {k} (t :: k) rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TExp t
e
  where
    t :: PrimType
t = Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType (Exp -> PrimType) -> Exp -> PrimType
forall a b. (a -> b) -> a -> b
$ TExp t -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp t
e

dPrimV :: String -> Imp.TExp t -> ImpM rep r op (TV t)
dPrimV :: forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
name TExp t
e = do
  TV t
name' <- [Char] -> PrimType -> ImpM rep r op (TV t)
forall {k} rep r op (t :: k).
[Char] -> PrimType -> ImpM rep r op (TV t)
dPrim [Char]
name (PrimType -> ImpM rep r op (TV t))
-> PrimType -> ImpM rep r op (TV t)
forall a b. (a -> b) -> a -> b
$ Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType (Exp -> PrimType) -> Exp -> PrimType
forall a b. (a -> b) -> a -> b
$ TExp t -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp t
e
  TV t
name' TV t -> TExp t -> ImpM rep r op ()
forall {k} (t :: k) rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TExp t
e
  TV t -> ImpM rep r op (TV t)
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure TV t
name'

dPrimVE :: String -> Imp.TExp t -> ImpM rep r op (Imp.TExp t)
dPrimVE :: forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
name TExp t
e = do
  TV t
name' <- [Char] -> PrimType -> ImpM rep r op (TV t)
forall {k} rep r op (t :: k).
[Char] -> PrimType -> ImpM rep r op (TV t)
dPrim [Char]
name (PrimType -> ImpM rep r op (TV t))
-> PrimType -> ImpM rep r op (TV t)
forall a b. (a -> b) -> a -> b
$ Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType (Exp -> PrimType) -> Exp -> PrimType
forall a b. (a -> b) -> a -> b
$ TExp t -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp t
e
  TV t
name' TV t -> TExp t -> ImpM rep r op ()
forall {k} (t :: k) rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TExp t
e
  TExp t -> ImpM rep r op (TExp t)
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TExp t -> ImpM rep r op (TExp t))
-> TExp t -> ImpM rep r op (TExp t)
forall a b. (a -> b) -> a -> b
$ TV t -> TExp t
forall {k} (t :: k). TV t -> TExp t
tvExp TV t
name'

memBoundToVarEntry ::
  Maybe (Exp rep) ->
  MemBound NoUniqueness ->
  VarEntry rep
memBoundToVarEntry :: forall rep. Maybe (Exp rep) -> LParamMem -> VarEntry rep
memBoundToVarEntry Maybe (Exp rep)
e (MemPrim PrimType
bt) =
  Maybe (Exp rep) -> ScalarEntry -> VarEntry rep
forall rep. Maybe (Exp rep) -> ScalarEntry -> VarEntry rep
ScalarVar Maybe (Exp rep)
e ScalarEntry {entryScalarType :: PrimType
entryScalarType = PrimType
bt}
memBoundToVarEntry Maybe (Exp rep)
e (MemMem Space
space) =
  Maybe (Exp rep) -> MemEntry -> VarEntry rep
forall rep. Maybe (Exp rep) -> MemEntry -> VarEntry rep
MemVar Maybe (Exp rep)
e (MemEntry -> VarEntry rep) -> MemEntry -> VarEntry rep
forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
space
memBoundToVarEntry Maybe (Exp rep)
e (MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
_) =
  Maybe (Exp rep) -> (VName, Shape, [Type]) -> VarEntry rep
forall rep.
Maybe (Exp rep) -> (VName, Shape, [Type]) -> VarEntry rep
AccVar Maybe (Exp rep)
e (VName
acc, Shape
ispace, [Type]
ts)
memBoundToVarEntry Maybe (Exp rep)
e (MemArray PrimType
bt Shape
shape NoUniqueness
_ (ArrayIn VName
mem IxFun
lmad)) =
  let location :: MemLoc
location = VName -> [SubExp] -> LMAD (TExp Int64) -> MemLoc
MemLoc VName
mem (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape) (LMAD (TExp Int64) -> MemLoc) -> LMAD (TExp Int64) -> MemLoc
forall a b. (a -> b) -> a -> b
$ IxFun -> LMAD (TExp Int64)
forall num. IxFun num -> LMAD num
IxFun.ixfunLMAD IxFun
lmad
   in Maybe (Exp rep) -> ArrayEntry -> VarEntry rep
forall rep. Maybe (Exp rep) -> ArrayEntry -> VarEntry rep
ArrayVar
        Maybe (Exp rep)
e
        ArrayEntry
          { entryArrayLoc :: MemLoc
entryArrayLoc = MemLoc
location,
            entryArrayElemType :: PrimType
entryArrayElemType = PrimType
bt
          }

infoDec ::
  (Mem rep inner) =>
  NameInfo rep ->
  MemInfo SubExp NoUniqueness MemBind
infoDec :: forall rep (inner :: * -> *).
Mem rep inner =>
NameInfo rep -> LParamMem
infoDec (LetName LetDec rep
dec) = LetDec rep -> LParamMem
forall t. HasLetDecMem t => t -> LParamMem
letDecMem LetDec rep
dec
infoDec (FParamName FParamInfo rep
dec) = FParamMem -> LParamMem
forall d u r. MemInfo d u r -> MemInfo d NoUniqueness r
noUniquenessReturns FParamInfo rep
FParamMem
dec
infoDec (LParamName LParamInfo rep
dec) = LParamInfo rep
LParamMem
dec
infoDec (IndexName IntType
it) = PrimType -> LParamMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim (PrimType -> LParamMem) -> PrimType -> LParamMem
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
it

dInfo ::
  (Mem rep inner) =>
  Maybe (Exp rep) ->
  VName ->
  NameInfo rep ->
  ImpM rep r op ()
dInfo :: forall rep (inner :: * -> *) r op.
Mem rep inner =>
Maybe (Exp rep) -> VName -> NameInfo rep -> ImpM rep r op ()
dInfo Maybe (Exp rep)
e VName
name NameInfo rep
info = do
  let entry :: VarEntry rep
entry = Maybe (Exp rep) -> LParamMem -> VarEntry rep
forall rep. Maybe (Exp rep) -> LParamMem -> VarEntry rep
memBoundToVarEntry Maybe (Exp rep)
e (LParamMem -> VarEntry rep) -> LParamMem -> VarEntry rep
forall a b. (a -> b) -> a -> b
$ NameInfo rep -> LParamMem
forall rep (inner :: * -> *).
Mem rep inner =>
NameInfo rep -> LParamMem
infoDec NameInfo rep
info
  case VarEntry rep
entry of
    MemVar Maybe (Exp rep)
_ MemEntry
entry' ->
      Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Code op
forall a. VName -> Space -> Code a
Imp.DeclareMem VName
name (Space -> Code op) -> Space -> Code op
forall a b. (a -> b) -> a -> b
$ MemEntry -> Space
entryMemSpace MemEntry
entry'
    ScalarVar Maybe (Exp rep)
_ ScalarEntry
entry' ->
      Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Volatility -> PrimType -> Code op
forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
name Volatility
Imp.Nonvolatile (PrimType -> Code op) -> PrimType -> Code op
forall a b. (a -> b) -> a -> b
$ ScalarEntry -> PrimType
entryScalarType ScalarEntry
entry'
    ArrayVar Maybe (Exp rep)
_ ArrayEntry
_ ->
      () -> ImpM rep r op ()
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    AccVar {} ->
      () -> ImpM rep r op ()
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  VName -> VarEntry rep -> ImpM rep r op ()
forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
name VarEntry rep
entry

dScope ::
  (Mem rep inner) =>
  Maybe (Exp rep) ->
  Scope rep ->
  ImpM rep r op ()
dScope :: forall rep (inner :: * -> *) r op.
Mem rep inner =>
Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
dScope Maybe (Exp rep)
e = ((VName, NameInfo rep) -> ImpM rep r op ())
-> [(VName, NameInfo rep)] -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((VName -> NameInfo rep -> ImpM rep r op ())
-> (VName, NameInfo rep) -> ImpM rep r op ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((VName -> NameInfo rep -> ImpM rep r op ())
 -> (VName, NameInfo rep) -> ImpM rep r op ())
-> (VName -> NameInfo rep -> ImpM rep r op ())
-> (VName, NameInfo rep)
-> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp rep) -> VName -> NameInfo rep -> ImpM rep r op ()
forall rep (inner :: * -> *) r op.
Mem rep inner =>
Maybe (Exp rep) -> VName -> NameInfo rep -> ImpM rep r op ()
dInfo Maybe (Exp rep)
e) ([(VName, NameInfo rep)] -> ImpM rep r op ())
-> (Scope rep -> [(VName, NameInfo rep)])
-> Scope rep
-> ImpM rep r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope rep -> [(VName, NameInfo rep)]
forall k a. Map k a -> [(k, a)]
M.toList

dArray :: VName -> PrimType -> ShapeBase SubExp -> VName -> LMAD -> ImpM rep r op ()
dArray :: forall rep r op.
VName
-> PrimType
-> Shape
-> VName
-> LMAD (TExp Int64)
-> ImpM rep r op ()
dArray VName
name PrimType
pt Shape
shape VName
mem LMAD (TExp Int64)
lmad =
  VName -> VarEntry rep -> ImpM rep r op ()
forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
name (VarEntry rep -> ImpM rep r op ())
-> VarEntry rep -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp rep) -> ArrayEntry -> VarEntry rep
forall rep. Maybe (Exp rep) -> ArrayEntry -> VarEntry rep
ArrayVar Maybe (Exp rep)
forall a. Maybe a
Nothing (ArrayEntry -> VarEntry rep) -> ArrayEntry -> VarEntry rep
forall a b. (a -> b) -> a -> b
$ MemLoc -> PrimType -> ArrayEntry
ArrayEntry MemLoc
location PrimType
pt
  where
    location :: MemLoc
location = VName -> [SubExp] -> LMAD (TExp Int64) -> MemLoc
MemLoc VName
mem (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape) LMAD (TExp Int64)
lmad

everythingVolatile :: ImpM rep r op a -> ImpM rep r op a
everythingVolatile :: forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile = (Env rep r op -> Env rep r op)
-> ImpM rep r op a -> ImpM rep r op a
forall a.
(Env rep r op -> Env rep r op)
-> ImpM rep r op a -> ImpM rep r op a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((Env rep r op -> Env rep r op)
 -> ImpM rep r op a -> ImpM rep r op a)
-> (Env rep r op -> Env rep r op)
-> ImpM rep r op a
-> ImpM rep r op a
forall a b. (a -> b) -> a -> b
$ \Env rep r op
env -> Env rep r op
env {envVolatility :: Volatility
envVolatility = Volatility
Imp.Volatile}

funcallTargets :: [ValueDestination] -> ImpM rep r op [VName]
funcallTargets :: forall rep r op. [ValueDestination] -> ImpM rep r op [VName]
funcallTargets [ValueDestination]
dests =
  [[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[VName]] -> [VName])
-> ImpM rep r op [[VName]] -> ImpM rep r op [VName]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ValueDestination -> ImpM rep r op [VName])
-> [ValueDestination] -> ImpM rep r op [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ValueDestination -> ImpM rep r op [VName]
forall {f :: * -> *}.
Applicative f =>
ValueDestination -> f [VName]
funcallTarget [ValueDestination]
dests
  where
    funcallTarget :: ValueDestination -> f [VName]
funcallTarget (ScalarDestination VName
name) =
      [VName] -> f [VName]
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName
name]
    funcallTarget (ArrayDestination Maybe MemLoc
_) =
      [VName] -> f [VName]
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
    funcallTarget (MemoryDestination VName
name) =
      [VName] -> f [VName]
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName
name]

-- | A typed variable, which we can turn into a typed expression, or
-- use as the target for an assignment.  This is used to aid in type
-- safety when doing code generation, by keeping the types straight.
-- It is still easy to cheat when you need to.
data TV t = TV VName PrimType

-- | Create a typed variable from a name and a dynamic type.  Note
-- that there is no guarantee that the dynamic type corresponds to the
-- inferred static type, but the latter will at least have to be used
-- consistently.
mkTV :: VName -> PrimType -> TV t
mkTV :: forall {k} (t :: k). VName -> PrimType -> TV t
mkTV = VName -> PrimType -> TV t
forall {k} (t :: k). VName -> PrimType -> TV t
TV

-- | Convert a typed variable to a size (a SubExp).
tvSize :: TV t -> Imp.DimSize
tvSize :: forall {k} (t :: k). TV t -> SubExp
tvSize = VName -> SubExp
Var (VName -> SubExp) -> (TV t -> VName) -> TV t -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TV t -> VName
forall {k} (t :: k). TV t -> VName
tvVar

-- | Convert a typed variable to a similarly typed expression.
tvExp :: TV t -> Imp.TExp t
tvExp :: forall {k} (t :: k). TV t -> TExp t
tvExp (TV VName
v PrimType
t) = Exp -> TPrimExp t VName
forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
Imp.TPrimExp (Exp -> TPrimExp t VName) -> Exp -> TPrimExp t VName
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
v PrimType
t

-- | Extract the underlying variable name from a typed variable.
tvVar :: TV t -> VName
tvVar :: forall {k} (t :: k). TV t -> VName
tvVar (TV VName
v PrimType
_) = VName
v

-- | Compile things to 'Imp.Exp'.
class ToExp a where
  -- | Compile to an 'Imp.Exp', where the type (which must still be a
  -- primitive) is deduced monadically.
  toExp :: a -> ImpM rep r op Imp.Exp

  -- | Compile where we know the type in advance.
  toExp' :: PrimType -> a -> Imp.Exp

instance ToExp SubExp where
  toExp :: forall rep r op. SubExp -> ImpM rep r op Exp
toExp (Constant PrimValue
v) =
    Exp -> ImpM rep r op Exp
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp -> ImpM rep r op Exp) -> Exp -> ImpM rep r op Exp
forall a b. (a -> b) -> a -> b
$ PrimValue -> Exp
forall v. PrimValue -> PrimExp v
Imp.ValueExp PrimValue
v
  toExp (Var VName
v) =
    VName -> ImpM rep r op (VarEntry rep)
forall rep r op. VName -> ImpM rep r op (VarEntry rep)
lookupVar VName
v ImpM rep r op (VarEntry rep)
-> (VarEntry rep -> ImpM rep r op Exp) -> ImpM rep r op Exp
forall a b.
ImpM rep r op a -> (a -> ImpM rep r op b) -> ImpM rep r op b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      ScalarVar Maybe (Exp rep)
_ (ScalarEntry PrimType
pt) ->
        Exp -> ImpM rep r op Exp
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp -> ImpM rep r op Exp) -> Exp -> ImpM rep r op Exp
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
v PrimType
pt
      VarEntry rep
_ -> [Char] -> ImpM rep r op Exp
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM rep r op Exp) -> [Char] -> ImpM rep r op Exp
forall a b. (a -> b) -> a -> b
$ [Char]
"toExp SubExp: SubExp is not a primitive type: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
prettyString VName
v

  toExp' :: PrimType -> SubExp -> Exp
toExp' PrimType
_ (Constant PrimValue
v) = PrimValue -> Exp
forall v. PrimValue -> PrimExp v
Imp.ValueExp PrimValue
v
  toExp' PrimType
t (Var VName
v) = VName -> PrimType -> Exp
Imp.var VName
v PrimType
t

instance ToExp (PrimExp VName) where
  toExp :: forall rep r op. Exp -> ImpM rep r op Exp
toExp = Exp -> ImpM rep r op Exp
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
  toExp' :: PrimType -> Exp -> Exp
toExp' PrimType
_ = Exp -> Exp
forall a. a -> a
id

addVar :: VName -> VarEntry rep -> ImpM rep r op ()
addVar :: forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
name VarEntry rep
entry =
  (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ())
-> (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s -> ImpState rep r op
s {stateVTable :: VTable rep
stateVTable = VName -> VarEntry rep -> VTable rep -> VTable rep
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
name VarEntry rep
entry (VTable rep -> VTable rep) -> VTable rep -> VTable rep
forall a b. (a -> b) -> a -> b
$ ImpState rep r op -> VTable rep
forall {k} rep (r :: k) op. ImpState rep r op -> VTable rep
stateVTable ImpState rep r op
s}

localDefaultSpace :: Imp.Space -> ImpM rep r op a -> ImpM rep r op a
localDefaultSpace :: forall rep r op a. Space -> ImpM rep r op a -> ImpM rep r op a
localDefaultSpace Space
space = (Env rep r op -> Env rep r op)
-> ImpM rep r op a -> ImpM rep r op a
forall a.
(Env rep r op -> Env rep r op)
-> ImpM rep r op a -> ImpM rep r op a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (\Env rep r op
env -> Env rep r op
env {envDefaultSpace :: Space
envDefaultSpace = Space
space})

askFunction :: ImpM rep r op (Maybe Name)
askFunction :: forall rep r op. ImpM rep r op (Maybe Name)
askFunction = (Env rep r op -> Maybe Name) -> ImpM rep r op (Maybe Name)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env rep r op -> Maybe Name
forall rep r op. Env rep r op -> Maybe Name
envFunction

-- | Generate a 'VName', prefixed with 'askFunction' if it exists.
newVNameForFun :: String -> ImpM rep r op VName
newVNameForFun :: forall rep r op. [Char] -> ImpM rep r op VName
newVNameForFun [Char]
s = do
  Maybe [Char]
fname <- (Name -> [Char]) -> Maybe Name -> Maybe [Char]
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Name -> [Char]
nameToString (Maybe Name -> Maybe [Char])
-> ImpM rep r op (Maybe Name) -> ImpM rep r op (Maybe [Char])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM rep r op (Maybe Name)
forall rep r op. ImpM rep r op (Maybe Name)
askFunction
  [Char] -> ImpM rep r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName ([Char] -> ImpM rep r op VName) -> [Char] -> ImpM rep r op VName
forall a b. (a -> b) -> a -> b
$ [Char] -> ShowS -> Maybe [Char] -> [Char]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe [Char]
"" ([Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
".") Maybe [Char]
fname [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
s

-- | Generate a 'Name', prefixed with 'askFunction' if it exists.
nameForFun :: String -> ImpM rep r op Name
nameForFun :: forall rep r op. [Char] -> ImpM rep r op Name
nameForFun [Char]
s = do
  Maybe Name
fname <- ImpM rep r op (Maybe Name)
forall rep r op. ImpM rep r op (Maybe Name)
askFunction
  Name -> ImpM rep r op Name
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Name -> ImpM rep r op Name) -> Name -> ImpM rep r op Name
forall a b. (a -> b) -> a -> b
$ Name -> (Name -> Name) -> Maybe Name -> Name
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Name
"" (Name -> Name -> Name
forall a. Semigroup a => a -> a -> a
<> Name
".") Maybe Name
fname Name -> Name -> Name
forall a. Semigroup a => a -> a -> a
<> [Char] -> Name
nameFromString [Char]
s

askEnv :: ImpM rep r op r
askEnv :: forall rep r op. ImpM rep r op r
askEnv = (Env rep r op -> r) -> ImpM rep r op r
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env rep r op -> r
forall rep r op. Env rep r op -> r
envEnv

localEnv :: (r -> r) -> ImpM rep r op a -> ImpM rep r op a
localEnv :: forall r rep op a. (r -> r) -> ImpM rep r op a -> ImpM rep r op a
localEnv r -> r
f = (Env rep r op -> Env rep r op)
-> ImpM rep r op a -> ImpM rep r op a
forall a.
(Env rep r op -> Env rep r op)
-> ImpM rep r op a -> ImpM rep r op a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((Env rep r op -> Env rep r op)
 -> ImpM rep r op a -> ImpM rep r op a)
-> (Env rep r op -> Env rep r op)
-> ImpM rep r op a
-> ImpM rep r op a
forall a b. (a -> b) -> a -> b
$ \Env rep r op
env -> Env rep r op
env {envEnv :: r
envEnv = r -> r
f (r -> r) -> r -> r
forall a b. (a -> b) -> a -> b
$ Env rep r op -> r
forall rep r op. Env rep r op -> r
envEnv Env rep r op
env}

-- | The active attributes, including those for the statement
-- currently being compiled.
askAttrs :: ImpM rep r op Attrs
askAttrs :: forall rep r op. ImpM rep r op Attrs
askAttrs = (Env rep r op -> Attrs) -> ImpM rep r op Attrs
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env rep r op -> Attrs
forall rep r op. Env rep r op -> Attrs
envAttrs

-- | Add more attributes to what is returning by 'askAttrs'.
localAttrs :: Attrs -> ImpM rep r op a -> ImpM rep r op a
localAttrs :: forall rep r op a. Attrs -> ImpM rep r op a -> ImpM rep r op a
localAttrs Attrs
attrs = (Env rep r op -> Env rep r op)
-> ImpM rep r op a -> ImpM rep r op a
forall a.
(Env rep r op -> Env rep r op)
-> ImpM rep r op a -> ImpM rep r op a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((Env rep r op -> Env rep r op)
 -> ImpM rep r op a -> ImpM rep r op a)
-> (Env rep r op -> Env rep r op)
-> ImpM rep r op a
-> ImpM rep r op a
forall a b. (a -> b) -> a -> b
$ \Env rep r op
env -> Env rep r op
env {envAttrs :: Attrs
envAttrs = Attrs
attrs Attrs -> Attrs -> Attrs
forall a. Semigroup a => a -> a -> a
<> Env rep r op -> Attrs
forall rep r op. Env rep r op -> Attrs
envAttrs Env rep r op
env}

localOps :: Operations rep r op -> ImpM rep r op a -> ImpM rep r op a
localOps :: forall rep r op a.
Operations rep r op -> ImpM rep r op a -> ImpM rep r op a
localOps Operations rep r op
ops = (Env rep r op -> Env rep r op)
-> ImpM rep r op a -> ImpM rep r op a
forall a.
(Env rep r op -> Env rep r op)
-> ImpM rep r op a -> ImpM rep r op a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((Env rep r op -> Env rep r op)
 -> ImpM rep r op a -> ImpM rep r op a)
-> (Env rep r op -> Env rep r op)
-> ImpM rep r op a
-> ImpM rep r op a
forall a b. (a -> b) -> a -> b
$ \Env rep r op
env ->
  Env rep r op
env
    { envExpCompiler :: ExpCompiler rep r op
envExpCompiler = Operations rep r op -> ExpCompiler rep r op
forall rep r op. Operations rep r op -> ExpCompiler rep r op
opsExpCompiler Operations rep r op
ops,
      envStmsCompiler :: StmsCompiler rep r op
envStmsCompiler = Operations rep r op -> StmsCompiler rep r op
forall rep r op. Operations rep r op -> StmsCompiler rep r op
opsStmsCompiler Operations rep r op
ops,
      envCopyCompiler :: CopyCompiler rep r op
envCopyCompiler = Operations rep r op -> CopyCompiler rep r op
forall rep r op. Operations rep r op -> CopyCompiler rep r op
opsCopyCompiler Operations rep r op
ops,
      envOpCompiler :: OpCompiler rep r op
envOpCompiler = Operations rep r op -> OpCompiler rep r op
forall rep r op. Operations rep r op -> OpCompiler rep r op
opsOpCompiler Operations rep r op
ops,
      envAllocCompilers :: Map Space (AllocCompiler rep r op)
envAllocCompilers = Operations rep r op -> Map Space (AllocCompiler rep r op)
forall rep r op.
Operations rep r op -> Map Space (AllocCompiler rep r op)
opsAllocCompilers Operations rep r op
ops
    }

-- | Get the current symbol table.
getVTable :: ImpM rep r op (VTable rep)
getVTable :: forall rep r op. ImpM rep r op (VTable rep)
getVTable = (ImpState rep r op -> VTable rep) -> ImpM rep r op (VTable rep)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ImpState rep r op -> VTable rep
forall {k} rep (r :: k) op. ImpState rep r op -> VTable rep
stateVTable

putVTable :: VTable rep -> ImpM rep r op ()
putVTable :: forall rep r op. VTable rep -> ImpM rep r op ()
putVTable VTable rep
vtable = (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ())
-> (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s -> ImpState rep r op
s {stateVTable :: VTable rep
stateVTable = VTable rep
vtable}

-- | Run an action with a modified symbol table.  All changes to the
-- symbol table will be reverted once the action is done!
localVTable :: (VTable rep -> VTable rep) -> ImpM rep r op a -> ImpM rep r op a
localVTable :: forall rep r op a.
(VTable rep -> VTable rep) -> ImpM rep r op a -> ImpM rep r op a
localVTable VTable rep -> VTable rep
f ImpM rep r op a
m = do
  VTable rep
old_vtable <- ImpM rep r op (VTable rep)
forall rep r op. ImpM rep r op (VTable rep)
getVTable
  VTable rep -> ImpM rep r op ()
forall rep r op. VTable rep -> ImpM rep r op ()
putVTable (VTable rep -> ImpM rep r op ()) -> VTable rep -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VTable rep -> VTable rep
f VTable rep
old_vtable
  a
a <- ImpM rep r op a
m
  VTable rep -> ImpM rep r op ()
forall rep r op. VTable rep -> ImpM rep r op ()
putVTable VTable rep
old_vtable
  a -> ImpM rep r op a
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
a

lookupVar :: VName -> ImpM rep r op (VarEntry rep)
lookupVar :: forall rep r op. VName -> ImpM rep r op (VarEntry rep)
lookupVar VName
name = do
  Maybe (VarEntry rep)
res <- (ImpState rep r op -> Maybe (VarEntry rep))
-> ImpM rep r op (Maybe (VarEntry rep))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ImpState rep r op -> Maybe (VarEntry rep))
 -> ImpM rep r op (Maybe (VarEntry rep)))
-> (ImpState rep r op -> Maybe (VarEntry rep))
-> ImpM rep r op (Maybe (VarEntry rep))
forall a b. (a -> b) -> a -> b
$ VName -> Map VName (VarEntry rep) -> Maybe (VarEntry rep)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name (Map VName (VarEntry rep) -> Maybe (VarEntry rep))
-> (ImpState rep r op -> Map VName (VarEntry rep))
-> ImpState rep r op
-> Maybe (VarEntry rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ImpState rep r op -> Map VName (VarEntry rep)
forall {k} rep (r :: k) op. ImpState rep r op -> VTable rep
stateVTable
  case Maybe (VarEntry rep)
res of
    Just VarEntry rep
entry -> VarEntry rep -> ImpM rep r op (VarEntry rep)
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VarEntry rep
entry
    Maybe (VarEntry rep)
_ -> [Char] -> ImpM rep r op (VarEntry rep)
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM rep r op (VarEntry rep))
-> [Char] -> ImpM rep r op (VarEntry rep)
forall a b. (a -> b) -> a -> b
$ [Char]
"Unknown variable: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
prettyString VName
name

lookupArray :: VName -> ImpM rep r op ArrayEntry
lookupArray :: forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray VName
name = do
  VarEntry rep
res <- VName -> ImpM rep r op (VarEntry rep)
forall rep r op. VName -> ImpM rep r op (VarEntry rep)
lookupVar VName
name
  case VarEntry rep
res of
    ArrayVar Maybe (Exp rep)
_ ArrayEntry
entry -> ArrayEntry -> ImpM rep r op ArrayEntry
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ArrayEntry
entry
    VarEntry rep
_ -> [Char] -> ImpM rep r op ArrayEntry
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM rep r op ArrayEntry)
-> [Char] -> ImpM rep r op ArrayEntry
forall a b. (a -> b) -> a -> b
$ [Char]
"ImpGen.lookupArray: not an array: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
prettyString VName
name

lookupMemory :: VName -> ImpM rep r op MemEntry
lookupMemory :: forall rep r op. VName -> ImpM rep r op MemEntry
lookupMemory VName
name = do
  VarEntry rep
res <- VName -> ImpM rep r op (VarEntry rep)
forall rep r op. VName -> ImpM rep r op (VarEntry rep)
lookupVar VName
name
  case VarEntry rep
res of
    MemVar Maybe (Exp rep)
_ MemEntry
entry -> MemEntry -> ImpM rep r op MemEntry
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure MemEntry
entry
    VarEntry rep
_ -> [Char] -> ImpM rep r op MemEntry
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM rep r op MemEntry)
-> [Char] -> ImpM rep r op MemEntry
forall a b. (a -> b) -> a -> b
$ [Char]
"Unknown memory block: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
prettyString VName
name

lookupArraySpace :: VName -> ImpM rep r op Space
lookupArraySpace :: forall rep r op. VName -> ImpM rep r op Space
lookupArraySpace =
  (MemEntry -> Space)
-> ImpM rep r op MemEntry -> ImpM rep r op Space
forall a b. (a -> b) -> ImpM rep r op a -> ImpM rep r op b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap MemEntry -> Space
entryMemSpace (ImpM rep r op MemEntry -> ImpM rep r op Space)
-> (VName -> ImpM rep r op MemEntry)
-> VName
-> ImpM rep r op Space
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> ImpM rep r op MemEntry
forall rep r op. VName -> ImpM rep r op MemEntry
lookupMemory
    (VName -> ImpM rep r op Space)
-> (VName -> ImpM rep r op VName) -> VName -> ImpM rep r op Space
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< (ArrayEntry -> VName)
-> ImpM rep r op ArrayEntry -> ImpM rep r op VName
forall a b. (a -> b) -> ImpM rep r op a -> ImpM rep r op b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (MemLoc -> VName
memLocName (MemLoc -> VName) -> (ArrayEntry -> MemLoc) -> ArrayEntry -> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArrayEntry -> MemLoc
entryArrayLoc) (ImpM rep r op ArrayEntry -> ImpM rep r op VName)
-> (VName -> ImpM rep r op ArrayEntry)
-> VName
-> ImpM rep r op VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> ImpM rep r op ArrayEntry
forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray

-- | In the case of a histogram-like accumulator, also sets the index
-- parameters.
lookupAcc ::
  VName ->
  [Imp.TExp Int64] ->
  ImpM rep r op (VName, Space, [VName], [Imp.TExp Int64], Maybe (Lambda rep))
lookupAcc :: forall rep r op.
VName
-> [TExp Int64]
-> ImpM
     rep r op (VName, Space, [VName], [TExp Int64], Maybe (Lambda rep))
lookupAcc VName
name [TExp Int64]
is = do
  VarEntry rep
res <- VName -> ImpM rep r op (VarEntry rep)
forall rep r op. VName -> ImpM rep r op (VarEntry rep)
lookupVar VName
name
  case VarEntry rep
res of
    AccVar Maybe (Exp rep)
_ (VName
acc, Shape
ispace, [Type]
_) -> do
      Maybe ([VName], Maybe (Lambda rep, [SubExp]))
acc' <- (ImpState rep r op
 -> Maybe ([VName], Maybe (Lambda rep, [SubExp])))
-> ImpM rep r op (Maybe ([VName], Maybe (Lambda rep, [SubExp])))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ImpState rep r op
  -> Maybe ([VName], Maybe (Lambda rep, [SubExp])))
 -> ImpM rep r op (Maybe ([VName], Maybe (Lambda rep, [SubExp]))))
-> (ImpState rep r op
    -> Maybe ([VName], Maybe (Lambda rep, [SubExp])))
-> ImpM rep r op (Maybe ([VName], Maybe (Lambda rep, [SubExp])))
forall a b. (a -> b) -> a -> b
$ VName
-> Map VName ([VName], Maybe (Lambda rep, [SubExp]))
-> Maybe ([VName], Maybe (Lambda rep, [SubExp]))
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
acc (Map VName ([VName], Maybe (Lambda rep, [SubExp]))
 -> Maybe ([VName], Maybe (Lambda rep, [SubExp])))
-> (ImpState rep r op
    -> Map VName ([VName], Maybe (Lambda rep, [SubExp])))
-> ImpState rep r op
-> Maybe ([VName], Maybe (Lambda rep, [SubExp]))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ImpState rep r op
-> Map VName ([VName], Maybe (Lambda rep, [SubExp]))
forall {k} rep (r :: k) op.
ImpState rep r op
-> Map VName ([VName], Maybe (Lambda rep, [SubExp]))
stateAccs
      case Maybe ([VName], Maybe (Lambda rep, [SubExp]))
acc' of
        Just ([], Maybe (Lambda rep, [SubExp])
_) ->
          [Char]
-> ImpM
     rep r op (VName, Space, [VName], [TExp Int64], Maybe (Lambda rep))
forall a. HasCallStack => [Char] -> a
error ([Char]
 -> ImpM
      rep r op (VName, Space, [VName], [TExp Int64], Maybe (Lambda rep)))
-> [Char]
-> ImpM
     rep r op (VName, Space, [VName], [TExp Int64], Maybe (Lambda rep))
forall a b. (a -> b) -> a -> b
$ [Char]
"Accumulator with no arrays: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
prettyString VName
name
        Just (arrs :: [VName]
arrs@(VName
arr : [VName]
_), Just (Lambda rep
op, [SubExp]
_)) -> do
          Space
space <- VName -> ImpM rep r op Space
forall rep r op. VName -> ImpM rep r op Space
lookupArraySpace VName
arr
          let ([Param (LParamInfo rep)]
i_params, [Param (LParamInfo rep)]
ps) = Int
-> [Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([TExp Int64] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TExp Int64]
is) ([Param (LParamInfo rep)]
 -> ([Param (LParamInfo rep)], [Param (LParamInfo rep)]))
-> [Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Param (LParamInfo rep)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
op
          (VName -> TExp Int64 -> ImpM rep r op ())
-> [VName] -> [TExp Int64] -> ImpM rep r op ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TExp Int64 -> ImpM rep r op ()
forall {k} (t :: k) rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ ((Param (LParamInfo rep) -> VName)
-> [Param (LParamInfo rep)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (LParamInfo rep) -> VName
forall dec. Param dec -> VName
paramName [Param (LParamInfo rep)]
i_params) [TExp Int64]
is
          (VName, Space, [VName], [TExp Int64], Maybe (Lambda rep))
-> ImpM
     rep r op (VName, Space, [VName], [TExp Int64], Maybe (Lambda rep))
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
            ( VName
acc,
              Space
space,
              [VName]
arrs,
              (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
ispace),
              Lambda rep -> Maybe (Lambda rep)
forall a. a -> Maybe a
Just Lambda rep
op {lambdaParams :: [Param (LParamInfo rep)]
lambdaParams = [Param (LParamInfo rep)]
ps}
            )
        Just (arrs :: [VName]
arrs@(VName
arr : [VName]
_), Maybe (Lambda rep, [SubExp])
Nothing) -> do
          Space
space <- VName -> ImpM rep r op Space
forall rep r op. VName -> ImpM rep r op Space
lookupArraySpace VName
arr
          (VName, Space, [VName], [TExp Int64], Maybe (Lambda rep))
-> ImpM
     rep r op (VName, Space, [VName], [TExp Int64], Maybe (Lambda rep))
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
acc, Space
space, [VName]
arrs, (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
ispace), Maybe (Lambda rep)
forall a. Maybe a
Nothing)
        Maybe ([VName], Maybe (Lambda rep, [SubExp]))
Nothing ->
          [Char]
-> ImpM
     rep r op (VName, Space, [VName], [TExp Int64], Maybe (Lambda rep))
forall a. HasCallStack => [Char] -> a
error ([Char]
 -> ImpM
      rep r op (VName, Space, [VName], [TExp Int64], Maybe (Lambda rep)))
-> [Char]
-> ImpM
     rep r op (VName, Space, [VName], [TExp Int64], Maybe (Lambda rep))
forall a b. (a -> b) -> a -> b
$ [Char]
"ImpGen.lookupAcc: unlisted accumulator: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
prettyString VName
name
    VarEntry rep
_ -> [Char]
-> ImpM
     rep r op (VName, Space, [VName], [TExp Int64], Maybe (Lambda rep))
forall a. HasCallStack => [Char] -> a
error ([Char]
 -> ImpM
      rep r op (VName, Space, [VName], [TExp Int64], Maybe (Lambda rep)))
-> [Char]
-> ImpM
     rep r op (VName, Space, [VName], [TExp Int64], Maybe (Lambda rep))
forall a b. (a -> b) -> a -> b
$ [Char]
"ImpGen.lookupAcc: not an accumulator: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
prettyString VName
name

destinationFromPat :: Pat (LetDec rep) -> ImpM rep r op [ValueDestination]
destinationFromPat :: forall rep r op.
Pat (LetDec rep) -> ImpM rep r op [ValueDestination]
destinationFromPat = (PatElem (LetDec rep) -> ImpM rep r op ValueDestination)
-> [PatElem (LetDec rep)] -> ImpM rep r op [ValueDestination]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM PatElem (LetDec rep) -> ImpM rep r op ValueDestination
forall {dec} {rep} {r} {op}.
PatElem dec -> ImpM rep r op ValueDestination
inspect ([PatElem (LetDec rep)] -> ImpM rep r op [ValueDestination])
-> (Pat (LetDec rep) -> [PatElem (LetDec rep)])
-> Pat (LetDec rep)
-> ImpM rep r op [ValueDestination]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec rep) -> [PatElem (LetDec rep)]
forall dec. Pat dec -> [PatElem dec]
patElems
  where
    inspect :: PatElem dec -> ImpM rep r op ValueDestination
inspect PatElem dec
pe = do
      let name :: VName
name = PatElem dec -> VName
forall dec. PatElem dec -> VName
patElemName PatElem dec
pe
      VarEntry rep
entry <- VName -> ImpM rep r op (VarEntry rep)
forall rep r op. VName -> ImpM rep r op (VarEntry rep)
lookupVar VName
name
      case VarEntry rep
entry of
        ArrayVar Maybe (Exp rep)
_ (ArrayEntry MemLoc {} PrimType
_) ->
          ValueDestination -> ImpM rep r op ValueDestination
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ValueDestination -> ImpM rep r op ValueDestination)
-> ValueDestination -> ImpM rep r op ValueDestination
forall a b. (a -> b) -> a -> b
$ Maybe MemLoc -> ValueDestination
ArrayDestination Maybe MemLoc
forall a. Maybe a
Nothing
        MemVar {} ->
          ValueDestination -> ImpM rep r op ValueDestination
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ValueDestination -> ImpM rep r op ValueDestination)
-> ValueDestination -> ImpM rep r op ValueDestination
forall a b. (a -> b) -> a -> b
$ VName -> ValueDestination
MemoryDestination VName
name
        ScalarVar {} ->
          ValueDestination -> ImpM rep r op ValueDestination
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ValueDestination -> ImpM rep r op ValueDestination)
-> ValueDestination -> ImpM rep r op ValueDestination
forall a b. (a -> b) -> a -> b
$ VName -> ValueDestination
ScalarDestination VName
name
        AccVar {} ->
          ValueDestination -> ImpM rep r op ValueDestination
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ValueDestination -> ImpM rep r op ValueDestination)
-> ValueDestination -> ImpM rep r op ValueDestination
forall a b. (a -> b) -> a -> b
$ Maybe MemLoc -> ValueDestination
ArrayDestination Maybe MemLoc
forall a. Maybe a
Nothing

fullyIndexArray ::
  VName ->
  [Imp.TExp Int64] ->
  ImpM rep r op (VName, Imp.Space, Count Elements (Imp.TExp Int64))
fullyIndexArray :: forall rep r op.
VName
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray VName
name [TExp Int64]
indices = do
  ArrayEntry
arr <- VName -> ImpM rep r op ArrayEntry
forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray VName
name
  MemLoc
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
forall rep r op.
MemLoc
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' (ArrayEntry -> MemLoc
entryArrayLoc ArrayEntry
arr) [TExp Int64]
indices

fullyIndexArray' ::
  MemLoc ->
  [Imp.TExp Int64] ->
  ImpM rep r op (VName, Imp.Space, Count Elements (Imp.TExp Int64))
fullyIndexArray' :: forall rep r op.
MemLoc
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' (MemLoc VName
mem [SubExp]
_ LMAD (TExp Int64)
lmad) [TExp Int64]
indices = do
  Space
space <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM rep r op MemEntry -> ImpM rep r op Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM rep r op MemEntry
forall rep r op. VName -> ImpM rep r op MemEntry
lookupMemory VName
mem
  (VName, Space, Count Elements (TExp Int64))
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( VName
mem,
      Space
space,
      TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
elements (TExp Int64 -> Count Elements (TExp Int64))
-> TExp Int64 -> Count Elements (TExp Int64)
forall a b. (a -> b) -> a -> b
$ LMAD (TExp Int64) -> [TExp Int64] -> TExp Int64
forall num.
(IntegralExp num, Eq num) =>
LMAD num -> Indices num -> num
LMAD.index LMAD (TExp Int64)
lmad [TExp Int64]
indices
    )

-- More complicated read/write operations that use index functions.

copy :: CopyCompiler rep r op
copy :: forall rep r op. CopyCompiler rep r op
copy
  PrimType
bt
  dst :: MemLoc
dst@(MemLoc VName
dst_name [SubExp]
_ dst_ixfn :: LMAD (TExp Int64)
dst_ixfn@LMAD (TExp Int64)
dst_lmad)
  src :: MemLoc
src@(MemLoc VName
src_name [SubExp]
_ src_ixfn :: LMAD (TExp Int64)
src_ixfn@LMAD (TExp Int64)
src_lmad) = do
    -- If we can statically determine that the two index-functions
    -- are equivalent, don't do anything
    Bool -> ImpM rep r op () -> ImpM rep r op ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (VName
dst_name VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
src_name Bool -> Bool -> Bool
&& LMAD (TExp Int64)
dst_ixfn LMAD (TExp Int64) -> LMAD (TExp Int64) -> Bool
forall num. Eq num => LMAD num -> LMAD num -> Bool
`LMAD.equivalent` LMAD (TExp Int64)
src_ixfn)
      (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
      -- It's also possible that we can dynamically determine that the two
      -- index-functions are equivalent.
      TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sUnless
        ( Bool -> TExp Bool
forall v. Bool -> TPrimExp Bool v
fromBool (VName
dst_name VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
src_name)
            TExp Bool -> TExp Bool -> TExp Bool
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. LMAD (TExp Int64) -> LMAD (TExp Int64) -> TExp Bool
forall {k} num (t :: k).
Eq num =>
LMAD (TPrimExp t num) -> LMAD (TPrimExp t num) -> TPrimExp Bool num
LMAD.dynamicEqualsLMAD LMAD (TExp Int64)
dst_lmad LMAD (TExp Int64)
src_lmad
        )
      (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ do
        -- If none of the above is true, actually do the copy
        CopyCompiler rep r op
cc <- (Env rep r op -> CopyCompiler rep r op)
-> ImpM rep r op (CopyCompiler rep r op)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env rep r op -> CopyCompiler rep r op
forall rep r op. Env rep r op -> CopyCompiler rep r op
envCopyCompiler
        CopyCompiler rep r op
cc PrimType
bt MemLoc
dst MemLoc
src

lmadCopy :: CopyCompiler rep r op
lmadCopy :: forall rep r op. CopyCompiler rep r op
lmadCopy PrimType
t MemLoc
dstloc MemLoc
srcloc = do
  let dstmem :: VName
dstmem = MemLoc -> VName
memLocName MemLoc
dstloc
      srcmem :: VName
srcmem = MemLoc -> VName
memLocName MemLoc
srcloc
      dstlmad :: LMAD (TExp Int64)
dstlmad = MemLoc -> LMAD (TExp Int64)
memLocLMAD MemLoc
dstloc
      srclmad :: LMAD (TExp Int64)
srclmad = MemLoc -> LMAD (TExp Int64)
memLocLMAD MemLoc
srcloc
  Space
srcspace <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM rep r op MemEntry -> ImpM rep r op Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM rep r op MemEntry
forall rep r op. VName -> ImpM rep r op MemEntry
lookupMemory VName
srcmem
  Space
dstspace <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM rep r op MemEntry -> ImpM rep r op Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM rep r op MemEntry
forall rep r op. VName -> ImpM rep r op MemEntry
lookupMemory VName
dstmem
  Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
    PrimType
-> [Count Elements (TExp Int64)]
-> (VName, Space)
-> (Count Elements (TExp Int64), [Count Elements (TExp Int64)])
-> (VName, Space)
-> (Count Elements (TExp Int64), [Count Elements (TExp Int64)])
-> Code op
forall a.
PrimType
-> [Count Elements (TExp Int64)]
-> (VName, Space)
-> (Count Elements (TExp Int64), [Count Elements (TExp Int64)])
-> (VName, Space)
-> (Count Elements (TExp Int64), [Count Elements (TExp Int64)])
-> Code a
Imp.LMADCopy
      PrimType
t
      (TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
elements (TExp Int64 -> Count Elements (TExp Int64))
-> [TExp Int64] -> [Count Elements (TExp Int64)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> LMAD (TExp Int64) -> [TExp Int64]
forall num. LMAD num -> Shape num
LMAD.shape LMAD (TExp Int64)
dstlmad)
      (VName
dstmem, Space
dstspace)
      ( LMAD (Count Elements (TExp Int64)) -> Count Elements (TExp Int64)
forall num. LMAD num -> num
LMAD.offset (LMAD (Count Elements (TExp Int64)) -> Count Elements (TExp Int64))
-> LMAD (Count Elements (TExp Int64))
-> Count Elements (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
elements (TExp Int64 -> Count Elements (TExp Int64))
-> LMAD (TExp Int64) -> LMAD (Count Elements (TExp Int64))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> LMAD (TExp Int64)
dstlmad,
        (LMADDim (Count Elements (TExp Int64))
 -> Count Elements (TExp Int64))
-> [LMADDim (Count Elements (TExp Int64))]
-> [Count Elements (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map LMADDim (Count Elements (TExp Int64))
-> Count Elements (TExp Int64)
forall num. LMADDim num -> num
LMAD.ldStride ([LMADDim (Count Elements (TExp Int64))]
 -> [Count Elements (TExp Int64)])
-> [LMADDim (Count Elements (TExp Int64))]
-> [Count Elements (TExp Int64)]
forall a b. (a -> b) -> a -> b
$ LMAD (Count Elements (TExp Int64))
-> [LMADDim (Count Elements (TExp Int64))]
forall num. LMAD num -> [LMADDim num]
LMAD.dims (LMAD (Count Elements (TExp Int64))
 -> [LMADDim (Count Elements (TExp Int64))])
-> LMAD (Count Elements (TExp Int64))
-> [LMADDim (Count Elements (TExp Int64))]
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
elements (TExp Int64 -> Count Elements (TExp Int64))
-> LMAD (TExp Int64) -> LMAD (Count Elements (TExp Int64))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> LMAD (TExp Int64)
dstlmad
      )
      (VName
srcmem, Space
srcspace)
      ( LMAD (Count Elements (TExp Int64)) -> Count Elements (TExp Int64)
forall num. LMAD num -> num
LMAD.offset (LMAD (Count Elements (TExp Int64)) -> Count Elements (TExp Int64))
-> LMAD (Count Elements (TExp Int64))
-> Count Elements (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
elements (TExp Int64 -> Count Elements (TExp Int64))
-> LMAD (TExp Int64) -> LMAD (Count Elements (TExp Int64))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> LMAD (TExp Int64)
srclmad,
        (LMADDim (Count Elements (TExp Int64))
 -> Count Elements (TExp Int64))
-> [LMADDim (Count Elements (TExp Int64))]
-> [Count Elements (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map LMADDim (Count Elements (TExp Int64))
-> Count Elements (TExp Int64)
forall num. LMADDim num -> num
LMAD.ldStride ([LMADDim (Count Elements (TExp Int64))]
 -> [Count Elements (TExp Int64)])
-> [LMADDim (Count Elements (TExp Int64))]
-> [Count Elements (TExp Int64)]
forall a b. (a -> b) -> a -> b
$ LMAD (Count Elements (TExp Int64))
-> [LMADDim (Count Elements (TExp Int64))]
forall num. LMAD num -> [LMADDim num]
LMAD.dims (LMAD (Count Elements (TExp Int64))
 -> [LMADDim (Count Elements (TExp Int64))])
-> LMAD (Count Elements (TExp Int64))
-> [LMADDim (Count Elements (TExp Int64))]
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
elements (TExp Int64 -> Count Elements (TExp Int64))
-> LMAD (TExp Int64) -> LMAD (Count Elements (TExp Int64))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> LMAD (TExp Int64)
srclmad
      )

-- | Copy from here to there; both destination and source may be
-- indexeded.
copyArrayDWIM ::
  PrimType ->
  MemLoc ->
  [DimIndex (Imp.TExp Int64)] ->
  MemLoc ->
  [DimIndex (Imp.TExp Int64)] ->
  ImpM rep r op (Imp.Code op)
copyArrayDWIM :: forall rep r op.
PrimType
-> MemLoc
-> [DimIndex (TExp Int64)]
-> MemLoc
-> [DimIndex (TExp Int64)]
-> ImpM rep r op (Code op)
copyArrayDWIM
  PrimType
bt
  destlocation :: MemLoc
destlocation@(MemLoc VName
_ [SubExp]
destshape LMAD (TExp Int64)
_)
  [DimIndex (TExp Int64)]
destslice
  srclocation :: MemLoc
srclocation@(MemLoc VName
_ [SubExp]
srcshape LMAD (TExp Int64)
_)
  [DimIndex (TExp Int64)]
srcslice
    | Just [TExp Int64]
destis <- (DimIndex (TExp Int64) -> Maybe (TExp Int64))
-> [DimIndex (TExp Int64)] -> Maybe [TExp Int64]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM DimIndex (TExp Int64) -> Maybe (TExp Int64)
forall d. DimIndex d -> Maybe d
dimFix [DimIndex (TExp Int64)]
destslice,
      Just [TExp Int64]
srcis <- (DimIndex (TExp Int64) -> Maybe (TExp Int64))
-> [DimIndex (TExp Int64)] -> Maybe [TExp Int64]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM DimIndex (TExp Int64) -> Maybe (TExp Int64)
forall d. DimIndex d -> Maybe d
dimFix [DimIndex (TExp Int64)]
srcslice,
      [TExp Int64] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TExp Int64]
srcis Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
srcshape,
      [TExp Int64] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TExp Int64]
destis Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
destshape = do
        (VName
targetmem, Space
destspace, Count Elements (TExp Int64)
targetoffset) <-
          MemLoc
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
forall rep r op.
MemLoc
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLoc
destlocation [TExp Int64]
destis
        (VName
srcmem, Space
srcspace, Count Elements (TExp Int64)
srcoffset) <-
          MemLoc
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
forall rep r op.
MemLoc
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLoc
srclocation [TExp Int64]
srcis
        Volatility
vol <- (Env rep r op -> Volatility) -> ImpM rep r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env rep r op -> Volatility
forall rep r op. Env rep r op -> Volatility
envVolatility
        ImpM rep r op () -> ImpM rep r op (Code op)
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect (ImpM rep r op () -> ImpM rep r op (Code op))
-> ImpM rep r op () -> ImpM rep r op (Code op)
forall a b. (a -> b) -> a -> b
$ do
          VName
tmp <- TV Any -> VName
forall {k} (t :: k). TV t -> VName
tvVar (TV Any -> VName) -> ImpM rep r op (TV Any) -> ImpM rep r op VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Char] -> PrimType -> ImpM rep r op (TV Any)
forall {k} rep r op (t :: k).
[Char] -> PrimType -> ImpM rep r op (TV t)
dPrim [Char]
"tmp" PrimType
bt
          Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName
-> VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Code op
forall a.
VName
-> VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Code a
Imp.Read VName
tmp VName
srcmem Count Elements (TExp Int64)
srcoffset PrimType
bt Space
srcspace Volatility
vol
          Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code op
forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
targetmem Count Elements (TExp Int64)
targetoffset PrimType
bt Space
destspace Volatility
vol (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
tmp PrimType
bt
    | Bool
otherwise = do
        let destslice' :: Slice (TExp Int64)
destslice' = [TExp Int64] -> [DimIndex (TExp Int64)] -> Slice (TExp Int64)
forall d. Num d => [d] -> [DimIndex d] -> Slice d
fullSliceNum ((SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 [SubExp]
destshape) [DimIndex (TExp Int64)]
destslice
            srcslice' :: Slice (TExp Int64)
srcslice' = [TExp Int64] -> [DimIndex (TExp Int64)] -> Slice (TExp Int64)
forall d. Num d => [d] -> [DimIndex d] -> Slice d
fullSliceNum ((SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 [SubExp]
srcshape) [DimIndex (TExp Int64)]
srcslice
            destrank :: Int
destrank = [TExp Int64] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([TExp Int64] -> Int) -> [TExp Int64] -> Int
forall a b. (a -> b) -> a -> b
$ Slice (TExp Int64) -> [TExp Int64]
forall d. Slice d -> [d]
sliceDims Slice (TExp Int64)
destslice'
            srcrank :: Int
srcrank = [TExp Int64] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([TExp Int64] -> Int) -> [TExp Int64] -> Int
forall a b. (a -> b) -> a -> b
$ Slice (TExp Int64) -> [TExp Int64]
forall d. Slice d -> [d]
sliceDims Slice (TExp Int64)
srcslice'
            destlocation' :: MemLoc
destlocation' = MemLoc -> Slice (TExp Int64) -> MemLoc
sliceMemLoc MemLoc
destlocation Slice (TExp Int64)
destslice'
            srclocation' :: MemLoc
srclocation' = MemLoc -> Slice (TExp Int64) -> MemLoc
sliceMemLoc MemLoc
srclocation Slice (TExp Int64)
srcslice'
        if Int
destrank Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
srcrank
          then
            [Char] -> ImpM rep r op (Code op)
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM rep r op (Code op))
-> [Char] -> ImpM rep r op (Code op)
forall a b. (a -> b) -> a -> b
$
              [Char]
"copyArrayDWIM: cannot copy to "
                [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
prettyString (MemLoc -> VName
memLocName MemLoc
destlocation)
                [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
" from "
                [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
prettyString (MemLoc -> VName
memLocName MemLoc
srclocation)
                [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
" because ranks do not match ("
                [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Pretty a => a -> [Char]
prettyString Int
destrank
                [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
" vs "
                [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Pretty a => a -> [Char]
prettyString Int
srcrank
                [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
")"
          else
            if MemLoc
destlocation' MemLoc -> MemLoc -> Bool
forall a. Eq a => a -> a -> Bool
== MemLoc
srclocation'
              then Code op -> ImpM rep r op (Code op)
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Code op
forall a. Monoid a => a
mempty -- Copy would be no-op.
              else ImpM rep r op () -> ImpM rep r op (Code op)
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect (ImpM rep r op () -> ImpM rep r op (Code op))
-> ImpM rep r op () -> ImpM rep r op (Code op)
forall a b. (a -> b) -> a -> b
$ CopyCompiler rep r op
forall rep r op. CopyCompiler rep r op
copy PrimType
bt MemLoc
destlocation' MemLoc
srclocation'

-- Like 'copyDWIM', but the target is a 'ValueDestination' instead of
-- a variable name.
copyDWIMDest ::
  ValueDestination ->
  [DimIndex (Imp.TExp Int64)] ->
  SubExp ->
  [DimIndex (Imp.TExp Int64)] ->
  ImpM rep r op ()
copyDWIMDest :: forall rep r op.
ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIMDest ValueDestination
_ [DimIndex (TExp Int64)]
_ (Constant PrimValue
v) (DimIndex (TExp Int64)
_ : [DimIndex (TExp Int64)]
_) =
  [Char] -> ImpM rep r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM rep r op ()) -> [Char] -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
    [[Char]] -> [Char]
unwords [[Char]
"copyDWIMDest: constant source", PrimValue -> [Char]
forall a. Pretty a => a -> [Char]
prettyString PrimValue
v, [Char]
"cannot be indexed."]
copyDWIMDest ValueDestination
pat [DimIndex (TExp Int64)]
dest_slice (Constant PrimValue
v) [] =
  case (DimIndex (TExp Int64) -> Maybe (TExp Int64))
-> [DimIndex (TExp Int64)] -> Maybe [TExp Int64]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM DimIndex (TExp Int64) -> Maybe (TExp Int64)
forall d. DimIndex d -> Maybe d
dimFix [DimIndex (TExp Int64)]
dest_slice of
    Maybe [TExp Int64]
Nothing ->
      [Char] -> ImpM rep r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM rep r op ()) -> [Char] -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
        [[Char]] -> [Char]
unwords [[Char]
"copyDWIMDest: constant source", PrimValue -> [Char]
forall a. Pretty a => a -> [Char]
prettyString PrimValue
v, [Char]
"with slice destination."]
    Just [TExp Int64]
dest_is ->
      case ValueDestination
pat of
        ScalarDestination VName
name ->
          Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
name (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ PrimValue -> Exp
forall v. PrimValue -> PrimExp v
Imp.ValueExp PrimValue
v
        MemoryDestination {} ->
          [Char] -> ImpM rep r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM rep r op ()) -> [Char] -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
            [[Char]] -> [Char]
unwords [[Char]
"copyDWIMDest: constant source", PrimValue -> [Char]
forall a. Pretty a => a -> [Char]
prettyString PrimValue
v, [Char]
"cannot be written to memory destination."]
        ArrayDestination (Just MemLoc
dest_loc) -> do
          (VName
dest_mem, Space
dest_space, Count Elements (TExp Int64)
dest_i) <-
            MemLoc
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
forall rep r op.
MemLoc
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLoc
dest_loc [TExp Int64]
dest_is
          Volatility
vol <- (Env rep r op -> Volatility) -> ImpM rep r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env rep r op -> Volatility
forall rep r op. Env rep r op -> Volatility
envVolatility
          Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code op
forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
dest_mem Count Elements (TExp Int64)
dest_i PrimType
bt Space
dest_space Volatility
vol (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ PrimValue -> Exp
forall v. PrimValue -> PrimExp v
Imp.ValueExp PrimValue
v
        ArrayDestination Maybe MemLoc
Nothing ->
          [Char] -> ImpM rep r op ()
forall a. HasCallStack => [Char] -> a
error [Char]
"copyDWIMDest: ArrayDestination Nothing"
  where
    bt :: PrimType
bt = PrimValue -> PrimType
primValueType PrimValue
v
copyDWIMDest ValueDestination
dest [DimIndex (TExp Int64)]
dest_slice (Var VName
src) [DimIndex (TExp Int64)]
src_slice = do
  VarEntry rep
src_entry <- VName -> ImpM rep r op (VarEntry rep)
forall rep r op. VName -> ImpM rep r op (VarEntry rep)
lookupVar VName
src
  case (ValueDestination
dest, VarEntry rep
src_entry) of
    (MemoryDestination VName
mem, MemVar Maybe (Exp rep)
_ (MemEntry Space
space)) ->
      Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> Space -> Code op
forall a. VName -> VName -> Space -> Code a
Imp.SetMem VName
mem VName
src Space
space
    (MemoryDestination {}, VarEntry rep
_) ->
      [Char] -> ImpM rep r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM rep r op ()) -> [Char] -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
        [[Char]] -> [Char]
unwords [[Char]
"copyDWIMDest: cannot write", VName -> [Char]
forall a. Pretty a => a -> [Char]
prettyString VName
src, [Char]
"to memory destination."]
    (ValueDestination
_, MemVar {}) ->
      [Char] -> ImpM rep r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM rep r op ()) -> [Char] -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
        [[Char]] -> [Char]
unwords [[Char]
"copyDWIMDest: source", VName -> [Char]
forall a. Pretty a => a -> [Char]
prettyString VName
src, [Char]
"is a memory block."]
    (ValueDestination
_, ScalarVar Maybe (Exp rep)
_ (ScalarEntry PrimType
_))
      | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [DimIndex (TExp Int64)] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [DimIndex (TExp Int64)]
src_slice ->
          [Char] -> ImpM rep r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM rep r op ()) -> [Char] -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
            [[Char]] -> [Char]
unwords [[Char]
"copyDWIMDest: prim-typed source", VName -> [Char]
forall a. Pretty a => a -> [Char]
prettyString VName
src, [Char]
"with slice", [DimIndex (TExp Int64)] -> [Char]
forall a. Pretty a => a -> [Char]
prettyString [DimIndex (TExp Int64)]
src_slice]
    (ScalarDestination VName
name, VarEntry rep
_)
      | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [DimIndex (TExp Int64)] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [DimIndex (TExp Int64)]
dest_slice ->
          [Char] -> ImpM rep r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM rep r op ()) -> [Char] -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
            [[Char]] -> [Char]
unwords [[Char]
"copyDWIMDest: prim-typed target", VName -> [Char]
forall a. Pretty a => a -> [Char]
prettyString VName
name, [Char]
"with slice", [DimIndex (TExp Int64)] -> [Char]
forall a. Pretty a => a -> [Char]
prettyString [DimIndex (TExp Int64)]
dest_slice]
    (ScalarDestination VName
name, ScalarVar Maybe (Exp rep)
_ (ScalarEntry PrimType
pt)) ->
      Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
name (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
src PrimType
pt
    (ScalarDestination VName
name, ArrayVar Maybe (Exp rep)
_ ArrayEntry
arr)
      | Just [TExp Int64]
src_is <- (DimIndex (TExp Int64) -> Maybe (TExp Int64))
-> [DimIndex (TExp Int64)] -> Maybe [TExp Int64]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM DimIndex (TExp Int64) -> Maybe (TExp Int64)
forall d. DimIndex d -> Maybe d
dimFix [DimIndex (TExp Int64)]
src_slice,
        [DimIndex (TExp Int64)] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex (TExp Int64)]
src_slice Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (ArrayEntry -> [SubExp]
entryArrayShape ArrayEntry
arr) -> do
          let bt :: PrimType
bt = ArrayEntry -> PrimType
entryArrayElemType ArrayEntry
arr
          (VName
mem, Space
space, Count Elements (TExp Int64)
i) <-
            MemLoc
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
forall rep r op.
MemLoc
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' (ArrayEntry -> MemLoc
entryArrayLoc ArrayEntry
arr) [TExp Int64]
src_is
          Volatility
vol <- (Env rep r op -> Volatility) -> ImpM rep r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env rep r op -> Volatility
forall rep r op. Env rep r op -> Volatility
envVolatility
          Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName
-> VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Code op
forall a.
VName
-> VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Code a
Imp.Read VName
name VName
mem Count Elements (TExp Int64)
i PrimType
bt Space
space Volatility
vol
      | Bool
otherwise ->
          [Char] -> ImpM rep r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM rep r op ()) -> [Char] -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
            [[Char]] -> [Char]
unwords
              [ [Char]
"copyDWIMDest: prim-typed target",
                VName -> [Char]
forall a. Pretty a => a -> [Char]
prettyString VName
name,
                [Char]
"and array-typed source",
                VName -> [Char]
forall a. Pretty a => a -> [Char]
prettyString VName
src,
                [Char]
"of shape",
                [SubExp] -> [Char]
forall a. Pretty a => a -> [Char]
prettyString (ArrayEntry -> [SubExp]
entryArrayShape ArrayEntry
arr),
                [Char]
"sliced with",
                [DimIndex (TExp Int64)] -> [Char]
forall a. Pretty a => a -> [Char]
prettyString [DimIndex (TExp Int64)]
src_slice
              ]
    (ArrayDestination (Just MemLoc
dest_loc), ArrayVar Maybe (Exp rep)
_ ArrayEntry
src_arr) -> do
      let src_loc :: MemLoc
src_loc = ArrayEntry -> MemLoc
entryArrayLoc ArrayEntry
src_arr
          bt :: PrimType
bt = ArrayEntry -> PrimType
entryArrayElemType ArrayEntry
src_arr
      Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ())
-> ImpM rep r op (Code op) -> ImpM rep r op ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< PrimType
-> MemLoc
-> [DimIndex (TExp Int64)]
-> MemLoc
-> [DimIndex (TExp Int64)]
-> ImpM rep r op (Code op)
forall rep r op.
PrimType
-> MemLoc
-> [DimIndex (TExp Int64)]
-> MemLoc
-> [DimIndex (TExp Int64)]
-> ImpM rep r op (Code op)
copyArrayDWIM PrimType
bt MemLoc
dest_loc [DimIndex (TExp Int64)]
dest_slice MemLoc
src_loc [DimIndex (TExp Int64)]
src_slice
    (ArrayDestination (Just MemLoc
dest_loc), ScalarVar Maybe (Exp rep)
_ (ScalarEntry PrimType
bt))
      | Just [TExp Int64]
dest_is <- (DimIndex (TExp Int64) -> Maybe (TExp Int64))
-> [DimIndex (TExp Int64)] -> Maybe [TExp Int64]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM DimIndex (TExp Int64) -> Maybe (TExp Int64)
forall d. DimIndex d -> Maybe d
dimFix [DimIndex (TExp Int64)]
dest_slice,
        [TExp Int64] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TExp Int64]
dest_is Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (MemLoc -> [SubExp]
memLocShape MemLoc
dest_loc) -> do
          (VName
dest_mem, Space
dest_space, Count Elements (TExp Int64)
dest_i) <- MemLoc
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
forall rep r op.
MemLoc
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLoc
dest_loc [TExp Int64]
dest_is
          Volatility
vol <- (Env rep r op -> Volatility) -> ImpM rep r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env rep r op -> Volatility
forall rep r op. Env rep r op -> Volatility
envVolatility
          Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code op
forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
dest_mem Count Elements (TExp Int64)
dest_i PrimType
bt Space
dest_space Volatility
vol (VName -> PrimType -> Exp
Imp.var VName
src PrimType
bt)
      | Bool
otherwise ->
          [Char] -> ImpM rep r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM rep r op ()) -> [Char] -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
            [[Char]] -> [Char]
unwords
              [ [Char]
"copyDWIMDest: array-typed target and prim-typed source",
                VName -> [Char]
forall a. Pretty a => a -> [Char]
prettyString VName
src,
                [Char]
"with slice",
                [DimIndex (TExp Int64)] -> [Char]
forall a. Pretty a => a -> [Char]
prettyString [DimIndex (TExp Int64)]
dest_slice
              ]
    (ArrayDestination Maybe MemLoc
Nothing, VarEntry rep
_) ->
      () -> ImpM rep r op ()
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure () -- Nothing to do; something else set some memory
      -- somewhere.
    (ValueDestination
_, AccVar {}) ->
      () -> ImpM rep r op ()
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure () -- Nothing to do; accumulators are phantoms.

-- | Copy from here to there; both destination and source be
-- indexeded.  If so, they better be arrays of enough dimensions.
-- This function will generally just Do What I Mean, and Do The Right
-- Thing.  Both destination and source must be in scope.
copyDWIM ::
  VName ->
  [DimIndex (Imp.TExp Int64)] ->
  SubExp ->
  [DimIndex (Imp.TExp Int64)] ->
  ImpM rep r op ()
copyDWIM :: forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM VName
dest [DimIndex (TExp Int64)]
dest_slice SubExp
src [DimIndex (TExp Int64)]
src_slice = do
  VarEntry rep
dest_entry <- VName -> ImpM rep r op (VarEntry rep)
forall rep r op. VName -> ImpM rep r op (VarEntry rep)
lookupVar VName
dest
  let dest_target :: ValueDestination
dest_target =
        case VarEntry rep
dest_entry of
          ScalarVar Maybe (Exp rep)
_ ScalarEntry
_ ->
            VName -> ValueDestination
ScalarDestination VName
dest
          ArrayVar Maybe (Exp rep)
_ (ArrayEntry (MemLoc VName
mem [SubExp]
shape LMAD (TExp Int64)
ixfun) PrimType
_) ->
            Maybe MemLoc -> ValueDestination
ArrayDestination (Maybe MemLoc -> ValueDestination)
-> Maybe MemLoc -> ValueDestination
forall a b. (a -> b) -> a -> b
$ MemLoc -> Maybe MemLoc
forall a. a -> Maybe a
Just (MemLoc -> Maybe MemLoc) -> MemLoc -> Maybe MemLoc
forall a b. (a -> b) -> a -> b
$ VName -> [SubExp] -> LMAD (TExp Int64) -> MemLoc
MemLoc VName
mem [SubExp]
shape LMAD (TExp Int64)
ixfun
          MemVar Maybe (Exp rep)
_ MemEntry
_ ->
            VName -> ValueDestination
MemoryDestination VName
dest
          AccVar {} ->
            -- Does not matter; accumulators are phantoms.
            Maybe MemLoc -> ValueDestination
ArrayDestination Maybe MemLoc
forall a. Maybe a
Nothing
  ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIMDest ValueDestination
dest_target [DimIndex (TExp Int64)]
dest_slice SubExp
src [DimIndex (TExp Int64)]
src_slice

-- | As 'copyDWIM', but implicitly 'DimFix'es the indexes.
copyDWIMFix ::
  VName ->
  [Imp.TExp Int64] ->
  SubExp ->
  [Imp.TExp Int64] ->
  ImpM rep r op ()
copyDWIMFix :: forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
dest [TExp Int64]
dest_is SubExp
src [TExp Int64]
src_is =
  VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM VName
dest ((TExp Int64 -> DimIndex (TExp Int64))
-> [TExp Int64] -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix [TExp Int64]
dest_is) SubExp
src ((TExp Int64 -> DimIndex (TExp Int64))
-> [TExp Int64] -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix [TExp Int64]
src_is)

-- | @compileAlloc pat size space@ allocates @n@ bytes of memory in
-- @space@, writing the result to @pat@, which must contain a single
-- memory-typed element.
compileAlloc ::
  (Mem rep inner) => Pat (LetDec rep) -> SubExp -> Space -> ImpM rep r op ()
compileAlloc :: forall rep (inner :: * -> *) r op.
Mem rep inner =>
Pat (LetDec rep) -> SubExp -> Space -> ImpM rep r op ()
compileAlloc (Pat [PatElem (LetDec rep)
mem]) SubExp
e Space
space = do
  let e' :: Count Bytes (TExp Int64)
e' = TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
Imp.bytes (TExp Int64 -> Count Bytes (TExp Int64))
-> TExp Int64 -> Count Bytes (TExp Int64)
forall a b. (a -> b) -> a -> b
$ SubExp -> TExp Int64
pe64 SubExp
e
  Maybe (AllocCompiler rep r op)
allocator <- (Env rep r op -> Maybe (AllocCompiler rep r op))
-> ImpM rep r op (Maybe (AllocCompiler rep r op))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((Env rep r op -> Maybe (AllocCompiler rep r op))
 -> ImpM rep r op (Maybe (AllocCompiler rep r op)))
-> (Env rep r op -> Maybe (AllocCompiler rep r op))
-> ImpM rep r op (Maybe (AllocCompiler rep r op))
forall a b. (a -> b) -> a -> b
$ Space
-> Map Space (AllocCompiler rep r op)
-> Maybe (AllocCompiler rep r op)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Space
space (Map Space (AllocCompiler rep r op)
 -> Maybe (AllocCompiler rep r op))
-> (Env rep r op -> Map Space (AllocCompiler rep r op))
-> Env rep r op
-> Maybe (AllocCompiler rep r op)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env rep r op -> Map Space (AllocCompiler rep r op)
forall rep r op. Env rep r op -> Map Space (AllocCompiler rep r op)
envAllocCompilers
  case Maybe (AllocCompiler rep r op)
allocator of
    Maybe (AllocCompiler rep r op)
Nothing -> Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Count Bytes (TExp Int64) -> Space -> Code op
forall a. VName -> Count Bytes (TExp Int64) -> Space -> Code a
Imp.Allocate (PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
mem) Count Bytes (TExp Int64)
e' Space
space
    Just AllocCompiler rep r op
allocator' -> AllocCompiler rep r op
allocator' (PatElem (LetDec rep) -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
mem) Count Bytes (TExp Int64)
e'
compileAlloc Pat (LetDec rep)
pat SubExp
_ Space
_ =
  [Char] -> ImpM rep r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM rep r op ()) -> [Char] -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ [Char]
"compileAlloc: Invalid pattern: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ Pat (LetDec rep) -> [Char]
forall a. Pretty a => a -> [Char]
prettyString Pat (LetDec rep)
pat

-- | The number of bytes needed to represent the array in a
-- straightforward contiguous format, as an t'Int64' expression.
typeSize :: Type -> Count Bytes (Imp.TExp Int64)
typeSize :: Type -> Count Bytes (TExp Int64)
typeSize Type
t =
  TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
Imp.bytes (TExp Int64 -> Count Bytes (TExp Int64))
-> TExp Int64 -> Count Bytes (TExp Int64)
forall a b. (a -> b) -> a -> b
$ PrimType -> TExp Int64
forall a. Num a => PrimType -> a
primByteSize (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t) TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* [TExp Int64] -> TExp Int64
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ((SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 (Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t))

-- | Is this indexing in-bounds for an array of the given shape?  This
-- is useful for things like scatter, which ignores out-of-bounds
-- writes.
inBounds :: Slice (Imp.TExp Int64) -> [Imp.TExp Int64] -> Imp.TExp Bool
inBounds :: Slice (TExp Int64) -> [TExp Int64] -> TExp Bool
inBounds (Slice [DimIndex (TExp Int64)]
slice) [TExp Int64]
dims =
  let condInBounds :: DimIndex (TPrimExp t v) -> TPrimExp t v -> TPrimExp Bool v
condInBounds (DimFix TPrimExp t v
i) TPrimExp t v
d =
        TPrimExp t v
0 TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TPrimExp t v
i TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TPrimExp t v
i TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp t v
d
      condInBounds (DimSlice TPrimExp t v
i TPrimExp t v
n TPrimExp t v
s) TPrimExp t v
d =
        TPrimExp t v
0 TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TPrimExp t v
i TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TPrimExp t v
i TPrimExp t v -> TPrimExp t v -> TPrimExp t v
forall a. Num a => a -> a -> a
+ (TPrimExp t v
n TPrimExp t v -> TPrimExp t v -> TPrimExp t v
forall a. Num a => a -> a -> a
- TPrimExp t v
1) TPrimExp t v -> TPrimExp t v -> TPrimExp t v
forall a. Num a => a -> a -> a
* TPrimExp t v
s TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp t v
d
   in (TExp Bool -> TExp Bool -> TExp Bool) -> [TExp Bool] -> TExp Bool
forall a. (a -> a -> a) -> [a] -> a
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 TExp Bool -> TExp Bool -> TExp Bool
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
(.&&.) ([TExp Bool] -> TExp Bool) -> [TExp Bool] -> TExp Bool
forall a b. (a -> b) -> a -> b
$ (DimIndex (TExp Int64) -> TExp Int64 -> TExp Bool)
-> [DimIndex (TExp Int64)] -> [TExp Int64] -> [TExp Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith DimIndex (TExp Int64) -> TExp Int64 -> TExp Bool
forall {k} {v} {t :: k}.
(Eq v, NumExp t, Pretty v) =>
DimIndex (TPrimExp t v) -> TPrimExp t v -> TPrimExp Bool v
condInBounds [DimIndex (TExp Int64)]
slice [TExp Int64]
dims

--- Building blocks for constructing code.

sFor' :: VName -> Imp.Exp -> ImpM rep r op () -> ImpM rep r op ()
sFor' :: forall rep r op.
VName -> Exp -> ImpM rep r op () -> ImpM rep r op ()
sFor' VName
i Exp
bound ImpM rep r op ()
body = do
  let it :: IntType
it = case Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType Exp
bound of
        IntType IntType
bound_t -> IntType
bound_t
        PrimType
t -> [Char] -> IntType
forall a. HasCallStack => [Char] -> a
error ([Char] -> IntType) -> [Char] -> IntType
forall a b. (a -> b) -> a -> b
$ [Char]
"sFor': bound " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ Exp -> [Char]
forall a. Pretty a => a -> [Char]
prettyString Exp
bound [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
" is of type " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ PrimType -> [Char]
forall a. Pretty a => a -> [Char]
prettyString PrimType
t
  VName -> IntType -> ImpM rep r op ()
forall rep r op. VName -> IntType -> ImpM rep r op ()
addLoopVar VName
i IntType
it
  Code op
body' <- ImpM rep r op () -> ImpM rep r op (Code op)
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect ImpM rep r op ()
body
  Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op -> Code op
forall a. VName -> Exp -> Code a -> Code a
Imp.For VName
i Exp
bound Code op
body'

sFor :: String -> Imp.TExp t -> (Imp.TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor :: forall {k} (t :: k) rep r op.
[Char]
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor [Char]
i TExp t
bound TExp t -> ImpM rep r op ()
body = do
  VName
i' <- [Char] -> ImpM rep r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
i
  VName -> Exp -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op.
VName -> Exp -> ImpM rep r op () -> ImpM rep r op ()
sFor' VName
i' (TExp t -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp t
bound) (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
    TExp t -> ImpM rep r op ()
body (TExp t -> ImpM rep r op ()) -> TExp t -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
      Exp -> TExp t
forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp (Exp -> TExp t) -> Exp -> TExp t
forall a b. (a -> b) -> a -> b
$
        VName -> PrimType -> Exp
Imp.var VName
i' (PrimType -> Exp) -> PrimType -> Exp
forall a b. (a -> b) -> a -> b
$
          Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType (Exp -> PrimType) -> Exp -> PrimType
forall a b. (a -> b) -> a -> b
$
            TExp t -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp t
bound

sWhile :: Imp.TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhile :: forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhile TExp Bool
cond ImpM rep r op ()
body = do
  Code op
body' <- ImpM rep r op () -> ImpM rep r op (Code op)
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect ImpM rep r op ()
body
  Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ TExp Bool -> Code op -> Code op
forall a. TExp Bool -> Code a -> Code a
Imp.While TExp Bool
cond Code op
body'

sComment :: T.Text -> ImpM rep r op () -> ImpM rep r op ()
sComment :: forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
s ImpM rep r op ()
code = do
  Code op
code' <- ImpM rep r op () -> ImpM rep r op (Code op)
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect ImpM rep r op ()
code
  Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Text -> Code op -> Code op
forall a. Text -> Code a -> Code a
Imp.Comment Text
s Code op
code'

sIf :: Imp.TExp Bool -> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf :: forall rep r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf TExp Bool
cond ImpM rep r op ()
tbranch ImpM rep r op ()
fbranch = do
  Code op
tbranch' <- ImpM rep r op () -> ImpM rep r op (Code op)
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect ImpM rep r op ()
tbranch
  Code op
fbranch' <- ImpM rep r op () -> ImpM rep r op (Code op)
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect ImpM rep r op ()
fbranch
  -- Avoid generating branch if the condition is known statically.
  Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
    if TExp Bool
cond TExp Bool -> TExp Bool -> Bool
forall a. Eq a => a -> a -> Bool
== TExp Bool
forall v. TPrimExp Bool v
true
      then Code op
tbranch'
      else
        if TExp Bool
cond TExp Bool -> TExp Bool -> Bool
forall a. Eq a => a -> a -> Bool
== TExp Bool
forall v. TPrimExp Bool v
false
          then Code op
fbranch'
          else TExp Bool -> Code op -> Code op -> Code op
forall a. TExp Bool -> Code a -> Code a -> Code a
Imp.If TExp Bool
cond Code op
tbranch' Code op
fbranch'

sWhen :: Imp.TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen :: forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
cond ImpM rep r op ()
tbranch = TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf TExp Bool
cond ImpM rep r op ()
tbranch (() -> ImpM rep r op ()
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ())

sUnless :: Imp.TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sUnless :: forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sUnless TExp Bool
cond = TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf TExp Bool
cond (() -> ImpM rep r op ()
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ())

sOp :: op -> ImpM rep r op ()
sOp :: forall op rep r. op -> ImpM rep r op ()
sOp = Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ())
-> (op -> Code op) -> op -> ImpM rep r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. op -> Code op
forall a. a -> Code a
Imp.Op

sDeclareMem :: String -> Space -> ImpM rep r op VName
sDeclareMem :: forall rep r op. [Char] -> Space -> ImpM rep r op VName
sDeclareMem [Char]
name Space
space = do
  VName
name' <- [Char] -> ImpM rep r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
name
  Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Code op
forall a. VName -> Space -> Code a
Imp.DeclareMem VName
name' Space
space
  VName -> VarEntry rep -> ImpM rep r op ()
forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
name' (VarEntry rep -> ImpM rep r op ())
-> VarEntry rep -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp rep) -> MemEntry -> VarEntry rep
forall rep. Maybe (Exp rep) -> MemEntry -> VarEntry rep
MemVar Maybe (Exp rep)
forall a. Maybe a
Nothing (MemEntry -> VarEntry rep) -> MemEntry -> VarEntry rep
forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
space
  VName -> ImpM rep r op VName
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
name'

sAlloc_ :: VName -> Count Bytes (Imp.TExp Int64) -> Space -> ImpM rep r op ()
sAlloc_ :: forall rep r op.
VName -> Count Bytes (TExp Int64) -> Space -> ImpM rep r op ()
sAlloc_ VName
name' Count Bytes (TExp Int64)
size' Space
space = do
  Maybe (AllocCompiler rep r op)
allocator <- (Env rep r op -> Maybe (AllocCompiler rep r op))
-> ImpM rep r op (Maybe (AllocCompiler rep r op))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((Env rep r op -> Maybe (AllocCompiler rep r op))
 -> ImpM rep r op (Maybe (AllocCompiler rep r op)))
-> (Env rep r op -> Maybe (AllocCompiler rep r op))
-> ImpM rep r op (Maybe (AllocCompiler rep r op))
forall a b. (a -> b) -> a -> b
$ Space
-> Map Space (AllocCompiler rep r op)
-> Maybe (AllocCompiler rep r op)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Space
space (Map Space (AllocCompiler rep r op)
 -> Maybe (AllocCompiler rep r op))
-> (Env rep r op -> Map Space (AllocCompiler rep r op))
-> Env rep r op
-> Maybe (AllocCompiler rep r op)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env rep r op -> Map Space (AllocCompiler rep r op)
forall rep r op. Env rep r op -> Map Space (AllocCompiler rep r op)
envAllocCompilers
  case Maybe (AllocCompiler rep r op)
allocator of
    Maybe (AllocCompiler rep r op)
Nothing -> Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Count Bytes (TExp Int64) -> Space -> Code op
forall a. VName -> Count Bytes (TExp Int64) -> Space -> Code a
Imp.Allocate VName
name' Count Bytes (TExp Int64)
size' Space
space
    Just AllocCompiler rep r op
allocator' -> AllocCompiler rep r op
allocator' VName
name' Count Bytes (TExp Int64)
size'

sAlloc :: String -> Count Bytes (Imp.TExp Int64) -> Space -> ImpM rep r op VName
sAlloc :: forall rep r op.
[Char] -> Count Bytes (TExp Int64) -> Space -> ImpM rep r op VName
sAlloc [Char]
name Count Bytes (TExp Int64)
size Space
space = do
  VName
name' <- [Char] -> Space -> ImpM rep r op VName
forall rep r op. [Char] -> Space -> ImpM rep r op VName
sDeclareMem [Char]
name Space
space
  VName -> Count Bytes (TExp Int64) -> Space -> ImpM rep r op ()
forall rep r op.
VName -> Count Bytes (TExp Int64) -> Space -> ImpM rep r op ()
sAlloc_ VName
name' Count Bytes (TExp Int64)
size Space
space
  VName -> ImpM rep r op VName
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
name'

sArray :: String -> PrimType -> ShapeBase SubExp -> VName -> LMAD -> ImpM rep r op VName
sArray :: forall rep r op.
[Char]
-> PrimType
-> Shape
-> VName
-> LMAD (TExp Int64)
-> ImpM rep r op VName
sArray [Char]
name PrimType
bt Shape
shape VName
mem LMAD (TExp Int64)
ixfun = do
  VName
name' <- [Char] -> ImpM rep r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
name
  VName
-> PrimType
-> Shape
-> VName
-> LMAD (TExp Int64)
-> ImpM rep r op ()
forall rep r op.
VName
-> PrimType
-> Shape
-> VName
-> LMAD (TExp Int64)
-> ImpM rep r op ()
dArray VName
name' PrimType
bt Shape
shape VName
mem LMAD (TExp Int64)
ixfun
  VName -> ImpM rep r op VName
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
name'

-- | Declare an array in row-major order in the given memory block.
sArrayInMem :: String -> PrimType -> ShapeBase SubExp -> VName -> ImpM rep r op VName
sArrayInMem :: forall rep r op.
[Char] -> PrimType -> Shape -> VName -> ImpM rep r op VName
sArrayInMem [Char]
name PrimType
pt Shape
shape VName
mem =
  [Char]
-> PrimType
-> Shape
-> VName
-> LMAD (TExp Int64)
-> ImpM rep r op VName
forall rep r op.
[Char]
-> PrimType
-> Shape
-> VName
-> LMAD (TExp Int64)
-> ImpM rep r op VName
sArray [Char]
name PrimType
pt Shape
shape VName
mem (LMAD (TExp Int64) -> ImpM rep r op VName)
-> LMAD (TExp Int64) -> ImpM rep r op VName
forall a b. (a -> b) -> a -> b
$
    TExp Int64 -> [TExp Int64] -> LMAD (TExp Int64)
forall num. IntegralExp num => num -> [num] -> LMAD num
LMAD.iota TExp Int64
0 ([TExp Int64] -> LMAD (TExp Int64))
-> [TExp Int64] -> LMAD (TExp Int64)
forall a b. (a -> b) -> a -> b
$
      (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map (Exp -> TExp Int64
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (Exp -> TExp Int64) -> (SubExp -> Exp) -> SubExp -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> SubExp -> Exp
primExpFromSubExp PrimType
int64) ([SubExp] -> [TExp Int64]) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$
        Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape

-- | Like 'sAllocArray', but permute the in-memory representation of the indices as specified.
sAllocArrayPerm :: String -> PrimType -> ShapeBase SubExp -> Space -> [Int] -> ImpM rep r op VName
sAllocArrayPerm :: forall rep r op.
[Char]
-> PrimType -> Shape -> Space -> [Int] -> ImpM rep r op VName
sAllocArrayPerm [Char]
name PrimType
pt Shape
shape Space
space [Int]
perm = do
  let permuted_dims :: [SubExp]
permuted_dims = [Int] -> [SubExp] -> [SubExp]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape
  VName
mem <- [Char] -> Count Bytes (TExp Int64) -> Space -> ImpM rep r op VName
forall rep r op.
[Char] -> Count Bytes (TExp Int64) -> Space -> ImpM rep r op VName
sAlloc ([Char]
name [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
"_mem") (Type -> Count Bytes (TExp Int64)
typeSize (PrimType -> Shape -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
pt Shape
shape NoUniqueness
NoUniqueness)) Space
space
  let iota_ixfun :: LMAD (TExp Int64)
iota_ixfun = TExp Int64 -> [TExp Int64] -> LMAD (TExp Int64)
forall num. IntegralExp num => num -> [num] -> LMAD num
LMAD.iota TExp Int64
0 ([TExp Int64] -> LMAD (TExp Int64))
-> [TExp Int64] -> LMAD (TExp Int64)
forall a b. (a -> b) -> a -> b
$ (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map (Exp -> TExp Int64
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (Exp -> TExp Int64) -> (SubExp -> Exp) -> SubExp -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> SubExp -> Exp
primExpFromSubExp PrimType
int64) [SubExp]
permuted_dims
  [Char]
-> PrimType
-> Shape
-> VName
-> LMAD (TExp Int64)
-> ImpM rep r op VName
forall rep r op.
[Char]
-> PrimType
-> Shape
-> VName
-> LMAD (TExp Int64)
-> ImpM rep r op VName
sArray [Char]
name PrimType
pt Shape
shape VName
mem (LMAD (TExp Int64) -> ImpM rep r op VName)
-> LMAD (TExp Int64) -> ImpM rep r op VName
forall a b. (a -> b) -> a -> b
$
    LMAD (TExp Int64) -> [Int] -> LMAD (TExp Int64)
forall num. LMAD num -> [Int] -> LMAD num
LMAD.permute LMAD (TExp Int64)
iota_ixfun ([Int] -> LMAD (TExp Int64)) -> [Int] -> LMAD (TExp Int64)
forall a b. (a -> b) -> a -> b
$
      [Int] -> [Int]
rearrangeInverse [Int]
perm

-- | Uses linear/iota index function.
sAllocArray :: String -> PrimType -> ShapeBase SubExp -> Space -> ImpM rep r op VName
sAllocArray :: forall rep r op.
[Char] -> PrimType -> Shape -> Space -> ImpM rep r op VName
sAllocArray [Char]
name PrimType
pt Shape
shape Space
space =
  [Char]
-> PrimType -> Shape -> Space -> [Int] -> ImpM rep r op VName
forall rep r op.
[Char]
-> PrimType -> Shape -> Space -> [Int] -> ImpM rep r op VName
sAllocArrayPerm [Char]
name PrimType
pt Shape
shape Space
space [Int
0 .. Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
shape Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]

-- | Uses linear/iota index function.
sStaticArray :: String -> PrimType -> Imp.ArrayContents -> ImpM rep r op VName
sStaticArray :: forall rep r op.
[Char] -> PrimType -> ArrayContents -> ImpM rep r op VName
sStaticArray [Char]
name PrimType
pt ArrayContents
vs = do
  let num_elems :: Int
num_elems = case ArrayContents
vs of
        Imp.ArrayValues [PrimValue]
vs' -> [PrimValue] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimValue]
vs'
        Imp.ArrayZeros Int
n -> Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n
      shape :: Shape
shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [IntType -> Integer -> SubExp
intConst IntType
Int64 (Integer -> SubExp) -> Integer -> SubExp
forall a b. (a -> b) -> a -> b
$ Int -> Integer
forall a. Integral a => a -> Integer
toInteger Int
num_elems]
  VName
mem <- [Char] -> ImpM rep r op VName
forall rep r op. [Char] -> ImpM rep r op VName
newVNameForFun ([Char] -> ImpM rep r op VName) -> [Char] -> ImpM rep r op VName
forall a b. (a -> b) -> a -> b
$ [Char]
name [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
"_mem"
  Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> ArrayContents -> Code op
forall a. VName -> PrimType -> ArrayContents -> Code a
Imp.DeclareArray VName
mem PrimType
pt ArrayContents
vs
  VName -> VarEntry rep -> ImpM rep r op ()
forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
mem (VarEntry rep -> ImpM rep r op ())
-> VarEntry rep -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp rep) -> MemEntry -> VarEntry rep
forall rep. Maybe (Exp rep) -> MemEntry -> VarEntry rep
MemVar Maybe (Exp rep)
forall a. Maybe a
Nothing (MemEntry -> VarEntry rep) -> MemEntry -> VarEntry rep
forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
DefaultSpace
  [Char]
-> PrimType
-> Shape
-> VName
-> LMAD (TExp Int64)
-> ImpM rep r op VName
forall rep r op.
[Char]
-> PrimType
-> Shape
-> VName
-> LMAD (TExp Int64)
-> ImpM rep r op VName
sArray [Char]
name PrimType
pt Shape
shape VName
mem (LMAD (TExp Int64) -> ImpM rep r op VName)
-> LMAD (TExp Int64) -> ImpM rep r op VName
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> [TExp Int64] -> LMAD (TExp Int64)
forall num. IntegralExp num => num -> [num] -> LMAD num
LMAD.iota TExp Int64
0 [Int -> TExp Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
num_elems]

sWrite :: VName -> [Imp.TExp Int64] -> Imp.Exp -> ImpM rep r op ()
sWrite :: forall rep r op. VName -> [TExp Int64] -> Exp -> ImpM rep r op ()
sWrite VName
arr [TExp Int64]
is Exp
v = do
  (VName
mem, Space
space, Count Elements (TExp Int64)
offset) <- VName
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
forall rep r op.
VName
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray VName
arr [TExp Int64]
is
  Volatility
vol <- (Env rep r op -> Volatility) -> ImpM rep r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env rep r op -> Volatility
forall rep r op. Env rep r op -> Volatility
envVolatility
  Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code op
forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
mem Count Elements (TExp Int64)
offset (Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType Exp
v) Space
space Volatility
vol Exp
v

sUpdate :: VName -> Slice (Imp.TExp Int64) -> SubExp -> ImpM rep r op ()
sUpdate :: forall rep r op.
VName -> Slice (TExp Int64) -> SubExp -> ImpM rep r op ()
sUpdate VName
arr Slice (TExp Int64)
slice SubExp
v = VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM VName
arr (Slice (TExp Int64) -> [DimIndex (TExp Int64)]
forall d. Slice d -> [DimIndex d]
unSlice Slice (TExp Int64)
slice) SubExp
v []

-- | Create a sequential 'Imp.For' loop covering a space of the given
-- shape.  The function is calling with the indexes for a given
-- iteration.
sLoopSpace ::
  [Imp.TExp t] ->
  ([Imp.TExp t] -> ImpM rep r op ()) ->
  ImpM rep r op ()
sLoopSpace :: forall {k} (t :: k) rep r op.
[TExp t] -> ([TExp t] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopSpace = [TExp t]
-> [TExp t] -> ([TExp t] -> ImpM rep r op ()) -> ImpM rep r op ()
forall {k} {t :: k} {rep} {r} {op}.
[TExp t]
-> [TExp t] -> ([TExp t] -> ImpM rep r op ()) -> ImpM rep r op ()
nest []
  where
    nest :: [TExp t]
-> [TExp t] -> ([TExp t] -> ImpM rep r op ()) -> ImpM rep r op ()
nest [TExp t]
is [] [TExp t] -> ImpM rep r op ()
f = [TExp t] -> ImpM rep r op ()
f ([TExp t] -> ImpM rep r op ()) -> [TExp t] -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ [TExp t] -> [TExp t]
forall a. [a] -> [a]
reverse [TExp t]
is
    nest [TExp t]
is (TExp t
d : [TExp t]
ds) [TExp t] -> ImpM rep r op ()
f = [Char]
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
forall {k} (t :: k) rep r op.
[Char]
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor [Char]
"nest_i" TExp t
d ((TExp t -> ImpM rep r op ()) -> ImpM rep r op ())
-> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \TExp t
i -> [TExp t]
-> [TExp t] -> ([TExp t] -> ImpM rep r op ()) -> ImpM rep r op ()
nest (TExp t
i TExp t -> [TExp t] -> [TExp t]
forall a. a -> [a] -> [a]
: [TExp t]
is) [TExp t]
ds [TExp t] -> ImpM rep r op ()
f

sLoopNest ::
  Shape ->
  ([Imp.TExp Int64] -> ImpM rep r op ()) ->
  ImpM rep r op ()
sLoopNest :: forall rep r op.
Shape -> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest = [TExp Int64]
-> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
forall {k} (t :: k) rep r op.
[TExp t] -> ([TExp t] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopSpace ([TExp Int64]
 -> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ())
-> (Shape -> [TExp Int64])
-> Shape
-> ([TExp Int64] -> ImpM rep r op ())
-> ImpM rep r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 ([SubExp] -> [TExp Int64])
-> (Shape -> [SubExp]) -> Shape -> [TExp Int64]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims

-- | Untyped assignment.
(<~~) :: VName -> Imp.Exp -> ImpM rep r op ()
VName
x <~~ :: forall rep r op. VName -> Exp -> ImpM rep r op ()
<~~ Exp
e = Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
x Exp
e

infixl 3 <~~

-- | Typed assignment.
(<--) :: TV t -> Imp.TExp t -> ImpM rep r op ()
TV VName
x PrimType
_ <-- :: forall {k} (t :: k) rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TExp t
e = Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
x (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ TExp t -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp t
e

infixl 3 <--

-- | Constructing an ad-hoc function that does not
-- correspond to any of the IR functions in the input program.
function ::
  Name ->
  [Imp.Param] ->
  [Imp.Param] ->
  ImpM rep r op () ->
  ImpM rep r op ()
function :: forall rep r op.
Name -> [Param] -> [Param] -> ImpM rep r op () -> ImpM rep r op ()
function Name
fname [Param]
outputs [Param]
inputs ImpM rep r op ()
m = (Env rep r op -> Env rep r op)
-> ImpM rep r op () -> ImpM rep r op ()
forall a.
(Env rep r op -> Env rep r op)
-> ImpM rep r op a -> ImpM rep r op a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local Env rep r op -> Env rep r op
newFunction (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ do
  Code op
body <- ImpM rep r op () -> ImpM rep r op (Code op)
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect (ImpM rep r op () -> ImpM rep r op (Code op))
-> ImpM rep r op () -> ImpM rep r op (Code op)
forall a b. (a -> b) -> a -> b
$ do
    (Param -> ImpM rep r op ()) -> [Param] -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Param -> ImpM rep r op ()
forall {rep} {r} {op}. Param -> ImpM rep r op ()
addParam ([Param] -> ImpM rep r op ()) -> [Param] -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ [Param]
outputs [Param] -> [Param] -> [Param]
forall a. [a] -> [a] -> [a]
++ [Param]
inputs
    ImpM rep r op ()
m
  Name -> Function op -> ImpM rep r op ()
forall op rep r. Name -> Function op -> ImpM rep r op ()
emitFunction Name
fname (Function op -> ImpM rep r op ())
-> Function op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Maybe EntryPoint -> [Param] -> [Param] -> Code op -> Function op
forall a.
Maybe EntryPoint -> [Param] -> [Param] -> Code a -> FunctionT a
Imp.Function Maybe EntryPoint
forall a. Maybe a
Nothing [Param]
outputs [Param]
inputs Code op
body
  where
    addParam :: Param -> ImpM rep r op ()
addParam (Imp.MemParam VName
name Space
space) =
      VName -> VarEntry rep -> ImpM rep r op ()
forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
name (VarEntry rep -> ImpM rep r op ())
-> VarEntry rep -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp rep) -> MemEntry -> VarEntry rep
forall rep. Maybe (Exp rep) -> MemEntry -> VarEntry rep
MemVar Maybe (Exp rep)
forall a. Maybe a
Nothing (MemEntry -> VarEntry rep) -> MemEntry -> VarEntry rep
forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
space
    addParam (Imp.ScalarParam VName
name PrimType
bt) =
      VName -> VarEntry rep -> ImpM rep r op ()
forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
name (VarEntry rep -> ImpM rep r op ())
-> VarEntry rep -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp rep) -> ScalarEntry -> VarEntry rep
forall rep. Maybe (Exp rep) -> ScalarEntry -> VarEntry rep
ScalarVar Maybe (Exp rep)
forall a. Maybe a
Nothing (ScalarEntry -> VarEntry rep) -> ScalarEntry -> VarEntry rep
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarEntry
ScalarEntry PrimType
bt
    newFunction :: Env rep r op -> Env rep r op
newFunction Env rep r op
env = Env rep r op
env {envFunction :: Maybe Name
envFunction = Name -> Maybe Name
forall a. a -> Maybe a
Just Name
fname}

-- Fish out those top-level declarations in the constant
-- initialisation code that are free in the functions.
constParams :: Names -> Imp.Code a -> (DL.DList Imp.Param, Imp.Code a)
constParams :: forall a. Names -> Code a -> (DList Param, Code a)
constParams Names
used (Code a
x Imp.:>>: Code a
y) =
  Names -> Code a -> (DList Param, Code a)
forall a. Names -> Code a -> (DList Param, Code a)
constParams Names
used Code a
x (DList Param, Code a)
-> (DList Param, Code a) -> (DList Param, Code a)
forall a. Semigroup a => a -> a -> a
<> Names -> Code a -> (DList Param, Code a)
forall a. Names -> Code a -> (DList Param, Code a)
constParams Names
used Code a
y
constParams Names
used (Imp.DeclareMem VName
name Space
space)
  | VName
name VName -> Names -> Bool
`nameIn` Names
used =
      ( Param -> DList Param
forall a. a -> DList a
DL.singleton (Param -> DList Param) -> Param -> DList Param
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Param
Imp.MemParam VName
name Space
space,
        Code a
forall a. Monoid a => a
mempty
      )
constParams Names
used (Imp.DeclareScalar VName
name Volatility
_ PrimType
t)
  | VName
name VName -> Names -> Bool
`nameIn` Names
used =
      ( Param -> DList Param
forall a. a -> DList a
DL.singleton (Param -> DList Param) -> Param -> DList Param
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Param
Imp.ScalarParam VName
name PrimType
t,
        Code a
forall a. Monoid a => a
mempty
      )
constParams Names
used s :: Code a
s@(Imp.DeclareArray VName
name PrimType
_ ArrayContents
_)
  | VName
name VName -> Names -> Bool
`nameIn` Names
used =
      ( Param -> DList Param
forall a. a -> DList a
DL.singleton (Param -> DList Param) -> Param -> DList Param
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Param
Imp.MemParam VName
name Space
DefaultSpace,
        Code a
s
      )
constParams Names
_ Code a
s =
  (DList Param
forall a. Monoid a => a
mempty, Code a
s)

-- | Generate constants that get put outside of all functions.  Will
-- be executed at program startup.  Action must return the names that
-- should should be made available.  This one has real sharp edges. Do
-- not use inside 'subImpM'.  Do not use any variable from the context.
genConstants :: ImpM rep r op (Names, a) -> ImpM rep r op a
genConstants :: forall rep r op a. ImpM rep r op (Names, a) -> ImpM rep r op a
genConstants ImpM rep r op (Names, a)
m = do
  ((Names
avail, a
a), Code op
code) <- ImpM rep r op (Names, a) -> ImpM rep r op ((Names, a), Code op)
forall rep r op a. ImpM rep r op a -> ImpM rep r op (a, Code op)
collect' ImpM rep r op (Names, a)
m
  let consts :: Constants op
consts = ([Param] -> Code op -> Constants op)
-> ([Param], Code op) -> Constants op
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry [Param] -> Code op -> Constants op
forall a. [Param] -> Code a -> Constants a
Imp.Constants (([Param], Code op) -> Constants op)
-> ([Param], Code op) -> Constants op
forall a b. (a -> b) -> a -> b
$ (DList Param -> [Param])
-> (DList Param, Code op) -> ([Param], Code op)
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first DList Param -> [Param]
forall a. DList a -> [a]
DL.toList ((DList Param, Code op) -> ([Param], Code op))
-> (DList Param, Code op) -> ([Param], Code op)
forall a b. (a -> b) -> a -> b
$ Names -> Code op -> (DList Param, Code op)
forall a. Names -> Code a -> (DList Param, Code a)
constParams Names
avail Code op
code
  (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ())
-> (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s -> ImpState rep r op
s {stateConstants :: Constants op
stateConstants = ImpState rep r op -> Constants op
forall {k} rep (r :: k) op. ImpState rep r op -> Constants op
stateConstants ImpState rep r op
s Constants op -> Constants op -> Constants op
forall a. Semigroup a => a -> a -> a
<> Constants op
consts}
  a -> ImpM rep r op a
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
a

dSlices :: [Imp.TExp Int64] -> ImpM rep r op [Imp.TExp Int64]
dSlices :: forall rep r op. [TExp Int64] -> ImpM rep r op [TExp Int64]
dSlices = ((TExp Int64, [TExp Int64]) -> [TExp Int64])
-> ImpM rep r op (TExp Int64, [TExp Int64])
-> ImpM rep r op [TExp Int64]
forall a b. (a -> b) -> ImpM rep r op a -> ImpM rep r op b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int -> [TExp Int64] -> [TExp Int64]
forall a. Int -> [a] -> [a]
drop Int
1 ([TExp Int64] -> [TExp Int64])
-> ((TExp Int64, [TExp Int64]) -> [TExp Int64])
-> (TExp Int64, [TExp Int64])
-> [TExp Int64]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TExp Int64, [TExp Int64]) -> [TExp Int64]
forall a b. (a, b) -> b
snd) (ImpM rep r op (TExp Int64, [TExp Int64])
 -> ImpM rep r op [TExp Int64])
-> ([TExp Int64] -> ImpM rep r op (TExp Int64, [TExp Int64]))
-> [TExp Int64]
-> ImpM rep r op [TExp Int64]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [TExp Int64] -> ImpM rep r op (TExp Int64, [TExp Int64])
forall {k} {t :: k} {rep} {r} {op}.
NumExp t =>
[TExp t] -> ImpM rep r op (TExp t, [TExp t])
dSlices'
  where
    dSlices' :: [TExp t] -> ImpM rep r op (TExp t, [TExp t])
dSlices' [] = (TExp t, [TExp t]) -> ImpM rep r op (TExp t, [TExp t])
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TExp t
1, [TExp t
1])
    dSlices' (TExp t
n : [TExp t]
ns) = do
      (TExp t
prod, [TExp t]
ns') <- [TExp t] -> ImpM rep r op (TExp t, [TExp t])
dSlices' [TExp t]
ns
      TExp t
n' <- [Char] -> TExp t -> ImpM rep r op (TExp t)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"slice" (TExp t -> ImpM rep r op (TExp t))
-> TExp t -> ImpM rep r op (TExp t)
forall a b. (a -> b) -> a -> b
$ TExp t
n TExp t -> TExp t -> TExp t
forall a. Num a => a -> a -> a
* TExp t
prod
      (TExp t, [TExp t]) -> ImpM rep r op (TExp t, [TExp t])
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TExp t
n', TExp t
n' TExp t -> [TExp t] -> [TExp t]
forall a. a -> [a] -> [a]
: [TExp t]
ns')

-- | @dIndexSpace f dims i@ computes a list of indices into an
-- array with dimension @dims@ given the flat index @i@.  The
-- resulting list will have the same size as @dims@.  Intermediate
-- results are passed to @f@.
dIndexSpace ::
  [(VName, Imp.TExp Int64)] ->
  Imp.TExp Int64 ->
  ImpM rep r op ()
dIndexSpace :: forall rep r op.
[(VName, TExp Int64)] -> TExp Int64 -> ImpM rep r op ()
dIndexSpace [(VName, TExp Int64)]
vs_ds TExp Int64
j = do
  [TExp Int64]
slices <- [TExp Int64] -> ImpM rep r op [TExp Int64]
forall rep r op. [TExp Int64] -> ImpM rep r op [TExp Int64]
dSlices (((VName, TExp Int64) -> TExp Int64)
-> [(VName, TExp Int64)] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map (VName, TExp Int64) -> TExp Int64
forall a b. (a, b) -> b
snd [(VName, TExp Int64)]
vs_ds)
  [(VName, TExp Int64)] -> TExp Int64 -> ImpM rep r op ()
forall rep r op.
[(VName, TExp Int64)] -> TExp Int64 -> ImpM rep r op ()
loop ([VName] -> [TExp Int64] -> [(VName, TExp Int64)]
forall a b. [a] -> [b] -> [(a, b)]
zip (((VName, TExp Int64) -> VName) -> [(VName, TExp Int64)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, TExp Int64) -> VName
forall a b. (a, b) -> a
fst [(VName, TExp Int64)]
vs_ds) [TExp Int64]
slices) TExp Int64
j
  where
    loop :: [(VName, TExp Int64)] -> TExp Int64 -> ImpM rep r op ()
loop ((VName
v, TExp Int64
size) : [(VName, TExp Int64)]
rest) TExp Int64
i = do
      VName -> TExp Int64 -> ImpM rep r op ()
forall {k} (t :: k) rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
v (TExp Int64
i TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`quot` TExp Int64
size)
      TExp Int64
i' <- [Char] -> TExp Int64 -> ImpM rep r op (TExp Int64)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"remnant" (TExp Int64 -> ImpM rep r op (TExp Int64))
-> TExp Int64 -> ImpM rep r op (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int64
i TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- VName -> TExp Int64
forall a. a -> TPrimExp Int64 a
Imp.le64 VName
v TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
size
      [(VName, TExp Int64)] -> TExp Int64 -> ImpM rep r op ()
loop [(VName, TExp Int64)]
rest TExp Int64
i'
    loop [(VName, TExp Int64)]
_ TExp Int64
_ = () -> ImpM rep r op ()
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

-- | Like 'dIndexSpace', but invent some new names for the indexes
-- based on the given template.
dIndexSpace' ::
  String ->
  [Imp.TExp Int64] ->
  Imp.TExp Int64 ->
  ImpM rep r op [Imp.TExp Int64]
dIndexSpace' :: forall rep r op.
[Char] -> [TExp Int64] -> TExp Int64 -> ImpM rep r op [TExp Int64]
dIndexSpace' [Char]
desc [TExp Int64]
ds TExp Int64
j = do
  [VName]
ivs <- Int -> ImpM rep r op VName -> ImpM rep r op [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM ([TExp Int64] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TExp Int64]
ds) ([Char] -> ImpM rep r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
desc)
  [(VName, TExp Int64)] -> TExp Int64 -> ImpM rep r op ()
forall rep r op.
[(VName, TExp Int64)] -> TExp Int64 -> ImpM rep r op ()
dIndexSpace ([VName] -> [TExp Int64] -> [(VName, TExp Int64)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
ivs [TExp Int64]
ds) TExp Int64
j
  [TExp Int64] -> ImpM rep r op [TExp Int64]
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([TExp Int64] -> ImpM rep r op [TExp Int64])
-> [TExp Int64] -> ImpM rep r op [TExp Int64]
forall a b. (a -> b) -> a -> b
$ (VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
ivs