{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE Strict #-}
{-# LANGUAGE TypeFamilies #-}

module Futhark.CodeGen.ImpGen
  ( -- * Entry Points
    compileProg,

    -- * Pluggable Compiler
    OpCompiler,
    ExpCompiler,
    CopyCompiler,
    StmsCompiler,
    AllocCompiler,
    Operations (..),
    defaultOperations,
    MemLoc (..),
    sliceMemLoc,
    MemEntry (..),
    ScalarEntry (..),

    -- * Monadic Compiler Interface
    ImpM,
    localDefaultSpace,
    askFunction,
    newVNameForFun,
    nameForFun,
    askEnv,
    localEnv,
    localOps,
    VTable,
    getVTable,
    localVTable,
    subImpM,
    subImpM_,
    emit,
    emitFunction,
    hasFunction,
    collect,
    collect',
    comment,
    VarEntry (..),
    ArrayEntry (..),

    -- * Lookups
    lookupVar,
    lookupArray,
    lookupMemory,
    lookupAcc,
    askAttrs,

    -- * Building Blocks
    TV,
    mkTV,
    tvSize,
    tvExp,
    tvVar,
    ToExp (..),
    compileAlloc,
    everythingVolatile,
    compileBody,
    compileBody',
    compileLoopBody,
    defCompileStms,
    compileStms,
    compileExp,
    defCompileExp,
    fullyIndexArray,
    fullyIndexArray',
    copy,
    copyDWIM,
    copyDWIMFix,
    copyElementWise,
    typeSize,
    inBounds,
    isMapTransposeCopy,
    caseMatch,

    -- * Constructing code.
    dLParams,
    dFParams,
    addLoopVar,
    dScope,
    dArray,
    dPrim,
    dPrimVol,
    dPrim_,
    dPrimV_,
    dPrimV,
    dPrimVE,
    dIndexSpace,
    dIndexSpace',
    rotateIndex,
    sFor,
    sWhile,
    sComment,
    sIf,
    sWhen,
    sUnless,
    sOp,
    sDeclareMem,
    sAlloc,
    sAlloc_,
    sArray,
    sArrayInMem,
    sAllocArray,
    sAllocArrayPerm,
    sStaticArray,
    sWrite,
    sUpdate,
    sLoopNest,
    sCopy,
    sLoopSpace,
    (<--),
    (<~~),
    function,
    genConstants,
    warn,
    module Language.Futhark.Warnings,
  )
where

import Control.Monad
import Control.Monad.Reader
import Control.Monad.State
import Control.Parallel.Strategies
import Data.Bifunctor (first)
import Data.DList qualified as DL
import Data.Either
import Data.List (find)
import Data.List.NonEmpty (NonEmpty (..))
import Data.Map.Strict qualified as M
import Data.Maybe
import Data.Set qualified as S
import Data.String
import Data.Text qualified as T
import Futhark.CodeGen.ImpCode
  ( Bytes,
    Count,
    Elements,
    bytes,
    elements,
    withElemType,
  )
import Futhark.CodeGen.ImpCode qualified as Imp
import Futhark.CodeGen.ImpGen.Transpose
import Futhark.Construct hiding (ToExp (..))
import Futhark.IR.Mem
import Futhark.IR.Mem.IxFun qualified as IxFun
import Futhark.IR.SOACS (SOACS)
import Futhark.Util
import Futhark.Util.IntegralExp
import Futhark.Util.Loc (noLoc)
import Futhark.Util.Pretty hiding (nest, space)
import Language.Futhark.Warnings
import Prelude hiding (mod, quot)

-- | How to compile an t'Op'.
type OpCompiler rep r op = Pat (LetDec rep) -> Op rep -> ImpM rep r op ()

-- | How to compile some 'Stms'.
type StmsCompiler rep r op = Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()

-- | How to compile an 'Exp'.
type ExpCompiler rep r op = Pat (LetDec rep) -> Exp rep -> ImpM rep r op ()

type CopyCompiler rep r op =
  PrimType ->
  MemLoc ->
  MemLoc ->
  ImpM rep r op ()

-- | An alternate way of compiling an allocation.
type AllocCompiler rep r op = VName -> Count Bytes (Imp.TExp Int64) -> ImpM rep r op ()

data Operations rep r op = Operations
  { forall rep r op. Operations rep r op -> ExpCompiler rep r op
opsExpCompiler :: ExpCompiler rep r op,
    forall rep r op. Operations rep r op -> OpCompiler rep r op
opsOpCompiler :: OpCompiler rep r op,
    forall rep r op. Operations rep r op -> StmsCompiler rep r op
opsStmsCompiler :: StmsCompiler rep r op,
    forall rep r op. Operations rep r op -> CopyCompiler rep r op
opsCopyCompiler :: CopyCompiler rep r op,
    forall rep r op.
Operations rep r op -> Map Space (AllocCompiler rep r op)
opsAllocCompilers :: M.Map Space (AllocCompiler rep r op)
  }

-- | An operations set for which the expression compiler always
-- returns 'defCompileExp'.
defaultOperations ::
  (Mem rep inner, FreeIn op) =>
  OpCompiler rep r op ->
  Operations rep r op
defaultOperations :: forall rep (inner :: * -> *) op r.
(Mem rep inner, FreeIn op) =>
OpCompiler rep r op -> Operations rep r op
defaultOperations OpCompiler rep r op
opc =
  Operations
    { opsExpCompiler :: ExpCompiler rep r op
opsExpCompiler = forall rep (inner :: * -> *) r op.
Mem rep inner =>
Pat (LetDec rep) -> Exp rep -> ImpM rep r op ()
defCompileExp,
      opsOpCompiler :: OpCompiler rep r op
opsOpCompiler = OpCompiler rep r op
opc,
      opsStmsCompiler :: StmsCompiler rep r op
opsStmsCompiler = forall rep (inner :: * -> *) op r.
(Mem rep inner, FreeIn op) =>
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
defCompileStms,
      opsCopyCompiler :: CopyCompiler rep r op
opsCopyCompiler = forall rep r op. CopyCompiler rep r op
defaultCopy,
      opsAllocCompilers :: Map Space (AllocCompiler rep r op)
opsAllocCompilers = forall a. Monoid a => a
mempty
    }

-- | When an array is declared, this is where it is stored.
data MemLoc = MemLoc
  { MemLoc -> VName
memLocName :: VName,
    MemLoc -> [SubExp]
memLocShape :: [Imp.DimSize],
    MemLoc -> IxFun (TExp Int64)
memLocIxFun :: IxFun.IxFun (Imp.TExp Int64)
  }
  deriving (MemLoc -> MemLoc -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: MemLoc -> MemLoc -> Bool
$c/= :: MemLoc -> MemLoc -> Bool
== :: MemLoc -> MemLoc -> Bool
$c== :: MemLoc -> MemLoc -> Bool
Eq, Int -> MemLoc -> ShowS
[MemLoc] -> ShowS
MemLoc -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [MemLoc] -> ShowS
$cshowList :: [MemLoc] -> ShowS
show :: MemLoc -> [Char]
$cshow :: MemLoc -> [Char]
showsPrec :: Int -> MemLoc -> ShowS
$cshowsPrec :: Int -> MemLoc -> ShowS
Show)

sliceMemLoc :: MemLoc -> Slice (Imp.TExp Int64) -> MemLoc
sliceMemLoc :: MemLoc -> Slice (TExp Int64) -> MemLoc
sliceMemLoc (MemLoc VName
mem [SubExp]
shape IxFun (TExp Int64)
ixfun) Slice (TExp Int64)
slice =
  VName -> [SubExp] -> IxFun (TExp Int64) -> MemLoc
MemLoc VName
mem [SubExp]
shape forall a b. (a -> b) -> a -> b
$ forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TExp Int64)
ixfun Slice (TExp Int64)
slice

flatSliceMemLoc :: MemLoc -> FlatSlice (Imp.TExp Int64) -> MemLoc
flatSliceMemLoc :: MemLoc -> FlatSlice (TExp Int64) -> MemLoc
flatSliceMemLoc (MemLoc VName
mem [SubExp]
shape IxFun (TExp Int64)
ixfun) FlatSlice (TExp Int64)
slice =
  VName -> [SubExp] -> IxFun (TExp Int64) -> MemLoc
MemLoc VName
mem [SubExp]
shape forall a b. (a -> b) -> a -> b
$ forall num.
(Eq num, IntegralExp num) =>
IxFun num -> FlatSlice num -> IxFun num
IxFun.flatSlice IxFun (TExp Int64)
ixfun FlatSlice (TExp Int64)
slice

data ArrayEntry = ArrayEntry
  { ArrayEntry -> MemLoc
entryArrayLoc :: MemLoc,
    ArrayEntry -> PrimType
entryArrayElemType :: PrimType
  }
  deriving (Int -> ArrayEntry -> ShowS
[ArrayEntry] -> ShowS
ArrayEntry -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [ArrayEntry] -> ShowS
$cshowList :: [ArrayEntry] -> ShowS
show :: ArrayEntry -> [Char]
$cshow :: ArrayEntry -> [Char]
showsPrec :: Int -> ArrayEntry -> ShowS
$cshowsPrec :: Int -> ArrayEntry -> ShowS
Show)

entryArrayShape :: ArrayEntry -> [Imp.DimSize]
entryArrayShape :: ArrayEntry -> [SubExp]
entryArrayShape = MemLoc -> [SubExp]
memLocShape forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArrayEntry -> MemLoc
entryArrayLoc

newtype MemEntry = MemEntry {MemEntry -> Space
entryMemSpace :: Imp.Space}
  deriving (Int -> MemEntry -> ShowS
[MemEntry] -> ShowS
MemEntry -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [MemEntry] -> ShowS
$cshowList :: [MemEntry] -> ShowS
show :: MemEntry -> [Char]
$cshow :: MemEntry -> [Char]
showsPrec :: Int -> MemEntry -> ShowS
$cshowsPrec :: Int -> MemEntry -> ShowS
Show)

newtype ScalarEntry = ScalarEntry
  { ScalarEntry -> PrimType
entryScalarType :: PrimType
  }
  deriving (Int -> ScalarEntry -> ShowS
[ScalarEntry] -> ShowS
ScalarEntry -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [ScalarEntry] -> ShowS
$cshowList :: [ScalarEntry] -> ShowS
show :: ScalarEntry -> [Char]
$cshow :: ScalarEntry -> [Char]
showsPrec :: Int -> ScalarEntry -> ShowS
$cshowsPrec :: Int -> ScalarEntry -> ShowS
Show)

-- | Every non-scalar variable must be associated with an entry.
data VarEntry rep
  = ArrayVar (Maybe (Exp rep)) ArrayEntry
  | ScalarVar (Maybe (Exp rep)) ScalarEntry
  | MemVar (Maybe (Exp rep)) MemEntry
  | AccVar (Maybe (Exp rep)) (VName, Shape, [Type])
  deriving (Int -> VarEntry rep -> ShowS
forall rep. RepTypes rep => Int -> VarEntry rep -> ShowS
forall rep. RepTypes rep => [VarEntry rep] -> ShowS
forall rep. RepTypes rep => VarEntry rep -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [VarEntry rep] -> ShowS
$cshowList :: forall rep. RepTypes rep => [VarEntry rep] -> ShowS
show :: VarEntry rep -> [Char]
$cshow :: forall rep. RepTypes rep => VarEntry rep -> [Char]
showsPrec :: Int -> VarEntry rep -> ShowS
$cshowsPrec :: forall rep. RepTypes rep => Int -> VarEntry rep -> ShowS
Show)

data ValueDestination
  = ScalarDestination VName
  | MemoryDestination VName
  | -- | The 'MemLoc' is 'Just' if a copy if
    -- required.  If it is 'Nothing', then a
    -- copy/assignment of a memory block somewhere
    -- takes care of this array.
    ArrayDestination (Maybe MemLoc)
  deriving (Int -> ValueDestination -> ShowS
[ValueDestination] -> ShowS
ValueDestination -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [ValueDestination] -> ShowS
$cshowList :: [ValueDestination] -> ShowS
show :: ValueDestination -> [Char]
$cshow :: ValueDestination -> [Char]
showsPrec :: Int -> ValueDestination -> ShowS
$cshowsPrec :: Int -> ValueDestination -> ShowS
Show)

data Env rep r op = Env
  { forall rep r op. Env rep r op -> ExpCompiler rep r op
envExpCompiler :: ExpCompiler rep r op,
    forall rep r op. Env rep r op -> StmsCompiler rep r op
envStmsCompiler :: StmsCompiler rep r op,
    forall rep r op. Env rep r op -> OpCompiler rep r op
envOpCompiler :: OpCompiler rep r op,
    forall rep r op. Env rep r op -> CopyCompiler rep r op
envCopyCompiler :: CopyCompiler rep r op,
    forall rep r op. Env rep r op -> Map Space (AllocCompiler rep r op)
envAllocCompilers :: M.Map Space (AllocCompiler rep r op),
    forall rep r op. Env rep r op -> Space
envDefaultSpace :: Imp.Space,
    forall rep r op. Env rep r op -> Volatility
envVolatility :: Imp.Volatility,
    -- | User-extensible environment.
    forall rep r op. Env rep r op -> r
envEnv :: r,
    -- | Name of the function we are compiling, if any.
    forall rep r op. Env rep r op -> Maybe Name
envFunction :: Maybe Name,
    -- | The set of attributes that are active on the enclosing
    -- statements (including the one we are currently compiling).
    forall rep r op. Env rep r op -> Attrs
envAttrs :: Attrs
  }

newEnv :: r -> Operations rep r op -> Imp.Space -> Env rep r op
newEnv :: forall r rep op. r -> Operations rep r op -> Space -> Env rep r op
newEnv r
r Operations rep r op
ops Space
ds =
  Env
    { envExpCompiler :: ExpCompiler rep r op
envExpCompiler = forall rep r op. Operations rep r op -> ExpCompiler rep r op
opsExpCompiler Operations rep r op
ops,
      envStmsCompiler :: StmsCompiler rep r op
envStmsCompiler = forall rep r op. Operations rep r op -> StmsCompiler rep r op
opsStmsCompiler Operations rep r op
ops,
      envOpCompiler :: OpCompiler rep r op
envOpCompiler = forall rep r op. Operations rep r op -> OpCompiler rep r op
opsOpCompiler Operations rep r op
ops,
      envCopyCompiler :: CopyCompiler rep r op
envCopyCompiler = forall rep r op. Operations rep r op -> CopyCompiler rep r op
opsCopyCompiler Operations rep r op
ops,
      envAllocCompilers :: Map Space (AllocCompiler rep r op)
envAllocCompilers = forall a. Monoid a => a
mempty,
      envDefaultSpace :: Space
envDefaultSpace = Space
ds,
      envVolatility :: Volatility
envVolatility = Volatility
Imp.Nonvolatile,
      envEnv :: r
envEnv = r
r,
      envFunction :: Maybe Name
envFunction = forall a. Maybe a
Nothing,
      envAttrs :: Attrs
envAttrs = forall a. Monoid a => a
mempty
    }

-- | The symbol table used during compilation.
type VTable rep = M.Map VName (VarEntry rep)

data ImpState rep r op = ImpState
  { forall {k} rep (r :: k) op. ImpState rep r op -> VTable rep
stateVTable :: VTable rep,
    forall {k} rep (r :: k) op. ImpState rep r op -> Functions op
stateFunctions :: Imp.Functions op,
    forall {k} rep (r :: k) op. ImpState rep r op -> Code op
stateCode :: Imp.Code op,
    forall {k} rep (r :: k) op. ImpState rep r op -> Constants op
stateConstants :: Imp.Constants op,
    forall {k} rep (r :: k) op. ImpState rep r op -> Warnings
stateWarnings :: Warnings,
    -- | Maps the arrays backing each accumulator to their
    -- update function and neutral elements.  This works
    -- because an array name can only become part of a single
    -- accumulator throughout its lifetime.  If the arrays
    -- backing an accumulator is not in this mapping, the
    -- accumulator is scatter-like.
    forall {k} rep (r :: k) op.
ImpState rep r op
-> Map VName ([VName], Maybe (Lambda rep, [SubExp]))
stateAccs :: M.Map VName ([VName], Maybe (Lambda rep, [SubExp])),
    forall {k} rep (r :: k) op. ImpState rep r op -> VNameSource
stateNameSource :: VNameSource
  }

newState :: VNameSource -> ImpState rep r op
newState :: forall {k} rep (r :: k) op. VNameSource -> ImpState rep r op
newState = forall {k} rep (r :: k) op.
VTable rep
-> Functions op
-> Code op
-> Constants op
-> Warnings
-> Map VName ([VName], Maybe (Lambda rep, [SubExp]))
-> VNameSource
-> ImpState rep r op
ImpState forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty

newtype ImpM rep r op a
  = ImpM (ReaderT (Env rep r op) (State (ImpState rep r op)) a)
  deriving
    ( forall a b. a -> ImpM rep r op b -> ImpM rep r op a
forall a b. (a -> b) -> ImpM rep r op a -> ImpM rep r op b
forall rep r op a b. a -> ImpM rep r op b -> ImpM rep r op a
forall rep r op a b. (a -> b) -> ImpM rep r op a -> ImpM rep r op b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> ImpM rep r op b -> ImpM rep r op a
$c<$ :: forall rep r op a b. a -> ImpM rep r op b -> ImpM rep r op a
fmap :: forall a b. (a -> b) -> ImpM rep r op a -> ImpM rep r op b
$cfmap :: forall rep r op a b. (a -> b) -> ImpM rep r op a -> ImpM rep r op b
Functor,
      forall a. a -> ImpM rep r op a
forall a b. ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op a
forall a b. ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op b
forall a b.
ImpM rep r op (a -> b) -> ImpM rep r op a -> ImpM rep r op b
forall rep r op. Functor (ImpM rep r op)
forall a b c.
(a -> b -> c)
-> ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op c
forall rep r op a. a -> ImpM rep r op a
forall rep r op a b.
ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op a
forall rep r op a b.
ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op b
forall rep r op a b.
ImpM rep r op (a -> b) -> ImpM rep r op a -> ImpM rep r op b
forall rep r op a b c.
(a -> b -> c)
-> ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b. ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op a
$c<* :: forall rep r op a b.
ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op a
*> :: forall a b. ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op b
$c*> :: forall rep r op a b.
ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op b
liftA2 :: forall a b c.
(a -> b -> c)
-> ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op c
$cliftA2 :: forall rep r op a b c.
(a -> b -> c)
-> ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op c
<*> :: forall a b.
ImpM rep r op (a -> b) -> ImpM rep r op a -> ImpM rep r op b
$c<*> :: forall rep r op a b.
ImpM rep r op (a -> b) -> ImpM rep r op a -> ImpM rep r op b
pure :: forall a. a -> ImpM rep r op a
$cpure :: forall rep r op a. a -> ImpM rep r op a
Applicative,
      forall a. a -> ImpM rep r op a
forall a b. ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op b
forall a b.
ImpM rep r op a -> (a -> ImpM rep r op b) -> ImpM rep r op b
forall rep r op. Applicative (ImpM rep r op)
forall rep r op a. a -> ImpM rep r op a
forall rep r op a b.
ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op b
forall rep r op a b.
ImpM rep r op a -> (a -> ImpM rep r op b) -> ImpM rep r op b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> ImpM rep r op a
$creturn :: forall rep r op a. a -> ImpM rep r op a
>> :: forall a b. ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op b
$c>> :: forall rep r op a b.
ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op b
>>= :: forall a b.
ImpM rep r op a -> (a -> ImpM rep r op b) -> ImpM rep r op b
$c>>= :: forall rep r op a b.
ImpM rep r op a -> (a -> ImpM rep r op b) -> ImpM rep r op b
Monad,
      MonadState (ImpState rep r op),
      MonadReader (Env rep r op)
    )

instance MonadFreshNames (ImpM rep r op) where
  getNameSource :: ImpM rep r op VNameSource
getNameSource = forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall {k} rep (r :: k) op. ImpState rep r op -> VNameSource
stateNameSource
  putNameSource :: VNameSource -> ImpM rep r op ()
putNameSource VNameSource
src = forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s -> ImpState rep r op
s {stateNameSource :: VNameSource
stateNameSource = VNameSource
src}

-- Cannot be an KernelsMem scope because the index functions have
-- the wrong leaves (VName instead of Imp.Exp).
instance HasScope SOACS (ImpM rep r op) where
  askScope :: ImpM rep r op (Scope SOACS)
askScope = forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall a b. (a -> b) -> a -> b
$ forall a b k. (a -> b) -> Map k a -> Map k b
M.map (forall rep. LetDec rep -> NameInfo rep
LetName forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {rep}. VarEntry rep -> Type
entryType) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} rep (r :: k) op. ImpState rep r op -> VTable rep
stateVTable
    where
      entryType :: VarEntry rep -> Type
entryType (MemVar Maybe (Exp rep)
_ MemEntry
memEntry) =
        forall shape u. Space -> TypeBase shape u
Mem (MemEntry -> Space
entryMemSpace MemEntry
memEntry)
      entryType (ArrayVar Maybe (Exp rep)
_ ArrayEntry
arrayEntry) =
        forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array
          (ArrayEntry -> PrimType
entryArrayElemType ArrayEntry
arrayEntry)
          (forall d. [d] -> ShapeBase d
Shape forall a b. (a -> b) -> a -> b
$ ArrayEntry -> [SubExp]
entryArrayShape ArrayEntry
arrayEntry)
          NoUniqueness
NoUniqueness
      entryType (ScalarVar Maybe (Exp rep)
_ ScalarEntry
scalarEntry) =
        forall shape u. PrimType -> TypeBase shape u
Prim forall a b. (a -> b) -> a -> b
$ ScalarEntry -> PrimType
entryScalarType ScalarEntry
scalarEntry
      entryType (AccVar Maybe (Exp rep)
_ (VName
acc, Shape
ispace, [Type]
ts)) =
        forall shape u. VName -> Shape -> [Type] -> u -> TypeBase shape u
Acc VName
acc Shape
ispace [Type]
ts NoUniqueness
NoUniqueness

runImpM ::
  ImpM rep r op a ->
  r ->
  Operations rep r op ->
  Imp.Space ->
  ImpState rep r op ->
  (a, ImpState rep r op)
runImpM :: forall rep r op a.
ImpM rep r op a
-> r
-> Operations rep r op
-> Space
-> ImpState rep r op
-> (a, ImpState rep r op)
runImpM (ImpM ReaderT (Env rep r op) (State (ImpState rep r op)) a
m) r
r Operations rep r op
ops Space
space = forall s a. State s a -> s -> (a, s)
runState (forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Env rep r op) (State (ImpState rep r op)) a
m forall a b. (a -> b) -> a -> b
$ forall r rep op. r -> Operations rep r op -> Space -> Env rep r op
newEnv r
r Operations rep r op
ops Space
space)

subImpM_ ::
  r' ->
  Operations rep r' op' ->
  ImpM rep r' op' a ->
  ImpM rep r op (Imp.Code op')
subImpM_ :: forall r' rep op' a r op.
r'
-> Operations rep r' op'
-> ImpM rep r' op' a
-> ImpM rep r op (Code op')
subImpM_ r'
r Operations rep r' op'
ops ImpM rep r' op' a
m = forall a b. (a, b) -> b
snd forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall r' rep op' a r op.
r'
-> Operations rep r' op'
-> ImpM rep r' op' a
-> ImpM rep r op (a, Code op')
subImpM r'
r Operations rep r' op'
ops ImpM rep r' op' a
m

subImpM ::
  r' ->
  Operations rep r' op' ->
  ImpM rep r' op' a ->
  ImpM rep r op (a, Imp.Code op')
subImpM :: forall r' rep op' a r op.
r'
-> Operations rep r' op'
-> ImpM rep r' op' a
-> ImpM rep r op (a, Code op')
subImpM r'
r Operations rep r' op'
ops (ImpM ReaderT (Env rep r' op') (State (ImpState rep r' op')) a
m) = do
  Env rep r op
env <- forall r (m :: * -> *). MonadReader r m => m r
ask
  ImpState rep r op
s <- forall s (m :: * -> *). MonadState s m => m s
get

  let env' :: Env rep r' op'
env' =
        Env rep r op
env
          { envExpCompiler :: ExpCompiler rep r' op'
envExpCompiler = forall rep r op. Operations rep r op -> ExpCompiler rep r op
opsExpCompiler Operations rep r' op'
ops,
            envStmsCompiler :: StmsCompiler rep r' op'
envStmsCompiler = forall rep r op. Operations rep r op -> StmsCompiler rep r op
opsStmsCompiler Operations rep r' op'
ops,
            envCopyCompiler :: CopyCompiler rep r' op'
envCopyCompiler = forall rep r op. Operations rep r op -> CopyCompiler rep r op
opsCopyCompiler Operations rep r' op'
ops,
            envOpCompiler :: OpCompiler rep r' op'
envOpCompiler = forall rep r op. Operations rep r op -> OpCompiler rep r op
opsOpCompiler Operations rep r' op'
ops,
            envAllocCompilers :: Map Space (AllocCompiler rep r' op')
envAllocCompilers = forall rep r op.
Operations rep r op -> Map Space (AllocCompiler rep r op)
opsAllocCompilers Operations rep r' op'
ops,
            envEnv :: r'
envEnv = r'
r
          }
      s' :: ImpState rep r' op'
s' =
        ImpState
          { stateVTable :: VTable rep
stateVTable = forall {k} rep (r :: k) op. ImpState rep r op -> VTable rep
stateVTable ImpState rep r op
s,
            stateFunctions :: Functions op'
stateFunctions = forall a. Monoid a => a
mempty,
            stateCode :: Code op'
stateCode = forall a. Monoid a => a
mempty,
            stateNameSource :: VNameSource
stateNameSource = forall {k} rep (r :: k) op. ImpState rep r op -> VNameSource
stateNameSource ImpState rep r op
s,
            stateConstants :: Constants op'
stateConstants = forall a. Monoid a => a
mempty,
            stateWarnings :: Warnings
stateWarnings = forall a. Monoid a => a
mempty,
            stateAccs :: Map VName ([VName], Maybe (Lambda rep, [SubExp]))
stateAccs = forall {k} rep (r :: k) op.
ImpState rep r op
-> Map VName ([VName], Maybe (Lambda rep, [SubExp]))
stateAccs ImpState rep r op
s
          }
      (a
x, ImpState rep r' op'
s'') = forall s a. State s a -> s -> (a, s)
runState (forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Env rep r' op') (State (ImpState rep r' op')) a
m Env rep r' op'
env') ImpState rep r' op'
s'

  forall (m :: * -> *). MonadFreshNames m => VNameSource -> m ()
putNameSource forall a b. (a -> b) -> a -> b
$ forall {k} rep (r :: k) op. ImpState rep r op -> VNameSource
stateNameSource ImpState rep r' op'
s''
  forall rep r op. Warnings -> ImpM rep r op ()
warnings forall a b. (a -> b) -> a -> b
$ forall {k} rep (r :: k) op. ImpState rep r op -> Warnings
stateWarnings ImpState rep r' op'
s''
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
x, forall {k} rep (r :: k) op. ImpState rep r op -> Code op
stateCode ImpState rep r' op'
s'')

-- | Execute a code generation action, returning the code that was
-- emitted.
collect :: ImpM rep r op () -> ImpM rep r op (Imp.Code op)
collect :: forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a, b) -> b
snd forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep r op a. ImpM rep r op a -> ImpM rep r op (a, Code op)
collect'

collect' :: ImpM rep r op a -> ImpM rep r op (a, Imp.Code op)
collect' :: forall rep r op a. ImpM rep r op a -> ImpM rep r op (a, Code op)
collect' ImpM rep r op a
m = do
  Code op
prev_code <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall {k} rep (r :: k) op. ImpState rep r op -> Code op
stateCode
  forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s -> ImpState rep r op
s {stateCode :: Code op
stateCode = forall a. Monoid a => a
mempty}
  a
x <- ImpM rep r op a
m
  Code op
new_code <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall {k} rep (r :: k) op. ImpState rep r op -> Code op
stateCode
  forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s -> ImpState rep r op
s {stateCode :: Code op
stateCode = Code op
prev_code}
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
x, Code op
new_code)

-- | Execute a code generation action, wrapping the generated code
-- within a 'Imp.Comment' with the given description.
comment :: T.Text -> ImpM rep r op () -> ImpM rep r op ()
comment :: forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
comment Text
desc ImpM rep r op ()
m = do
  Code op
code <- forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect ImpM rep r op ()
m
  forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. Text -> Code a -> Code a
Imp.Comment Text
desc Code op
code

-- | Emit some generated imperative code.
emit :: Imp.Code op -> ImpM rep r op ()
emit :: forall op rep r. Code op -> ImpM rep r op ()
emit Code op
code = forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s -> ImpState rep r op
s {stateCode :: Code op
stateCode = forall {k} rep (r :: k) op. ImpState rep r op -> Code op
stateCode ImpState rep r op
s forall a. Semigroup a => a -> a -> a
<> Code op
code}

warnings :: Warnings -> ImpM rep r op ()
warnings :: forall rep r op. Warnings -> ImpM rep r op ()
warnings Warnings
ws = forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s -> ImpState rep r op
s {stateWarnings :: Warnings
stateWarnings = Warnings
ws forall a. Semigroup a => a -> a -> a
<> forall {k} rep (r :: k) op. ImpState rep r op -> Warnings
stateWarnings ImpState rep r op
s}

-- | Emit a warning about something the user should be aware of.
warn :: Located loc => loc -> [loc] -> T.Text -> ImpM rep r op ()
warn :: forall loc rep r op.
Located loc =>
loc -> [loc] -> Text -> ImpM rep r op ()
warn loc
loc [loc]
locs Text
problem =
  forall rep r op. Warnings -> ImpM rep r op ()
warnings forall a b. (a -> b) -> a -> b
$ SrcLoc -> [SrcLoc] -> Doc () -> Warnings
singleWarning' (forall a. Located a => a -> SrcLoc
srclocOf loc
loc) (forall a b. (a -> b) -> [a] -> [b]
map forall a. Located a => a -> SrcLoc
srclocOf [loc]
locs) (forall a ann. Pretty a => a -> Doc ann
pretty Text
problem)

-- | Emit a function in the generated code.
emitFunction :: Name -> Imp.Function op -> ImpM rep r op ()
emitFunction :: forall op rep r. Name -> Function op -> ImpM rep r op ()
emitFunction Name
fname Function op
fun = do
  Imp.Functions [(Name, Function op)]
fs <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall {k} rep (r :: k) op. ImpState rep r op -> Functions op
stateFunctions
  forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s -> ImpState rep r op
s {stateFunctions :: Functions op
stateFunctions = forall a. [(Name, Function a)] -> Functions a
Imp.Functions forall a b. (a -> b) -> a -> b
$ (Name
fname, Function op
fun) forall a. a -> [a] -> [a]
: [(Name, Function op)]
fs}

-- | Check if a function of a given name exists.
hasFunction :: Name -> ImpM rep r op Bool
hasFunction :: forall rep r op. Name -> ImpM rep r op Bool
hasFunction Name
fname = forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s ->
  let Imp.Functions [(Name, Function op)]
fs = forall {k} rep (r :: k) op. ImpState rep r op -> Functions op
stateFunctions ImpState rep r op
s
   in forall a. Maybe a -> Bool
isJust forall a b. (a -> b) -> a -> b
$ forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Name
fname [(Name, Function op)]
fs

constsVTable :: Mem rep inner => Stms rep -> VTable rep
constsVTable :: forall rep (inner :: * -> *).
Mem rep inner =>
Stms rep -> VTable rep
constsVTable = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap forall {rep}.
HasLetDecMem (LetDec rep) =>
Stm rep -> Map VName (VarEntry rep)
stmVtable
  where
    stmVtable :: Stm rep -> Map VName (VarEntry rep)
stmVtable (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ Exp rep
e) =
      forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (forall {t} {rep}.
HasLetDecMem t =>
Exp rep -> PatElem t -> Map VName (VarEntry rep)
peVtable Exp rep
e) forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat
    peVtable :: Exp rep -> PatElem t -> Map VName (VarEntry rep)
peVtable Exp rep
e (PatElem VName
name t
dec) =
      forall k a. k -> a -> Map k a
M.singleton VName
name forall a b. (a -> b) -> a -> b
$ forall rep. Maybe (Exp rep) -> LParamMem -> VarEntry rep
memBoundToVarEntry (forall a. a -> Maybe a
Just Exp rep
e) forall a b. (a -> b) -> a -> b
$ forall t. HasLetDecMem t => t -> LParamMem
letDecMem t
dec

compileProg ::
  (Mem rep inner, FreeIn op, MonadFreshNames m) =>
  r ->
  Operations rep r op ->
  Imp.Space ->
  Prog rep ->
  m (Warnings, Imp.Definitions op)
compileProg :: forall rep (inner :: * -> *) op (m :: * -> *) r.
(Mem rep inner, FreeIn op, MonadFreshNames m) =>
r
-> Operations rep r op
-> Space
-> Prog rep
-> m (Warnings, Definitions op)
compileProg r
r Operations rep r op
ops Space
space (Prog OpaqueTypes
types Stms rep
consts [FunDef rep]
funs) =
  forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
    let ([()]
_, [ImpState rep r op]
ss) =
          forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ forall b a. Strategy b -> (a -> b) -> [a] -> [b]
parMap forall a. Strategy a
rpar (VNameSource -> FunDef rep -> ((), ImpState rep r op)
compileFunDef' VNameSource
src) [FunDef rep]
funs
        free_in_funs :: Names
free_in_funs =
          forall a. FreeIn a => a -> Names
freeIn forall a b. (a -> b) -> a -> b
$ forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall {k} rep (r :: k) op. ImpState rep r op -> Functions op
stateFunctions [ImpState rep r op]
ss
        ((), ImpState rep r op
s') =
          forall rep r op a.
ImpM rep r op a
-> r
-> Operations rep r op
-> Space
-> ImpState rep r op
-> (a, ImpState rep r op)
runImpM (forall rep r op. Names -> Stms rep -> ImpM rep r op ()
compileConsts Names
free_in_funs Stms rep
consts) r
r Operations rep r op
ops Space
space forall a b. (a -> b) -> a -> b
$
            forall {k} {k} {rep} {r :: k} {op} {rep} {r :: k}.
[ImpState rep r op] -> ImpState rep r op
combineStates [ImpState rep r op]
ss
     in ( ( forall {k} rep (r :: k) op. ImpState rep r op -> Warnings
stateWarnings ImpState rep r op
s',
            forall a.
OpaqueTypes -> Constants a -> Functions a -> Definitions a
Imp.Definitions
              OpaqueTypes
types
              (forall {k} rep (r :: k) op. ImpState rep r op -> Constants op
stateConstants ImpState rep r op
s' forall a. Semigroup a => a -> a -> a
<> forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap forall {k} rep (r :: k) op. ImpState rep r op -> Constants op
stateConstants [ImpState rep r op]
ss)
              (forall {k} rep (r :: k) op. ImpState rep r op -> Functions op
stateFunctions ImpState rep r op
s')
          ),
          forall {k} rep (r :: k) op. ImpState rep r op -> VNameSource
stateNameSource ImpState rep r op
s'
        )
  where
    compileFunDef' :: VNameSource -> FunDef rep -> ((), ImpState rep r op)
compileFunDef' VNameSource
src FunDef rep
fdef =
      forall rep r op a.
ImpM rep r op a
-> r
-> Operations rep r op
-> Space
-> ImpState rep r op
-> (a, ImpState rep r op)
runImpM
        (forall rep (inner :: * -> *) r op.
Mem rep inner =>
OpaqueTypes -> FunDef rep -> ImpM rep r op ()
compileFunDef OpaqueTypes
types FunDef rep
fdef)
        r
r
        Operations rep r op
ops
        Space
space
        (forall {k} rep (r :: k) op. VNameSource -> ImpState rep r op
newState VNameSource
src) {stateVTable :: VTable rep
stateVTable = forall rep (inner :: * -> *).
Mem rep inner =>
Stms rep -> VTable rep
constsVTable Stms rep
consts}

    combineStates :: [ImpState rep r op] -> ImpState rep r op
combineStates [ImpState rep r op]
ss =
      let Imp.Functions [(Name, Function op)]
funs' = forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall {k} rep (r :: k) op. ImpState rep r op -> Functions op
stateFunctions [ImpState rep r op]
ss
          src :: VNameSource
src = forall a. Monoid a => [a] -> a
mconcat (forall a b. (a -> b) -> [a] -> [b]
map forall {k} rep (r :: k) op. ImpState rep r op -> VNameSource
stateNameSource [ImpState rep r op]
ss)
       in (forall {k} rep (r :: k) op. VNameSource -> ImpState rep r op
newState VNameSource
src)
            { stateFunctions :: Functions op
stateFunctions =
                forall a. [(Name, Function a)] -> Functions a
Imp.Functions forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [(k, a)]
M.toList forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(Name, Function op)]
funs',
              stateWarnings :: Warnings
stateWarnings =
                forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall {k} rep (r :: k) op. ImpState rep r op -> Warnings
stateWarnings [ImpState rep r op]
ss
            }

compileConsts :: Names -> Stms rep -> ImpM rep r op ()
compileConsts :: forall rep r op. Names -> Stms rep -> ImpM rep r op ()
compileConsts Names
used_consts Stms rep
stms = forall rep r op a. ImpM rep r op (Names, a) -> ImpM rep r op a
genConstants forall a b. (a -> b) -> a -> b
$ do
  forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
used_consts Stms rep
stms forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (Names
used_consts, ())

lookupOpaqueType :: Name -> OpaqueTypes -> OpaqueType
lookupOpaqueType :: Name -> OpaqueTypes -> OpaqueType
lookupOpaqueType Name
v (OpaqueTypes [(Name, OpaqueType)]
types) =
  case forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Name
v [(Name, OpaqueType)]
types of
    Just OpaqueType
t -> OpaqueType
t
    Maybe OpaqueType
Nothing -> forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"Unknown opaque type: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show Name
v

valueTypeSign :: ValueType -> Signedness
valueTypeSign :: ValueType -> Signedness
valueTypeSign (ValueType Signedness
sign Rank
_ PrimType
_) = Signedness
sign

entryPointSignedness :: OpaqueTypes -> EntryPointType -> [Signedness]
entryPointSignedness :: OpaqueTypes -> EntryPointType -> [Signedness]
entryPointSignedness OpaqueTypes
_ (TypeTransparent ValueType
vt) = [ValueType -> Signedness
valueTypeSign ValueType
vt]
entryPointSignedness OpaqueTypes
types (TypeOpaque Name
desc) =
  case Name -> OpaqueTypes -> OpaqueType
lookupOpaqueType Name
desc OpaqueTypes
types of
    OpaqueType [ValueType]
vts -> forall a b. (a -> b) -> [a] -> [b]
map ValueType -> Signedness
valueTypeSign [ValueType]
vts
    OpaqueRecord [(Name, EntryPointType)]
fs -> forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (OpaqueTypes -> EntryPointType -> [Signedness]
entryPointSignedness OpaqueTypes
types forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) [(Name, EntryPointType)]
fs

-- | How many value parameters are accepted by this entry point?  This
-- is used to determine which of the function parameters correspond to
-- the parameters of the original function (they must all come at the
-- end).
entryPointSize :: OpaqueTypes -> EntryPointType -> Int
entryPointSize :: OpaqueTypes -> EntryPointType -> Int
entryPointSize OpaqueTypes
_ (TypeTransparent ValueType
_) = Int
1
entryPointSize OpaqueTypes
types (TypeOpaque Name
desc) =
  case Name -> OpaqueTypes -> OpaqueType
lookupOpaqueType Name
desc OpaqueTypes
types of
    OpaqueType [ValueType]
vts -> forall (t :: * -> *) a. Foldable t => t a -> Int
length [ValueType]
vts
    OpaqueRecord [(Name, EntryPointType)]
fs -> forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (OpaqueTypes -> EntryPointType -> Int
entryPointSize OpaqueTypes
types forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) [(Name, EntryPointType)]
fs

compileInParam ::
  Mem rep inner =>
  FParam rep ->
  ImpM rep r op (Either Imp.Param ArrayDecl)
compileInParam :: forall rep (inner :: * -> *) r op.
Mem rep inner =>
FParam rep -> ImpM rep r op (Either Param ArrayDecl)
compileInParam FParam rep
fparam = case forall dec. Param dec -> dec
paramDec FParam rep
fparam of
  MemPrim PrimType
bt ->
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a b. a -> Either a b
Left forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Param
Imp.ScalarParam VName
name PrimType
bt
  MemMem Space
space ->
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a b. a -> Either a b
Left forall a b. (a -> b) -> a -> b
$ VName -> Space -> Param
Imp.MemParam VName
name Space
space
  MemArray PrimType
bt Shape
shape Uniqueness
_ (ArrayIn VName
mem IxFun (TExp Int64)
ixfun) ->
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a b. b -> Either a b
Right forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> MemLoc -> ArrayDecl
ArrayDecl VName
name PrimType
bt forall a b. (a -> b) -> a -> b
$ VName -> [SubExp] -> IxFun (TExp Int64) -> MemLoc
MemLoc VName
mem (forall d. ShapeBase d -> [d]
shapeDims Shape
shape) IxFun (TExp Int64)
ixfun
  MemAcc {} ->
    forall a. HasCallStack => [Char] -> a
error [Char]
"Functions may not have accumulator parameters."
  where
    name :: VName
name = forall dec. Param dec -> VName
paramName FParam rep
fparam

data ArrayDecl = ArrayDecl VName PrimType MemLoc

compileInParams ::
  Mem rep inner =>
  OpaqueTypes ->
  [FParam rep] ->
  Maybe [EntryParam] ->
  ImpM rep r op ([Imp.Param], [ArrayDecl], Maybe [((Name, Uniqueness), Imp.ExternalValue)])
compileInParams :: forall rep (inner :: * -> *) r op.
Mem rep inner =>
OpaqueTypes
-> [FParam rep]
-> Maybe [EntryParam]
-> ImpM
     rep
     r
     op
     ([Param], [ArrayDecl], Maybe [((Name, Uniqueness), ExternalValue)])
compileInParams OpaqueTypes
types [FParam rep]
params Maybe [EntryParam]
eparams = do
  ([Param]
inparams, [ArrayDecl]
arrayds) <- forall a b. [Either a b] -> ([a], [b])
partitionEithers forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall rep (inner :: * -> *) r op.
Mem rep inner =>
FParam rep -> ImpM rep r op (Either Param ArrayDecl)
compileInParam [FParam rep]
params
  let findArray :: VName -> Maybe ArrayDecl
findArray VName
x = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (VName -> ArrayDecl -> Bool
isArrayDecl VName
x) [ArrayDecl]
arrayds

      summaries :: Map VName Space
summaries = forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$ forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe forall {d} {u} {ret}.
Param (MemInfo d u ret) -> Maybe (VName, Space)
memSummary [FParam rep]
params
        where
          memSummary :: Param (MemInfo d u ret) -> Maybe (VName, Space)
memSummary Param (MemInfo d u ret)
param
            | MemMem Space
space <- forall dec. Param dec -> dec
paramDec Param (MemInfo d u ret)
param =
                forall a. a -> Maybe a
Just (forall dec. Param dec -> VName
paramName Param (MemInfo d u ret)
param, Space
space)
            | Bool
otherwise =
                forall a. Maybe a
Nothing

      findMemInfo :: VName -> Maybe Space
      findMemInfo :: VName -> Maybe Space
findMemInfo = forall a b c. (a -> b -> c) -> b -> a -> c
flip forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Map VName Space
summaries

      mkValueDesc :: Param FParamMem -> Signedness -> Maybe ValueDesc
mkValueDesc Param FParamMem
fparam Signedness
signedness =
        case (VName -> Maybe ArrayDecl
findArray forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param FParamMem
fparam, forall dec. Typed dec => Param dec -> Type
paramType Param FParamMem
fparam) of
          (Just (ArrayDecl VName
_ PrimType
bt (MemLoc VName
mem [SubExp]
shape IxFun (TExp Int64)
_)), Type
_) -> do
            Space
memspace <- VName -> Maybe Space
findMemInfo VName
mem
            forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ VName -> Space -> PrimType -> Signedness -> [SubExp] -> ValueDesc
Imp.ArrayValue VName
mem Space
memspace PrimType
bt Signedness
signedness [SubExp]
shape
          (Maybe ArrayDecl
_, Prim PrimType
bt) ->
            forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ PrimType -> Signedness -> VName -> ValueDesc
Imp.ScalarValue PrimType
bt Signedness
signedness forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param FParamMem
fparam
          (Maybe ArrayDecl, Type)
_ ->
            forall a. Maybe a
Nothing

      mkExts :: [EntryParam]
-> [Param FParamMem] -> [((Name, Uniqueness), ExternalValue)]
mkExts (EntryParam Name
v Uniqueness
u et :: EntryPointType
et@(TypeOpaque Name
desc) : [EntryParam]
epts) [Param FParamMem]
fparams =
        let signs :: [Signedness]
signs = OpaqueTypes -> EntryPointType -> [Signedness]
entryPointSignedness OpaqueTypes
types EntryPointType
et
            n :: Int
n = OpaqueTypes -> EntryPointType -> Int
entryPointSize OpaqueTypes
types EntryPointType
et
            ([Param FParamMem]
fparams', [Param FParamMem]
rest) = forall a. Int -> [a] -> ([a], [a])
splitAt Int
n [Param FParamMem]
fparams
         in ( (Name
v, Uniqueness
u),
              Name -> [ValueDesc] -> ExternalValue
Imp.OpaqueValue
                Name
desc
                (forall a. [Maybe a] -> [a]
catMaybes forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Param FParamMem -> Signedness -> Maybe ValueDesc
mkValueDesc [Param FParamMem]
fparams' [Signedness]
signs)
            )
              forall a. a -> [a] -> [a]
: [EntryParam]
-> [Param FParamMem] -> [((Name, Uniqueness), ExternalValue)]
mkExts [EntryParam]
epts [Param FParamMem]
rest
      mkExts (EntryParam Name
v Uniqueness
u (TypeTransparent (ValueType Signedness
s Rank
_ PrimType
_)) : [EntryParam]
epts) (Param FParamMem
fparam : [Param FParamMem]
fparams) =
        forall a. Maybe a -> [a]
maybeToList (((Name
v, Uniqueness
u),) forall b c a. (b -> c) -> (a -> b) -> a -> c
. ValueDesc -> ExternalValue
Imp.TransparentValue forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Param FParamMem -> Signedness -> Maybe ValueDesc
mkValueDesc Param FParamMem
fparam Signedness
s)
          forall a. [a] -> [a] -> [a]
++ [EntryParam]
-> [Param FParamMem] -> [((Name, Uniqueness), ExternalValue)]
mkExts [EntryParam]
epts [Param FParamMem]
fparams
      mkExts [EntryParam]
_ [Param FParamMem]
_ = []

  forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( [Param]
inparams,
      [ArrayDecl]
arrayds,
      case Maybe [EntryParam]
eparams of
        Just [EntryParam]
eparams' ->
          let num_val_params :: Int
num_val_params = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum (forall a b. (a -> b) -> [a] -> [b]
map (OpaqueTypes -> EntryPointType -> Int
entryPointSize OpaqueTypes
types forall b c a. (b -> c) -> (a -> b) -> a -> c
. EntryParam -> EntryPointType
entryParamType) [EntryParam]
eparams')
              ([Param FParamMem]
_ctx_params, [Param FParamMem]
val_params) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [FParam rep]
params forall a. Num a => a -> a -> a
- Int
num_val_params) [FParam rep]
params
           in forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ [EntryParam]
-> [Param FParamMem] -> [((Name, Uniqueness), ExternalValue)]
mkExts [EntryParam]
eparams' [Param FParamMem]
val_params
        Maybe [EntryParam]
Nothing -> forall a. Maybe a
Nothing
    )
  where
    isArrayDecl :: VName -> ArrayDecl -> Bool
isArrayDecl VName
x (ArrayDecl VName
y PrimType
_ MemLoc
_) = VName
x forall a. Eq a => a -> a -> Bool
== VName
y

compileOutParam ::
  FunReturns -> ImpM rep r op (Maybe Imp.Param, ValueDestination)
compileOutParam :: forall rep r op.
RetTypeMem -> ImpM rep r op (Maybe Param, ValueDestination)
compileOutParam (MemPrim PrimType
t) = do
  VName
name <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"prim_out"
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Param
Imp.ScalarParam VName
name PrimType
t, VName -> ValueDestination
ScalarDestination VName
name)
compileOutParam (MemMem Space
space) = do
  VName
name <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"mem_out"
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ VName -> Space -> Param
Imp.MemParam VName
name Space
space, VName -> ValueDestination
MemoryDestination VName
name)
compileOutParam MemArray {} =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. Maybe a
Nothing, Maybe MemLoc -> ValueDestination
ArrayDestination forall a. Maybe a
Nothing)
compileOutParam MemAcc {} =
  forall a. HasCallStack => [Char] -> a
error [Char]
"Functions may not return accumulators."

compileExternalValues ::
  Mem rep inner =>
  OpaqueTypes ->
  [RetType rep] ->
  [EntryResult] ->
  [Maybe Imp.Param] ->
  ImpM rep r op [(Uniqueness, Imp.ExternalValue)]
compileExternalValues :: forall rep (inner :: * -> *) r op.
Mem rep inner =>
OpaqueTypes
-> [RetType rep]
-> [EntryResult]
-> [Maybe Param]
-> ImpM rep r op [(Uniqueness, ExternalValue)]
compileExternalValues OpaqueTypes
types [RetType rep]
orig_rts [EntryResult]
orig_epts [Maybe Param]
maybe_params = do
  let ([RetTypeMem]
ctx_rts, [RetTypeMem]
val_rts) =
        forall a. Int -> [a] -> ([a], [a])
splitAt
          (forall (t :: * -> *) a. Foldable t => t a -> Int
length [RetType rep]
orig_rts forall a. Num a => a -> a -> a
- forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum (forall a b. (a -> b) -> [a] -> [b]
map (OpaqueTypes -> EntryPointType -> Int
entryPointSize OpaqueTypes
types forall b c a. (b -> c) -> (a -> b) -> a -> c
. EntryResult -> EntryPointType
entryResultType) [EntryResult]
orig_epts))
          [RetType rep]
orig_rts

  let nthOut :: Int -> VName
nthOut Int
i = case forall int a. Integral int => int -> [a] -> Maybe a
maybeNth Int
i [Maybe Param]
maybe_params of
        Just (Just Param
p) -> Param -> VName
Imp.paramName Param
p
        Just Maybe Param
Nothing -> forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"Output " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show Int
i forall a. [a] -> [a] -> [a]
++ [Char]
" not a param."
        Maybe (Maybe Param)
Nothing -> forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"Param " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show Int
i forall a. [a] -> [a] -> [a]
++ [Char]
" does not exist."

      mkValueDesc :: Int -> Signedness -> RetTypeMem -> ImpM rep r op ValueDesc
mkValueDesc Int
_ Signedness
signedness (MemArray PrimType
t ShapeBase (Ext SubExp)
shape Uniqueness
_ MemReturn
ret) = do
        (VName
mem, Space
space) <-
          case MemReturn
ret of
            ReturnsNewBlock Space
space Int
j ExtIxFun
_ixfun ->
              forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> VName
nthOut Int
j, Space
space)
            ReturnsInBlock VName
mem ExtIxFun
_ixfun -> do
              Space
space <- MemEntry -> Space
entryMemSpace forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep r op. VName -> ImpM rep r op MemEntry
lookupMemory VName
mem
              forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
mem, Space
space)
        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> Space -> PrimType -> Signedness -> [SubExp] -> ValueDesc
Imp.ArrayValue VName
mem Space
space PrimType
t Signedness
signedness forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map Ext SubExp -> SubExp
f forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims ShapeBase (Ext SubExp)
shape
        where
          f :: Ext SubExp -> SubExp
f (Free SubExp
v) = SubExp
v
          f (Ext Int
i) = VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ Int -> VName
nthOut Int
i
      mkValueDesc Int
i Signedness
signedness (MemPrim PrimType
bt) =
        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ PrimType -> Signedness -> VName -> ValueDesc
Imp.ScalarValue PrimType
bt Signedness
signedness forall a b. (a -> b) -> a -> b
$ Int -> VName
nthOut Int
i
      mkValueDesc Int
_ Signedness
_ MemAcc {} =
        forall a. HasCallStack => [Char] -> a
error [Char]
"mkValueDesc: unexpected MemAcc output."
      mkValueDesc Int
_ Signedness
_ MemMem {} =
        forall a. HasCallStack => [Char] -> a
error [Char]
"mkValueDesc: unexpected MemMem output."

      mkExts :: Int
-> [EntryResult]
-> [RetTypeMem]
-> ImpM rep r op [(Uniqueness, ExternalValue)]
mkExts Int
i (EntryResult Uniqueness
u et :: EntryPointType
et@(TypeOpaque Name
desc) : [EntryResult]
epts) [RetTypeMem]
rets = do
        let signs :: [Signedness]
signs = OpaqueTypes -> EntryPointType -> [Signedness]
entryPointSignedness OpaqueTypes
types EntryPointType
et
            n :: Int
n = OpaqueTypes -> EntryPointType -> Int
entryPointSize OpaqueTypes
types EntryPointType
et
            ([RetTypeMem]
rets', [RetTypeMem]
rest) = forall a. Int -> [a] -> ([a], [a])
splitAt Int
n [RetTypeMem]
rets
        [ValueDesc]
vds <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Int
i ..] [Signedness]
signs [RetTypeMem]
rets') forall a b. (a -> b) -> a -> b
$ \(Int
j, Signedness
s, RetTypeMem
r) -> Int -> Signedness -> RetTypeMem -> ImpM rep r op ValueDesc
mkValueDesc Int
j Signedness
s RetTypeMem
r
        ((Uniqueness
u, Name -> [ValueDesc] -> ExternalValue
Imp.OpaqueValue Name
desc [ValueDesc]
vds) :) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int
-> [EntryResult]
-> [RetTypeMem]
-> ImpM rep r op [(Uniqueness, ExternalValue)]
mkExts (Int
i forall a. Num a => a -> a -> a
+ Int
n) [EntryResult]
epts [RetTypeMem]
rest
      mkExts Int
i (EntryResult Uniqueness
u (TypeTransparent (ValueType Signedness
s Rank
_ PrimType
_)) : [EntryResult]
epts) (RetTypeMem
ret : [RetTypeMem]
rets) = do
        ValueDesc
vd <- Int -> Signedness -> RetTypeMem -> ImpM rep r op ValueDesc
mkValueDesc Int
i Signedness
s RetTypeMem
ret
        ((Uniqueness
u, ValueDesc -> ExternalValue
Imp.TransparentValue ValueDesc
vd) :) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int
-> [EntryResult]
-> [RetTypeMem]
-> ImpM rep r op [(Uniqueness, ExternalValue)]
mkExts (Int
i forall a. Num a => a -> a -> a
+ Int
1) [EntryResult]
epts [RetTypeMem]
rets
      mkExts Int
_ [EntryResult]
_ [RetTypeMem]
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure []

  Int
-> [EntryResult]
-> [RetTypeMem]
-> ImpM rep r op [(Uniqueness, ExternalValue)]
mkExts (forall (t :: * -> *) a. Foldable t => t a -> Int
length [RetTypeMem]
ctx_rts) [EntryResult]
orig_epts [RetTypeMem]
val_rts

compileOutParams ::
  Mem rep inner =>
  OpaqueTypes ->
  [RetType rep] ->
  Maybe [EntryResult] ->
  ImpM rep r op (Maybe [(Uniqueness, Imp.ExternalValue)], [Imp.Param], [ValueDestination])
compileOutParams :: forall rep (inner :: * -> *) r op.
Mem rep inner =>
OpaqueTypes
-> [RetType rep]
-> Maybe [EntryResult]
-> ImpM
     rep
     r
     op
     (Maybe [(Uniqueness, ExternalValue)], [Param], [ValueDestination])
compileOutParams OpaqueTypes
types [RetType rep]
orig_rts Maybe [EntryResult]
maybe_orig_epts = do
  ([Maybe Param]
maybe_params, [ValueDestination]
dests) <- forall (m :: * -> *) a b c.
Applicative m =>
(a -> m (b, c)) -> [a] -> m ([b], [c])
mapAndUnzipM forall rep r op.
RetTypeMem -> ImpM rep r op (Maybe Param, ValueDestination)
compileOutParam [RetType rep]
orig_rts
  Maybe [(Uniqueness, ExternalValue)]
evs <- case Maybe [EntryResult]
maybe_orig_epts of
    Just [EntryResult]
orig_epts ->
      forall a. a -> Maybe a
Just forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep (inner :: * -> *) r op.
Mem rep inner =>
OpaqueTypes
-> [RetType rep]
-> [EntryResult]
-> [Maybe Param]
-> ImpM rep r op [(Uniqueness, ExternalValue)]
compileExternalValues OpaqueTypes
types [RetType rep]
orig_rts [EntryResult]
orig_epts [Maybe Param]
maybe_params
    Maybe [EntryResult]
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe [(Uniqueness, ExternalValue)]
evs, forall a. [Maybe a] -> [a]
catMaybes [Maybe Param]
maybe_params, [ValueDestination]
dests)

compileFunDef ::
  Mem rep inner =>
  OpaqueTypes ->
  FunDef rep ->
  ImpM rep r op ()
compileFunDef :: forall rep (inner :: * -> *) r op.
Mem rep inner =>
OpaqueTypes -> FunDef rep -> ImpM rep r op ()
compileFunDef OpaqueTypes
types (FunDef Maybe EntryPoint
entry Attrs
_ Name
fname [RetType rep]
rettype [FParam rep]
params Body rep
body) =
  forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (\Env rep r op
env -> Env rep r op
env {envFunction :: Maybe Name
envFunction = Maybe Name
name_entry forall (m :: * -> *) a. MonadPlus m => m a -> m a -> m a
`mplus` forall a. a -> Maybe a
Just Name
fname}) forall a b. (a -> b) -> a -> b
$ do
    (([Param]
outparams, [Param]
inparams, Maybe [(Uniqueness, ExternalValue)]
results, Maybe [((Name, Uniqueness), ExternalValue)]
args), Code op
body') <- forall rep r op a. ImpM rep r op a -> ImpM rep r op (a, Code op)
collect' ImpM
  rep
  r
  op
  ([Param], [Param], Maybe [(Uniqueness, ExternalValue)],
   Maybe [((Name, Uniqueness), ExternalValue)])
compile
    let entry' :: Maybe EntryPoint
entry' = case (Maybe Name
name_entry, Maybe [(Uniqueness, ExternalValue)]
results, Maybe [((Name, Uniqueness), ExternalValue)]
args) of
          (Just Name
name_entry', Just [(Uniqueness, ExternalValue)]
results', Just [((Name, Uniqueness), ExternalValue)]
args') ->
            forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Name
-> [(Uniqueness, ExternalValue)]
-> [((Name, Uniqueness), ExternalValue)]
-> EntryPoint
Imp.EntryPoint Name
name_entry' [(Uniqueness, ExternalValue)]
results' [((Name, Uniqueness), ExternalValue)]
args'
          (Maybe Name, Maybe [(Uniqueness, ExternalValue)],
 Maybe [((Name, Uniqueness), ExternalValue)])
_ ->
            forall a. Maybe a
Nothing
    forall op rep r. Name -> Function op -> ImpM rep r op ()
emitFunction Name
fname forall a b. (a -> b) -> a -> b
$ forall a.
Maybe EntryPoint -> [Param] -> [Param] -> Code a -> FunctionT a
Imp.Function Maybe EntryPoint
entry' [Param]
outparams [Param]
inparams Code op
body'
  where
    (Maybe Name
name_entry, Maybe [EntryParam]
params_entry, Maybe [EntryResult]
ret_entry) = case Maybe EntryPoint
entry of
      Maybe EntryPoint
Nothing -> (forall a. Maybe a
Nothing, forall a. Maybe a
Nothing, forall a. Maybe a
Nothing)
      Just (Name
x, [EntryParam]
y, [EntryResult]
z) -> (forall a. a -> Maybe a
Just Name
x, forall a. a -> Maybe a
Just [EntryParam]
y, forall a. a -> Maybe a
Just [EntryResult]
z)
    compile :: ImpM
  rep
  r
  op
  ([Param], [Param], Maybe [(Uniqueness, ExternalValue)],
   Maybe [((Name, Uniqueness), ExternalValue)])
compile = do
      ([Param]
inparams, [ArrayDecl]
arrayds, Maybe [((Name, Uniqueness), ExternalValue)]
args) <- forall rep (inner :: * -> *) r op.
Mem rep inner =>
OpaqueTypes
-> [FParam rep]
-> Maybe [EntryParam]
-> ImpM
     rep
     r
     op
     ([Param], [ArrayDecl], Maybe [((Name, Uniqueness), ExternalValue)])
compileInParams OpaqueTypes
types [FParam rep]
params Maybe [EntryParam]
params_entry
      (Maybe [(Uniqueness, ExternalValue)]
results, [Param]
outparams, [ValueDestination]
dests) <- forall rep (inner :: * -> *) r op.
Mem rep inner =>
OpaqueTypes
-> [RetType rep]
-> Maybe [EntryResult]
-> ImpM
     rep
     r
     op
     (Maybe [(Uniqueness, ExternalValue)], [Param], [ValueDestination])
compileOutParams OpaqueTypes
types [RetType rep]
rettype Maybe [EntryResult]
ret_entry
      forall rep (inner :: * -> *) r op.
Mem rep inner =>
[FParam rep] -> ImpM rep r op ()
addFParams [FParam rep]
params
      forall rep r op. [ArrayDecl] -> ImpM rep r op ()
addArrays [ArrayDecl]
arrayds

      let Body BodyDec rep
_ Stms rep
stms Result
ses = Body rep
body
      forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms (forall a. FreeIn a => a -> Names
freeIn Result
ses) Stms rep
stms forall a b. (a -> b) -> a -> b
$
        forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [ValueDestination]
dests Result
ses) forall a b. (a -> b) -> a -> b
$
          \(ValueDestination
d, SubExpRes Certs
_ SubExp
se) -> forall rep r op.
ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIMDest ValueDestination
d [] SubExp
se []

      forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Param]
outparams, [Param]
inparams, Maybe [(Uniqueness, ExternalValue)]
results, Maybe [((Name, Uniqueness), ExternalValue)]
args)

compileBody :: Pat (LetDec rep) -> Body rep -> ImpM rep r op ()
compileBody :: forall rep r op. Pat (LetDec rep) -> Body rep -> ImpM rep r op ()
compileBody Pat (LetDec rep)
pat (Body BodyDec rep
_ Stms rep
stms Result
ses) = do
  [ValueDestination]
dests <- forall rep r op.
Pat (LetDec rep) -> ImpM rep r op [ValueDestination]
destinationFromPat Pat (LetDec rep)
pat
  forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms (forall a. FreeIn a => a -> Names
freeIn Result
ses) Stms rep
stms forall a b. (a -> b) -> a -> b
$
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [ValueDestination]
dests Result
ses) forall a b. (a -> b) -> a -> b
$
      \(ValueDestination
d, SubExpRes Certs
_ SubExp
se) -> forall rep r op.
ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIMDest ValueDestination
d [] SubExp
se []

compileBody' :: [Param dec] -> Body rep -> ImpM rep r op ()
compileBody' :: forall dec rep r op. [Param dec] -> Body rep -> ImpM rep r op ()
compileBody' [Param dec]
params (Body BodyDec rep
_ Stms rep
stms Result
ses) =
  forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms (forall a. FreeIn a => a -> Names
freeIn Result
ses) Stms rep
stms forall a b. (a -> b) -> a -> b
$
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Param dec]
params Result
ses) forall a b. (a -> b) -> a -> b
$
      \(Param dec
param, SubExpRes Certs
_ SubExp
se) -> forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (forall dec. Param dec -> VName
paramName Param dec
param) [] SubExp
se []

compileLoopBody :: Typed dec => [Param dec] -> Body rep -> ImpM rep r op ()
compileLoopBody :: forall dec rep r op.
Typed dec =>
[Param dec] -> Body rep -> ImpM rep r op ()
compileLoopBody [Param dec]
mergeparams (Body BodyDec rep
_ Stms rep
stms Result
ses) = do
  -- We cannot write the results to the merge parameters immediately,
  -- as some of the results may actually *be* merge parameters, and
  -- would thus be clobbered.  Therefore, we first copy to new
  -- variables mirroring the merge parameters, and then copy this
  -- buffer to the merge parameters.  This is efficient, because the
  -- operations are all scalar operations.
  [VName]
tmpnames <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a. [a] -> [a] -> [a]
++ [Char]
"_tmp") forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> [Char]
baseString forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> VName
paramName) [Param dec]
mergeparams
  forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms (forall a. FreeIn a => a -> Names
freeIn Result
ses) Stms rep
stms forall a b. (a -> b) -> a -> b
$ do
    [ImpM rep r op ()]
copy_to_merge_params <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Param dec]
mergeparams [VName]
tmpnames Result
ses) forall a b. (a -> b) -> a -> b
$ \(Param dec
p, VName
tmp, SubExpRes Certs
_ SubExp
se) ->
      case forall t. Typed t => t -> Type
typeOf Param dec
p of
        Prim PrimType
pt -> do
          forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
tmp Volatility
Imp.Nonvolatile PrimType
pt
          forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. VName -> Exp -> Code a
Imp.SetScalar VName
tmp forall a b. (a -> b) -> a -> b
$ forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
pt SubExp
se
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. VName -> Exp -> Code a
Imp.SetScalar (forall dec. Param dec -> VName
paramName Param dec
p) forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
tmp PrimType
pt
        Mem Space
space | Var VName
v <- SubExp
se -> do
          forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. VName -> Space -> Code a
Imp.DeclareMem VName
tmp Space
space
          forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. VName -> VName -> Space -> Code a
Imp.SetMem VName
tmp VName
v Space
space
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. VName -> VName -> Space -> Code a
Imp.SetMem (forall dec. Param dec -> VName
paramName Param dec
p) VName
tmp Space
space
        Type
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ [ImpM rep r op ()]
copy_to_merge_params

compileStms :: Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms :: forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
alive_after_stms Stms rep
all_stms ImpM rep r op ()
m = do
  StmsCompiler rep r op
cb <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall rep r op. Env rep r op -> StmsCompiler rep r op
envStmsCompiler
  StmsCompiler rep r op
cb Names
alive_after_stms Stms rep
all_stms ImpM rep r op ()
m

defCompileStms ::
  (Mem rep inner, FreeIn op) =>
  Names ->
  Stms rep ->
  ImpM rep r op () ->
  ImpM rep r op ()
defCompileStms :: forall rep (inner :: * -> *) op r.
(Mem rep inner, FreeIn op) =>
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
defCompileStms Names
alive_after_stms Stms rep
all_stms ImpM rep r op ()
m =
  -- We keep track of any memory blocks produced by the statements,
  -- and after the last time that memory block is used, we insert a
  -- Free.  This is very conservative, but can cut down on lifetimes
  -- in some cases.
  forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ Set (VName, Space) -> [Stm rep] -> ImpM rep r op Names
compileStms' forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall rep. Stms rep -> [Stm rep]
stmsToList Stms rep
all_stms
  where
    compileStms' :: Set (VName, Space) -> [Stm rep] -> ImpM rep r op Names
compileStms' Set (VName, Space)
allocs (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux Exp rep
e : [Stm rep]
bs) = do
      forall rep (inner :: * -> *) r op.
Mem rep inner =>
Maybe (Exp rep) -> [PatElem (LetDec rep)] -> ImpM rep r op ()
dVars (forall a. a -> Maybe a
Just Exp rep
e) (forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat)

      Code op
e_code <-
        forall rep r op a. Attrs -> ImpM rep r op a -> ImpM rep r op a
localAttrs (forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux (ExpDec rep)
aux) forall a b. (a -> b) -> a -> b
$
          forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect forall a b. (a -> b) -> a -> b
$
            forall rep r op. Pat (LetDec rep) -> Exp rep -> ImpM rep r op ()
compileExp Pat (LetDec rep)
pat Exp rep
e
      (Names
live_after, Code op
bs_code) <- forall rep r op a. ImpM rep r op a -> ImpM rep r op (a, Code op)
collect' forall a b. (a -> b) -> a -> b
$ Set (VName, Space) -> [Stm rep] -> ImpM rep r op Names
compileStms' (Pat (LetDec rep) -> Set (VName, Space)
patternAllocs Pat (LetDec rep)
pat forall a. Semigroup a => a -> a -> a
<> Set (VName, Space)
allocs) [Stm rep]
bs
      let dies_here :: VName -> Bool
dies_here VName
v =
            (VName
v VName -> Names -> Bool
`notNameIn` Names
live_after) Bool -> Bool -> Bool
&& (VName
v VName -> Names -> Bool
`nameIn` forall a. FreeIn a => a -> Names
freeIn Code op
e_code)
          to_free :: Set (VName, Space)
to_free = forall a. (a -> Bool) -> Set a -> Set a
S.filter (VName -> Bool
dies_here forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) Set (VName, Space)
allocs

      forall op rep r. Code op -> ImpM rep r op ()
emit Code op
e_code
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall op rep r. Code op -> ImpM rep r op ()
emit forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry forall a. VName -> Space -> Code a
Imp.Free) Set (VName, Space)
to_free
      forall op rep r. Code op -> ImpM rep r op ()
emit Code op
bs_code

      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Names
freeIn Code op
e_code forall a. Semigroup a => a -> a -> a
<> Names
live_after
    compileStms' Set (VName, Space)
_ [] = do
      Code op
code <- forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect ImpM rep r op ()
m
      forall op rep r. Code op -> ImpM rep r op ()
emit Code op
code
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Names
freeIn Code op
code forall a. Semigroup a => a -> a -> a
<> Names
alive_after_stms

    patternAllocs :: Pat (LetDec rep) -> Set (VName, Space)
patternAllocs = forall a. Ord a => [a] -> Set a
S.fromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe forall {dec}. Typed dec => PatElem dec -> Maybe (VName, Space)
isMemPatElem forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Pat dec -> [PatElem dec]
patElems
    isMemPatElem :: PatElem dec -> Maybe (VName, Space)
isMemPatElem PatElem dec
pe = case forall dec. Typed dec => PatElem dec -> Type
patElemType PatElem dec
pe of
      Mem Space
space -> forall a. a -> Maybe a
Just (forall dec. PatElem dec -> VName
patElemName PatElem dec
pe, Space
space)
      Type
_ -> forall a. Maybe a
Nothing

compileExp :: Pat (LetDec rep) -> Exp rep -> ImpM rep r op ()
compileExp :: forall rep r op. Pat (LetDec rep) -> Exp rep -> ImpM rep r op ()
compileExp Pat (LetDec rep)
pat Exp rep
e = do
  Pat (LetDec rep) -> Exp rep -> ImpM rep r op ()
ec <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall rep r op. Env rep r op -> ExpCompiler rep r op
envExpCompiler
  Pat (LetDec rep) -> Exp rep -> ImpM rep r op ()
ec Pat (LetDec rep)
pat Exp rep
e

-- | Generate an expression that is true if the subexpressions match
-- the case pasttern.
caseMatch :: [SubExp] -> [Maybe PrimValue] -> Imp.TExp Bool
caseMatch :: [SubExp] -> [Maybe PrimValue] -> TExp Bool
caseMatch [SubExp]
ses [Maybe PrimValue]
vs = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
(.&&.) forall v. TPrimExp Bool v
true (forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall {a}. ToExp a => a -> Maybe PrimValue -> TExp Bool
cmp [SubExp]
ses [Maybe PrimValue]
vs)
  where
    cmp :: a -> Maybe PrimValue -> TExp Bool
cmp a
se (Just PrimValue
v) = forall v. PrimExp v -> TPrimExp Bool v
isBool forall a b. (a -> b) -> a -> b
$ forall a. ToExp a => PrimType -> a -> Exp
toExp' (PrimValue -> PrimType
primValueType PrimValue
v) a
se forall v. PrimExp v -> PrimExp v -> PrimExp v
~==~ forall v. PrimValue -> PrimExp v
ValueExp PrimValue
v
    cmp a
_ Maybe PrimValue
Nothing = forall v. TPrimExp Bool v
true

defCompileExp ::
  (Mem rep inner) =>
  Pat (LetDec rep) ->
  Exp rep ->
  ImpM rep r op ()
defCompileExp :: forall rep (inner :: * -> *) r op.
Mem rep inner =>
Pat (LetDec rep) -> Exp rep -> ImpM rep r op ()
defCompileExp Pat (LetDec rep)
pat (Match [SubExp]
ses [Case (Body rep)]
cases Body rep
defbody MatchDec (BranchType rep)
_) =
  forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Case (Body rep) -> ImpM rep r op () -> ImpM rep r op ()
f (forall rep r op. Pat (LetDec rep) -> Body rep -> ImpM rep r op ()
compileBody Pat (LetDec rep)
pat Body rep
defbody) [Case (Body rep)]
cases
  where
    f :: Case (Body rep) -> ImpM rep r op () -> ImpM rep r op ()
f (Case [Maybe PrimValue]
vs Body rep
body) = forall rep r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf ([SubExp] -> [Maybe PrimValue] -> TExp Bool
caseMatch [SubExp]
ses [Maybe PrimValue]
vs) (forall rep r op. Pat (LetDec rep) -> Body rep -> ImpM rep r op ()
compileBody Pat (LetDec rep)
pat Body rep
body)
defCompileExp Pat (LetDec rep)
pat (Apply Name
fname [(SubExp, Diet)]
args [RetType rep]
_ (Safety, SrcLoc, [SrcLoc])
_) = do
  [ValueDestination]
dest <- forall rep r op.
Pat (LetDec rep) -> ImpM rep r op [ValueDestination]
destinationFromPat Pat (LetDec rep)
pat
  [VName]
targets <- forall rep r op. [ValueDestination] -> ImpM rep r op [VName]
funcallTargets [ValueDestination]
dest
  [Arg]
args' <- forall a. [Maybe a] -> [a]
catMaybes forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {m :: * -> *} {t} {b}.
(Monad m, HasScope t m) =>
(SubExp, b) -> m (Maybe Arg)
compileArg [(SubExp, Diet)]
args
  forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. [VName] -> Name -> [Arg] -> Code a
Imp.Call [VName]
targets Name
fname [Arg]
args'
  where
    compileArg :: (SubExp, b) -> m (Maybe Arg)
compileArg (SubExp
se, b
_) = do
      Type
t <- forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
se
      case (SubExp
se, Type
t) of
        (SubExp
_, Prim PrimType
pt) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Exp -> Arg
Imp.ExpArg forall a b. (a -> b) -> a -> b
$ forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
pt SubExp
se
        (Var VName
v, Mem {}) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ VName -> Arg
Imp.MemArg VName
v
        (SubExp, Type)
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
defCompileExp Pat (LetDec rep)
pat (BasicOp BasicOp
op) = forall rep (inner :: * -> *) r op.
Mem rep inner =>
Pat (LetDec rep) -> BasicOp -> ImpM rep r op ()
defCompileBasicOp Pat (LetDec rep)
pat BasicOp
op
defCompileExp Pat (LetDec rep)
pat (DoLoop [(FParam rep, SubExp)]
merge LoopForm rep
form Body rep
body) = do
  Attrs
attrs <- forall rep r op. ImpM rep r op Attrs
askAttrs
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Attr
"unroll" Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs) forall a b. (a -> b) -> a -> b
$
    forall loc rep r op.
Located loc =>
loc -> [loc] -> Text -> ImpM rep r op ()
warn (forall a. IsLocation a => a
noLoc :: SrcLoc) [] Text
"#[unroll] on loop with unknown number of iterations." -- FIXME: no location.
  forall rep (inner :: * -> *) r op.
Mem rep inner =>
[FParam rep] -> ImpM rep r op ()
dFParams [Param FParamMem]
params
  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(FParam rep, SubExp)]
merge forall a b. (a -> b) -> a -> b
$ \(Param FParamMem
p, SubExp
se) ->
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ((forall a. Eq a => a -> a -> Bool
== Int
0) forall a b. (a -> b) -> a -> b
$ forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank forall a b. (a -> b) -> a -> b
$ forall dec. Typed dec => Param dec -> Type
paramType Param FParamMem
p) forall a b. (a -> b) -> a -> b
$
      forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (forall dec. Param dec -> VName
paramName Param FParamMem
p) [] SubExp
se []

  let doBody :: ImpM rep r op ()
doBody = forall dec rep r op.
Typed dec =>
[Param dec] -> Body rep -> ImpM rep r op ()
compileLoopBody [Param FParamMem]
params Body rep
body

  case LoopForm rep
form of
    ForLoop VName
i IntType
_ SubExp
bound [(LParam rep, VName)]
loopvars -> do
      let setLoopParam :: (Param LParamMem, VName) -> ImpM rep r op ()
setLoopParam (Param LParamMem
p, VName
a)
            | Prim PrimType
_ <- forall dec. Typed dec => Param dec -> Type
paramType Param LParamMem
p =
                forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var VName
a) [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ forall a. a -> TPrimExp Int64 a
Imp.le64 VName
i]
            | Bool
otherwise =
                forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

      Exp
bound' <- forall a rep r op. ToExp a => a -> ImpM rep r op Exp
toExp SubExp
bound

      forall rep (inner :: * -> *) r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(LParam rep, VName)]
loopvars
      forall rep r op.
VName -> Exp -> ImpM rep r op () -> ImpM rep r op ()
sFor' VName
i Exp
bound' forall a b. (a -> b) -> a -> b
$
        forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Param LParamMem, VName) -> ImpM rep r op ()
setLoopParam [(LParam rep, VName)]
loopvars forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ImpM rep r op ()
doBody
    WhileLoop VName
cond ->
      forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhile (forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
cond PrimType
Bool) ImpM rep r op ()
doBody

  [ValueDestination]
pat_dests <- forall rep r op.
Pat (LetDec rep) -> ImpM rep r op [ValueDestination]
destinationFromPat Pat (LetDec rep)
pat
  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [ValueDestination]
pat_dests forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> VName
paramName forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(FParam rep, SubExp)]
merge) forall a b. (a -> b) -> a -> b
$ \(ValueDestination
d, SubExp
r) ->
    forall rep r op.
ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIMDest ValueDestination
d [] SubExp
r []
  where
    params :: [Param FParamMem]
params = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(FParam rep, SubExp)]
merge
defCompileExp Pat (LetDec rep)
pat (WithAcc [WithAccInput rep]
inputs Lambda rep
lam) = do
  forall rep (inner :: * -> *) r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam
  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [WithAccInput rep]
inputs forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam) forall a b. (a -> b) -> a -> b
$ \((Shape
_, [VName]
arrs, Maybe (Lambda rep, [SubExp])
op), Param LParamMem
p) ->
    forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s ->
      ImpState rep r op
s {stateAccs :: Map VName ([VName], Maybe (Lambda rep, [SubExp]))
stateAccs = forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (forall dec. Param dec -> VName
paramName Param LParamMem
p) ([VName]
arrs, Maybe (Lambda rep, [SubExp])
op) forall a b. (a -> b) -> a -> b
$ forall {k} rep (r :: k) op.
ImpState rep r op
-> Map VName ([VName], Maybe (Lambda rep, [SubExp]))
stateAccs ImpState rep r op
s}
  forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall rep. Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam) forall a b. (a -> b) -> a -> b
$ do
    let nonacc_res :: Result
nonacc_res = forall a. Int -> [a] -> [a]
drop Int
num_accs (forall rep. Body rep -> Result
bodyResult (forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam))
        nonacc_pat_names :: [VName]
nonacc_pat_names = forall a. Int -> [a] -> [a]
takeLast (forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
nonacc_res) (forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat)
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
nonacc_pat_names Result
nonacc_res) forall a b. (a -> b) -> a -> b
$ \(VName
v, SubExpRes Certs
_ SubExp
se) ->
      forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM VName
v [] SubExp
se []
  where
    num_accs :: Int
num_accs = forall (t :: * -> *) a. Foldable t => t a -> Int
length [WithAccInput rep]
inputs
defCompileExp Pat (LetDec rep)
pat (Op Op rep
op) = do
  Pat (LetDec rep) -> MemOp inner rep -> ImpM rep r op ()
opc <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall rep r op. Env rep r op -> OpCompiler rep r op
envOpCompiler
  Pat (LetDec rep) -> MemOp inner rep -> ImpM rep r op ()
opc Pat (LetDec rep)
pat Op rep
op

tracePrim :: T.Text -> PrimType -> SubExp -> ImpM rep r op ()
tracePrim :: forall rep r op. Text -> PrimType -> SubExp -> ImpM rep r op ()
tracePrim Text
s PrimType
t SubExp
se =
  forall op rep r. Code op -> ImpM rep r op ()
emit forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. ErrorMsg Exp -> Code a
Imp.TracePrint forall a b. (a -> b) -> a -> b
$
    forall a. [ErrorMsgPart a] -> ErrorMsg a
ErrorMsg [forall a. Text -> ErrorMsgPart a
ErrorString (Text
s forall a. Semigroup a => a -> a -> a
<> Text
": "), forall a. PrimType -> a -> ErrorMsgPart a
ErrorVal PrimType
t (forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
t SubExp
se), forall a. Text -> ErrorMsgPart a
ErrorString Text
"\n"]

traceArray :: T.Text -> PrimType -> Shape -> SubExp -> ImpM rep r op ()
traceArray :: forall rep r op.
Text -> PrimType -> Shape -> SubExp -> ImpM rep r op ()
traceArray Text
s PrimType
t Shape
shape SubExp
se = do
  forall op rep r. Code op -> ImpM rep r op ()
emit forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. ErrorMsg Exp -> Code a
Imp.TracePrint forall a b. (a -> b) -> a -> b
$ forall a. [ErrorMsgPart a] -> ErrorMsg a
ErrorMsg [forall a. Text -> ErrorMsgPart a
ErrorString (Text
s forall a. Semigroup a => a -> a -> a
<> Text
": ")]
  forall rep r op.
Shape -> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest Shape
shape forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
is -> do
    TV Any
arr_elem <- forall {k} rep r op (t :: k).
[Char] -> PrimType -> ImpM rep r op (TV t)
dPrim [Char]
"arr_elem" PrimType
t
    forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall {k} (t :: k). TV t -> VName
tvVar TV Any
arr_elem) [] SubExp
se [TExp Int64]
is
    forall op rep r. Code op -> ImpM rep r op ()
emit forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. ErrorMsg Exp -> Code a
Imp.TracePrint forall a b. (a -> b) -> a -> b
$ forall a. [ErrorMsgPart a] -> ErrorMsg a
ErrorMsg [forall a. PrimType -> a -> ErrorMsgPart a
ErrorVal PrimType
t (forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (forall {k} (t :: k). TV t -> TExp t
tvExp TV Any
arr_elem)), ErrorMsgPart Exp
" "]
  forall op rep r. Code op -> ImpM rep r op ()
emit forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. ErrorMsg Exp -> Code a
Imp.TracePrint forall a b. (a -> b) -> a -> b
$ forall a. [ErrorMsgPart a] -> ErrorMsg a
ErrorMsg [ErrorMsgPart Exp
"\n"]

defCompileBasicOp ::
  Mem rep inner =>
  Pat (LetDec rep) ->
  BasicOp ->
  ImpM rep r op ()
defCompileBasicOp :: forall rep (inner :: * -> *) r op.
Mem rep inner =>
Pat (LetDec rep) -> BasicOp -> ImpM rep r op ()
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (SubExp SubExp
se) =
  forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe) [] SubExp
se []
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (Opaque OpaqueOp
op SubExp
se) = do
  forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe) [] SubExp
se []
  case OpaqueOp
op of
    OpaqueOp
OpaqueNil -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    OpaqueTrace Text
s -> forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
comment (Text
"Trace: " forall a. Semigroup a => a -> a -> a
<> Text
s) forall a b. (a -> b) -> a -> b
$ do
      Type
se_t <- forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
se
      case Type
se_t of
        Prim PrimType
t -> forall rep r op. Text -> PrimType -> SubExp -> ImpM rep r op ()
tracePrim Text
s PrimType
t SubExp
se
        Array PrimType
t Shape
shape NoUniqueness
_ -> forall rep r op.
Text -> PrimType -> Shape -> SubExp -> ImpM rep r op ()
traceArray Text
s PrimType
t Shape
shape SubExp
se
        Type
_ ->
          forall loc rep r op.
Located loc =>
loc -> [loc] -> Text -> ImpM rep r op ()
warn [forall a. Monoid a => a
mempty :: SrcLoc] forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$
            Text
s forall a. Semigroup a => a -> a -> a
<> Text
": cannot trace value of this (core) type: " forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText Type
se_t
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (UnOp UnOp
op SubExp
e) = do
  Exp
e' <- forall a rep r op. ToExp a => a -> ImpM rep r op Exp
toExp SubExp
e
  forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe forall rep r op. VName -> Exp -> ImpM rep r op ()
<~~ forall v. UnOp -> PrimExp v -> PrimExp v
Imp.UnOpExp UnOp
op Exp
e'
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (ConvOp ConvOp
conv SubExp
e) = do
  Exp
e' <- forall a rep r op. ToExp a => a -> ImpM rep r op Exp
toExp SubExp
e
  forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe forall rep r op. VName -> Exp -> ImpM rep r op ()
<~~ forall v. ConvOp -> PrimExp v -> PrimExp v
Imp.ConvOpExp ConvOp
conv Exp
e'
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (BinOp BinOp
bop SubExp
x SubExp
y) = do
  Exp
x' <- forall a rep r op. ToExp a => a -> ImpM rep r op Exp
toExp SubExp
x
  Exp
y' <- forall a rep r op. ToExp a => a -> ImpM rep r op Exp
toExp SubExp
y
  forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe forall rep r op. VName -> Exp -> ImpM rep r op ()
<~~ forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp BinOp
bop Exp
x' Exp
y'
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (CmpOp CmpOp
bop SubExp
x SubExp
y) = do
  Exp
x' <- forall a rep r op. ToExp a => a -> ImpM rep r op Exp
toExp SubExp
x
  Exp
y' <- forall a rep r op. ToExp a => a -> ImpM rep r op Exp
toExp SubExp
y
  forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe forall rep r op. VName -> Exp -> ImpM rep r op ()
<~~ forall v. CmpOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.CmpOpExp CmpOp
bop Exp
x' Exp
y'
defCompileBasicOp Pat (LetDec rep)
_ (Assert SubExp
e ErrorMsg SubExp
msg (SrcLoc, [SrcLoc])
loc) = do
  Exp
e' <- forall a rep r op. ToExp a => a -> ImpM rep r op Exp
toExp SubExp
e
  ErrorMsg Exp
msg' <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall a rep r op. ToExp a => a -> ImpM rep r op Exp
toExp ErrorMsg SubExp
msg
  forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. Exp -> ErrorMsg Exp -> (SrcLoc, [SrcLoc]) -> Code a
Imp.Assert Exp
e' ErrorMsg Exp
msg' (SrcLoc, [SrcLoc])
loc

  Attrs
attrs <- forall rep r op. ImpM rep r op Attrs
askAttrs
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Name -> [Attr] -> Attr
AttrComp Name
"warn" [Attr
"safety_checks"] Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs) forall a b. (a -> b) -> a -> b
$
    forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry forall loc rep r op.
Located loc =>
loc -> [loc] -> Text -> ImpM rep r op ()
warn (SrcLoc, [SrcLoc])
loc Text
"Safety check required at run-time."
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (Index VName
src Slice SubExp
slice)
  | Just [SubExp]
idxs <- forall d. Slice d -> Maybe [d]
sliceIndices Slice SubExp
slice =
      forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe) [] (VName -> SubExp
Var VName
src) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall d. d -> DimIndex d
DimFix forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> TExp Int64
pe64) [SubExp]
idxs
defCompileBasicOp Pat (LetDec rep)
_ Index {} =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (Update Safety
safety VName
_ Slice SubExp
slice SubExp
se) =
  case Safety
safety of
    Safety
Unsafe -> ImpM rep r op ()
write
    Safety
Safe -> forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (Slice (TExp Int64) -> [TExp Int64] -> TExp Bool
inBounds Slice (TExp Int64)
slice' [TExp Int64]
dims) ImpM rep r op ()
write
  where
    slice' :: Slice (TExp Int64)
slice' = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
pe64 Slice SubExp
slice
    dims :: [TExp Int64]
dims = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 forall a b. (a -> b) -> a -> b
$ forall u. TypeBase Shape u -> [SubExp]
arrayDims forall a b. (a -> b) -> a -> b
$ forall dec. Typed dec => PatElem dec -> Type
patElemType PatElem (LetDec rep)
pe
    write :: ImpM rep r op ()
write = forall rep r op.
VName -> Slice (TExp Int64) -> SubExp -> ImpM rep r op ()
sUpdate (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe) Slice (TExp Int64)
slice' SubExp
se
defCompileBasicOp Pat (LetDec rep)
_ FlatIndex {} =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (FlatUpdate VName
_ FlatSlice SubExp
slice VName
v) = do
  MemLoc
pe_loc <- ArrayEntry -> MemLoc
entryArrayLoc forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe)
  MemLoc
v_loc <- ArrayEntry -> MemLoc
entryArrayLoc forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray VName
v
  forall rep r op. CopyCompiler rep r op
copy (forall shape u. TypeBase shape u -> PrimType
elemType (forall dec. Typed dec => PatElem dec -> Type
patElemType PatElem (LetDec rep)
pe)) (MemLoc -> FlatSlice (TExp Int64) -> MemLoc
flatSliceMemLoc MemLoc
pe_loc FlatSlice (TExp Int64)
slice') MemLoc
v_loc
  where
    slice' :: FlatSlice (TExp Int64)
slice' = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
pe64 FlatSlice SubExp
slice
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (Replicate Shape
shape SubExp
se)
  | Acc {} <- forall dec. Typed dec => PatElem dec -> Type
patElemType PatElem (LetDec rep)
pe = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  | Bool
otherwise =
      forall rep r op.
Shape -> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest Shape
shape forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
is -> forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe) [TExp Int64]
is SubExp
se []
defCompileBasicOp Pat (LetDec rep)
_ Scratch {} =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (Iota SubExp
n SubExp
e SubExp
s IntType
it) = do
  Exp
e' <- forall a rep r op. ToExp a => a -> ImpM rep r op Exp
toExp SubExp
e
  Exp
s' <- forall a rep r op. ToExp a => a -> ImpM rep r op Exp
toExp SubExp
s
  forall {k} (t :: k) rep r op.
[Char]
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor [Char]
"i" (SubExp -> TExp Int64
pe64 SubExp
n) forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
    let i' :: Exp
i' = forall v. IntType -> PrimExp v -> PrimExp v
sExt IntType
it forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Int64
i
    TV Any
x <-
      forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
"x" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp forall a b. (a -> b) -> a -> b
$
        forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> Overflow -> BinOp
Add IntType
it Overflow
OverflowUndef) Exp
e' forall a b. (a -> b) -> a -> b
$
          forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> Overflow -> BinOp
Mul IntType
it Overflow
OverflowUndef) Exp
i' Exp
s'
    forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe) [forall d. d -> DimIndex d
DimFix TExp Int64
i] (VName -> SubExp
Var (forall {k} (t :: k). TV t -> VName
tvVar TV Any
x)) []
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (Copy VName
src) =
  forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe) [] (VName -> SubExp
Var VName
src) []
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (Manifest [Int]
_ VName
src) =
  forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe) [] (VName -> SubExp
Var VName
src) []
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (Concat Int
i (VName
x :| [VName]
ys) SubExp
_) = do
  TV Int64
offs_glb <- forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
"tmp_offs" TExp Int64
0

  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (VName
x forall a. a -> [a] -> [a]
: [VName]
ys) forall a b. (a -> b) -> a -> b
$ \VName
y -> do
    [SubExp]
y_dims <- forall u. TypeBase Shape u -> [SubExp]
arrayDims forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
y
    let rows :: TExp Int64
rows = case forall a. Int -> [a] -> [a]
drop Int
i [SubExp]
y_dims of
          [] -> forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"defCompileBasicOp Concat: empty array shape for " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString VName
y
          SubExp
r : [SubExp]
_ -> SubExp -> TExp Int64
pe64 SubExp
r
        skip_dims :: [SubExp]
skip_dims = forall a. Int -> [a] -> [a]
take Int
i [SubExp]
y_dims
        sliceAllDim :: d -> DimIndex d
sliceAllDim d
d = forall d. d -> d -> d -> DimIndex d
DimSlice d
0 d
d d
1
        skip_slices :: [DimIndex (TExp Int64)]
skip_slices = forall a b. (a -> b) -> [a] -> [b]
map (forall {d}. Num d => d -> DimIndex d
sliceAllDim forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> TExp Int64
pe64) [SubExp]
skip_dims
        destslice :: [DimIndex (TExp Int64)]
destslice = [DimIndex (TExp Int64)]
skip_slices forall a. [a] -> [a] -> [a]
++ [forall d. d -> d -> d -> DimIndex d
DimSlice (forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
offs_glb) TExp Int64
rows TExp Int64
1]
    forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe) [DimIndex (TExp Int64)]
destslice (VName -> SubExp
Var VName
y) []
    TV Int64
offs_glb forall {k} (t :: k) rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
offs_glb forall a. Num a => a -> a -> a
+ TExp Int64
rows
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (ArrayLit [SubExp]
es Type
_)
  | Just vs :: [PrimValue]
vs@(PrimValue
v : [PrimValue]
_) <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> Maybe PrimValue
isLiteral [SubExp]
es = do
      MemLoc
dest_mem <- ArrayEntry -> MemLoc
entryArrayLoc forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe)
      let t :: PrimType
t = PrimValue -> PrimType
primValueType PrimValue
v
      VName
static_array <- forall rep r op. [Char] -> ImpM rep r op VName
newVNameForFun [Char]
"static_array"
      forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. VName -> PrimType -> ArrayContents -> Code a
Imp.DeclareArray VName
static_array PrimType
t forall a b. (a -> b) -> a -> b
$ [PrimValue] -> ArrayContents
Imp.ArrayValues [PrimValue]
vs
      let static_src :: MemLoc
static_src =
            VName -> [SubExp] -> IxFun (TExp Int64) -> MemLoc
MemLoc VName
static_array [IntType -> Integer -> SubExp
intConst IntType
Int64 forall a b. (a -> b) -> a -> b
$ forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
es] forall a b. (a -> b) -> a -> b
$
              forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota [forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
es]
      forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
static_array forall a b. (a -> b) -> a -> b
$ forall rep. Maybe (Exp rep) -> MemEntry -> VarEntry rep
MemVar forall a. Maybe a
Nothing forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
DefaultSpace
      forall rep r op. CopyCompiler rep r op
copy PrimType
t MemLoc
dest_mem MemLoc
static_src
  | Bool
otherwise =
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Integer
0 ..] [SubExp]
es) forall a b. (a -> b) -> a -> b
$ \(Integer
i, SubExp
e) ->
        forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe) [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ forall a. Num a => Integer -> a
fromInteger Integer
i] SubExp
e []
  where
    isLiteral :: SubExp -> Maybe PrimValue
isLiteral (Constant PrimValue
v) = forall a. a -> Maybe a
Just PrimValue
v
    isLiteral SubExp
_ = forall a. Maybe a
Nothing
defCompileBasicOp Pat (LetDec rep)
_ Rearrange {} =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (Rotate [SubExp]
rs VName
arr) = do
  Shape
shape <- forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
  forall rep r op.
Shape -> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest Shape
shape forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
is -> do
    [TExp Int64]
is' <- forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence forall a b. (a -> b) -> a -> b
$ forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 forall {rep} {r} {op}.
SubExp -> SubExp -> TExp Int64 -> ImpM rep r op (TExp Int64)
rotate (forall d. ShapeBase d -> [d]
shapeDims Shape
shape) [SubExp]
rs [TExp Int64]
is
    forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe) [TExp Int64]
is (VName -> SubExp
Var VName
arr) [TExp Int64]
is'
  where
    rotate :: SubExp -> SubExp -> TExp Int64 -> ImpM rep r op (TExp Int64)
rotate SubExp
d SubExp
r TExp Int64
i = forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"rot_i" forall a b. (a -> b) -> a -> b
$ TExp Int64 -> TExp Int64 -> TExp Int64 -> TExp Int64
rotateIndex (SubExp -> TExp Int64
pe64 SubExp
d) (SubExp -> TExp Int64
pe64 SubExp
r) TExp Int64
i
defCompileBasicOp Pat (LetDec rep)
_ Reshape {} =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
defCompileBasicOp Pat (LetDec rep)
_ (UpdateAcc VName
acc [SubExp]
is [SubExp]
vs) = forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"UpdateAcc" forall a b. (a -> b) -> a -> b
$ do
  -- We are abusing the comment mechanism to wrap the operator in
  -- braces when we end up generating code.  This is necessary because
  -- we might otherwise end up declaring lambda parameters (if any)
  -- multiple times, as they are duplicated every time we do an
  -- UpdateAcc for the same accumulator.
  let is' :: [TExp Int64]
is' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 [SubExp]
is

  -- We need to figure out whether we are updating a scatter-like
  -- accumulator or a generalised reduction.  This also binds the
  -- index parameters.
  (VName
_, Space
_, [VName]
arrs, [TExp Int64]
dims, Maybe (Lambda rep)
op) <- forall rep r op.
VName
-> [TExp Int64]
-> ImpM
     rep r op (VName, Space, [VName], [TExp Int64], Maybe (Lambda rep))
lookupAcc VName
acc [TExp Int64]
is'

  forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (Slice (TExp Int64) -> [TExp Int64] -> TExp Bool
inBounds (forall d. [DimIndex d] -> Slice d
Slice (forall a b. (a -> b) -> [a] -> [b]
map forall d. d -> DimIndex d
DimFix [TExp Int64]
is')) [TExp Int64]
dims) forall a b. (a -> b) -> a -> b
$
    case Maybe (Lambda rep)
op of
      Maybe (Lambda rep)
Nothing ->
        -- Scatter-like.
        forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
arrs [SubExp]
vs) forall a b. (a -> b) -> a -> b
$ \(VName
arr, SubExp
v) -> forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
arr [TExp Int64]
is' SubExp
v []
      Just Lambda rep
lam -> do
        -- Generalised reduction.
        forall rep (inner :: * -> *) r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam
        let ([VName]
x_params, [VName]
y_params) =
              forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
vs) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam

        forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
x_params [VName]
arrs) forall a b. (a -> b) -> a -> b
$ \(VName
xp, VName
arr) ->
          forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
xp [] (VName -> SubExp
Var VName
arr) [TExp Int64]
is'

        forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
y_params [SubExp]
vs) forall a b. (a -> b) -> a -> b
$ \(VName
yp, SubExp
v) ->
          forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM VName
yp [] SubExp
v []

        forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall rep. Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam) forall a b. (a -> b) -> a -> b
$
          forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
arrs (forall rep. Body rep -> Result
bodyResult (forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam))) forall a b. (a -> b) -> a -> b
$ \(VName
arr, SubExpRes Certs
_ SubExp
se) ->
            forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
arr [TExp Int64]
is' SubExp
se []
defCompileBasicOp Pat (LetDec rep)
pat BasicOp
e =
  forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$
    [Char]
"ImpGen.defCompileBasicOp: Invalid pattern\n  "
      forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString Pat (LetDec rep)
pat
      forall a. [a] -> [a] -> [a]
++ [Char]
"\nfor expression\n  "
      forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString BasicOp
e

-- | Note: a hack to be used only for functions.
addArrays :: [ArrayDecl] -> ImpM rep r op ()
addArrays :: forall rep r op. [ArrayDecl] -> ImpM rep r op ()
addArrays = forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall {rep} {r} {op}. ArrayDecl -> ImpM rep r op ()
addArray
  where
    addArray :: ArrayDecl -> ImpM rep r op ()
addArray (ArrayDecl VName
name PrimType
bt MemLoc
location) =
      forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
name forall a b. (a -> b) -> a -> b
$
        forall rep. Maybe (Exp rep) -> ArrayEntry -> VarEntry rep
ArrayVar
          forall a. Maybe a
Nothing
          ArrayEntry
            { entryArrayLoc :: MemLoc
entryArrayLoc = MemLoc
location,
              entryArrayElemType :: PrimType
entryArrayElemType = PrimType
bt
            }

-- | Like 'dFParams', but does not create new declarations.
-- Note: a hack to be used only for functions.
addFParams :: Mem rep inner => [FParam rep] -> ImpM rep r op ()
addFParams :: forall rep (inner :: * -> *) r op.
Mem rep inner =>
[FParam rep] -> ImpM rep r op ()
addFParams = forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall {u} {rep} {r} {op}.
Param (MemInfo SubExp u MemBind) -> ImpM rep r op ()
addFParam
  where
    addFParam :: Param (MemInfo SubExp u MemBind) -> ImpM rep r op ()
addFParam Param (MemInfo SubExp u MemBind)
fparam =
      forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar (forall dec. Param dec -> VName
paramName Param (MemInfo SubExp u MemBind)
fparam) forall a b. (a -> b) -> a -> b
$
        forall rep. Maybe (Exp rep) -> LParamMem -> VarEntry rep
memBoundToVarEntry forall a. Maybe a
Nothing forall a b. (a -> b) -> a -> b
$
          forall d u r. MemInfo d u r -> MemInfo d NoUniqueness r
noUniquenessReturns forall a b. (a -> b) -> a -> b
$
            forall dec. Param dec -> dec
paramDec Param (MemInfo SubExp u MemBind)
fparam

-- | Another hack.
addLoopVar :: VName -> IntType -> ImpM rep r op ()
addLoopVar :: forall rep r op. VName -> IntType -> ImpM rep r op ()
addLoopVar VName
i IntType
it = forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
i forall a b. (a -> b) -> a -> b
$ forall rep. Maybe (Exp rep) -> ScalarEntry -> VarEntry rep
ScalarVar forall a. Maybe a
Nothing forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarEntry
ScalarEntry forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
it

dVars ::
  Mem rep inner =>
  Maybe (Exp rep) ->
  [PatElem (LetDec rep)] ->
  ImpM rep r op ()
dVars :: forall rep (inner :: * -> *) r op.
Mem rep inner =>
Maybe (Exp rep) -> [PatElem (LetDec rep)] -> ImpM rep r op ()
dVars Maybe (Exp rep)
e = forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ PatElem (LetDec rep) -> ImpM rep r op ()
dVar
  where
    dVar :: PatElem (LetDec rep) -> ImpM rep r op ()
dVar = forall rep (inner :: * -> *) r op.
Mem rep inner =>
Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
dScope Maybe (Exp rep)
e forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep dec. (LetDec rep ~ dec) => PatElem dec -> Scope rep
scopeOfPatElem

dFParams :: Mem rep inner => [FParam rep] -> ImpM rep r op ()
dFParams :: forall rep (inner :: * -> *) r op.
Mem rep inner =>
[FParam rep] -> ImpM rep r op ()
dFParams = forall rep (inner :: * -> *) r op.
Mem rep inner =>
Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
dScope forall a. Maybe a
Nothing forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams

dLParams :: Mem rep inner => [LParam rep] -> ImpM rep r op ()
dLParams :: forall rep (inner :: * -> *) r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams = forall rep (inner :: * -> *) r op.
Mem rep inner =>
Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
dScope forall a. Maybe a
Nothing forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams

dPrimVol :: String -> PrimType -> Imp.TExp t -> ImpM rep r op (TV t)
dPrimVol :: forall {k} (t :: k) rep r op.
[Char] -> PrimType -> TExp t -> ImpM rep r op (TV t)
dPrimVol [Char]
name PrimType
t TExp t
e = do
  VName
name' <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
name
  forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
name' Volatility
Imp.Volatile PrimType
t
  forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
name' forall a b. (a -> b) -> a -> b
$ forall rep. Maybe (Exp rep) -> ScalarEntry -> VarEntry rep
ScalarVar forall a. Maybe a
Nothing forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarEntry
ScalarEntry PrimType
t
  VName
name' forall rep r op. VName -> Exp -> ImpM rep r op ()
<~~ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp t
e
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). VName -> PrimType -> TV t
TV VName
name' PrimType
t

dPrim_ :: VName -> PrimType -> ImpM rep r op ()
dPrim_ :: forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
name PrimType
t = do
  forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
name Volatility
Imp.Nonvolatile PrimType
t
  forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
name forall a b. (a -> b) -> a -> b
$ forall rep. Maybe (Exp rep) -> ScalarEntry -> VarEntry rep
ScalarVar forall a. Maybe a
Nothing forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarEntry
ScalarEntry PrimType
t

-- | The return type is polymorphic, so there is no guarantee it
-- actually matches the 'PrimType', but at least we have to use it
-- consistently.
dPrim :: String -> PrimType -> ImpM rep r op (TV t)
dPrim :: forall {k} rep r op (t :: k).
[Char] -> PrimType -> ImpM rep r op (TV t)
dPrim [Char]
name PrimType
t = do
  VName
name' <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
name
  forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
name' PrimType
t
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). VName -> PrimType -> TV t
TV VName
name' PrimType
t

dPrimV_ :: VName -> Imp.TExp t -> ImpM rep r op ()
dPrimV_ :: forall {k} (t :: k) rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
name TExp t
e = do
  forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
name PrimType
t
  forall {k} (t :: k). VName -> PrimType -> TV t
TV VName
name PrimType
t forall {k} (t :: k) rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TExp t
e
  where
    t :: PrimType
t = forall v. PrimExp v -> PrimType
primExpType forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp t
e

dPrimV :: String -> Imp.TExp t -> ImpM rep r op (TV t)
dPrimV :: forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
name TExp t
e = do
  TV t
name' <- forall {k} rep r op (t :: k).
[Char] -> PrimType -> ImpM rep r op (TV t)
dPrim [Char]
name forall a b. (a -> b) -> a -> b
$ forall v. PrimExp v -> PrimType
primExpType forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp t
e
  TV t
name' forall {k} (t :: k) rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TExp t
e
  forall (f :: * -> *) a. Applicative f => a -> f a
pure TV t
name'

dPrimVE :: String -> Imp.TExp t -> ImpM rep r op (Imp.TExp t)
dPrimVE :: forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
name TExp t
e = do
  TV t
name' <- forall {k} rep r op (t :: k).
[Char] -> PrimType -> ImpM rep r op (TV t)
dPrim [Char]
name forall a b. (a -> b) -> a -> b
$ forall v. PrimExp v -> PrimType
primExpType forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp t
e
  TV t
name' forall {k} (t :: k) rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TExp t
e
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> TExp t
tvExp TV t
name'

memBoundToVarEntry ::
  Maybe (Exp rep) ->
  MemBound NoUniqueness ->
  VarEntry rep
memBoundToVarEntry :: forall rep. Maybe (Exp rep) -> LParamMem -> VarEntry rep
memBoundToVarEntry Maybe (Exp rep)
e (MemPrim PrimType
bt) =
  forall rep. Maybe (Exp rep) -> ScalarEntry -> VarEntry rep
ScalarVar Maybe (Exp rep)
e ScalarEntry {entryScalarType :: PrimType
entryScalarType = PrimType
bt}
memBoundToVarEntry Maybe (Exp rep)
e (MemMem Space
space) =
  forall rep. Maybe (Exp rep) -> MemEntry -> VarEntry rep
MemVar Maybe (Exp rep)
e forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
space
memBoundToVarEntry Maybe (Exp rep)
e (MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
_) =
  forall rep.
Maybe (Exp rep) -> (VName, Shape, [Type]) -> VarEntry rep
AccVar Maybe (Exp rep)
e (VName
acc, Shape
ispace, [Type]
ts)
memBoundToVarEntry Maybe (Exp rep)
e (MemArray PrimType
bt Shape
shape NoUniqueness
_ (ArrayIn VName
mem IxFun (TExp Int64)
ixfun)) =
  let location :: MemLoc
location = VName -> [SubExp] -> IxFun (TExp Int64) -> MemLoc
MemLoc VName
mem (forall d. ShapeBase d -> [d]
shapeDims Shape
shape) IxFun (TExp Int64)
ixfun
   in forall rep. Maybe (Exp rep) -> ArrayEntry -> VarEntry rep
ArrayVar
        Maybe (Exp rep)
e
        ArrayEntry
          { entryArrayLoc :: MemLoc
entryArrayLoc = MemLoc
location,
            entryArrayElemType :: PrimType
entryArrayElemType = PrimType
bt
          }

infoDec ::
  Mem rep inner =>
  NameInfo rep ->
  MemInfo SubExp NoUniqueness MemBind
infoDec :: forall rep (inner :: * -> *).
Mem rep inner =>
NameInfo rep -> LParamMem
infoDec (LetName LetDec rep
dec) = forall t. HasLetDecMem t => t -> LParamMem
letDecMem LetDec rep
dec
infoDec (FParamName FParamInfo rep
dec) = forall d u r. MemInfo d u r -> MemInfo d NoUniqueness r
noUniquenessReturns FParamInfo rep
dec
infoDec (LParamName LParamInfo rep
dec) = LParamInfo rep
dec
infoDec (IndexName IntType
it) = forall d u ret. PrimType -> MemInfo d u ret
MemPrim forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
it

dInfo ::
  Mem rep inner =>
  Maybe (Exp rep) ->
  VName ->
  NameInfo rep ->
  ImpM rep r op ()
dInfo :: forall rep (inner :: * -> *) r op.
Mem rep inner =>
Maybe (Exp rep) -> VName -> NameInfo rep -> ImpM rep r op ()
dInfo Maybe (Exp rep)
e VName
name NameInfo rep
info = do
  let entry :: VarEntry rep
entry = forall rep. Maybe (Exp rep) -> LParamMem -> VarEntry rep
memBoundToVarEntry Maybe (Exp rep)
e forall a b. (a -> b) -> a -> b
$ forall rep (inner :: * -> *).
Mem rep inner =>
NameInfo rep -> LParamMem
infoDec NameInfo rep
info
  case VarEntry rep
entry of
    MemVar Maybe (Exp rep)
_ MemEntry
entry' ->
      forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. VName -> Space -> Code a
Imp.DeclareMem VName
name forall a b. (a -> b) -> a -> b
$ MemEntry -> Space
entryMemSpace MemEntry
entry'
    ScalarVar Maybe (Exp rep)
_ ScalarEntry
entry' ->
      forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
name Volatility
Imp.Nonvolatile forall a b. (a -> b) -> a -> b
$ ScalarEntry -> PrimType
entryScalarType ScalarEntry
entry'
    ArrayVar Maybe (Exp rep)
_ ArrayEntry
_ ->
      forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    AccVar {} ->
      forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
name VarEntry rep
entry

dScope ::
  Mem rep inner =>
  Maybe (Exp rep) ->
  Scope rep ->
  ImpM rep r op ()
dScope :: forall rep (inner :: * -> *) r op.
Mem rep inner =>
Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
dScope Maybe (Exp rep)
e = forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry forall a b. (a -> b) -> a -> b
$ forall rep (inner :: * -> *) r op.
Mem rep inner =>
Maybe (Exp rep) -> VName -> NameInfo rep -> ImpM rep r op ()
dInfo Maybe (Exp rep)
e) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k a. Map k a -> [(k, a)]
M.toList

dArray :: VName -> PrimType -> ShapeBase SubExp -> VName -> IxFun -> ImpM rep r op ()
dArray :: forall rep r op.
VName
-> PrimType
-> Shape
-> VName
-> IxFun (TExp Int64)
-> ImpM rep r op ()
dArray VName
name PrimType
pt Shape
shape VName
mem IxFun (TExp Int64)
ixfun =
  forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
name forall a b. (a -> b) -> a -> b
$ forall rep. Maybe (Exp rep) -> ArrayEntry -> VarEntry rep
ArrayVar forall a. Maybe a
Nothing forall a b. (a -> b) -> a -> b
$ MemLoc -> PrimType -> ArrayEntry
ArrayEntry MemLoc
location PrimType
pt
  where
    location :: MemLoc
location =
      VName -> [SubExp] -> IxFun (TExp Int64) -> MemLoc
MemLoc VName
mem (forall d. ShapeBase d -> [d]
shapeDims Shape
shape) IxFun (TExp Int64)
ixfun

everythingVolatile :: ImpM rep r op a -> ImpM rep r op a
everythingVolatile :: forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile = forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local forall a b. (a -> b) -> a -> b
$ \Env rep r op
env -> Env rep r op
env {envVolatility :: Volatility
envVolatility = Volatility
Imp.Volatile}

funcallTargets :: [ValueDestination] -> ImpM rep r op [VName]
funcallTargets :: forall rep r op. [ValueDestination] -> ImpM rep r op [VName]
funcallTargets [ValueDestination]
dests =
  forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {f :: * -> *}.
Applicative f =>
ValueDestination -> f [VName]
funcallTarget [ValueDestination]
dests
  where
    funcallTarget :: ValueDestination -> f [VName]
funcallTarget (ScalarDestination VName
name) =
      forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName
name]
    funcallTarget (ArrayDestination Maybe MemLoc
_) =
      forall (f :: * -> *) a. Applicative f => a -> f a
pure []
    funcallTarget (MemoryDestination VName
name) =
      forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName
name]

-- | A typed variable, which we can turn into a typed expression, or
-- use as the target for an assignment.  This is used to aid in type
-- safety when doing code generation, by keeping the types straight.
-- It is still easy to cheat when you need to.
data TV t = TV VName PrimType

-- | Create a typed variable from a name and a dynamic type.  Note
-- that there is no guarantee that the dynamic type corresponds to the
-- inferred static type, but the latter will at least have to be used
-- consistently.
mkTV :: VName -> PrimType -> TV t
mkTV :: forall {k} (t :: k). VName -> PrimType -> TV t
mkTV = forall {k} (t :: k). VName -> PrimType -> TV t
TV

-- | Convert a typed variable to a size (a SubExp).
tvSize :: TV t -> Imp.DimSize
tvSize :: forall {k} (t :: k). TV t -> SubExp
tvSize = VName -> SubExp
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (t :: k). TV t -> VName
tvVar

-- | Convert a typed variable to a similarly typed expression.
tvExp :: TV t -> Imp.TExp t
tvExp :: forall {k} (t :: k). TV t -> TExp t
tvExp (TV VName
v PrimType
t) = forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
Imp.TPrimExp forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
v PrimType
t

-- | Extract the underlying variable name from a typed variable.
tvVar :: TV t -> VName
tvVar :: forall {k} (t :: k). TV t -> VName
tvVar (TV VName
v PrimType
_) = VName
v

-- | Compile things to 'Imp.Exp'.
class ToExp a where
  -- | Compile to an 'Imp.Exp', where the type (which must still be a
  -- primitive) is deduced monadically.
  toExp :: a -> ImpM rep r op Imp.Exp

  -- | Compile where we know the type in advance.
  toExp' :: PrimType -> a -> Imp.Exp

instance ToExp SubExp where
  toExp :: forall rep r op. SubExp -> ImpM rep r op Exp
toExp (Constant PrimValue
v) =
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall v. PrimValue -> PrimExp v
Imp.ValueExp PrimValue
v
  toExp (Var VName
v) =
    forall rep r op. VName -> ImpM rep r op (VarEntry rep)
lookupVar VName
v forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      ScalarVar Maybe (Exp rep)
_ (ScalarEntry PrimType
pt) ->
        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
v PrimType
pt
      VarEntry rep
_ -> forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"toExp SubExp: SubExp is not a primitive type: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString VName
v

  toExp' :: PrimType -> SubExp -> Exp
toExp' PrimType
_ (Constant PrimValue
v) = forall v. PrimValue -> PrimExp v
Imp.ValueExp PrimValue
v
  toExp' PrimType
t (Var VName
v) = VName -> PrimType -> Exp
Imp.var VName
v PrimType
t

instance ToExp (PrimExp VName) where
  toExp :: forall rep r op. Exp -> ImpM rep r op Exp
toExp = forall (f :: * -> *) a. Applicative f => a -> f a
pure
  toExp' :: PrimType -> Exp -> Exp
toExp' PrimType
_ = forall a. a -> a
id

addVar :: VName -> VarEntry rep -> ImpM rep r op ()
addVar :: forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
name VarEntry rep
entry =
  forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s -> ImpState rep r op
s {stateVTable :: VTable rep
stateVTable = forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
name VarEntry rep
entry forall a b. (a -> b) -> a -> b
$ forall {k} rep (r :: k) op. ImpState rep r op -> VTable rep
stateVTable ImpState rep r op
s}

localDefaultSpace :: Imp.Space -> ImpM rep r op a -> ImpM rep r op a
localDefaultSpace :: forall rep r op a. Space -> ImpM rep r op a -> ImpM rep r op a
localDefaultSpace Space
space = forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (\Env rep r op
env -> Env rep r op
env {envDefaultSpace :: Space
envDefaultSpace = Space
space})

askFunction :: ImpM rep r op (Maybe Name)
askFunction :: forall rep r op. ImpM rep r op (Maybe Name)
askFunction = forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall rep r op. Env rep r op -> Maybe Name
envFunction

-- | Generate a 'VName', prefixed with 'askFunction' if it exists.
newVNameForFun :: String -> ImpM rep r op VName
newVNameForFun :: forall rep r op. [Char] -> ImpM rep r op VName
newVNameForFun [Char]
s = do
  Maybe [Char]
fname <- forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Name -> [Char]
nameToString forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep r op. ImpM rep r op (Maybe Name)
askFunction
  forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName forall a b. (a -> b) -> a -> b
$ forall b a. b -> (a -> b) -> Maybe a -> b
maybe [Char]
"" (forall a. [a] -> [a] -> [a]
++ [Char]
".") Maybe [Char]
fname forall a. [a] -> [a] -> [a]
++ [Char]
s

-- | Generate a 'Name', prefixed with 'askFunction' if it exists.
nameForFun :: String -> ImpM rep r op Name
nameForFun :: forall rep r op. [Char] -> ImpM rep r op Name
nameForFun [Char]
s = do
  Maybe Name
fname <- forall rep r op. ImpM rep r op (Maybe Name)
askFunction
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall b a. b -> (a -> b) -> Maybe a -> b
maybe Name
"" (forall a. Semigroup a => a -> a -> a
<> Name
".") Maybe Name
fname forall a. Semigroup a => a -> a -> a
<> [Char] -> Name
nameFromString [Char]
s

askEnv :: ImpM rep r op r
askEnv :: forall rep r op. ImpM rep r op r
askEnv = forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall rep r op. Env rep r op -> r
envEnv

localEnv :: (r -> r) -> ImpM rep r op a -> ImpM rep r op a
localEnv :: forall r rep op a. (r -> r) -> ImpM rep r op a -> ImpM rep r op a
localEnv r -> r
f = forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local forall a b. (a -> b) -> a -> b
$ \Env rep r op
env -> Env rep r op
env {envEnv :: r
envEnv = r -> r
f forall a b. (a -> b) -> a -> b
$ forall rep r op. Env rep r op -> r
envEnv Env rep r op
env}

-- | The active attributes, including those for the statement
-- currently being compiled.
askAttrs :: ImpM rep r op Attrs
askAttrs :: forall rep r op. ImpM rep r op Attrs
askAttrs = forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall rep r op. Env rep r op -> Attrs
envAttrs

-- | Add more attributes to what is returning by 'askAttrs'.
localAttrs :: Attrs -> ImpM rep r op a -> ImpM rep r op a
localAttrs :: forall rep r op a. Attrs -> ImpM rep r op a -> ImpM rep r op a
localAttrs Attrs
attrs = forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local forall a b. (a -> b) -> a -> b
$ \Env rep r op
env -> Env rep r op
env {envAttrs :: Attrs
envAttrs = Attrs
attrs forall a. Semigroup a => a -> a -> a
<> forall rep r op. Env rep r op -> Attrs
envAttrs Env rep r op
env}

localOps :: Operations rep r op -> ImpM rep r op a -> ImpM rep r op a
localOps :: forall rep r op a.
Operations rep r op -> ImpM rep r op a -> ImpM rep r op a
localOps Operations rep r op
ops = forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local forall a b. (a -> b) -> a -> b
$ \Env rep r op
env ->
  Env rep r op
env
    { envExpCompiler :: ExpCompiler rep r op
envExpCompiler = forall rep r op. Operations rep r op -> ExpCompiler rep r op
opsExpCompiler Operations rep r op
ops,
      envStmsCompiler :: StmsCompiler rep r op
envStmsCompiler = forall rep r op. Operations rep r op -> StmsCompiler rep r op
opsStmsCompiler Operations rep r op
ops,
      envCopyCompiler :: CopyCompiler rep r op
envCopyCompiler = forall rep r op. Operations rep r op -> CopyCompiler rep r op
opsCopyCompiler Operations rep r op
ops,
      envOpCompiler :: OpCompiler rep r op
envOpCompiler = forall rep r op. Operations rep r op -> OpCompiler rep r op
opsOpCompiler Operations rep r op
ops,
      envAllocCompilers :: Map Space (AllocCompiler rep r op)
envAllocCompilers = forall rep r op.
Operations rep r op -> Map Space (AllocCompiler rep r op)
opsAllocCompilers Operations rep r op
ops
    }

-- | Get the current symbol table.
getVTable :: ImpM rep r op (VTable rep)
getVTable :: forall rep r op. ImpM rep r op (VTable rep)
getVTable = forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall {k} rep (r :: k) op. ImpState rep r op -> VTable rep
stateVTable

putVTable :: VTable rep -> ImpM rep r op ()
putVTable :: forall rep r op. VTable rep -> ImpM rep r op ()
putVTable VTable rep
vtable = forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s -> ImpState rep r op
s {stateVTable :: VTable rep
stateVTable = VTable rep
vtable}

-- | Run an action with a modified symbol table.  All changes to the
-- symbol table will be reverted once the action is done!
localVTable :: (VTable rep -> VTable rep) -> ImpM rep r op a -> ImpM rep r op a
localVTable :: forall rep r op a.
(VTable rep -> VTable rep) -> ImpM rep r op a -> ImpM rep r op a
localVTable VTable rep -> VTable rep
f ImpM rep r op a
m = do
  VTable rep
old_vtable <- forall rep r op. ImpM rep r op (VTable rep)
getVTable
  forall rep r op. VTable rep -> ImpM rep r op ()
putVTable forall a b. (a -> b) -> a -> b
$ VTable rep -> VTable rep
f VTable rep
old_vtable
  a
a <- ImpM rep r op a
m
  forall rep r op. VTable rep -> ImpM rep r op ()
putVTable VTable rep
old_vtable
  forall (f :: * -> *) a. Applicative f => a -> f a
pure a
a

lookupVar :: VName -> ImpM rep r op (VarEntry rep)
lookupVar :: forall rep r op. VName -> ImpM rep r op (VarEntry rep)
lookupVar VName
name = do
  Maybe (VarEntry rep)
res <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} rep (r :: k) op. ImpState rep r op -> VTable rep
stateVTable
  case Maybe (VarEntry rep)
res of
    Just VarEntry rep
entry -> forall (f :: * -> *) a. Applicative f => a -> f a
pure VarEntry rep
entry
    Maybe (VarEntry rep)
_ -> forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"Unknown variable: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString VName
name

lookupArray :: VName -> ImpM rep r op ArrayEntry
lookupArray :: forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray VName
name = do
  VarEntry rep
res <- forall rep r op. VName -> ImpM rep r op (VarEntry rep)
lookupVar VName
name
  case VarEntry rep
res of
    ArrayVar Maybe (Exp rep)
_ ArrayEntry
entry -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ArrayEntry
entry
    VarEntry rep
_ -> forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"ImpGen.lookupArray: not an array: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString VName
name

lookupMemory :: VName -> ImpM rep r op MemEntry
lookupMemory :: forall rep r op. VName -> ImpM rep r op MemEntry
lookupMemory VName
name = do
  VarEntry rep
res <- forall rep r op. VName -> ImpM rep r op (VarEntry rep)
lookupVar VName
name
  case VarEntry rep
res of
    MemVar Maybe (Exp rep)
_ MemEntry
entry -> forall (f :: * -> *) a. Applicative f => a -> f a
pure MemEntry
entry
    VarEntry rep
_ -> forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"Unknown memory block: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString VName
name

lookupArraySpace :: VName -> ImpM rep r op Space
lookupArraySpace :: forall rep r op. VName -> ImpM rep r op Space
lookupArraySpace =
  forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap MemEntry -> Space
entryMemSpace forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep r op. VName -> ImpM rep r op MemEntry
lookupMemory
    forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (MemLoc -> VName
memLocName forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArrayEntry -> MemLoc
entryArrayLoc) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray

-- | In the case of a histogram-like accumulator, also sets the index
-- parameters.
lookupAcc ::
  VName ->
  [Imp.TExp Int64] ->
  ImpM rep r op (VName, Space, [VName], [Imp.TExp Int64], Maybe (Lambda rep))
lookupAcc :: forall rep r op.
VName
-> [TExp Int64]
-> ImpM
     rep r op (VName, Space, [VName], [TExp Int64], Maybe (Lambda rep))
lookupAcc VName
name [TExp Int64]
is = do
  VarEntry rep
res <- forall rep r op. VName -> ImpM rep r op (VarEntry rep)
lookupVar VName
name
  case VarEntry rep
res of
    AccVar Maybe (Exp rep)
_ (VName
acc, Shape
ispace, [Type]
_) -> do
      Maybe ([VName], Maybe (Lambda rep, [SubExp]))
acc' <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
acc forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} rep (r :: k) op.
ImpState rep r op
-> Map VName ([VName], Maybe (Lambda rep, [SubExp]))
stateAccs
      case Maybe ([VName], Maybe (Lambda rep, [SubExp]))
acc' of
        Just ([], Maybe (Lambda rep, [SubExp])
_) ->
          forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"Accumulator with no arrays: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString VName
name
        Just (arrs :: [VName]
arrs@(VName
arr : [VName]
_), Just (Lambda rep
op, [SubExp]
_)) -> do
          Space
space <- forall rep r op. VName -> ImpM rep r op Space
lookupArraySpace VName
arr
          let ([Param (LParamInfo rep)]
i_params, [Param (LParamInfo rep)]
ps) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [TExp Int64]
is) forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
op
          forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ forall {k} (t :: k) rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param (LParamInfo rep)]
i_params) [TExp Int64]
is
          forall (f :: * -> *) a. Applicative f => a -> f a
pure
            ( VName
acc,
              Space
space,
              [VName]
arrs,
              forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 (forall d. ShapeBase d -> [d]
shapeDims Shape
ispace),
              forall a. a -> Maybe a
Just Lambda rep
op {lambdaParams :: [Param (LParamInfo rep)]
lambdaParams = [Param (LParamInfo rep)]
ps}
            )
        Just (arrs :: [VName]
arrs@(VName
arr : [VName]
_), Maybe (Lambda rep, [SubExp])
Nothing) -> do
          Space
space <- forall rep r op. VName -> ImpM rep r op Space
lookupArraySpace VName
arr
          forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
acc, Space
space, [VName]
arrs, forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 (forall d. ShapeBase d -> [d]
shapeDims Shape
ispace), forall a. Maybe a
Nothing)
        Maybe ([VName], Maybe (Lambda rep, [SubExp]))
Nothing ->
          forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"ImpGen.lookupAcc: unlisted accumulator: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString VName
name
    VarEntry rep
_ -> forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"ImpGen.lookupAcc: not an accumulator: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString VName
name

destinationFromPat :: Pat (LetDec rep) -> ImpM rep r op [ValueDestination]
destinationFromPat :: forall rep r op.
Pat (LetDec rep) -> ImpM rep r op [ValueDestination]
destinationFromPat = forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {dec} {rep} {r} {op}.
PatElem dec -> ImpM rep r op ValueDestination
inspect forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Pat dec -> [PatElem dec]
patElems
  where
    inspect :: PatElem dec -> ImpM rep r op ValueDestination
inspect PatElem dec
pe = do
      let name :: VName
name = forall dec. PatElem dec -> VName
patElemName PatElem dec
pe
      VarEntry rep
entry <- forall rep r op. VName -> ImpM rep r op (VarEntry rep)
lookupVar VName
name
      case VarEntry rep
entry of
        ArrayVar Maybe (Exp rep)
_ (ArrayEntry MemLoc {} PrimType
_) ->
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Maybe MemLoc -> ValueDestination
ArrayDestination forall a. Maybe a
Nothing
        MemVar {} ->
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> ValueDestination
MemoryDestination VName
name
        ScalarVar {} ->
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> ValueDestination
ScalarDestination VName
name
        AccVar {} ->
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Maybe MemLoc -> ValueDestination
ArrayDestination forall a. Maybe a
Nothing

fullyIndexArray ::
  VName ->
  [Imp.TExp Int64] ->
  ImpM rep r op (VName, Imp.Space, Count Elements (Imp.TExp Int64))
fullyIndexArray :: forall rep r op.
VName
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray VName
name [TExp Int64]
indices = do
  ArrayEntry
arr <- forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray VName
name
  forall rep r op.
MemLoc
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' (ArrayEntry -> MemLoc
entryArrayLoc ArrayEntry
arr) [TExp Int64]
indices

fullyIndexArray' ::
  MemLoc ->
  [Imp.TExp Int64] ->
  ImpM rep r op (VName, Imp.Space, Count Elements (Imp.TExp Int64))
fullyIndexArray' :: forall rep r op.
MemLoc
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' (MemLoc VName
mem [SubExp]
_ IxFun (TExp Int64)
ixfun) [TExp Int64]
indices = do
  Space
space <- MemEntry -> Space
entryMemSpace forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep r op. VName -> ImpM rep r op MemEntry
lookupMemory VName
mem
  forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( VName
mem,
      Space
space,
      forall a. a -> Count Elements a
elements forall a b. (a -> b) -> a -> b
$ forall num.
(IntegralExp num, Eq num) =>
IxFun num -> Indices num -> num
IxFun.index IxFun (TExp Int64)
ixfun [TExp Int64]
indices
    )

-- More complicated read/write operations that use index functions.

copy :: CopyCompiler rep r op
copy :: forall rep r op. CopyCompiler rep r op
copy
  PrimType
bt
  dst :: MemLoc
dst@(MemLoc VName
dst_name [SubExp]
_ dst_ixfn :: IxFun (TExp Int64)
dst_ixfn@(IxFun.IxFun dst_lmads :: NonEmpty (LMAD (TExp Int64))
dst_lmads@(LMAD (TExp Int64)
dst_lmad :| [LMAD (TExp Int64)]
_) [TExp Int64]
_ Bool
_))
  src :: MemLoc
src@(MemLoc VName
src_name [SubExp]
_ src_ixfn :: IxFun (TExp Int64)
src_ixfn@(IxFun.IxFun src_lmads :: NonEmpty (LMAD (TExp Int64))
src_lmads@(LMAD (TExp Int64)
src_lmad :| [LMAD (TExp Int64)]
_) [TExp Int64]
_ Bool
_)) = do
    -- If we can statically determine that the two index-functions
    -- are equivalent, don't do anything
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (VName
dst_name forall a. Eq a => a -> a -> Bool
== VName
src_name Bool -> Bool -> Bool
&& IxFun (TExp Int64)
dst_ixfn forall num. Eq num => IxFun num -> IxFun num -> Bool
`IxFun.equivalent` IxFun (TExp Int64)
src_ixfn)
      forall a b. (a -> b) -> a -> b
$
      -- It's also possible that we can dynamically determine that the two
      -- index-functions are equivalent.
      forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sUnless
        ( forall v. Bool -> TPrimExp Bool v
fromBool (VName
dst_name forall a. Eq a => a -> a -> Bool
== VName
src_name Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => t a -> Int
length NonEmpty (LMAD (TExp Int64))
dst_lmads forall a. Eq a => a -> a -> Bool
== Int
1 Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => t a -> Int
length NonEmpty (LMAD (TExp Int64))
src_lmads forall a. Eq a => a -> a -> Bool
== Int
1)
            forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. forall {k} num (t :: k).
Eq num =>
LMAD (TPrimExp t num) -> LMAD (TPrimExp t num) -> TPrimExp Bool num
IxFun.dynamicEqualsLMAD LMAD (TExp Int64)
dst_lmad LMAD (TExp Int64)
src_lmad
        )
      forall a b. (a -> b) -> a -> b
$ do
        -- If none of the above is true, actually do the copy
        CopyCompiler rep r op
cc <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall rep r op. Env rep r op -> CopyCompiler rep r op
envCopyCompiler
        CopyCompiler rep r op
cc PrimType
bt MemLoc
dst MemLoc
src

-- | Is this copy really a mapping with transpose?
isMapTransposeCopy ::
  PrimType ->
  MemLoc ->
  MemLoc ->
  Maybe
    ( Imp.TExp Int64,
      Imp.TExp Int64,
      Imp.TExp Int64,
      Imp.TExp Int64,
      Imp.TExp Int64
    )
isMapTransposeCopy :: PrimType
-> MemLoc
-> MemLoc
-> Maybe
     (TExp Int64, TExp Int64, TExp Int64, TExp Int64, TExp Int64)
isMapTransposeCopy PrimType
bt (MemLoc VName
_ [SubExp]
_ IxFun (TExp Int64)
destIxFun) (MemLoc VName
_ [SubExp]
_ IxFun (TExp Int64)
srcIxFun)
  | Just (TExp Int64
dest_offset, [(Int, TExp Int64)]
perm_and_destshape) <- forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe (num, [(Int, num)])
IxFun.rearrangeWithOffset IxFun (TExp Int64)
destIxFun TExp Int64
bt_size,
    ([Int]
perm, [TExp Int64]
destshape) <- forall a b. [(a, b)] -> ([a], [b])
unzip [(Int, TExp Int64)]
perm_and_destshape,
    Just TExp Int64
src_offset <- forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
IxFun.linearWithOffset IxFun (TExp Int64)
srcIxFun TExp Int64
bt_size,
    Just (Int
r1, Int
r2, Int
_) <- [Int] -> Maybe (Int, Int, Int)
isMapTranspose [Int]
perm =
      forall {t :: * -> *} {t :: * -> *} {c} {d} {e} {f :: * -> *} {a}
       {b}.
(Foldable t, Foldable t, Num c, Num d, Num e, Applicative f) =>
[c]
-> (([c], [c]) -> (t d, t e))
-> Int
-> Int
-> a
-> b
-> f (a, b, c, d, e)
isOk [TExp Int64]
destshape forall {b} {a}. (b, a) -> (a, b)
swap Int
r1 Int
r2 TExp Int64
dest_offset TExp Int64
src_offset
  | Just TExp Int64
dest_offset <- forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
IxFun.linearWithOffset IxFun (TExp Int64)
destIxFun TExp Int64
bt_size,
    Just (TExp Int64
src_offset, [(Int, TExp Int64)]
perm_and_srcshape) <- forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe (num, [(Int, num)])
IxFun.rearrangeWithOffset IxFun (TExp Int64)
srcIxFun TExp Int64
bt_size,
    ([Int]
perm, [TExp Int64]
srcshape) <- forall a b. [(a, b)] -> ([a], [b])
unzip [(Int, TExp Int64)]
perm_and_srcshape,
    Just (Int
r1, Int
r2, Int
_) <- [Int] -> Maybe (Int, Int, Int)
isMapTranspose [Int]
perm =
      forall {t :: * -> *} {t :: * -> *} {c} {d} {e} {f :: * -> *} {a}
       {b}.
(Foldable t, Foldable t, Num c, Num d, Num e, Applicative f) =>
[c]
-> (([c], [c]) -> (t d, t e))
-> Int
-> Int
-> a
-> b
-> f (a, b, c, d, e)
isOk [TExp Int64]
srcshape forall a. a -> a
id Int
r1 Int
r2 TExp Int64
dest_offset TExp Int64
src_offset
  | Bool
otherwise =
      forall a. Maybe a
Nothing
  where
    bt_size :: TExp Int64
bt_size = forall a. Num a => PrimType -> a
primByteSize PrimType
bt
    swap :: (b, a) -> (a, b)
swap (b
x, a
y) = (a
y, b
x)

    isOk :: [c]
-> (([c], [c]) -> (t d, t e))
-> Int
-> Int
-> a
-> b
-> f (a, b, c, d, e)
isOk [c]
shape ([c], [c]) -> (t d, t e)
f Int
r1 Int
r2 a
dest_offset b
src_offset = do
      let (c
num_arrays, d
size_x, e
size_y) = forall {t :: * -> *} {t :: * -> *} {a} {b} {c}.
(Foldable t, Foldable t, Num a, Num b, Num c) =>
[a] -> (([a], [a]) -> (t b, t c)) -> Int -> Int -> (a, b, c)
getSizes [c]
shape ([c], [c]) -> (t d, t e)
f Int
r1 Int
r2
      forall (f :: * -> *) a. Applicative f => a -> f a
pure
        ( a
dest_offset,
          b
src_offset,
          c
num_arrays,
          d
size_x,
          e
size_y
        )

    getSizes :: [a] -> (([a], [a]) -> (t b, t c)) -> Int -> Int -> (a, b, c)
getSizes [a]
shape ([a], [a]) -> (t b, t c)
f Int
r1 Int
r2 =
      let ([a]
mapped, [a]
notmapped) = forall a. Int -> [a] -> ([a], [a])
splitAt Int
r1 [a]
shape
          (t b
pretrans, t c
posttrans) = ([a], [a]) -> (t b, t c)
f forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> ([a], [a])
splitAt Int
r2 [a]
notmapped
       in (forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [a]
mapped, forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product t b
pretrans, forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product t c
posttrans)

mapTransposeName :: PrimType -> String
mapTransposeName :: PrimType -> [Char]
mapTransposeName PrimType
bt = [Char]
"map_transpose_" forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString PrimType
bt

mapTransposeForType :: PrimType -> ImpM rep r op Name
mapTransposeForType :: forall rep r op. PrimType -> ImpM rep r op Name
mapTransposeForType PrimType
bt = do
  let fname :: Name
fname = [Char] -> Name
nameFromString forall a b. (a -> b) -> a -> b
$ [Char]
"builtin#" forall a. Semigroup a => a -> a -> a
<> PrimType -> [Char]
mapTransposeName PrimType
bt

  Bool
exists <- forall rep r op. Name -> ImpM rep r op Bool
hasFunction Name
fname
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
exists forall a b. (a -> b) -> a -> b
$ forall op rep r. Name -> Function op -> ImpM rep r op ()
emitFunction Name
fname forall a b. (a -> b) -> a -> b
$ forall op. Name -> PrimType -> Function op
mapTransposeFunction Name
fname PrimType
bt

  forall (f :: * -> *) a. Applicative f => a -> f a
pure Name
fname

-- | Use 'sCopy' if possible, otherwise 'copyElementWise'.
defaultCopy :: CopyCompiler rep r op
defaultCopy :: forall rep r op. CopyCompiler rep r op
defaultCopy PrimType
pt MemLoc
dest MemLoc
src
  | Just (TExp Int64
destoffset, TExp Int64
srcoffset, TExp Int64
num_arrays, TExp Int64
size_x, TExp Int64
size_y) <-
      PrimType
-> MemLoc
-> MemLoc
-> Maybe
     (TExp Int64, TExp Int64, TExp Int64, TExp Int64, TExp Int64)
isMapTransposeCopy PrimType
pt MemLoc
dest MemLoc
src = do
      Name
fname <- forall rep r op. PrimType -> ImpM rep r op Name
mapTransposeForType PrimType
pt
      forall op rep r. Code op -> ImpM rep r op ()
emit
        forall a b. (a -> b) -> a -> b
$ forall a. [VName] -> Name -> [Arg] -> Code a
Imp.Call
          []
          Name
fname
        forall a b. (a -> b) -> a -> b
$ PrimType
-> VName
-> Count Bytes (TExp Int64)
-> VName
-> Count Bytes (TExp Int64)
-> TExp Int64
-> TExp Int64
-> TExp Int64
-> [Arg]
transposeArgs
          PrimType
pt
          VName
destmem
          (forall a. a -> Count Bytes a
bytes TExp Int64
destoffset)
          VName
srcmem
          (forall a. a -> Count Bytes a
bytes TExp Int64
srcoffset)
          TExp Int64
num_arrays
          TExp Int64
size_x
          TExp Int64
size_y
  | Just TExp Int64
destoffset <-
      forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
IxFun.linearWithOffset IxFun (TExp Int64)
dest_ixfun TExp Int64
pt_size,
    Just TExp Int64
srcoffset <-
      forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
IxFun.linearWithOffset IxFun (TExp Int64)
src_ixfun TExp Int64
pt_size = do
      Space
srcspace <- MemEntry -> Space
entryMemSpace forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep r op. VName -> ImpM rep r op MemEntry
lookupMemory VName
srcmem
      Space
destspace <- MemEntry -> Space
entryMemSpace forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep r op. VName -> ImpM rep r op MemEntry
lookupMemory VName
destmem
      if Space -> Bool
isScalarSpace Space
srcspace Bool -> Bool -> Bool
|| Space -> Bool
isScalarSpace Space
destspace
        then forall rep r op. CopyCompiler rep r op
copyElementWise PrimType
pt MemLoc
dest MemLoc
src
        else forall rep r op.
VName
-> TExp Int64
-> Space
-> VName
-> TExp Int64
-> Space
-> Count Elements (TExp Int64)
-> PrimType
-> ImpM rep r op ()
sCopy VName
destmem TExp Int64
destoffset Space
destspace VName
srcmem TExp Int64
srcoffset Space
srcspace Count Elements (TExp Int64)
num_elems PrimType
pt
  | Bool
otherwise =
      forall rep r op. CopyCompiler rep r op
copyElementWise PrimType
pt MemLoc
dest MemLoc
src
  where
    pt_size :: TExp Int64
pt_size = forall a. Num a => PrimType -> a
primByteSize PrimType
pt
    num_elems :: Count Elements (TExp Int64)
num_elems = forall a. a -> Count Elements a
Imp.elements forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product forall a b. (a -> b) -> a -> b
$ forall num. (Eq num, IntegralExp num) => IxFun num -> Shape num
IxFun.shape forall a b. (a -> b) -> a -> b
$ MemLoc -> IxFun (TExp Int64)
memLocIxFun MemLoc
src

    MemLoc VName
destmem [SubExp]
_ IxFun (TExp Int64)
dest_ixfun = MemLoc
dest
    MemLoc VName
srcmem [SubExp]
_ IxFun (TExp Int64)
src_ixfun = MemLoc
src

    isScalarSpace :: Space -> Bool
isScalarSpace ScalarSpace {} = Bool
True
    isScalarSpace Space
_ = Bool
False

copyElementWise :: CopyCompiler rep r op
copyElementWise :: forall rep r op. CopyCompiler rep r op
copyElementWise PrimType
bt MemLoc
dest MemLoc
src = do
  let bounds :: [TExp Int64]
bounds = forall num. (Eq num, IntegralExp num) => IxFun num -> Shape num
IxFun.shape forall a b. (a -> b) -> a -> b
$ MemLoc -> IxFun (TExp Int64)
memLocIxFun MemLoc
src
  [VName]
is <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (forall (t :: * -> *) a. Foldable t => t a -> Int
length [TExp Int64]
bounds) (forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"i")
  let ivars :: [TExp Int64]
ivars = forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
is
  (VName
destmem, Space
destspace, Count Elements (TExp Int64)
destidx) <- forall rep r op.
MemLoc
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLoc
dest [TExp Int64]
ivars
  (VName
srcmem, Space
srcspace, Count Elements (TExp Int64)
srcidx) <- forall rep r op.
MemLoc
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLoc
src [TExp Int64]
ivars
  Volatility
vol <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall rep r op. Env rep r op -> Volatility
envVolatility
  VName
tmp <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"tmp"
  forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$
    forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl forall b c a. (b -> c) -> (a -> b) -> a -> c
(.) forall a. a -> a
id (forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. VName -> Exp -> Code a -> Code a
Imp.For [VName]
is forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped [TExp Int64]
bounds) forall a b. (a -> b) -> a -> b
$
      forall a. Monoid a => [a] -> a
mconcat
        [ forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
tmp Volatility
vol PrimType
bt,
          forall a.
VName
-> VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Code a
Imp.Read VName
tmp VName
srcmem Count Elements (TExp Int64)
srcidx PrimType
bt Space
srcspace Volatility
vol,
          forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
destmem Count Elements (TExp Int64)
destidx PrimType
bt Space
destspace Volatility
vol forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
tmp PrimType
bt
        ]

-- | Copy from here to there; both destination and source may be
-- indexeded.
copyArrayDWIM ::
  PrimType ->
  MemLoc ->
  [DimIndex (Imp.TExp Int64)] ->
  MemLoc ->
  [DimIndex (Imp.TExp Int64)] ->
  ImpM rep r op (Imp.Code op)
copyArrayDWIM :: forall rep r op.
PrimType
-> MemLoc
-> [DimIndex (TExp Int64)]
-> MemLoc
-> [DimIndex (TExp Int64)]
-> ImpM rep r op (Code op)
copyArrayDWIM
  PrimType
bt
  destlocation :: MemLoc
destlocation@(MemLoc VName
_ [SubExp]
destshape IxFun (TExp Int64)
_)
  [DimIndex (TExp Int64)]
destslice
  srclocation :: MemLoc
srclocation@(MemLoc VName
_ [SubExp]
srcshape IxFun (TExp Int64)
_)
  [DimIndex (TExp Int64)]
srcslice
    | Just [TExp Int64]
destis <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall d. DimIndex d -> Maybe d
dimFix [DimIndex (TExp Int64)]
destslice,
      Just [TExp Int64]
srcis <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall d. DimIndex d -> Maybe d
dimFix [DimIndex (TExp Int64)]
srcslice,
      forall (t :: * -> *) a. Foldable t => t a -> Int
length [TExp Int64]
srcis forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
srcshape,
      forall (t :: * -> *) a. Foldable t => t a -> Int
length [TExp Int64]
destis forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
destshape = do
        (VName
targetmem, Space
destspace, Count Elements (TExp Int64)
targetoffset) <-
          forall rep r op.
MemLoc
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLoc
destlocation [TExp Int64]
destis
        (VName
srcmem, Space
srcspace, Count Elements (TExp Int64)
srcoffset) <-
          forall rep r op.
MemLoc
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLoc
srclocation [TExp Int64]
srcis
        Volatility
vol <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall rep r op. Env rep r op -> Volatility
envVolatility
        forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect forall a b. (a -> b) -> a -> b
$ do
          VName
tmp <- forall {k} (t :: k). TV t -> VName
tvVar forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} rep r op (t :: k).
[Char] -> PrimType -> ImpM rep r op (TV t)
dPrim [Char]
"tmp" PrimType
bt
          forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a.
VName
-> VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Code a
Imp.Read VName
tmp VName
srcmem Count Elements (TExp Int64)
srcoffset PrimType
bt Space
srcspace Volatility
vol
          forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
targetmem Count Elements (TExp Int64)
targetoffset PrimType
bt Space
destspace Volatility
vol forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
tmp PrimType
bt
    | Bool
otherwise = do
        let destslice' :: Slice (TExp Int64)
destslice' = forall d. Num d => [d] -> [DimIndex d] -> Slice d
fullSliceNum (forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 [SubExp]
destshape) [DimIndex (TExp Int64)]
destslice
            srcslice' :: Slice (TExp Int64)
srcslice' = forall d. Num d => [d] -> [DimIndex d] -> Slice d
fullSliceNum (forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 [SubExp]
srcshape) [DimIndex (TExp Int64)]
srcslice
            destrank :: Int
destrank = forall (t :: * -> *) a. Foldable t => t a -> Int
length forall a b. (a -> b) -> a -> b
$ forall d. Slice d -> [d]
sliceDims Slice (TExp Int64)
destslice'
            srcrank :: Int
srcrank = forall (t :: * -> *) a. Foldable t => t a -> Int
length forall a b. (a -> b) -> a -> b
$ forall d. Slice d -> [d]
sliceDims Slice (TExp Int64)
srcslice'
            destlocation' :: MemLoc
destlocation' = MemLoc -> Slice (TExp Int64) -> MemLoc
sliceMemLoc MemLoc
destlocation Slice (TExp Int64)
destslice'
            srclocation' :: MemLoc
srclocation' = MemLoc -> Slice (TExp Int64) -> MemLoc
sliceMemLoc MemLoc
srclocation Slice (TExp Int64)
srcslice'
        if Int
destrank forall a. Eq a => a -> a -> Bool
/= Int
srcrank
          then
            forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$
              [Char]
"copyArrayDWIM: cannot copy to "
                forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString (MemLoc -> VName
memLocName MemLoc
destlocation)
                forall a. [a] -> [a] -> [a]
++ [Char]
" from "
                forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString (MemLoc -> VName
memLocName MemLoc
srclocation)
                forall a. [a] -> [a] -> [a]
++ [Char]
" because ranks do not match ("
                forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString Int
destrank
                forall a. [a] -> [a] -> [a]
++ [Char]
" vs "
                forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString Int
srcrank
                forall a. [a] -> [a] -> [a]
++ [Char]
")"
          else
            if MemLoc
destlocation' forall a. Eq a => a -> a -> Bool
== MemLoc
srclocation'
              then forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Monoid a => a
mempty -- Copy would be no-op.
              else forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect forall a b. (a -> b) -> a -> b
$ forall rep r op. CopyCompiler rep r op
copy PrimType
bt MemLoc
destlocation' MemLoc
srclocation'

-- Like 'copyDWIM', but the target is a 'ValueDestination' instead of
-- a variable name.
copyDWIMDest ::
  ValueDestination ->
  [DimIndex (Imp.TExp Int64)] ->
  SubExp ->
  [DimIndex (Imp.TExp Int64)] ->
  ImpM rep r op ()
copyDWIMDest :: forall rep r op.
ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIMDest ValueDestination
_ [DimIndex (TExp Int64)]
_ (Constant PrimValue
v) (DimIndex (TExp Int64)
_ : [DimIndex (TExp Int64)]
_) =
  forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$
    [[Char]] -> [Char]
unwords [[Char]
"copyDWIMDest: constant source", forall a. Pretty a => a -> [Char]
prettyString PrimValue
v, [Char]
"cannot be indexed."]
copyDWIMDest ValueDestination
pat [DimIndex (TExp Int64)]
dest_slice (Constant PrimValue
v) [] =
  case forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall d. DimIndex d -> Maybe d
dimFix [DimIndex (TExp Int64)]
dest_slice of
    Maybe [TExp Int64]
Nothing ->
      forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$
        [[Char]] -> [Char]
unwords [[Char]
"copyDWIMDest: constant source", forall a. Pretty a => a -> [Char]
prettyString PrimValue
v, [Char]
"with slice destination."]
    Just [TExp Int64]
dest_is ->
      case ValueDestination
pat of
        ScalarDestination VName
name ->
          forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. VName -> Exp -> Code a
Imp.SetScalar VName
name forall a b. (a -> b) -> a -> b
$ forall v. PrimValue -> PrimExp v
Imp.ValueExp PrimValue
v
        MemoryDestination {} ->
          forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$
            [[Char]] -> [Char]
unwords [[Char]
"copyDWIMDest: constant source", forall a. Pretty a => a -> [Char]
prettyString PrimValue
v, [Char]
"cannot be written to memory destination."]
        ArrayDestination (Just MemLoc
dest_loc) -> do
          (VName
dest_mem, Space
dest_space, Count Elements (TExp Int64)
dest_i) <-
            forall rep r op.
MemLoc
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLoc
dest_loc [TExp Int64]
dest_is
          Volatility
vol <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall rep r op. Env rep r op -> Volatility
envVolatility
          forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
dest_mem Count Elements (TExp Int64)
dest_i PrimType
bt Space
dest_space Volatility
vol forall a b. (a -> b) -> a -> b
$ forall v. PrimValue -> PrimExp v
Imp.ValueExp PrimValue
v
        ArrayDestination Maybe MemLoc
Nothing ->
          forall a. HasCallStack => [Char] -> a
error [Char]
"copyDWIMDest: ArrayDestination Nothing"
  where
    bt :: PrimType
bt = PrimValue -> PrimType
primValueType PrimValue
v
copyDWIMDest ValueDestination
dest [DimIndex (TExp Int64)]
dest_slice (Var VName
src) [DimIndex (TExp Int64)]
src_slice = do
  VarEntry rep
src_entry <- forall rep r op. VName -> ImpM rep r op (VarEntry rep)
lookupVar VName
src
  case (ValueDestination
dest, VarEntry rep
src_entry) of
    (MemoryDestination VName
mem, MemVar Maybe (Exp rep)
_ (MemEntry Space
space)) ->
      forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. VName -> VName -> Space -> Code a
Imp.SetMem VName
mem VName
src Space
space
    (MemoryDestination {}, VarEntry rep
_) ->
      forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$
        [[Char]] -> [Char]
unwords [[Char]
"copyDWIMDest: cannot write", forall a. Pretty a => a -> [Char]
prettyString VName
src, [Char]
"to memory destination."]
    (ValueDestination
_, MemVar {}) ->
      forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$
        [[Char]] -> [Char]
unwords [[Char]
"copyDWIMDest: source", forall a. Pretty a => a -> [Char]
prettyString VName
src, [Char]
"is a memory block."]
    (ValueDestination
_, ScalarVar Maybe (Exp rep)
_ (ScalarEntry PrimType
_))
      | Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null [DimIndex (TExp Int64)]
src_slice ->
          forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$
            [[Char]] -> [Char]
unwords [[Char]
"copyDWIMDest: prim-typed source", forall a. Pretty a => a -> [Char]
prettyString VName
src, [Char]
"with slice", forall a. Pretty a => a -> [Char]
prettyString [DimIndex (TExp Int64)]
src_slice]
    (ScalarDestination VName
name, VarEntry rep
_)
      | Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null [DimIndex (TExp Int64)]
dest_slice ->
          forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$
            [[Char]] -> [Char]
unwords [[Char]
"copyDWIMDest: prim-typed target", forall a. Pretty a => a -> [Char]
prettyString VName
name, [Char]
"with slice", forall a. Pretty a => a -> [Char]
prettyString [DimIndex (TExp Int64)]
dest_slice]
    (ScalarDestination VName
name, ScalarVar Maybe (Exp rep)
_ (ScalarEntry PrimType
pt)) ->
      forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. VName -> Exp -> Code a
Imp.SetScalar VName
name forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
src PrimType
pt
    (ScalarDestination VName
name, ArrayVar Maybe (Exp rep)
_ ArrayEntry
arr)
      | Just [TExp Int64]
src_is <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall d. DimIndex d -> Maybe d
dimFix [DimIndex (TExp Int64)]
src_slice,
        forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex (TExp Int64)]
src_slice forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length (ArrayEntry -> [SubExp]
entryArrayShape ArrayEntry
arr) -> do
          let bt :: PrimType
bt = ArrayEntry -> PrimType
entryArrayElemType ArrayEntry
arr
          (VName
mem, Space
space, Count Elements (TExp Int64)
i) <-
            forall rep r op.
MemLoc
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' (ArrayEntry -> MemLoc
entryArrayLoc ArrayEntry
arr) [TExp Int64]
src_is
          Volatility
vol <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall rep r op. Env rep r op -> Volatility
envVolatility
          forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a.
VName
-> VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Code a
Imp.Read VName
name VName
mem Count Elements (TExp Int64)
i PrimType
bt Space
space Volatility
vol
      | Bool
otherwise ->
          forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$
            [[Char]] -> [Char]
unwords
              [ [Char]
"copyDWIMDest: prim-typed target",
                forall a. Pretty a => a -> [Char]
prettyString VName
name,
                [Char]
"and array-typed source",
                forall a. Pretty a => a -> [Char]
prettyString VName
src,
                [Char]
"of shape",
                forall a. Pretty a => a -> [Char]
prettyString (ArrayEntry -> [SubExp]
entryArrayShape ArrayEntry
arr),
                [Char]
"sliced with",
                forall a. Pretty a => a -> [Char]
prettyString [DimIndex (TExp Int64)]
src_slice
              ]
    (ArrayDestination (Just MemLoc
dest_loc), ArrayVar Maybe (Exp rep)
_ ArrayEntry
src_arr) -> do
      let src_loc :: MemLoc
src_loc = ArrayEntry -> MemLoc
entryArrayLoc ArrayEntry
src_arr
          bt :: PrimType
bt = ArrayEntry -> PrimType
entryArrayElemType ArrayEntry
src_arr
      forall op rep r. Code op -> ImpM rep r op ()
emit forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall rep r op.
PrimType
-> MemLoc
-> [DimIndex (TExp Int64)]
-> MemLoc
-> [DimIndex (TExp Int64)]
-> ImpM rep r op (Code op)
copyArrayDWIM PrimType
bt MemLoc
dest_loc [DimIndex (TExp Int64)]
dest_slice MemLoc
src_loc [DimIndex (TExp Int64)]
src_slice
    (ArrayDestination (Just MemLoc
dest_loc), ScalarVar Maybe (Exp rep)
_ (ScalarEntry PrimType
bt))
      | Just [TExp Int64]
dest_is <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall d. DimIndex d -> Maybe d
dimFix [DimIndex (TExp Int64)]
dest_slice,
        forall (t :: * -> *) a. Foldable t => t a -> Int
length [TExp Int64]
dest_is forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length (MemLoc -> [SubExp]
memLocShape MemLoc
dest_loc) -> do
          (VName
dest_mem, Space
dest_space, Count Elements (TExp Int64)
dest_i) <- forall rep r op.
MemLoc
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLoc
dest_loc [TExp Int64]
dest_is
          Volatility
vol <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall rep r op. Env rep r op -> Volatility
envVolatility
          forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
dest_mem Count Elements (TExp Int64)
dest_i PrimType
bt Space
dest_space Volatility
vol (VName -> PrimType -> Exp
Imp.var VName
src PrimType
bt)
      | Bool
otherwise ->
          forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$
            [[Char]] -> [Char]
unwords
              [ [Char]
"copyDWIMDest: array-typed target and prim-typed source",
                forall a. Pretty a => a -> [Char]
prettyString VName
src,
                [Char]
"with slice",
                forall a. Pretty a => a -> [Char]
prettyString [DimIndex (TExp Int64)]
dest_slice
              ]
    (ArrayDestination Maybe MemLoc
Nothing, VarEntry rep
_) ->
      forall (f :: * -> *) a. Applicative f => a -> f a
pure () -- Nothing to do; something else set some memory
      -- somewhere.
    (ValueDestination
_, AccVar {}) ->
      forall (f :: * -> *) a. Applicative f => a -> f a
pure () -- Nothing to do; accumulators are phantoms.

-- | Copy from here to there; both destination and source be
-- indexeded.  If so, they better be arrays of enough dimensions.
-- This function will generally just Do What I Mean, and Do The Right
-- Thing.  Both destination and source must be in scope.
copyDWIM ::
  VName ->
  [DimIndex (Imp.TExp Int64)] ->
  SubExp ->
  [DimIndex (Imp.TExp Int64)] ->
  ImpM rep r op ()
copyDWIM :: forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM VName
dest [DimIndex (TExp Int64)]
dest_slice SubExp
src [DimIndex (TExp Int64)]
src_slice = do
  VarEntry rep
dest_entry <- forall rep r op. VName -> ImpM rep r op (VarEntry rep)
lookupVar VName
dest
  let dest_target :: ValueDestination
dest_target =
        case VarEntry rep
dest_entry of
          ScalarVar Maybe (Exp rep)
_ ScalarEntry
_ ->
            VName -> ValueDestination
ScalarDestination VName
dest
          ArrayVar Maybe (Exp rep)
_ (ArrayEntry (MemLoc VName
mem [SubExp]
shape IxFun (TExp Int64)
ixfun) PrimType
_) ->
            Maybe MemLoc -> ValueDestination
ArrayDestination forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ VName -> [SubExp] -> IxFun (TExp Int64) -> MemLoc
MemLoc VName
mem [SubExp]
shape IxFun (TExp Int64)
ixfun
          MemVar Maybe (Exp rep)
_ MemEntry
_ ->
            VName -> ValueDestination
MemoryDestination VName
dest
          AccVar {} ->
            -- Does not matter; accumulators are phantoms.
            Maybe MemLoc -> ValueDestination
ArrayDestination forall a. Maybe a
Nothing
  forall rep r op.
ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIMDest ValueDestination
dest_target [DimIndex (TExp Int64)]
dest_slice SubExp
src [DimIndex (TExp Int64)]
src_slice

-- | As 'copyDWIM', but implicitly 'DimFix'es the indexes.
copyDWIMFix ::
  VName ->
  [Imp.TExp Int64] ->
  SubExp ->
  [Imp.TExp Int64] ->
  ImpM rep r op ()
copyDWIMFix :: forall rep r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
dest [TExp Int64]
dest_is SubExp
src [TExp Int64]
src_is =
  forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM VName
dest (forall a b. (a -> b) -> [a] -> [b]
map forall d. d -> DimIndex d
DimFix [TExp Int64]
dest_is) SubExp
src (forall a b. (a -> b) -> [a] -> [b]
map forall d. d -> DimIndex d
DimFix [TExp Int64]
src_is)

-- | @compileAlloc pat size space@ allocates @n@ bytes of memory in
-- @space@, writing the result to @pat@, which must contain a single
-- memory-typed element.
compileAlloc ::
  Mem rep inner => Pat (LetDec rep) -> SubExp -> Space -> ImpM rep r op ()
compileAlloc :: forall rep (inner :: * -> *) r op.
Mem rep inner =>
Pat (LetDec rep) -> SubExp -> Space -> ImpM rep r op ()
compileAlloc (Pat [PatElem (LetDec rep)
mem]) SubExp
e Space
space = do
  let e' :: Count Bytes (TExp Int64)
e' = forall a. a -> Count Bytes a
Imp.bytes forall a b. (a -> b) -> a -> b
$ SubExp -> TExp Int64
pe64 SubExp
e
  Maybe (AllocCompiler rep r op)
allocator <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Space
space forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep r op. Env rep r op -> Map Space (AllocCompiler rep r op)
envAllocCompilers
  case Maybe (AllocCompiler rep r op)
allocator of
    Maybe (AllocCompiler rep r op)
Nothing -> forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. VName -> Count Bytes (TExp Int64) -> Space -> Code a
Imp.Allocate (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
mem) Count Bytes (TExp Int64)
e' Space
space
    Just AllocCompiler rep r op
allocator' -> AllocCompiler rep r op
allocator' (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
mem) Count Bytes (TExp Int64)
e'
compileAlloc Pat (LetDec rep)
pat SubExp
_ Space
_ =
  forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"compileAlloc: Invalid pattern: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString Pat (LetDec rep)
pat

-- | The number of bytes needed to represent the array in a
-- straightforward contiguous format, as an t'Int64' expression.
typeSize :: Type -> Count Bytes (Imp.TExp Int64)
typeSize :: Type -> Count Bytes (TExp Int64)
typeSize Type
t =
  forall a. a -> Count Bytes a
Imp.bytes forall a b. (a -> b) -> a -> b
$ forall a. Num a => PrimType -> a
primByteSize (forall shape u. TypeBase shape u -> PrimType
elemType Type
t) forall a. Num a => a -> a -> a
* forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product (forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 (forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t))

-- | Is this indexing in-bounds for an array of the given shape?  This
-- is useful for things like scatter, which ignores out-of-bounds
-- writes.
inBounds :: Slice (Imp.TExp Int64) -> [Imp.TExp Int64] -> Imp.TExp Bool
inBounds :: Slice (TExp Int64) -> [TExp Int64] -> TExp Bool
inBounds (Slice [DimIndex (TExp Int64)]
slice) [TExp Int64]
dims =
  let condInBounds :: DimIndex (TPrimExp t v) -> TPrimExp t v -> TPrimExp Bool v
condInBounds (DimFix TPrimExp t v
i) TPrimExp t v
d =
        TPrimExp t v
0 forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TPrimExp t v
i forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TPrimExp t v
i forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp t v
d
      condInBounds (DimSlice TPrimExp t v
i TPrimExp t v
n TPrimExp t v
s) TPrimExp t v
d =
        TPrimExp t v
0 forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TPrimExp t v
i forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TPrimExp t v
i forall a. Num a => a -> a -> a
+ (TPrimExp t v
n forall a. Num a => a -> a -> a
- TPrimExp t v
1) forall a. Num a => a -> a -> a
* TPrimExp t v
s forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp t v
d
   in forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
(.&&.) forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall {k} {v} {t :: k}.
(Eq v, NumExp t, Pretty v) =>
DimIndex (TPrimExp t v) -> TPrimExp t v -> TPrimExp Bool v
condInBounds [DimIndex (TExp Int64)]
slice [TExp Int64]
dims

--- Building blocks for constructing code.

rotateIndex ::
  Imp.TExp Int64 ->
  Imp.TExp Int64 ->
  Imp.TExp Int64 ->
  Imp.TExp Int64
rotateIndex :: TExp Int64 -> TExp Int64 -> TExp Int64 -> TExp Int64
rotateIndex TExp Int64
d TExp Int64
r TExp Int64
i = (TExp Int64
i forall a. Num a => a -> a -> a
+ TExp Int64
r) forall e. IntegralExp e => e -> e -> e
`mod` TExp Int64
d

sFor' :: VName -> Imp.Exp -> ImpM rep r op () -> ImpM rep r op ()
sFor' :: forall rep r op.
VName -> Exp -> ImpM rep r op () -> ImpM rep r op ()
sFor' VName
i Exp
bound ImpM rep r op ()
body = do
  let it :: IntType
it = case forall v. PrimExp v -> PrimType
primExpType Exp
bound of
        IntType IntType
bound_t -> IntType
bound_t
        PrimType
t -> forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"sFor': bound " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString Exp
bound forall a. [a] -> [a] -> [a]
++ [Char]
" is of type " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString PrimType
t
  forall rep r op. VName -> IntType -> ImpM rep r op ()
addLoopVar VName
i IntType
it
  Code op
body' <- forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect ImpM rep r op ()
body
  forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. VName -> Exp -> Code a -> Code a
Imp.For VName
i Exp
bound Code op
body'

sFor :: String -> Imp.TExp t -> (Imp.TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor :: forall {k} (t :: k) rep r op.
[Char]
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor [Char]
i TExp t
bound TExp t -> ImpM rep r op ()
body = do
  VName
i' <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
i
  forall rep r op.
VName -> Exp -> ImpM rep r op () -> ImpM rep r op ()
sFor' VName
i' (forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp t
bound) forall a b. (a -> b) -> a -> b
$
    TExp t -> ImpM rep r op ()
body forall a b. (a -> b) -> a -> b
$
      forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp forall a b. (a -> b) -> a -> b
$
        VName -> PrimType -> Exp
Imp.var VName
i' forall a b. (a -> b) -> a -> b
$
          forall v. PrimExp v -> PrimType
primExpType forall a b. (a -> b) -> a -> b
$
            forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp t
bound

sWhile :: Imp.TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhile :: forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhile TExp Bool
cond ImpM rep r op ()
body = do
  Code op
body' <- forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect ImpM rep r op ()
body
  forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. TExp Bool -> Code a -> Code a
Imp.While TExp Bool
cond Code op
body'

sComment :: T.Text -> ImpM rep r op () -> ImpM rep r op ()
sComment :: forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
s ImpM rep r op ()
code = do
  Code op
code' <- forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect ImpM rep r op ()
code
  forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. Text -> Code a -> Code a
Imp.Comment Text
s Code op
code'

sIf :: Imp.TExp Bool -> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf :: forall rep r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf TExp Bool
cond ImpM rep r op ()
tbranch ImpM rep r op ()
fbranch = do
  Code op
tbranch' <- forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect ImpM rep r op ()
tbranch
  Code op
fbranch' <- forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect ImpM rep r op ()
fbranch
  -- Avoid generating branch if the condition is known statically.
  forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$
    if TExp Bool
cond forall a. Eq a => a -> a -> Bool
== forall v. TPrimExp Bool v
true
      then Code op
tbranch'
      else
        if TExp Bool
cond forall a. Eq a => a -> a -> Bool
== forall v. TPrimExp Bool v
false
          then Code op
fbranch'
          else forall a. TExp Bool -> Code a -> Code a -> Code a
Imp.If TExp Bool
cond Code op
tbranch' Code op
fbranch'

sWhen :: Imp.TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen :: forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
cond ImpM rep r op ()
tbranch = forall rep r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf TExp Bool
cond ImpM rep r op ()
tbranch (forall (f :: * -> *) a. Applicative f => a -> f a
pure ())

sUnless :: Imp.TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sUnless :: forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sUnless TExp Bool
cond = forall rep r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf TExp Bool
cond (forall (f :: * -> *) a. Applicative f => a -> f a
pure ())

sOp :: op -> ImpM rep r op ()
sOp :: forall op rep r. op -> ImpM rep r op ()
sOp = forall op rep r. Code op -> ImpM rep r op ()
emit forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> Code a
Imp.Op

sDeclareMem :: String -> Space -> ImpM rep r op VName
sDeclareMem :: forall rep r op. [Char] -> Space -> ImpM rep r op VName
sDeclareMem [Char]
name Space
space = do
  VName
name' <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
name
  forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. VName -> Space -> Code a
Imp.DeclareMem VName
name' Space
space
  forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
name' forall a b. (a -> b) -> a -> b
$ forall rep. Maybe (Exp rep) -> MemEntry -> VarEntry rep
MemVar forall a. Maybe a
Nothing forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
space
  forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
name'

sAlloc_ :: VName -> Count Bytes (Imp.TExp Int64) -> Space -> ImpM rep r op ()
sAlloc_ :: forall rep r op.
VName -> Count Bytes (TExp Int64) -> Space -> ImpM rep r op ()
sAlloc_ VName
name' Count Bytes (TExp Int64)
size' Space
space = do
  Maybe (AllocCompiler rep r op)
allocator <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Space
space forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep r op. Env rep r op -> Map Space (AllocCompiler rep r op)
envAllocCompilers
  case Maybe (AllocCompiler rep r op)
allocator of
    Maybe (AllocCompiler rep r op)
Nothing -> forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. VName -> Count Bytes (TExp Int64) -> Space -> Code a
Imp.Allocate VName
name' Count Bytes (TExp Int64)
size' Space
space
    Just AllocCompiler rep r op
allocator' -> AllocCompiler rep r op
allocator' VName
name' Count Bytes (TExp Int64)
size'

sAlloc :: String -> Count Bytes (Imp.TExp Int64) -> Space -> ImpM rep r op VName
sAlloc :: forall rep r op.
[Char] -> Count Bytes (TExp Int64) -> Space -> ImpM rep r op VName
sAlloc [Char]
name Count Bytes (TExp Int64)
size Space
space = do
  VName
name' <- forall rep r op. [Char] -> Space -> ImpM rep r op VName
sDeclareMem [Char]
name Space
space
  forall rep r op.
VName -> Count Bytes (TExp Int64) -> Space -> ImpM rep r op ()
sAlloc_ VName
name' Count Bytes (TExp Int64)
size Space
space
  forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
name'

sArray :: String -> PrimType -> ShapeBase SubExp -> VName -> IxFun -> ImpM rep r op VName
sArray :: forall rep r op.
[Char]
-> PrimType
-> Shape
-> VName
-> IxFun (TExp Int64)
-> ImpM rep r op VName
sArray [Char]
name PrimType
bt Shape
shape VName
mem IxFun (TExp Int64)
ixfun = do
  VName
name' <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
name
  forall rep r op.
VName
-> PrimType
-> Shape
-> VName
-> IxFun (TExp Int64)
-> ImpM rep r op ()
dArray VName
name' PrimType
bt Shape
shape VName
mem IxFun (TExp Int64)
ixfun
  forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
name'

-- | Declare an array in row-major order in the given memory block.
sArrayInMem :: String -> PrimType -> ShapeBase SubExp -> VName -> ImpM rep r op VName
sArrayInMem :: forall rep r op.
[Char] -> PrimType -> Shape -> VName -> ImpM rep r op VName
sArrayInMem [Char]
name PrimType
pt Shape
shape VName
mem =
  forall rep r op.
[Char]
-> PrimType
-> Shape
-> VName
-> IxFun (TExp Int64)
-> ImpM rep r op VName
sArray [Char]
name PrimType
pt Shape
shape VName
mem forall a b. (a -> b) -> a -> b
$
    forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota forall a b. (a -> b) -> a -> b
$
      forall a b. (a -> b) -> [a] -> [b]
map (forall v. PrimExp v -> TPrimExp Int64 v
isInt64 forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> SubExp -> Exp
primExpFromSubExp PrimType
int64) forall a b. (a -> b) -> a -> b
$
        forall d. ShapeBase d -> [d]
shapeDims Shape
shape

-- | Like 'sAllocArray', but permute the in-memory representation of the indices as specified.
sAllocArrayPerm :: String -> PrimType -> ShapeBase SubExp -> Space -> [Int] -> ImpM rep r op VName
sAllocArrayPerm :: forall rep r op.
[Char]
-> PrimType -> Shape -> Space -> [Int] -> ImpM rep r op VName
sAllocArrayPerm [Char]
name PrimType
pt Shape
shape Space
space [Int]
perm = do
  let permuted_dims :: [SubExp]
permuted_dims = forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims Shape
shape
  VName
mem <- forall rep r op.
[Char] -> Count Bytes (TExp Int64) -> Space -> ImpM rep r op VName
sAlloc ([Char]
name forall a. [a] -> [a] -> [a]
++ [Char]
"_mem") (Type -> Count Bytes (TExp Int64)
typeSize (forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
pt Shape
shape NoUniqueness
NoUniqueness)) Space
space
  let iota_ixfun :: IxFun (TExp Int64)
iota_ixfun = forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall v. PrimExp v -> TPrimExp Int64 v
isInt64 forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> SubExp -> Exp
primExpFromSubExp PrimType
int64) [SubExp]
permuted_dims
  forall rep r op.
[Char]
-> PrimType
-> Shape
-> VName
-> IxFun (TExp Int64)
-> ImpM rep r op VName
sArray [Char]
name PrimType
pt Shape
shape VName
mem forall a b. (a -> b) -> a -> b
$
    forall num. IntegralExp num => IxFun num -> [Int] -> IxFun num
IxFun.permute IxFun (TExp Int64)
iota_ixfun forall a b. (a -> b) -> a -> b
$
      [Int] -> [Int]
rearrangeInverse [Int]
perm

-- | Uses linear/iota index function.
sAllocArray :: String -> PrimType -> ShapeBase SubExp -> Space -> ImpM rep r op VName
sAllocArray :: forall rep r op.
[Char] -> PrimType -> Shape -> Space -> ImpM rep r op VName
sAllocArray [Char]
name PrimType
pt Shape
shape Space
space =
  forall rep r op.
[Char]
-> PrimType -> Shape -> Space -> [Int] -> ImpM rep r op VName
sAllocArrayPerm [Char]
name PrimType
pt Shape
shape Space
space [Int
0 .. forall a. ArrayShape a => a -> Int
shapeRank Shape
shape forall a. Num a => a -> a -> a
- Int
1]

-- | Uses linear/iota index function.
sStaticArray :: String -> PrimType -> Imp.ArrayContents -> ImpM rep r op VName
sStaticArray :: forall rep r op.
[Char] -> PrimType -> ArrayContents -> ImpM rep r op VName
sStaticArray [Char]
name PrimType
pt ArrayContents
vs = do
  let num_elems :: Int
num_elems = case ArrayContents
vs of
        Imp.ArrayValues [PrimValue]
vs' -> forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimValue]
vs'
        Imp.ArrayZeros Int
n -> forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n
      shape :: Shape
shape = forall d. [d] -> ShapeBase d
Shape [IntType -> Integer -> SubExp
intConst IntType
Int64 forall a b. (a -> b) -> a -> b
$ forall a. Integral a => a -> Integer
toInteger Int
num_elems]
  VName
mem <- forall rep r op. [Char] -> ImpM rep r op VName
newVNameForFun forall a b. (a -> b) -> a -> b
$ [Char]
name forall a. [a] -> [a] -> [a]
++ [Char]
"_mem"
  forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. VName -> PrimType -> ArrayContents -> Code a
Imp.DeclareArray VName
mem PrimType
pt ArrayContents
vs
  forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
mem forall a b. (a -> b) -> a -> b
$ forall rep. Maybe (Exp rep) -> MemEntry -> VarEntry rep
MemVar forall a. Maybe a
Nothing forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
DefaultSpace
  forall rep r op.
[Char]
-> PrimType
-> Shape
-> VName
-> IxFun (TExp Int64)
-> ImpM rep r op VName
sArray [Char]
name PrimType
pt Shape
shape VName
mem forall a b. (a -> b) -> a -> b
$ forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota [forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
num_elems]

sWrite :: VName -> [Imp.TExp Int64] -> Imp.Exp -> ImpM rep r op ()
sWrite :: forall rep r op. VName -> [TExp Int64] -> Exp -> ImpM rep r op ()
sWrite VName
arr [TExp Int64]
is Exp
v = do
  (VName
mem, Space
space, Count Elements (TExp Int64)
offset) <- forall rep r op.
VName
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray VName
arr [TExp Int64]
is
  Volatility
vol <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall rep r op. Env rep r op -> Volatility
envVolatility
  forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
mem Count Elements (TExp Int64)
offset (forall v. PrimExp v -> PrimType
primExpType Exp
v) Space
space Volatility
vol Exp
v

sUpdate :: VName -> Slice (Imp.TExp Int64) -> SubExp -> ImpM rep r op ()
sUpdate :: forall rep r op.
VName -> Slice (TExp Int64) -> SubExp -> ImpM rep r op ()
sUpdate VName
arr Slice (TExp Int64)
slice SubExp
v = forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM VName
arr (forall d. Slice d -> [DimIndex d]
unSlice Slice (TExp Int64)
slice) SubExp
v []

-- | Create a sequential 'Imp.For' loop covering a space of the given
-- shape.  The function is calling with the indexes for a given
-- iteration.
sLoopSpace ::
  [Imp.TExp t] ->
  ([Imp.TExp t] -> ImpM rep r op ()) ->
  ImpM rep r op ()
sLoopSpace :: forall {k} (t :: k) rep r op.
[TExp t] -> ([TExp t] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopSpace = forall {k} {t :: k} {rep} {r} {op}.
[TExp t]
-> [TExp t] -> ([TExp t] -> ImpM rep r op ()) -> ImpM rep r op ()
nest []
  where
    nest :: [TExp t]
-> [TExp t] -> ([TExp t] -> ImpM rep r op ()) -> ImpM rep r op ()
nest [TExp t]
is [] [TExp t] -> ImpM rep r op ()
f = [TExp t] -> ImpM rep r op ()
f forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
reverse [TExp t]
is
    nest [TExp t]
is (TExp t
d : [TExp t]
ds) [TExp t] -> ImpM rep r op ()
f = forall {k} (t :: k) rep r op.
[Char]
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor [Char]
"nest_i" TExp t
d forall a b. (a -> b) -> a -> b
$ \TExp t
i -> [TExp t]
-> [TExp t] -> ([TExp t] -> ImpM rep r op ()) -> ImpM rep r op ()
nest (TExp t
i forall a. a -> [a] -> [a]
: [TExp t]
is) [TExp t]
ds [TExp t] -> ImpM rep r op ()
f

sLoopNest ::
  Shape ->
  ([Imp.TExp Int64] -> ImpM rep r op ()) ->
  ImpM rep r op ()
sLoopNest :: forall rep r op.
Shape -> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest = forall {k} (t :: k) rep r op.
[TExp t] -> ([TExp t] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopSpace forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d. ShapeBase d -> [d]
shapeDims

sCopy ::
  VName ->
  Imp.TExp Int64 ->
  Space ->
  VName ->
  Imp.TExp Int64 ->
  Space ->
  Count Elements (Imp.TExp Int64) ->
  PrimType ->
  ImpM rep r op ()
sCopy :: forall rep r op.
VName
-> TExp Int64
-> Space
-> VName
-> TExp Int64
-> Space
-> Count Elements (TExp Int64)
-> PrimType
-> ImpM rep r op ()
sCopy VName
destmem TExp Int64
destoffset Space
destspace VName
srcmem TExp Int64
srcoffset Space
srcspace Count Elements (TExp Int64)
num_elems PrimType
pt =
  if VName
destmem forall a. Eq a => a -> a -> Bool
== VName
srcmem
    then forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sUnless (TExp Int64
destoffset forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int64
srcoffset) ImpM rep r op ()
the_copy
    else ImpM rep r op ()
the_copy
  where
    the_copy :: ImpM rep r op ()
the_copy =
      forall op rep r. Code op -> ImpM rep r op ()
emit
        forall a b. (a -> b) -> a -> b
$ forall a.
PrimType
-> VName
-> Count Bytes (TExp Int64)
-> Space
-> VName
-> Count Bytes (TExp Int64)
-> Space
-> Count Bytes (TExp Int64)
-> Code a
Imp.Copy
          PrimType
pt
          VName
destmem
          (forall a. a -> Count Bytes a
bytes TExp Int64
destoffset)
          Space
destspace
          VName
srcmem
          (forall a. a -> Count Bytes a
bytes TExp Int64
srcoffset)
          Space
srcspace
        forall a b. (a -> b) -> a -> b
$ Count Elements (TExp Int64)
num_elems Count Elements (TExp Int64) -> PrimType -> Count Bytes (TExp Int64)
`withElemType` PrimType
pt

-- | Untyped assignment.
(<~~) :: VName -> Imp.Exp -> ImpM rep r op ()
VName
x <~~ :: forall rep r op. VName -> Exp -> ImpM rep r op ()
<~~ Exp
e = forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. VName -> Exp -> Code a
Imp.SetScalar VName
x Exp
e

infixl 3 <~~

-- | Typed assignment.
(<--) :: TV t -> Imp.TExp t -> ImpM rep r op ()
TV VName
x PrimType
_ <-- :: forall {k} (t :: k) rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TExp t
e = forall op rep r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. VName -> Exp -> Code a
Imp.SetScalar VName
x forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp t
e

infixl 3 <--

-- | Constructing an ad-hoc function that does not
-- correspond to any of the IR functions in the input program.
function ::
  Name ->
  [Imp.Param] ->
  [Imp.Param] ->
  ImpM rep r op () ->
  ImpM rep r op ()
function :: forall rep r op.
Name -> [Param] -> [Param] -> ImpM rep r op () -> ImpM rep r op ()
function Name
fname [Param]
outputs [Param]
inputs ImpM rep r op ()
m = forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local Env rep r op -> Env rep r op
newFunction forall a b. (a -> b) -> a -> b
$ do
  Code op
body <- forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect forall a b. (a -> b) -> a -> b
$ do
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall {rep} {r} {op}. Param -> ImpM rep r op ()
addParam forall a b. (a -> b) -> a -> b
$ [Param]
outputs forall a. [a] -> [a] -> [a]
++ [Param]
inputs
    ImpM rep r op ()
m
  forall op rep r. Name -> Function op -> ImpM rep r op ()
emitFunction Name
fname forall a b. (a -> b) -> a -> b
$ forall a.
Maybe EntryPoint -> [Param] -> [Param] -> Code a -> FunctionT a
Imp.Function forall a. Maybe a
Nothing [Param]
outputs [Param]
inputs Code op
body
  where
    addParam :: Param -> ImpM rep r op ()
addParam (Imp.MemParam VName
name Space
space) =
      forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
name forall a b. (a -> b) -> a -> b
$ forall rep. Maybe (Exp rep) -> MemEntry -> VarEntry rep
MemVar forall a. Maybe a
Nothing forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
space
    addParam (Imp.ScalarParam VName
name PrimType
bt) =
      forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
name forall a b. (a -> b) -> a -> b
$ forall rep. Maybe (Exp rep) -> ScalarEntry -> VarEntry rep
ScalarVar forall a. Maybe a
Nothing forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarEntry
ScalarEntry PrimType
bt
    newFunction :: Env rep r op -> Env rep r op
newFunction Env rep r op
env = Env rep r op
env {envFunction :: Maybe Name
envFunction = forall a. a -> Maybe a
Just Name
fname}

-- Fish out those top-level declarations in the constant
-- initialisation code that are free in the functions.
constParams :: Names -> Imp.Code a -> (DL.DList Imp.Param, Imp.Code a)
constParams :: forall a. Names -> Code a -> (DList Param, Code a)
constParams Names
used (Code a
x Imp.:>>: Code a
y) =
  forall a. Names -> Code a -> (DList Param, Code a)
constParams Names
used Code a
x forall a. Semigroup a => a -> a -> a
<> forall a. Names -> Code a -> (DList Param, Code a)
constParams Names
used Code a
y
constParams Names
used (Imp.DeclareMem VName
name Space
space)
  | VName
name VName -> Names -> Bool
`nameIn` Names
used =
      ( forall a. a -> DList a
DL.singleton forall a b. (a -> b) -> a -> b
$ VName -> Space -> Param
Imp.MemParam VName
name Space
space,
        forall a. Monoid a => a
mempty
      )
constParams Names
used (Imp.DeclareScalar VName
name Volatility
_ PrimType
t)
  | VName
name VName -> Names -> Bool
`nameIn` Names
used =
      ( forall a. a -> DList a
DL.singleton forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Param
Imp.ScalarParam VName
name PrimType
t,
        forall a. Monoid a => a
mempty
      )
constParams Names
used s :: Code a
s@(Imp.DeclareArray VName
name PrimType
_ ArrayContents
_)
  | VName
name VName -> Names -> Bool
`nameIn` Names
used =
      ( forall a. a -> DList a
DL.singleton forall a b. (a -> b) -> a -> b
$ VName -> Space -> Param
Imp.MemParam VName
name Space
DefaultSpace,
        Code a
s
      )
constParams Names
_ Code a
s =
  (forall a. Monoid a => a
mempty, Code a
s)

-- | Generate constants that get put outside of all functions.  Will
-- be executed at program startup.  Action must return the names that
-- should should be made available.  This one has real sharp edges. Do
-- not use inside 'subImpM'.  Do not use any variable from the context.
genConstants :: ImpM rep r op (Names, a) -> ImpM rep r op a
genConstants :: forall rep r op a. ImpM rep r op (Names, a) -> ImpM rep r op a
genConstants ImpM rep r op (Names, a)
m = do
  ((Names
avail, a
a), Code op
code) <- forall rep r op a. ImpM rep r op a -> ImpM rep r op (a, Code op)
collect' ImpM rep r op (Names, a)
m
  let consts :: Constants op
consts = forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry forall a. [Param] -> Code a -> Constants a
Imp.Constants forall a b. (a -> b) -> a -> b
$ forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first forall a. DList a -> [a]
DL.toList forall a b. (a -> b) -> a -> b
$ forall a. Names -> Code a -> (DList Param, Code a)
constParams Names
avail Code op
code
  forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s -> ImpState rep r op
s {stateConstants :: Constants op
stateConstants = forall {k} rep (r :: k) op. ImpState rep r op -> Constants op
stateConstants ImpState rep r op
s forall a. Semigroup a => a -> a -> a
<> Constants op
consts}
  forall (f :: * -> *) a. Applicative f => a -> f a
pure a
a

dSlices :: [Imp.TExp Int64] -> ImpM rep r op [Imp.TExp Int64]
dSlices :: forall rep r op. [TExp Int64] -> ImpM rep r op [TExp Int64]
dSlices = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a. Int -> [a] -> [a]
drop Int
1 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} {t :: k} {rep} {r} {op}.
NumExp t =>
[TExp t] -> ImpM rep r op (TExp t, [TExp t])
dSlices'
  where
    dSlices' :: [TExp t] -> ImpM rep r op (TExp t, [TExp t])
dSlices' [] = forall (f :: * -> *) a. Applicative f => a -> f a
pure (TExp t
1, [TExp t
1])
    dSlices' (TExp t
n : [TExp t]
ns) = do
      (TExp t
prod, [TExp t]
ns') <- [TExp t] -> ImpM rep r op (TExp t, [TExp t])
dSlices' [TExp t]
ns
      TExp t
n' <- forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"slice" forall a b. (a -> b) -> a -> b
$ TExp t
n forall a. Num a => a -> a -> a
* TExp t
prod
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (TExp t
n', TExp t
n' forall a. a -> [a] -> [a]
: [TExp t]
ns')

-- | @dIndexSpace f dims i@ computes a list of indices into an
-- array with dimension @dims@ given the flat index @i@.  The
-- resulting list will have the same size as @dims@.  Intermediate
-- results are passed to @f@.
dIndexSpace ::
  [(VName, Imp.TExp Int64)] ->
  Imp.TExp Int64 ->
  ImpM rep r op ()
dIndexSpace :: forall rep r op.
[(VName, TExp Int64)] -> TExp Int64 -> ImpM rep r op ()
dIndexSpace [(VName, TExp Int64)]
vs_ds TExp Int64
j = do
  [TExp Int64]
slices <- forall rep r op. [TExp Int64] -> ImpM rep r op [TExp Int64]
dSlices (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [(VName, TExp Int64)]
vs_ds)
  forall rep r op.
[(VName, TExp Int64)] -> TExp Int64 -> ImpM rep r op ()
loop (forall a b. [a] -> [b] -> [(a, b)]
zip (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(VName, TExp Int64)]
vs_ds) [TExp Int64]
slices) TExp Int64
j
  where
    loop :: [(VName, TExp Int64)] -> TExp Int64 -> ImpM rep r op ()
loop ((VName
v, TExp Int64
size) : [(VName, TExp Int64)]
rest) TExp Int64
i = do
      forall {k} (t :: k) rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
v (TExp Int64
i forall e. IntegralExp e => e -> e -> e
`quot` TExp Int64
size)
      TExp Int64
i' <- forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"remnant" forall a b. (a -> b) -> a -> b
$ TExp Int64
i forall a. Num a => a -> a -> a
- forall a. a -> TPrimExp Int64 a
Imp.le64 VName
v forall a. Num a => a -> a -> a
* TExp Int64
size
      [(VName, TExp Int64)] -> TExp Int64 -> ImpM rep r op ()
loop [(VName, TExp Int64)]
rest TExp Int64
i'
    loop [(VName, TExp Int64)]
_ TExp Int64
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

-- | Like 'dIndexSpace', but invent some new names for the indexes
-- based on the given template.
dIndexSpace' ::
  String ->
  [Imp.TExp Int64] ->
  Imp.TExp Int64 ->
  ImpM rep r op [Imp.TExp Int64]
dIndexSpace' :: forall rep r op.
[Char] -> [TExp Int64] -> TExp Int64 -> ImpM rep r op [TExp Int64]
dIndexSpace' [Char]
desc [TExp Int64]
ds TExp Int64
j = do
  [VName]
ivs <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (forall (t :: * -> *) a. Foldable t => t a -> Int
length [TExp Int64]
ds) (forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
desc)
  forall rep r op.
[(VName, TExp Int64)] -> TExp Int64 -> ImpM rep r op ()
dIndexSpace (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
ivs [TExp Int64]
ds) TExp Int64
j
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
ivs