{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE Strict #-}
{-# LANGUAGE Trustworthy #-}
{-# LANGUAGE TypeFamilies #-}

module Futhark.CodeGen.ImpGen
  ( -- * Entry Points
    compileProg,

    -- * Pluggable Compiler
    OpCompiler,
    ExpCompiler,
    CopyCompiler,
    StmsCompiler,
    AllocCompiler,
    Operations (..),
    defaultOperations,
    MemLocation (..),
    MemEntry (..),
    ScalarEntry (..),

    -- * Monadic Compiler Interface
    ImpM,
    localDefaultSpace,
    askFunction,
    newVNameForFun,
    nameForFun,
    askEnv,
    localEnv,
    localOps,
    VTable,
    getVTable,
    localVTable,
    subImpM,
    subImpM_,
    emit,
    emitFunction,
    hasFunction,
    collect,
    collect',
    comment,
    VarEntry (..),
    ArrayEntry (..),

    -- * Lookups
    lookupVar,
    lookupArray,
    lookupMemory,
    lookupAcc,

    -- * Building Blocks
    TV,
    mkTV,
    tvSize,
    tvExp,
    tvVar,
    ToExp (..),
    compileAlloc,
    everythingVolatile,
    compileBody,
    compileBody',
    compileLoopBody,
    defCompileStms,
    compileStms,
    compileExp,
    defCompileExp,
    fullyIndexArray,
    fullyIndexArray',
    copy,
    copyDWIM,
    copyDWIMFix,
    copyElementWise,
    typeSize,
    inBounds,
    isMapTransposeCopy,

    -- * Constructing code.
    dLParams,
    dFParams,
    dScope,
    dArray,
    dPrim,
    dPrimVol,
    dPrim_,
    dPrimV_,
    dPrimV,
    dPrimVE,
    dIndexSpace,
    sFor,
    sWhile,
    sComment,
    sIf,
    sWhen,
    sUnless,
    sOp,
    sDeclareMem,
    sAlloc,
    sAlloc_,
    sArray,
    sArrayInMem,
    sAllocArray,
    sAllocArrayPerm,
    sStaticArray,
    sWrite,
    sUpdate,
    sLoopNest,
    (<--),
    (<~~),
    function,
    warn,
    module Language.Futhark.Warnings,
  )
where

import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Writer
import Control.Parallel.Strategies
import Data.Bifunctor (first)
import qualified Data.DList as DL
import Data.Either
import Data.List (find, genericLength, sortOn)
import qualified Data.Map.Strict as M
import Data.Maybe
import qualified Data.Set as S
import Data.String
import Futhark.CodeGen.ImpCode
  ( Bytes,
    Count,
    Elements,
    bytes,
    elements,
    withElemType,
  )
import qualified Futhark.CodeGen.ImpCode as Imp
import Futhark.CodeGen.ImpGen.Transpose
import Futhark.Construct hiding (ToExp (..))
import Futhark.IR.Mem
import qualified Futhark.IR.Mem.IxFun as IxFun
import Futhark.IR.SOACS (SOACS)
import Futhark.Util
import Futhark.Util.IntegralExp
import Futhark.Util.Loc (noLoc)
import Language.Futhark.Warnings
import Prelude hiding (quot)

-- | How to compile an t'Op'.
type OpCompiler rep r op = Pattern rep -> Op rep -> ImpM rep r op ()

-- | How to compile some 'Stms'.
type StmsCompiler rep r op = Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()

-- | How to compile an 'Exp'.
type ExpCompiler rep r op = Pattern rep -> Exp rep -> ImpM rep r op ()

type CopyCompiler rep r op =
  PrimType ->
  MemLocation ->
  Slice (Imp.TExp Int64) ->
  MemLocation ->
  Slice (Imp.TExp Int64) ->
  ImpM rep r op ()

-- | An alternate way of compiling an allocation.
type AllocCompiler rep r op = VName -> Count Bytes (Imp.TExp Int64) -> ImpM rep r op ()

data Operations rep r op = Operations
  { forall rep r op. Operations rep r op -> ExpCompiler rep r op
opsExpCompiler :: ExpCompiler rep r op,
    forall rep r op. Operations rep r op -> OpCompiler rep r op
opsOpCompiler :: OpCompiler rep r op,
    forall rep r op. Operations rep r op -> StmsCompiler rep r op
opsStmsCompiler :: StmsCompiler rep r op,
    forall rep r op. Operations rep r op -> CopyCompiler rep r op
opsCopyCompiler :: CopyCompiler rep r op,
    forall rep r op.
Operations rep r op -> Map Space (AllocCompiler rep r op)
opsAllocCompilers :: M.Map Space (AllocCompiler rep r op)
  }

-- | An operations set for which the expression compiler always
-- returns 'defCompileExp'.
defaultOperations ::
  (Mem rep, FreeIn op) =>
  OpCompiler rep r op ->
  Operations rep r op
defaultOperations :: forall rep op r.
(Mem rep, FreeIn op) =>
OpCompiler rep r op -> Operations rep r op
defaultOperations OpCompiler rep r op
opc =
  Operations :: forall rep r op.
ExpCompiler rep r op
-> OpCompiler rep r op
-> StmsCompiler rep r op
-> CopyCompiler rep r op
-> Map Space (AllocCompiler rep r op)
-> Operations rep r op
Operations
    { opsExpCompiler :: ExpCompiler rep r op
opsExpCompiler = ExpCompiler rep r op
forall rep r op.
Mem rep =>
Pattern rep -> Exp rep -> ImpM rep r op ()
defCompileExp,
      opsOpCompiler :: OpCompiler rep r op
opsOpCompiler = OpCompiler rep r op
opc,
      opsStmsCompiler :: StmsCompiler rep r op
opsStmsCompiler = StmsCompiler rep r op
forall rep op r.
(Mem rep, FreeIn op) =>
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
defCompileStms,
      opsCopyCompiler :: CopyCompiler rep r op
opsCopyCompiler = CopyCompiler rep r op
forall rep r op. CopyCompiler rep r op
defaultCopy,
      opsAllocCompilers :: Map Space (AllocCompiler rep r op)
opsAllocCompilers = Map Space (AllocCompiler rep r op)
forall a. Monoid a => a
mempty
    }

-- | When an array is dared, this is where it is stored.
data MemLocation = MemLocation
  { MemLocation -> VName
memLocationName :: VName,
    MemLocation -> [SubExp]
memLocationShape :: [Imp.DimSize],
    MemLocation -> IxFun (TExp Int64)
memLocationIxFun :: IxFun.IxFun (Imp.TExp Int64)
  }
  deriving (MemLocation -> MemLocation -> Bool
(MemLocation -> MemLocation -> Bool)
-> (MemLocation -> MemLocation -> Bool) -> Eq MemLocation
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: MemLocation -> MemLocation -> Bool
$c/= :: MemLocation -> MemLocation -> Bool
== :: MemLocation -> MemLocation -> Bool
$c== :: MemLocation -> MemLocation -> Bool
Eq, Int -> MemLocation -> ShowS
[MemLocation] -> ShowS
MemLocation -> [Char]
(Int -> MemLocation -> ShowS)
-> (MemLocation -> [Char])
-> ([MemLocation] -> ShowS)
-> Show MemLocation
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [MemLocation] -> ShowS
$cshowList :: [MemLocation] -> ShowS
show :: MemLocation -> [Char]
$cshow :: MemLocation -> [Char]
showsPrec :: Int -> MemLocation -> ShowS
$cshowsPrec :: Int -> MemLocation -> ShowS
Show)

data ArrayEntry = ArrayEntry
  { ArrayEntry -> MemLocation
entryArrayLocation :: MemLocation,
    ArrayEntry -> PrimType
entryArrayElemType :: PrimType
  }
  deriving (Int -> ArrayEntry -> ShowS
[ArrayEntry] -> ShowS
ArrayEntry -> [Char]
(Int -> ArrayEntry -> ShowS)
-> (ArrayEntry -> [Char])
-> ([ArrayEntry] -> ShowS)
-> Show ArrayEntry
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [ArrayEntry] -> ShowS
$cshowList :: [ArrayEntry] -> ShowS
show :: ArrayEntry -> [Char]
$cshow :: ArrayEntry -> [Char]
showsPrec :: Int -> ArrayEntry -> ShowS
$cshowsPrec :: Int -> ArrayEntry -> ShowS
Show)

entryArrayShape :: ArrayEntry -> [Imp.DimSize]
entryArrayShape :: ArrayEntry -> [SubExp]
entryArrayShape = MemLocation -> [SubExp]
memLocationShape (MemLocation -> [SubExp])
-> (ArrayEntry -> MemLocation) -> ArrayEntry -> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArrayEntry -> MemLocation
entryArrayLocation

newtype MemEntry = MemEntry {MemEntry -> Space
entryMemSpace :: Imp.Space}
  deriving (Int -> MemEntry -> ShowS
[MemEntry] -> ShowS
MemEntry -> [Char]
(Int -> MemEntry -> ShowS)
-> (MemEntry -> [Char]) -> ([MemEntry] -> ShowS) -> Show MemEntry
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [MemEntry] -> ShowS
$cshowList :: [MemEntry] -> ShowS
show :: MemEntry -> [Char]
$cshow :: MemEntry -> [Char]
showsPrec :: Int -> MemEntry -> ShowS
$cshowsPrec :: Int -> MemEntry -> ShowS
Show)

newtype ScalarEntry = ScalarEntry
  { ScalarEntry -> PrimType
entryScalarType :: PrimType
  }
  deriving (Int -> ScalarEntry -> ShowS
[ScalarEntry] -> ShowS
ScalarEntry -> [Char]
(Int -> ScalarEntry -> ShowS)
-> (ScalarEntry -> [Char])
-> ([ScalarEntry] -> ShowS)
-> Show ScalarEntry
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [ScalarEntry] -> ShowS
$cshowList :: [ScalarEntry] -> ShowS
show :: ScalarEntry -> [Char]
$cshow :: ScalarEntry -> [Char]
showsPrec :: Int -> ScalarEntry -> ShowS
$cshowsPrec :: Int -> ScalarEntry -> ShowS
Show)

-- | Every non-scalar variable must be associated with an entry.
data VarEntry rep
  = ArrayVar (Maybe (Exp rep)) ArrayEntry
  | ScalarVar (Maybe (Exp rep)) ScalarEntry
  | MemVar (Maybe (Exp rep)) MemEntry
  | AccVar (Maybe (Exp rep)) (VName, Shape, [Type])
  deriving (Int -> VarEntry rep -> ShowS
[VarEntry rep] -> ShowS
VarEntry rep -> [Char]
(Int -> VarEntry rep -> ShowS)
-> (VarEntry rep -> [Char])
-> ([VarEntry rep] -> ShowS)
-> Show (VarEntry rep)
forall rep. RepTypes rep => Int -> VarEntry rep -> ShowS
forall rep. RepTypes rep => [VarEntry rep] -> ShowS
forall rep. RepTypes rep => VarEntry rep -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [VarEntry rep] -> ShowS
$cshowList :: forall rep. RepTypes rep => [VarEntry rep] -> ShowS
show :: VarEntry rep -> [Char]
$cshow :: forall rep. RepTypes rep => VarEntry rep -> [Char]
showsPrec :: Int -> VarEntry rep -> ShowS
$cshowsPrec :: forall rep. RepTypes rep => Int -> VarEntry rep -> ShowS
Show)

-- | When compiling an expression, this is a description of where the
-- result should end up.  The integer is a reference to the construct
-- that gave rise to this destination (for patterns, this will be the
-- tag of the first name in the pattern).  This can be used to make
-- the generated code easier to relate to the original code.
data Destination = Destination
  { Destination -> Maybe Int
destinationTag :: Maybe Int,
    Destination -> [ValueDestination]
valueDestinations :: [ValueDestination]
  }
  deriving (Int -> Destination -> ShowS
[Destination] -> ShowS
Destination -> [Char]
(Int -> Destination -> ShowS)
-> (Destination -> [Char])
-> ([Destination] -> ShowS)
-> Show Destination
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [Destination] -> ShowS
$cshowList :: [Destination] -> ShowS
show :: Destination -> [Char]
$cshow :: Destination -> [Char]
showsPrec :: Int -> Destination -> ShowS
$cshowsPrec :: Int -> Destination -> ShowS
Show)

data ValueDestination
  = ScalarDestination VName
  | MemoryDestination VName
  | -- | The 'MemLocation' is 'Just' if a copy if
    -- required.  If it is 'Nothing', then a
    -- copy/assignment of a memory block somewhere
    -- takes care of this array.
    ArrayDestination (Maybe MemLocation)
  deriving (Int -> ValueDestination -> ShowS
[ValueDestination] -> ShowS
ValueDestination -> [Char]
(Int -> ValueDestination -> ShowS)
-> (ValueDestination -> [Char])
-> ([ValueDestination] -> ShowS)
-> Show ValueDestination
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [ValueDestination] -> ShowS
$cshowList :: [ValueDestination] -> ShowS
show :: ValueDestination -> [Char]
$cshow :: ValueDestination -> [Char]
showsPrec :: Int -> ValueDestination -> ShowS
$cshowsPrec :: Int -> ValueDestination -> ShowS
Show)

data Env rep r op = Env
  { forall rep r op. Env rep r op -> ExpCompiler rep r op
envExpCompiler :: ExpCompiler rep r op,
    forall rep r op. Env rep r op -> StmsCompiler rep r op
envStmsCompiler :: StmsCompiler rep r op,
    forall rep r op. Env rep r op -> OpCompiler rep r op
envOpCompiler :: OpCompiler rep r op,
    forall rep r op. Env rep r op -> CopyCompiler rep r op
envCopyCompiler :: CopyCompiler rep r op,
    forall rep r op. Env rep r op -> Map Space (AllocCompiler rep r op)
envAllocCompilers :: M.Map Space (AllocCompiler rep r op),
    forall rep r op. Env rep r op -> Space
envDefaultSpace :: Imp.Space,
    forall rep r op. Env rep r op -> Volatility
envVolatility :: Imp.Volatility,
    -- | User-extensible environment.
    forall rep r op. Env rep r op -> r
envEnv :: r,
    -- | Name of the function we are compiling, if any.
    forall rep r op. Env rep r op -> Maybe Name
envFunction :: Maybe Name,
    -- | The set of attributes that are active on the enclosing
    -- statements (including the one we are currently compiling).
    forall rep r op. Env rep r op -> Attrs
envAttrs :: Attrs
  }

newEnv :: r -> Operations rep r op -> Imp.Space -> Env rep r op
newEnv :: forall r rep op. r -> Operations rep r op -> Space -> Env rep r op
newEnv r
r Operations rep r op
ops Space
ds =
  Env :: forall rep r op.
ExpCompiler rep r op
-> StmsCompiler rep r op
-> OpCompiler rep r op
-> CopyCompiler rep r op
-> Map Space (AllocCompiler rep r op)
-> Space
-> Volatility
-> r
-> Maybe Name
-> Attrs
-> Env rep r op
Env
    { envExpCompiler :: ExpCompiler rep r op
envExpCompiler = Operations rep r op -> ExpCompiler rep r op
forall rep r op. Operations rep r op -> ExpCompiler rep r op
opsExpCompiler Operations rep r op
ops,
      envStmsCompiler :: StmsCompiler rep r op
envStmsCompiler = Operations rep r op -> StmsCompiler rep r op
forall rep r op. Operations rep r op -> StmsCompiler rep r op
opsStmsCompiler Operations rep r op
ops,
      envOpCompiler :: OpCompiler rep r op
envOpCompiler = Operations rep r op -> OpCompiler rep r op
forall rep r op. Operations rep r op -> OpCompiler rep r op
opsOpCompiler Operations rep r op
ops,
      envCopyCompiler :: CopyCompiler rep r op
envCopyCompiler = Operations rep r op -> CopyCompiler rep r op
forall rep r op. Operations rep r op -> CopyCompiler rep r op
opsCopyCompiler Operations rep r op
ops,
      envAllocCompilers :: Map Space (AllocCompiler rep r op)
envAllocCompilers = Map Space (AllocCompiler rep r op)
forall a. Monoid a => a
mempty,
      envDefaultSpace :: Space
envDefaultSpace = Space
ds,
      envVolatility :: Volatility
envVolatility = Volatility
Imp.Nonvolatile,
      envEnv :: r
envEnv = r
r,
      envFunction :: Maybe Name
envFunction = Maybe Name
forall a. Maybe a
Nothing,
      envAttrs :: Attrs
envAttrs = Attrs
forall a. Monoid a => a
mempty
    }

-- | The symbol table used during compilation.
type VTable rep = M.Map VName (VarEntry rep)

data ImpState rep r op = ImpState
  { forall rep r op. ImpState rep r op -> VTable rep
stateVTable :: VTable rep,
    forall rep r op. ImpState rep r op -> Functions op
stateFunctions :: Imp.Functions op,
    forall rep r op. ImpState rep r op -> Code op
stateCode :: Imp.Code op,
    forall rep r op. ImpState rep r op -> Warnings
stateWarnings :: Warnings,
    -- | Maps the arrays backing each accumulator to their
    -- update function and neutral elements.  This works
    -- because an array name can only become part of a single
    -- accumulator throughout its lifetime.  If the arrays
    -- backing an accumulator is not in this mapping, the
    -- accumulator is scatter-like.
    forall rep r op.
ImpState rep r op
-> Map VName ([VName], Maybe (Lambda rep, [SubExp]))
stateAccs :: M.Map VName ([VName], Maybe (Lambda rep, [SubExp])),
    forall rep r op. ImpState rep r op -> VNameSource
stateNameSource :: VNameSource
  }

newState :: VNameSource -> ImpState rep r op
newState :: forall rep r op. VNameSource -> ImpState rep r op
newState = VTable rep
-> Functions op
-> Code op
-> Warnings
-> Map VName ([VName], Maybe (Lambda rep, [SubExp]))
-> VNameSource
-> ImpState rep r op
forall rep r op.
VTable rep
-> Functions op
-> Code op
-> Warnings
-> Map VName ([VName], Maybe (Lambda rep, [SubExp]))
-> VNameSource
-> ImpState rep r op
ImpState VTable rep
forall a. Monoid a => a
mempty Functions op
forall a. Monoid a => a
mempty Code op
forall a. Monoid a => a
mempty Warnings
forall a. Monoid a => a
mempty Map VName ([VName], Maybe (Lambda rep, [SubExp]))
forall a. Monoid a => a
mempty

newtype ImpM rep r op a
  = ImpM (ReaderT (Env rep r op) (State (ImpState rep r op)) a)
  deriving
    ( (forall a b. (a -> b) -> ImpM rep r op a -> ImpM rep r op b)
-> (forall a b. a -> ImpM rep r op b -> ImpM rep r op a)
-> Functor (ImpM rep r op)
forall a b. a -> ImpM rep r op b -> ImpM rep r op a
forall a b. (a -> b) -> ImpM rep r op a -> ImpM rep r op b
forall rep r op a b. a -> ImpM rep r op b -> ImpM rep r op a
forall rep r op a b. (a -> b) -> ImpM rep r op a -> ImpM rep r op b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> ImpM rep r op b -> ImpM rep r op a
$c<$ :: forall rep r op a b. a -> ImpM rep r op b -> ImpM rep r op a
fmap :: forall a b. (a -> b) -> ImpM rep r op a -> ImpM rep r op b
$cfmap :: forall rep r op a b. (a -> b) -> ImpM rep r op a -> ImpM rep r op b
Functor,
      Functor (ImpM rep r op)
Functor (ImpM rep r op)
-> (forall a. a -> ImpM rep r op a)
-> (forall a b.
    ImpM rep r op (a -> b) -> ImpM rep r op a -> ImpM rep r op b)
-> (forall a b c.
    (a -> b -> c)
    -> ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op c)
-> (forall a b.
    ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op b)
-> (forall a b.
    ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op a)
-> Applicative (ImpM rep r op)
forall a. a -> ImpM rep r op a
forall a b. ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op a
forall a b. ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op b
forall a b.
ImpM rep r op (a -> b) -> ImpM rep r op a -> ImpM rep r op b
forall rep r op. Functor (ImpM rep r op)
forall a b c.
(a -> b -> c)
-> ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op c
forall rep r op a. a -> ImpM rep r op a
forall rep r op a b.
ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op a
forall rep r op a b.
ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op b
forall rep r op a b.
ImpM rep r op (a -> b) -> ImpM rep r op a -> ImpM rep r op b
forall rep r op a b c.
(a -> b -> c)
-> ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b. ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op a
$c<* :: forall rep r op a b.
ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op a
*> :: forall a b. ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op b
$c*> :: forall rep r op a b.
ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op b
liftA2 :: forall a b c.
(a -> b -> c)
-> ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op c
$cliftA2 :: forall rep r op a b c.
(a -> b -> c)
-> ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op c
<*> :: forall a b.
ImpM rep r op (a -> b) -> ImpM rep r op a -> ImpM rep r op b
$c<*> :: forall rep r op a b.
ImpM rep r op (a -> b) -> ImpM rep r op a -> ImpM rep r op b
pure :: forall a. a -> ImpM rep r op a
$cpure :: forall rep r op a. a -> ImpM rep r op a
Applicative,
      Applicative (ImpM rep r op)
Applicative (ImpM rep r op)
-> (forall a b.
    ImpM rep r op a -> (a -> ImpM rep r op b) -> ImpM rep r op b)
-> (forall a b.
    ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op b)
-> (forall a. a -> ImpM rep r op a)
-> Monad (ImpM rep r op)
forall a. a -> ImpM rep r op a
forall a b. ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op b
forall a b.
ImpM rep r op a -> (a -> ImpM rep r op b) -> ImpM rep r op b
forall rep r op. Applicative (ImpM rep r op)
forall rep r op a. a -> ImpM rep r op a
forall rep r op a b.
ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op b
forall rep r op a b.
ImpM rep r op a -> (a -> ImpM rep r op b) -> ImpM rep r op b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> ImpM rep r op a
$creturn :: forall rep r op a. a -> ImpM rep r op a
>> :: forall a b. ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op b
$c>> :: forall rep r op a b.
ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op b
>>= :: forall a b.
ImpM rep r op a -> (a -> ImpM rep r op b) -> ImpM rep r op b
$c>>= :: forall rep r op a b.
ImpM rep r op a -> (a -> ImpM rep r op b) -> ImpM rep r op b
Monad,
      MonadState (ImpState rep r op),
      MonadReader (Env rep r op)
    )

instance MonadFreshNames (ImpM rep r op) where
  getNameSource :: ImpM rep r op VNameSource
getNameSource = (ImpState rep r op -> VNameSource) -> ImpM rep r op VNameSource
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ImpState rep r op -> VNameSource
forall rep r op. ImpState rep r op -> VNameSource
stateNameSource
  putNameSource :: VNameSource -> ImpM rep r op ()
putNameSource VNameSource
src = (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ())
-> (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s -> ImpState rep r op
s {stateNameSource :: VNameSource
stateNameSource = VNameSource
src}

-- Cannot be an KernelsMem scope because the index functions have
-- the wrong leaves (VName instead of Imp.Exp).
instance HasScope SOACS (ImpM rep r op) where
  askScope :: ImpM rep r op (Scope SOACS)
askScope = (ImpState rep r op -> Scope SOACS) -> ImpM rep r op (Scope SOACS)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ImpState rep r op -> Scope SOACS) -> ImpM rep r op (Scope SOACS))
-> (ImpState rep r op -> Scope SOACS)
-> ImpM rep r op (Scope SOACS)
forall a b. (a -> b) -> a -> b
$ (VarEntry rep -> NameInfo SOACS)
-> Map VName (VarEntry rep) -> Scope SOACS
forall a b k. (a -> b) -> Map k a -> Map k b
M.map (Type -> NameInfo SOACS
forall rep. LetDec rep -> NameInfo rep
LetName (Type -> NameInfo SOACS)
-> (VarEntry rep -> Type) -> VarEntry rep -> NameInfo SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VarEntry rep -> Type
forall {rep}. VarEntry rep -> Type
entryType) (Map VName (VarEntry rep) -> Scope SOACS)
-> (ImpState rep r op -> Map VName (VarEntry rep))
-> ImpState rep r op
-> Scope SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ImpState rep r op -> Map VName (VarEntry rep)
forall rep r op. ImpState rep r op -> VTable rep
stateVTable
    where
      entryType :: VarEntry rep -> Type
entryType (MemVar Maybe (Exp rep)
_ MemEntry
memEntry) =
        Space -> Type
forall shape u. Space -> TypeBase shape u
Mem (MemEntry -> Space
entryMemSpace MemEntry
memEntry)
      entryType (ArrayVar Maybe (Exp rep)
_ ArrayEntry
arrayEntry) =
        PrimType -> Shape -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array
          (ArrayEntry -> PrimType
entryArrayElemType ArrayEntry
arrayEntry)
          ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> [SubExp] -> Shape
forall a b. (a -> b) -> a -> b
$ ArrayEntry -> [SubExp]
entryArrayShape ArrayEntry
arrayEntry)
          NoUniqueness
NoUniqueness
      entryType (ScalarVar Maybe (Exp rep)
_ ScalarEntry
scalarEntry) =
        PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> Type) -> PrimType -> Type
forall a b. (a -> b) -> a -> b
$ ScalarEntry -> PrimType
entryScalarType ScalarEntry
scalarEntry
      entryType (AccVar Maybe (Exp rep)
_ (VName
acc, Shape
ispace, [Type]
ts)) =
        VName -> Shape -> [Type] -> NoUniqueness -> Type
forall shape u. VName -> Shape -> [Type] -> u -> TypeBase shape u
Acc VName
acc Shape
ispace [Type]
ts NoUniqueness
NoUniqueness

runImpM ::
  ImpM rep r op a ->
  r ->
  Operations rep r op ->
  Imp.Space ->
  ImpState rep r op ->
  (a, ImpState rep r op)
runImpM :: forall rep r op a.
ImpM rep r op a
-> r
-> Operations rep r op
-> Space
-> ImpState rep r op
-> (a, ImpState rep r op)
runImpM (ImpM ReaderT (Env rep r op) (State (ImpState rep r op)) a
m) r
r Operations rep r op
ops Space
space = State (ImpState rep r op) a
-> ImpState rep r op -> (a, ImpState rep r op)
forall s a. State s a -> s -> (a, s)
runState (ReaderT (Env rep r op) (State (ImpState rep r op)) a
-> Env rep r op -> State (ImpState rep r op) a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Env rep r op) (State (ImpState rep r op)) a
m (Env rep r op -> State (ImpState rep r op) a)
-> Env rep r op -> State (ImpState rep r op) a
forall a b. (a -> b) -> a -> b
$ r -> Operations rep r op -> Space -> Env rep r op
forall r rep op. r -> Operations rep r op -> Space -> Env rep r op
newEnv r
r Operations rep r op
ops Space
space)

subImpM_ ::
  r' ->
  Operations rep r' op' ->
  ImpM rep r' op' a ->
  ImpM rep r op (Imp.Code op')
subImpM_ :: forall r' rep op' a r op.
r'
-> Operations rep r' op'
-> ImpM rep r' op' a
-> ImpM rep r op (Code op')
subImpM_ r'
r Operations rep r' op'
ops ImpM rep r' op' a
m = (a, Code op') -> Code op'
forall a b. (a, b) -> b
snd ((a, Code op') -> Code op')
-> ImpM rep r op (a, Code op') -> ImpM rep r op (Code op')
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> r'
-> Operations rep r' op'
-> ImpM rep r' op' a
-> ImpM rep r op (a, Code op')
forall r' rep op' a r op.
r'
-> Operations rep r' op'
-> ImpM rep r' op' a
-> ImpM rep r op (a, Code op')
subImpM r'
r Operations rep r' op'
ops ImpM rep r' op' a
m

subImpM ::
  r' ->
  Operations rep r' op' ->
  ImpM rep r' op' a ->
  ImpM rep r op (a, Imp.Code op')
subImpM :: forall r' rep op' a r op.
r'
-> Operations rep r' op'
-> ImpM rep r' op' a
-> ImpM rep r op (a, Code op')
subImpM r'
r Operations rep r' op'
ops (ImpM ReaderT (Env rep r' op') (State (ImpState rep r' op')) a
m) = do
  Env rep r op
env <- ImpM rep r op (Env rep r op)
forall r (m :: * -> *). MonadReader r m => m r
ask
  ImpState rep r op
s <- ImpM rep r op (ImpState rep r op)
forall s (m :: * -> *). MonadState s m => m s
get

  let env' :: Env rep r' op'
env' =
        Env rep r op
env
          { envExpCompiler :: ExpCompiler rep r' op'
envExpCompiler = Operations rep r' op' -> ExpCompiler rep r' op'
forall rep r op. Operations rep r op -> ExpCompiler rep r op
opsExpCompiler Operations rep r' op'
ops,
            envStmsCompiler :: StmsCompiler rep r' op'
envStmsCompiler = Operations rep r' op' -> StmsCompiler rep r' op'
forall rep r op. Operations rep r op -> StmsCompiler rep r op
opsStmsCompiler Operations rep r' op'
ops,
            envCopyCompiler :: CopyCompiler rep r' op'
envCopyCompiler = Operations rep r' op' -> CopyCompiler rep r' op'
forall rep r op. Operations rep r op -> CopyCompiler rep r op
opsCopyCompiler Operations rep r' op'
ops,
            envOpCompiler :: OpCompiler rep r' op'
envOpCompiler = Operations rep r' op' -> OpCompiler rep r' op'
forall rep r op. Operations rep r op -> OpCompiler rep r op
opsOpCompiler Operations rep r' op'
ops,
            envAllocCompilers :: Map Space (AllocCompiler rep r' op')
envAllocCompilers = Operations rep r' op' -> Map Space (AllocCompiler rep r' op')
forall rep r op.
Operations rep r op -> Map Space (AllocCompiler rep r op)
opsAllocCompilers Operations rep r' op'
ops,
            envEnv :: r'
envEnv = r'
r
          }
      s' :: ImpState rep r' op'
s' =
        ImpState :: forall rep r op.
VTable rep
-> Functions op
-> Code op
-> Warnings
-> Map VName ([VName], Maybe (Lambda rep, [SubExp]))
-> VNameSource
-> ImpState rep r op
ImpState
          { stateVTable :: VTable rep
stateVTable = ImpState rep r op -> VTable rep
forall rep r op. ImpState rep r op -> VTable rep
stateVTable ImpState rep r op
s,
            stateFunctions :: Functions op'
stateFunctions = Functions op'
forall a. Monoid a => a
mempty,
            stateCode :: Code op'
stateCode = Code op'
forall a. Monoid a => a
mempty,
            stateNameSource :: VNameSource
stateNameSource = ImpState rep r op -> VNameSource
forall rep r op. ImpState rep r op -> VNameSource
stateNameSource ImpState rep r op
s,
            stateWarnings :: Warnings
stateWarnings = Warnings
forall a. Monoid a => a
mempty,
            stateAccs :: Map VName ([VName], Maybe (Lambda rep, [SubExp]))
stateAccs = ImpState rep r op
-> Map VName ([VName], Maybe (Lambda rep, [SubExp]))
forall rep r op.
ImpState rep r op
-> Map VName ([VName], Maybe (Lambda rep, [SubExp]))
stateAccs ImpState rep r op
s
          }
      (a
x, ImpState rep r' op'
s'') = State (ImpState rep r' op') a
-> ImpState rep r' op' -> (a, ImpState rep r' op')
forall s a. State s a -> s -> (a, s)
runState (ReaderT (Env rep r' op') (State (ImpState rep r' op')) a
-> Env rep r' op' -> State (ImpState rep r' op') a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Env rep r' op') (State (ImpState rep r' op')) a
m Env rep r' op'
env') ImpState rep r' op'
s'

  VNameSource -> ImpM rep r op ()
forall (m :: * -> *). MonadFreshNames m => VNameSource -> m ()
putNameSource (VNameSource -> ImpM rep r op ())
-> VNameSource -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ ImpState rep r' op' -> VNameSource
forall rep r op. ImpState rep r op -> VNameSource
stateNameSource ImpState rep r' op'
s''
  Warnings -> ImpM rep r op ()
forall rep r op. Warnings -> ImpM rep r op ()
warnings (Warnings -> ImpM rep r op ()) -> Warnings -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ ImpState rep r' op' -> Warnings
forall rep r op. ImpState rep r op -> Warnings
stateWarnings ImpState rep r' op'
s''
  (a, Code op') -> ImpM rep r op (a, Code op')
forall (m :: * -> *) a. Monad m => a -> m a
return (a
x, ImpState rep r' op' -> Code op'
forall rep r op. ImpState rep r op -> Code op
stateCode ImpState rep r' op'
s'')

-- | Execute a code generation action, returning the code that was
-- emitted.
collect :: ImpM rep r op () -> ImpM rep r op (Imp.Code op)
collect :: forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect = (((), Code op) -> Code op)
-> ImpM rep r op ((), Code op) -> ImpM rep r op (Code op)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((), Code op) -> Code op
forall a b. (a, b) -> b
snd (ImpM rep r op ((), Code op) -> ImpM rep r op (Code op))
-> (ImpM rep r op () -> ImpM rep r op ((), Code op))
-> ImpM rep r op ()
-> ImpM rep r op (Code op)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ImpM rep r op () -> ImpM rep r op ((), Code op)
forall rep r op a. ImpM rep r op a -> ImpM rep r op (a, Code op)
collect'

collect' :: ImpM rep r op a -> ImpM rep r op (a, Imp.Code op)
collect' :: forall rep r op a. ImpM rep r op a -> ImpM rep r op (a, Code op)
collect' ImpM rep r op a
m = do
  Code op
prev_code <- (ImpState rep r op -> Code op) -> ImpM rep r op (Code op)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ImpState rep r op -> Code op
forall rep r op. ImpState rep r op -> Code op
stateCode
  (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ())
-> (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s -> ImpState rep r op
s {stateCode :: Code op
stateCode = Code op
forall a. Monoid a => a
mempty}
  a
x <- ImpM rep r op a
m
  Code op
new_code <- (ImpState rep r op -> Code op) -> ImpM rep r op (Code op)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ImpState rep r op -> Code op
forall rep r op. ImpState rep r op -> Code op
stateCode
  (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ())
-> (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s -> ImpState rep r op
s {stateCode :: Code op
stateCode = Code op
prev_code}
  (a, Code op) -> ImpM rep r op (a, Code op)
forall (m :: * -> *) a. Monad m => a -> m a
return (a
x, Code op
new_code)

-- | Execute a code generation action, wrapping the generated code
-- within a 'Imp.Comment' with the given description.
comment :: String -> ImpM rep r op () -> ImpM rep r op ()
comment :: forall rep r op. [Char] -> ImpM rep r op () -> ImpM rep r op ()
comment [Char]
desc ImpM rep r op ()
m = do
  Code op
code <- ImpM rep r op () -> ImpM rep r op (Code op)
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect ImpM rep r op ()
m
  Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Code op -> Code op
forall a. [Char] -> Code a -> Code a
Imp.Comment [Char]
desc Code op
code

-- | Emit some generated imperative code.
emit :: Imp.Code op -> ImpM rep r op ()
emit :: forall op rep r. Code op -> ImpM rep r op ()
emit Code op
code = (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ())
-> (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s -> ImpState rep r op
s {stateCode :: Code op
stateCode = ImpState rep r op -> Code op
forall rep r op. ImpState rep r op -> Code op
stateCode ImpState rep r op
s Code op -> Code op -> Code op
forall a. Semigroup a => a -> a -> a
<> Code op
code}

warnings :: Warnings -> ImpM rep r op ()
warnings :: forall rep r op. Warnings -> ImpM rep r op ()
warnings Warnings
ws = (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ())
-> (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s -> ImpState rep r op
s {stateWarnings :: Warnings
stateWarnings = Warnings
ws Warnings -> Warnings -> Warnings
forall a. Semigroup a => a -> a -> a
<> ImpState rep r op -> Warnings
forall rep r op. ImpState rep r op -> Warnings
stateWarnings ImpState rep r op
s}

-- | Emit a warning about something the user should be aware of.
warn :: Located loc => loc -> [loc] -> String -> ImpM rep r op ()
warn :: forall loc rep r op.
Located loc =>
loc -> [loc] -> [Char] -> ImpM rep r op ()
warn loc
loc [loc]
locs [Char]
problem =
  Warnings -> ImpM rep r op ()
forall rep r op. Warnings -> ImpM rep r op ()
warnings (Warnings -> ImpM rep r op ()) -> Warnings -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ SrcLoc -> [SrcLoc] -> Doc -> Warnings
singleWarning' (loc -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf loc
loc) ((loc -> SrcLoc) -> [loc] -> [SrcLoc]
forall a b. (a -> b) -> [a] -> [b]
map loc -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf [loc]
locs) ([Char] -> Doc
forall a. IsString a => [Char] -> a
fromString [Char]
problem)

-- | Emit a function in the generated code.
emitFunction :: Name -> Imp.Function op -> ImpM rep r op ()
emitFunction :: forall op rep r. Name -> Function op -> ImpM rep r op ()
emitFunction Name
fname Function op
fun = do
  Imp.Functions [(Name, Function op)]
fs <- (ImpState rep r op -> Functions op) -> ImpM rep r op (Functions op)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ImpState rep r op -> Functions op
forall rep r op. ImpState rep r op -> Functions op
stateFunctions
  (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ())
-> (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s -> ImpState rep r op
s {stateFunctions :: Functions op
stateFunctions = [(Name, Function op)] -> Functions op
forall a. [(Name, Function a)] -> Functions a
Imp.Functions ([(Name, Function op)] -> Functions op)
-> [(Name, Function op)] -> Functions op
forall a b. (a -> b) -> a -> b
$ (Name
fname, Function op
fun) (Name, Function op)
-> [(Name, Function op)] -> [(Name, Function op)]
forall a. a -> [a] -> [a]
: [(Name, Function op)]
fs}

-- | Check if a function of a given name exists.
hasFunction :: Name -> ImpM rep r op Bool
hasFunction :: forall rep r op. Name -> ImpM rep r op Bool
hasFunction Name
fname = (ImpState rep r op -> Bool) -> ImpM rep r op Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ImpState rep r op -> Bool) -> ImpM rep r op Bool)
-> (ImpState rep r op -> Bool) -> ImpM rep r op Bool
forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s ->
  let Imp.Functions [(Name, Function op)]
fs = ImpState rep r op -> Functions op
forall rep r op. ImpState rep r op -> Functions op
stateFunctions ImpState rep r op
s
   in Maybe (Function op) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe (Function op) -> Bool) -> Maybe (Function op) -> Bool
forall a b. (a -> b) -> a -> b
$ Name -> [(Name, Function op)] -> Maybe (Function op)
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Name
fname [(Name, Function op)]
fs

constsVTable :: Mem rep => Stms rep -> VTable rep
constsVTable :: forall rep. Mem rep => Stms rep -> VTable rep
constsVTable = (Stm rep -> Map VName (VarEntry rep))
-> Seq (Stm rep) -> Map VName (VarEntry rep)
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm rep -> Map VName (VarEntry rep)
forall {rep}.
(LetDec rep ~ LParamMem) =>
Stm rep -> Map VName (VarEntry rep)
stmVtable
  where
    stmVtable :: Stm rep -> Map VName (VarEntry rep)
stmVtable (Let Pattern rep
pat StmAux (ExpDec rep)
_ Exp rep
e) =
      (PatElemT LParamMem -> Map VName (VarEntry rep))
-> [PatElemT LParamMem] -> Map VName (VarEntry rep)
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Exp rep -> PatElemT LParamMem -> Map VName (VarEntry rep)
forall {rep}.
Exp rep -> PatElemT LParamMem -> Map VName (VarEntry rep)
peVtable Exp rep
e) ([PatElemT LParamMem] -> Map VName (VarEntry rep))
-> [PatElemT LParamMem] -> Map VName (VarEntry rep)
forall a b. (a -> b) -> a -> b
$ PatternT LParamMem -> [PatElemT LParamMem]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern rep
PatternT LParamMem
pat
    peVtable :: Exp rep -> PatElemT LParamMem -> Map VName (VarEntry rep)
peVtable Exp rep
e (PatElem VName
name LParamMem
dec) =
      VName -> VarEntry rep -> Map VName (VarEntry rep)
forall k a. k -> a -> Map k a
M.singleton VName
name (VarEntry rep -> Map VName (VarEntry rep))
-> VarEntry rep -> Map VName (VarEntry rep)
forall a b. (a -> b) -> a -> b
$ Maybe (Exp rep) -> LParamMem -> VarEntry rep
forall rep. Maybe (Exp rep) -> LParamMem -> VarEntry rep
memBoundToVarEntry (Exp rep -> Maybe (Exp rep)
forall a. a -> Maybe a
Just Exp rep
e) LParamMem
dec

compileProg ::
  (Mem rep, FreeIn op, MonadFreshNames m) =>
  r ->
  Operations rep r op ->
  Imp.Space ->
  Prog rep ->
  m (Warnings, Imp.Definitions op)
compileProg :: forall rep op (m :: * -> *) r.
(Mem rep, FreeIn op, MonadFreshNames m) =>
r
-> Operations rep r op
-> Space
-> Prog rep
-> m (Warnings, Definitions op)
compileProg r
r Operations rep r op
ops Space
space (Prog Stms rep
consts [FunDef rep]
funs) =
  (VNameSource -> ((Warnings, Definitions op), VNameSource))
-> m (Warnings, Definitions op)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> ((Warnings, Definitions op), VNameSource))
 -> m (Warnings, Definitions op))
-> (VNameSource -> ((Warnings, Definitions op), VNameSource))
-> m (Warnings, Definitions op)
forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
    let ([()]
_, [ImpState rep r op]
ss) =
          [((), ImpState rep r op)] -> ([()], [ImpState rep r op])
forall a b. [(a, b)] -> ([a], [b])
unzip ([((), ImpState rep r op)] -> ([()], [ImpState rep r op]))
-> [((), ImpState rep r op)] -> ([()], [ImpState rep r op])
forall a b. (a -> b) -> a -> b
$ Strategy ((), ImpState rep r op)
-> (FunDef rep -> ((), ImpState rep r op))
-> [FunDef rep]
-> [((), ImpState rep r op)]
forall b a. Strategy b -> (a -> b) -> [a] -> [b]
parMap Strategy ((), ImpState rep r op)
forall a. Strategy a
rpar (VNameSource -> FunDef rep -> ((), ImpState rep r op)
compileFunDef' VNameSource
src) [FunDef rep]
funs
        free_in_funs :: Names
free_in_funs =
          Functions op -> Names
forall a. FreeIn a => a -> Names
freeIn (Functions op -> Names) -> Functions op -> Names
forall a b. (a -> b) -> a -> b
$ [Functions op] -> Functions op
forall a. Monoid a => [a] -> a
mconcat ([Functions op] -> Functions op) -> [Functions op] -> Functions op
forall a b. (a -> b) -> a -> b
$ (ImpState rep r op -> Functions op)
-> [ImpState rep r op] -> [Functions op]
forall a b. (a -> b) -> [a] -> [b]
map ImpState rep r op -> Functions op
forall rep r op. ImpState rep r op -> Functions op
stateFunctions [ImpState rep r op]
ss
        (Constants op
consts', ImpState rep r op
s') =
          ImpM rep r op (Constants op)
-> r
-> Operations rep r op
-> Space
-> ImpState rep r op
-> (Constants op, ImpState rep r op)
forall rep r op a.
ImpM rep r op a
-> r
-> Operations rep r op
-> Space
-> ImpState rep r op
-> (a, ImpState rep r op)
runImpM (Names -> Stms rep -> ImpM rep r op (Constants op)
forall rep r op. Names -> Stms rep -> ImpM rep r op (Constants op)
compileConsts Names
free_in_funs Stms rep
consts) r
r Operations rep r op
ops Space
space (ImpState rep r op -> (Constants op, ImpState rep r op))
-> ImpState rep r op -> (Constants op, ImpState rep r op)
forall a b. (a -> b) -> a -> b
$
            [ImpState rep r op] -> ImpState rep r op
forall {rep} {r} {op} {rep} {r}.
[ImpState rep r op] -> ImpState rep r op
combineStates [ImpState rep r op]
ss
     in ( ( ImpState rep r op -> Warnings
forall rep r op. ImpState rep r op -> Warnings
stateWarnings ImpState rep r op
s',
            Constants op -> Functions op -> Definitions op
forall a. Constants a -> Functions a -> Definitions a
Imp.Definitions Constants op
consts' (ImpState rep r op -> Functions op
forall rep r op. ImpState rep r op -> Functions op
stateFunctions ImpState rep r op
s')
          ),
          ImpState rep r op -> VNameSource
forall rep r op. ImpState rep r op -> VNameSource
stateNameSource ImpState rep r op
s'
        )
  where
    compileFunDef' :: VNameSource -> FunDef rep -> ((), ImpState rep r op)
compileFunDef' VNameSource
src FunDef rep
fdef =
      ImpM rep r op ()
-> r
-> Operations rep r op
-> Space
-> ImpState rep r op
-> ((), ImpState rep r op)
forall rep r op a.
ImpM rep r op a
-> r
-> Operations rep r op
-> Space
-> ImpState rep r op
-> (a, ImpState rep r op)
runImpM
        (FunDef rep -> ImpM rep r op ()
forall rep r op. Mem rep => FunDef rep -> ImpM rep r op ()
compileFunDef FunDef rep
fdef)
        r
r
        Operations rep r op
ops
        Space
space
        (VNameSource -> ImpState rep Any op
forall rep r op. VNameSource -> ImpState rep r op
newState VNameSource
src) {stateVTable :: VTable rep
stateVTable = Stms rep -> VTable rep
forall rep. Mem rep => Stms rep -> VTable rep
constsVTable Stms rep
consts}

    combineStates :: [ImpState rep r op] -> ImpState rep r op
combineStates [ImpState rep r op]
ss =
      let Imp.Functions [(Name, Function op)]
funs' = [Functions op] -> Functions op
forall a. Monoid a => [a] -> a
mconcat ([Functions op] -> Functions op) -> [Functions op] -> Functions op
forall a b. (a -> b) -> a -> b
$ (ImpState rep r op -> Functions op)
-> [ImpState rep r op] -> [Functions op]
forall a b. (a -> b) -> [a] -> [b]
map ImpState rep r op -> Functions op
forall rep r op. ImpState rep r op -> Functions op
stateFunctions [ImpState rep r op]
ss
          src :: VNameSource
src = [VNameSource] -> VNameSource
forall a. Monoid a => [a] -> a
mconcat ((ImpState rep r op -> VNameSource)
-> [ImpState rep r op] -> [VNameSource]
forall a b. (a -> b) -> [a] -> [b]
map ImpState rep r op -> VNameSource
forall rep r op. ImpState rep r op -> VNameSource
stateNameSource [ImpState rep r op]
ss)
       in (VNameSource -> ImpState rep Any op
forall rep r op. VNameSource -> ImpState rep r op
newState VNameSource
src)
            { stateFunctions :: Functions op
stateFunctions =
                [(Name, Function op)] -> Functions op
forall a. [(Name, Function a)] -> Functions a
Imp.Functions ([(Name, Function op)] -> Functions op)
-> [(Name, Function op)] -> Functions op
forall a b. (a -> b) -> a -> b
$ Map Name (Function op) -> [(Name, Function op)]
forall k a. Map k a -> [(k, a)]
M.toList (Map Name (Function op) -> [(Name, Function op)])
-> Map Name (Function op) -> [(Name, Function op)]
forall a b. (a -> b) -> a -> b
$ [(Name, Function op)] -> Map Name (Function op)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(Name, Function op)]
funs',
              stateWarnings :: Warnings
stateWarnings =
                [Warnings] -> Warnings
forall a. Monoid a => [a] -> a
mconcat ([Warnings] -> Warnings) -> [Warnings] -> Warnings
forall a b. (a -> b) -> a -> b
$ (ImpState rep r op -> Warnings)
-> [ImpState rep r op] -> [Warnings]
forall a b. (a -> b) -> [a] -> [b]
map ImpState rep r op -> Warnings
forall rep r op. ImpState rep r op -> Warnings
stateWarnings [ImpState rep r op]
ss
            }

compileConsts :: Names -> Stms rep -> ImpM rep r op (Imp.Constants op)
compileConsts :: forall rep r op. Names -> Stms rep -> ImpM rep r op (Constants op)
compileConsts Names
used_consts Stms rep
stms = do
  Code op
code <- ImpM rep r op () -> ImpM rep r op (Code op)
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect (ImpM rep r op () -> ImpM rep r op (Code op))
-> ImpM rep r op () -> ImpM rep r op (Code op)
forall a b. (a -> b) -> a -> b
$ Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
used_consts Stms rep
stms (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ () -> ImpM rep r op ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  Constants op -> ImpM rep r op (Constants op)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Constants op -> ImpM rep r op (Constants op))
-> Constants op -> ImpM rep r op (Constants op)
forall a b. (a -> b) -> a -> b
$ ([Param] -> Code op -> Constants op)
-> ([Param], Code op) -> Constants op
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry [Param] -> Code op -> Constants op
forall a. [Param] -> Code a -> Constants a
Imp.Constants (([Param], Code op) -> Constants op)
-> ([Param], Code op) -> Constants op
forall a b. (a -> b) -> a -> b
$ (DList Param -> [Param])
-> (DList Param, Code op) -> ([Param], Code op)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first DList Param -> [Param]
forall a. DList a -> [a]
DL.toList ((DList Param, Code op) -> ([Param], Code op))
-> (DList Param, Code op) -> ([Param], Code op)
forall a b. (a -> b) -> a -> b
$ Code op -> (DList Param, Code op)
extract Code op
code
  where
    -- Fish out those top-level declarations in the constant
    -- initialisation code that are free in the functions.
    extract :: Code op -> (DList Param, Code op)
extract (Code op
x Imp.:>>: Code op
y) =
      Code op -> (DList Param, Code op)
extract Code op
x (DList Param, Code op)
-> (DList Param, Code op) -> (DList Param, Code op)
forall a. Semigroup a => a -> a -> a
<> Code op -> (DList Param, Code op)
extract Code op
y
    extract (Imp.DeclareMem VName
name Space
space)
      | VName
name VName -> Names -> Bool
`nameIn` Names
used_consts =
        ( Param -> DList Param
forall a. a -> DList a
DL.singleton (Param -> DList Param) -> Param -> DList Param
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Param
Imp.MemParam VName
name Space
space,
          Code op
forall a. Monoid a => a
mempty
        )
    extract (Imp.DeclareScalar VName
name Volatility
_ PrimType
t)
      | VName
name VName -> Names -> Bool
`nameIn` Names
used_consts =
        ( Param -> DList Param
forall a. a -> DList a
DL.singleton (Param -> DList Param) -> Param -> DList Param
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Param
Imp.ScalarParam VName
name PrimType
t,
          Code op
forall a. Monoid a => a
mempty
        )
    extract Code op
s =
      (DList Param
forall a. Monoid a => a
mempty, Code op
s)

compileInParam ::
  Mem rep =>
  FParam rep ->
  ImpM rep r op (Either Imp.Param ArrayDecl)
compileInParam :: forall rep r op.
Mem rep =>
FParam rep -> ImpM rep r op (Either Param ArrayDecl)
compileInParam FParam rep
fparam = case Param FParamMem -> FParamMem
forall dec. Param dec -> dec
paramDec FParam rep
Param FParamMem
fparam of
  MemPrim PrimType
bt ->
    Either Param ArrayDecl -> ImpM rep r op (Either Param ArrayDecl)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either Param ArrayDecl -> ImpM rep r op (Either Param ArrayDecl))
-> Either Param ArrayDecl -> ImpM rep r op (Either Param ArrayDecl)
forall a b. (a -> b) -> a -> b
$ Param -> Either Param ArrayDecl
forall a b. a -> Either a b
Left (Param -> Either Param ArrayDecl)
-> Param -> Either Param ArrayDecl
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Param
Imp.ScalarParam VName
name PrimType
bt
  MemMem Space
space ->
    Either Param ArrayDecl -> ImpM rep r op (Either Param ArrayDecl)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either Param ArrayDecl -> ImpM rep r op (Either Param ArrayDecl))
-> Either Param ArrayDecl -> ImpM rep r op (Either Param ArrayDecl)
forall a b. (a -> b) -> a -> b
$ Param -> Either Param ArrayDecl
forall a b. a -> Either a b
Left (Param -> Either Param ArrayDecl)
-> Param -> Either Param ArrayDecl
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Param
Imp.MemParam VName
name Space
space
  MemArray PrimType
bt Shape
shape Uniqueness
_ (ArrayIn VName
mem IxFun
ixfun) ->
    Either Param ArrayDecl -> ImpM rep r op (Either Param ArrayDecl)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either Param ArrayDecl -> ImpM rep r op (Either Param ArrayDecl))
-> Either Param ArrayDecl -> ImpM rep r op (Either Param ArrayDecl)
forall a b. (a -> b) -> a -> b
$
      ArrayDecl -> Either Param ArrayDecl
forall a b. b -> Either a b
Right (ArrayDecl -> Either Param ArrayDecl)
-> ArrayDecl -> Either Param ArrayDecl
forall a b. (a -> b) -> a -> b
$
        VName -> PrimType -> MemLocation -> ArrayDecl
ArrayDecl VName
name PrimType
bt (MemLocation -> ArrayDecl) -> MemLocation -> ArrayDecl
forall a b. (a -> b) -> a -> b
$
          VName -> [SubExp] -> IxFun (TExp Int64) -> MemLocation
MemLocation VName
mem (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape) (IxFun (TExp Int64) -> MemLocation)
-> IxFun (TExp Int64) -> MemLocation
forall a b. (a -> b) -> a -> b
$ (TPrimExp Int64 VName -> TExp Int64) -> IxFun -> IxFun (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> ExpLeaf) -> TPrimExp Int64 VName -> TExp Int64
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> ExpLeaf
Imp.ScalarVar) IxFun
ixfun
  MemAcc {} ->
    [Char] -> ImpM rep r op (Either Param ArrayDecl)
forall a. HasCallStack => [Char] -> a
error [Char]
"Functions may not have accumulator parameters."
  where
    name :: VName
name = Param FParamMem -> VName
forall dec. Param dec -> VName
paramName FParam rep
Param FParamMem
fparam

data ArrayDecl = ArrayDecl VName PrimType MemLocation

compileInParams ::
  Mem rep =>
  [FParam rep] ->
  [EntryPointType] ->
  ImpM rep r op ([Imp.Param], [ArrayDecl], [Imp.ExternalValue])
compileInParams :: forall rep r op.
Mem rep =>
[FParam rep]
-> [EntryPointType]
-> ImpM rep r op ([Param], [ArrayDecl], [ExternalValue])
compileInParams [FParam rep]
params [EntryPointType]
orig_epts = do
  let ([Param FParamMem]
ctx_params, [Param FParamMem]
val_params) =
        Int -> [Param FParamMem] -> ([Param FParamMem], [Param FParamMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([Param FParamMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [FParam rep]
[Param FParamMem]
params Int -> Int -> Int
forall a. Num a => a -> a -> a
- [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((EntryPointType -> Int) -> [EntryPointType] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map EntryPointType -> Int
entryPointSize [EntryPointType]
orig_epts)) [FParam rep]
[Param FParamMem]
params
  ([Param]
inparams, [ArrayDecl]
arrayds) <- [Either Param ArrayDecl] -> ([Param], [ArrayDecl])
forall a b. [Either a b] -> ([a], [b])
partitionEithers ([Either Param ArrayDecl] -> ([Param], [ArrayDecl]))
-> ImpM rep r op [Either Param ArrayDecl]
-> ImpM rep r op ([Param], [ArrayDecl])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Param FParamMem -> ImpM rep r op (Either Param ArrayDecl))
-> [Param FParamMem] -> ImpM rep r op [Either Param ArrayDecl]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Param FParamMem -> ImpM rep r op (Either Param ArrayDecl)
forall rep r op.
Mem rep =>
FParam rep -> ImpM rep r op (Either Param ArrayDecl)
compileInParam ([Param FParamMem]
ctx_params [Param FParamMem] -> [Param FParamMem] -> [Param FParamMem]
forall a. [a] -> [a] -> [a]
++ [Param FParamMem]
val_params)
  let findArray :: VName -> Maybe ArrayDecl
findArray VName
x = (ArrayDecl -> Bool) -> [ArrayDecl] -> Maybe ArrayDecl
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (VName -> ArrayDecl -> Bool
isArrayDecl VName
x) [ArrayDecl]
arrayds

      summaries :: Map VName Space
summaries = [(VName, Space)] -> Map VName Space
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Space)] -> Map VName Space)
-> [(VName, Space)] -> Map VName Space
forall a b. (a -> b) -> a -> b
$ (Param FParamMem -> Maybe (VName, Space))
-> [Param FParamMem] -> [(VName, Space)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Param FParamMem -> Maybe (VName, Space)
forall {d} {u} {ret}.
Param (MemInfo d u ret) -> Maybe (VName, Space)
memSummary [FParam rep]
[Param FParamMem]
params
        where
          memSummary :: Param (MemInfo d u ret) -> Maybe (VName, Space)
memSummary Param (MemInfo d u ret)
param
            | MemMem Space
space <- Param (MemInfo d u ret) -> MemInfo d u ret
forall dec. Param dec -> dec
paramDec Param (MemInfo d u ret)
param =
              (VName, Space) -> Maybe (VName, Space)
forall a. a -> Maybe a
Just (Param (MemInfo d u ret) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo d u ret)
param, Space
space)
            | Bool
otherwise =
              Maybe (VName, Space)
forall a. Maybe a
Nothing

      findMemInfo :: VName -> Maybe Space
      findMemInfo :: VName -> Maybe Space
findMemInfo = (VName -> Map VName Space -> Maybe Space)
-> Map VName Space -> VName -> Maybe Space
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> Map VName Space -> Maybe Space
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Map VName Space
summaries

      mkValueDesc :: Param FParamMem -> Signedness -> Maybe ValueDesc
mkValueDesc Param FParamMem
fparam Signedness
signedness =
        case (VName -> Maybe ArrayDecl
findArray (VName -> Maybe ArrayDecl) -> VName -> Maybe ArrayDecl
forall a b. (a -> b) -> a -> b
$ Param FParamMem -> VName
forall dec. Param dec -> VName
paramName Param FParamMem
fparam, Param FParamMem -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param FParamMem
fparam) of
          (Just (ArrayDecl VName
_ PrimType
bt (MemLocation VName
mem [SubExp]
shape IxFun (TExp Int64)
_)), Type
_) -> do
            Space
memspace <- VName -> Maybe Space
findMemInfo VName
mem
            ValueDesc -> Maybe ValueDesc
forall a. a -> Maybe a
Just (ValueDesc -> Maybe ValueDesc) -> ValueDesc -> Maybe ValueDesc
forall a b. (a -> b) -> a -> b
$ VName -> Space -> PrimType -> Signedness -> [SubExp] -> ValueDesc
Imp.ArrayValue VName
mem Space
memspace PrimType
bt Signedness
signedness [SubExp]
shape
          (Maybe ArrayDecl
_, Prim PrimType
bt) ->
            ValueDesc -> Maybe ValueDesc
forall a. a -> Maybe a
Just (ValueDesc -> Maybe ValueDesc) -> ValueDesc -> Maybe ValueDesc
forall a b. (a -> b) -> a -> b
$ PrimType -> Signedness -> VName -> ValueDesc
Imp.ScalarValue PrimType
bt Signedness
signedness (VName -> ValueDesc) -> VName -> ValueDesc
forall a b. (a -> b) -> a -> b
$ Param FParamMem -> VName
forall dec. Param dec -> VName
paramName Param FParamMem
fparam
          (Maybe ArrayDecl, Type)
_ ->
            Maybe ValueDesc
forall a. Maybe a
Nothing

      mkExts :: [EntryPointType] -> [Param FParamMem] -> [ExternalValue]
mkExts (TypeOpaque Uniqueness
u [Char]
desc Int
n : [EntryPointType]
epts) [Param FParamMem]
fparams =
        let ([Param FParamMem]
fparams', [Param FParamMem]
rest) = Int -> [Param FParamMem] -> ([Param FParamMem], [Param FParamMem])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
n [Param FParamMem]
fparams
         in Uniqueness -> [Char] -> [ValueDesc] -> ExternalValue
Imp.OpaqueValue
              Uniqueness
u
              [Char]
desc
              ((Param FParamMem -> Maybe ValueDesc)
-> [Param FParamMem] -> [ValueDesc]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (Param FParamMem -> Signedness -> Maybe ValueDesc
`mkValueDesc` Signedness
Imp.TypeDirect) [Param FParamMem]
fparams') ExternalValue -> [ExternalValue] -> [ExternalValue]
forall a. a -> [a] -> [a]
:
            [EntryPointType] -> [Param FParamMem] -> [ExternalValue]
mkExts [EntryPointType]
epts [Param FParamMem]
rest
      mkExts (TypeUnsigned Uniqueness
u : [EntryPointType]
epts) (Param FParamMem
fparam : [Param FParamMem]
fparams) =
        Maybe ExternalValue -> [ExternalValue]
forall a. Maybe a -> [a]
maybeToList (Uniqueness -> ValueDesc -> ExternalValue
Imp.TransparentValue Uniqueness
u (ValueDesc -> ExternalValue)
-> Maybe ValueDesc -> Maybe ExternalValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Param FParamMem -> Signedness -> Maybe ValueDesc
mkValueDesc Param FParamMem
fparam Signedness
Imp.TypeUnsigned)
          [ExternalValue] -> [ExternalValue] -> [ExternalValue]
forall a. [a] -> [a] -> [a]
++ [EntryPointType] -> [Param FParamMem] -> [ExternalValue]
mkExts [EntryPointType]
epts [Param FParamMem]
fparams
      mkExts (TypeDirect Uniqueness
u : [EntryPointType]
epts) (Param FParamMem
fparam : [Param FParamMem]
fparams) =
        Maybe ExternalValue -> [ExternalValue]
forall a. Maybe a -> [a]
maybeToList (Uniqueness -> ValueDesc -> ExternalValue
Imp.TransparentValue Uniqueness
u (ValueDesc -> ExternalValue)
-> Maybe ValueDesc -> Maybe ExternalValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Param FParamMem -> Signedness -> Maybe ValueDesc
mkValueDesc Param FParamMem
fparam Signedness
Imp.TypeDirect)
          [ExternalValue] -> [ExternalValue] -> [ExternalValue]
forall a. [a] -> [a] -> [a]
++ [EntryPointType] -> [Param FParamMem] -> [ExternalValue]
mkExts [EntryPointType]
epts [Param FParamMem]
fparams
      mkExts [EntryPointType]
_ [Param FParamMem]
_ = []

  ([Param], [ArrayDecl], [ExternalValue])
-> ImpM rep r op ([Param], [ArrayDecl], [ExternalValue])
forall (m :: * -> *) a. Monad m => a -> m a
return ([Param]
inparams, [ArrayDecl]
arrayds, [EntryPointType] -> [Param FParamMem] -> [ExternalValue]
mkExts [EntryPointType]
orig_epts [Param FParamMem]
val_params)
  where
    isArrayDecl :: VName -> ArrayDecl -> Bool
isArrayDecl VName
x (ArrayDecl VName
y PrimType
_ MemLocation
_) = VName
x VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
y

compileOutParams ::
  Mem rep =>
  [RetType rep] ->
  [EntryPointType] ->
  ImpM rep r op ([Imp.ExternalValue], [Imp.Param], Destination)
compileOutParams :: forall rep r op.
Mem rep =>
[RetType rep]
-> [EntryPointType]
-> ImpM rep r op ([ExternalValue], [Param], Destination)
compileOutParams [RetType rep]
orig_rts [EntryPointType]
orig_epts = do
  (([ExternalValue]
extvs, [ValueDestination]
dests), ([Param]
outparams, Map Int ValueDestination
ctx_dests)) <-
    WriterT
  ([Param], Map Int ValueDestination)
  (ImpM rep r op)
  ([ExternalValue], [ValueDestination])
-> ImpM
     rep
     r
     op
     (([ExternalValue], [ValueDestination]),
      ([Param], Map Int ValueDestination))
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT
   ([Param], Map Int ValueDestination)
   (ImpM rep r op)
   ([ExternalValue], [ValueDestination])
 -> ImpM
      rep
      r
      op
      (([ExternalValue], [ValueDestination]),
       ([Param], Map Int ValueDestination)))
-> WriterT
     ([Param], Map Int ValueDestination)
     (ImpM rep r op)
     ([ExternalValue], [ValueDestination])
-> ImpM
     rep
     r
     op
     (([ExternalValue], [ValueDestination]),
      ([Param], Map Int ValueDestination))
forall a b. (a -> b) -> a -> b
$ StateT
  (Map Any Any, Map Int VName)
  (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
  ([ExternalValue], [ValueDestination])
-> (Map Any Any, Map Int VName)
-> WriterT
     ([Param], Map Int ValueDestination)
     (ImpM rep r op)
     ([ExternalValue], [ValueDestination])
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT ([EntryPointType]
-> [RetTypeMem]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
     ([ExternalValue], [ValueDestination])
mkExts [EntryPointType]
orig_epts [RetType rep]
[RetTypeMem]
orig_rts) (Map Any Any
forall k a. Map k a
M.empty, Map Int VName
forall k a. Map k a
M.empty)
  let ctx_dests' :: [ValueDestination]
ctx_dests' = ((Int, ValueDestination) -> ValueDestination)
-> [(Int, ValueDestination)] -> [ValueDestination]
forall a b. (a -> b) -> [a] -> [b]
map (Int, ValueDestination) -> ValueDestination
forall a b. (a, b) -> b
snd ([(Int, ValueDestination)] -> [ValueDestination])
-> [(Int, ValueDestination)] -> [ValueDestination]
forall a b. (a -> b) -> a -> b
$ ((Int, ValueDestination) -> Int)
-> [(Int, ValueDestination)] -> [(Int, ValueDestination)]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn (Int, ValueDestination) -> Int
forall a b. (a, b) -> a
fst ([(Int, ValueDestination)] -> [(Int, ValueDestination)])
-> [(Int, ValueDestination)] -> [(Int, ValueDestination)]
forall a b. (a -> b) -> a -> b
$ Map Int ValueDestination -> [(Int, ValueDestination)]
forall k a. Map k a -> [(k, a)]
M.toList Map Int ValueDestination
ctx_dests
  ([ExternalValue], [Param], Destination)
-> ImpM rep r op ([ExternalValue], [Param], Destination)
forall (m :: * -> *) a. Monad m => a -> m a
return ([ExternalValue]
extvs, [Param]
outparams, Maybe Int -> [ValueDestination] -> Destination
Destination Maybe Int
forall a. Maybe a
Nothing ([ValueDestination] -> Destination)
-> [ValueDestination] -> Destination
forall a b. (a -> b) -> a -> b
$ [ValueDestination]
ctx_dests' [ValueDestination] -> [ValueDestination] -> [ValueDestination]
forall a. Semigroup a => a -> a -> a
<> [ValueDestination]
dests)
  where
    imp :: ImpM rep r op a
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
     a
imp = WriterT ([Param], Map Int ValueDestination) (ImpM rep r op) a
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
     a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op) a
 -> StateT
      (Map Any Any, Map Int VName)
      (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
      a)
-> (ImpM rep r op a
    -> WriterT ([Param], Map Int ValueDestination) (ImpM rep r op) a)
-> ImpM rep r op a
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
     a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ImpM rep r op a
-> WriterT ([Param], Map Int ValueDestination) (ImpM rep r op) a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift

    mkExts :: [EntryPointType]
-> [RetTypeMem]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
     ([ExternalValue], [ValueDestination])
mkExts (TypeOpaque Uniqueness
u [Char]
desc Int
n : [EntryPointType]
epts) [RetTypeMem]
rts = do
      let ([RetTypeMem]
rts', [RetTypeMem]
rest) = Int -> [RetTypeMem] -> ([RetTypeMem], [RetTypeMem])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
n [RetTypeMem]
rts
      ([ValueDesc]
evs, [ValueDestination]
dests) <- [(ValueDesc, ValueDestination)]
-> ([ValueDesc], [ValueDestination])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(ValueDesc, ValueDestination)]
 -> ([ValueDesc], [ValueDestination]))
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
     [(ValueDesc, ValueDestination)]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
     ([ValueDesc], [ValueDestination])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (RetTypeMem
 -> Signedness
 -> StateT
      (Map Any Any, Map Int VName)
      (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
      (ValueDesc, ValueDestination))
-> [RetTypeMem]
-> [Signedness]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
     [(ValueDesc, ValueDestination)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM RetTypeMem
-> Signedness
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
     (ValueDesc, ValueDestination)
mkParam [RetTypeMem]
rts' (Signedness -> [Signedness]
forall a. a -> [a]
repeat Signedness
Imp.TypeDirect)
      ([ExternalValue]
more_values, [ValueDestination]
more_dests) <- [EntryPointType]
-> [RetTypeMem]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
     ([ExternalValue], [ValueDestination])
mkExts [EntryPointType]
epts [RetTypeMem]
rest
      ([ExternalValue], [ValueDestination])
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
     ([ExternalValue], [ValueDestination])
forall (m :: * -> *) a. Monad m => a -> m a
return
        ( Uniqueness -> [Char] -> [ValueDesc] -> ExternalValue
Imp.OpaqueValue Uniqueness
u [Char]
desc [ValueDesc]
evs ExternalValue -> [ExternalValue] -> [ExternalValue]
forall a. a -> [a] -> [a]
: [ExternalValue]
more_values,
          [ValueDestination]
dests [ValueDestination] -> [ValueDestination] -> [ValueDestination]
forall a. [a] -> [a] -> [a]
++ [ValueDestination]
more_dests
        )
    mkExts (TypeUnsigned Uniqueness
u : [EntryPointType]
epts) (RetTypeMem
rt : [RetTypeMem]
rts) = do
      (ValueDesc
ev, ValueDestination
dest) <- RetTypeMem
-> Signedness
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
     (ValueDesc, ValueDestination)
mkParam RetTypeMem
rt Signedness
Imp.TypeUnsigned
      ([ExternalValue]
more_values, [ValueDestination]
more_dests) <- [EntryPointType]
-> [RetTypeMem]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
     ([ExternalValue], [ValueDestination])
mkExts [EntryPointType]
epts [RetTypeMem]
rts
      ([ExternalValue], [ValueDestination])
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
     ([ExternalValue], [ValueDestination])
forall (m :: * -> *) a. Monad m => a -> m a
return
        ( Uniqueness -> ValueDesc -> ExternalValue
Imp.TransparentValue Uniqueness
u ValueDesc
ev ExternalValue -> [ExternalValue] -> [ExternalValue]
forall a. a -> [a] -> [a]
: [ExternalValue]
more_values,
          ValueDestination
dest ValueDestination -> [ValueDestination] -> [ValueDestination]
forall a. a -> [a] -> [a]
: [ValueDestination]
more_dests
        )
    mkExts (TypeDirect Uniqueness
u : [EntryPointType]
epts) (RetTypeMem
rt : [RetTypeMem]
rts) = do
      (ValueDesc
ev, ValueDestination
dest) <- RetTypeMem
-> Signedness
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
     (ValueDesc, ValueDestination)
mkParam RetTypeMem
rt Signedness
Imp.TypeDirect
      ([ExternalValue]
more_values, [ValueDestination]
more_dests) <- [EntryPointType]
-> [RetTypeMem]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
     ([ExternalValue], [ValueDestination])
mkExts [EntryPointType]
epts [RetTypeMem]
rts
      ([ExternalValue], [ValueDestination])
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
     ([ExternalValue], [ValueDestination])
forall (m :: * -> *) a. Monad m => a -> m a
return
        ( Uniqueness -> ValueDesc -> ExternalValue
Imp.TransparentValue Uniqueness
u ValueDesc
ev ExternalValue -> [ExternalValue] -> [ExternalValue]
forall a. a -> [a] -> [a]
: [ExternalValue]
more_values,
          ValueDestination
dest ValueDestination -> [ValueDestination] -> [ValueDestination]
forall a. a -> [a] -> [a]
: [ValueDestination]
more_dests
        )
    mkExts [EntryPointType]
_ [RetTypeMem]
_ = ([ExternalValue], [ValueDestination])
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
     ([ExternalValue], [ValueDestination])
forall (m :: * -> *) a. Monad m => a -> m a
return ([], [])

    mkParam :: RetTypeMem
-> Signedness
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
     (ValueDesc, ValueDestination)
mkParam MemMem {} Signedness
_ =
      [Char]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
     (ValueDesc, ValueDestination)
forall a. HasCallStack => [Char] -> a
error [Char]
"Functions may not explicitly return memory blocks."
    mkParam MemAcc {} Signedness
_ =
      [Char]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
     (ValueDesc, ValueDestination)
forall a. HasCallStack => [Char] -> a
error [Char]
"Functions may not return accumulators."
    mkParam (MemPrim PrimType
t) Signedness
ept = do
      VName
out <- ImpM rep r op VName
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
     VName
forall {a}.
ImpM rep r op a
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
     a
imp (ImpM rep r op VName
 -> StateT
      (Map Any Any, Map Int VName)
      (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
      VName)
-> ImpM rep r op VName
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
     VName
forall a b. (a -> b) -> a -> b
$ [Char] -> ImpM rep r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"scalar_out"
      ([Param], Map Int ValueDestination)
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
     ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([VName -> PrimType -> Param
Imp.ScalarParam VName
out PrimType
t], Map Int ValueDestination
forall a. Monoid a => a
mempty)
      (ValueDesc, ValueDestination)
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
     (ValueDesc, ValueDestination)
forall (m :: * -> *) a. Monad m => a -> m a
return (PrimType -> Signedness -> VName -> ValueDesc
Imp.ScalarValue PrimType
t Signedness
ept VName
out, VName -> ValueDestination
ScalarDestination VName
out)
    mkParam (MemArray PrimType
t ShapeBase (Ext SubExp)
shape Uniqueness
_ MemReturn
dec) Signedness
ept = do
      Space
space <- (Env rep r op -> Space)
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
     Space
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env rep r op -> Space
forall rep r op. Env rep r op -> Space
envDefaultSpace
      VName
memout <- case MemReturn
dec of
        ReturnsNewBlock Space
_ Int
x ExtIxFun
_ixfun -> do
          VName
memout <- ImpM rep r op VName
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
     VName
forall {a}.
ImpM rep r op a
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
     a
imp (ImpM rep r op VName
 -> StateT
      (Map Any Any, Map Int VName)
      (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
      VName)
-> ImpM rep r op VName
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
     VName
forall a b. (a -> b) -> a -> b
$ [Char] -> ImpM rep r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"out_mem"
          ([Param], Map Int ValueDestination)
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
     ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell
            ( [VName -> Space -> Param
Imp.MemParam VName
memout Space
space],
              Int -> ValueDestination -> Map Int ValueDestination
forall k a. k -> a -> Map k a
M.singleton Int
x (ValueDestination -> Map Int ValueDestination)
-> ValueDestination -> Map Int ValueDestination
forall a b. (a -> b) -> a -> b
$ VName -> ValueDestination
MemoryDestination VName
memout
            )
          VName
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
     VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
memout
        ReturnsInBlock VName
memout ExtIxFun
_ ->
          VName
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
     VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
memout
      [SubExp]
resultshape <- (Ext SubExp
 -> StateT
      (Map Any Any, Map Int VName)
      (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
      SubExp)
-> [Ext SubExp]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
     [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Ext SubExp
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
     SubExp
inspectExtSize ([Ext SubExp]
 -> StateT
      (Map Any Any, Map Int VName)
      (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
      [SubExp])
-> [Ext SubExp]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
     [SubExp]
forall a b. (a -> b) -> a -> b
$ ShapeBase (Ext SubExp) -> [Ext SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase (Ext SubExp)
shape
      (ValueDesc, ValueDestination)
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
     (ValueDesc, ValueDestination)
forall (m :: * -> *) a. Monad m => a -> m a
return
        ( VName -> Space -> PrimType -> Signedness -> [SubExp] -> ValueDesc
Imp.ArrayValue VName
memout Space
space PrimType
t Signedness
ept [SubExp]
resultshape,
          Maybe MemLocation -> ValueDestination
ArrayDestination Maybe MemLocation
forall a. Maybe a
Nothing
        )

    inspectExtSize :: Ext SubExp
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
     SubExp
inspectExtSize (Ext Int
x) = do
      (Map Any Any
memseen, Map Int VName
arrseen) <- StateT
  (Map Any Any, Map Int VName)
  (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
  (Map Any Any, Map Int VName)
forall s (m :: * -> *). MonadState s m => m s
get
      case Int -> Map Int VName -> Maybe VName
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Int
x Map Int VName
arrseen of
        Maybe VName
Nothing -> do
          VName
out <- ImpM rep r op VName
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
     VName
forall {a}.
ImpM rep r op a
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
     a
imp (ImpM rep r op VName
 -> StateT
      (Map Any Any, Map Int VName)
      (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
      VName)
-> ImpM rep r op VName
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
     VName
forall a b. (a -> b) -> a -> b
$ [Char] -> ImpM rep r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"out_arrsize"
          ([Param], Map Int ValueDestination)
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
     ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell
            ( [VName -> PrimType -> Param
Imp.ScalarParam VName
out PrimType
int64],
              Int -> ValueDestination -> Map Int ValueDestination
forall k a. k -> a -> Map k a
M.singleton Int
x (ValueDestination -> Map Int ValueDestination)
-> ValueDestination -> Map Int ValueDestination
forall a b. (a -> b) -> a -> b
$ VName -> ValueDestination
ScalarDestination VName
out
            )
          (Map Any Any, Map Int VName)
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
     ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (Map Any Any
memseen, Int -> VName -> Map Int VName -> Map Int VName
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Int
x VName
out Map Int VName
arrseen)
          SubExp
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
     SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
 -> StateT
      (Map Any Any, Map Int VName)
      (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
      SubExp)
-> SubExp
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
     SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
out
        Just VName
out ->
          SubExp
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
     SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
 -> StateT
      (Map Any Any, Map Int VName)
      (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
      SubExp)
-> SubExp
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
     SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
out
    inspectExtSize (Free SubExp
se) =
      SubExp
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM rep r op))
     SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
se

compileFunDef ::
  Mem rep =>
  FunDef rep ->
  ImpM rep r op ()
compileFunDef :: forall rep r op. Mem rep => FunDef rep -> ImpM rep r op ()
compileFunDef (FunDef Maybe EntryPoint
entry Attrs
_ Name
fname [RetType rep]
rettype [FParam rep]
params BodyT rep
body) =
  (Env rep r op -> Env rep r op)
-> ImpM rep r op () -> ImpM rep r op ()
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (\Env rep r op
env -> Env rep r op
env {envFunction :: Maybe Name
envFunction = Maybe Name
name_entry Maybe Name -> Maybe Name -> Maybe Name
forall (m :: * -> *) a. MonadPlus m => m a -> m a -> m a
`mplus` Name -> Maybe Name
forall a. a -> Maybe a
Just Name
fname}) (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ do
    (([Param]
outparams, [Param]
inparams, [ExternalValue]
results, [ExternalValue]
args), Code op
body') <- ImpM rep r op ([Param], [Param], [ExternalValue], [ExternalValue])
-> ImpM
     rep
     r
     op
     (([Param], [Param], [ExternalValue], [ExternalValue]), Code op)
forall rep r op a. ImpM rep r op a -> ImpM rep r op (a, Code op)
collect' ImpM rep r op ([Param], [Param], [ExternalValue], [ExternalValue])
compile
    Name -> Function op -> ImpM rep r op ()
forall op rep r. Name -> Function op -> ImpM rep r op ()
emitFunction Name
fname (Function op -> ImpM rep r op ())
-> Function op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Maybe Name
-> [Param]
-> [Param]
-> Code op
-> [ExternalValue]
-> [ExternalValue]
-> Function op
forall a.
Maybe Name
-> [Param]
-> [Param]
-> Code a
-> [ExternalValue]
-> [ExternalValue]
-> FunctionT a
Imp.Function Maybe Name
name_entry [Param]
outparams [Param]
inparams Code op
body' [ExternalValue]
results [ExternalValue]
args
  where
    (Maybe Name
name_entry, [EntryPointType]
params_entry, [EntryPointType]
ret_entry) = case Maybe EntryPoint
entry of
      Maybe EntryPoint
Nothing ->
        ( Maybe Name
forall a. Maybe a
Nothing,
          Int -> EntryPointType -> [EntryPointType]
forall a. Int -> a -> [a]
replicate ([Param FParamMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [FParam rep]
[Param FParamMem]
params) (Uniqueness -> EntryPointType
TypeDirect Uniqueness
forall a. Monoid a => a
mempty),
          Int -> EntryPointType -> [EntryPointType]
forall a. Int -> a -> [a]
replicate ([RetTypeMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [RetType rep]
[RetTypeMem]
rettype) (Uniqueness -> EntryPointType
TypeDirect Uniqueness
forall a. Monoid a => a
mempty)
        )
      Just (Name
x, [EntryPointType]
y, [EntryPointType]
z) -> (Name -> Maybe Name
forall a. a -> Maybe a
Just Name
x, [EntryPointType]
y, [EntryPointType]
z)
    compile :: ImpM rep r op ([Param], [Param], [ExternalValue], [ExternalValue])
compile = do
      ([Param]
inparams, [ArrayDecl]
arrayds, [ExternalValue]
args) <- [FParam rep]
-> [EntryPointType]
-> ImpM rep r op ([Param], [ArrayDecl], [ExternalValue])
forall rep r op.
Mem rep =>
[FParam rep]
-> [EntryPointType]
-> ImpM rep r op ([Param], [ArrayDecl], [ExternalValue])
compileInParams [FParam rep]
params [EntryPointType]
params_entry
      ([ExternalValue]
results, [Param]
outparams, Destination Maybe Int
_ [ValueDestination]
dests) <- [RetType rep]
-> [EntryPointType]
-> ImpM rep r op ([ExternalValue], [Param], Destination)
forall rep r op.
Mem rep =>
[RetType rep]
-> [EntryPointType]
-> ImpM rep r op ([ExternalValue], [Param], Destination)
compileOutParams [RetType rep]
rettype [EntryPointType]
ret_entry
      [FParam rep] -> ImpM rep r op ()
forall rep r op. Mem rep => [FParam rep] -> ImpM rep r op ()
addFParams [FParam rep]
params
      [ArrayDecl] -> ImpM rep r op ()
forall rep r op. [ArrayDecl] -> ImpM rep r op ()
addArrays [ArrayDecl]
arrayds

      let Body BodyDec rep
_ Stms rep
stms [SubExp]
ses = BodyT rep
body
      Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms ([SubExp] -> Names
forall a. FreeIn a => a -> Names
freeIn [SubExp]
ses) Stms rep
stms (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
        [(ValueDestination, SubExp)]
-> ((ValueDestination, SubExp) -> ImpM rep r op ())
-> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([ValueDestination] -> [SubExp] -> [(ValueDestination, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [ValueDestination]
dests [SubExp]
ses) (((ValueDestination, SubExp) -> ImpM rep r op ())
 -> ImpM rep r op ())
-> ((ValueDestination, SubExp) -> ImpM rep r op ())
-> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \(ValueDestination
d, SubExp
se) -> ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIMDest ValueDestination
d [] SubExp
se []

      ([Param], [Param], [ExternalValue], [ExternalValue])
-> ImpM
     rep r op ([Param], [Param], [ExternalValue], [ExternalValue])
forall (m :: * -> *) a. Monad m => a -> m a
return ([Param]
outparams, [Param]
inparams, [ExternalValue]
results, [ExternalValue]
args)

compileBody :: (Mem rep) => Pattern rep -> Body rep -> ImpM rep r op ()
compileBody :: forall rep r op.
Mem rep =>
Pattern rep -> Body rep -> ImpM rep r op ()
compileBody Pattern rep
pat (Body BodyDec rep
_ Stms rep
bnds [SubExp]
ses) = do
  Destination Maybe Int
_ [ValueDestination]
dests <- Pattern rep -> ImpM rep r op Destination
forall rep r op.
Mem rep =>
Pattern rep -> ImpM rep r op Destination
destinationFromPattern Pattern rep
pat
  Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms ([SubExp] -> Names
forall a. FreeIn a => a -> Names
freeIn [SubExp]
ses) Stms rep
bnds (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
    [(ValueDestination, SubExp)]
-> ((ValueDestination, SubExp) -> ImpM rep r op ())
-> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([ValueDestination] -> [SubExp] -> [(ValueDestination, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [ValueDestination]
dests [SubExp]
ses) (((ValueDestination, SubExp) -> ImpM rep r op ())
 -> ImpM rep r op ())
-> ((ValueDestination, SubExp) -> ImpM rep r op ())
-> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \(ValueDestination
d, SubExp
se) -> ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIMDest ValueDestination
d [] SubExp
se []

compileBody' :: [Param dec] -> Body rep -> ImpM rep r op ()
compileBody' :: forall dec rep r op. [Param dec] -> Body rep -> ImpM rep r op ()
compileBody' [Param dec]
params (Body BodyDec rep
_ Stms rep
bnds [SubExp]
ses) =
  Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms ([SubExp] -> Names
forall a. FreeIn a => a -> Names
freeIn [SubExp]
ses) Stms rep
bnds (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
    [(Param dec, SubExp)]
-> ((Param dec, SubExp) -> ImpM rep r op ()) -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param dec] -> [SubExp] -> [(Param dec, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param dec]
params [SubExp]
ses) (((Param dec, SubExp) -> ImpM rep r op ()) -> ImpM rep r op ())
-> ((Param dec, SubExp) -> ImpM rep r op ()) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \(Param dec
param, SubExp
se) -> VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
param) [] SubExp
se []

compileLoopBody :: Typed dec => [Param dec] -> Body rep -> ImpM rep r op ()
compileLoopBody :: forall dec rep r op.
Typed dec =>
[Param dec] -> Body rep -> ImpM rep r op ()
compileLoopBody [Param dec]
mergeparams (Body BodyDec rep
_ Stms rep
bnds [SubExp]
ses) = do
  -- We cannot write the results to the merge parameters immediately,
  -- as some of the results may actually *be* merge parameters, and
  -- would thus be clobbered.  Therefore, we first copy to new
  -- variables mirroring the merge parameters, and then copy this
  -- buffer to the merge parameters.  This is efficient, because the
  -- operations are all scalar operations.
  [VName]
tmpnames <- (Param dec -> ImpM rep r op VName)
-> [Param dec] -> ImpM rep r op [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([Char] -> ImpM rep r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName ([Char] -> ImpM rep r op VName)
-> (Param dec -> [Char]) -> Param dec -> ImpM rep r op VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
"_tmp") ShowS -> (Param dec -> [Char]) -> Param dec -> [Char]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> [Char]
baseString (VName -> [Char]) -> (Param dec -> VName) -> Param dec -> [Char]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param dec -> VName
forall dec. Param dec -> VName
paramName) [Param dec]
mergeparams
  Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms ([SubExp] -> Names
forall a. FreeIn a => a -> Names
freeIn [SubExp]
ses) Stms rep
bnds (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ do
    [ImpM rep r op ()]
copy_to_merge_params <- [(Param dec, VName, SubExp)]
-> ((Param dec, VName, SubExp) -> ImpM rep r op (ImpM rep r op ()))
-> ImpM rep r op [ImpM rep r op ()]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Param dec] -> [VName] -> [SubExp] -> [(Param dec, VName, SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Param dec]
mergeparams [VName]
tmpnames [SubExp]
ses) (((Param dec, VName, SubExp) -> ImpM rep r op (ImpM rep r op ()))
 -> ImpM rep r op [ImpM rep r op ()])
-> ((Param dec, VName, SubExp) -> ImpM rep r op (ImpM rep r op ()))
-> ImpM rep r op [ImpM rep r op ()]
forall a b. (a -> b) -> a -> b
$ \(Param dec
p, VName
tmp, SubExp
se) ->
      case Param dec -> Type
forall t. Typed t => t -> Type
typeOf Param dec
p of
        Prim PrimType
pt -> do
          Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Volatility -> PrimType -> Code op
forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
tmp Volatility
Imp.Nonvolatile PrimType
pt
          Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
tmp (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
pt SubExp
se
          ImpM rep r op () -> ImpM rep r op (ImpM rep r op ())
forall (m :: * -> *) a. Monad m => a -> m a
return (ImpM rep r op () -> ImpM rep r op (ImpM rep r op ()))
-> ImpM rep r op () -> ImpM rep r op (ImpM rep r op ())
forall a b. (a -> b) -> a -> b
$ Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar (Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p) (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
tmp PrimType
pt
        Mem Space
space | Var VName
v <- SubExp
se -> do
          Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Code op
forall a. VName -> Space -> Code a
Imp.DeclareMem VName
tmp Space
space
          Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> Space -> Code op
forall a. VName -> VName -> Space -> Code a
Imp.SetMem VName
tmp VName
v Space
space
          ImpM rep r op () -> ImpM rep r op (ImpM rep r op ())
forall (m :: * -> *) a. Monad m => a -> m a
return (ImpM rep r op () -> ImpM rep r op (ImpM rep r op ()))
-> ImpM rep r op () -> ImpM rep r op (ImpM rep r op ())
forall a b. (a -> b) -> a -> b
$ Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> Space -> Code op
forall a. VName -> VName -> Space -> Code a
Imp.SetMem (Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p) VName
tmp Space
space
        Type
_ -> ImpM rep r op () -> ImpM rep r op (ImpM rep r op ())
forall (m :: * -> *) a. Monad m => a -> m a
return (ImpM rep r op () -> ImpM rep r op (ImpM rep r op ()))
-> ImpM rep r op () -> ImpM rep r op (ImpM rep r op ())
forall a b. (a -> b) -> a -> b
$ () -> ImpM rep r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    [ImpM rep r op ()] -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ [ImpM rep r op ()]
copy_to_merge_params

compileStms :: Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms :: forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
alive_after_stms Stms rep
all_stms ImpM rep r op ()
m = do
  StmsCompiler rep r op
cb <- (Env rep r op -> StmsCompiler rep r op)
-> ImpM rep r op (StmsCompiler rep r op)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env rep r op -> StmsCompiler rep r op
forall rep r op. Env rep r op -> StmsCompiler rep r op
envStmsCompiler
  StmsCompiler rep r op
cb Names
alive_after_stms Stms rep
all_stms ImpM rep r op ()
m

defCompileStms ::
  (Mem rep, FreeIn op) =>
  Names ->
  Stms rep ->
  ImpM rep r op () ->
  ImpM rep r op ()
defCompileStms :: forall rep op r.
(Mem rep, FreeIn op) =>
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
defCompileStms Names
alive_after_stms Stms rep
all_stms ImpM rep r op ()
m =
  -- We keep track of any memory blocks produced by the statements,
  -- and after the last time that memory block is used, we insert a
  -- Free.  This is very conservative, but can cut down on lifetimes
  -- in some cases.
  ImpM rep r op Names -> ImpM rep r op ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ImpM rep r op Names -> ImpM rep r op ())
-> ImpM rep r op Names -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Set (VName, Space) -> [Stm rep] -> ImpM rep r op Names
compileStms' Set (VName, Space)
forall a. Monoid a => a
mempty ([Stm rep] -> ImpM rep r op Names)
-> [Stm rep] -> ImpM rep r op Names
forall a b. (a -> b) -> a -> b
$ Stms rep -> [Stm rep]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms rep
all_stms
  where
    compileStms' :: Set (VName, Space) -> [Stm rep] -> ImpM rep r op Names
compileStms' Set (VName, Space)
allocs (Let Pattern rep
pat StmAux (ExpDec rep)
aux Exp rep
e : [Stm rep]
bs) = do
      Maybe (Exp rep) -> [PatElem rep] -> ImpM rep r op ()
forall rep r op.
Mem rep =>
Maybe (Exp rep) -> [PatElem rep] -> ImpM rep r op ()
dVars (Exp rep -> Maybe (Exp rep)
forall a. a -> Maybe a
Just Exp rep
e) (PatternT LParamMem -> [PatElemT LParamMem]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern rep
PatternT LParamMem
pat)

      Code op
e_code <-
        Attrs -> ImpM rep r op (Code op) -> ImpM rep r op (Code op)
forall rep r op a. Attrs -> ImpM rep r op a -> ImpM rep r op a
localAttrs (StmAux (ExpDec rep) -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux (ExpDec rep)
aux) (ImpM rep r op (Code op) -> ImpM rep r op (Code op))
-> ImpM rep r op (Code op) -> ImpM rep r op (Code op)
forall a b. (a -> b) -> a -> b
$
          ImpM rep r op () -> ImpM rep r op (Code op)
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect (ImpM rep r op () -> ImpM rep r op (Code op))
-> ImpM rep r op () -> ImpM rep r op (Code op)
forall a b. (a -> b) -> a -> b
$ Pattern rep -> Exp rep -> ImpM rep r op ()
forall rep r op. Pattern rep -> Exp rep -> ImpM rep r op ()
compileExp Pattern rep
pat Exp rep
e
      (Names
live_after, Code op
bs_code) <- ImpM rep r op Names -> ImpM rep r op (Names, Code op)
forall rep r op a. ImpM rep r op a -> ImpM rep r op (a, Code op)
collect' (ImpM rep r op Names -> ImpM rep r op (Names, Code op))
-> ImpM rep r op Names -> ImpM rep r op (Names, Code op)
forall a b. (a -> b) -> a -> b
$ Set (VName, Space) -> [Stm rep] -> ImpM rep r op Names
compileStms' (PatternT LParamMem -> Set (VName, Space)
patternAllocs Pattern rep
PatternT LParamMem
pat Set (VName, Space) -> Set (VName, Space) -> Set (VName, Space)
forall a. Semigroup a => a -> a -> a
<> Set (VName, Space)
allocs) [Stm rep]
bs
      let dies_here :: VName -> Bool
dies_here VName
v =
            Bool -> Bool
not (VName
v VName -> Names -> Bool
`nameIn` Names
live_after)
              Bool -> Bool -> Bool
&& VName
v VName -> Names -> Bool
`nameIn` Code op -> Names
forall a. FreeIn a => a -> Names
freeIn Code op
e_code
          to_free :: Set (VName, Space)
to_free = ((VName, Space) -> Bool)
-> Set (VName, Space) -> Set (VName, Space)
forall a. (a -> Bool) -> Set a -> Set a
S.filter (VName -> Bool
dies_here (VName -> Bool)
-> ((VName, Space) -> VName) -> (VName, Space) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, Space) -> VName
forall a b. (a, b) -> a
fst) Set (VName, Space)
allocs

      Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit Code op
e_code
      ((VName, Space) -> ImpM rep r op ())
-> Set (VName, Space) -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ())
-> ((VName, Space) -> Code op)
-> (VName, Space)
-> ImpM rep r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> Space -> Code op) -> (VName, Space) -> Code op
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry VName -> Space -> Code op
forall a. VName -> Space -> Code a
Imp.Free) Set (VName, Space)
to_free
      Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit Code op
bs_code

      Names -> ImpM rep r op Names
forall (m :: * -> *) a. Monad m => a -> m a
return (Names -> ImpM rep r op Names) -> Names -> ImpM rep r op Names
forall a b. (a -> b) -> a -> b
$ Code op -> Names
forall a. FreeIn a => a -> Names
freeIn Code op
e_code Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
live_after
    compileStms' Set (VName, Space)
_ [] = do
      Code op
code <- ImpM rep r op () -> ImpM rep r op (Code op)
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect ImpM rep r op ()
m
      Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit Code op
code
      Names -> ImpM rep r op Names
forall (m :: * -> *) a. Monad m => a -> m a
return (Names -> ImpM rep r op Names) -> Names -> ImpM rep r op Names
forall a b. (a -> b) -> a -> b
$ Code op -> Names
forall a. FreeIn a => a -> Names
freeIn Code op
code Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
alive_after_stms

    patternAllocs :: PatternT LParamMem -> Set (VName, Space)
patternAllocs = [(VName, Space)] -> Set (VName, Space)
forall a. Ord a => [a] -> Set a
S.fromList ([(VName, Space)] -> Set (VName, Space))
-> (PatternT LParamMem -> [(VName, Space)])
-> PatternT LParamMem
-> Set (VName, Space)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElemT LParamMem -> Maybe (VName, Space))
-> [PatElemT LParamMem] -> [(VName, Space)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe PatElemT LParamMem -> Maybe (VName, Space)
forall {dec}. Typed dec => PatElemT dec -> Maybe (VName, Space)
isMemPatElem ([PatElemT LParamMem] -> [(VName, Space)])
-> (PatternT LParamMem -> [PatElemT LParamMem])
-> PatternT LParamMem
-> [(VName, Space)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatternT LParamMem -> [PatElemT LParamMem]
forall dec. PatternT dec -> [PatElemT dec]
patternElements
    isMemPatElem :: PatElemT dec -> Maybe (VName, Space)
isMemPatElem PatElemT dec
pe = case PatElemT dec -> Type
forall dec. Typed dec => PatElemT dec -> Type
patElemType PatElemT dec
pe of
      Mem Space
space -> (VName, Space) -> Maybe (VName, Space)
forall a. a -> Maybe a
Just (PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT dec
pe, Space
space)
      Type
_ -> Maybe (VName, Space)
forall a. Maybe a
Nothing

compileExp :: Pattern rep -> Exp rep -> ImpM rep r op ()
compileExp :: forall rep r op. Pattern rep -> Exp rep -> ImpM rep r op ()
compileExp Pattern rep
pat Exp rep
e = do
  Pattern rep -> Exp rep -> ImpM rep r op ()
ec <- (Env rep r op -> Pattern rep -> Exp rep -> ImpM rep r op ())
-> ImpM rep r op (Pattern rep -> Exp rep -> ImpM rep r op ())
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env rep r op -> Pattern rep -> Exp rep -> ImpM rep r op ()
forall rep r op. Env rep r op -> ExpCompiler rep r op
envExpCompiler
  Pattern rep -> Exp rep -> ImpM rep r op ()
ec Pattern rep
pat Exp rep
e

defCompileExp ::
  (Mem rep) =>
  Pattern rep ->
  Exp rep ->
  ImpM rep r op ()
defCompileExp :: forall rep r op.
Mem rep =>
Pattern rep -> Exp rep -> ImpM rep r op ()
defCompileExp Pattern rep
pat (If SubExp
cond BodyT rep
tbranch BodyT rep
fbranch IfDec (BranchType rep)
_) =
  TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf (SubExp -> TExp Bool
forall a. ToExp a => a -> TExp Bool
toBoolExp SubExp
cond) (Pattern rep -> BodyT rep -> ImpM rep r op ()
forall rep r op.
Mem rep =>
Pattern rep -> Body rep -> ImpM rep r op ()
compileBody Pattern rep
pat BodyT rep
tbranch) (Pattern rep -> BodyT rep -> ImpM rep r op ()
forall rep r op.
Mem rep =>
Pattern rep -> Body rep -> ImpM rep r op ()
compileBody Pattern rep
pat BodyT rep
fbranch)
defCompileExp Pattern rep
pat (Apply Name
fname [(SubExp, Diet)]
args [RetType rep]
_ (Safety, SrcLoc, [SrcLoc])
_) = do
  Destination
dest <- Pattern rep -> ImpM rep r op Destination
forall rep r op.
Mem rep =>
Pattern rep -> ImpM rep r op Destination
destinationFromPattern Pattern rep
pat
  [VName]
targets <- Destination -> ImpM rep r op [VName]
forall rep r op. Destination -> ImpM rep r op [VName]
funcallTargets Destination
dest
  [Arg]
args' <- [Maybe Arg] -> [Arg]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe Arg] -> [Arg])
-> ImpM rep r op [Maybe Arg] -> ImpM rep r op [Arg]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((SubExp, Diet) -> ImpM rep r op (Maybe Arg))
-> [(SubExp, Diet)] -> ImpM rep r op [Maybe Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SubExp, Diet) -> ImpM rep r op (Maybe Arg)
forall {m :: * -> *} {t} {b}.
(Monad m, HasScope t m) =>
(SubExp, b) -> m (Maybe Arg)
compileArg [(SubExp, Diet)]
args
  Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Name -> [Arg] -> Code op
forall a. [VName] -> Name -> [Arg] -> Code a
Imp.Call [VName]
targets Name
fname [Arg]
args'
  where
    compileArg :: (SubExp, b) -> m (Maybe Arg)
compileArg (SubExp
se, b
_) = do
      Type
t <- SubExp -> m Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
se
      case (SubExp
se, Type
t) of
        (SubExp
_, Prim PrimType
pt) -> Maybe Arg -> m (Maybe Arg)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Arg -> m (Maybe Arg)) -> Maybe Arg -> m (Maybe Arg)
forall a b. (a -> b) -> a -> b
$ Arg -> Maybe Arg
forall a. a -> Maybe a
Just (Arg -> Maybe Arg) -> Arg -> Maybe Arg
forall a b. (a -> b) -> a -> b
$ Exp -> Arg
Imp.ExpArg (Exp -> Arg) -> Exp -> Arg
forall a b. (a -> b) -> a -> b
$ PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
pt SubExp
se
        (Var VName
v, Mem {}) -> Maybe Arg -> m (Maybe Arg)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Arg -> m (Maybe Arg)) -> Maybe Arg -> m (Maybe Arg)
forall a b. (a -> b) -> a -> b
$ Arg -> Maybe Arg
forall a. a -> Maybe a
Just (Arg -> Maybe Arg) -> Arg -> Maybe Arg
forall a b. (a -> b) -> a -> b
$ VName -> Arg
Imp.MemArg VName
v
        (SubExp, Type)
_ -> Maybe Arg -> m (Maybe Arg)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Arg
forall a. Maybe a
Nothing
defCompileExp Pattern rep
pat (BasicOp BasicOp
op) = Pattern rep -> BasicOp -> ImpM rep r op ()
forall rep r op.
Mem rep =>
Pattern rep -> BasicOp -> ImpM rep r op ()
defCompileBasicOp Pattern rep
pat BasicOp
op
defCompileExp Pattern rep
pat (DoLoop [(FParam rep, SubExp)]
ctx [(FParam rep, SubExp)]
val LoopForm rep
form BodyT rep
body) = do
  Attrs
attrs <- ImpM rep r op Attrs
forall rep r op. ImpM rep r op Attrs
askAttrs
  Bool -> ImpM rep r op () -> ImpM rep r op ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Attr
"unroll" Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs) (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
    SrcLoc -> [SrcLoc] -> [Char] -> ImpM rep r op ()
forall loc rep r op.
Located loc =>
loc -> [loc] -> [Char] -> ImpM rep r op ()
warn (SrcLoc
forall a. IsLocation a => a
noLoc :: SrcLoc) [] [Char]
"#[unroll] on loop with unknown number of iterations." -- FIXME: no location.
  [FParam rep] -> ImpM rep r op ()
forall rep r op. Mem rep => [FParam rep] -> ImpM rep r op ()
dFParams [FParam rep]
[Param FParamMem]
mergepat
  [(Param FParamMem, SubExp)]
-> ((Param FParamMem, SubExp) -> ImpM rep r op ())
-> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(Param FParamMem, SubExp)]
merge (((Param FParamMem, SubExp) -> ImpM rep r op ())
 -> ImpM rep r op ())
-> ((Param FParamMem, SubExp) -> ImpM rep r op ())
-> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \(Param FParamMem
p, SubExp
se) ->
    Bool -> ImpM rep r op () -> ImpM rep r op ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ((Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0) (Int -> Bool) -> Int -> Bool
forall a b. (a -> b) -> a -> b
$ Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (Type -> Int) -> Type -> Int
forall a b. (a -> b) -> a -> b
$ Param FParamMem -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param FParamMem
p) (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
      VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (Param FParamMem -> VName
forall dec. Param dec -> VName
paramName Param FParamMem
p) [] SubExp
se []

  let doBody :: ImpM rep r op ()
doBody = [Param FParamMem] -> BodyT rep -> ImpM rep r op ()
forall dec rep r op.
Typed dec =>
[Param dec] -> Body rep -> ImpM rep r op ()
compileLoopBody [Param FParamMem]
mergepat BodyT rep
body

  case LoopForm rep
form of
    ForLoop VName
i IntType
_ SubExp
bound [(LParam rep, VName)]
loopvars -> do
      let setLoopParam :: (Param LParamMem, VName) -> ImpM rep r op ()
setLoopParam (Param LParamMem
p, VName
a)
            | Prim PrimType
_ <- Param LParamMem -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param LParamMem
p =
              VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var VName
a) [TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix (TExp Int64 -> DimIndex (TExp Int64))
-> TExp Int64 -> DimIndex (TExp Int64)
forall a b. (a -> b) -> a -> b
$ VName -> TExp Int64
Imp.vi64 VName
i]
            | Bool
otherwise =
              () -> ImpM rep r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

      Exp
bound' <- SubExp -> ImpM rep r op Exp
forall a rep r op. ToExp a => a -> ImpM rep r op Exp
toExp SubExp
bound

      [LParam rep] -> ImpM rep r op ()
forall rep r op. Mem rep => [LParam rep] -> ImpM rep r op ()
dLParams ([LParam rep] -> ImpM rep r op ())
-> [LParam rep] -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ ((Param LParamMem, VName) -> Param LParamMem)
-> [(Param LParamMem, VName)] -> [Param LParamMem]
forall a b. (a -> b) -> [a] -> [b]
map (Param LParamMem, VName) -> Param LParamMem
forall a b. (a, b) -> a
fst [(LParam rep, VName)]
[(Param LParamMem, VName)]
loopvars
      VName -> Exp -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op.
VName -> Exp -> ImpM rep r op () -> ImpM rep r op ()
sFor' VName
i Exp
bound' (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
        ((Param LParamMem, VName) -> ImpM rep r op ())
-> [(Param LParamMem, VName)] -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Param LParamMem, VName) -> ImpM rep r op ()
setLoopParam [(LParam rep, VName)]
[(Param LParamMem, VName)]
loopvars ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ImpM rep r op ()
doBody
    WhileLoop VName
cond ->
      TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhile (Exp -> TExp Bool
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (Exp -> TExp Bool) -> Exp -> TExp Bool
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
cond PrimType
Bool) ImpM rep r op ()
doBody

  Destination Maybe Int
_ [ValueDestination]
pat_dests <- Pattern rep -> ImpM rep r op Destination
forall rep r op.
Mem rep =>
Pattern rep -> ImpM rep r op Destination
destinationFromPattern Pattern rep
pat
  [(ValueDestination, SubExp)]
-> ((ValueDestination, SubExp) -> ImpM rep r op ())
-> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([ValueDestination] -> [SubExp] -> [(ValueDestination, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [ValueDestination]
pat_dests ([SubExp] -> [(ValueDestination, SubExp)])
-> [SubExp] -> [(ValueDestination, SubExp)]
forall a b. (a -> b) -> a -> b
$ ((Param FParamMem, SubExp) -> SubExp)
-> [(Param FParamMem, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> ((Param FParamMem, SubExp) -> VName)
-> (Param FParamMem, SubExp)
-> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param FParamMem -> VName
forall dec. Param dec -> VName
paramName (Param FParamMem -> VName)
-> ((Param FParamMem, SubExp) -> Param FParamMem)
-> (Param FParamMem, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param FParamMem, SubExp) -> Param FParamMem
forall a b. (a, b) -> a
fst) [(Param FParamMem, SubExp)]
merge) (((ValueDestination, SubExp) -> ImpM rep r op ())
 -> ImpM rep r op ())
-> ((ValueDestination, SubExp) -> ImpM rep r op ())
-> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \(ValueDestination
d, SubExp
r) ->
    ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIMDest ValueDestination
d [] SubExp
r []
  where
    merge :: [(Param FParamMem, SubExp)]
merge = [(FParam rep, SubExp)]
[(Param FParamMem, SubExp)]
ctx [(Param FParamMem, SubExp)]
-> [(Param FParamMem, SubExp)] -> [(Param FParamMem, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(FParam rep, SubExp)]
[(Param FParamMem, SubExp)]
val
    mergepat :: [Param FParamMem]
mergepat = ((Param FParamMem, SubExp) -> Param FParamMem)
-> [(Param FParamMem, SubExp)] -> [Param FParamMem]
forall a b. (a -> b) -> [a] -> [b]
map (Param FParamMem, SubExp) -> Param FParamMem
forall a b. (a, b) -> a
fst [(Param FParamMem, SubExp)]
merge
defCompileExp Pattern rep
pat (WithAcc [(Shape, [VName], Maybe (Lambda rep, [SubExp]))]
inputs Lambda rep
lam) = do
  [LParam rep] -> ImpM rep r op ()
forall rep r op. Mem rep => [LParam rep] -> ImpM rep r op ()
dLParams ([LParam rep] -> ImpM rep r op ())
-> [LParam rep] -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [LParam rep]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda rep
lam
  [((Shape, [VName], Maybe (Lambda rep, [SubExp])), Param LParamMem)]
-> (((Shape, [VName], Maybe (Lambda rep, [SubExp])),
     Param LParamMem)
    -> ImpM rep r op ())
-> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(Shape, [VName], Maybe (Lambda rep, [SubExp]))]
-> [Param LParamMem]
-> [((Shape, [VName], Maybe (Lambda rep, [SubExp])),
     Param LParamMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(Shape, [VName], Maybe (Lambda rep, [SubExp]))]
inputs ([Param LParamMem]
 -> [((Shape, [VName], Maybe (Lambda rep, [SubExp])),
      Param LParamMem)])
-> [Param LParamMem]
-> [((Shape, [VName], Maybe (Lambda rep, [SubExp])),
     Param LParamMem)]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [LParam rep]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda rep
lam) ((((Shape, [VName], Maybe (Lambda rep, [SubExp])), Param LParamMem)
  -> ImpM rep r op ())
 -> ImpM rep r op ())
-> (((Shape, [VName], Maybe (Lambda rep, [SubExp])),
     Param LParamMem)
    -> ImpM rep r op ())
-> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \((Shape
_, [VName]
arrs, Maybe (Lambda rep, [SubExp])
op), Param LParamMem
p) ->
    (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ())
-> (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s ->
      ImpState rep r op
s {stateAccs :: Map VName ([VName], Maybe (Lambda rep, [SubExp]))
stateAccs = VName
-> ([VName], Maybe (Lambda rep, [SubExp]))
-> Map VName ([VName], Maybe (Lambda rep, [SubExp]))
-> Map VName ([VName], Maybe (Lambda rep, [SubExp]))
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) ([VName]
arrs, Maybe (Lambda rep, [SubExp])
op) (Map VName ([VName], Maybe (Lambda rep, [SubExp]))
 -> Map VName ([VName], Maybe (Lambda rep, [SubExp])))
-> Map VName ([VName], Maybe (Lambda rep, [SubExp]))
-> Map VName ([VName], Maybe (Lambda rep, [SubExp]))
forall a b. (a -> b) -> a -> b
$ ImpState rep r op
-> Map VName ([VName], Maybe (Lambda rep, [SubExp]))
forall rep r op.
ImpState rep r op
-> Map VName ([VName], Maybe (Lambda rep, [SubExp]))
stateAccs ImpState rep r op
s}
  Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (BodyT rep -> Stms rep
forall rep. BodyT rep -> Stms rep
bodyStms (BodyT rep -> Stms rep) -> BodyT rep -> Stms rep
forall a b. (a -> b) -> a -> b
$ Lambda rep -> BodyT rep
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda rep
lam) (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ do
    let nonacc_res :: [SubExp]
nonacc_res = Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
drop Int
num_accs (BodyT rep -> [SubExp]
forall rep. BodyT rep -> [SubExp]
bodyResult (Lambda rep -> BodyT rep
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda rep
lam))
        nonacc_pat_names :: [VName]
nonacc_pat_names = Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
takeLast ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nonacc_res) (PatternT LParamMem -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern rep
PatternT LParamMem
pat)
    [(VName, SubExp)]
-> ((VName, SubExp) -> ImpM rep r op ()) -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
nonacc_pat_names [SubExp]
nonacc_res) (((VName, SubExp) -> ImpM rep r op ()) -> ImpM rep r op ())
-> ((VName, SubExp) -> ImpM rep r op ()) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \(VName
v, SubExp
se) ->
      VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM VName
v [] SubExp
se []
  where
    num_accs :: Int
num_accs = [(Shape, [VName], Maybe (Lambda rep, [SubExp]))] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Shape, [VName], Maybe (Lambda rep, [SubExp]))]
inputs
defCompileExp Pattern rep
pat (Op Op rep
op) = do
  PatternT LParamMem -> Op rep -> ImpM rep r op ()
opc <- (Env rep r op -> PatternT LParamMem -> Op rep -> ImpM rep r op ())
-> ImpM rep r op (PatternT LParamMem -> Op rep -> ImpM rep r op ())
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env rep r op -> PatternT LParamMem -> Op rep -> ImpM rep r op ()
forall rep r op. Env rep r op -> OpCompiler rep r op
envOpCompiler
  PatternT LParamMem -> Op rep -> ImpM rep r op ()
opc Pattern rep
PatternT LParamMem
pat Op rep
op

defCompileBasicOp ::
  Mem rep =>
  Pattern rep ->
  BasicOp ->
  ImpM rep r op ()
defCompileBasicOp :: forall rep r op.
Mem rep =>
Pattern rep -> BasicOp -> ImpM rep r op ()
defCompileBasicOp (Pattern [PatElemT (LetDec rep)]
_ [PatElemT (LetDec rep)
pe]) (SubExp SubExp
se) =
  VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
PatElemT LParamMem
pe) [] SubExp
se []
defCompileBasicOp (Pattern [PatElemT (LetDec rep)]
_ [PatElemT (LetDec rep)
pe]) (Opaque SubExp
se) =
  VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
PatElemT LParamMem
pe) [] SubExp
se []
defCompileBasicOp (Pattern [PatElemT (LetDec rep)]
_ [PatElemT (LetDec rep)
pe]) (UnOp UnOp
op SubExp
e) = do
  Exp
e' <- SubExp -> ImpM rep r op Exp
forall a rep r op. ToExp a => a -> ImpM rep r op Exp
toExp SubExp
e
  PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
PatElemT LParamMem
pe VName -> Exp -> ImpM rep r op ()
forall rep r op. VName -> Exp -> ImpM rep r op ()
<~~ UnOp -> Exp -> Exp
forall v. UnOp -> PrimExp v -> PrimExp v
Imp.UnOpExp UnOp
op Exp
e'
defCompileBasicOp (Pattern [PatElemT (LetDec rep)]
_ [PatElemT (LetDec rep)
pe]) (ConvOp ConvOp
conv SubExp
e) = do
  Exp
e' <- SubExp -> ImpM rep r op Exp
forall a rep r op. ToExp a => a -> ImpM rep r op Exp
toExp SubExp
e
  PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
PatElemT LParamMem
pe VName -> Exp -> ImpM rep r op ()
forall rep r op. VName -> Exp -> ImpM rep r op ()
<~~ ConvOp -> Exp -> Exp
forall v. ConvOp -> PrimExp v -> PrimExp v
Imp.ConvOpExp ConvOp
conv Exp
e'
defCompileBasicOp (Pattern [PatElemT (LetDec rep)]
_ [PatElemT (LetDec rep)
pe]) (BinOp BinOp
bop SubExp
x SubExp
y) = do
  Exp
x' <- SubExp -> ImpM rep r op Exp
forall a rep r op. ToExp a => a -> ImpM rep r op Exp
toExp SubExp
x
  Exp
y' <- SubExp -> ImpM rep r op Exp
forall a rep r op. ToExp a => a -> ImpM rep r op Exp
toExp SubExp
y
  PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
PatElemT LParamMem
pe VName -> Exp -> ImpM rep r op ()
forall rep r op. VName -> Exp -> ImpM rep r op ()
<~~ BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp BinOp
bop Exp
x' Exp
y'
defCompileBasicOp (Pattern [PatElemT (LetDec rep)]
_ [PatElemT (LetDec rep)
pe]) (CmpOp CmpOp
bop SubExp
x SubExp
y) = do
  Exp
x' <- SubExp -> ImpM rep r op Exp
forall a rep r op. ToExp a => a -> ImpM rep r op Exp
toExp SubExp
x
  Exp
y' <- SubExp -> ImpM rep r op Exp
forall a rep r op. ToExp a => a -> ImpM rep r op Exp
toExp SubExp
y
  PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
PatElemT LParamMem
pe VName -> Exp -> ImpM rep r op ()
forall rep r op. VName -> Exp -> ImpM rep r op ()
<~~ CmpOp -> Exp -> Exp -> Exp
forall v. CmpOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.CmpOpExp CmpOp
bop Exp
x' Exp
y'
defCompileBasicOp PatternT (LetDec rep)
_ (Assert SubExp
e ErrorMsg SubExp
msg (SrcLoc, [SrcLoc])
loc) = do
  Exp
e' <- SubExp -> ImpM rep r op Exp
forall a rep r op. ToExp a => a -> ImpM rep r op Exp
toExp SubExp
e
  ErrorMsg Exp
msg' <- (SubExp -> ImpM rep r op Exp)
-> ErrorMsg SubExp -> ImpM rep r op (ErrorMsg Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> ImpM rep r op Exp
forall a rep r op. ToExp a => a -> ImpM rep r op Exp
toExp ErrorMsg SubExp
msg
  Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Exp -> ErrorMsg Exp -> (SrcLoc, [SrcLoc]) -> Code op
forall a. Exp -> ErrorMsg Exp -> (SrcLoc, [SrcLoc]) -> Code a
Imp.Assert Exp
e' ErrorMsg Exp
msg' (SrcLoc, [SrcLoc])
loc

  Attrs
attrs <- ImpM rep r op Attrs
forall rep r op. ImpM rep r op Attrs
askAttrs
  Bool -> ImpM rep r op () -> ImpM rep r op ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Name -> [Attr] -> Attr
AttrComp Name
"warn" [Attr
"safety_checks"] Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs) (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
    (SrcLoc -> [SrcLoc] -> [Char] -> ImpM rep r op ())
-> (SrcLoc, [SrcLoc]) -> [Char] -> ImpM rep r op ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry SrcLoc -> [SrcLoc] -> [Char] -> ImpM rep r op ()
forall loc rep r op.
Located loc =>
loc -> [loc] -> [Char] -> ImpM rep r op ()
warn (SrcLoc, [SrcLoc])
loc [Char]
"Safety check required at run-time."
defCompileBasicOp (Pattern [PatElemT (LetDec rep)]
_ [PatElemT (LetDec rep)
pe]) (Index VName
src Slice SubExp
slice)
  | Just [SubExp]
idxs <- Slice SubExp -> Maybe [SubExp]
forall d. Slice d -> Maybe [d]
sliceIndices Slice SubExp
slice =
    VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
PatElemT LParamMem
pe) [] (VName -> SubExp
Var VName
src) ([DimIndex (TExp Int64)] -> ImpM rep r op ())
-> [DimIndex (TExp Int64)] -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ (SubExp -> DimIndex (TExp Int64))
-> [SubExp] -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map (TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix (TExp Int64 -> DimIndex (TExp Int64))
-> (SubExp -> TExp Int64) -> SubExp -> DimIndex (TExp Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp) [SubExp]
idxs
defCompileBasicOp PatternT (LetDec rep)
_ Index {} =
  () -> ImpM rep r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
defCompileBasicOp (Pattern [PatElemT (LetDec rep)]
_ [PatElemT (LetDec rep)
pe]) (Update VName
_ Slice SubExp
slice SubExp
se) =
  VName -> [DimIndex (TExp Int64)] -> SubExp -> ImpM rep r op ()
forall rep r op.
VName -> [DimIndex (TExp Int64)] -> SubExp -> ImpM rep r op ()
sUpdate (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
PatElemT LParamMem
pe) ((DimIndex SubExp -> DimIndex (TExp Int64))
-> Slice SubExp -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map ((SubExp -> TExp Int64) -> DimIndex SubExp -> DimIndex (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp) Slice SubExp
slice) SubExp
se
defCompileBasicOp (Pattern [PatElemT (LetDec rep)]
_ [PatElemT (LetDec rep)
pe]) (Replicate (Shape [SubExp]
ds) SubExp
se) = do
  [Exp]
ds' <- (SubExp -> ImpM rep r op Exp) -> [SubExp] -> ImpM rep r op [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> ImpM rep r op Exp
forall a rep r op. ToExp a => a -> ImpM rep r op Exp
toExp [SubExp]
ds
  [VName]
is <- Int -> ImpM rep r op VName -> ImpM rep r op [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
ds) ([Char] -> ImpM rep r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"i")
  Code op
copy_elem <- ImpM rep r op () -> ImpM rep r op (Code op)
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect (ImpM rep r op () -> ImpM rep r op (Code op))
-> ImpM rep r op () -> ImpM rep r op (Code op)
forall a b. (a -> b) -> a -> b
$ VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
PatElemT LParamMem
pe) ((VName -> DimIndex (TExp Int64))
-> [VName] -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map (TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix (TExp Int64 -> DimIndex (TExp Int64))
-> (VName -> TExp Int64) -> VName -> DimIndex (TExp Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> TExp Int64
Imp.vi64) [VName]
is) SubExp
se []
  Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ ((Code op -> Code op)
 -> (Code op -> Code op) -> Code op -> Code op)
-> (Code op -> Code op)
-> [Code op -> Code op]
-> Code op
-> Code op
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (Code op -> Code op) -> (Code op -> Code op) -> Code op -> Code op
forall b c a. (b -> c) -> (a -> b) -> a -> c
(.) Code op -> Code op
forall a. a -> a
id ((VName -> Exp -> Code op -> Code op)
-> [VName] -> [Exp] -> [Code op -> Code op]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> Exp -> Code op -> Code op
forall a. VName -> Exp -> Code a -> Code a
Imp.For [VName]
is [Exp]
ds') Code op
copy_elem
defCompileBasicOp PatternT (LetDec rep)
_ Scratch {} =
  () -> ImpM rep r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
defCompileBasicOp (Pattern [] [PatElemT (LetDec rep)
pe]) (Iota SubExp
n SubExp
e SubExp
s IntType
it) = do
  Exp
e' <- SubExp -> ImpM rep r op Exp
forall a rep r op. ToExp a => a -> ImpM rep r op Exp
toExp SubExp
e
  Exp
s' <- SubExp -> ImpM rep r op Exp
forall a rep r op. ToExp a => a -> ImpM rep r op Exp
toExp SubExp
s
  [Char]
-> TExp Int64
-> (TExp Int64 -> ImpM rep r op ())
-> ImpM rep r op ()
forall t rep r op.
[Char]
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor [Char]
"i" (SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp SubExp
n) ((TExp Int64 -> ImpM rep r op ()) -> ImpM rep r op ())
-> (TExp Int64 -> ImpM rep r op ()) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
    let i' :: Exp
i' = IntType -> Exp -> Exp
forall v. IntType -> PrimExp v -> PrimExp v
sExt IntType
it (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
i
    TV Any
x <-
      [Char] -> TExp Any -> ImpM rep r op (TV Any)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
"x" (TExp Any -> ImpM rep r op (TV Any))
-> TExp Any -> ImpM rep r op (TV Any)
forall a b. (a -> b) -> a -> b
$
        Exp -> TExp Any
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (Exp -> TExp Any) -> Exp -> TExp Any
forall a b. (a -> b) -> a -> b
$
          BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> Overflow -> BinOp
Add IntType
it Overflow
OverflowUndef) Exp
e' (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$
            BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> Overflow -> BinOp
Mul IntType
it Overflow
OverflowUndef) Exp
i' Exp
s'
    VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
PatElemT LParamMem
pe) [TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix TExp Int64
i] (VName -> SubExp
Var (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
x)) []
defCompileBasicOp (Pattern [PatElemT (LetDec rep)]
_ [PatElemT (LetDec rep)
pe]) (Copy VName
src) =
  VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
PatElemT LParamMem
pe) [] (VName -> SubExp
Var VName
src) []
defCompileBasicOp (Pattern [PatElemT (LetDec rep)]
_ [PatElemT (LetDec rep)
pe]) (Manifest [Int]
_ VName
src) =
  VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
PatElemT LParamMem
pe) [] (VName -> SubExp
Var VName
src) []
defCompileBasicOp (Pattern [PatElemT (LetDec rep)]
_ [PatElemT (LetDec rep)
pe]) (Concat Int
i VName
x [VName]
ys SubExp
_) = do
  TV Int64
offs_glb <- [Char] -> TExp Int64 -> ImpM rep r op (TV Int64)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
"tmp_offs" TExp Int64
0

  [VName] -> (VName -> ImpM rep r op ()) -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (VName
x VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
ys) ((VName -> ImpM rep r op ()) -> ImpM rep r op ())
-> (VName -> ImpM rep r op ()) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \VName
y -> do
    [SubExp]
y_dims <- Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [SubExp]) -> ImpM rep r op Type -> ImpM rep r op [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM rep r op Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
y
    let rows :: TExp Int64
rows = case Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
drop Int
i [SubExp]
y_dims of
          [] -> [Char] -> TExp Int64
forall a. HasCallStack => [Char] -> a
error ([Char] -> TExp Int64) -> [Char] -> TExp Int64
forall a b. (a -> b) -> a -> b
$ [Char]
"defCompileBasicOp Concat: empty array shape for " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
y
          SubExp
r : [SubExp]
_ -> SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp SubExp
r
        skip_dims :: [SubExp]
skip_dims = Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
take Int
i [SubExp]
y_dims
        sliceAllDim :: d -> DimIndex d
sliceAllDim d
d = d -> d -> d -> DimIndex d
forall d. d -> d -> d -> DimIndex d
DimSlice d
0 d
d d
1
        skip_slices :: [DimIndex (TExp Int64)]
skip_slices = (SubExp -> DimIndex (TExp Int64))
-> [SubExp] -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map (TExp Int64 -> DimIndex (TExp Int64)
forall {d}. Num d => d -> DimIndex d
sliceAllDim (TExp Int64 -> DimIndex (TExp Int64))
-> (SubExp -> TExp Int64) -> SubExp -> DimIndex (TExp Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp) [SubExp]
skip_dims
        destslice :: [DimIndex (TExp Int64)]
destslice = [DimIndex (TExp Int64)]
skip_slices [DimIndex (TExp Int64)]
-> [DimIndex (TExp Int64)] -> [DimIndex (TExp Int64)]
forall a. [a] -> [a] -> [a]
++ [TExp Int64 -> TExp Int64 -> TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> d -> d -> DimIndex d
DimSlice (TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
offs_glb) TExp Int64
rows TExp Int64
1]
    VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
PatElemT LParamMem
pe) [DimIndex (TExp Int64)]
destslice (VName -> SubExp
Var VName
y) []
    TV Int64
offs_glb TV Int64 -> TExp Int64 -> ImpM rep r op ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
offs_glb TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
rows
defCompileBasicOp (Pattern [] [PatElemT (LetDec rep)
pe]) (ArrayLit [SubExp]
es Type
_)
  | Just vs :: [PrimValue]
vs@(PrimValue
v : [PrimValue]
_) <- (SubExp -> Maybe PrimValue) -> [SubExp] -> Maybe [PrimValue]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> Maybe PrimValue
isLiteral [SubExp]
es = do
    MemLocation
dest_mem <- ArrayEntry -> MemLocation
entryArrayLocation (ArrayEntry -> MemLocation)
-> ImpM rep r op ArrayEntry -> ImpM rep r op MemLocation
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM rep r op ArrayEntry
forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
PatElemT LParamMem
pe)
    Space
dest_space <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM rep r op MemEntry -> ImpM rep r op Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM rep r op MemEntry
forall rep r op. VName -> ImpM rep r op MemEntry
lookupMemory (MemLocation -> VName
memLocationName MemLocation
dest_mem)
    let t :: PrimType
t = PrimValue -> PrimType
primValueType PrimValue
v
    VName
static_array <- [Char] -> ImpM rep r op VName
forall rep r op. [Char] -> ImpM rep r op VName
newVNameForFun [Char]
"static_array"
    Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Space -> PrimType -> ArrayContents -> Code op
forall a. VName -> Space -> PrimType -> ArrayContents -> Code a
Imp.DeclareArray VName
static_array Space
dest_space PrimType
t (ArrayContents -> Code op) -> ArrayContents -> Code op
forall a b. (a -> b) -> a -> b
$ [PrimValue] -> ArrayContents
Imp.ArrayValues [PrimValue]
vs
    let static_src :: MemLocation
static_src =
          VName -> [SubExp] -> IxFun (TExp Int64) -> MemLocation
MemLocation VName
static_array [IntType -> Integer -> SubExp
intConst IntType
Int64 (Integer -> SubExp) -> Integer -> SubExp
forall a b. (a -> b) -> a -> b
$ Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Integer) -> Int -> Integer
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
es] (IxFun (TExp Int64) -> MemLocation)
-> IxFun (TExp Int64) -> MemLocation
forall a b. (a -> b) -> a -> b
$
            Shape (TExp Int64) -> IxFun (TExp Int64)
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota [Int -> TExp Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> TExp Int64) -> Int -> TExp Int64
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
es]
        entry :: VarEntry rep
entry = Maybe (Exp rep) -> MemEntry -> VarEntry rep
forall rep. Maybe (Exp rep) -> MemEntry -> VarEntry rep
MemVar Maybe (Exp rep)
forall a. Maybe a
Nothing (MemEntry -> VarEntry rep) -> MemEntry -> VarEntry rep
forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
dest_space
    VName -> VarEntry rep -> ImpM rep r op ()
forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
static_array VarEntry rep
entry
    let slice :: [DimIndex (TExp Int64)]
slice = [TExp Int64 -> TExp Int64 -> TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> d -> d -> DimIndex d
DimSlice TExp Int64
0 ([SubExp] -> TExp Int64
forall i a. Num i => [a] -> i
genericLength [SubExp]
es) TExp Int64
1]
    CopyCompiler rep r op
forall rep r op. CopyCompiler rep r op
copy PrimType
t MemLocation
dest_mem [DimIndex (TExp Int64)]
slice MemLocation
static_src [DimIndex (TExp Int64)]
slice
  | Bool
otherwise =
    [(Integer, SubExp)]
-> ((Integer, SubExp) -> ImpM rep r op ()) -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Integer] -> [SubExp] -> [(Integer, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Integer
0 ..] [SubExp]
es) (((Integer, SubExp) -> ImpM rep r op ()) -> ImpM rep r op ())
-> ((Integer, SubExp) -> ImpM rep r op ()) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \(Integer
i, SubExp
e) ->
      VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
PatElemT LParamMem
pe) [TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix (TExp Int64 -> DimIndex (TExp Int64))
-> TExp Int64 -> DimIndex (TExp Int64)
forall a b. (a -> b) -> a -> b
$ Integer -> TExp Int64
forall a. Num a => Integer -> a
fromInteger Integer
i] SubExp
e []
  where
    isLiteral :: SubExp -> Maybe PrimValue
isLiteral (Constant PrimValue
v) = PrimValue -> Maybe PrimValue
forall a. a -> Maybe a
Just PrimValue
v
    isLiteral SubExp
_ = Maybe PrimValue
forall a. Maybe a
Nothing
defCompileBasicOp PatternT (LetDec rep)
_ Rearrange {} =
  () -> ImpM rep r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
defCompileBasicOp PatternT (LetDec rep)
_ Rotate {} =
  () -> ImpM rep r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
defCompileBasicOp PatternT (LetDec rep)
_ Reshape {} =
  () -> ImpM rep r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
defCompileBasicOp PatternT (LetDec rep)
_ (UpdateAcc VName
acc [SubExp]
is [SubExp]
vs) = [Char] -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op. [Char] -> ImpM rep r op () -> ImpM rep r op ()
sComment [Char]
"UpdateAcc" (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ do
  -- We are abusing the comment mechanism to wrap the operator in
  -- braces when we end up generating code.  This is necessary because
  -- we might otherwise end up declaring lambda parameters (if any)
  -- multiple times, as they are duplicated every time we do an
  -- UpdateAcc for the same accumulator.
  let is' :: Shape (TExp Int64)
is' = (SubExp -> TExp Int64) -> [SubExp] -> Shape (TExp Int64)
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
is

  -- We need to figure out whether we are updating a scatter-like
  -- accumulator or a generalised reduction.  This also binds the
  -- index parameters.
  (VName
_, Space
_, [VName]
arrs, Shape (TExp Int64)
dims, Maybe (Lambda rep)
op) <- VName
-> Shape (TExp Int64)
-> ImpM
     rep
     r
     op
     (VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda rep))
forall rep r op.
VName
-> Shape (TExp Int64)
-> ImpM
     rep
     r
     op
     (VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda rep))
lookupAcc VName
acc Shape (TExp Int64)
is'

  TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen ([DimIndex (TExp Int64)] -> Shape (TExp Int64) -> TExp Bool
inBounds ((TExp Int64 -> DimIndex (TExp Int64))
-> Shape (TExp Int64) -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix Shape (TExp Int64)
is') Shape (TExp Int64)
dims) (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
    case Maybe (Lambda rep)
op of
      Maybe (Lambda rep)
Nothing ->
        -- Scatter-like.
        [(VName, SubExp)]
-> ((VName, SubExp) -> ImpM rep r op ()) -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
arrs [SubExp]
vs) (((VName, SubExp) -> ImpM rep r op ()) -> ImpM rep r op ())
-> ((VName, SubExp) -> ImpM rep r op ()) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \(VName
arr, SubExp
v) -> VName
-> Shape (TExp Int64)
-> SubExp
-> Shape (TExp Int64)
-> ImpM rep r op ()
forall rep r op.
VName
-> Shape (TExp Int64)
-> SubExp
-> Shape (TExp Int64)
-> ImpM rep r op ()
copyDWIMFix VName
arr Shape (TExp Int64)
is' SubExp
v []
      Just Lambda rep
lam -> do
        -- Generalised reduction.
        [LParam rep] -> ImpM rep r op ()
forall rep r op. Mem rep => [LParam rep] -> ImpM rep r op ()
dLParams ([LParam rep] -> ImpM rep r op ())
-> [LParam rep] -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [LParam rep]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda rep
lam
        let ([VName]
x_params, [VName]
y_params) =
              Int -> [VName] -> ([VName], [VName])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
vs) ([VName] -> ([VName], [VName])) -> [VName] -> ([VName], [VName])
forall a b. (a -> b) -> a -> b
$ (Param LParamMem -> VName) -> [Param LParamMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param LParamMem -> VName
forall dec. Param dec -> VName
paramName ([Param LParamMem] -> [VName]) -> [Param LParamMem] -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [LParam rep]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda rep
lam

        [(VName, VName)]
-> ((VName, VName) -> ImpM rep r op ()) -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
x_params [VName]
arrs) (((VName, VName) -> ImpM rep r op ()) -> ImpM rep r op ())
-> ((VName, VName) -> ImpM rep r op ()) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \(VName
xp, VName
arr) ->
          VName
-> Shape (TExp Int64)
-> SubExp
-> Shape (TExp Int64)
-> ImpM rep r op ()
forall rep r op.
VName
-> Shape (TExp Int64)
-> SubExp
-> Shape (TExp Int64)
-> ImpM rep r op ()
copyDWIMFix VName
xp [] (VName -> SubExp
Var VName
arr) Shape (TExp Int64)
is'

        [(VName, SubExp)]
-> ((VName, SubExp) -> ImpM rep r op ()) -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
y_params [SubExp]
vs) (((VName, SubExp) -> ImpM rep r op ()) -> ImpM rep r op ())
-> ((VName, SubExp) -> ImpM rep r op ()) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \(VName
yp, SubExp
v) ->
          VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM VName
yp [] SubExp
v []

        Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (BodyT rep -> Stms rep
forall rep. BodyT rep -> Stms rep
bodyStms (BodyT rep -> Stms rep) -> BodyT rep -> Stms rep
forall a b. (a -> b) -> a -> b
$ Lambda rep -> BodyT rep
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda rep
lam) (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
          [(VName, SubExp)]
-> ((VName, SubExp) -> ImpM rep r op ()) -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
arrs (BodyT rep -> [SubExp]
forall rep. BodyT rep -> [SubExp]
bodyResult (Lambda rep -> BodyT rep
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda rep
lam))) (((VName, SubExp) -> ImpM rep r op ()) -> ImpM rep r op ())
-> ((VName, SubExp) -> ImpM rep r op ()) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \(VName
arr, SubExp
se) ->
            VName
-> Shape (TExp Int64)
-> SubExp
-> Shape (TExp Int64)
-> ImpM rep r op ()
forall rep r op.
VName
-> Shape (TExp Int64)
-> SubExp
-> Shape (TExp Int64)
-> ImpM rep r op ()
copyDWIMFix VName
arr Shape (TExp Int64)
is' SubExp
se []
defCompileBasicOp PatternT (LetDec rep)
pat BasicOp
e =
  [Char] -> ImpM rep r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM rep r op ()) -> [Char] -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
    [Char]
"ImpGen.defCompileBasicOp: Invalid pattern\n  "
      [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ PatternT LParamMem -> [Char]
forall a. Pretty a => a -> [Char]
pretty PatternT (LetDec rep)
PatternT LParamMem
pat
      [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
"\nfor expression\n  "
      [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ BasicOp -> [Char]
forall a. Pretty a => a -> [Char]
pretty BasicOp
e

-- | Note: a hack to be used only for functions.
addArrays :: [ArrayDecl] -> ImpM rep r op ()
addArrays :: forall rep r op. [ArrayDecl] -> ImpM rep r op ()
addArrays = (ArrayDecl -> ImpM rep r op ()) -> [ArrayDecl] -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ArrayDecl -> ImpM rep r op ()
forall {rep} {r} {op}. ArrayDecl -> ImpM rep r op ()
addArray
  where
    addArray :: ArrayDecl -> ImpM rep r op ()
addArray (ArrayDecl VName
name PrimType
bt MemLocation
location) =
      VName -> VarEntry rep -> ImpM rep r op ()
forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
name (VarEntry rep -> ImpM rep r op ())
-> VarEntry rep -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
        Maybe (Exp rep) -> ArrayEntry -> VarEntry rep
forall rep. Maybe (Exp rep) -> ArrayEntry -> VarEntry rep
ArrayVar
          Maybe (Exp rep)
forall a. Maybe a
Nothing
          ArrayEntry :: MemLocation -> PrimType -> ArrayEntry
ArrayEntry
            { entryArrayLocation :: MemLocation
entryArrayLocation = MemLocation
location,
              entryArrayElemType :: PrimType
entryArrayElemType = PrimType
bt
            }

-- | Like 'dFParams', but does not create new declarations.
-- Note: a hack to be used only for functions.
addFParams :: Mem rep => [FParam rep] -> ImpM rep r op ()
addFParams :: forall rep r op. Mem rep => [FParam rep] -> ImpM rep r op ()
addFParams = (Param FParamMem -> ImpM rep r op ())
-> [Param FParamMem] -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Param FParamMem -> ImpM rep r op ()
forall {u} {rep} {r} {op}.
Param (MemInfo SubExp u MemBind) -> ImpM rep r op ()
addFParam
  where
    addFParam :: Param (MemInfo SubExp u MemBind) -> ImpM rep r op ()
addFParam Param (MemInfo SubExp u MemBind)
fparam =
      VName -> VarEntry rep -> ImpM rep r op ()
forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar (Param (MemInfo SubExp u MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo SubExp u MemBind)
fparam) (VarEntry rep -> ImpM rep r op ())
-> VarEntry rep -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
        Maybe (Exp rep) -> LParamMem -> VarEntry rep
forall rep. Maybe (Exp rep) -> LParamMem -> VarEntry rep
memBoundToVarEntry Maybe (Exp rep)
forall a. Maybe a
Nothing (LParamMem -> VarEntry rep) -> LParamMem -> VarEntry rep
forall a b. (a -> b) -> a -> b
$ MemInfo SubExp u MemBind -> LParamMem
forall d u r. MemInfo d u r -> MemInfo d NoUniqueness r
noUniquenessReturns (MemInfo SubExp u MemBind -> LParamMem)
-> MemInfo SubExp u MemBind -> LParamMem
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp u MemBind) -> MemInfo SubExp u MemBind
forall dec. Param dec -> dec
paramDec Param (MemInfo SubExp u MemBind)
fparam

-- | Another hack.
addLoopVar :: VName -> IntType -> ImpM rep r op ()
addLoopVar :: forall rep r op. VName -> IntType -> ImpM rep r op ()
addLoopVar VName
i IntType
it = VName -> VarEntry rep -> ImpM rep r op ()
forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
i (VarEntry rep -> ImpM rep r op ())
-> VarEntry rep -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp rep) -> ScalarEntry -> VarEntry rep
forall rep. Maybe (Exp rep) -> ScalarEntry -> VarEntry rep
ScalarVar Maybe (Exp rep)
forall a. Maybe a
Nothing (ScalarEntry -> VarEntry rep) -> ScalarEntry -> VarEntry rep
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarEntry
ScalarEntry (PrimType -> ScalarEntry) -> PrimType -> ScalarEntry
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
it

dVars ::
  Mem rep =>
  Maybe (Exp rep) ->
  [PatElem rep] ->
  ImpM rep r op ()
dVars :: forall rep r op.
Mem rep =>
Maybe (Exp rep) -> [PatElem rep] -> ImpM rep r op ()
dVars Maybe (Exp rep)
e = (PatElemT LParamMem -> ImpM rep r op ())
-> [PatElemT LParamMem] -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ PatElemT LParamMem -> ImpM rep r op ()
dVar
  where
    dVar :: PatElemT LParamMem -> ImpM rep r op ()
dVar = Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
forall rep r op.
Mem rep =>
Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
dScope Maybe (Exp rep)
e (Scope rep -> ImpM rep r op ())
-> (PatElemT LParamMem -> Scope rep)
-> PatElemT LParamMem
-> ImpM rep r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT LParamMem -> Scope rep
forall rep dec. (LetDec rep ~ dec) => PatElemT dec -> Scope rep
scopeOfPatElem

dFParams :: Mem rep => [FParam rep] -> ImpM rep r op ()
dFParams :: forall rep r op. Mem rep => [FParam rep] -> ImpM rep r op ()
dFParams = Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
forall rep r op.
Mem rep =>
Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
dScope Maybe (Exp rep)
forall a. Maybe a
Nothing (Scope rep -> ImpM rep r op ())
-> ([Param FParamMem] -> Scope rep)
-> [Param FParamMem]
-> ImpM rep r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Param FParamMem] -> Scope rep
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams

dLParams :: Mem rep => [LParam rep] -> ImpM rep r op ()
dLParams :: forall rep r op. Mem rep => [LParam rep] -> ImpM rep r op ()
dLParams = Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
forall rep r op.
Mem rep =>
Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
dScope Maybe (Exp rep)
forall a. Maybe a
Nothing (Scope rep -> ImpM rep r op ())
-> ([Param LParamMem] -> Scope rep)
-> [Param LParamMem]
-> ImpM rep r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Param LParamMem] -> Scope rep
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams

dPrimVol :: String -> PrimType -> Imp.TExp t -> ImpM rep r op (TV t)
dPrimVol :: forall t rep r op.
[Char] -> PrimType -> TExp t -> ImpM rep r op (TV t)
dPrimVol [Char]
name PrimType
t TExp t
e = do
  VName
name' <- [Char] -> ImpM rep r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
name
  Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Volatility -> PrimType -> Code op
forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
name' Volatility
Imp.Volatile PrimType
t
  VName -> VarEntry rep -> ImpM rep r op ()
forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
name' (VarEntry rep -> ImpM rep r op ())
-> VarEntry rep -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp rep) -> ScalarEntry -> VarEntry rep
forall rep. Maybe (Exp rep) -> ScalarEntry -> VarEntry rep
ScalarVar Maybe (Exp rep)
forall a. Maybe a
Nothing (ScalarEntry -> VarEntry rep) -> ScalarEntry -> VarEntry rep
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarEntry
ScalarEntry PrimType
t
  VName
name' VName -> Exp -> ImpM rep r op ()
forall rep r op. VName -> Exp -> ImpM rep r op ()
<~~ TExp t -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp t
e
  TV t -> ImpM rep r op (TV t)
forall (m :: * -> *) a. Monad m => a -> m a
return (TV t -> ImpM rep r op (TV t)) -> TV t -> ImpM rep r op (TV t)
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> TV t
forall t. VName -> PrimType -> TV t
TV VName
name' PrimType
t

dPrim_ :: VName -> PrimType -> ImpM rep r op ()
dPrim_ :: forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
name PrimType
t = do
  Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Volatility -> PrimType -> Code op
forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
name Volatility
Imp.Nonvolatile PrimType
t
  VName -> VarEntry rep -> ImpM rep r op ()
forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
name (VarEntry rep -> ImpM rep r op ())
-> VarEntry rep -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp rep) -> ScalarEntry -> VarEntry rep
forall rep. Maybe (Exp rep) -> ScalarEntry -> VarEntry rep
ScalarVar Maybe (Exp rep)
forall a. Maybe a
Nothing (ScalarEntry -> VarEntry rep) -> ScalarEntry -> VarEntry rep
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarEntry
ScalarEntry PrimType
t

-- | The return type is polymorphic, so there is no guarantee it
-- actually matches the 'PrimType', but at least we have to use it
-- consistently.
dPrim :: String -> PrimType -> ImpM rep r op (TV t)
dPrim :: forall rep r op t. [Char] -> PrimType -> ImpM rep r op (TV t)
dPrim [Char]
name PrimType
t = do
  VName
name' <- [Char] -> ImpM rep r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
name
  VName -> PrimType -> ImpM rep r op ()
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
name' PrimType
t
  TV t -> ImpM rep r op (TV t)
forall (m :: * -> *) a. Monad m => a -> m a
return (TV t -> ImpM rep r op (TV t)) -> TV t -> ImpM rep r op (TV t)
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> TV t
forall t. VName -> PrimType -> TV t
TV VName
name' PrimType
t

dPrimV_ :: VName -> Imp.TExp t -> ImpM rep r op ()
dPrimV_ :: forall t rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
name TExp t
e = do
  VName -> PrimType -> ImpM rep r op ()
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
name PrimType
t
  VName -> PrimType -> TV t
forall t. VName -> PrimType -> TV t
TV VName
name PrimType
t TV t -> TExp t -> ImpM rep r op ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TExp t
e
  where
    t :: PrimType
t = Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType (Exp -> PrimType) -> Exp -> PrimType
forall a b. (a -> b) -> a -> b
$ TExp t -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp t
e

dPrimV :: String -> Imp.TExp t -> ImpM rep r op (TV t)
dPrimV :: forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
name TExp t
e = do
  TV t
name' <- [Char] -> PrimType -> ImpM rep r op (TV t)
forall rep r op t. [Char] -> PrimType -> ImpM rep r op (TV t)
dPrim [Char]
name (PrimType -> ImpM rep r op (TV t))
-> PrimType -> ImpM rep r op (TV t)
forall a b. (a -> b) -> a -> b
$ Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType (Exp -> PrimType) -> Exp -> PrimType
forall a b. (a -> b) -> a -> b
$ TExp t -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp t
e
  TV t
name' TV t -> TExp t -> ImpM rep r op ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TExp t
e
  TV t -> ImpM rep r op (TV t)
forall (m :: * -> *) a. Monad m => a -> m a
return TV t
name'

dPrimVE :: String -> Imp.TExp t -> ImpM rep r op (Imp.TExp t)
dPrimVE :: forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
name TExp t
e = do
  TV t
name' <- [Char] -> PrimType -> ImpM rep r op (TV t)
forall rep r op t. [Char] -> PrimType -> ImpM rep r op (TV t)
dPrim [Char]
name (PrimType -> ImpM rep r op (TV t))
-> PrimType -> ImpM rep r op (TV t)
forall a b. (a -> b) -> a -> b
$ Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType (Exp -> PrimType) -> Exp -> PrimType
forall a b. (a -> b) -> a -> b
$ TExp t -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp t
e
  TV t
name' TV t -> TExp t -> ImpM rep r op ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TExp t
e
  TExp t -> ImpM rep r op (TExp t)
forall (m :: * -> *) a. Monad m => a -> m a
return (TExp t -> ImpM rep r op (TExp t))
-> TExp t -> ImpM rep r op (TExp t)
forall a b. (a -> b) -> a -> b
$ TV t -> TExp t
forall t. TV t -> TExp t
tvExp TV t
name'

memBoundToVarEntry ::
  Maybe (Exp rep) ->
  MemBound NoUniqueness ->
  VarEntry rep
memBoundToVarEntry :: forall rep. Maybe (Exp rep) -> LParamMem -> VarEntry rep
memBoundToVarEntry Maybe (Exp rep)
e (MemPrim PrimType
bt) =
  Maybe (Exp rep) -> ScalarEntry -> VarEntry rep
forall rep. Maybe (Exp rep) -> ScalarEntry -> VarEntry rep
ScalarVar Maybe (Exp rep)
e ScalarEntry :: PrimType -> ScalarEntry
ScalarEntry {entryScalarType :: PrimType
entryScalarType = PrimType
bt}
memBoundToVarEntry Maybe (Exp rep)
e (MemMem Space
space) =
  Maybe (Exp rep) -> MemEntry -> VarEntry rep
forall rep. Maybe (Exp rep) -> MemEntry -> VarEntry rep
MemVar Maybe (Exp rep)
e (MemEntry -> VarEntry rep) -> MemEntry -> VarEntry rep
forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
space
memBoundToVarEntry Maybe (Exp rep)
e (MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
_) =
  Maybe (Exp rep) -> (VName, Shape, [Type]) -> VarEntry rep
forall rep.
Maybe (Exp rep) -> (VName, Shape, [Type]) -> VarEntry rep
AccVar Maybe (Exp rep)
e (VName
acc, Shape
ispace, [Type]
ts)
memBoundToVarEntry Maybe (Exp rep)
e (MemArray PrimType
bt Shape
shape NoUniqueness
_ (ArrayIn VName
mem IxFun
ixfun)) =
  let location :: MemLocation
location = VName -> [SubExp] -> IxFun (TExp Int64) -> MemLocation
MemLocation VName
mem (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape) (IxFun (TExp Int64) -> MemLocation)
-> IxFun (TExp Int64) -> MemLocation
forall a b. (a -> b) -> a -> b
$ (TPrimExp Int64 VName -> TExp Int64) -> IxFun -> IxFun (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> ExpLeaf) -> TPrimExp Int64 VName -> TExp Int64
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> ExpLeaf
Imp.ScalarVar) IxFun
ixfun
   in Maybe (Exp rep) -> ArrayEntry -> VarEntry rep
forall rep. Maybe (Exp rep) -> ArrayEntry -> VarEntry rep
ArrayVar
        Maybe (Exp rep)
e
        ArrayEntry :: MemLocation -> PrimType -> ArrayEntry
ArrayEntry
          { entryArrayLocation :: MemLocation
entryArrayLocation = MemLocation
location,
            entryArrayElemType :: PrimType
entryArrayElemType = PrimType
bt
          }

infoDec ::
  Mem rep =>
  NameInfo rep ->
  MemInfo SubExp NoUniqueness MemBind
infoDec :: forall rep. Mem rep => NameInfo rep -> LParamMem
infoDec (LetName LetDec rep
dec) = LetDec rep
LParamMem
dec
infoDec (FParamName FParamInfo rep
dec) = FParamMem -> LParamMem
forall d u r. MemInfo d u r -> MemInfo d NoUniqueness r
noUniquenessReturns FParamInfo rep
FParamMem
dec
infoDec (LParamName LParamInfo rep
dec) = LParamInfo rep
LParamMem
dec
infoDec (IndexName IntType
it) = PrimType -> LParamMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim (PrimType -> LParamMem) -> PrimType -> LParamMem
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
it

dInfo ::
  Mem rep =>
  Maybe (Exp rep) ->
  VName ->
  NameInfo rep ->
  ImpM rep r op ()
dInfo :: forall rep r op.
Mem rep =>
Maybe (Exp rep) -> VName -> NameInfo rep -> ImpM rep r op ()
dInfo Maybe (Exp rep)
e VName
name NameInfo rep
info = do
  let entry :: VarEntry rep
entry = Maybe (Exp rep) -> LParamMem -> VarEntry rep
forall rep. Maybe (Exp rep) -> LParamMem -> VarEntry rep
memBoundToVarEntry Maybe (Exp rep)
e (LParamMem -> VarEntry rep) -> LParamMem -> VarEntry rep
forall a b. (a -> b) -> a -> b
$ NameInfo rep -> LParamMem
forall rep. Mem rep => NameInfo rep -> LParamMem
infoDec NameInfo rep
info
  case VarEntry rep
entry of
    MemVar Maybe (Exp rep)
_ MemEntry
entry' ->
      Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Code op
forall a. VName -> Space -> Code a
Imp.DeclareMem VName
name (Space -> Code op) -> Space -> Code op
forall a b. (a -> b) -> a -> b
$ MemEntry -> Space
entryMemSpace MemEntry
entry'
    ScalarVar Maybe (Exp rep)
_ ScalarEntry
entry' ->
      Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Volatility -> PrimType -> Code op
forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
name Volatility
Imp.Nonvolatile (PrimType -> Code op) -> PrimType -> Code op
forall a b. (a -> b) -> a -> b
$ ScalarEntry -> PrimType
entryScalarType ScalarEntry
entry'
    ArrayVar Maybe (Exp rep)
_ ArrayEntry
_ ->
      () -> ImpM rep r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    AccVar {} ->
      () -> ImpM rep r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  VName -> VarEntry rep -> ImpM rep r op ()
forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
name VarEntry rep
entry

dScope ::
  Mem rep =>
  Maybe (Exp rep) ->
  Scope rep ->
  ImpM rep r op ()
dScope :: forall rep r op.
Mem rep =>
Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
dScope Maybe (Exp rep)
e = ((VName, NameInfo rep) -> ImpM rep r op ())
-> [(VName, NameInfo rep)] -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((VName -> NameInfo rep -> ImpM rep r op ())
-> (VName, NameInfo rep) -> ImpM rep r op ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((VName -> NameInfo rep -> ImpM rep r op ())
 -> (VName, NameInfo rep) -> ImpM rep r op ())
-> (VName -> NameInfo rep -> ImpM rep r op ())
-> (VName, NameInfo rep)
-> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp rep) -> VName -> NameInfo rep -> ImpM rep r op ()
forall rep r op.
Mem rep =>
Maybe (Exp rep) -> VName -> NameInfo rep -> ImpM rep r op ()
dInfo Maybe (Exp rep)
e) ([(VName, NameInfo rep)] -> ImpM rep r op ())
-> (Map VName (NameInfo rep) -> [(VName, NameInfo rep)])
-> Map VName (NameInfo rep)
-> ImpM rep r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map VName (NameInfo rep) -> [(VName, NameInfo rep)]
forall k a. Map k a -> [(k, a)]
M.toList

dArray :: VName -> PrimType -> ShapeBase SubExp -> MemBind -> ImpM rep r op ()
dArray :: forall rep r op.
VName -> PrimType -> Shape -> MemBind -> ImpM rep r op ()
dArray VName
name PrimType
bt Shape
shape MemBind
membind =
  VName -> VarEntry rep -> ImpM rep r op ()
forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
name (VarEntry rep -> ImpM rep r op ())
-> VarEntry rep -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
    Maybe (Exp rep) -> LParamMem -> VarEntry rep
forall rep. Maybe (Exp rep) -> LParamMem -> VarEntry rep
memBoundToVarEntry Maybe (Exp rep)
forall a. Maybe a
Nothing (LParamMem -> VarEntry rep) -> LParamMem -> VarEntry rep
forall a b. (a -> b) -> a -> b
$ PrimType -> Shape -> NoUniqueness -> MemBind -> LParamMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt Shape
shape NoUniqueness
NoUniqueness MemBind
membind

everythingVolatile :: ImpM rep r op a -> ImpM rep r op a
everythingVolatile :: forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile = (Env rep r op -> Env rep r op)
-> ImpM rep r op a -> ImpM rep r op a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((Env rep r op -> Env rep r op)
 -> ImpM rep r op a -> ImpM rep r op a)
-> (Env rep r op -> Env rep r op)
-> ImpM rep r op a
-> ImpM rep r op a
forall a b. (a -> b) -> a -> b
$ \Env rep r op
env -> Env rep r op
env {envVolatility :: Volatility
envVolatility = Volatility
Imp.Volatile}

-- | Remove the array targets.
funcallTargets :: Destination -> ImpM rep r op [VName]
funcallTargets :: forall rep r op. Destination -> ImpM rep r op [VName]
funcallTargets (Destination Maybe Int
_ [ValueDestination]
dests) =
  [[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[VName]] -> [VName])
-> ImpM rep r op [[VName]] -> ImpM rep r op [VName]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ValueDestination -> ImpM rep r op [VName])
-> [ValueDestination] -> ImpM rep r op [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ValueDestination -> ImpM rep r op [VName]
forall {m :: * -> *}. Monad m => ValueDestination -> m [VName]
funcallTarget [ValueDestination]
dests
  where
    funcallTarget :: ValueDestination -> m [VName]
funcallTarget (ScalarDestination VName
name) =
      [VName] -> m [VName]
forall (m :: * -> *) a. Monad m => a -> m a
return [VName
name]
    funcallTarget (ArrayDestination Maybe MemLocation
_) =
      [VName] -> m [VName]
forall (m :: * -> *) a. Monad m => a -> m a
return []
    funcallTarget (MemoryDestination VName
name) =
      [VName] -> m [VName]
forall (m :: * -> *) a. Monad m => a -> m a
return [VName
name]

-- | A typed variable, which we can turn into a typed expression, or
-- use as the target for an assignment.  This is used to aid in type
-- safety when doing code generation, by keeping the types straight.
-- It is still easy to cheat when you need to.
data TV t = TV VName PrimType

-- | Create a typed variable from a name and a dynamic type.  Note
-- that there is no guarantee that the dynamic type corresponds to the
-- inferred static type, but the latter will at least have to be used
-- consistently.
mkTV :: VName -> PrimType -> TV t
mkTV :: forall t. VName -> PrimType -> TV t
mkTV = VName -> PrimType -> TV t
forall t. VName -> PrimType -> TV t
TV

-- | Convert a typed variable to a size (a SubExp).
tvSize :: TV t -> Imp.DimSize
tvSize :: forall t. TV t -> SubExp
tvSize = VName -> SubExp
Var (VName -> SubExp) -> (TV t -> VName) -> TV t -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TV t -> VName
forall t. TV t -> VName
tvVar

-- | Convert a typed variable to a similarly typed expression.
tvExp :: TV t -> Imp.TExp t
tvExp :: forall t. TV t -> TExp t
tvExp (TV VName
v PrimType
t) = Exp -> TPrimExp t ExpLeaf
forall t v. PrimExp v -> TPrimExp t v
Imp.TPrimExp (Exp -> TPrimExp t ExpLeaf) -> Exp -> TPrimExp t ExpLeaf
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
v PrimType
t

-- | Extract the underlying variable name from a typed variable.
tvVar :: TV t -> VName
tvVar :: forall t. TV t -> VName
tvVar (TV VName
v PrimType
_) = VName
v

-- | Compile things to 'Imp.Exp'.
class ToExp a where
  -- | Compile to an 'Imp.Exp', where the type (must must still be a
  -- primitive) is deduced monadically.
  toExp :: a -> ImpM rep r op Imp.Exp

  -- | Compile where we know the type in advance.
  toExp' :: PrimType -> a -> Imp.Exp

  toInt64Exp :: a -> Imp.TExp Int64
  toInt64Exp = Exp -> TExp Int64
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (Exp -> TExp Int64) -> (a -> Exp) -> a -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> a -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int64

  toBoolExp :: a -> Imp.TExp Bool
  toBoolExp = Exp -> TExp Bool
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (Exp -> TExp Bool) -> (a -> Exp) -> a -> TExp Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> a -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
Bool

instance ToExp SubExp where
  toExp :: forall rep r op. SubExp -> ImpM rep r op Exp
toExp (Constant PrimValue
v) =
    Exp -> ImpM rep r op Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> ImpM rep r op Exp) -> Exp -> ImpM rep r op Exp
forall a b. (a -> b) -> a -> b
$ PrimValue -> Exp
forall v. PrimValue -> PrimExp v
Imp.ValueExp PrimValue
v
  toExp (Var VName
v) =
    VName -> ImpM rep r op (VarEntry rep)
forall rep r op. VName -> ImpM rep r op (VarEntry rep)
lookupVar VName
v ImpM rep r op (VarEntry rep)
-> (VarEntry rep -> ImpM rep r op Exp) -> ImpM rep r op Exp
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      ScalarVar Maybe (Exp rep)
_ (ScalarEntry PrimType
pt) ->
        Exp -> ImpM rep r op Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> ImpM rep r op Exp) -> Exp -> ImpM rep r op Exp
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
v PrimType
pt
      VarEntry rep
_ -> [Char] -> ImpM rep r op Exp
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM rep r op Exp) -> [Char] -> ImpM rep r op Exp
forall a b. (a -> b) -> a -> b
$ [Char]
"toExp SubExp: SubExp is not a primitive type: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
v

  toExp' :: PrimType -> SubExp -> Exp
toExp' PrimType
_ (Constant PrimValue
v) = PrimValue -> Exp
forall v. PrimValue -> PrimExp v
Imp.ValueExp PrimValue
v
  toExp' PrimType
t (Var VName
v) = VName -> PrimType -> Exp
Imp.var VName
v PrimType
t

instance ToExp (PrimExp VName) where
  toExp :: forall rep r op. PrimExp VName -> ImpM rep r op Exp
toExp = Exp -> ImpM rep r op Exp
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp -> ImpM rep r op Exp)
-> (PrimExp VName -> Exp) -> PrimExp VName -> ImpM rep r op Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> ExpLeaf) -> PrimExp VName -> Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> ExpLeaf
Imp.ScalarVar
  toExp' :: PrimType -> PrimExp VName -> Exp
toExp' PrimType
_ = (VName -> ExpLeaf) -> PrimExp VName -> Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> ExpLeaf
Imp.ScalarVar

addVar :: VName -> VarEntry rep -> ImpM rep r op ()
addVar :: forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
name VarEntry rep
entry =
  (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ())
-> (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s -> ImpState rep r op
s {stateVTable :: VTable rep
stateVTable = VName -> VarEntry rep -> VTable rep -> VTable rep
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
name VarEntry rep
entry (VTable rep -> VTable rep) -> VTable rep -> VTable rep
forall a b. (a -> b) -> a -> b
$ ImpState rep r op -> VTable rep
forall rep r op. ImpState rep r op -> VTable rep
stateVTable ImpState rep r op
s}

localDefaultSpace :: Imp.Space -> ImpM rep r op a -> ImpM rep r op a
localDefaultSpace :: forall rep r op a. Space -> ImpM rep r op a -> ImpM rep r op a
localDefaultSpace Space
space = (Env rep r op -> Env rep r op)
-> ImpM rep r op a -> ImpM rep r op a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (\Env rep r op
env -> Env rep r op
env {envDefaultSpace :: Space
envDefaultSpace = Space
space})

askFunction :: ImpM rep r op (Maybe Name)
askFunction :: forall rep r op. ImpM rep r op (Maybe Name)
askFunction = (Env rep r op -> Maybe Name) -> ImpM rep r op (Maybe Name)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env rep r op -> Maybe Name
forall rep r op. Env rep r op -> Maybe Name
envFunction

-- | Generate a 'VName', prefixed with 'askFunction' if it exists.
newVNameForFun :: String -> ImpM rep r op VName
newVNameForFun :: forall rep r op. [Char] -> ImpM rep r op VName
newVNameForFun [Char]
s = do
  Maybe [Char]
fname <- (Name -> [Char]) -> Maybe Name -> Maybe [Char]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Name -> [Char]
nameToString (Maybe Name -> Maybe [Char])
-> ImpM rep r op (Maybe Name) -> ImpM rep r op (Maybe [Char])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM rep r op (Maybe Name)
forall rep r op. ImpM rep r op (Maybe Name)
askFunction
  [Char] -> ImpM rep r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName ([Char] -> ImpM rep r op VName) -> [Char] -> ImpM rep r op VName
forall a b. (a -> b) -> a -> b
$ [Char] -> ShowS -> Maybe [Char] -> [Char]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe [Char]
"" ([Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
".") Maybe [Char]
fname [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
s

-- | Generate a 'Name', prefixed with 'askFunction' if it exists.
nameForFun :: String -> ImpM rep r op Name
nameForFun :: forall rep r op. [Char] -> ImpM rep r op Name
nameForFun [Char]
s = do
  Maybe Name
fname <- ImpM rep r op (Maybe Name)
forall rep r op. ImpM rep r op (Maybe Name)
askFunction
  Name -> ImpM rep r op Name
forall (m :: * -> *) a. Monad m => a -> m a
return (Name -> ImpM rep r op Name) -> Name -> ImpM rep r op Name
forall a b. (a -> b) -> a -> b
$ Name -> (Name -> Name) -> Maybe Name -> Name
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Name
"" (Name -> Name -> Name
forall a. Semigroup a => a -> a -> a
<> Name
".") Maybe Name
fname Name -> Name -> Name
forall a. Semigroup a => a -> a -> a
<> [Char] -> Name
nameFromString [Char]
s

askEnv :: ImpM rep r op r
askEnv :: forall rep r op. ImpM rep r op r
askEnv = (Env rep r op -> r) -> ImpM rep r op r
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env rep r op -> r
forall rep r op. Env rep r op -> r
envEnv

localEnv :: (r -> r) -> ImpM rep r op a -> ImpM rep r op a
localEnv :: forall r rep op a. (r -> r) -> ImpM rep r op a -> ImpM rep r op a
localEnv r -> r
f = (Env rep r op -> Env rep r op)
-> ImpM rep r op a -> ImpM rep r op a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((Env rep r op -> Env rep r op)
 -> ImpM rep r op a -> ImpM rep r op a)
-> (Env rep r op -> Env rep r op)
-> ImpM rep r op a
-> ImpM rep r op a
forall a b. (a -> b) -> a -> b
$ \Env rep r op
env -> Env rep r op
env {envEnv :: r
envEnv = r -> r
f (r -> r) -> r -> r
forall a b. (a -> b) -> a -> b
$ Env rep r op -> r
forall rep r op. Env rep r op -> r
envEnv Env rep r op
env}

-- | The active attributes, including those for the statement
-- currently being compiled.
askAttrs :: ImpM rep r op Attrs
askAttrs :: forall rep r op. ImpM rep r op Attrs
askAttrs = (Env rep r op -> Attrs) -> ImpM rep r op Attrs
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env rep r op -> Attrs
forall rep r op. Env rep r op -> Attrs
envAttrs

-- | Add more attributes to what is returning by 'askAttrs'.
localAttrs :: Attrs -> ImpM rep r op a -> ImpM rep r op a
localAttrs :: forall rep r op a. Attrs -> ImpM rep r op a -> ImpM rep r op a
localAttrs Attrs
attrs = (Env rep r op -> Env rep r op)
-> ImpM rep r op a -> ImpM rep r op a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((Env rep r op -> Env rep r op)
 -> ImpM rep r op a -> ImpM rep r op a)
-> (Env rep r op -> Env rep r op)
-> ImpM rep r op a
-> ImpM rep r op a
forall a b. (a -> b) -> a -> b
$ \Env rep r op
env -> Env rep r op
env {envAttrs :: Attrs
envAttrs = Attrs
attrs Attrs -> Attrs -> Attrs
forall a. Semigroup a => a -> a -> a
<> Env rep r op -> Attrs
forall rep r op. Env rep r op -> Attrs
envAttrs Env rep r op
env}

localOps :: Operations rep r op -> ImpM rep r op a -> ImpM rep r op a
localOps :: forall rep r op a.
Operations rep r op -> ImpM rep r op a -> ImpM rep r op a
localOps Operations rep r op
ops = (Env rep r op -> Env rep r op)
-> ImpM rep r op a -> ImpM rep r op a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((Env rep r op -> Env rep r op)
 -> ImpM rep r op a -> ImpM rep r op a)
-> (Env rep r op -> Env rep r op)
-> ImpM rep r op a
-> ImpM rep r op a
forall a b. (a -> b) -> a -> b
$ \Env rep r op
env ->
  Env rep r op
env
    { envExpCompiler :: ExpCompiler rep r op
envExpCompiler = Operations rep r op -> ExpCompiler rep r op
forall rep r op. Operations rep r op -> ExpCompiler rep r op
opsExpCompiler Operations rep r op
ops,
      envStmsCompiler :: StmsCompiler rep r op
envStmsCompiler = Operations rep r op -> StmsCompiler rep r op
forall rep r op. Operations rep r op -> StmsCompiler rep r op
opsStmsCompiler Operations rep r op
ops,
      envCopyCompiler :: CopyCompiler rep r op
envCopyCompiler = Operations rep r op -> CopyCompiler rep r op
forall rep r op. Operations rep r op -> CopyCompiler rep r op
opsCopyCompiler Operations rep r op
ops,
      envOpCompiler :: OpCompiler rep r op
envOpCompiler = Operations rep r op -> OpCompiler rep r op
forall rep r op. Operations rep r op -> OpCompiler rep r op
opsOpCompiler Operations rep r op
ops,
      envAllocCompilers :: Map Space (AllocCompiler rep r op)
envAllocCompilers = Operations rep r op -> Map Space (AllocCompiler rep r op)
forall rep r op.
Operations rep r op -> Map Space (AllocCompiler rep r op)
opsAllocCompilers Operations rep r op
ops
    }

-- | Get the current symbol table.
getVTable :: ImpM rep r op (VTable rep)
getVTable :: forall rep r op. ImpM rep r op (VTable rep)
getVTable = (ImpState rep r op -> VTable rep) -> ImpM rep r op (VTable rep)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ImpState rep r op -> VTable rep
forall rep r op. ImpState rep r op -> VTable rep
stateVTable

putVTable :: VTable rep -> ImpM rep r op ()
putVTable :: forall rep r op. VTable rep -> ImpM rep r op ()
putVTable VTable rep
vtable = (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ())
-> (ImpState rep r op -> ImpState rep r op) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s -> ImpState rep r op
s {stateVTable :: VTable rep
stateVTable = VTable rep
vtable}

-- | Run an action with a modified symbol table.  All changes to the
-- symbol table will be reverted once the action is done!
localVTable :: (VTable rep -> VTable rep) -> ImpM rep r op a -> ImpM rep r op a
localVTable :: forall rep r op a.
(VTable rep -> VTable rep) -> ImpM rep r op a -> ImpM rep r op a
localVTable VTable rep -> VTable rep
f ImpM rep r op a
m = do
  VTable rep
old_vtable <- ImpM rep r op (VTable rep)
forall rep r op. ImpM rep r op (VTable rep)
getVTable
  VTable rep -> ImpM rep r op ()
forall rep r op. VTable rep -> ImpM rep r op ()
putVTable (VTable rep -> ImpM rep r op ()) -> VTable rep -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VTable rep -> VTable rep
f VTable rep
old_vtable
  a
a <- ImpM rep r op a
m
  VTable rep -> ImpM rep r op ()
forall rep r op. VTable rep -> ImpM rep r op ()
putVTable VTable rep
old_vtable
  a -> ImpM rep r op a
forall (m :: * -> *) a. Monad m => a -> m a
return a
a

lookupVar :: VName -> ImpM rep r op (VarEntry rep)
lookupVar :: forall rep r op. VName -> ImpM rep r op (VarEntry rep)
lookupVar VName
name = do
  Maybe (VarEntry rep)
res <- (ImpState rep r op -> Maybe (VarEntry rep))
-> ImpM rep r op (Maybe (VarEntry rep))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ImpState rep r op -> Maybe (VarEntry rep))
 -> ImpM rep r op (Maybe (VarEntry rep)))
-> (ImpState rep r op -> Maybe (VarEntry rep))
-> ImpM rep r op (Maybe (VarEntry rep))
forall a b. (a -> b) -> a -> b
$ VName -> Map VName (VarEntry rep) -> Maybe (VarEntry rep)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name (Map VName (VarEntry rep) -> Maybe (VarEntry rep))
-> (ImpState rep r op -> Map VName (VarEntry rep))
-> ImpState rep r op
-> Maybe (VarEntry rep)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ImpState rep r op -> Map VName (VarEntry rep)
forall rep r op. ImpState rep r op -> VTable rep
stateVTable
  case Maybe (VarEntry rep)
res of
    Just VarEntry rep
entry -> VarEntry rep -> ImpM rep r op (VarEntry rep)
forall (m :: * -> *) a. Monad m => a -> m a
return VarEntry rep
entry
    Maybe (VarEntry rep)
_ -> [Char] -> ImpM rep r op (VarEntry rep)
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM rep r op (VarEntry rep))
-> [Char] -> ImpM rep r op (VarEntry rep)
forall a b. (a -> b) -> a -> b
$ [Char]
"Unknown variable: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
name

lookupArray :: VName -> ImpM rep r op ArrayEntry
lookupArray :: forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray VName
name = do
  VarEntry rep
res <- VName -> ImpM rep r op (VarEntry rep)
forall rep r op. VName -> ImpM rep r op (VarEntry rep)
lookupVar VName
name
  case VarEntry rep
res of
    ArrayVar Maybe (Exp rep)
_ ArrayEntry
entry -> ArrayEntry -> ImpM rep r op ArrayEntry
forall (m :: * -> *) a. Monad m => a -> m a
return ArrayEntry
entry
    VarEntry rep
_ -> [Char] -> ImpM rep r op ArrayEntry
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM rep r op ArrayEntry)
-> [Char] -> ImpM rep r op ArrayEntry
forall a b. (a -> b) -> a -> b
$ [Char]
"ImpGen.lookupArray: not an array: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
name

lookupMemory :: VName -> ImpM rep r op MemEntry
lookupMemory :: forall rep r op. VName -> ImpM rep r op MemEntry
lookupMemory VName
name = do
  VarEntry rep
res <- VName -> ImpM rep r op (VarEntry rep)
forall rep r op. VName -> ImpM rep r op (VarEntry rep)
lookupVar VName
name
  case VarEntry rep
res of
    MemVar Maybe (Exp rep)
_ MemEntry
entry -> MemEntry -> ImpM rep r op MemEntry
forall (m :: * -> *) a. Monad m => a -> m a
return MemEntry
entry
    VarEntry rep
_ -> [Char] -> ImpM rep r op MemEntry
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM rep r op MemEntry)
-> [Char] -> ImpM rep r op MemEntry
forall a b. (a -> b) -> a -> b
$ [Char]
"Unknown memory block: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
name

lookupArraySpace :: VName -> ImpM rep r op Space
lookupArraySpace :: forall rep r op. VName -> ImpM rep r op Space
lookupArraySpace =
  (MemEntry -> Space)
-> ImpM rep r op MemEntry -> ImpM rep r op Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap MemEntry -> Space
entryMemSpace (ImpM rep r op MemEntry -> ImpM rep r op Space)
-> (VName -> ImpM rep r op MemEntry)
-> VName
-> ImpM rep r op Space
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> ImpM rep r op MemEntry
forall rep r op. VName -> ImpM rep r op MemEntry
lookupMemory
    (VName -> ImpM rep r op Space)
-> (VName -> ImpM rep r op VName) -> VName -> ImpM rep r op Space
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< (ArrayEntry -> VName)
-> ImpM rep r op ArrayEntry -> ImpM rep r op VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (MemLocation -> VName
memLocationName (MemLocation -> VName)
-> (ArrayEntry -> MemLocation) -> ArrayEntry -> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArrayEntry -> MemLocation
entryArrayLocation) (ImpM rep r op ArrayEntry -> ImpM rep r op VName)
-> (VName -> ImpM rep r op ArrayEntry)
-> VName
-> ImpM rep r op VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> ImpM rep r op ArrayEntry
forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray

-- | In the case of a histogram-like accumulator, also sets the index
-- parameters.
lookupAcc ::
  VName ->
  [Imp.TExp Int64] ->
  ImpM rep r op (VName, Space, [VName], [Imp.TExp Int64], Maybe (Lambda rep))
lookupAcc :: forall rep r op.
VName
-> Shape (TExp Int64)
-> ImpM
     rep
     r
     op
     (VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda rep))
lookupAcc VName
name Shape (TExp Int64)
is = do
  VarEntry rep
res <- VName -> ImpM rep r op (VarEntry rep)
forall rep r op. VName -> ImpM rep r op (VarEntry rep)
lookupVar VName
name
  case VarEntry rep
res of
    AccVar Maybe (Exp rep)
_ (VName
acc, Shape
ispace, [Type]
_) -> do
      Maybe ([VName], Maybe (Lambda rep, [SubExp]))
acc' <- (ImpState rep r op
 -> Maybe ([VName], Maybe (Lambda rep, [SubExp])))
-> ImpM rep r op (Maybe ([VName], Maybe (Lambda rep, [SubExp])))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ImpState rep r op
  -> Maybe ([VName], Maybe (Lambda rep, [SubExp])))
 -> ImpM rep r op (Maybe ([VName], Maybe (Lambda rep, [SubExp]))))
-> (ImpState rep r op
    -> Maybe ([VName], Maybe (Lambda rep, [SubExp])))
-> ImpM rep r op (Maybe ([VName], Maybe (Lambda rep, [SubExp])))
forall a b. (a -> b) -> a -> b
$ VName
-> Map VName ([VName], Maybe (Lambda rep, [SubExp]))
-> Maybe ([VName], Maybe (Lambda rep, [SubExp]))
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
acc (Map VName ([VName], Maybe (Lambda rep, [SubExp]))
 -> Maybe ([VName], Maybe (Lambda rep, [SubExp])))
-> (ImpState rep r op
    -> Map VName ([VName], Maybe (Lambda rep, [SubExp])))
-> ImpState rep r op
-> Maybe ([VName], Maybe (Lambda rep, [SubExp]))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ImpState rep r op
-> Map VName ([VName], Maybe (Lambda rep, [SubExp]))
forall rep r op.
ImpState rep r op
-> Map VName ([VName], Maybe (Lambda rep, [SubExp]))
stateAccs
      case Maybe ([VName], Maybe (Lambda rep, [SubExp]))
acc' of
        Just ([], Maybe (Lambda rep, [SubExp])
_) ->
          [Char]
-> ImpM
     rep
     r
     op
     (VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda rep))
forall a. HasCallStack => [Char] -> a
error ([Char]
 -> ImpM
      rep
      r
      op
      (VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda rep)))
-> [Char]
-> ImpM
     rep
     r
     op
     (VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda rep))
forall a b. (a -> b) -> a -> b
$ [Char]
"Accumulator with no arrays: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
name
        Just (arrs :: [VName]
arrs@(VName
arr : [VName]
_), Just (Lambda rep
op, [SubExp]
_)) -> do
          Space
space <- VName -> ImpM rep r op Space
forall rep r op. VName -> ImpM rep r op Space
lookupArraySpace VName
arr
          let ([Param (LParamInfo rep)]
i_params, [Param (LParamInfo rep)]
ps) = Int
-> [Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a. Int -> [a] -> ([a], [a])
splitAt (Shape (TExp Int64) -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape (TExp Int64)
is) ([Param (LParamInfo rep)]
 -> ([Param (LParamInfo rep)], [Param (LParamInfo rep)]))
-> [Param (LParamInfo rep)]
-> ([Param (LParamInfo rep)], [Param (LParamInfo rep)])
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [Param (LParamInfo rep)]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda rep
op
          (VName -> TExp Int64 -> ImpM rep r op ())
-> [VName] -> Shape (TExp Int64) -> ImpM rep r op ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TExp Int64 -> ImpM rep r op ()
forall t rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ ((Param (LParamInfo rep) -> VName)
-> [Param (LParamInfo rep)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (LParamInfo rep) -> VName
forall dec. Param dec -> VName
paramName [Param (LParamInfo rep)]
i_params) Shape (TExp Int64)
is
          (VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda rep))
-> ImpM
     rep
     r
     op
     (VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda rep))
forall (m :: * -> *) a. Monad m => a -> m a
return
            ( VName
acc,
              Space
space,
              [VName]
arrs,
              (SubExp -> TExp Int64) -> [SubExp] -> Shape (TExp Int64)
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
ispace),
              Lambda rep -> Maybe (Lambda rep)
forall a. a -> Maybe a
Just Lambda rep
op {lambdaParams :: [Param (LParamInfo rep)]
lambdaParams = [Param (LParamInfo rep)]
ps}
            )
        Just (arrs :: [VName]
arrs@(VName
arr : [VName]
_), Maybe (Lambda rep, [SubExp])
Nothing) -> do
          Space
space <- VName -> ImpM rep r op Space
forall rep r op. VName -> ImpM rep r op Space
lookupArraySpace VName
arr
          (VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda rep))
-> ImpM
     rep
     r
     op
     (VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda rep))
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
acc, Space
space, [VName]
arrs, (SubExp -> TExp Int64) -> [SubExp] -> Shape (TExp Int64)
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
ispace), Maybe (Lambda rep)
forall a. Maybe a
Nothing)
        Maybe ([VName], Maybe (Lambda rep, [SubExp]))
Nothing ->
          [Char]
-> ImpM
     rep
     r
     op
     (VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda rep))
forall a. HasCallStack => [Char] -> a
error ([Char]
 -> ImpM
      rep
      r
      op
      (VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda rep)))
-> [Char]
-> ImpM
     rep
     r
     op
     (VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda rep))
forall a b. (a -> b) -> a -> b
$ [Char]
"ImpGen.lookupAcc: unlisted accumulator: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
name
    VarEntry rep
_ -> [Char]
-> ImpM
     rep
     r
     op
     (VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda rep))
forall a. HasCallStack => [Char] -> a
error ([Char]
 -> ImpM
      rep
      r
      op
      (VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda rep)))
-> [Char]
-> ImpM
     rep
     r
     op
     (VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda rep))
forall a b. (a -> b) -> a -> b
$ [Char]
"ImpGen.lookupAcc: not an accumulator: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
name

destinationFromPattern :: Mem rep => Pattern rep -> ImpM rep r op Destination
destinationFromPattern :: forall rep r op.
Mem rep =>
Pattern rep -> ImpM rep r op Destination
destinationFromPattern Pattern rep
pat =
  ([ValueDestination] -> Destination)
-> ImpM rep r op [ValueDestination] -> ImpM rep r op Destination
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Maybe Int -> [ValueDestination] -> Destination
Destination (VName -> Int
baseTag (VName -> Int) -> Maybe VName -> Maybe Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [VName] -> Maybe VName
forall a. [a] -> Maybe a
maybeHead (PatternT LParamMem -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern rep
PatternT LParamMem
pat))) (ImpM rep r op [ValueDestination] -> ImpM rep r op Destination)
-> ([PatElemT LParamMem] -> ImpM rep r op [ValueDestination])
-> [PatElemT LParamMem]
-> ImpM rep r op Destination
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElemT LParamMem -> ImpM rep r op ValueDestination)
-> [PatElemT LParamMem] -> ImpM rep r op [ValueDestination]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM PatElemT LParamMem -> ImpM rep r op ValueDestination
forall {dec} {rep} {r} {op}.
PatElemT dec -> ImpM rep r op ValueDestination
inspect ([PatElemT LParamMem] -> ImpM rep r op Destination)
-> [PatElemT LParamMem] -> ImpM rep r op Destination
forall a b. (a -> b) -> a -> b
$
    PatternT LParamMem -> [PatElemT LParamMem]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern rep
PatternT LParamMem
pat
  where
    inspect :: PatElemT dec -> ImpM rep r op ValueDestination
inspect PatElemT dec
patElem = do
      let name :: VName
name = PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT dec
patElem
      VarEntry rep
entry <- VName -> ImpM rep r op (VarEntry rep)
forall rep r op. VName -> ImpM rep r op (VarEntry rep)
lookupVar VName
name
      case VarEntry rep
entry of
        ArrayVar Maybe (Exp rep)
_ (ArrayEntry MemLocation {} PrimType
_) ->
          ValueDestination -> ImpM rep r op ValueDestination
forall (m :: * -> *) a. Monad m => a -> m a
return (ValueDestination -> ImpM rep r op ValueDestination)
-> ValueDestination -> ImpM rep r op ValueDestination
forall a b. (a -> b) -> a -> b
$ Maybe MemLocation -> ValueDestination
ArrayDestination Maybe MemLocation
forall a. Maybe a
Nothing
        MemVar {} ->
          ValueDestination -> ImpM rep r op ValueDestination
forall (m :: * -> *) a. Monad m => a -> m a
return (ValueDestination -> ImpM rep r op ValueDestination)
-> ValueDestination -> ImpM rep r op ValueDestination
forall a b. (a -> b) -> a -> b
$ VName -> ValueDestination
MemoryDestination VName
name
        ScalarVar {} ->
          ValueDestination -> ImpM rep r op ValueDestination
forall (m :: * -> *) a. Monad m => a -> m a
return (ValueDestination -> ImpM rep r op ValueDestination)
-> ValueDestination -> ImpM rep r op ValueDestination
forall a b. (a -> b) -> a -> b
$ VName -> ValueDestination
ScalarDestination VName
name
        AccVar {} ->
          ValueDestination -> ImpM rep r op ValueDestination
forall (m :: * -> *) a. Monad m => a -> m a
return (ValueDestination -> ImpM rep r op ValueDestination)
-> ValueDestination -> ImpM rep r op ValueDestination
forall a b. (a -> b) -> a -> b
$ Maybe MemLocation -> ValueDestination
ArrayDestination Maybe MemLocation
forall a. Maybe a
Nothing

fullyIndexArray ::
  VName ->
  [Imp.TExp Int64] ->
  ImpM rep r op (VName, Imp.Space, Count Elements (Imp.TExp Int64))
fullyIndexArray :: forall rep r op.
VName
-> Shape (TExp Int64)
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray VName
name Shape (TExp Int64)
indices = do
  ArrayEntry
arr <- VName -> ImpM rep r op ArrayEntry
forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray VName
name
  MemLocation
-> Shape (TExp Int64)
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
forall rep r op.
MemLocation
-> Shape (TExp Int64)
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' (ArrayEntry -> MemLocation
entryArrayLocation ArrayEntry
arr) Shape (TExp Int64)
indices

fullyIndexArray' ::
  MemLocation ->
  [Imp.TExp Int64] ->
  ImpM rep r op (VName, Imp.Space, Count Elements (Imp.TExp Int64))
fullyIndexArray' :: forall rep r op.
MemLocation
-> Shape (TExp Int64)
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' (MemLocation VName
mem [SubExp]
_ IxFun (TExp Int64)
ixfun) Shape (TExp Int64)
indices = do
  Space
space <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM rep r op MemEntry -> ImpM rep r op Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM rep r op MemEntry
forall rep r op. VName -> ImpM rep r op MemEntry
lookupMemory VName
mem
  (VName, Space, Count Elements (TExp Int64))
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( VName
mem,
      Space
space,
      TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
elements (TExp Int64 -> Count Elements (TExp Int64))
-> TExp Int64 -> Count Elements (TExp Int64)
forall a b. (a -> b) -> a -> b
$ IxFun (TExp Int64) -> Shape (TExp Int64) -> TExp Int64
forall num.
(IntegralExp num, Eq num) =>
IxFun num -> Indices num -> num
IxFun.index IxFun (TExp Int64)
ixfun Shape (TExp Int64)
indices
    )

-- More complicated read/write operations that use index functions.

copy :: CopyCompiler rep r op
copy :: forall rep r op. CopyCompiler rep r op
copy PrimType
bt MemLocation
dest [DimIndex (TExp Int64)]
destslice MemLocation
src [DimIndex (TExp Int64)]
srcslice = do
  CopyCompiler rep r op
cc <- (Env rep r op -> CopyCompiler rep r op)
-> ImpM rep r op (CopyCompiler rep r op)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env rep r op -> CopyCompiler rep r op
forall rep r op. Env rep r op -> CopyCompiler rep r op
envCopyCompiler
  CopyCompiler rep r op
cc PrimType
bt MemLocation
dest [DimIndex (TExp Int64)]
destslice MemLocation
src [DimIndex (TExp Int64)]
srcslice

-- | Is this copy really a mapping with transpose?
isMapTransposeCopy ::
  PrimType ->
  MemLocation ->
  Slice (Imp.TExp Int64) ->
  MemLocation ->
  Slice (Imp.TExp Int64) ->
  Maybe
    ( Imp.TExp Int64,
      Imp.TExp Int64,
      Imp.TExp Int64,
      Imp.TExp Int64,
      Imp.TExp Int64
    )
isMapTransposeCopy :: PrimType
-> MemLocation
-> [DimIndex (TExp Int64)]
-> MemLocation
-> [DimIndex (TExp Int64)]
-> Maybe
     (TExp Int64, TExp Int64, TExp Int64, TExp Int64, TExp Int64)
isMapTransposeCopy
  PrimType
bt
  (MemLocation VName
_ [SubExp]
_ IxFun (TExp Int64)
destIxFun)
  [DimIndex (TExp Int64)]
destslice
  (MemLocation VName
_ [SubExp]
_ IxFun (TExp Int64)
srcIxFun)
  [DimIndex (TExp Int64)]
srcslice
    | Just (TExp Int64
dest_offset, [(Int, TExp Int64)]
perm_and_destshape) <- IxFun (TExp Int64)
-> TExp Int64 -> Maybe (TExp Int64, [(Int, TExp Int64)])
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe (num, [(Int, num)])
IxFun.rearrangeWithOffset IxFun (TExp Int64)
destIxFun' TExp Int64
bt_size,
      ([Int]
perm, Shape (TExp Int64)
destshape) <- [(Int, TExp Int64)] -> ([Int], Shape (TExp Int64))
forall a b. [(a, b)] -> ([a], [b])
unzip [(Int, TExp Int64)]
perm_and_destshape,
      Just TExp Int64
src_offset <- IxFun (TExp Int64) -> TExp Int64 -> Maybe (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
IxFun.linearWithOffset IxFun (TExp Int64)
srcIxFun' TExp Int64
bt_size,
      Just (Int
r1, Int
r2, Int
_) <- [Int] -> Maybe (Int, Int, Int)
isMapTranspose [Int]
perm =
      Shape (TExp Int64)
-> ((Shape (TExp Int64), Shape (TExp Int64))
    -> (Shape (TExp Int64), Shape (TExp Int64)))
-> Int
-> Int
-> TExp Int64
-> TExp Int64
-> Maybe
     (TExp Int64, TExp Int64, TExp Int64, TExp Int64, TExp Int64)
forall {t :: * -> *} {t :: * -> *} {c} {d} {e} {m :: * -> *} {a}
       {b}.
(Foldable t, Foldable t, Num c, Num d, Num e, Monad m) =>
[c]
-> (([c], [c]) -> (t d, t e))
-> Int
-> Int
-> a
-> b
-> m (a, b, c, d, e)
isOk Shape (TExp Int64)
destshape (Shape (TExp Int64), Shape (TExp Int64))
-> (Shape (TExp Int64), Shape (TExp Int64))
forall {b} {a}. (b, a) -> (a, b)
swap Int
r1 Int
r2 TExp Int64
dest_offset TExp Int64
src_offset
    | Just TExp Int64
dest_offset <- IxFun (TExp Int64) -> TExp Int64 -> Maybe (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
IxFun.linearWithOffset IxFun (TExp Int64)
destIxFun' TExp Int64
bt_size,
      Just (TExp Int64
src_offset, [(Int, TExp Int64)]
perm_and_srcshape) <- IxFun (TExp Int64)
-> TExp Int64 -> Maybe (TExp Int64, [(Int, TExp Int64)])
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe (num, [(Int, num)])
IxFun.rearrangeWithOffset IxFun (TExp Int64)
srcIxFun' TExp Int64
bt_size,
      ([Int]
perm, Shape (TExp Int64)
srcshape) <- [(Int, TExp Int64)] -> ([Int], Shape (TExp Int64))
forall a b. [(a, b)] -> ([a], [b])
unzip [(Int, TExp Int64)]
perm_and_srcshape,
      Just (Int
r1, Int
r2, Int
_) <- [Int] -> Maybe (Int, Int, Int)
isMapTranspose [Int]
perm =
      Shape (TExp Int64)
-> ((Shape (TExp Int64), Shape (TExp Int64))
    -> (Shape (TExp Int64), Shape (TExp Int64)))
-> Int
-> Int
-> TExp Int64
-> TExp Int64
-> Maybe
     (TExp Int64, TExp Int64, TExp Int64, TExp Int64, TExp Int64)
forall {t :: * -> *} {t :: * -> *} {c} {d} {e} {m :: * -> *} {a}
       {b}.
(Foldable t, Foldable t, Num c, Num d, Num e, Monad m) =>
[c]
-> (([c], [c]) -> (t d, t e))
-> Int
-> Int
-> a
-> b
-> m (a, b, c, d, e)
isOk Shape (TExp Int64)
srcshape (Shape (TExp Int64), Shape (TExp Int64))
-> (Shape (TExp Int64), Shape (TExp Int64))
forall a. a -> a
id Int
r1 Int
r2 TExp Int64
dest_offset TExp Int64
src_offset
    | Bool
otherwise =
      Maybe (TExp Int64, TExp Int64, TExp Int64, TExp Int64, TExp Int64)
forall a. Maybe a
Nothing
    where
      bt_size :: TExp Int64
bt_size = PrimType -> TExp Int64
forall a. Num a => PrimType -> a
primByteSize PrimType
bt
      swap :: (b, a) -> (a, b)
swap (b
x, a
y) = (a
y, b
x)

      destIxFun' :: IxFun (TExp Int64)
destIxFun' = IxFun (TExp Int64) -> [DimIndex (TExp Int64)] -> IxFun (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TExp Int64)
destIxFun [DimIndex (TExp Int64)]
destslice
      srcIxFun' :: IxFun (TExp Int64)
srcIxFun' = IxFun (TExp Int64) -> [DimIndex (TExp Int64)] -> IxFun (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TExp Int64)
srcIxFun [DimIndex (TExp Int64)]
srcslice

      isOk :: [c]
-> (([c], [c]) -> (t d, t e))
-> Int
-> Int
-> a
-> b
-> m (a, b, c, d, e)
isOk [c]
shape ([c], [c]) -> (t d, t e)
f Int
r1 Int
r2 a
dest_offset b
src_offset = do
        let (c
num_arrays, d
size_x, e
size_y) = [c] -> (([c], [c]) -> (t d, t e)) -> Int -> Int -> (c, d, e)
forall {t :: * -> *} {t :: * -> *} {a} {b} {c}.
(Foldable t, Foldable t, Num a, Num b, Num c) =>
[a] -> (([a], [a]) -> (t b, t c)) -> Int -> Int -> (a, b, c)
getSizes [c]
shape ([c], [c]) -> (t d, t e)
f Int
r1 Int
r2
        (a, b, c, d, e) -> m (a, b, c, d, e)
forall (m :: * -> *) a. Monad m => a -> m a
return
          ( a
dest_offset,
            b
src_offset,
            c
num_arrays,
            d
size_x,
            e
size_y
          )

      getSizes :: [a] -> (([a], [a]) -> (t b, t c)) -> Int -> Int -> (a, b, c)
getSizes [a]
shape ([a], [a]) -> (t b, t c)
f Int
r1 Int
r2 =
        let ([a]
mapped, [a]
notmapped) = Int -> [a] -> ([a], [a])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
r1 [a]
shape
            (t b
pretrans, t c
posttrans) = ([a], [a]) -> (t b, t c)
f (([a], [a]) -> (t b, t c)) -> ([a], [a]) -> (t b, t c)
forall a b. (a -> b) -> a -> b
$ Int -> [a] -> ([a], [a])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
r2 [a]
notmapped
         in ([a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [a]
mapped, t b -> b
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product t b
pretrans, t c -> c
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product t c
posttrans)

mapTransposeName :: PrimType -> String
mapTransposeName :: PrimType -> [Char]
mapTransposeName PrimType
bt = [Char]
"map_transpose_" [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ PrimType -> [Char]
forall a. Pretty a => a -> [Char]
pretty PrimType
bt

mapTransposeForType :: PrimType -> ImpM rep r op Name
mapTransposeForType :: forall rep r op. PrimType -> ImpM rep r op Name
mapTransposeForType PrimType
bt = do
  let fname :: Name
fname = [Char] -> Name
nameFromString ([Char] -> Name) -> [Char] -> Name
forall a b. (a -> b) -> a -> b
$ [Char]
"builtin#" [Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> PrimType -> [Char]
mapTransposeName PrimType
bt

  Bool
exists <- Name -> ImpM rep r op Bool
forall rep r op. Name -> ImpM rep r op Bool
hasFunction Name
fname
  Bool -> ImpM rep r op () -> ImpM rep r op ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
exists (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Name -> Function op -> ImpM rep r op ()
forall op rep r. Name -> Function op -> ImpM rep r op ()
emitFunction Name
fname (Function op -> ImpM rep r op ())
-> Function op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Name -> PrimType -> Function op
forall op. Name -> PrimType -> Function op
mapTransposeFunction Name
fname PrimType
bt

  Name -> ImpM rep r op Name
forall (m :: * -> *) a. Monad m => a -> m a
return Name
fname

-- | Use an 'Imp.Copy' if possible, otherwise 'copyElementWise'.
defaultCopy :: CopyCompiler rep r op
defaultCopy :: forall rep r op. CopyCompiler rep r op
defaultCopy PrimType
pt MemLocation
dest [DimIndex (TExp Int64)]
destslice MemLocation
src [DimIndex (TExp Int64)]
srcslice
  | Just
      ( TExp Int64
destoffset,
        TExp Int64
srcoffset,
        TExp Int64
num_arrays,
        TExp Int64
size_x,
        TExp Int64
size_y
        ) <-
      PrimType
-> MemLocation
-> [DimIndex (TExp Int64)]
-> MemLocation
-> [DimIndex (TExp Int64)]
-> Maybe
     (TExp Int64, TExp Int64, TExp Int64, TExp Int64, TExp Int64)
isMapTransposeCopy PrimType
pt MemLocation
dest [DimIndex (TExp Int64)]
destslice MemLocation
src [DimIndex (TExp Int64)]
srcslice = do
    Name
fname <- PrimType -> ImpM rep r op Name
forall rep r op. PrimType -> ImpM rep r op Name
mapTransposeForType PrimType
pt
    Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
      [VName] -> Name -> [Arg] -> Code op
forall a. [VName] -> Name -> [Arg] -> Code a
Imp.Call
        []
        Name
fname
        ([Arg] -> Code op) -> [Arg] -> Code op
forall a b. (a -> b) -> a -> b
$ PrimType
-> VName
-> Count Bytes (TExp Int64)
-> VName
-> Count Bytes (TExp Int64)
-> TExp Int64
-> TExp Int64
-> TExp Int64
-> [Arg]
transposeArgs
          PrimType
pt
          VName
destmem
          (TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
bytes TExp Int64
destoffset)
          VName
srcmem
          (TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
bytes TExp Int64
srcoffset)
          TExp Int64
num_arrays
          TExp Int64
size_x
          TExp Int64
size_y
  | Just TExp Int64
destoffset <-
      IxFun (TExp Int64) -> TExp Int64 -> Maybe (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
IxFun.linearWithOffset (IxFun (TExp Int64) -> [DimIndex (TExp Int64)] -> IxFun (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TExp Int64)
dest_ixfun [DimIndex (TExp Int64)]
destslice) TExp Int64
pt_size,
    Just TExp Int64
srcoffset <-
      IxFun (TExp Int64) -> TExp Int64 -> Maybe (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
IxFun.linearWithOffset (IxFun (TExp Int64) -> [DimIndex (TExp Int64)] -> IxFun (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TExp Int64)
src_ixfun [DimIndex (TExp Int64)]
srcslice) TExp Int64
pt_size = do
    Space
srcspace <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM rep r op MemEntry -> ImpM rep r op Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM rep r op MemEntry
forall rep r op. VName -> ImpM rep r op MemEntry
lookupMemory VName
srcmem
    Space
destspace <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM rep r op MemEntry -> ImpM rep r op Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM rep r op MemEntry
forall rep r op. VName -> ImpM rep r op MemEntry
lookupMemory VName
destmem
    if Space -> Bool
isScalarSpace Space
srcspace Bool -> Bool -> Bool
|| Space -> Bool
isScalarSpace Space
destspace
      then CopyCompiler rep r op
forall rep r op. CopyCompiler rep r op
copyElementWise PrimType
pt MemLocation
dest [DimIndex (TExp Int64)]
destslice MemLocation
src [DimIndex (TExp Int64)]
srcslice
      else
        Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
          VName
-> Count Bytes (TExp Int64)
-> Space
-> VName
-> Count Bytes (TExp Int64)
-> Space
-> Count Bytes (TExp Int64)
-> Code op
forall a.
VName
-> Count Bytes (TExp Int64)
-> Space
-> VName
-> Count Bytes (TExp Int64)
-> Space
-> Count Bytes (TExp Int64)
-> Code a
Imp.Copy
            VName
destmem
            (TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
bytes TExp Int64
destoffset)
            Space
destspace
            VName
srcmem
            (TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
bytes TExp Int64
srcoffset)
            Space
srcspace
            (Count Bytes (TExp Int64) -> Code op)
-> Count Bytes (TExp Int64) -> Code op
forall a b. (a -> b) -> a -> b
$ Count Elements (TExp Int64)
num_elems Count Elements (TExp Int64) -> PrimType -> Count Bytes (TExp Int64)
`withElemType` PrimType
pt
  | Bool
otherwise =
    CopyCompiler rep r op
forall rep r op. CopyCompiler rep r op
copyElementWise PrimType
pt MemLocation
dest [DimIndex (TExp Int64)]
destslice MemLocation
src [DimIndex (TExp Int64)]
srcslice
  where
    pt_size :: TExp Int64
pt_size = PrimType -> TExp Int64
forall a. Num a => PrimType -> a
primByteSize PrimType
pt
    num_elems :: Count Elements (TExp Int64)
num_elems = TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
Imp.elements (TExp Int64 -> Count Elements (TExp Int64))
-> TExp Int64 -> Count Elements (TExp Int64)
forall a b. (a -> b) -> a -> b
$ Shape (TExp Int64) -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product (Shape (TExp Int64) -> TExp Int64)
-> Shape (TExp Int64) -> TExp Int64
forall a b. (a -> b) -> a -> b
$ [DimIndex (TExp Int64)] -> Shape (TExp Int64)
forall d. Slice d -> [d]
sliceDims [DimIndex (TExp Int64)]
srcslice

    MemLocation VName
destmem [SubExp]
_ IxFun (TExp Int64)
dest_ixfun = MemLocation
dest
    MemLocation VName
srcmem [SubExp]
_ IxFun (TExp Int64)
src_ixfun = MemLocation
src

    isScalarSpace :: Space -> Bool
isScalarSpace ScalarSpace {} = Bool
True
    isScalarSpace Space
_ = Bool
False

copyElementWise :: CopyCompiler rep r op
copyElementWise :: forall rep r op. CopyCompiler rep r op
copyElementWise PrimType
bt MemLocation
dest [DimIndex (TExp Int64)]
destslice MemLocation
src [DimIndex (TExp Int64)]
srcslice = do
  let bounds :: Shape (TExp Int64)
bounds = [DimIndex (TExp Int64)] -> Shape (TExp Int64)
forall d. Slice d -> [d]
sliceDims [DimIndex (TExp Int64)]
srcslice
  [VName]
is <- Int -> ImpM rep r op VName -> ImpM rep r op [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (Shape (TExp Int64) -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape (TExp Int64)
bounds) ([Char] -> ImpM rep r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"i")
  let ivars :: Shape (TExp Int64)
ivars = (VName -> TExp Int64) -> [VName] -> Shape (TExp Int64)
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 [VName]
is
  (VName
destmem, Space
destspace, Count Elements (TExp Int64)
destidx) <-
    MemLocation
-> Shape (TExp Int64)
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
forall rep r op.
MemLocation
-> Shape (TExp Int64)
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLocation
dest (Shape (TExp Int64)
 -> ImpM rep r op (VName, Space, Count Elements (TExp Int64)))
-> Shape (TExp Int64)
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
forall a b. (a -> b) -> a -> b
$ [DimIndex (TExp Int64)] -> Shape (TExp Int64) -> Shape (TExp Int64)
forall d. Num d => Slice d -> [d] -> [d]
fixSlice [DimIndex (TExp Int64)]
destslice Shape (TExp Int64)
ivars
  (VName
srcmem, Space
srcspace, Count Elements (TExp Int64)
srcidx) <-
    MemLocation
-> Shape (TExp Int64)
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
forall rep r op.
MemLocation
-> Shape (TExp Int64)
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLocation
src (Shape (TExp Int64)
 -> ImpM rep r op (VName, Space, Count Elements (TExp Int64)))
-> Shape (TExp Int64)
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
forall a b. (a -> b) -> a -> b
$ [DimIndex (TExp Int64)] -> Shape (TExp Int64) -> Shape (TExp Int64)
forall d. Num d => Slice d -> [d] -> [d]
fixSlice [DimIndex (TExp Int64)]
srcslice Shape (TExp Int64)
ivars
  Volatility
vol <- (Env rep r op -> Volatility) -> ImpM rep r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env rep r op -> Volatility
forall rep r op. Env rep r op -> Volatility
envVolatility
  Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
    ((Code op -> Code op)
 -> (Code op -> Code op) -> Code op -> Code op)
-> (Code op -> Code op)
-> [Code op -> Code op]
-> Code op
-> Code op
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (Code op -> Code op) -> (Code op -> Code op) -> Code op -> Code op
forall b c a. (b -> c) -> (a -> b) -> a -> c
(.) Code op -> Code op
forall a. a -> a
id ((VName -> Exp -> Code op -> Code op)
-> [VName] -> [Exp] -> [Code op -> Code op]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> Exp -> Code op -> Code op
forall a. VName -> Exp -> Code a -> Code a
Imp.For [VName]
is ([Exp] -> [Code op -> Code op]) -> [Exp] -> [Code op -> Code op]
forall a b. (a -> b) -> a -> b
$ (TExp Int64 -> Exp) -> Shape (TExp Int64) -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped Shape (TExp Int64)
bounds) (Code op -> Code op) -> Code op -> Code op
forall a b. (a -> b) -> a -> b
$
      VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code op
forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
destmem Count Elements (TExp Int64)
destidx PrimType
bt Space
destspace Volatility
vol (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$
        VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
Imp.index VName
srcmem Count Elements (TExp Int64)
srcidx PrimType
bt Space
srcspace Volatility
vol

-- | Copy from here to there; both destination and source may be
-- indexeded.
copyArrayDWIM ::
  PrimType ->
  MemLocation ->
  [DimIndex (Imp.TExp Int64)] ->
  MemLocation ->
  [DimIndex (Imp.TExp Int64)] ->
  ImpM rep r op (Imp.Code op)
copyArrayDWIM :: forall rep r op.
PrimType
-> MemLocation
-> [DimIndex (TExp Int64)]
-> MemLocation
-> [DimIndex (TExp Int64)]
-> ImpM rep r op (Code op)
copyArrayDWIM
  PrimType
bt
  destlocation :: MemLocation
destlocation@(MemLocation VName
_ [SubExp]
destshape IxFun (TExp Int64)
_)
  [DimIndex (TExp Int64)]
destslice
  srclocation :: MemLocation
srclocation@(MemLocation VName
_ [SubExp]
srcshape IxFun (TExp Int64)
_)
  [DimIndex (TExp Int64)]
srcslice
    | Just Shape (TExp Int64)
destis <- (DimIndex (TExp Int64) -> Maybe (TExp Int64))
-> [DimIndex (TExp Int64)] -> Maybe (Shape (TExp Int64))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndex (TExp Int64) -> Maybe (TExp Int64)
forall d. DimIndex d -> Maybe d
dimFix [DimIndex (TExp Int64)]
destslice,
      Just Shape (TExp Int64)
srcis <- (DimIndex (TExp Int64) -> Maybe (TExp Int64))
-> [DimIndex (TExp Int64)] -> Maybe (Shape (TExp Int64))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndex (TExp Int64) -> Maybe (TExp Int64)
forall d. DimIndex d -> Maybe d
dimFix [DimIndex (TExp Int64)]
srcslice,
      Shape (TExp Int64) -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape (TExp Int64)
srcis Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
srcshape,
      Shape (TExp Int64) -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape (TExp Int64)
destis Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
destshape = do
      (VName
targetmem, Space
destspace, Count Elements (TExp Int64)
targetoffset) <-
        MemLocation
-> Shape (TExp Int64)
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
forall rep r op.
MemLocation
-> Shape (TExp Int64)
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLocation
destlocation Shape (TExp Int64)
destis
      (VName
srcmem, Space
srcspace, Count Elements (TExp Int64)
srcoffset) <-
        MemLocation
-> Shape (TExp Int64)
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
forall rep r op.
MemLocation
-> Shape (TExp Int64)
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLocation
srclocation Shape (TExp Int64)
srcis
      Volatility
vol <- (Env rep r op -> Volatility) -> ImpM rep r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env rep r op -> Volatility
forall rep r op. Env rep r op -> Volatility
envVolatility
      Code op -> ImpM rep r op (Code op)
forall (m :: * -> *) a. Monad m => a -> m a
return (Code op -> ImpM rep r op (Code op))
-> Code op -> ImpM rep r op (Code op)
forall a b. (a -> b) -> a -> b
$
        VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code op
forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
targetmem Count Elements (TExp Int64)
targetoffset PrimType
bt Space
destspace Volatility
vol (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$
          VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
Imp.index VName
srcmem Count Elements (TExp Int64)
srcoffset PrimType
bt Space
srcspace Volatility
vol
    | Bool
otherwise = do
      let destslice' :: [DimIndex (TExp Int64)]
destslice' =
            Shape (TExp Int64)
-> [DimIndex (TExp Int64)] -> [DimIndex (TExp Int64)]
forall d. Num d => [d] -> [DimIndex d] -> [DimIndex d]
fullSliceNum ((SubExp -> TExp Int64) -> [SubExp] -> Shape (TExp Int64)
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
destshape) [DimIndex (TExp Int64)]
destslice
          srcslice' :: [DimIndex (TExp Int64)]
srcslice' =
            Shape (TExp Int64)
-> [DimIndex (TExp Int64)] -> [DimIndex (TExp Int64)]
forall d. Num d => [d] -> [DimIndex d] -> [DimIndex d]
fullSliceNum ((SubExp -> TExp Int64) -> [SubExp] -> Shape (TExp Int64)
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
srcshape) [DimIndex (TExp Int64)]
srcslice
          destrank :: Int
destrank = Shape (TExp Int64) -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Shape (TExp Int64) -> Int) -> Shape (TExp Int64) -> Int
forall a b. (a -> b) -> a -> b
$ [DimIndex (TExp Int64)] -> Shape (TExp Int64)
forall d. Slice d -> [d]
sliceDims [DimIndex (TExp Int64)]
destslice'
          srcrank :: Int
srcrank = Shape (TExp Int64) -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Shape (TExp Int64) -> Int) -> Shape (TExp Int64) -> Int
forall a b. (a -> b) -> a -> b
$ [DimIndex (TExp Int64)] -> Shape (TExp Int64)
forall d. Slice d -> [d]
sliceDims [DimIndex (TExp Int64)]
srcslice'
      if Int
destrank Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
srcrank
        then
          [Char] -> ImpM rep r op (Code op)
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM rep r op (Code op))
-> [Char] -> ImpM rep r op (Code op)
forall a b. (a -> b) -> a -> b
$
            [Char]
"copyArrayDWIM: cannot copy to "
              [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty (MemLocation -> VName
memLocationName MemLocation
destlocation)
              [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
" from "
              [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty (MemLocation -> VName
memLocationName MemLocation
srclocation)
              [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
" because ranks do not match ("
              [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Pretty a => a -> [Char]
pretty Int
destrank
              [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
" vs "
              [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Pretty a => a -> [Char]
pretty Int
srcrank
              [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
")"
        else
          if MemLocation
destlocation MemLocation -> MemLocation -> Bool
forall a. Eq a => a -> a -> Bool
== MemLocation
srclocation Bool -> Bool -> Bool
&& [DimIndex (TExp Int64)]
destslice' [DimIndex (TExp Int64)] -> [DimIndex (TExp Int64)] -> Bool
forall a. Eq a => a -> a -> Bool
== [DimIndex (TExp Int64)]
srcslice'
            then Code op -> ImpM rep r op (Code op)
forall (m :: * -> *) a. Monad m => a -> m a
return Code op
forall a. Monoid a => a
mempty -- Copy would be no-op.
            else ImpM rep r op () -> ImpM rep r op (Code op)
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect (ImpM rep r op () -> ImpM rep r op (Code op))
-> ImpM rep r op () -> ImpM rep r op (Code op)
forall a b. (a -> b) -> a -> b
$ CopyCompiler rep r op
forall rep r op. CopyCompiler rep r op
copy PrimType
bt MemLocation
destlocation [DimIndex (TExp Int64)]
destslice' MemLocation
srclocation [DimIndex (TExp Int64)]
srcslice'

-- | Like 'copyDWIM', but the target is a 'ValueDestination'
-- instead of a variable name.
copyDWIMDest ::
  ValueDestination ->
  [DimIndex (Imp.TExp Int64)] ->
  SubExp ->
  [DimIndex (Imp.TExp Int64)] ->
  ImpM rep r op ()
copyDWIMDest :: forall rep r op.
ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIMDest ValueDestination
_ [DimIndex (TExp Int64)]
_ (Constant PrimValue
v) (DimIndex (TExp Int64)
_ : [DimIndex (TExp Int64)]
_) =
  [Char] -> ImpM rep r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM rep r op ()) -> [Char] -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
    [[Char]] -> [Char]
unwords [[Char]
"copyDWIMDest: constant source", PrimValue -> [Char]
forall a. Pretty a => a -> [Char]
pretty PrimValue
v, [Char]
"cannot be indexed."]
copyDWIMDest ValueDestination
pat [DimIndex (TExp Int64)]
dest_slice (Constant PrimValue
v) [] =
  case (DimIndex (TExp Int64) -> Maybe (TExp Int64))
-> [DimIndex (TExp Int64)] -> Maybe (Shape (TExp Int64))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndex (TExp Int64) -> Maybe (TExp Int64)
forall d. DimIndex d -> Maybe d
dimFix [DimIndex (TExp Int64)]
dest_slice of
    Maybe (Shape (TExp Int64))
Nothing ->
      [Char] -> ImpM rep r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM rep r op ()) -> [Char] -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
        [[Char]] -> [Char]
unwords [[Char]
"copyDWIMDest: constant source", PrimValue -> [Char]
forall a. Pretty a => a -> [Char]
pretty PrimValue
v, [Char]
"with slice destination."]
    Just Shape (TExp Int64)
dest_is ->
      case ValueDestination
pat of
        ScalarDestination VName
name ->
          Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
name (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ PrimValue -> Exp
forall v. PrimValue -> PrimExp v
Imp.ValueExp PrimValue
v
        MemoryDestination {} ->
          [Char] -> ImpM rep r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM rep r op ()) -> [Char] -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
            [[Char]] -> [Char]
unwords [[Char]
"copyDWIMDest: constant source", PrimValue -> [Char]
forall a. Pretty a => a -> [Char]
pretty PrimValue
v, [Char]
"cannot be written to memory destination."]
        ArrayDestination (Just MemLocation
dest_loc) -> do
          (VName
dest_mem, Space
dest_space, Count Elements (TExp Int64)
dest_i) <-
            MemLocation
-> Shape (TExp Int64)
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
forall rep r op.
MemLocation
-> Shape (TExp Int64)
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLocation
dest_loc Shape (TExp Int64)
dest_is
          Volatility
vol <- (Env rep r op -> Volatility) -> ImpM rep r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env rep r op -> Volatility
forall rep r op. Env rep r op -> Volatility
envVolatility
          Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code op
forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
dest_mem Count Elements (TExp Int64)
dest_i PrimType
bt Space
dest_space Volatility
vol (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ PrimValue -> Exp
forall v. PrimValue -> PrimExp v
Imp.ValueExp PrimValue
v
        ArrayDestination Maybe MemLocation
Nothing ->
          [Char] -> ImpM rep r op ()
forall a. HasCallStack => [Char] -> a
error [Char]
"copyDWIMDest: ArrayDestination Nothing"
  where
    bt :: PrimType
bt = PrimValue -> PrimType
primValueType PrimValue
v
copyDWIMDest ValueDestination
dest [DimIndex (TExp Int64)]
dest_slice (Var VName
src) [DimIndex (TExp Int64)]
src_slice = do
  VarEntry rep
src_entry <- VName -> ImpM rep r op (VarEntry rep)
forall rep r op. VName -> ImpM rep r op (VarEntry rep)
lookupVar VName
src
  case (ValueDestination
dest, VarEntry rep
src_entry) of
    (MemoryDestination VName
mem, MemVar Maybe (Exp rep)
_ (MemEntry Space
space)) ->
      Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> Space -> Code op
forall a. VName -> VName -> Space -> Code a
Imp.SetMem VName
mem VName
src Space
space
    (MemoryDestination {}, VarEntry rep
_) ->
      [Char] -> ImpM rep r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM rep r op ()) -> [Char] -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
        [[Char]] -> [Char]
unwords [[Char]
"copyDWIMDest: cannot write", VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
src, [Char]
"to memory destination."]
    (ValueDestination
_, MemVar {}) ->
      [Char] -> ImpM rep r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM rep r op ()) -> [Char] -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
        [[Char]] -> [Char]
unwords [[Char]
"copyDWIMDest: source", VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
src, [Char]
"is a memory block."]
    (ValueDestination
_, ScalarVar Maybe (Exp rep)
_ (ScalarEntry PrimType
_))
      | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [DimIndex (TExp Int64)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [DimIndex (TExp Int64)]
src_slice ->
        [Char] -> ImpM rep r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM rep r op ()) -> [Char] -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
          [[Char]] -> [Char]
unwords [[Char]
"copyDWIMDest: prim-typed source", VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
src, [Char]
"with slice", [DimIndex (TExp Int64)] -> [Char]
forall a. Pretty a => a -> [Char]
pretty [DimIndex (TExp Int64)]
src_slice]
    (ScalarDestination VName
name, VarEntry rep
_)
      | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [DimIndex (TExp Int64)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [DimIndex (TExp Int64)]
dest_slice ->
        [Char] -> ImpM rep r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM rep r op ()) -> [Char] -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
          [[Char]] -> [Char]
unwords [[Char]
"copyDWIMDest: prim-typed target", VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
name, [Char]
"with slice", [DimIndex (TExp Int64)] -> [Char]
forall a. Pretty a => a -> [Char]
pretty [DimIndex (TExp Int64)]
dest_slice]
    (ScalarDestination VName
name, ScalarVar Maybe (Exp rep)
_ (ScalarEntry PrimType
pt)) ->
      Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
name (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
src PrimType
pt
    (ScalarDestination VName
name, ArrayVar Maybe (Exp rep)
_ ArrayEntry
arr)
      | Just Shape (TExp Int64)
src_is <- (DimIndex (TExp Int64) -> Maybe (TExp Int64))
-> [DimIndex (TExp Int64)] -> Maybe (Shape (TExp Int64))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndex (TExp Int64) -> Maybe (TExp Int64)
forall d. DimIndex d -> Maybe d
dimFix [DimIndex (TExp Int64)]
src_slice,
        [DimIndex (TExp Int64)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex (TExp Int64)]
src_slice Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (ArrayEntry -> [SubExp]
entryArrayShape ArrayEntry
arr) -> do
        let bt :: PrimType
bt = ArrayEntry -> PrimType
entryArrayElemType ArrayEntry
arr
        (VName
mem, Space
space, Count Elements (TExp Int64)
i) <-
          MemLocation
-> Shape (TExp Int64)
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
forall rep r op.
MemLocation
-> Shape (TExp Int64)
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' (ArrayEntry -> MemLocation
entryArrayLocation ArrayEntry
arr) Shape (TExp Int64)
src_is
        Volatility
vol <- (Env rep r op -> Volatility) -> ImpM rep r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env rep r op -> Volatility
forall rep r op. Env rep r op -> Volatility
envVolatility
        Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
name (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
Imp.index VName
mem Count Elements (TExp Int64)
i PrimType
bt Space
space Volatility
vol
      | Bool
otherwise ->
        [Char] -> ImpM rep r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM rep r op ()) -> [Char] -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
          [[Char]] -> [Char]
unwords
            [ [Char]
"copyDWIMDest: prim-typed target",
              VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
name,
              [Char]
"and array-typed source",
              VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
src,
              [Char]
"with slice",
              [DimIndex (TExp Int64)] -> [Char]
forall a. Pretty a => a -> [Char]
pretty [DimIndex (TExp Int64)]
src_slice
            ]
    (ArrayDestination (Just MemLocation
dest_loc), ArrayVar Maybe (Exp rep)
_ ArrayEntry
src_arr) -> do
      let src_loc :: MemLocation
src_loc = ArrayEntry -> MemLocation
entryArrayLocation ArrayEntry
src_arr
          bt :: PrimType
bt = ArrayEntry -> PrimType
entryArrayElemType ArrayEntry
src_arr
      Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ())
-> ImpM rep r op (Code op) -> ImpM rep r op ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< PrimType
-> MemLocation
-> [DimIndex (TExp Int64)]
-> MemLocation
-> [DimIndex (TExp Int64)]
-> ImpM rep r op (Code op)
forall rep r op.
PrimType
-> MemLocation
-> [DimIndex (TExp Int64)]
-> MemLocation
-> [DimIndex (TExp Int64)]
-> ImpM rep r op (Code op)
copyArrayDWIM PrimType
bt MemLocation
dest_loc [DimIndex (TExp Int64)]
dest_slice MemLocation
src_loc [DimIndex (TExp Int64)]
src_slice
    (ArrayDestination (Just MemLocation
dest_loc), ScalarVar Maybe (Exp rep)
_ (ScalarEntry PrimType
bt))
      | Just Shape (TExp Int64)
dest_is <- (DimIndex (TExp Int64) -> Maybe (TExp Int64))
-> [DimIndex (TExp Int64)] -> Maybe (Shape (TExp Int64))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndex (TExp Int64) -> Maybe (TExp Int64)
forall d. DimIndex d -> Maybe d
dimFix [DimIndex (TExp Int64)]
dest_slice -> do
        (VName
dest_mem, Space
dest_space, Count Elements (TExp Int64)
dest_i) <- MemLocation
-> Shape (TExp Int64)
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
forall rep r op.
MemLocation
-> Shape (TExp Int64)
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLocation
dest_loc Shape (TExp Int64)
dest_is
        Volatility
vol <- (Env rep r op -> Volatility) -> ImpM rep r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env rep r op -> Volatility
forall rep r op. Env rep r op -> Volatility
envVolatility
        Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code op
forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
dest_mem Count Elements (TExp Int64)
dest_i PrimType
bt Space
dest_space Volatility
vol (VName -> PrimType -> Exp
Imp.var VName
src PrimType
bt)
      | Bool
otherwise ->
        [Char] -> ImpM rep r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM rep r op ()) -> [Char] -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
          [[Char]] -> [Char]
unwords
            [ [Char]
"copyDWIMDest: array-typed target and prim-typed source",
              VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
src,
              [Char]
"with slice",
              [DimIndex (TExp Int64)] -> [Char]
forall a. Pretty a => a -> [Char]
pretty [DimIndex (TExp Int64)]
dest_slice
            ]
    (ArrayDestination Maybe MemLocation
Nothing, VarEntry rep
_) ->
      () -> ImpM rep r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return () -- Nothing to do; something else set some memory
      -- somewhere.
    (ValueDestination
_, AccVar {}) ->
      () -> ImpM rep r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return () -- Nothing to do; accumulators are phantoms.

-- | Copy from here to there; both destination and source be
-- indexeded.  If so, they better be arrays of enough dimensions.
-- This function will generally just Do What I Mean, and Do The Right
-- Thing.  Both destination and source must be in scope.
copyDWIM ::
  VName ->
  [DimIndex (Imp.TExp Int64)] ->
  SubExp ->
  [DimIndex (Imp.TExp Int64)] ->
  ImpM rep r op ()
copyDWIM :: forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM VName
dest [DimIndex (TExp Int64)]
dest_slice SubExp
src [DimIndex (TExp Int64)]
src_slice = do
  VarEntry rep
dest_entry <- VName -> ImpM rep r op (VarEntry rep)
forall rep r op. VName -> ImpM rep r op (VarEntry rep)
lookupVar VName
dest
  let dest_target :: ValueDestination
dest_target =
        case VarEntry rep
dest_entry of
          ScalarVar Maybe (Exp rep)
_ ScalarEntry
_ ->
            VName -> ValueDestination
ScalarDestination VName
dest
          ArrayVar Maybe (Exp rep)
_ (ArrayEntry (MemLocation VName
mem [SubExp]
shape IxFun (TExp Int64)
ixfun) PrimType
_) ->
            Maybe MemLocation -> ValueDestination
ArrayDestination (Maybe MemLocation -> ValueDestination)
-> Maybe MemLocation -> ValueDestination
forall a b. (a -> b) -> a -> b
$ MemLocation -> Maybe MemLocation
forall a. a -> Maybe a
Just (MemLocation -> Maybe MemLocation)
-> MemLocation -> Maybe MemLocation
forall a b. (a -> b) -> a -> b
$ VName -> [SubExp] -> IxFun (TExp Int64) -> MemLocation
MemLocation VName
mem [SubExp]
shape IxFun (TExp Int64)
ixfun
          MemVar Maybe (Exp rep)
_ MemEntry
_ ->
            VName -> ValueDestination
MemoryDestination VName
dest
          AccVar {} ->
            -- Does not matter; accumulators are phantoms.
            Maybe MemLocation -> ValueDestination
ArrayDestination Maybe MemLocation
forall a. Maybe a
Nothing
  ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIMDest ValueDestination
dest_target [DimIndex (TExp Int64)]
dest_slice SubExp
src [DimIndex (TExp Int64)]
src_slice

-- | As 'copyDWIM', but implicitly 'DimFix'es the indexes.
copyDWIMFix ::
  VName ->
  [Imp.TExp Int64] ->
  SubExp ->
  [Imp.TExp Int64] ->
  ImpM rep r op ()
copyDWIMFix :: forall rep r op.
VName
-> Shape (TExp Int64)
-> SubExp
-> Shape (TExp Int64)
-> ImpM rep r op ()
copyDWIMFix VName
dest Shape (TExp Int64)
dest_is SubExp
src Shape (TExp Int64)
src_is =
  VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM VName
dest ((TExp Int64 -> DimIndex (TExp Int64))
-> Shape (TExp Int64) -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix Shape (TExp Int64)
dest_is) SubExp
src ((TExp Int64 -> DimIndex (TExp Int64))
-> Shape (TExp Int64) -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix Shape (TExp Int64)
src_is)

-- | @compileAlloc pat size space@ allocates @n@ bytes of memory in @space@,
-- writing the result to @dest@, which must be a single
-- 'MemoryDestination',
compileAlloc ::
  Mem rep =>
  Pattern rep ->
  SubExp ->
  Space ->
  ImpM rep r op ()
compileAlloc :: forall rep r op.
Mem rep =>
Pattern rep -> SubExp -> Space -> ImpM rep r op ()
compileAlloc (Pattern [] [PatElemT (LetDec rep)
mem]) SubExp
e Space
space = do
  let e' :: Count Bytes (TExp Int64)
e' = TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
Imp.bytes (TExp Int64 -> Count Bytes (TExp Int64))
-> TExp Int64 -> Count Bytes (TExp Int64)
forall a b. (a -> b) -> a -> b
$ SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp SubExp
e
  Maybe (AllocCompiler rep r op)
allocator <- (Env rep r op -> Maybe (AllocCompiler rep r op))
-> ImpM rep r op (Maybe (AllocCompiler rep r op))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((Env rep r op -> Maybe (AllocCompiler rep r op))
 -> ImpM rep r op (Maybe (AllocCompiler rep r op)))
-> (Env rep r op -> Maybe (AllocCompiler rep r op))
-> ImpM rep r op (Maybe (AllocCompiler rep r op))
forall a b. (a -> b) -> a -> b
$ Space
-> Map Space (AllocCompiler rep r op)
-> Maybe (AllocCompiler rep r op)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Space
space (Map Space (AllocCompiler rep r op)
 -> Maybe (AllocCompiler rep r op))
-> (Env rep r op -> Map Space (AllocCompiler rep r op))
-> Env rep r op
-> Maybe (AllocCompiler rep r op)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env rep r op -> Map Space (AllocCompiler rep r op)
forall rep r op. Env rep r op -> Map Space (AllocCompiler rep r op)
envAllocCompilers
  case Maybe (AllocCompiler rep r op)
allocator of
    Maybe (AllocCompiler rep r op)
Nothing -> Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Count Bytes (TExp Int64) -> Space -> Code op
forall a. VName -> Count Bytes (TExp Int64) -> Space -> Code a
Imp.Allocate (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
PatElemT LParamMem
mem) Count Bytes (TExp Int64)
e' Space
space
    Just AllocCompiler rep r op
allocator' -> AllocCompiler rep r op
allocator' (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec rep)
PatElemT LParamMem
mem) Count Bytes (TExp Int64)
e'
compileAlloc PatternT (LetDec rep)
pat SubExp
_ Space
_ =
  [Char] -> ImpM rep r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM rep r op ()) -> [Char] -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ [Char]
"compileAlloc: Invalid pattern: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ PatternT LParamMem -> [Char]
forall a. Pretty a => a -> [Char]
pretty PatternT (LetDec rep)
PatternT LParamMem
pat

-- | The number of bytes needed to represent the array in a
-- straightforward contiguous format, as an t'Int64' expression.
typeSize :: Type -> Count Bytes (Imp.TExp Int64)
typeSize :: Type -> Count Bytes (TExp Int64)
typeSize Type
t =
  TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
Imp.bytes (TExp Int64 -> Count Bytes (TExp Int64))
-> TExp Int64 -> Count Bytes (TExp Int64)
forall a b. (a -> b) -> a -> b
$ PrimType -> TExp Int64
forall a. Num a => PrimType -> a
primByteSize (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t) TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* Shape (TExp Int64) -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ((SubExp -> TExp Int64) -> [SubExp] -> Shape (TExp Int64)
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t))

-- | Is this indexing in-bounds for an array of the given shape?  This
-- is useful for things like scatter, which ignores out-of-bounds
-- writes.
inBounds :: Slice (Imp.TExp Int64) -> [Imp.TExp Int64] -> Imp.TExp Bool
inBounds :: [DimIndex (TExp Int64)] -> Shape (TExp Int64) -> TExp Bool
inBounds [DimIndex (TExp Int64)]
slice Shape (TExp Int64)
dims =
  let condInBounds :: DimIndex (TPrimExp t v) -> TPrimExp t v -> TPrimExp Bool v
condInBounds (DimFix TPrimExp t v
i) TPrimExp t v
d =
        TPrimExp t v
0 TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TPrimExp t v
i TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TPrimExp t v
i TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp t v
d
      condInBounds (DimSlice TPrimExp t v
i TPrimExp t v
n TPrimExp t v
s) TPrimExp t v
d =
        TPrimExp t v
0 TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TPrimExp t v
i TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TPrimExp t v
i TPrimExp t v -> TPrimExp t v -> TPrimExp t v
forall a. Num a => a -> a -> a
+ TPrimExp t v
n TPrimExp t v -> TPrimExp t v -> TPrimExp t v
forall a. Num a => a -> a -> a
* TPrimExp t v
s TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp t v
d
   in (TExp Bool -> TExp Bool -> TExp Bool) -> [TExp Bool] -> TExp Bool
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
(.&&.) ([TExp Bool] -> TExp Bool) -> [TExp Bool] -> TExp Bool
forall a b. (a -> b) -> a -> b
$ (DimIndex (TExp Int64) -> TExp Int64 -> TExp Bool)
-> [DimIndex (TExp Int64)] -> Shape (TExp Int64) -> [TExp Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith DimIndex (TExp Int64) -> TExp Int64 -> TExp Bool
forall {t} {v}.
(NumExp t, Pretty v) =>
DimIndex (TPrimExp t v) -> TPrimExp t v -> TPrimExp Bool v
condInBounds [DimIndex (TExp Int64)]
slice Shape (TExp Int64)
dims

--- Building blocks for constructing code.

sFor' :: VName -> Imp.Exp -> ImpM rep r op () -> ImpM rep r op ()
sFor' :: forall rep r op.
VName -> Exp -> ImpM rep r op () -> ImpM rep r op ()
sFor' VName
i Exp
bound ImpM rep r op ()
body = do
  let it :: IntType
it = case Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType Exp
bound of
        IntType IntType
bound_t -> IntType
bound_t
        PrimType
t -> [Char] -> IntType
forall a. HasCallStack => [Char] -> a
error ([Char] -> IntType) -> [Char] -> IntType
forall a b. (a -> b) -> a -> b
$ [Char]
"sFor': bound " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ Exp -> [Char]
forall a. Pretty a => a -> [Char]
pretty Exp
bound [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
" is of type " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ PrimType -> [Char]
forall a. Pretty a => a -> [Char]
pretty PrimType
t
  VName -> IntType -> ImpM rep r op ()
forall rep r op. VName -> IntType -> ImpM rep r op ()
addLoopVar VName
i IntType
it
  Code op
body' <- ImpM rep r op () -> ImpM rep r op (Code op)
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect ImpM rep r op ()
body
  Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op -> Code op
forall a. VName -> Exp -> Code a -> Code a
Imp.For VName
i Exp
bound Code op
body'

sFor :: String -> Imp.TExp t -> (Imp.TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor :: forall t rep r op.
[Char]
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor [Char]
i TExp t
bound TExp t -> ImpM rep r op ()
body = do
  VName
i' <- [Char] -> ImpM rep r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
i
  VName -> Exp -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op.
VName -> Exp -> ImpM rep r op () -> ImpM rep r op ()
sFor' VName
i' (TExp t -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp t
bound) (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$
    TExp t -> ImpM rep r op ()
body (TExp t -> ImpM rep r op ()) -> TExp t -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Exp -> TExp t
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (Exp -> TExp t) -> Exp -> TExp t
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
i' (PrimType -> Exp) -> PrimType -> Exp
forall a b. (a -> b) -> a -> b
$ Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType (Exp -> PrimType) -> Exp -> PrimType
forall a b. (a -> b) -> a -> b
$ TExp t -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp t
bound

sWhile :: Imp.TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhile :: forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhile TExp Bool
cond ImpM rep r op ()
body = do
  Code op
body' <- ImpM rep r op () -> ImpM rep r op (Code op)
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect ImpM rep r op ()
body
  Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ TExp Bool -> Code op -> Code op
forall a. TExp Bool -> Code a -> Code a
Imp.While TExp Bool
cond Code op
body'

sComment :: String -> ImpM rep r op () -> ImpM rep r op ()
sComment :: forall rep r op. [Char] -> ImpM rep r op () -> ImpM rep r op ()
sComment [Char]
s ImpM rep r op ()
code = do
  Code op
code' <- ImpM rep r op () -> ImpM rep r op (Code op)
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect ImpM rep r op ()
code
  Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Code op -> Code op
forall a. [Char] -> Code a -> Code a
Imp.Comment [Char]
s Code op
code'

sIf :: Imp.TExp Bool -> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf :: forall rep r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf TExp Bool
cond ImpM rep r op ()
tbranch ImpM rep r op ()
fbranch = do
  Code op
tbranch' <- ImpM rep r op () -> ImpM rep r op (Code op)
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect ImpM rep r op ()
tbranch
  Code op
fbranch' <- ImpM rep r op () -> ImpM rep r op (Code op)
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect ImpM rep r op ()
fbranch
  Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ TExp Bool -> Code op -> Code op -> Code op
forall a. TExp Bool -> Code a -> Code a -> Code a
Imp.If TExp Bool
cond Code op
tbranch' Code op
fbranch'

sWhen :: Imp.TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen :: forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
cond ImpM rep r op ()
tbranch = TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf TExp Bool
cond ImpM rep r op ()
tbranch (() -> ImpM rep r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())

sUnless :: Imp.TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sUnless :: forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sUnless TExp Bool
cond = TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf TExp Bool
cond (() -> ImpM rep r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())

sOp :: op -> ImpM rep r op ()
sOp :: forall op rep r. op -> ImpM rep r op ()
sOp = Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ())
-> (op -> Code op) -> op -> ImpM rep r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. op -> Code op
forall a. a -> Code a
Imp.Op

sDeclareMem :: String -> Space -> ImpM rep r op VName
sDeclareMem :: forall rep r op. [Char] -> Space -> ImpM rep r op VName
sDeclareMem [Char]
name Space
space = do
  VName
name' <- [Char] -> ImpM rep r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
name
  Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Code op
forall a. VName -> Space -> Code a
Imp.DeclareMem VName
name' Space
space
  VName -> VarEntry rep -> ImpM rep r op ()
forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
name' (VarEntry rep -> ImpM rep r op ())
-> VarEntry rep -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp rep) -> MemEntry -> VarEntry rep
forall rep. Maybe (Exp rep) -> MemEntry -> VarEntry rep
MemVar Maybe (Exp rep)
forall a. Maybe a
Nothing (MemEntry -> VarEntry rep) -> MemEntry -> VarEntry rep
forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
space
  VName -> ImpM rep r op VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
name'

sAlloc_ :: VName -> Count Bytes (Imp.TExp Int64) -> Space -> ImpM rep r op ()
sAlloc_ :: forall rep r op.
VName -> Count Bytes (TExp Int64) -> Space -> ImpM rep r op ()
sAlloc_ VName
name' Count Bytes (TExp Int64)
size' Space
space = do
  Maybe (AllocCompiler rep r op)
allocator <- (Env rep r op -> Maybe (AllocCompiler rep r op))
-> ImpM rep r op (Maybe (AllocCompiler rep r op))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((Env rep r op -> Maybe (AllocCompiler rep r op))
 -> ImpM rep r op (Maybe (AllocCompiler rep r op)))
-> (Env rep r op -> Maybe (AllocCompiler rep r op))
-> ImpM rep r op (Maybe (AllocCompiler rep r op))
forall a b. (a -> b) -> a -> b
$ Space
-> Map Space (AllocCompiler rep r op)
-> Maybe (AllocCompiler rep r op)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Space
space (Map Space (AllocCompiler rep r op)
 -> Maybe (AllocCompiler rep r op))
-> (Env rep r op -> Map Space (AllocCompiler rep r op))
-> Env rep r op
-> Maybe (AllocCompiler rep r op)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env rep r op -> Map Space (AllocCompiler rep r op)
forall rep r op. Env rep r op -> Map Space (AllocCompiler rep r op)
envAllocCompilers
  case Maybe (AllocCompiler rep r op)
allocator of
    Maybe (AllocCompiler rep r op)
Nothing -> Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Count Bytes (TExp Int64) -> Space -> Code op
forall a. VName -> Count Bytes (TExp Int64) -> Space -> Code a
Imp.Allocate VName
name' Count Bytes (TExp Int64)
size' Space
space
    Just AllocCompiler rep r op
allocator' -> AllocCompiler rep r op
allocator' VName
name' Count Bytes (TExp Int64)
size'

sAlloc :: String -> Count Bytes (Imp.TExp Int64) -> Space -> ImpM rep r op VName
sAlloc :: forall rep r op.
[Char] -> Count Bytes (TExp Int64) -> Space -> ImpM rep r op VName
sAlloc [Char]
name Count Bytes (TExp Int64)
size Space
space = do
  VName
name' <- [Char] -> Space -> ImpM rep r op VName
forall rep r op. [Char] -> Space -> ImpM rep r op VName
sDeclareMem [Char]
name Space
space
  VName -> Count Bytes (TExp Int64) -> Space -> ImpM rep r op ()
forall rep r op.
VName -> Count Bytes (TExp Int64) -> Space -> ImpM rep r op ()
sAlloc_ VName
name' Count Bytes (TExp Int64)
size Space
space
  VName -> ImpM rep r op VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
name'

sArray :: String -> PrimType -> ShapeBase SubExp -> MemBind -> ImpM rep r op VName
sArray :: forall rep r op.
[Char] -> PrimType -> Shape -> MemBind -> ImpM rep r op VName
sArray [Char]
name PrimType
bt Shape
shape MemBind
membind = do
  VName
name' <- [Char] -> ImpM rep r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
name
  VName -> PrimType -> Shape -> MemBind -> ImpM rep r op ()
forall rep r op.
VName -> PrimType -> Shape -> MemBind -> ImpM rep r op ()
dArray VName
name' PrimType
bt Shape
shape MemBind
membind
  VName -> ImpM rep r op VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
name'

-- | Declare an array in row-major order in the given memory block.
sArrayInMem :: String -> PrimType -> ShapeBase SubExp -> VName -> ImpM rep r op VName
sArrayInMem :: forall rep r op.
[Char] -> PrimType -> Shape -> VName -> ImpM rep r op VName
sArrayInMem [Char]
name PrimType
pt Shape
shape VName
mem =
  [Char] -> PrimType -> Shape -> MemBind -> ImpM rep r op VName
forall rep r op.
[Char] -> PrimType -> Shape -> MemBind -> ImpM rep r op VName
sArray [Char]
name PrimType
pt Shape
shape (MemBind -> ImpM rep r op VName) -> MemBind -> ImpM rep r op VName
forall a b. (a -> b) -> a -> b
$
    VName -> IxFun -> MemBind
ArrayIn VName
mem (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$
      Shape (TPrimExp Int64 VName) -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota (Shape (TPrimExp Int64 VName) -> IxFun)
-> Shape (TPrimExp Int64 VName) -> IxFun
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map (PrimExp VName -> TPrimExp Int64 VName
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (PrimExp VName -> TPrimExp Int64 VName)
-> (SubExp -> PrimExp VName) -> SubExp -> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int64) ([SubExp] -> Shape (TPrimExp Int64 VName))
-> [SubExp] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape

-- | Like 'sAllocArray', but permute the in-memory representation of the indices as specified.
sAllocArrayPerm :: String -> PrimType -> ShapeBase SubExp -> Space -> [Int] -> ImpM rep r op VName
sAllocArrayPerm :: forall rep r op.
[Char]
-> PrimType -> Shape -> Space -> [Int] -> ImpM rep r op VName
sAllocArrayPerm [Char]
name PrimType
pt Shape
shape Space
space [Int]
perm = do
  let permuted_dims :: [SubExp]
permuted_dims = [Int] -> [SubExp] -> [SubExp]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape
  VName
mem <- [Char] -> Count Bytes (TExp Int64) -> Space -> ImpM rep r op VName
forall rep r op.
[Char] -> Count Bytes (TExp Int64) -> Space -> ImpM rep r op VName
sAlloc ([Char]
name [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
"_mem") (Type -> Count Bytes (TExp Int64)
typeSize (PrimType -> Shape -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
pt Shape
shape NoUniqueness
NoUniqueness)) Space
space
  let iota_ixfun :: IxFun
iota_ixfun = Shape (TPrimExp Int64 VName) -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota (Shape (TPrimExp Int64 VName) -> IxFun)
-> Shape (TPrimExp Int64 VName) -> IxFun
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map (PrimExp VName -> TPrimExp Int64 VName
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (PrimExp VName -> TPrimExp Int64 VName)
-> (SubExp -> PrimExp VName) -> SubExp -> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int64) [SubExp]
permuted_dims
  [Char] -> PrimType -> Shape -> MemBind -> ImpM rep r op VName
forall rep r op.
[Char] -> PrimType -> Shape -> MemBind -> ImpM rep r op VName
sArray [Char]
name PrimType
pt Shape
shape (MemBind -> ImpM rep r op VName) -> MemBind -> ImpM rep r op VName
forall a b. (a -> b) -> a -> b
$
    VName -> IxFun -> MemBind
ArrayIn VName
mem (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$ IxFun -> [Int] -> IxFun
forall num. IntegralExp num => IxFun num -> [Int] -> IxFun num
IxFun.permute IxFun
iota_ixfun ([Int] -> IxFun) -> [Int] -> IxFun
forall a b. (a -> b) -> a -> b
$ [Int] -> [Int]
rearrangeInverse [Int]
perm

-- | Uses linear/iota index function.
sAllocArray :: String -> PrimType -> ShapeBase SubExp -> Space -> ImpM rep r op VName
sAllocArray :: forall rep r op.
[Char] -> PrimType -> Shape -> Space -> ImpM rep r op VName
sAllocArray [Char]
name PrimType
pt Shape
shape Space
space =
  [Char]
-> PrimType -> Shape -> Space -> [Int] -> ImpM rep r op VName
forall rep r op.
[Char]
-> PrimType -> Shape -> Space -> [Int] -> ImpM rep r op VName
sAllocArrayPerm [Char]
name PrimType
pt Shape
shape Space
space [Int
0 .. Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
shape Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]

-- | Uses linear/iota index function.
sStaticArray :: String -> Space -> PrimType -> Imp.ArrayContents -> ImpM rep r op VName
sStaticArray :: forall rep r op.
[Char] -> Space -> PrimType -> ArrayContents -> ImpM rep r op VName
sStaticArray [Char]
name Space
space PrimType
pt ArrayContents
vs = do
  let num_elems :: Int
num_elems = case ArrayContents
vs of
        Imp.ArrayValues [PrimValue]
vs' -> [PrimValue] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimValue]
vs'
        Imp.ArrayZeros Int
n -> Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n
      shape :: Shape
shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [IntType -> Integer -> SubExp
intConst IntType
Int64 (Integer -> SubExp) -> Integer -> SubExp
forall a b. (a -> b) -> a -> b
$ Int -> Integer
forall a. Integral a => a -> Integer
toInteger Int
num_elems]
  VName
mem <- [Char] -> ImpM rep r op VName
forall rep r op. [Char] -> ImpM rep r op VName
newVNameForFun ([Char] -> ImpM rep r op VName) -> [Char] -> ImpM rep r op VName
forall a b. (a -> b) -> a -> b
$ [Char]
name [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
"_mem"
  Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Space -> PrimType -> ArrayContents -> Code op
forall a. VName -> Space -> PrimType -> ArrayContents -> Code a
Imp.DeclareArray VName
mem Space
space PrimType
pt ArrayContents
vs
  VName -> VarEntry rep -> ImpM rep r op ()
forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
mem (VarEntry rep -> ImpM rep r op ())
-> VarEntry rep -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp rep) -> MemEntry -> VarEntry rep
forall rep. Maybe (Exp rep) -> MemEntry -> VarEntry rep
MemVar Maybe (Exp rep)
forall a. Maybe a
Nothing (MemEntry -> VarEntry rep) -> MemEntry -> VarEntry rep
forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
space
  [Char] -> PrimType -> Shape -> MemBind -> ImpM rep r op VName
forall rep r op.
[Char] -> PrimType -> Shape -> MemBind -> ImpM rep r op VName
sArray [Char]
name PrimType
pt Shape
shape (MemBind -> ImpM rep r op VName) -> MemBind -> ImpM rep r op VName
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn VName
mem (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$ Shape (TPrimExp Int64 VName) -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota [Int -> TPrimExp Int64 VName
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
num_elems]

sWrite :: VName -> [Imp.TExp Int64] -> Imp.Exp -> ImpM rep r op ()
sWrite :: forall rep r op.
VName -> Shape (TExp Int64) -> Exp -> ImpM rep r op ()
sWrite VName
arr Shape (TExp Int64)
is Exp
v = do
  (VName
mem, Space
space, Count Elements (TExp Int64)
offset) <- VName
-> Shape (TExp Int64)
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
forall rep r op.
VName
-> Shape (TExp Int64)
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray VName
arr Shape (TExp Int64)
is
  Volatility
vol <- (Env rep r op -> Volatility) -> ImpM rep r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env rep r op -> Volatility
forall rep r op. Env rep r op -> Volatility
envVolatility
  Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code op
forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
mem Count Elements (TExp Int64)
offset (Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType Exp
v) Space
space Volatility
vol Exp
v

sUpdate :: VName -> Slice (Imp.TExp Int64) -> SubExp -> ImpM rep r op ()
sUpdate :: forall rep r op.
VName -> [DimIndex (TExp Int64)] -> SubExp -> ImpM rep r op ()
sUpdate VName
arr [DimIndex (TExp Int64)]
slice SubExp
v = VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM VName
arr [DimIndex (TExp Int64)]
slice SubExp
v []

sLoopNest ::
  Shape ->
  ([Imp.TExp Int64] -> ImpM rep r op ()) ->
  ImpM rep r op ()
sLoopNest :: forall rep r op.
Shape
-> (Shape (TExp Int64) -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest = Shape (TExp Int64)
-> [SubExp]
-> (Shape (TExp Int64) -> ImpM rep r op ())
-> ImpM rep r op ()
forall {a} {rep} {r} {op}.
ToExp a =>
Shape (TExp Int64)
-> [a]
-> (Shape (TExp Int64) -> ImpM rep r op ())
-> ImpM rep r op ()
sLoopNest' [] ([SubExp]
 -> (Shape (TExp Int64) -> ImpM rep r op ()) -> ImpM rep r op ())
-> (Shape -> [SubExp])
-> Shape
-> (Shape (TExp Int64) -> ImpM rep r op ())
-> ImpM rep r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims
  where
    sLoopNest' :: Shape (TExp Int64)
-> [a]
-> (Shape (TExp Int64) -> ImpM rep r op ())
-> ImpM rep r op ()
sLoopNest' Shape (TExp Int64)
is [] Shape (TExp Int64) -> ImpM rep r op ()
f = Shape (TExp Int64) -> ImpM rep r op ()
f (Shape (TExp Int64) -> ImpM rep r op ())
-> Shape (TExp Int64) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Shape (TExp Int64) -> Shape (TExp Int64)
forall a. [a] -> [a]
reverse Shape (TExp Int64)
is
    sLoopNest' Shape (TExp Int64)
is (a
d : [a]
ds) Shape (TExp Int64) -> ImpM rep r op ()
f =
      [Char]
-> TExp Int64
-> (TExp Int64 -> ImpM rep r op ())
-> ImpM rep r op ()
forall t rep r op.
[Char]
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor [Char]
"nest_i" (a -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp a
d) ((TExp Int64 -> ImpM rep r op ()) -> ImpM rep r op ())
-> (TExp Int64 -> ImpM rep r op ()) -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> Shape (TExp Int64)
-> [a]
-> (Shape (TExp Int64) -> ImpM rep r op ())
-> ImpM rep r op ()
sLoopNest' (TExp Int64
i TExp Int64 -> Shape (TExp Int64) -> Shape (TExp Int64)
forall a. a -> [a] -> [a]
: Shape (TExp Int64)
is) [a]
ds Shape (TExp Int64) -> ImpM rep r op ()
f

-- | Untyped assignment.
(<~~) :: VName -> Imp.Exp -> ImpM rep r op ()
VName
x <~~ :: forall rep r op. VName -> Exp -> ImpM rep r op ()
<~~ Exp
e = Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
x Exp
e

infixl 3 <~~

-- | Typed assignment.
(<--) :: TV t -> Imp.TExp t -> ImpM rep r op ()
TV VName
x PrimType
_ <-- :: forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TExp t
e = Code op -> ImpM rep r op ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code op -> ImpM rep r op ()) -> Code op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
x (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ TExp t -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp t
e

infixl 3 <--

-- | Constructing an ad-hoc function that does not
-- correspond to any of the IR functions in the input program.
function ::
  Name ->
  [Imp.Param] ->
  [Imp.Param] ->
  ImpM rep r op () ->
  ImpM rep r op ()
function :: forall rep r op.
Name -> [Param] -> [Param] -> ImpM rep r op () -> ImpM rep r op ()
function Name
fname [Param]
outputs [Param]
inputs ImpM rep r op ()
m = (Env rep r op -> Env rep r op)
-> ImpM rep r op () -> ImpM rep r op ()
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local Env rep r op -> Env rep r op
newFunction (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ do
  Code op
body <- ImpM rep r op () -> ImpM rep r op (Code op)
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect (ImpM rep r op () -> ImpM rep r op (Code op))
-> ImpM rep r op () -> ImpM rep r op (Code op)
forall a b. (a -> b) -> a -> b
$ do
    (Param -> ImpM rep r op ()) -> [Param] -> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Param -> ImpM rep r op ()
forall {rep} {r} {op}. Param -> ImpM rep r op ()
addParam ([Param] -> ImpM rep r op ()) -> [Param] -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ [Param]
outputs [Param] -> [Param] -> [Param]
forall a. [a] -> [a] -> [a]
++ [Param]
inputs
    ImpM rep r op ()
m
  Name -> Function op -> ImpM rep r op ()
forall op rep r. Name -> Function op -> ImpM rep r op ()
emitFunction Name
fname (Function op -> ImpM rep r op ())
-> Function op -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Maybe Name
-> [Param]
-> [Param]
-> Code op
-> [ExternalValue]
-> [ExternalValue]
-> Function op
forall a.
Maybe Name
-> [Param]
-> [Param]
-> Code a
-> [ExternalValue]
-> [ExternalValue]
-> FunctionT a
Imp.Function Maybe Name
forall a. Maybe a
Nothing [Param]
outputs [Param]
inputs Code op
body [] []
  where
    addParam :: Param -> ImpM rep r op ()
addParam (Imp.MemParam VName
name Space
space) =
      VName -> VarEntry rep -> ImpM rep r op ()
forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
name (VarEntry rep -> ImpM rep r op ())
-> VarEntry rep -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp rep) -> MemEntry -> VarEntry rep
forall rep. Maybe (Exp rep) -> MemEntry -> VarEntry rep
MemVar Maybe (Exp rep)
forall a. Maybe a
Nothing (MemEntry -> VarEntry rep) -> MemEntry -> VarEntry rep
forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
space
    addParam (Imp.ScalarParam VName
name PrimType
bt) =
      VName -> VarEntry rep -> ImpM rep r op ()
forall rep r op. VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
name (VarEntry rep -> ImpM rep r op ())
-> VarEntry rep -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp rep) -> ScalarEntry -> VarEntry rep
forall rep. Maybe (Exp rep) -> ScalarEntry -> VarEntry rep
ScalarVar Maybe (Exp rep)
forall a. Maybe a
Nothing (ScalarEntry -> VarEntry rep) -> ScalarEntry -> VarEntry rep
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarEntry
ScalarEntry PrimType
bt
    newFunction :: Env rep r op -> Env rep r op
newFunction Env rep r op
env = Env rep r op
env {envFunction :: Maybe Name
envFunction = Name -> Maybe Name
forall a. a -> Maybe a
Just Name
fname}

dSlices :: [Imp.TExp Int64] -> ImpM rep r op [Imp.TExp Int64]
dSlices :: forall rep r op.
Shape (TExp Int64) -> ImpM rep r op (Shape (TExp Int64))
dSlices = ((TExp Int64, Shape (TExp Int64)) -> Shape (TExp Int64))
-> ImpM rep r op (TExp Int64, Shape (TExp Int64))
-> ImpM rep r op (Shape (TExp Int64))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int -> Shape (TExp Int64) -> Shape (TExp Int64)
forall a. Int -> [a] -> [a]
drop Int
1 (Shape (TExp Int64) -> Shape (TExp Int64))
-> ((TExp Int64, Shape (TExp Int64)) -> Shape (TExp Int64))
-> (TExp Int64, Shape (TExp Int64))
-> Shape (TExp Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TExp Int64, Shape (TExp Int64)) -> Shape (TExp Int64)
forall a b. (a, b) -> b
snd) (ImpM rep r op (TExp Int64, Shape (TExp Int64))
 -> ImpM rep r op (Shape (TExp Int64)))
-> (Shape (TExp Int64)
    -> ImpM rep r op (TExp Int64, Shape (TExp Int64)))
-> Shape (TExp Int64)
-> ImpM rep r op (Shape (TExp Int64))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Shape (TExp Int64)
-> ImpM rep r op (TExp Int64, Shape (TExp Int64))
forall {t} {rep} {r} {op}.
NumExp t =>
[TExp t] -> ImpM rep r op (TExp t, [TExp t])
dSlices'
  where
    dSlices' :: [TExp t] -> ImpM rep r op (TExp t, [TExp t])
dSlices' [] = (TExp t, [TExp t]) -> ImpM rep r op (TExp t, [TExp t])
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TExp t
1, [TExp t
1])
    dSlices' (TExp t
n : [TExp t]
ns) = do
      (TExp t
prod, [TExp t]
ns') <- [TExp t] -> ImpM rep r op (TExp t, [TExp t])
dSlices' [TExp t]
ns
      TExp t
n' <- [Char] -> TExp t -> ImpM rep r op (TExp t)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"slice" (TExp t -> ImpM rep r op (TExp t))
-> TExp t -> ImpM rep r op (TExp t)
forall a b. (a -> b) -> a -> b
$ TExp t
n TExp t -> TExp t -> TExp t
forall a. Num a => a -> a -> a
* TExp t
prod
      (TExp t, [TExp t]) -> ImpM rep r op (TExp t, [TExp t])
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TExp t
n', TExp t
n' TExp t -> [TExp t] -> [TExp t]
forall a. a -> [a] -> [a]
: [TExp t]
ns')

-- | @dIndexSpace f dims i@ computes a list of indices into an
-- array with dimension @dims@ given the flat index @i@.  The
-- resulting list will have the same size as @dims@.  Intermediate
-- results are passed to @f@.
dIndexSpace ::
  [(VName, Imp.TExp Int64)] ->
  Imp.TExp Int64 ->
  ImpM rep r op ()
dIndexSpace :: forall rep r op.
[(VName, TExp Int64)] -> TExp Int64 -> ImpM rep r op ()
dIndexSpace [(VName, TExp Int64)]
vs_ds TExp Int64
j = do
  Shape (TExp Int64)
slices <- Shape (TExp Int64) -> ImpM rep r op (Shape (TExp Int64))
forall rep r op.
Shape (TExp Int64) -> ImpM rep r op (Shape (TExp Int64))
dSlices (((VName, TExp Int64) -> TExp Int64)
-> [(VName, TExp Int64)] -> Shape (TExp Int64)
forall a b. (a -> b) -> [a] -> [b]
map (VName, TExp Int64) -> TExp Int64
forall a b. (a, b) -> b
snd [(VName, TExp Int64)]
vs_ds)
  [(VName, TExp Int64)] -> TExp Int64 -> ImpM rep r op ()
forall rep r op.
[(VName, TExp Int64)] -> TExp Int64 -> ImpM rep r op ()
loop ([VName] -> Shape (TExp Int64) -> [(VName, TExp Int64)]
forall a b. [a] -> [b] -> [(a, b)]
zip (((VName, TExp Int64) -> VName) -> [(VName, TExp Int64)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, TExp Int64) -> VName
forall a b. (a, b) -> a
fst [(VName, TExp Int64)]
vs_ds) Shape (TExp Int64)
slices) TExp Int64
j
  where
    loop :: [(VName, TExp Int64)] -> TExp Int64 -> ImpM rep r op ()
loop ((VName
v, TExp Int64
size) : [(VName, TExp Int64)]
rest) TExp Int64
i = do
      VName -> TExp Int64 -> ImpM rep r op ()
forall t rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
v (TExp Int64
i TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`quot` TExp Int64
size)
      TExp Int64
i' <- [Char] -> TExp Int64 -> ImpM rep r op (TExp Int64)
forall t rep r op. [Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"remnant" (TExp Int64 -> ImpM rep r op (TExp Int64))
-> TExp Int64 -> ImpM rep r op (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int64
i TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- VName -> TExp Int64
Imp.vi64 VName
v TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
size
      [(VName, TExp Int64)] -> TExp Int64 -> ImpM rep r op ()
loop [(VName, TExp Int64)]
rest TExp Int64
i'
    loop [(VName, TExp Int64)]
_ TExp Int64
_ = () -> ImpM rep r op ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()