{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE Strict #-}
{-# LANGUAGE Trustworthy #-}
{-# LANGUAGE TypeFamilies #-}

module Futhark.CodeGen.ImpGen
  ( -- * Entry Points
    compileProg,

    -- * Pluggable Compiler
    OpCompiler,
    ExpCompiler,
    CopyCompiler,
    StmsCompiler,
    AllocCompiler,
    Operations (..),
    defaultOperations,
    MemLocation (..),
    MemEntry (..),
    ScalarEntry (..),

    -- * Monadic Compiler Interface
    ImpM,
    localDefaultSpace,
    askFunction,
    newVNameForFun,
    nameForFun,
    askEnv,
    localEnv,
    localOps,
    VTable,
    getVTable,
    localVTable,
    subImpM,
    subImpM_,
    emit,
    emitFunction,
    hasFunction,
    collect,
    collect',
    comment,
    VarEntry (..),
    ArrayEntry (..),

    -- * Lookups
    lookupVar,
    lookupArray,
    lookupMemory,
    lookupAcc,

    -- * Building Blocks
    TV,
    mkTV,
    tvSize,
    tvExp,
    tvVar,
    ToExp (..),
    compileAlloc,
    everythingVolatile,
    compileBody,
    compileBody',
    compileLoopBody,
    defCompileStms,
    compileStms,
    compileExp,
    defCompileExp,
    fullyIndexArray,
    fullyIndexArray',
    copy,
    copyDWIM,
    copyDWIMFix,
    copyElementWise,
    typeSize,
    inBounds,
    isMapTransposeCopy,

    -- * Constructing code.
    dLParams,
    dFParams,
    dScope,
    dArray,
    dPrim,
    dPrimVol,
    dPrim_,
    dPrimV_,
    dPrimV,
    dPrimVE,
    dIndexSpace,
    sFor,
    sWhile,
    sComment,
    sIf,
    sWhen,
    sUnless,
    sOp,
    sDeclareMem,
    sAlloc,
    sAlloc_,
    sArray,
    sArrayInMem,
    sAllocArray,
    sAllocArrayPerm,
    sStaticArray,
    sWrite,
    sUpdate,
    sLoopNest,
    (<--),
    (<~~),
    function,
    warn,
    module Language.Futhark.Warnings,
  )
where

import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Writer
import Control.Parallel.Strategies
import Data.Bifunctor (first)
import qualified Data.DList as DL
import Data.Either
import Data.List (find, genericLength, sortOn)
import qualified Data.Map.Strict as M
import Data.Maybe
import qualified Data.Set as S
import Data.String
import Futhark.CodeGen.ImpCode
  ( Bytes,
    Count,
    Elements,
    bytes,
    elements,
    withElemType,
  )
import qualified Futhark.CodeGen.ImpCode as Imp
import Futhark.CodeGen.ImpGen.Transpose
import Futhark.Construct hiding (ToExp (..))
import Futhark.IR.Mem
import qualified Futhark.IR.Mem.IxFun as IxFun
import Futhark.IR.SOACS (SOACS)
import Futhark.Util
import Futhark.Util.IntegralExp
import Futhark.Util.Loc (noLoc)
import Language.Futhark.Warnings
import Prelude hiding (quot)

-- | How to compile an t'Op'.
type OpCompiler lore r op = Pattern lore -> Op lore -> ImpM lore r op ()

-- | How to compile some 'Stms'.
type StmsCompiler lore r op = Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()

-- | How to compile an 'Exp'.
type ExpCompiler lore r op = Pattern lore -> Exp lore -> ImpM lore r op ()

type CopyCompiler lore r op =
  PrimType ->
  MemLocation ->
  Slice (Imp.TExp Int64) ->
  MemLocation ->
  Slice (Imp.TExp Int64) ->
  ImpM lore r op ()

-- | An alternate way of compiling an allocation.
type AllocCompiler lore r op = VName -> Count Bytes (Imp.TExp Int64) -> ImpM lore r op ()

data Operations lore r op = Operations
  { forall lore r op. Operations lore r op -> ExpCompiler lore r op
opsExpCompiler :: ExpCompiler lore r op,
    forall lore r op. Operations lore r op -> OpCompiler lore r op
opsOpCompiler :: OpCompiler lore r op,
    forall lore r op. Operations lore r op -> StmsCompiler lore r op
opsStmsCompiler :: StmsCompiler lore r op,
    forall lore r op. Operations lore r op -> CopyCompiler lore r op
opsCopyCompiler :: CopyCompiler lore r op,
    forall lore r op.
Operations lore r op -> Map Space (AllocCompiler lore r op)
opsAllocCompilers :: M.Map Space (AllocCompiler lore r op)
  }

-- | An operations set for which the expression compiler always
-- returns 'defCompileExp'.
defaultOperations ::
  (Mem lore, FreeIn op) =>
  OpCompiler lore r op ->
  Operations lore r op
defaultOperations :: forall lore op r.
(Mem lore, FreeIn op) =>
OpCompiler lore r op -> Operations lore r op
defaultOperations OpCompiler lore r op
opc =
  Operations :: forall lore r op.
ExpCompiler lore r op
-> OpCompiler lore r op
-> StmsCompiler lore r op
-> CopyCompiler lore r op
-> Map Space (AllocCompiler lore r op)
-> Operations lore r op
Operations
    { opsExpCompiler :: ExpCompiler lore r op
opsExpCompiler = ExpCompiler lore r op
forall lore r op.
Mem lore =>
Pattern lore -> Exp lore -> ImpM lore r op ()
defCompileExp,
      opsOpCompiler :: OpCompiler lore r op
opsOpCompiler = OpCompiler lore r op
opc,
      opsStmsCompiler :: StmsCompiler lore r op
opsStmsCompiler = StmsCompiler lore r op
forall lore op r.
(Mem lore, FreeIn op) =>
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
defCompileStms,
      opsCopyCompiler :: CopyCompiler lore r op
opsCopyCompiler = CopyCompiler lore r op
forall lore r op. CopyCompiler lore r op
defaultCopy,
      opsAllocCompilers :: Map Space (AllocCompiler lore r op)
opsAllocCompilers = Map Space (AllocCompiler lore r op)
forall a. Monoid a => a
mempty
    }

-- | When an array is dared, this is where it is stored.
data MemLocation = MemLocation
  { MemLocation -> VName
memLocationName :: VName,
    MemLocation -> [SubExp]
memLocationShape :: [Imp.DimSize],
    MemLocation -> IxFun (TExp Int64)
memLocationIxFun :: IxFun.IxFun (Imp.TExp Int64)
  }
  deriving (MemLocation -> MemLocation -> Bool
(MemLocation -> MemLocation -> Bool)
-> (MemLocation -> MemLocation -> Bool) -> Eq MemLocation
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: MemLocation -> MemLocation -> Bool
$c/= :: MemLocation -> MemLocation -> Bool
== :: MemLocation -> MemLocation -> Bool
$c== :: MemLocation -> MemLocation -> Bool
Eq, Int -> MemLocation -> ShowS
[MemLocation] -> ShowS
MemLocation -> [Char]
(Int -> MemLocation -> ShowS)
-> (MemLocation -> [Char])
-> ([MemLocation] -> ShowS)
-> Show MemLocation
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [MemLocation] -> ShowS
$cshowList :: [MemLocation] -> ShowS
show :: MemLocation -> [Char]
$cshow :: MemLocation -> [Char]
showsPrec :: Int -> MemLocation -> ShowS
$cshowsPrec :: Int -> MemLocation -> ShowS
Show)

data ArrayEntry = ArrayEntry
  { ArrayEntry -> MemLocation
entryArrayLocation :: MemLocation,
    ArrayEntry -> PrimType
entryArrayElemType :: PrimType
  }
  deriving (Int -> ArrayEntry -> ShowS
[ArrayEntry] -> ShowS
ArrayEntry -> [Char]
(Int -> ArrayEntry -> ShowS)
-> (ArrayEntry -> [Char])
-> ([ArrayEntry] -> ShowS)
-> Show ArrayEntry
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [ArrayEntry] -> ShowS
$cshowList :: [ArrayEntry] -> ShowS
show :: ArrayEntry -> [Char]
$cshow :: ArrayEntry -> [Char]
showsPrec :: Int -> ArrayEntry -> ShowS
$cshowsPrec :: Int -> ArrayEntry -> ShowS
Show)

entryArrayShape :: ArrayEntry -> [Imp.DimSize]
entryArrayShape :: ArrayEntry -> [SubExp]
entryArrayShape = MemLocation -> [SubExp]
memLocationShape (MemLocation -> [SubExp])
-> (ArrayEntry -> MemLocation) -> ArrayEntry -> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArrayEntry -> MemLocation
entryArrayLocation

newtype MemEntry = MemEntry {MemEntry -> Space
entryMemSpace :: Imp.Space}
  deriving (Int -> MemEntry -> ShowS
[MemEntry] -> ShowS
MemEntry -> [Char]
(Int -> MemEntry -> ShowS)
-> (MemEntry -> [Char]) -> ([MemEntry] -> ShowS) -> Show MemEntry
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [MemEntry] -> ShowS
$cshowList :: [MemEntry] -> ShowS
show :: MemEntry -> [Char]
$cshow :: MemEntry -> [Char]
showsPrec :: Int -> MemEntry -> ShowS
$cshowsPrec :: Int -> MemEntry -> ShowS
Show)

newtype ScalarEntry = ScalarEntry
  { ScalarEntry -> PrimType
entryScalarType :: PrimType
  }
  deriving (Int -> ScalarEntry -> ShowS
[ScalarEntry] -> ShowS
ScalarEntry -> [Char]
(Int -> ScalarEntry -> ShowS)
-> (ScalarEntry -> [Char])
-> ([ScalarEntry] -> ShowS)
-> Show ScalarEntry
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [ScalarEntry] -> ShowS
$cshowList :: [ScalarEntry] -> ShowS
show :: ScalarEntry -> [Char]
$cshow :: ScalarEntry -> [Char]
showsPrec :: Int -> ScalarEntry -> ShowS
$cshowsPrec :: Int -> ScalarEntry -> ShowS
Show)

-- | Every non-scalar variable must be associated with an entry.
data VarEntry lore
  = ArrayVar (Maybe (Exp lore)) ArrayEntry
  | ScalarVar (Maybe (Exp lore)) ScalarEntry
  | MemVar (Maybe (Exp lore)) MemEntry
  | AccVar (Maybe (Exp lore)) (VName, Shape, [Type])
  deriving (Int -> VarEntry lore -> ShowS
[VarEntry lore] -> ShowS
VarEntry lore -> [Char]
(Int -> VarEntry lore -> ShowS)
-> (VarEntry lore -> [Char])
-> ([VarEntry lore] -> ShowS)
-> Show (VarEntry lore)
forall lore. Decorations lore => Int -> VarEntry lore -> ShowS
forall lore. Decorations lore => [VarEntry lore] -> ShowS
forall lore. Decorations lore => VarEntry lore -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [VarEntry lore] -> ShowS
$cshowList :: forall lore. Decorations lore => [VarEntry lore] -> ShowS
show :: VarEntry lore -> [Char]
$cshow :: forall lore. Decorations lore => VarEntry lore -> [Char]
showsPrec :: Int -> VarEntry lore -> ShowS
$cshowsPrec :: forall lore. Decorations lore => Int -> VarEntry lore -> ShowS
Show)

-- | When compiling an expression, this is a description of where the
-- result should end up.  The integer is a reference to the construct
-- that gave rise to this destination (for patterns, this will be the
-- tag of the first name in the pattern).  This can be used to make
-- the generated code easier to relate to the original code.
data Destination = Destination
  { Destination -> Maybe Int
destinationTag :: Maybe Int,
    Destination -> [ValueDestination]
valueDestinations :: [ValueDestination]
  }
  deriving (Int -> Destination -> ShowS
[Destination] -> ShowS
Destination -> [Char]
(Int -> Destination -> ShowS)
-> (Destination -> [Char])
-> ([Destination] -> ShowS)
-> Show Destination
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [Destination] -> ShowS
$cshowList :: [Destination] -> ShowS
show :: Destination -> [Char]
$cshow :: Destination -> [Char]
showsPrec :: Int -> Destination -> ShowS
$cshowsPrec :: Int -> Destination -> ShowS
Show)

data ValueDestination
  = ScalarDestination VName
  | MemoryDestination VName
  | -- | The 'MemLocation' is 'Just' if a copy if
    -- required.  If it is 'Nothing', then a
    -- copy/assignment of a memory block somewhere
    -- takes care of this array.
    ArrayDestination (Maybe MemLocation)
  deriving (Int -> ValueDestination -> ShowS
[ValueDestination] -> ShowS
ValueDestination -> [Char]
(Int -> ValueDestination -> ShowS)
-> (ValueDestination -> [Char])
-> ([ValueDestination] -> ShowS)
-> Show ValueDestination
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [ValueDestination] -> ShowS
$cshowList :: [ValueDestination] -> ShowS
show :: ValueDestination -> [Char]
$cshow :: ValueDestination -> [Char]
showsPrec :: Int -> ValueDestination -> ShowS
$cshowsPrec :: Int -> ValueDestination -> ShowS
Show)

data Env lore r op = Env
  { forall lore r op. Env lore r op -> ExpCompiler lore r op
envExpCompiler :: ExpCompiler lore r op,
    forall lore r op. Env lore r op -> StmsCompiler lore r op
envStmsCompiler :: StmsCompiler lore r op,
    forall lore r op. Env lore r op -> OpCompiler lore r op
envOpCompiler :: OpCompiler lore r op,
    forall lore r op. Env lore r op -> CopyCompiler lore r op
envCopyCompiler :: CopyCompiler lore r op,
    forall lore r op.
Env lore r op -> Map Space (AllocCompiler lore r op)
envAllocCompilers :: M.Map Space (AllocCompiler lore r op),
    forall lore r op. Env lore r op -> Space
envDefaultSpace :: Imp.Space,
    forall lore r op. Env lore r op -> Volatility
envVolatility :: Imp.Volatility,
    -- | User-extensible environment.
    forall lore r op. Env lore r op -> r
envEnv :: r,
    -- | Name of the function we are compiling, if any.
    forall lore r op. Env lore r op -> Maybe Name
envFunction :: Maybe Name,
    -- | The set of attributes that are active on the enclosing
    -- statements (including the one we are currently compiling).
    forall lore r op. Env lore r op -> Attrs
envAttrs :: Attrs
  }

newEnv :: r -> Operations lore r op -> Imp.Space -> Env lore r op
newEnv :: forall r lore op.
r -> Operations lore r op -> Space -> Env lore r op
newEnv r
r Operations lore r op
ops Space
ds =
  Env :: forall lore r op.
ExpCompiler lore r op
-> StmsCompiler lore r op
-> OpCompiler lore r op
-> CopyCompiler lore r op
-> Map Space (AllocCompiler lore r op)
-> Space
-> Volatility
-> r
-> Maybe Name
-> Attrs
-> Env lore r op
Env
    { envExpCompiler :: ExpCompiler lore r op
envExpCompiler = Operations lore r op -> ExpCompiler lore r op
forall lore r op. Operations lore r op -> ExpCompiler lore r op
opsExpCompiler Operations lore r op
ops,
      envStmsCompiler :: StmsCompiler lore r op
envStmsCompiler = Operations lore r op -> StmsCompiler lore r op
forall lore r op. Operations lore r op -> StmsCompiler lore r op
opsStmsCompiler Operations lore r op
ops,
      envOpCompiler :: OpCompiler lore r op
envOpCompiler = Operations lore r op -> OpCompiler lore r op
forall lore r op. Operations lore r op -> OpCompiler lore r op
opsOpCompiler Operations lore r op
ops,
      envCopyCompiler :: CopyCompiler lore r op
envCopyCompiler = Operations lore r op -> CopyCompiler lore r op
forall lore r op. Operations lore r op -> CopyCompiler lore r op
opsCopyCompiler Operations lore r op
ops,
      envAllocCompilers :: Map Space (AllocCompiler lore r op)
envAllocCompilers = Map Space (AllocCompiler lore r op)
forall a. Monoid a => a
mempty,
      envDefaultSpace :: Space
envDefaultSpace = Space
ds,
      envVolatility :: Volatility
envVolatility = Volatility
Imp.Nonvolatile,
      envEnv :: r
envEnv = r
r,
      envFunction :: Maybe Name
envFunction = Maybe Name
forall a. Maybe a
Nothing,
      envAttrs :: Attrs
envAttrs = Attrs
forall a. Monoid a => a
mempty
    }

-- | The symbol table used during compilation.
type VTable lore = M.Map VName (VarEntry lore)

data ImpState lore r op = ImpState
  { forall lore r op. ImpState lore r op -> VTable lore
stateVTable :: VTable lore,
    forall lore r op. ImpState lore r op -> Functions op
stateFunctions :: Imp.Functions op,
    forall lore r op. ImpState lore r op -> Code op
stateCode :: Imp.Code op,
    forall lore r op. ImpState lore r op -> Warnings
stateWarnings :: Warnings,
    -- | Maps the arrays backing each accumulator to their
    -- update function and neutral elements.  This works
    -- because an array name can only become part of a single
    -- accumulator throughout its lifetime.  If the arrays
    -- backing an accumulator is not in this mapping, the
    -- accumulator is scatter-like.
    forall lore r op.
ImpState lore r op
-> Map VName ([VName], Maybe (Lambda lore, [SubExp]))
stateAccs :: M.Map VName ([VName], Maybe (Lambda lore, [SubExp])),
    forall lore r op. ImpState lore r op -> VNameSource
stateNameSource :: VNameSource
  }

newState :: VNameSource -> ImpState lore r op
newState :: forall lore r op. VNameSource -> ImpState lore r op
newState = VTable lore
-> Functions op
-> Code op
-> Warnings
-> Map VName ([VName], Maybe (Lambda lore, [SubExp]))
-> VNameSource
-> ImpState lore r op
forall lore r op.
VTable lore
-> Functions op
-> Code op
-> Warnings
-> Map VName ([VName], Maybe (Lambda lore, [SubExp]))
-> VNameSource
-> ImpState lore r op
ImpState VTable lore
forall a. Monoid a => a
mempty Functions op
forall a. Monoid a => a
mempty Code op
forall a. Monoid a => a
mempty Warnings
forall a. Monoid a => a
mempty Map VName ([VName], Maybe (Lambda lore, [SubExp]))
forall a. Monoid a => a
mempty

newtype ImpM lore r op a
  = ImpM (ReaderT (Env lore r op) (State (ImpState lore r op)) a)
  deriving
    ( (forall a b. (a -> b) -> ImpM lore r op a -> ImpM lore r op b)
-> (forall a b. a -> ImpM lore r op b -> ImpM lore r op a)
-> Functor (ImpM lore r op)
forall a b. a -> ImpM lore r op b -> ImpM lore r op a
forall a b. (a -> b) -> ImpM lore r op a -> ImpM lore r op b
forall lore r op a b. a -> ImpM lore r op b -> ImpM lore r op a
forall lore r op a b.
(a -> b) -> ImpM lore r op a -> ImpM lore r op b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> ImpM lore r op b -> ImpM lore r op a
$c<$ :: forall lore r op a b. a -> ImpM lore r op b -> ImpM lore r op a
fmap :: forall a b. (a -> b) -> ImpM lore r op a -> ImpM lore r op b
$cfmap :: forall lore r op a b.
(a -> b) -> ImpM lore r op a -> ImpM lore r op b
Functor,
      Functor (ImpM lore r op)
Functor (ImpM lore r op)
-> (forall a. a -> ImpM lore r op a)
-> (forall a b.
    ImpM lore r op (a -> b) -> ImpM lore r op a -> ImpM lore r op b)
-> (forall a b c.
    (a -> b -> c)
    -> ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op c)
-> (forall a b.
    ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b)
-> (forall a b.
    ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op a)
-> Applicative (ImpM lore r op)
forall a. a -> ImpM lore r op a
forall a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op a
forall a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
forall a b.
ImpM lore r op (a -> b) -> ImpM lore r op a -> ImpM lore r op b
forall lore r op. Functor (ImpM lore r op)
forall a b c.
(a -> b -> c)
-> ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op c
forall lore r op a. a -> ImpM lore r op a
forall lore r op a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op a
forall lore r op a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
forall lore r op a b.
ImpM lore r op (a -> b) -> ImpM lore r op a -> ImpM lore r op b
forall lore r op a b c.
(a -> b -> c)
-> ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op a
$c<* :: forall lore r op a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op a
*> :: forall a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
$c*> :: forall lore r op a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
liftA2 :: forall a b c.
(a -> b -> c)
-> ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op c
$cliftA2 :: forall lore r op a b c.
(a -> b -> c)
-> ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op c
<*> :: forall a b.
ImpM lore r op (a -> b) -> ImpM lore r op a -> ImpM lore r op b
$c<*> :: forall lore r op a b.
ImpM lore r op (a -> b) -> ImpM lore r op a -> ImpM lore r op b
pure :: forall a. a -> ImpM lore r op a
$cpure :: forall lore r op a. a -> ImpM lore r op a
Applicative,
      Applicative (ImpM lore r op)
Applicative (ImpM lore r op)
-> (forall a b.
    ImpM lore r op a -> (a -> ImpM lore r op b) -> ImpM lore r op b)
-> (forall a b.
    ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b)
-> (forall a. a -> ImpM lore r op a)
-> Monad (ImpM lore r op)
forall a. a -> ImpM lore r op a
forall a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
forall a b.
ImpM lore r op a -> (a -> ImpM lore r op b) -> ImpM lore r op b
forall lore r op. Applicative (ImpM lore r op)
forall lore r op a. a -> ImpM lore r op a
forall lore r op a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
forall lore r op a b.
ImpM lore r op a -> (a -> ImpM lore r op b) -> ImpM lore r op b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> ImpM lore r op a
$creturn :: forall lore r op a. a -> ImpM lore r op a
>> :: forall a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
$c>> :: forall lore r op a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
>>= :: forall a b.
ImpM lore r op a -> (a -> ImpM lore r op b) -> ImpM lore r op b
$c>>= :: forall lore r op a b.
ImpM lore r op a -> (a -> ImpM lore r op b) -> ImpM lore r op b
Monad,
      MonadState (ImpState lore r op),
      MonadReader (Env lore r op)
    )

instance MonadFreshNames (ImpM lore r op) where
  getNameSource :: ImpM lore r op VNameSource
getNameSource = (ImpState lore r op -> VNameSource) -> ImpM lore r op VNameSource
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ImpState lore r op -> VNameSource
forall lore r op. ImpState lore r op -> VNameSource
stateNameSource
  putNameSource :: VNameSource -> ImpM lore r op ()
putNameSource VNameSource
src = (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ())
-> (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState lore r op
s -> ImpState lore r op
s {stateNameSource :: VNameSource
stateNameSource = VNameSource
src}

-- Cannot be an KernelsMem scope because the index functions have
-- the wrong leaves (VName instead of Imp.Exp).
instance HasScope SOACS (ImpM lore r op) where
  askScope :: ImpM lore r op (Scope SOACS)
askScope = (ImpState lore r op -> Scope SOACS) -> ImpM lore r op (Scope SOACS)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ImpState lore r op -> Scope SOACS)
 -> ImpM lore r op (Scope SOACS))
-> (ImpState lore r op -> Scope SOACS)
-> ImpM lore r op (Scope SOACS)
forall a b. (a -> b) -> a -> b
$ (VarEntry lore -> NameInfo SOACS)
-> Map VName (VarEntry lore) -> Scope SOACS
forall a b k. (a -> b) -> Map k a -> Map k b
M.map (Type -> NameInfo SOACS
forall lore. LetDec lore -> NameInfo lore
LetName (Type -> NameInfo SOACS)
-> (VarEntry lore -> Type) -> VarEntry lore -> NameInfo SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VarEntry lore -> Type
forall {lore}. VarEntry lore -> Type
entryType) (Map VName (VarEntry lore) -> Scope SOACS)
-> (ImpState lore r op -> Map VName (VarEntry lore))
-> ImpState lore r op
-> Scope SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ImpState lore r op -> Map VName (VarEntry lore)
forall lore r op. ImpState lore r op -> VTable lore
stateVTable
    where
      entryType :: VarEntry lore -> Type
entryType (MemVar Maybe (Exp lore)
_ MemEntry
memEntry) =
        Space -> Type
forall shape u. Space -> TypeBase shape u
Mem (MemEntry -> Space
entryMemSpace MemEntry
memEntry)
      entryType (ArrayVar Maybe (Exp lore)
_ ArrayEntry
arrayEntry) =
        PrimType -> Shape -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array
          (ArrayEntry -> PrimType
entryArrayElemType ArrayEntry
arrayEntry)
          ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> [SubExp] -> Shape
forall a b. (a -> b) -> a -> b
$ ArrayEntry -> [SubExp]
entryArrayShape ArrayEntry
arrayEntry)
          NoUniqueness
NoUniqueness
      entryType (ScalarVar Maybe (Exp lore)
_ ScalarEntry
scalarEntry) =
        PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> Type) -> PrimType -> Type
forall a b. (a -> b) -> a -> b
$ ScalarEntry -> PrimType
entryScalarType ScalarEntry
scalarEntry
      entryType (AccVar Maybe (Exp lore)
_ (VName
acc, Shape
ispace, [Type]
ts)) =
        VName -> Shape -> [Type] -> NoUniqueness -> Type
forall shape u. VName -> Shape -> [Type] -> u -> TypeBase shape u
Acc VName
acc Shape
ispace [Type]
ts NoUniqueness
NoUniqueness

runImpM ::
  ImpM lore r op a ->
  r ->
  Operations lore r op ->
  Imp.Space ->
  ImpState lore r op ->
  (a, ImpState lore r op)
runImpM :: forall lore r op a.
ImpM lore r op a
-> r
-> Operations lore r op
-> Space
-> ImpState lore r op
-> (a, ImpState lore r op)
runImpM (ImpM ReaderT (Env lore r op) (State (ImpState lore r op)) a
m) r
r Operations lore r op
ops Space
space = State (ImpState lore r op) a
-> ImpState lore r op -> (a, ImpState lore r op)
forall s a. State s a -> s -> (a, s)
runState (ReaderT (Env lore r op) (State (ImpState lore r op)) a
-> Env lore r op -> State (ImpState lore r op) a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Env lore r op) (State (ImpState lore r op)) a
m (Env lore r op -> State (ImpState lore r op) a)
-> Env lore r op -> State (ImpState lore r op) a
forall a b. (a -> b) -> a -> b
$ r -> Operations lore r op -> Space -> Env lore r op
forall r lore op.
r -> Operations lore r op -> Space -> Env lore r op
newEnv r
r Operations lore r op
ops Space
space)

subImpM_ ::
  r' ->
  Operations lore r' op' ->
  ImpM lore r' op' a ->
  ImpM lore r op (Imp.Code op')
subImpM_ :: forall r' lore op' a r op.
r'
-> Operations lore r' op'
-> ImpM lore r' op' a
-> ImpM lore r op (Code op')
subImpM_ r'
r Operations lore r' op'
ops ImpM lore r' op' a
m = (a, Code op') -> Code op'
forall a b. (a, b) -> b
snd ((a, Code op') -> Code op')
-> ImpM lore r op (a, Code op') -> ImpM lore r op (Code op')
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> r'
-> Operations lore r' op'
-> ImpM lore r' op' a
-> ImpM lore r op (a, Code op')
forall r' lore op' a r op.
r'
-> Operations lore r' op'
-> ImpM lore r' op' a
-> ImpM lore r op (a, Code op')
subImpM r'
r Operations lore r' op'
ops ImpM lore r' op' a
m

subImpM ::
  r' ->
  Operations lore r' op' ->
  ImpM lore r' op' a ->
  ImpM lore r op (a, Imp.Code op')
subImpM :: forall r' lore op' a r op.
r'
-> Operations lore r' op'
-> ImpM lore r' op' a
-> ImpM lore r op (a, Code op')
subImpM r'
r Operations lore r' op'
ops (ImpM ReaderT (Env lore r' op') (State (ImpState lore r' op')) a
m) = do
  Env lore r op
env <- ImpM lore r op (Env lore r op)
forall r (m :: * -> *). MonadReader r m => m r
ask
  ImpState lore r op
s <- ImpM lore r op (ImpState lore r op)
forall s (m :: * -> *). MonadState s m => m s
get

  let env' :: Env lore r' op'
env' =
        Env lore r op
env
          { envExpCompiler :: ExpCompiler lore r' op'
envExpCompiler = Operations lore r' op' -> ExpCompiler lore r' op'
forall lore r op. Operations lore r op -> ExpCompiler lore r op
opsExpCompiler Operations lore r' op'
ops,
            envStmsCompiler :: StmsCompiler lore r' op'
envStmsCompiler = Operations lore r' op' -> StmsCompiler lore r' op'
forall lore r op. Operations lore r op -> StmsCompiler lore r op
opsStmsCompiler Operations lore r' op'
ops,
            envCopyCompiler :: CopyCompiler lore r' op'
envCopyCompiler = Operations lore r' op' -> CopyCompiler lore r' op'
forall lore r op. Operations lore r op -> CopyCompiler lore r op
opsCopyCompiler Operations lore r' op'
ops,
            envOpCompiler :: OpCompiler lore r' op'
envOpCompiler = Operations lore r' op' -> OpCompiler lore r' op'
forall lore r op. Operations lore r op -> OpCompiler lore r op
opsOpCompiler Operations lore r' op'
ops,
            envAllocCompilers :: Map Space (AllocCompiler lore r' op')
envAllocCompilers = Operations lore r' op' -> Map Space (AllocCompiler lore r' op')
forall lore r op.
Operations lore r op -> Map Space (AllocCompiler lore r op)
opsAllocCompilers Operations lore r' op'
ops,
            envEnv :: r'
envEnv = r'
r
          }
      s' :: ImpState lore r' op'
s' =
        ImpState :: forall lore r op.
VTable lore
-> Functions op
-> Code op
-> Warnings
-> Map VName ([VName], Maybe (Lambda lore, [SubExp]))
-> VNameSource
-> ImpState lore r op
ImpState
          { stateVTable :: VTable lore
stateVTable = ImpState lore r op -> VTable lore
forall lore r op. ImpState lore r op -> VTable lore
stateVTable ImpState lore r op
s,
            stateFunctions :: Functions op'
stateFunctions = Functions op'
forall a. Monoid a => a
mempty,
            stateCode :: Code op'
stateCode = Code op'
forall a. Monoid a => a
mempty,
            stateNameSource :: VNameSource
stateNameSource = ImpState lore r op -> VNameSource
forall lore r op. ImpState lore r op -> VNameSource
stateNameSource ImpState lore r op
s,
            stateWarnings :: Warnings
stateWarnings = Warnings
forall a. Monoid a => a
mempty,
            stateAccs :: Map VName ([VName], Maybe (Lambda lore, [SubExp]))
stateAccs = ImpState lore r op
-> Map VName ([VName], Maybe (Lambda lore, [SubExp]))
forall lore r op.
ImpState lore r op
-> Map VName ([VName], Maybe (Lambda lore, [SubExp]))
stateAccs ImpState lore r op
s
          }
      (a
x, ImpState lore r' op'
s'') = State (ImpState lore r' op') a
-> ImpState lore r' op' -> (a, ImpState lore r' op')
forall s a. State s a -> s -> (a, s)
runState (ReaderT (Env lore r' op') (State (ImpState lore r' op')) a
-> Env lore r' op' -> State (ImpState lore r' op') a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Env lore r' op') (State (ImpState lore r' op')) a
m Env lore r' op'
env') ImpState lore r' op'
s'

  VNameSource -> ImpM lore r op ()
forall (m :: * -> *). MonadFreshNames m => VNameSource -> m ()
putNameSource (VNameSource -> ImpM lore r op ())
-> VNameSource -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ ImpState lore r' op' -> VNameSource
forall lore r op. ImpState lore r op -> VNameSource
stateNameSource ImpState lore r' op'
s''
  Warnings -> ImpM lore r op ()
forall lore r op. Warnings -> ImpM lore r op ()
warnings (Warnings -> ImpM lore r op ()) -> Warnings -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ ImpState lore r' op' -> Warnings
forall lore r op. ImpState lore r op -> Warnings
stateWarnings ImpState lore r' op'
s''
  (a, Code op') -> ImpM lore r op (a, Code op')
forall (m :: * -> *) a. Monad m => a -> m a
return (a
x, ImpState lore r' op' -> Code op'
forall lore r op. ImpState lore r op -> Code op
stateCode ImpState lore r' op'
s'')

-- | Execute a code generation action, returning the code that was
-- emitted.
collect :: ImpM lore r op () -> ImpM lore r op (Imp.Code op)
collect :: forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect = (((), Code op) -> Code op)
-> ImpM lore r op ((), Code op) -> ImpM lore r op (Code op)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((), Code op) -> Code op
forall a b. (a, b) -> b
snd (ImpM lore r op ((), Code op) -> ImpM lore r op (Code op))
-> (ImpM lore r op () -> ImpM lore r op ((), Code op))
-> ImpM lore r op ()
-> ImpM lore r op (Code op)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ImpM lore r op () -> ImpM lore r op ((), Code op)
forall lore r op a. ImpM lore r op a -> ImpM lore r op (a, Code op)
collect'

collect' :: ImpM lore r op a -> ImpM lore r op (a, Imp.Code op)
collect' :: forall lore r op a. ImpM lore r op a -> ImpM lore r op (a, Code op)
collect' ImpM lore r op a
m = do
  Code op
prev_code <- (ImpState lore r op -> Code op) -> ImpM lore r op (Code op)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ImpState lore r op -> Code op
forall lore r op. ImpState lore r op -> Code op
stateCode
  (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ())
-> (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState lore r op
s -> ImpState lore r op
s {stateCode :: Code op
stateCode = Code op
forall a. Monoid a => a
mempty}
  a
x <- ImpM lore r op a
m
  Code op
new_code <- (ImpState lore r op -> Code op) -> ImpM lore r op (Code op)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ImpState lore r op -> Code op
forall lore r op. ImpState lore r op -> Code op
stateCode
  (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ())
-> (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState lore r op
s -> ImpState lore r op
s {stateCode :: Code op
stateCode = Code op
prev_code}
  (a, Code op) -> ImpM lore r op (a, Code op)
forall (m :: * -> *) a. Monad m => a -> m a
return (a
x, Code op
new_code)

-- | Execute a code generation action, wrapping the generated code
-- within a 'Imp.Comment' with the given description.
comment :: String -> ImpM lore r op () -> ImpM lore r op ()
comment :: forall lore r op. [Char] -> ImpM lore r op () -> ImpM lore r op ()
comment [Char]
desc ImpM lore r op ()
m = do
  Code op
code <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect ImpM lore r op ()
m
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Code op -> Code op
forall a. [Char] -> Code a -> Code a
Imp.Comment [Char]
desc Code op
code

-- | Emit some generated imperative code.
emit :: Imp.Code op -> ImpM lore r op ()
emit :: forall op lore r. Code op -> ImpM lore r op ()
emit Code op
code = (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ())
-> (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState lore r op
s -> ImpState lore r op
s {stateCode :: Code op
stateCode = ImpState lore r op -> Code op
forall lore r op. ImpState lore r op -> Code op
stateCode ImpState lore r op
s Code op -> Code op -> Code op
forall a. Semigroup a => a -> a -> a
<> Code op
code}

warnings :: Warnings -> ImpM lore r op ()
warnings :: forall lore r op. Warnings -> ImpM lore r op ()
warnings Warnings
ws = (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ())
-> (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState lore r op
s -> ImpState lore r op
s {stateWarnings :: Warnings
stateWarnings = Warnings
ws Warnings -> Warnings -> Warnings
forall a. Semigroup a => a -> a -> a
<> ImpState lore r op -> Warnings
forall lore r op. ImpState lore r op -> Warnings
stateWarnings ImpState lore r op
s}

-- | Emit a warning about something the user should be aware of.
warn :: Located loc => loc -> [loc] -> String -> ImpM lore r op ()
warn :: forall loc lore r op.
Located loc =>
loc -> [loc] -> [Char] -> ImpM lore r op ()
warn loc
loc [loc]
locs [Char]
problem =
  Warnings -> ImpM lore r op ()
forall lore r op. Warnings -> ImpM lore r op ()
warnings (Warnings -> ImpM lore r op ()) -> Warnings -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ SrcLoc -> [SrcLoc] -> Doc -> Warnings
singleWarning' (loc -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf loc
loc) ((loc -> SrcLoc) -> [loc] -> [SrcLoc]
forall a b. (a -> b) -> [a] -> [b]
map loc -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf [loc]
locs) ([Char] -> Doc
forall a. IsString a => [Char] -> a
fromString [Char]
problem)

-- | Emit a function in the generated code.
emitFunction :: Name -> Imp.Function op -> ImpM lore r op ()
emitFunction :: forall op lore r. Name -> Function op -> ImpM lore r op ()
emitFunction Name
fname Function op
fun = do
  Imp.Functions [(Name, Function op)]
fs <- (ImpState lore r op -> Functions op)
-> ImpM lore r op (Functions op)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ImpState lore r op -> Functions op
forall lore r op. ImpState lore r op -> Functions op
stateFunctions
  (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ())
-> (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState lore r op
s -> ImpState lore r op
s {stateFunctions :: Functions op
stateFunctions = [(Name, Function op)] -> Functions op
forall a. [(Name, Function a)] -> Functions a
Imp.Functions ([(Name, Function op)] -> Functions op)
-> [(Name, Function op)] -> Functions op
forall a b. (a -> b) -> a -> b
$ (Name
fname, Function op
fun) (Name, Function op)
-> [(Name, Function op)] -> [(Name, Function op)]
forall a. a -> [a] -> [a]
: [(Name, Function op)]
fs}

-- | Check if a function of a given name exists.
hasFunction :: Name -> ImpM lore r op Bool
hasFunction :: forall lore r op. Name -> ImpM lore r op Bool
hasFunction Name
fname = (ImpState lore r op -> Bool) -> ImpM lore r op Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ImpState lore r op -> Bool) -> ImpM lore r op Bool)
-> (ImpState lore r op -> Bool) -> ImpM lore r op Bool
forall a b. (a -> b) -> a -> b
$ \ImpState lore r op
s ->
  let Imp.Functions [(Name, Function op)]
fs = ImpState lore r op -> Functions op
forall lore r op. ImpState lore r op -> Functions op
stateFunctions ImpState lore r op
s
   in Maybe (Function op) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe (Function op) -> Bool) -> Maybe (Function op) -> Bool
forall a b. (a -> b) -> a -> b
$ Name -> [(Name, Function op)] -> Maybe (Function op)
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Name
fname [(Name, Function op)]
fs

constsVTable :: Mem lore => Stms lore -> VTable lore
constsVTable :: forall lore. Mem lore => Stms lore -> VTable lore
constsVTable = (Stm lore -> Map VName (VarEntry lore))
-> Seq (Stm lore) -> Map VName (VarEntry lore)
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm lore -> Map VName (VarEntry lore)
forall {lore}.
(LetDec lore ~ LParamMem) =>
Stm lore -> Map VName (VarEntry lore)
stmVtable
  where
    stmVtable :: Stm lore -> Map VName (VarEntry lore)
stmVtable (Let Pattern lore
pat StmAux (ExpDec lore)
_ Exp lore
e) =
      (PatElemT LParamMem -> Map VName (VarEntry lore))
-> [PatElemT LParamMem] -> Map VName (VarEntry lore)
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Exp lore -> PatElemT LParamMem -> Map VName (VarEntry lore)
forall {lore}.
Exp lore -> PatElemT LParamMem -> Map VName (VarEntry lore)
peVtable Exp lore
e) ([PatElemT LParamMem] -> Map VName (VarEntry lore))
-> [PatElemT LParamMem] -> Map VName (VarEntry lore)
forall a b. (a -> b) -> a -> b
$ PatternT LParamMem -> [PatElemT LParamMem]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern lore
PatternT LParamMem
pat
    peVtable :: Exp lore -> PatElemT LParamMem -> Map VName (VarEntry lore)
peVtable Exp lore
e (PatElem VName
name LParamMem
dec) =
      VName -> VarEntry lore -> Map VName (VarEntry lore)
forall k a. k -> a -> Map k a
M.singleton VName
name (VarEntry lore -> Map VName (VarEntry lore))
-> VarEntry lore -> Map VName (VarEntry lore)
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> LParamMem -> VarEntry lore
forall lore. Maybe (Exp lore) -> LParamMem -> VarEntry lore
memBoundToVarEntry (Exp lore -> Maybe (Exp lore)
forall a. a -> Maybe a
Just Exp lore
e) LParamMem
dec

compileProg ::
  (Mem lore, FreeIn op, MonadFreshNames m) =>
  r ->
  Operations lore r op ->
  Imp.Space ->
  Prog lore ->
  m (Warnings, Imp.Definitions op)
compileProg :: forall lore op (m :: * -> *) r.
(Mem lore, FreeIn op, MonadFreshNames m) =>
r
-> Operations lore r op
-> Space
-> Prog lore
-> m (Warnings, Definitions op)
compileProg r
r Operations lore r op
ops Space
space (Prog Stms lore
consts [FunDef lore]
funs) =
  (VNameSource -> ((Warnings, Definitions op), VNameSource))
-> m (Warnings, Definitions op)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> ((Warnings, Definitions op), VNameSource))
 -> m (Warnings, Definitions op))
-> (VNameSource -> ((Warnings, Definitions op), VNameSource))
-> m (Warnings, Definitions op)
forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
    let ([()]
_, [ImpState lore r op]
ss) =
          [((), ImpState lore r op)] -> ([()], [ImpState lore r op])
forall a b. [(a, b)] -> ([a], [b])
unzip ([((), ImpState lore r op)] -> ([()], [ImpState lore r op]))
-> [((), ImpState lore r op)] -> ([()], [ImpState lore r op])
forall a b. (a -> b) -> a -> b
$ Strategy ((), ImpState lore r op)
-> (FunDef lore -> ((), ImpState lore r op))
-> [FunDef lore]
-> [((), ImpState lore r op)]
forall b a. Strategy b -> (a -> b) -> [a] -> [b]
parMap Strategy ((), ImpState lore r op)
forall a. Strategy a
rpar (VNameSource -> FunDef lore -> ((), ImpState lore r op)
compileFunDef' VNameSource
src) [FunDef lore]
funs
        free_in_funs :: Names
free_in_funs =
          Functions op -> Names
forall a. FreeIn a => a -> Names
freeIn (Functions op -> Names) -> Functions op -> Names
forall a b. (a -> b) -> a -> b
$ [Functions op] -> Functions op
forall a. Monoid a => [a] -> a
mconcat ([Functions op] -> Functions op) -> [Functions op] -> Functions op
forall a b. (a -> b) -> a -> b
$ (ImpState lore r op -> Functions op)
-> [ImpState lore r op] -> [Functions op]
forall a b. (a -> b) -> [a] -> [b]
map ImpState lore r op -> Functions op
forall lore r op. ImpState lore r op -> Functions op
stateFunctions [ImpState lore r op]
ss
        (Constants op
consts', ImpState lore r op
s') =
          ImpM lore r op (Constants op)
-> r
-> Operations lore r op
-> Space
-> ImpState lore r op
-> (Constants op, ImpState lore r op)
forall lore r op a.
ImpM lore r op a
-> r
-> Operations lore r op
-> Space
-> ImpState lore r op
-> (a, ImpState lore r op)
runImpM (Names -> Stms lore -> ImpM lore r op (Constants op)
forall lore r op.
Names -> Stms lore -> ImpM lore r op (Constants op)
compileConsts Names
free_in_funs Stms lore
consts) r
r Operations lore r op
ops Space
space (ImpState lore r op -> (Constants op, ImpState lore r op))
-> ImpState lore r op -> (Constants op, ImpState lore r op)
forall a b. (a -> b) -> a -> b
$
            [ImpState lore r op] -> ImpState lore r op
forall {lore} {r} {op} {lore} {r}.
[ImpState lore r op] -> ImpState lore r op
combineStates [ImpState lore r op]
ss
     in ( ( ImpState lore r op -> Warnings
forall lore r op. ImpState lore r op -> Warnings
stateWarnings ImpState lore r op
s',
            Constants op -> Functions op -> Definitions op
forall a. Constants a -> Functions a -> Definitions a
Imp.Definitions Constants op
consts' (ImpState lore r op -> Functions op
forall lore r op. ImpState lore r op -> Functions op
stateFunctions ImpState lore r op
s')
          ),
          ImpState lore r op -> VNameSource
forall lore r op. ImpState lore r op -> VNameSource
stateNameSource ImpState lore r op
s'
        )
  where
    compileFunDef' :: VNameSource -> FunDef lore -> ((), ImpState lore r op)
compileFunDef' VNameSource
src FunDef lore
fdef =
      ImpM lore r op ()
-> r
-> Operations lore r op
-> Space
-> ImpState lore r op
-> ((), ImpState lore r op)
forall lore r op a.
ImpM lore r op a
-> r
-> Operations lore r op
-> Space
-> ImpState lore r op
-> (a, ImpState lore r op)
runImpM
        (FunDef lore -> ImpM lore r op ()
forall lore r op. Mem lore => FunDef lore -> ImpM lore r op ()
compileFunDef FunDef lore
fdef)
        r
r
        Operations lore r op
ops
        Space
space
        (VNameSource -> ImpState lore Any op
forall lore r op. VNameSource -> ImpState lore r op
newState VNameSource
src) {stateVTable :: VTable lore
stateVTable = Stms lore -> VTable lore
forall lore. Mem lore => Stms lore -> VTable lore
constsVTable Stms lore
consts}

    combineStates :: [ImpState lore r op] -> ImpState lore r op
combineStates [ImpState lore r op]
ss =
      let Imp.Functions [(Name, Function op)]
funs' = [Functions op] -> Functions op
forall a. Monoid a => [a] -> a
mconcat ([Functions op] -> Functions op) -> [Functions op] -> Functions op
forall a b. (a -> b) -> a -> b
$ (ImpState lore r op -> Functions op)
-> [ImpState lore r op] -> [Functions op]
forall a b. (a -> b) -> [a] -> [b]
map ImpState lore r op -> Functions op
forall lore r op. ImpState lore r op -> Functions op
stateFunctions [ImpState lore r op]
ss
          src :: VNameSource
src = [VNameSource] -> VNameSource
forall a. Monoid a => [a] -> a
mconcat ((ImpState lore r op -> VNameSource)
-> [ImpState lore r op] -> [VNameSource]
forall a b. (a -> b) -> [a] -> [b]
map ImpState lore r op -> VNameSource
forall lore r op. ImpState lore r op -> VNameSource
stateNameSource [ImpState lore r op]
ss)
       in (VNameSource -> ImpState lore Any op
forall lore r op. VNameSource -> ImpState lore r op
newState VNameSource
src)
            { stateFunctions :: Functions op
stateFunctions =
                [(Name, Function op)] -> Functions op
forall a. [(Name, Function a)] -> Functions a
Imp.Functions ([(Name, Function op)] -> Functions op)
-> [(Name, Function op)] -> Functions op
forall a b. (a -> b) -> a -> b
$ Map Name (Function op) -> [(Name, Function op)]
forall k a. Map k a -> [(k, a)]
M.toList (Map Name (Function op) -> [(Name, Function op)])
-> Map Name (Function op) -> [(Name, Function op)]
forall a b. (a -> b) -> a -> b
$ [(Name, Function op)] -> Map Name (Function op)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(Name, Function op)]
funs',
              stateWarnings :: Warnings
stateWarnings =
                [Warnings] -> Warnings
forall a. Monoid a => [a] -> a
mconcat ([Warnings] -> Warnings) -> [Warnings] -> Warnings
forall a b. (a -> b) -> a -> b
$ (ImpState lore r op -> Warnings)
-> [ImpState lore r op] -> [Warnings]
forall a b. (a -> b) -> [a] -> [b]
map ImpState lore r op -> Warnings
forall lore r op. ImpState lore r op -> Warnings
stateWarnings [ImpState lore r op]
ss
            }

compileConsts :: Names -> Stms lore -> ImpM lore r op (Imp.Constants op)
compileConsts :: forall lore r op.
Names -> Stms lore -> ImpM lore r op (Constants op)
compileConsts Names
used_consts Stms lore
stms = do
  Code op
code <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (ImpM lore r op () -> ImpM lore r op (Code op))
-> ImpM lore r op () -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$ Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
used_consts Stms lore
stms (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ () -> ImpM lore r op ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  Constants op -> ImpM lore r op (Constants op)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Constants op -> ImpM lore r op (Constants op))
-> Constants op -> ImpM lore r op (Constants op)
forall a b. (a -> b) -> a -> b
$ ([Param] -> Code op -> Constants op)
-> ([Param], Code op) -> Constants op
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry [Param] -> Code op -> Constants op
forall a. [Param] -> Code a -> Constants a
Imp.Constants (([Param], Code op) -> Constants op)
-> ([Param], Code op) -> Constants op
forall a b. (a -> b) -> a -> b
$ (DList Param -> [Param])
-> (DList Param, Code op) -> ([Param], Code op)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first DList Param -> [Param]
forall a. DList a -> [a]
DL.toList ((DList Param, Code op) -> ([Param], Code op))
-> (DList Param, Code op) -> ([Param], Code op)
forall a b. (a -> b) -> a -> b
$ Code op -> (DList Param, Code op)
extract Code op
code
  where
    -- Fish out those top-level declarations in the constant
    -- initialisation code that are free in the functions.
    extract :: Code op -> (DList Param, Code op)
extract (Code op
x Imp.:>>: Code op
y) =
      Code op -> (DList Param, Code op)
extract Code op
x (DList Param, Code op)
-> (DList Param, Code op) -> (DList Param, Code op)
forall a. Semigroup a => a -> a -> a
<> Code op -> (DList Param, Code op)
extract Code op
y
    extract (Imp.DeclareMem VName
name Space
space)
      | VName
name VName -> Names -> Bool
`nameIn` Names
used_consts =
        ( Param -> DList Param
forall a. a -> DList a
DL.singleton (Param -> DList Param) -> Param -> DList Param
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Param
Imp.MemParam VName
name Space
space,
          Code op
forall a. Monoid a => a
mempty
        )
    extract (Imp.DeclareScalar VName
name Volatility
_ PrimType
t)
      | VName
name VName -> Names -> Bool
`nameIn` Names
used_consts =
        ( Param -> DList Param
forall a. a -> DList a
DL.singleton (Param -> DList Param) -> Param -> DList Param
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Param
Imp.ScalarParam VName
name PrimType
t,
          Code op
forall a. Monoid a => a
mempty
        )
    extract Code op
s =
      (DList Param
forall a. Monoid a => a
mempty, Code op
s)

compileInParam ::
  Mem lore =>
  FParam lore ->
  ImpM lore r op (Either Imp.Param ArrayDecl)
compileInParam :: forall lore r op.
Mem lore =>
FParam lore -> ImpM lore r op (Either Param ArrayDecl)
compileInParam FParam lore
fparam = case Param FParamMem -> FParamMem
forall dec. Param dec -> dec
paramDec FParam lore
Param FParamMem
fparam of
  MemPrim PrimType
bt ->
    Either Param ArrayDecl -> ImpM lore r op (Either Param ArrayDecl)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either Param ArrayDecl -> ImpM lore r op (Either Param ArrayDecl))
-> Either Param ArrayDecl
-> ImpM lore r op (Either Param ArrayDecl)
forall a b. (a -> b) -> a -> b
$ Param -> Either Param ArrayDecl
forall a b. a -> Either a b
Left (Param -> Either Param ArrayDecl)
-> Param -> Either Param ArrayDecl
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Param
Imp.ScalarParam VName
name PrimType
bt
  MemMem Space
space ->
    Either Param ArrayDecl -> ImpM lore r op (Either Param ArrayDecl)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either Param ArrayDecl -> ImpM lore r op (Either Param ArrayDecl))
-> Either Param ArrayDecl
-> ImpM lore r op (Either Param ArrayDecl)
forall a b. (a -> b) -> a -> b
$ Param -> Either Param ArrayDecl
forall a b. a -> Either a b
Left (Param -> Either Param ArrayDecl)
-> Param -> Either Param ArrayDecl
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Param
Imp.MemParam VName
name Space
space
  MemArray PrimType
bt Shape
shape Uniqueness
_ (ArrayIn VName
mem IxFun
ixfun) ->
    Either Param ArrayDecl -> ImpM lore r op (Either Param ArrayDecl)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either Param ArrayDecl -> ImpM lore r op (Either Param ArrayDecl))
-> Either Param ArrayDecl
-> ImpM lore r op (Either Param ArrayDecl)
forall a b. (a -> b) -> a -> b
$
      ArrayDecl -> Either Param ArrayDecl
forall a b. b -> Either a b
Right (ArrayDecl -> Either Param ArrayDecl)
-> ArrayDecl -> Either Param ArrayDecl
forall a b. (a -> b) -> a -> b
$
        VName -> PrimType -> MemLocation -> ArrayDecl
ArrayDecl VName
name PrimType
bt (MemLocation -> ArrayDecl) -> MemLocation -> ArrayDecl
forall a b. (a -> b) -> a -> b
$
          VName -> [SubExp] -> IxFun (TExp Int64) -> MemLocation
MemLocation VName
mem (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape) (IxFun (TExp Int64) -> MemLocation)
-> IxFun (TExp Int64) -> MemLocation
forall a b. (a -> b) -> a -> b
$ (TPrimExp Int64 VName -> TExp Int64) -> IxFun -> IxFun (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> ExpLeaf) -> TPrimExp Int64 VName -> TExp Int64
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> ExpLeaf
Imp.ScalarVar) IxFun
ixfun
  MemAcc {} ->
    [Char] -> ImpM lore r op (Either Param ArrayDecl)
forall a. HasCallStack => [Char] -> a
error [Char]
"Functions may not have accumulator parameters."
  where
    name :: VName
name = Param FParamMem -> VName
forall dec. Param dec -> VName
paramName FParam lore
Param FParamMem
fparam

data ArrayDecl = ArrayDecl VName PrimType MemLocation

compileInParams ::
  Mem lore =>
  [FParam lore] ->
  [EntryPointType] ->
  ImpM lore r op ([Imp.Param], [ArrayDecl], [Imp.ExternalValue])
compileInParams :: forall lore r op.
Mem lore =>
[FParam lore]
-> [EntryPointType]
-> ImpM lore r op ([Param], [ArrayDecl], [ExternalValue])
compileInParams [FParam lore]
params [EntryPointType]
orig_epts = do
  let ([Param FParamMem]
ctx_params, [Param FParamMem]
val_params) =
        Int -> [Param FParamMem] -> ([Param FParamMem], [Param FParamMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([Param FParamMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [FParam lore]
[Param FParamMem]
params Int -> Int -> Int
forall a. Num a => a -> a -> a
- [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((EntryPointType -> Int) -> [EntryPointType] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map EntryPointType -> Int
entryPointSize [EntryPointType]
orig_epts)) [FParam lore]
[Param FParamMem]
params
  ([Param]
inparams, [ArrayDecl]
arrayds) <- [Either Param ArrayDecl] -> ([Param], [ArrayDecl])
forall a b. [Either a b] -> ([a], [b])
partitionEithers ([Either Param ArrayDecl] -> ([Param], [ArrayDecl]))
-> ImpM lore r op [Either Param ArrayDecl]
-> ImpM lore r op ([Param], [ArrayDecl])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Param FParamMem -> ImpM lore r op (Either Param ArrayDecl))
-> [Param FParamMem] -> ImpM lore r op [Either Param ArrayDecl]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Param FParamMem -> ImpM lore r op (Either Param ArrayDecl)
forall lore r op.
Mem lore =>
FParam lore -> ImpM lore r op (Either Param ArrayDecl)
compileInParam ([Param FParamMem]
ctx_params [Param FParamMem] -> [Param FParamMem] -> [Param FParamMem]
forall a. [a] -> [a] -> [a]
++ [Param FParamMem]
val_params)
  let findArray :: VName -> Maybe ArrayDecl
findArray VName
x = (ArrayDecl -> Bool) -> [ArrayDecl] -> Maybe ArrayDecl
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (VName -> ArrayDecl -> Bool
isArrayDecl VName
x) [ArrayDecl]
arrayds

      summaries :: Map VName Space
summaries = [(VName, Space)] -> Map VName Space
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Space)] -> Map VName Space)
-> [(VName, Space)] -> Map VName Space
forall a b. (a -> b) -> a -> b
$ (Param FParamMem -> Maybe (VName, Space))
-> [Param FParamMem] -> [(VName, Space)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Param FParamMem -> Maybe (VName, Space)
forall {d} {u} {ret}.
Param (MemInfo d u ret) -> Maybe (VName, Space)
memSummary [FParam lore]
[Param FParamMem]
params
        where
          memSummary :: Param (MemInfo d u ret) -> Maybe (VName, Space)
memSummary Param (MemInfo d u ret)
param
            | MemMem Space
space <- Param (MemInfo d u ret) -> MemInfo d u ret
forall dec. Param dec -> dec
paramDec Param (MemInfo d u ret)
param =
              (VName, Space) -> Maybe (VName, Space)
forall a. a -> Maybe a
Just (Param (MemInfo d u ret) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo d u ret)
param, Space
space)
            | Bool
otherwise =
              Maybe (VName, Space)
forall a. Maybe a
Nothing

      findMemInfo :: VName -> Maybe Space
      findMemInfo :: VName -> Maybe Space
findMemInfo = (VName -> Map VName Space -> Maybe Space)
-> Map VName Space -> VName -> Maybe Space
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> Map VName Space -> Maybe Space
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Map VName Space
summaries

      mkValueDesc :: Param FParamMem -> Signedness -> Maybe ValueDesc
mkValueDesc Param FParamMem
fparam Signedness
signedness =
        case (VName -> Maybe ArrayDecl
findArray (VName -> Maybe ArrayDecl) -> VName -> Maybe ArrayDecl
forall a b. (a -> b) -> a -> b
$ Param FParamMem -> VName
forall dec. Param dec -> VName
paramName Param FParamMem
fparam, Param FParamMem -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param FParamMem
fparam) of
          (Just (ArrayDecl VName
_ PrimType
bt (MemLocation VName
mem [SubExp]
shape IxFun (TExp Int64)
_)), Type
_) -> do
            Space
memspace <- VName -> Maybe Space
findMemInfo VName
mem
            ValueDesc -> Maybe ValueDesc
forall a. a -> Maybe a
Just (ValueDesc -> Maybe ValueDesc) -> ValueDesc -> Maybe ValueDesc
forall a b. (a -> b) -> a -> b
$ VName -> Space -> PrimType -> Signedness -> [SubExp] -> ValueDesc
Imp.ArrayValue VName
mem Space
memspace PrimType
bt Signedness
signedness [SubExp]
shape
          (Maybe ArrayDecl
_, Prim PrimType
bt) ->
            ValueDesc -> Maybe ValueDesc
forall a. a -> Maybe a
Just (ValueDesc -> Maybe ValueDesc) -> ValueDesc -> Maybe ValueDesc
forall a b. (a -> b) -> a -> b
$ PrimType -> Signedness -> VName -> ValueDesc
Imp.ScalarValue PrimType
bt Signedness
signedness (VName -> ValueDesc) -> VName -> ValueDesc
forall a b. (a -> b) -> a -> b
$ Param FParamMem -> VName
forall dec. Param dec -> VName
paramName Param FParamMem
fparam
          (Maybe ArrayDecl, Type)
_ ->
            Maybe ValueDesc
forall a. Maybe a
Nothing

      mkExts :: [EntryPointType] -> [Param FParamMem] -> [ExternalValue]
mkExts (TypeOpaque [Char]
desc Int
n : [EntryPointType]
epts) [Param FParamMem]
fparams =
        let ([Param FParamMem]
fparams', [Param FParamMem]
rest) = Int -> [Param FParamMem] -> ([Param FParamMem], [Param FParamMem])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
n [Param FParamMem]
fparams
         in [Char] -> [ValueDesc] -> ExternalValue
Imp.OpaqueValue
              [Char]
desc
              ((Param FParamMem -> Maybe ValueDesc)
-> [Param FParamMem] -> [ValueDesc]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (Param FParamMem -> Signedness -> Maybe ValueDesc
`mkValueDesc` Signedness
Imp.TypeDirect) [Param FParamMem]
fparams') ExternalValue -> [ExternalValue] -> [ExternalValue]
forall a. a -> [a] -> [a]
:
            [EntryPointType] -> [Param FParamMem] -> [ExternalValue]
mkExts [EntryPointType]
epts [Param FParamMem]
rest
      mkExts (EntryPointType
TypeUnsigned : [EntryPointType]
epts) (Param FParamMem
fparam : [Param FParamMem]
fparams) =
        Maybe ExternalValue -> [ExternalValue]
forall a. Maybe a -> [a]
maybeToList (ValueDesc -> ExternalValue
Imp.TransparentValue (ValueDesc -> ExternalValue)
-> Maybe ValueDesc -> Maybe ExternalValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Param FParamMem -> Signedness -> Maybe ValueDesc
mkValueDesc Param FParamMem
fparam Signedness
Imp.TypeUnsigned)
          [ExternalValue] -> [ExternalValue] -> [ExternalValue]
forall a. [a] -> [a] -> [a]
++ [EntryPointType] -> [Param FParamMem] -> [ExternalValue]
mkExts [EntryPointType]
epts [Param FParamMem]
fparams
      mkExts (EntryPointType
TypeDirect : [EntryPointType]
epts) (Param FParamMem
fparam : [Param FParamMem]
fparams) =
        Maybe ExternalValue -> [ExternalValue]
forall a. Maybe a -> [a]
maybeToList (ValueDesc -> ExternalValue
Imp.TransparentValue (ValueDesc -> ExternalValue)
-> Maybe ValueDesc -> Maybe ExternalValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Param FParamMem -> Signedness -> Maybe ValueDesc
mkValueDesc Param FParamMem
fparam Signedness
Imp.TypeDirect)
          [ExternalValue] -> [ExternalValue] -> [ExternalValue]
forall a. [a] -> [a] -> [a]
++ [EntryPointType] -> [Param FParamMem] -> [ExternalValue]
mkExts [EntryPointType]
epts [Param FParamMem]
fparams
      mkExts [EntryPointType]
_ [Param FParamMem]
_ = []

  ([Param], [ArrayDecl], [ExternalValue])
-> ImpM lore r op ([Param], [ArrayDecl], [ExternalValue])
forall (m :: * -> *) a. Monad m => a -> m a
return ([Param]
inparams, [ArrayDecl]
arrayds, [EntryPointType] -> [Param FParamMem] -> [ExternalValue]
mkExts [EntryPointType]
orig_epts [Param FParamMem]
val_params)
  where
    isArrayDecl :: VName -> ArrayDecl -> Bool
isArrayDecl VName
x (ArrayDecl VName
y PrimType
_ MemLocation
_) = VName
x VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
y

compileOutParams ::
  Mem lore =>
  [RetType lore] ->
  [EntryPointType] ->
  ImpM lore r op ([Imp.ExternalValue], [Imp.Param], Destination)
compileOutParams :: forall lore r op.
Mem lore =>
[RetType lore]
-> [EntryPointType]
-> ImpM lore r op ([ExternalValue], [Param], Destination)
compileOutParams [RetType lore]
orig_rts [EntryPointType]
orig_epts = do
  (([ExternalValue]
extvs, [ValueDestination]
dests), ([Param]
outparams, Map Int ValueDestination
ctx_dests)) <-
    WriterT
  ([Param], Map Int ValueDestination)
  (ImpM lore r op)
  ([ExternalValue], [ValueDestination])
-> ImpM
     lore
     r
     op
     (([ExternalValue], [ValueDestination]),
      ([Param], Map Int ValueDestination))
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT
   ([Param], Map Int ValueDestination)
   (ImpM lore r op)
   ([ExternalValue], [ValueDestination])
 -> ImpM
      lore
      r
      op
      (([ExternalValue], [ValueDestination]),
       ([Param], Map Int ValueDestination)))
-> WriterT
     ([Param], Map Int ValueDestination)
     (ImpM lore r op)
     ([ExternalValue], [ValueDestination])
-> ImpM
     lore
     r
     op
     (([ExternalValue], [ValueDestination]),
      ([Param], Map Int ValueDestination))
forall a b. (a -> b) -> a -> b
$ StateT
  (Map Any Any, Map Int VName)
  (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
  ([ExternalValue], [ValueDestination])
-> (Map Any Any, Map Int VName)
-> WriterT
     ([Param], Map Int ValueDestination)
     (ImpM lore r op)
     ([ExternalValue], [ValueDestination])
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT ([EntryPointType]
-> [RetTypeMem]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ([ExternalValue], [ValueDestination])
mkExts [EntryPointType]
orig_epts [RetType lore]
[RetTypeMem]
orig_rts) (Map Any Any
forall k a. Map k a
M.empty, Map Int VName
forall k a. Map k a
M.empty)
  let ctx_dests' :: [ValueDestination]
ctx_dests' = ((Int, ValueDestination) -> ValueDestination)
-> [(Int, ValueDestination)] -> [ValueDestination]
forall a b. (a -> b) -> [a] -> [b]
map (Int, ValueDestination) -> ValueDestination
forall a b. (a, b) -> b
snd ([(Int, ValueDestination)] -> [ValueDestination])
-> [(Int, ValueDestination)] -> [ValueDestination]
forall a b. (a -> b) -> a -> b
$ ((Int, ValueDestination) -> Int)
-> [(Int, ValueDestination)] -> [(Int, ValueDestination)]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn (Int, ValueDestination) -> Int
forall a b. (a, b) -> a
fst ([(Int, ValueDestination)] -> [(Int, ValueDestination)])
-> [(Int, ValueDestination)] -> [(Int, ValueDestination)]
forall a b. (a -> b) -> a -> b
$ Map Int ValueDestination -> [(Int, ValueDestination)]
forall k a. Map k a -> [(k, a)]
M.toList Map Int ValueDestination
ctx_dests
  ([ExternalValue], [Param], Destination)
-> ImpM lore r op ([ExternalValue], [Param], Destination)
forall (m :: * -> *) a. Monad m => a -> m a
return ([ExternalValue]
extvs, [Param]
outparams, Maybe Int -> [ValueDestination] -> Destination
Destination Maybe Int
forall a. Maybe a
Nothing ([ValueDestination] -> Destination)
-> [ValueDestination] -> Destination
forall a b. (a -> b) -> a -> b
$ [ValueDestination]
ctx_dests' [ValueDestination] -> [ValueDestination] -> [ValueDestination]
forall a. Semigroup a => a -> a -> a
<> [ValueDestination]
dests)
  where
    imp :: ImpM lore r op a
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     a
imp = WriterT ([Param], Map Int ValueDestination) (ImpM lore r op) a
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op) a
 -> StateT
      (Map Any Any, Map Int VName)
      (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
      a)
-> (ImpM lore r op a
    -> WriterT ([Param], Map Int ValueDestination) (ImpM lore r op) a)
-> ImpM lore r op a
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ImpM lore r op a
-> WriterT ([Param], Map Int ValueDestination) (ImpM lore r op) a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift

    mkExts :: [EntryPointType]
-> [RetTypeMem]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ([ExternalValue], [ValueDestination])
mkExts (TypeOpaque [Char]
desc Int
n : [EntryPointType]
epts) [RetTypeMem]
rts = do
      let ([RetTypeMem]
rts', [RetTypeMem]
rest) = Int -> [RetTypeMem] -> ([RetTypeMem], [RetTypeMem])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
n [RetTypeMem]
rts
      ([ValueDesc]
evs, [ValueDestination]
dests) <- [(ValueDesc, ValueDestination)]
-> ([ValueDesc], [ValueDestination])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(ValueDesc, ValueDestination)]
 -> ([ValueDesc], [ValueDestination]))
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     [(ValueDesc, ValueDestination)]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ([ValueDesc], [ValueDestination])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (RetTypeMem
 -> Signedness
 -> StateT
      (Map Any Any, Map Int VName)
      (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
      (ValueDesc, ValueDestination))
-> [RetTypeMem]
-> [Signedness]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     [(ValueDesc, ValueDestination)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM RetTypeMem
-> Signedness
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     (ValueDesc, ValueDestination)
mkParam [RetTypeMem]
rts' (Signedness -> [Signedness]
forall a. a -> [a]
repeat Signedness
Imp.TypeDirect)
      ([ExternalValue]
more_values, [ValueDestination]
more_dests) <- [EntryPointType]
-> [RetTypeMem]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ([ExternalValue], [ValueDestination])
mkExts [EntryPointType]
epts [RetTypeMem]
rest
      ([ExternalValue], [ValueDestination])
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ([ExternalValue], [ValueDestination])
forall (m :: * -> *) a. Monad m => a -> m a
return
        ( [Char] -> [ValueDesc] -> ExternalValue
Imp.OpaqueValue [Char]
desc [ValueDesc]
evs ExternalValue -> [ExternalValue] -> [ExternalValue]
forall a. a -> [a] -> [a]
: [ExternalValue]
more_values,
          [ValueDestination]
dests [ValueDestination] -> [ValueDestination] -> [ValueDestination]
forall a. [a] -> [a] -> [a]
++ [ValueDestination]
more_dests
        )
    mkExts (EntryPointType
TypeUnsigned : [EntryPointType]
epts) (RetTypeMem
rt : [RetTypeMem]
rts) = do
      (ValueDesc
ev, ValueDestination
dest) <- RetTypeMem
-> Signedness
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     (ValueDesc, ValueDestination)
mkParam RetTypeMem
rt Signedness
Imp.TypeUnsigned
      ([ExternalValue]
more_values, [ValueDestination]
more_dests) <- [EntryPointType]
-> [RetTypeMem]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ([ExternalValue], [ValueDestination])
mkExts [EntryPointType]
epts [RetTypeMem]
rts
      ([ExternalValue], [ValueDestination])
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ([ExternalValue], [ValueDestination])
forall (m :: * -> *) a. Monad m => a -> m a
return
        ( ValueDesc -> ExternalValue
Imp.TransparentValue ValueDesc
ev ExternalValue -> [ExternalValue] -> [ExternalValue]
forall a. a -> [a] -> [a]
: [ExternalValue]
more_values,
          ValueDestination
dest ValueDestination -> [ValueDestination] -> [ValueDestination]
forall a. a -> [a] -> [a]
: [ValueDestination]
more_dests
        )
    mkExts (EntryPointType
TypeDirect : [EntryPointType]
epts) (RetTypeMem
rt : [RetTypeMem]
rts) = do
      (ValueDesc
ev, ValueDestination
dest) <- RetTypeMem
-> Signedness
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     (ValueDesc, ValueDestination)
mkParam RetTypeMem
rt Signedness
Imp.TypeDirect
      ([ExternalValue]
more_values, [ValueDestination]
more_dests) <- [EntryPointType]
-> [RetTypeMem]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ([ExternalValue], [ValueDestination])
mkExts [EntryPointType]
epts [RetTypeMem]
rts
      ([ExternalValue], [ValueDestination])
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ([ExternalValue], [ValueDestination])
forall (m :: * -> *) a. Monad m => a -> m a
return
        ( ValueDesc -> ExternalValue
Imp.TransparentValue ValueDesc
ev ExternalValue -> [ExternalValue] -> [ExternalValue]
forall a. a -> [a] -> [a]
: [ExternalValue]
more_values,
          ValueDestination
dest ValueDestination -> [ValueDestination] -> [ValueDestination]
forall a. a -> [a] -> [a]
: [ValueDestination]
more_dests
        )
    mkExts [EntryPointType]
_ [RetTypeMem]
_ = ([ExternalValue], [ValueDestination])
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ([ExternalValue], [ValueDestination])
forall (m :: * -> *) a. Monad m => a -> m a
return ([], [])

    mkParam :: RetTypeMem
-> Signedness
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     (ValueDesc, ValueDestination)
mkParam MemMem {} Signedness
_ =
      [Char]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     (ValueDesc, ValueDestination)
forall a. HasCallStack => [Char] -> a
error [Char]
"Functions may not explicitly return memory blocks."
    mkParam MemAcc {} Signedness
_ =
      [Char]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     (ValueDesc, ValueDestination)
forall a. HasCallStack => [Char] -> a
error [Char]
"Functions may not return accumulators."
    mkParam (MemPrim PrimType
t) Signedness
ept = do
      VName
out <- ImpM lore r op VName
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     VName
forall {a}.
ImpM lore r op a
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     a
imp (ImpM lore r op VName
 -> StateT
      (Map Any Any, Map Int VName)
      (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
      VName)
-> ImpM lore r op VName
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     VName
forall a b. (a -> b) -> a -> b
$ [Char] -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"scalar_out"
      ([Param], Map Int ValueDestination)
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([VName -> PrimType -> Param
Imp.ScalarParam VName
out PrimType
t], Map Int ValueDestination
forall a. Monoid a => a
mempty)
      (ValueDesc, ValueDestination)
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     (ValueDesc, ValueDestination)
forall (m :: * -> *) a. Monad m => a -> m a
return (PrimType -> Signedness -> VName -> ValueDesc
Imp.ScalarValue PrimType
t Signedness
ept VName
out, VName -> ValueDestination
ScalarDestination VName
out)
    mkParam (MemArray PrimType
t ShapeBase (Ext SubExp)
shape Uniqueness
_ MemReturn
dec) Signedness
ept = do
      Space
space <- (Env lore r op -> Space)
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     Space
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Space
forall lore r op. Env lore r op -> Space
envDefaultSpace
      VName
memout <- case MemReturn
dec of
        ReturnsNewBlock Space
_ Int
x ExtIxFun
_ixfun -> do
          VName
memout <- ImpM lore r op VName
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     VName
forall {a}.
ImpM lore r op a
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     a
imp (ImpM lore r op VName
 -> StateT
      (Map Any Any, Map Int VName)
      (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
      VName)
-> ImpM lore r op VName
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     VName
forall a b. (a -> b) -> a -> b
$ [Char] -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"out_mem"
          ([Param], Map Int ValueDestination)
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell
            ( [VName -> Space -> Param
Imp.MemParam VName
memout Space
space],
              Int -> ValueDestination -> Map Int ValueDestination
forall k a. k -> a -> Map k a
M.singleton Int
x (ValueDestination -> Map Int ValueDestination)
-> ValueDestination -> Map Int ValueDestination
forall a b. (a -> b) -> a -> b
$ VName -> ValueDestination
MemoryDestination VName
memout
            )
          VName
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
memout
        ReturnsInBlock VName
memout ExtIxFun
_ ->
          VName
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
memout
      [SubExp]
resultshape <- (Ext SubExp
 -> StateT
      (Map Any Any, Map Int VName)
      (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
      SubExp)
-> [Ext SubExp]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Ext SubExp
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     SubExp
inspectExtSize ([Ext SubExp]
 -> StateT
      (Map Any Any, Map Int VName)
      (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
      [SubExp])
-> [Ext SubExp]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     [SubExp]
forall a b. (a -> b) -> a -> b
$ ShapeBase (Ext SubExp) -> [Ext SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase (Ext SubExp)
shape
      (ValueDesc, ValueDestination)
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     (ValueDesc, ValueDestination)
forall (m :: * -> *) a. Monad m => a -> m a
return
        ( VName -> Space -> PrimType -> Signedness -> [SubExp] -> ValueDesc
Imp.ArrayValue VName
memout Space
space PrimType
t Signedness
ept [SubExp]
resultshape,
          Maybe MemLocation -> ValueDestination
ArrayDestination Maybe MemLocation
forall a. Maybe a
Nothing
        )

    inspectExtSize :: Ext SubExp
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     SubExp
inspectExtSize (Ext Int
x) = do
      (Map Any Any
memseen, Map Int VName
arrseen) <- StateT
  (Map Any Any, Map Int VName)
  (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
  (Map Any Any, Map Int VName)
forall s (m :: * -> *). MonadState s m => m s
get
      case Int -> Map Int VName -> Maybe VName
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Int
x Map Int VName
arrseen of
        Maybe VName
Nothing -> do
          VName
out <- ImpM lore r op VName
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     VName
forall {a}.
ImpM lore r op a
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     a
imp (ImpM lore r op VName
 -> StateT
      (Map Any Any, Map Int VName)
      (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
      VName)
-> ImpM lore r op VName
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     VName
forall a b. (a -> b) -> a -> b
$ [Char] -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"out_arrsize"
          ([Param], Map Int ValueDestination)
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell
            ( [VName -> PrimType -> Param
Imp.ScalarParam VName
out PrimType
int64],
              Int -> ValueDestination -> Map Int ValueDestination
forall k a. k -> a -> Map k a
M.singleton Int
x (ValueDestination -> Map Int ValueDestination)
-> ValueDestination -> Map Int ValueDestination
forall a b. (a -> b) -> a -> b
$ VName -> ValueDestination
ScalarDestination VName
out
            )
          (Map Any Any, Map Int VName)
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (Map Any Any
memseen, Int -> VName -> Map Int VName -> Map Int VName
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Int
x VName
out Map Int VName
arrseen)
          SubExp
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
 -> StateT
      (Map Any Any, Map Int VName)
      (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
      SubExp)
-> SubExp
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
out
        Just VName
out ->
          SubExp
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
 -> StateT
      (Map Any Any, Map Int VName)
      (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
      SubExp)
-> SubExp
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
out
    inspectExtSize (Free SubExp
se) =
      SubExp
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
se

compileFunDef ::
  Mem lore =>
  FunDef lore ->
  ImpM lore r op ()
compileFunDef :: forall lore r op. Mem lore => FunDef lore -> ImpM lore r op ()
compileFunDef (FunDef Maybe EntryPoint
entry Attrs
_ Name
fname [RetType lore]
rettype [FParam lore]
params BodyT lore
body) =
  (Env lore r op -> Env lore r op)
-> ImpM lore r op () -> ImpM lore r op ()
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (\Env lore r op
env -> Env lore r op
env {envFunction :: Maybe Name
envFunction = Maybe Name
name_entry Maybe Name -> Maybe Name -> Maybe Name
forall (m :: * -> *) a. MonadPlus m => m a -> m a -> m a
`mplus` Name -> Maybe Name
forall a. a -> Maybe a
Just Name
fname}) (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ do
    (([Param]
outparams, [Param]
inparams, [ExternalValue]
results, [ExternalValue]
args), Code op
body') <- ImpM lore r op ([Param], [Param], [ExternalValue], [ExternalValue])
-> ImpM
     lore
     r
     op
     (([Param], [Param], [ExternalValue], [ExternalValue]), Code op)
forall lore r op a. ImpM lore r op a -> ImpM lore r op (a, Code op)
collect' ImpM lore r op ([Param], [Param], [ExternalValue], [ExternalValue])
compile
    Name -> Function op -> ImpM lore r op ()
forall op lore r. Name -> Function op -> ImpM lore r op ()
emitFunction Name
fname (Function op -> ImpM lore r op ())
-> Function op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Maybe Name
-> [Param]
-> [Param]
-> Code op
-> [ExternalValue]
-> [ExternalValue]
-> Function op
forall a.
Maybe Name
-> [Param]
-> [Param]
-> Code a
-> [ExternalValue]
-> [ExternalValue]
-> FunctionT a
Imp.Function Maybe Name
name_entry [Param]
outparams [Param]
inparams Code op
body' [ExternalValue]
results [ExternalValue]
args
  where
    (Maybe Name
name_entry, [EntryPointType]
params_entry, [EntryPointType]
ret_entry) = case Maybe EntryPoint
entry of
      Maybe EntryPoint
Nothing ->
        ( Maybe Name
forall a. Maybe a
Nothing,
          Int -> EntryPointType -> [EntryPointType]
forall a. Int -> a -> [a]
replicate ([Param FParamMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [FParam lore]
[Param FParamMem]
params) EntryPointType
TypeDirect,
          Int -> EntryPointType -> [EntryPointType]
forall a. Int -> a -> [a]
replicate ([RetTypeMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [RetType lore]
[RetTypeMem]
rettype) EntryPointType
TypeDirect
        )
      Just (Name
x, [EntryPointType]
y, [EntryPointType]
z) -> (Name -> Maybe Name
forall a. a -> Maybe a
Just Name
x, [EntryPointType]
y, [EntryPointType]
z)
    compile :: ImpM lore r op ([Param], [Param], [ExternalValue], [ExternalValue])
compile = do
      ([Param]
inparams, [ArrayDecl]
arrayds, [ExternalValue]
args) <- [FParam lore]
-> [EntryPointType]
-> ImpM lore r op ([Param], [ArrayDecl], [ExternalValue])
forall lore r op.
Mem lore =>
[FParam lore]
-> [EntryPointType]
-> ImpM lore r op ([Param], [ArrayDecl], [ExternalValue])
compileInParams [FParam lore]
params [EntryPointType]
params_entry
      ([ExternalValue]
results, [Param]
outparams, Destination Maybe Int
_ [ValueDestination]
dests) <- [RetType lore]
-> [EntryPointType]
-> ImpM lore r op ([ExternalValue], [Param], Destination)
forall lore r op.
Mem lore =>
[RetType lore]
-> [EntryPointType]
-> ImpM lore r op ([ExternalValue], [Param], Destination)
compileOutParams [RetType lore]
rettype [EntryPointType]
ret_entry
      [FParam lore] -> ImpM lore r op ()
forall lore r op. Mem lore => [FParam lore] -> ImpM lore r op ()
addFParams [FParam lore]
params
      [ArrayDecl] -> ImpM lore r op ()
forall lore r op. [ArrayDecl] -> ImpM lore r op ()
addArrays [ArrayDecl]
arrayds

      let Body BodyDec lore
_ Stms lore
stms [SubExp]
ses = BodyT lore
body
      Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms ([SubExp] -> Names
forall a. FreeIn a => a -> Names
freeIn [SubExp]
ses) Stms lore
stms (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
        [(ValueDestination, SubExp)]
-> ((ValueDestination, SubExp) -> ImpM lore r op ())
-> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([ValueDestination] -> [SubExp] -> [(ValueDestination, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [ValueDestination]
dests [SubExp]
ses) (((ValueDestination, SubExp) -> ImpM lore r op ())
 -> ImpM lore r op ())
-> ((ValueDestination, SubExp) -> ImpM lore r op ())
-> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(ValueDestination
d, SubExp
se) -> ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIMDest ValueDestination
d [] SubExp
se []

      ([Param], [Param], [ExternalValue], [ExternalValue])
-> ImpM
     lore r op ([Param], [Param], [ExternalValue], [ExternalValue])
forall (m :: * -> *) a. Monad m => a -> m a
return ([Param]
outparams, [Param]
inparams, [ExternalValue]
results, [ExternalValue]
args)

compileBody :: (Mem lore) => Pattern lore -> Body lore -> ImpM lore r op ()
compileBody :: forall lore r op.
Mem lore =>
Pattern lore -> Body lore -> ImpM lore r op ()
compileBody Pattern lore
pat (Body BodyDec lore
_ Stms lore
bnds [SubExp]
ses) = do
  Destination Maybe Int
_ [ValueDestination]
dests <- Pattern lore -> ImpM lore r op Destination
forall lore r op.
Mem lore =>
Pattern lore -> ImpM lore r op Destination
destinationFromPattern Pattern lore
pat
  Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms ([SubExp] -> Names
forall a. FreeIn a => a -> Names
freeIn [SubExp]
ses) Stms lore
bnds (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
    [(ValueDestination, SubExp)]
-> ((ValueDestination, SubExp) -> ImpM lore r op ())
-> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([ValueDestination] -> [SubExp] -> [(ValueDestination, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [ValueDestination]
dests [SubExp]
ses) (((ValueDestination, SubExp) -> ImpM lore r op ())
 -> ImpM lore r op ())
-> ((ValueDestination, SubExp) -> ImpM lore r op ())
-> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(ValueDestination
d, SubExp
se) -> ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIMDest ValueDestination
d [] SubExp
se []

compileBody' :: [Param dec] -> Body lore -> ImpM lore r op ()
compileBody' :: forall dec lore r op. [Param dec] -> Body lore -> ImpM lore r op ()
compileBody' [Param dec]
params (Body BodyDec lore
_ Stms lore
bnds [SubExp]
ses) =
  Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms ([SubExp] -> Names
forall a. FreeIn a => a -> Names
freeIn [SubExp]
ses) Stms lore
bnds (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
    [(Param dec, SubExp)]
-> ((Param dec, SubExp) -> ImpM lore r op ()) -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param dec] -> [SubExp] -> [(Param dec, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param dec]
params [SubExp]
ses) (((Param dec, SubExp) -> ImpM lore r op ()) -> ImpM lore r op ())
-> ((Param dec, SubExp) -> ImpM lore r op ()) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(Param dec
param, SubExp
se) -> VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
param) [] SubExp
se []

compileLoopBody :: Typed dec => [Param dec] -> Body lore -> ImpM lore r op ()
compileLoopBody :: forall dec lore r op.
Typed dec =>
[Param dec] -> Body lore -> ImpM lore r op ()
compileLoopBody [Param dec]
mergeparams (Body BodyDec lore
_ Stms lore
bnds [SubExp]
ses) = do
  -- We cannot write the results to the merge parameters immediately,
  -- as some of the results may actually *be* merge parameters, and
  -- would thus be clobbered.  Therefore, we first copy to new
  -- variables mirroring the merge parameters, and then copy this
  -- buffer to the merge parameters.  This is efficient, because the
  -- operations are all scalar operations.
  [VName]
tmpnames <- (Param dec -> ImpM lore r op VName)
-> [Param dec] -> ImpM lore r op [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([Char] -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName ([Char] -> ImpM lore r op VName)
-> (Param dec -> [Char]) -> Param dec -> ImpM lore r op VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
"_tmp") ShowS -> (Param dec -> [Char]) -> Param dec -> [Char]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> [Char]
baseString (VName -> [Char]) -> (Param dec -> VName) -> Param dec -> [Char]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param dec -> VName
forall dec. Param dec -> VName
paramName) [Param dec]
mergeparams
  Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms ([SubExp] -> Names
forall a. FreeIn a => a -> Names
freeIn [SubExp]
ses) Stms lore
bnds (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ do
    [ImpM lore r op ()]
copy_to_merge_params <- [(Param dec, VName, SubExp)]
-> ((Param dec, VName, SubExp)
    -> ImpM lore r op (ImpM lore r op ()))
-> ImpM lore r op [ImpM lore r op ()]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Param dec] -> [VName] -> [SubExp] -> [(Param dec, VName, SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Param dec]
mergeparams [VName]
tmpnames [SubExp]
ses) (((Param dec, VName, SubExp) -> ImpM lore r op (ImpM lore r op ()))
 -> ImpM lore r op [ImpM lore r op ()])
-> ((Param dec, VName, SubExp)
    -> ImpM lore r op (ImpM lore r op ()))
-> ImpM lore r op [ImpM lore r op ()]
forall a b. (a -> b) -> a -> b
$ \(Param dec
p, VName
tmp, SubExp
se) ->
      case Param dec -> Type
forall t. Typed t => t -> Type
typeOf Param dec
p of
        Prim PrimType
pt -> do
          Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Volatility -> PrimType -> Code op
forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
tmp Volatility
Imp.Nonvolatile PrimType
pt
          Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
tmp (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
pt SubExp
se
          ImpM lore r op () -> ImpM lore r op (ImpM lore r op ())
forall (m :: * -> *) a. Monad m => a -> m a
return (ImpM lore r op () -> ImpM lore r op (ImpM lore r op ()))
-> ImpM lore r op () -> ImpM lore r op (ImpM lore r op ())
forall a b. (a -> b) -> a -> b
$ Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar (Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p) (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
tmp PrimType
pt
        Mem Space
space | Var VName
v <- SubExp
se -> do
          Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Code op
forall a. VName -> Space -> Code a
Imp.DeclareMem VName
tmp Space
space
          Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> Space -> Code op
forall a. VName -> VName -> Space -> Code a
Imp.SetMem VName
tmp VName
v Space
space
          ImpM lore r op () -> ImpM lore r op (ImpM lore r op ())
forall (m :: * -> *) a. Monad m => a -> m a
return (ImpM lore r op () -> ImpM lore r op (ImpM lore r op ()))
-> ImpM lore r op () -> ImpM lore r op (ImpM lore r op ())
forall a b. (a -> b) -> a -> b
$ Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> Space -> Code op
forall a. VName -> VName -> Space -> Code a
Imp.SetMem (Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p) VName
tmp Space
space
        Type
_ -> ImpM lore r op () -> ImpM lore r op (ImpM lore r op ())
forall (m :: * -> *) a. Monad m => a -> m a
return (ImpM lore r op () -> ImpM lore r op (ImpM lore r op ()))
-> ImpM lore r op () -> ImpM lore r op (ImpM lore r op ())
forall a b. (a -> b) -> a -> b
$ () -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    [ImpM lore r op ()] -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ [ImpM lore r op ()]
copy_to_merge_params

compileStms :: Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms :: forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
alive_after_stms Stms lore
all_stms ImpM lore r op ()
m = do
  StmsCompiler lore r op
cb <- (Env lore r op -> StmsCompiler lore r op)
-> ImpM lore r op (StmsCompiler lore r op)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> StmsCompiler lore r op
forall lore r op. Env lore r op -> StmsCompiler lore r op
envStmsCompiler
  StmsCompiler lore r op
cb Names
alive_after_stms Stms lore
all_stms ImpM lore r op ()
m

defCompileStms ::
  (Mem lore, FreeIn op) =>
  Names ->
  Stms lore ->
  ImpM lore r op () ->
  ImpM lore r op ()
defCompileStms :: forall lore op r.
(Mem lore, FreeIn op) =>
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
defCompileStms Names
alive_after_stms Stms lore
all_stms ImpM lore r op ()
m =
  -- We keep track of any memory blocks produced by the statements,
  -- and after the last time that memory block is used, we insert a
  -- Free.  This is very conservative, but can cut down on lifetimes
  -- in some cases.
  ImpM lore r op Names -> ImpM lore r op ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ImpM lore r op Names -> ImpM lore r op ())
-> ImpM lore r op Names -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Set (VName, Space) -> [Stm lore] -> ImpM lore r op Names
compileStms' Set (VName, Space)
forall a. Monoid a => a
mempty ([Stm lore] -> ImpM lore r op Names)
-> [Stm lore] -> ImpM lore r op Names
forall a b. (a -> b) -> a -> b
$ Stms lore -> [Stm lore]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms lore
all_stms
  where
    compileStms' :: Set (VName, Space) -> [Stm lore] -> ImpM lore r op Names
compileStms' Set (VName, Space)
allocs (Let Pattern lore
pat StmAux (ExpDec lore)
aux Exp lore
e : [Stm lore]
bs) = do
      Maybe (Exp lore) -> [PatElem lore] -> ImpM lore r op ()
forall lore r op.
Mem lore =>
Maybe (Exp lore) -> [PatElem lore] -> ImpM lore r op ()
dVars (Exp lore -> Maybe (Exp lore)
forall a. a -> Maybe a
Just Exp lore
e) (PatternT LParamMem -> [PatElemT LParamMem]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern lore
PatternT LParamMem
pat)

      Code op
e_code <-
        Attrs -> ImpM lore r op (Code op) -> ImpM lore r op (Code op)
forall lore r op a. Attrs -> ImpM lore r op a -> ImpM lore r op a
localAttrs (StmAux (ExpDec lore) -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux (ExpDec lore)
aux) (ImpM lore r op (Code op) -> ImpM lore r op (Code op))
-> ImpM lore r op (Code op) -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$
          ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (ImpM lore r op () -> ImpM lore r op (Code op))
-> ImpM lore r op () -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$ Pattern lore -> Exp lore -> ImpM lore r op ()
forall lore r op. Pattern lore -> Exp lore -> ImpM lore r op ()
compileExp Pattern lore
pat Exp lore
e
      (Names
live_after, Code op
bs_code) <- ImpM lore r op Names -> ImpM lore r op (Names, Code op)
forall lore r op a. ImpM lore r op a -> ImpM lore r op (a, Code op)
collect' (ImpM lore r op Names -> ImpM lore r op (Names, Code op))
-> ImpM lore r op Names -> ImpM lore r op (Names, Code op)
forall a b. (a -> b) -> a -> b
$ Set (VName, Space) -> [Stm lore] -> ImpM lore r op Names
compileStms' (PatternT LParamMem -> Set (VName, Space)
patternAllocs Pattern lore
PatternT LParamMem
pat Set (VName, Space) -> Set (VName, Space) -> Set (VName, Space)
forall a. Semigroup a => a -> a -> a
<> Set (VName, Space)
allocs) [Stm lore]
bs
      let dies_here :: VName -> Bool
dies_here VName
v =
            Bool -> Bool
not (VName
v VName -> Names -> Bool
`nameIn` Names
live_after)
              Bool -> Bool -> Bool
&& VName
v VName -> Names -> Bool
`nameIn` Code op -> Names
forall a. FreeIn a => a -> Names
freeIn Code op
e_code
          to_free :: Set (VName, Space)
to_free = ((VName, Space) -> Bool)
-> Set (VName, Space) -> Set (VName, Space)
forall a. (a -> Bool) -> Set a -> Set a
S.filter (VName -> Bool
dies_here (VName -> Bool)
-> ((VName, Space) -> VName) -> (VName, Space) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, Space) -> VName
forall a b. (a, b) -> a
fst) Set (VName, Space)
allocs

      Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit Code op
e_code
      ((VName, Space) -> ImpM lore r op ())
-> Set (VName, Space) -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ())
-> ((VName, Space) -> Code op)
-> (VName, Space)
-> ImpM lore r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> Space -> Code op) -> (VName, Space) -> Code op
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry VName -> Space -> Code op
forall a. VName -> Space -> Code a
Imp.Free) Set (VName, Space)
to_free
      Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit Code op
bs_code

      Names -> ImpM lore r op Names
forall (m :: * -> *) a. Monad m => a -> m a
return (Names -> ImpM lore r op Names) -> Names -> ImpM lore r op Names
forall a b. (a -> b) -> a -> b
$ Code op -> Names
forall a. FreeIn a => a -> Names
freeIn Code op
e_code Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
live_after
    compileStms' Set (VName, Space)
_ [] = do
      Code op
code <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect ImpM lore r op ()
m
      Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit Code op
code
      Names -> ImpM lore r op Names
forall (m :: * -> *) a. Monad m => a -> m a
return (Names -> ImpM lore r op Names) -> Names -> ImpM lore r op Names
forall a b. (a -> b) -> a -> b
$ Code op -> Names
forall a. FreeIn a => a -> Names
freeIn Code op
code Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
alive_after_stms

    patternAllocs :: PatternT LParamMem -> Set (VName, Space)
patternAllocs = [(VName, Space)] -> Set (VName, Space)
forall a. Ord a => [a] -> Set a
S.fromList ([(VName, Space)] -> Set (VName, Space))
-> (PatternT LParamMem -> [(VName, Space)])
-> PatternT LParamMem
-> Set (VName, Space)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElemT LParamMem -> Maybe (VName, Space))
-> [PatElemT LParamMem] -> [(VName, Space)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe PatElemT LParamMem -> Maybe (VName, Space)
forall {dec}. Typed dec => PatElemT dec -> Maybe (VName, Space)
isMemPatElem ([PatElemT LParamMem] -> [(VName, Space)])
-> (PatternT LParamMem -> [PatElemT LParamMem])
-> PatternT LParamMem
-> [(VName, Space)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatternT LParamMem -> [PatElemT LParamMem]
forall dec. PatternT dec -> [PatElemT dec]
patternElements
    isMemPatElem :: PatElemT dec -> Maybe (VName, Space)
isMemPatElem PatElemT dec
pe = case PatElemT dec -> Type
forall dec. Typed dec => PatElemT dec -> Type
patElemType PatElemT dec
pe of
      Mem Space
space -> (VName, Space) -> Maybe (VName, Space)
forall a. a -> Maybe a
Just (PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT dec
pe, Space
space)
      Type
_ -> Maybe (VName, Space)
forall a. Maybe a
Nothing

compileExp :: Pattern lore -> Exp lore -> ImpM lore r op ()
compileExp :: forall lore r op. Pattern lore -> Exp lore -> ImpM lore r op ()
compileExp Pattern lore
pat Exp lore
e = do
  Pattern lore -> Exp lore -> ImpM lore r op ()
ec <- (Env lore r op -> Pattern lore -> Exp lore -> ImpM lore r op ())
-> ImpM lore r op (Pattern lore -> Exp lore -> ImpM lore r op ())
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Pattern lore -> Exp lore -> ImpM lore r op ()
forall lore r op. Env lore r op -> ExpCompiler lore r op
envExpCompiler
  Pattern lore -> Exp lore -> ImpM lore r op ()
ec Pattern lore
pat Exp lore
e

defCompileExp ::
  (Mem lore) =>
  Pattern lore ->
  Exp lore ->
  ImpM lore r op ()
defCompileExp :: forall lore r op.
Mem lore =>
Pattern lore -> Exp lore -> ImpM lore r op ()
defCompileExp Pattern lore
pat (If SubExp
cond BodyT lore
tbranch BodyT lore
fbranch IfDec (BranchType lore)
_) =
  TExp Bool
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
TExp Bool
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf (SubExp -> TExp Bool
forall a. ToExp a => a -> TExp Bool
toBoolExp SubExp
cond) (Pattern lore -> BodyT lore -> ImpM lore r op ()
forall lore r op.
Mem lore =>
Pattern lore -> Body lore -> ImpM lore r op ()
compileBody Pattern lore
pat BodyT lore
tbranch) (Pattern lore -> BodyT lore -> ImpM lore r op ()
forall lore r op.
Mem lore =>
Pattern lore -> Body lore -> ImpM lore r op ()
compileBody Pattern lore
pat BodyT lore
fbranch)
defCompileExp Pattern lore
pat (Apply Name
fname [(SubExp, Diet)]
args [RetType lore]
_ (Safety, SrcLoc, [SrcLoc])
_) = do
  Destination
dest <- Pattern lore -> ImpM lore r op Destination
forall lore r op.
Mem lore =>
Pattern lore -> ImpM lore r op Destination
destinationFromPattern Pattern lore
pat
  [VName]
targets <- Destination -> ImpM lore r op [VName]
forall lore r op. Destination -> ImpM lore r op [VName]
funcallTargets Destination
dest
  [Arg]
args' <- [Maybe Arg] -> [Arg]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe Arg] -> [Arg])
-> ImpM lore r op [Maybe Arg] -> ImpM lore r op [Arg]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((SubExp, Diet) -> ImpM lore r op (Maybe Arg))
-> [(SubExp, Diet)] -> ImpM lore r op [Maybe Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SubExp, Diet) -> ImpM lore r op (Maybe Arg)
forall {m :: * -> *} {t} {b}.
(Monad m, HasScope t m) =>
(SubExp, b) -> m (Maybe Arg)
compileArg [(SubExp, Diet)]
args
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Name -> [Arg] -> Code op
forall a. [VName] -> Name -> [Arg] -> Code a
Imp.Call [VName]
targets Name
fname [Arg]
args'
  where
    compileArg :: (SubExp, b) -> m (Maybe Arg)
compileArg (SubExp
se, b
_) = do
      Type
t <- SubExp -> m Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
se
      case (SubExp
se, Type
t) of
        (SubExp
_, Prim PrimType
pt) -> Maybe Arg -> m (Maybe Arg)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Arg -> m (Maybe Arg)) -> Maybe Arg -> m (Maybe Arg)
forall a b. (a -> b) -> a -> b
$ Arg -> Maybe Arg
forall a. a -> Maybe a
Just (Arg -> Maybe Arg) -> Arg -> Maybe Arg
forall a b. (a -> b) -> a -> b
$ Exp -> Arg
Imp.ExpArg (Exp -> Arg) -> Exp -> Arg
forall a b. (a -> b) -> a -> b
$ PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
pt SubExp
se
        (Var VName
v, Mem {}) -> Maybe Arg -> m (Maybe Arg)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Arg -> m (Maybe Arg)) -> Maybe Arg -> m (Maybe Arg)
forall a b. (a -> b) -> a -> b
$ Arg -> Maybe Arg
forall a. a -> Maybe a
Just (Arg -> Maybe Arg) -> Arg -> Maybe Arg
forall a b. (a -> b) -> a -> b
$ VName -> Arg
Imp.MemArg VName
v
        (SubExp, Type)
_ -> Maybe Arg -> m (Maybe Arg)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Arg
forall a. Maybe a
Nothing
defCompileExp Pattern lore
pat (BasicOp BasicOp
op) = Pattern lore -> BasicOp -> ImpM lore r op ()
forall lore r op.
Mem lore =>
Pattern lore -> BasicOp -> ImpM lore r op ()
defCompileBasicOp Pattern lore
pat BasicOp
op
defCompileExp Pattern lore
pat (DoLoop [(FParam lore, SubExp)]
ctx [(FParam lore, SubExp)]
val LoopForm lore
form BodyT lore
body) = do
  Attrs
attrs <- ImpM lore r op Attrs
forall lore r op. ImpM lore r op Attrs
askAttrs
  Bool -> ImpM lore r op () -> ImpM lore r op ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Attr
"unroll" Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs) (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
    SrcLoc -> [SrcLoc] -> [Char] -> ImpM lore r op ()
forall loc lore r op.
Located loc =>
loc -> [loc] -> [Char] -> ImpM lore r op ()
warn (SrcLoc
forall a. IsLocation a => a
noLoc :: SrcLoc) [] [Char]
"#[unroll] on loop with unknown number of iterations." -- FIXME: no location.
  [FParam lore] -> ImpM lore r op ()
forall lore r op. Mem lore => [FParam lore] -> ImpM lore r op ()
dFParams [FParam lore]
[Param FParamMem]
mergepat
  [(Param FParamMem, SubExp)]
-> ((Param FParamMem, SubExp) -> ImpM lore r op ())
-> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(Param FParamMem, SubExp)]
merge (((Param FParamMem, SubExp) -> ImpM lore r op ())
 -> ImpM lore r op ())
-> ((Param FParamMem, SubExp) -> ImpM lore r op ())
-> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(Param FParamMem
p, SubExp
se) ->
    Bool -> ImpM lore r op () -> ImpM lore r op ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ((Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0) (Int -> Bool) -> Int -> Bool
forall a b. (a -> b) -> a -> b
$ Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (Type -> Int) -> Type -> Int
forall a b. (a -> b) -> a -> b
$ Param FParamMem -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param FParamMem
p) (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
      VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (Param FParamMem -> VName
forall dec. Param dec -> VName
paramName Param FParamMem
p) [] SubExp
se []

  let doBody :: ImpM lore r op ()
doBody = [Param FParamMem] -> BodyT lore -> ImpM lore r op ()
forall dec lore r op.
Typed dec =>
[Param dec] -> Body lore -> ImpM lore r op ()
compileLoopBody [Param FParamMem]
mergepat BodyT lore
body

  case LoopForm lore
form of
    ForLoop VName
i IntType
_ SubExp
bound [(LParam lore, VName)]
loopvars -> do
      let setLoopParam :: (Param LParamMem, VName) -> ImpM lore r op ()
setLoopParam (Param LParamMem
p, VName
a)
            | Prim PrimType
_ <- Param LParamMem -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param LParamMem
p =
              VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var VName
a) [TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix (TExp Int64 -> DimIndex (TExp Int64))
-> TExp Int64 -> DimIndex (TExp Int64)
forall a b. (a -> b) -> a -> b
$ VName -> TExp Int64
Imp.vi64 VName
i]
            | Bool
otherwise =
              () -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

      Exp
bound' <- SubExp -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
bound

      [LParam lore] -> ImpM lore r op ()
forall lore r op. Mem lore => [LParam lore] -> ImpM lore r op ()
dLParams ([LParam lore] -> ImpM lore r op ())
-> [LParam lore] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ ((Param LParamMem, VName) -> Param LParamMem)
-> [(Param LParamMem, VName)] -> [Param LParamMem]
forall a b. (a -> b) -> [a] -> [b]
map (Param LParamMem, VName) -> Param LParamMem
forall a b. (a, b) -> a
fst [(LParam lore, VName)]
[(Param LParamMem, VName)]
loopvars
      VName -> Exp -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
VName -> Exp -> ImpM lore r op () -> ImpM lore r op ()
sFor' VName
i Exp
bound' (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
        ((Param LParamMem, VName) -> ImpM lore r op ())
-> [(Param LParamMem, VName)] -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Param LParamMem, VName) -> ImpM lore r op ()
setLoopParam [(LParam lore, VName)]
[(Param LParamMem, VName)]
loopvars ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ImpM lore r op ()
doBody
    WhileLoop VName
cond ->
      TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhile (Exp -> TExp Bool
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (Exp -> TExp Bool) -> Exp -> TExp Bool
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
cond PrimType
Bool) ImpM lore r op ()
doBody

  Destination Maybe Int
_ [ValueDestination]
pat_dests <- Pattern lore -> ImpM lore r op Destination
forall lore r op.
Mem lore =>
Pattern lore -> ImpM lore r op Destination
destinationFromPattern Pattern lore
pat
  [(ValueDestination, SubExp)]
-> ((ValueDestination, SubExp) -> ImpM lore r op ())
-> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([ValueDestination] -> [SubExp] -> [(ValueDestination, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [ValueDestination]
pat_dests ([SubExp] -> [(ValueDestination, SubExp)])
-> [SubExp] -> [(ValueDestination, SubExp)]
forall a b. (a -> b) -> a -> b
$ ((Param FParamMem, SubExp) -> SubExp)
-> [(Param FParamMem, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> ((Param FParamMem, SubExp) -> VName)
-> (Param FParamMem, SubExp)
-> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param FParamMem -> VName
forall dec. Param dec -> VName
paramName (Param FParamMem -> VName)
-> ((Param FParamMem, SubExp) -> Param FParamMem)
-> (Param FParamMem, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param FParamMem, SubExp) -> Param FParamMem
forall a b. (a, b) -> a
fst) [(Param FParamMem, SubExp)]
merge) (((ValueDestination, SubExp) -> ImpM lore r op ())
 -> ImpM lore r op ())
-> ((ValueDestination, SubExp) -> ImpM lore r op ())
-> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(ValueDestination
d, SubExp
r) ->
    ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIMDest ValueDestination
d [] SubExp
r []
  where
    merge :: [(Param FParamMem, SubExp)]
merge = [(FParam lore, SubExp)]
[(Param FParamMem, SubExp)]
ctx [(Param FParamMem, SubExp)]
-> [(Param FParamMem, SubExp)] -> [(Param FParamMem, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(FParam lore, SubExp)]
[(Param FParamMem, SubExp)]
val
    mergepat :: [Param FParamMem]
mergepat = ((Param FParamMem, SubExp) -> Param FParamMem)
-> [(Param FParamMem, SubExp)] -> [Param FParamMem]
forall a b. (a -> b) -> [a] -> [b]
map (Param FParamMem, SubExp) -> Param FParamMem
forall a b. (a, b) -> a
fst [(Param FParamMem, SubExp)]
merge
defCompileExp Pattern lore
pat (WithAcc [(Shape, [VName], Maybe (Lambda lore, [SubExp]))]
inputs Lambda lore
lam) = do
  [LParam lore] -> ImpM lore r op ()
forall lore r op. Mem lore => [LParam lore] -> ImpM lore r op ()
dLParams ([LParam lore] -> ImpM lore r op ())
-> [LParam lore] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [LParam lore]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
lam
  [((Shape, [VName], Maybe (Lambda lore, [SubExp])),
  Param LParamMem)]
-> (((Shape, [VName], Maybe (Lambda lore, [SubExp])),
     Param LParamMem)
    -> ImpM lore r op ())
-> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(Shape, [VName], Maybe (Lambda lore, [SubExp]))]
-> [Param LParamMem]
-> [((Shape, [VName], Maybe (Lambda lore, [SubExp])),
     Param LParamMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(Shape, [VName], Maybe (Lambda lore, [SubExp]))]
inputs ([Param LParamMem]
 -> [((Shape, [VName], Maybe (Lambda lore, [SubExp])),
      Param LParamMem)])
-> [Param LParamMem]
-> [((Shape, [VName], Maybe (Lambda lore, [SubExp])),
     Param LParamMem)]
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [LParam lore]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
lam) ((((Shape, [VName], Maybe (Lambda lore, [SubExp])),
   Param LParamMem)
  -> ImpM lore r op ())
 -> ImpM lore r op ())
-> (((Shape, [VName], Maybe (Lambda lore, [SubExp])),
     Param LParamMem)
    -> ImpM lore r op ())
-> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \((Shape
_, [VName]
arrs, Maybe (Lambda lore, [SubExp])
op), Param LParamMem
p) ->
    (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ())
-> (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState lore r op
s ->
      ImpState lore r op
s {stateAccs :: Map VName ([VName], Maybe (Lambda lore, [SubExp]))
stateAccs = VName
-> ([VName], Maybe (Lambda lore, [SubExp]))
-> Map VName ([VName], Maybe (Lambda lore, [SubExp]))
-> Map VName ([VName], Maybe (Lambda lore, [SubExp]))
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) ([VName]
arrs, Maybe (Lambda lore, [SubExp])
op) (Map VName ([VName], Maybe (Lambda lore, [SubExp]))
 -> Map VName ([VName], Maybe (Lambda lore, [SubExp])))
-> Map VName ([VName], Maybe (Lambda lore, [SubExp]))
-> Map VName ([VName], Maybe (Lambda lore, [SubExp]))
forall a b. (a -> b) -> a -> b
$ ImpState lore r op
-> Map VName ([VName], Maybe (Lambda lore, [SubExp]))
forall lore r op.
ImpState lore r op
-> Map VName ([VName], Maybe (Lambda lore, [SubExp]))
stateAccs ImpState lore r op
s}
  Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (BodyT lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms (BodyT lore -> Stms lore) -> BodyT lore -> Stms lore
forall a b. (a -> b) -> a -> b
$ Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam) (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ do
    let nonacc_res :: [SubExp]
nonacc_res = Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
drop Int
num_accs (BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam))
        nonacc_pat_names :: [VName]
nonacc_pat_names = Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
takeLast ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nonacc_res) (PatternT LParamMem -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern lore
PatternT LParamMem
pat)
    [(VName, SubExp)]
-> ((VName, SubExp) -> ImpM lore r op ()) -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
nonacc_pat_names [SubExp]
nonacc_res) (((VName, SubExp) -> ImpM lore r op ()) -> ImpM lore r op ())
-> ((VName, SubExp) -> ImpM lore r op ()) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(VName
v, SubExp
se) ->
      VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM VName
v [] SubExp
se []
  where
    num_accs :: Int
num_accs = [(Shape, [VName], Maybe (Lambda lore, [SubExp]))] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Shape, [VName], Maybe (Lambda lore, [SubExp]))]
inputs
defCompileExp Pattern lore
pat (Op Op lore
op) = do
  PatternT LParamMem -> Op lore -> ImpM lore r op ()
opc <- (Env lore r op
 -> PatternT LParamMem -> Op lore -> ImpM lore r op ())
-> ImpM
     lore r op (PatternT LParamMem -> Op lore -> ImpM lore r op ())
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> PatternT LParamMem -> Op lore -> ImpM lore r op ()
forall lore r op. Env lore r op -> OpCompiler lore r op
envOpCompiler
  PatternT LParamMem -> Op lore -> ImpM lore r op ()
opc Pattern lore
PatternT LParamMem
pat Op lore
op

defCompileBasicOp ::
  Mem lore =>
  Pattern lore ->
  BasicOp ->
  ImpM lore r op ()
defCompileBasicOp :: forall lore r op.
Mem lore =>
Pattern lore -> BasicOp -> ImpM lore r op ()
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (SubExp SubExp
se) =
  VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT LParamMem
pe) [] SubExp
se []
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (Opaque SubExp
se) =
  VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT LParamMem
pe) [] SubExp
se []
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (UnOp UnOp
op SubExp
e) = do
  Exp
e' <- SubExp -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
e
  PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT LParamMem
pe VName -> Exp -> ImpM lore r op ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<~~ UnOp -> Exp -> Exp
forall v. UnOp -> PrimExp v -> PrimExp v
Imp.UnOpExp UnOp
op Exp
e'
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (ConvOp ConvOp
conv SubExp
e) = do
  Exp
e' <- SubExp -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
e
  PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT LParamMem
pe VName -> Exp -> ImpM lore r op ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<~~ ConvOp -> Exp -> Exp
forall v. ConvOp -> PrimExp v -> PrimExp v
Imp.ConvOpExp ConvOp
conv Exp
e'
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (BinOp BinOp
bop SubExp
x SubExp
y) = do
  Exp
x' <- SubExp -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
x
  Exp
y' <- SubExp -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
y
  PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT LParamMem
pe VName -> Exp -> ImpM lore r op ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<~~ BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp BinOp
bop Exp
x' Exp
y'
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (CmpOp CmpOp
bop SubExp
x SubExp
y) = do
  Exp
x' <- SubExp -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
x
  Exp
y' <- SubExp -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
y
  PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT LParamMem
pe VName -> Exp -> ImpM lore r op ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<~~ CmpOp -> Exp -> Exp -> Exp
forall v. CmpOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.CmpOpExp CmpOp
bop Exp
x' Exp
y'
defCompileBasicOp PatternT (LetDec lore)
_ (Assert SubExp
e ErrorMsg SubExp
msg (SrcLoc, [SrcLoc])
loc) = do
  Exp
e' <- SubExp -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
e
  ErrorMsg Exp
msg' <- (SubExp -> ImpM lore r op Exp)
-> ErrorMsg SubExp -> ImpM lore r op (ErrorMsg Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp ErrorMsg SubExp
msg
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Exp -> ErrorMsg Exp -> (SrcLoc, [SrcLoc]) -> Code op
forall a. Exp -> ErrorMsg Exp -> (SrcLoc, [SrcLoc]) -> Code a
Imp.Assert Exp
e' ErrorMsg Exp
msg' (SrcLoc, [SrcLoc])
loc

  Attrs
attrs <- ImpM lore r op Attrs
forall lore r op. ImpM lore r op Attrs
askAttrs
  Bool -> ImpM lore r op () -> ImpM lore r op ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Name -> [Attr] -> Attr
AttrComp Name
"warn" [Attr
"safety_checks"] Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs) (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
    (SrcLoc -> [SrcLoc] -> [Char] -> ImpM lore r op ())
-> (SrcLoc, [SrcLoc]) -> [Char] -> ImpM lore r op ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry SrcLoc -> [SrcLoc] -> [Char] -> ImpM lore r op ()
forall loc lore r op.
Located loc =>
loc -> [loc] -> [Char] -> ImpM lore r op ()
warn (SrcLoc, [SrcLoc])
loc [Char]
"Safety check required at run-time."
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (Index VName
src Slice SubExp
slice)
  | Just [SubExp]
idxs <- Slice SubExp -> Maybe [SubExp]
forall d. Slice d -> Maybe [d]
sliceIndices Slice SubExp
slice =
    VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT LParamMem
pe) [] (VName -> SubExp
Var VName
src) ([DimIndex (TExp Int64)] -> ImpM lore r op ())
-> [DimIndex (TExp Int64)] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ (SubExp -> DimIndex (TExp Int64))
-> [SubExp] -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map (TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix (TExp Int64 -> DimIndex (TExp Int64))
-> (SubExp -> TExp Int64) -> SubExp -> DimIndex (TExp Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp) [SubExp]
idxs
defCompileBasicOp PatternT (LetDec lore)
_ Index {} =
  () -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (Update VName
_ Slice SubExp
slice SubExp
se) =
  VName -> [DimIndex (TExp Int64)] -> SubExp -> ImpM lore r op ()
forall lore r op.
VName -> [DimIndex (TExp Int64)] -> SubExp -> ImpM lore r op ()
sUpdate (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT LParamMem
pe) ((DimIndex SubExp -> DimIndex (TExp Int64))
-> Slice SubExp -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map ((SubExp -> TExp Int64) -> DimIndex SubExp -> DimIndex (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp) Slice SubExp
slice) SubExp
se
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (Replicate (Shape [SubExp]
ds) SubExp
se) = do
  [Exp]
ds' <- (SubExp -> ImpM lore r op Exp) -> [SubExp] -> ImpM lore r op [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp [SubExp]
ds
  [VName]
is <- Int -> ImpM lore r op VName -> ImpM lore r op [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
ds) ([Char] -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"i")
  Code op
copy_elem <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (ImpM lore r op () -> ImpM lore r op (Code op))
-> ImpM lore r op () -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$ VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT LParamMem
pe) ((VName -> DimIndex (TExp Int64))
-> [VName] -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map (TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix (TExp Int64 -> DimIndex (TExp Int64))
-> (VName -> TExp Int64) -> VName -> DimIndex (TExp Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> TExp Int64
Imp.vi64) [VName]
is) SubExp
se []
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ ((Code op -> Code op)
 -> (Code op -> Code op) -> Code op -> Code op)
-> (Code op -> Code op)
-> [Code op -> Code op]
-> Code op
-> Code op
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (Code op -> Code op) -> (Code op -> Code op) -> Code op -> Code op
forall b c a. (b -> c) -> (a -> b) -> a -> c
(.) Code op -> Code op
forall a. a -> a
id ((VName -> Exp -> Code op -> Code op)
-> [VName] -> [Exp] -> [Code op -> Code op]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> Exp -> Code op -> Code op
forall a. VName -> Exp -> Code a -> Code a
Imp.For [VName]
is [Exp]
ds') Code op
copy_elem
defCompileBasicOp PatternT (LetDec lore)
_ Scratch {} =
  () -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
defCompileBasicOp (Pattern [] [PatElemT (LetDec lore)
pe]) (Iota SubExp
n SubExp
e SubExp
s IntType
it) = do
  Exp
e' <- SubExp -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
e
  Exp
s' <- SubExp -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
s
  [Char]
-> TExp Int64
-> (TExp Int64 -> ImpM lore r op ())
-> ImpM lore r op ()
forall t lore r op.
[Char]
-> TExp t -> (TExp t -> ImpM lore r op ()) -> ImpM lore r op ()
sFor [Char]
"i" (SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp SubExp
n) ((TExp Int64 -> ImpM lore r op ()) -> ImpM lore r op ())
-> (TExp Int64 -> ImpM lore r op ()) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
    let i' :: Exp
i' = IntType -> Exp -> Exp
forall v. IntType -> PrimExp v -> PrimExp v
sExt IntType
it (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
i
    TV Any
x <-
      [Char] -> TExp Any -> ImpM lore r op (TV Any)
forall t lore r op. [Char] -> TExp t -> ImpM lore r op (TV t)
dPrimV [Char]
"x" (TExp Any -> ImpM lore r op (TV Any))
-> TExp Any -> ImpM lore r op (TV Any)
forall a b. (a -> b) -> a -> b
$
        Exp -> TExp Any
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (Exp -> TExp Any) -> Exp -> TExp Any
forall a b. (a -> b) -> a -> b
$
          BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> Overflow -> BinOp
Add IntType
it Overflow
OverflowUndef) Exp
e' (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$
            BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> Overflow -> BinOp
Mul IntType
it Overflow
OverflowUndef) Exp
i' Exp
s'
    VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT LParamMem
pe) [TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix TExp Int64
i] (VName -> SubExp
Var (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
x)) []
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (Copy VName
src) =
  VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT LParamMem
pe) [] (VName -> SubExp
Var VName
src) []
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (Manifest [Int]
_ VName
src) =
  VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT LParamMem
pe) [] (VName -> SubExp
Var VName
src) []
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (Concat Int
i VName
x [VName]
ys SubExp
_) = do
  TV Int64
offs_glb <- [Char] -> TExp Int64 -> ImpM lore r op (TV Int64)
forall t lore r op. [Char] -> TExp t -> ImpM lore r op (TV t)
dPrimV [Char]
"tmp_offs" TExp Int64
0

  [VName] -> (VName -> ImpM lore r op ()) -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (VName
x VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
ys) ((VName -> ImpM lore r op ()) -> ImpM lore r op ())
-> (VName -> ImpM lore r op ()) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \VName
y -> do
    [SubExp]
y_dims <- Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [SubExp])
-> ImpM lore r op Type -> ImpM lore r op [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM lore r op Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
y
    let rows :: TExp Int64
rows = case Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
drop Int
i [SubExp]
y_dims of
          [] -> [Char] -> TExp Int64
forall a. HasCallStack => [Char] -> a
error ([Char] -> TExp Int64) -> [Char] -> TExp Int64
forall a b. (a -> b) -> a -> b
$ [Char]
"defCompileBasicOp Concat: empty array shape for " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
y
          SubExp
r : [SubExp]
_ -> SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp SubExp
r
        skip_dims :: [SubExp]
skip_dims = Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
take Int
i [SubExp]
y_dims
        sliceAllDim :: d -> DimIndex d
sliceAllDim d
d = d -> d -> d -> DimIndex d
forall d. d -> d -> d -> DimIndex d
DimSlice d
0 d
d d
1
        skip_slices :: [DimIndex (TExp Int64)]
skip_slices = (SubExp -> DimIndex (TExp Int64))
-> [SubExp] -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map (TExp Int64 -> DimIndex (TExp Int64)
forall {d}. Num d => d -> DimIndex d
sliceAllDim (TExp Int64 -> DimIndex (TExp Int64))
-> (SubExp -> TExp Int64) -> SubExp -> DimIndex (TExp Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp) [SubExp]
skip_dims
        destslice :: [DimIndex (TExp Int64)]
destslice = [DimIndex (TExp Int64)]
skip_slices [DimIndex (TExp Int64)]
-> [DimIndex (TExp Int64)] -> [DimIndex (TExp Int64)]
forall a. [a] -> [a] -> [a]
++ [TExp Int64 -> TExp Int64 -> TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> d -> d -> DimIndex d
DimSlice (TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
offs_glb) TExp Int64
rows TExp Int64
1]
    VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT LParamMem
pe) [DimIndex (TExp Int64)]
destslice (VName -> SubExp
Var VName
y) []
    TV Int64
offs_glb TV Int64 -> TExp Int64 -> ImpM lore r op ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
offs_glb TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
rows
defCompileBasicOp (Pattern [] [PatElemT (LetDec lore)
pe]) (ArrayLit [SubExp]
es Type
_)
  | Just vs :: [PrimValue]
vs@(PrimValue
v : [PrimValue]
_) <- (SubExp -> Maybe PrimValue) -> [SubExp] -> Maybe [PrimValue]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> Maybe PrimValue
isLiteral [SubExp]
es = do
    MemLocation
dest_mem <- ArrayEntry -> MemLocation
entryArrayLocation (ArrayEntry -> MemLocation)
-> ImpM lore r op ArrayEntry -> ImpM lore r op MemLocation
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM lore r op ArrayEntry
forall lore r op. VName -> ImpM lore r op ArrayEntry
lookupArray (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT LParamMem
pe)
    Space
dest_space <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM lore r op MemEntry -> ImpM lore r op Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM lore r op MemEntry
forall lore r op. VName -> ImpM lore r op MemEntry
lookupMemory (MemLocation -> VName
memLocationName MemLocation
dest_mem)
    let t :: PrimType
t = PrimValue -> PrimType
primValueType PrimValue
v
    VName
static_array <- [Char] -> ImpM lore r op VName
forall lore r op. [Char] -> ImpM lore r op VName
newVNameForFun [Char]
"static_array"
    Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Space -> PrimType -> ArrayContents -> Code op
forall a. VName -> Space -> PrimType -> ArrayContents -> Code a
Imp.DeclareArray VName
static_array Space
dest_space PrimType
t (ArrayContents -> Code op) -> ArrayContents -> Code op
forall a b. (a -> b) -> a -> b
$ [PrimValue] -> ArrayContents
Imp.ArrayValues [PrimValue]
vs
    let static_src :: MemLocation
static_src =
          VName -> [SubExp] -> IxFun (TExp Int64) -> MemLocation
MemLocation VName
static_array [IntType -> Integer -> SubExp
intConst IntType
Int64 (Integer -> SubExp) -> Integer -> SubExp
forall a b. (a -> b) -> a -> b
$ Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Integer) -> Int -> Integer
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
es] (IxFun (TExp Int64) -> MemLocation)
-> IxFun (TExp Int64) -> MemLocation
forall a b. (a -> b) -> a -> b
$
            Shape (TExp Int64) -> IxFun (TExp Int64)
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota [Int -> TExp Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> TExp Int64) -> Int -> TExp Int64
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
es]
        entry :: VarEntry lore
entry = Maybe (Exp lore) -> MemEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> MemEntry -> VarEntry lore
MemVar Maybe (Exp lore)
forall a. Maybe a
Nothing (MemEntry -> VarEntry lore) -> MemEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
dest_space
    VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
static_array VarEntry lore
entry
    let slice :: [DimIndex (TExp Int64)]
slice = [TExp Int64 -> TExp Int64 -> TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> d -> d -> DimIndex d
DimSlice TExp Int64
0 ([SubExp] -> TExp Int64
forall i a. Num i => [a] -> i
genericLength [SubExp]
es) TExp Int64
1]
    CopyCompiler lore r op
forall lore r op. CopyCompiler lore r op
copy PrimType
t MemLocation
dest_mem [DimIndex (TExp Int64)]
slice MemLocation
static_src [DimIndex (TExp Int64)]
slice
  | Bool
otherwise =
    [(Integer, SubExp)]
-> ((Integer, SubExp) -> ImpM lore r op ()) -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Integer] -> [SubExp] -> [(Integer, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Integer
0 ..] [SubExp]
es) (((Integer, SubExp) -> ImpM lore r op ()) -> ImpM lore r op ())
-> ((Integer, SubExp) -> ImpM lore r op ()) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(Integer
i, SubExp
e) ->
      VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT LParamMem
pe) [TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix (TExp Int64 -> DimIndex (TExp Int64))
-> TExp Int64 -> DimIndex (TExp Int64)
forall a b. (a -> b) -> a -> b
$ Integer -> TExp Int64
forall a. Num a => Integer -> a
fromInteger Integer
i] SubExp
e []
  where
    isLiteral :: SubExp -> Maybe PrimValue
isLiteral (Constant PrimValue
v) = PrimValue -> Maybe PrimValue
forall a. a -> Maybe a
Just PrimValue
v
    isLiteral SubExp
_ = Maybe PrimValue
forall a. Maybe a
Nothing
defCompileBasicOp PatternT (LetDec lore)
_ Rearrange {} =
  () -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
defCompileBasicOp PatternT (LetDec lore)
_ Rotate {} =
  () -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
defCompileBasicOp PatternT (LetDec lore)
_ Reshape {} =
  () -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
defCompileBasicOp PatternT (LetDec lore)
_ (UpdateAcc VName
acc [SubExp]
is [SubExp]
vs) = [Char] -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op. [Char] -> ImpM lore r op () -> ImpM lore r op ()
sComment [Char]
"UpdateAcc" (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ do
  -- We are abusing the comment mechanism to wrap the operator in
  -- braces when we end up generating code.  This is necessary because
  -- we might otherwise end up declaring lambda parameters (if any)
  -- multiple times, as they are duplicated every time we do an
  -- UpdateAcc for the same accumulator.
  let is' :: Shape (TExp Int64)
is' = (SubExp -> TExp Int64) -> [SubExp] -> Shape (TExp Int64)
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
is

  -- We need to figure out whether we are updating a scatter-like
  -- accumulator or a generalised reduction.  This also binds the
  -- index parameters.
  (VName
_, Space
_, [VName]
arrs, Shape (TExp Int64)
dims, Maybe (Lambda lore)
op) <- VName
-> Shape (TExp Int64)
-> ImpM
     lore
     r
     op
     (VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda lore))
forall lore r op.
VName
-> Shape (TExp Int64)
-> ImpM
     lore
     r
     op
     (VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda lore))
lookupAcc VName
acc Shape (TExp Int64)
is'

  TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen ([DimIndex (TExp Int64)] -> Shape (TExp Int64) -> TExp Bool
inBounds ((TExp Int64 -> DimIndex (TExp Int64))
-> Shape (TExp Int64) -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix Shape (TExp Int64)
is') Shape (TExp Int64)
dims) (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
    case Maybe (Lambda lore)
op of
      Maybe (Lambda lore)
Nothing ->
        -- Scatter-like.
        [(VName, SubExp)]
-> ((VName, SubExp) -> ImpM lore r op ()) -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
arrs [SubExp]
vs) (((VName, SubExp) -> ImpM lore r op ()) -> ImpM lore r op ())
-> ((VName, SubExp) -> ImpM lore r op ()) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(VName
arr, SubExp
v) -> VName
-> Shape (TExp Int64)
-> SubExp
-> Shape (TExp Int64)
-> ImpM lore r op ()
forall lore r op.
VName
-> Shape (TExp Int64)
-> SubExp
-> Shape (TExp Int64)
-> ImpM lore r op ()
copyDWIMFix VName
arr Shape (TExp Int64)
is' SubExp
v []
      Just Lambda lore
lam -> do
        -- Generalised reduction.
        [LParam lore] -> ImpM lore r op ()
forall lore r op. Mem lore => [LParam lore] -> ImpM lore r op ()
dLParams ([LParam lore] -> ImpM lore r op ())
-> [LParam lore] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [LParam lore]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
lam
        let ([VName]
x_params, [VName]
y_params) =
              Int -> [VName] -> ([VName], [VName])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
vs) ([VName] -> ([VName], [VName])) -> [VName] -> ([VName], [VName])
forall a b. (a -> b) -> a -> b
$ (Param LParamMem -> VName) -> [Param LParamMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param LParamMem -> VName
forall dec. Param dec -> VName
paramName ([Param LParamMem] -> [VName]) -> [Param LParamMem] -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [LParam lore]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
lam

        [(VName, VName)]
-> ((VName, VName) -> ImpM lore r op ()) -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
x_params [VName]
arrs) (((VName, VName) -> ImpM lore r op ()) -> ImpM lore r op ())
-> ((VName, VName) -> ImpM lore r op ()) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(VName
xp, VName
arr) ->
          VName
-> Shape (TExp Int64)
-> SubExp
-> Shape (TExp Int64)
-> ImpM lore r op ()
forall lore r op.
VName
-> Shape (TExp Int64)
-> SubExp
-> Shape (TExp Int64)
-> ImpM lore r op ()
copyDWIMFix VName
xp [] (VName -> SubExp
Var VName
arr) Shape (TExp Int64)
is'

        [(VName, SubExp)]
-> ((VName, SubExp) -> ImpM lore r op ()) -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
y_params [SubExp]
vs) (((VName, SubExp) -> ImpM lore r op ()) -> ImpM lore r op ())
-> ((VName, SubExp) -> ImpM lore r op ()) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(VName
yp, SubExp
v) ->
          VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM VName
yp [] SubExp
v []

        Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (BodyT lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms (BodyT lore -> Stms lore) -> BodyT lore -> Stms lore
forall a b. (a -> b) -> a -> b
$ Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam) (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
          [(VName, SubExp)]
-> ((VName, SubExp) -> ImpM lore r op ()) -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
arrs (BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam))) (((VName, SubExp) -> ImpM lore r op ()) -> ImpM lore r op ())
-> ((VName, SubExp) -> ImpM lore r op ()) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(VName
arr, SubExp
se) ->
            VName
-> Shape (TExp Int64)
-> SubExp
-> Shape (TExp Int64)
-> ImpM lore r op ()
forall lore r op.
VName
-> Shape (TExp Int64)
-> SubExp
-> Shape (TExp Int64)
-> ImpM lore r op ()
copyDWIMFix VName
arr Shape (TExp Int64)
is' SubExp
se []
defCompileBasicOp PatternT (LetDec lore)
pat BasicOp
e =
  [Char] -> ImpM lore r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM lore r op ()) -> [Char] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
    [Char]
"ImpGen.defCompileBasicOp: Invalid pattern\n  "
      [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ PatternT LParamMem -> [Char]
forall a. Pretty a => a -> [Char]
pretty PatternT (LetDec lore)
PatternT LParamMem
pat
      [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
"\nfor expression\n  "
      [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ BasicOp -> [Char]
forall a. Pretty a => a -> [Char]
pretty BasicOp
e

-- | Note: a hack to be used only for functions.
addArrays :: [ArrayDecl] -> ImpM lore r op ()
addArrays :: forall lore r op. [ArrayDecl] -> ImpM lore r op ()
addArrays = (ArrayDecl -> ImpM lore r op ())
-> [ArrayDecl] -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ArrayDecl -> ImpM lore r op ()
forall {lore} {r} {op}. ArrayDecl -> ImpM lore r op ()
addArray
  where
    addArray :: ArrayDecl -> ImpM lore r op ()
addArray (ArrayDecl VName
name PrimType
bt MemLocation
location) =
      VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
        Maybe (Exp lore) -> ArrayEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> ArrayEntry -> VarEntry lore
ArrayVar
          Maybe (Exp lore)
forall a. Maybe a
Nothing
          ArrayEntry :: MemLocation -> PrimType -> ArrayEntry
ArrayEntry
            { entryArrayLocation :: MemLocation
entryArrayLocation = MemLocation
location,
              entryArrayElemType :: PrimType
entryArrayElemType = PrimType
bt
            }

-- | Like 'dFParams', but does not create new declarations.
-- Note: a hack to be used only for functions.
addFParams :: Mem lore => [FParam lore] -> ImpM lore r op ()
addFParams :: forall lore r op. Mem lore => [FParam lore] -> ImpM lore r op ()
addFParams = (Param FParamMem -> ImpM lore r op ())
-> [Param FParamMem] -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Param FParamMem -> ImpM lore r op ()
forall {u} {lore} {r} {op}.
Param (MemInfo SubExp u MemBind) -> ImpM lore r op ()
addFParam
  where
    addFParam :: Param (MemInfo SubExp u MemBind) -> ImpM lore r op ()
addFParam Param (MemInfo SubExp u MemBind)
fparam =
      VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar (Param (MemInfo SubExp u MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo SubExp u MemBind)
fparam) (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
        Maybe (Exp lore) -> LParamMem -> VarEntry lore
forall lore. Maybe (Exp lore) -> LParamMem -> VarEntry lore
memBoundToVarEntry Maybe (Exp lore)
forall a. Maybe a
Nothing (LParamMem -> VarEntry lore) -> LParamMem -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ MemInfo SubExp u MemBind -> LParamMem
forall d u r. MemInfo d u r -> MemInfo d NoUniqueness r
noUniquenessReturns (MemInfo SubExp u MemBind -> LParamMem)
-> MemInfo SubExp u MemBind -> LParamMem
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp u MemBind) -> MemInfo SubExp u MemBind
forall dec. Param dec -> dec
paramDec Param (MemInfo SubExp u MemBind)
fparam

-- | Another hack.
addLoopVar :: VName -> IntType -> ImpM lore r op ()
addLoopVar :: forall lore r op. VName -> IntType -> ImpM lore r op ()
addLoopVar VName
i IntType
it = VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
i (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
ScalarVar Maybe (Exp lore)
forall a. Maybe a
Nothing (ScalarEntry -> VarEntry lore) -> ScalarEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarEntry
ScalarEntry (PrimType -> ScalarEntry) -> PrimType -> ScalarEntry
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
it

dVars ::
  Mem lore =>
  Maybe (Exp lore) ->
  [PatElem lore] ->
  ImpM lore r op ()
dVars :: forall lore r op.
Mem lore =>
Maybe (Exp lore) -> [PatElem lore] -> ImpM lore r op ()
dVars Maybe (Exp lore)
e = (PatElemT LParamMem -> ImpM lore r op ())
-> [PatElemT LParamMem] -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ PatElemT LParamMem -> ImpM lore r op ()
dVar
  where
    dVar :: PatElemT LParamMem -> ImpM lore r op ()
dVar = Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
forall lore r op.
Mem lore =>
Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
dScope Maybe (Exp lore)
e (Scope lore -> ImpM lore r op ())
-> (PatElemT LParamMem -> Scope lore)
-> PatElemT LParamMem
-> ImpM lore r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT LParamMem -> Scope lore
forall lore dec. (LetDec lore ~ dec) => PatElemT dec -> Scope lore
scopeOfPatElem

dFParams :: Mem lore => [FParam lore] -> ImpM lore r op ()
dFParams :: forall lore r op. Mem lore => [FParam lore] -> ImpM lore r op ()
dFParams = Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
forall lore r op.
Mem lore =>
Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
dScope Maybe (Exp lore)
forall a. Maybe a
Nothing (Scope lore -> ImpM lore r op ())
-> ([Param FParamMem] -> Scope lore)
-> [Param FParamMem]
-> ImpM lore r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Param FParamMem] -> Scope lore
forall lore dec.
(FParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfFParams

dLParams :: Mem lore => [LParam lore] -> ImpM lore r op ()
dLParams :: forall lore r op. Mem lore => [LParam lore] -> ImpM lore r op ()
dLParams = Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
forall lore r op.
Mem lore =>
Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
dScope Maybe (Exp lore)
forall a. Maybe a
Nothing (Scope lore -> ImpM lore r op ())
-> ([Param LParamMem] -> Scope lore)
-> [Param LParamMem]
-> ImpM lore r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Param LParamMem] -> Scope lore
forall lore dec.
(LParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfLParams

dPrimVol :: String -> PrimType -> Imp.TExp t -> ImpM lore r op (TV t)
dPrimVol :: forall t lore r op.
[Char] -> PrimType -> TExp t -> ImpM lore r op (TV t)
dPrimVol [Char]
name PrimType
t TExp t
e = do
  VName
name' <- [Char] -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
name
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Volatility -> PrimType -> Code op
forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
name' Volatility
Imp.Volatile PrimType
t
  VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name' (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
ScalarVar Maybe (Exp lore)
forall a. Maybe a
Nothing (ScalarEntry -> VarEntry lore) -> ScalarEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarEntry
ScalarEntry PrimType
t
  VName
name' VName -> Exp -> ImpM lore r op ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<~~ TExp t -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp t
e
  TV t -> ImpM lore r op (TV t)
forall (m :: * -> *) a. Monad m => a -> m a
return (TV t -> ImpM lore r op (TV t)) -> TV t -> ImpM lore r op (TV t)
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> TV t
forall t. VName -> PrimType -> TV t
TV VName
name' PrimType
t

dPrim_ :: VName -> PrimType -> ImpM lore r op ()
dPrim_ :: forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
name PrimType
t = do
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Volatility -> PrimType -> Code op
forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
name Volatility
Imp.Nonvolatile PrimType
t
  VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
ScalarVar Maybe (Exp lore)
forall a. Maybe a
Nothing (ScalarEntry -> VarEntry lore) -> ScalarEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarEntry
ScalarEntry PrimType
t

-- | The return type is polymorphic, so there is no guarantee it
-- actually matches the 'PrimType', but at least we have to use it
-- consistently.
dPrim :: String -> PrimType -> ImpM lore r op (TV t)
dPrim :: forall lore r op t. [Char] -> PrimType -> ImpM lore r op (TV t)
dPrim [Char]
name PrimType
t = do
  VName
name' <- [Char] -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
name
  VName -> PrimType -> ImpM lore r op ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
name' PrimType
t
  TV t -> ImpM lore r op (TV t)
forall (m :: * -> *) a. Monad m => a -> m a
return (TV t -> ImpM lore r op (TV t)) -> TV t -> ImpM lore r op (TV t)
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> TV t
forall t. VName -> PrimType -> TV t
TV VName
name' PrimType
t

dPrimV_ :: VName -> Imp.TExp t -> ImpM lore r op ()
dPrimV_ :: forall t lore r op. VName -> TExp t -> ImpM lore r op ()
dPrimV_ VName
name TExp t
e = do
  VName -> PrimType -> ImpM lore r op ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
name PrimType
t
  VName -> PrimType -> TV t
forall t. VName -> PrimType -> TV t
TV VName
name PrimType
t TV t -> TExp t -> ImpM lore r op ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- TExp t
e
  where
    t :: PrimType
t = Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType (Exp -> PrimType) -> Exp -> PrimType
forall a b. (a -> b) -> a -> b
$ TExp t -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp t
e

dPrimV :: String -> Imp.TExp t -> ImpM lore r op (TV t)
dPrimV :: forall t lore r op. [Char] -> TExp t -> ImpM lore r op (TV t)
dPrimV [Char]
name TExp t
e = do
  TV t
name' <- [Char] -> PrimType -> ImpM lore r op (TV t)
forall lore r op t. [Char] -> PrimType -> ImpM lore r op (TV t)
dPrim [Char]
name (PrimType -> ImpM lore r op (TV t))
-> PrimType -> ImpM lore r op (TV t)
forall a b. (a -> b) -> a -> b
$ Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType (Exp -> PrimType) -> Exp -> PrimType
forall a b. (a -> b) -> a -> b
$ TExp t -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp t
e
  TV t
name' TV t -> TExp t -> ImpM lore r op ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- TExp t
e
  TV t -> ImpM lore r op (TV t)
forall (m :: * -> *) a. Monad m => a -> m a
return TV t
name'

dPrimVE :: String -> Imp.TExp t -> ImpM lore r op (Imp.TExp t)
dPrimVE :: forall t lore r op. [Char] -> TExp t -> ImpM lore r op (TExp t)
dPrimVE [Char]
name TExp t
e = do
  TV t
name' <- [Char] -> PrimType -> ImpM lore r op (TV t)
forall lore r op t. [Char] -> PrimType -> ImpM lore r op (TV t)
dPrim [Char]
name (PrimType -> ImpM lore r op (TV t))
-> PrimType -> ImpM lore r op (TV t)
forall a b. (a -> b) -> a -> b
$ Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType (Exp -> PrimType) -> Exp -> PrimType
forall a b. (a -> b) -> a -> b
$ TExp t -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp t
e
  TV t
name' TV t -> TExp t -> ImpM lore r op ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- TExp t
e
  TExp t -> ImpM lore r op (TExp t)
forall (m :: * -> *) a. Monad m => a -> m a
return (TExp t -> ImpM lore r op (TExp t))
-> TExp t -> ImpM lore r op (TExp t)
forall a b. (a -> b) -> a -> b
$ TV t -> TExp t
forall t. TV t -> TExp t
tvExp TV t
name'

memBoundToVarEntry ::
  Maybe (Exp lore) ->
  MemBound NoUniqueness ->
  VarEntry lore
memBoundToVarEntry :: forall lore. Maybe (Exp lore) -> LParamMem -> VarEntry lore
memBoundToVarEntry Maybe (Exp lore)
e (MemPrim PrimType
bt) =
  Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
ScalarVar Maybe (Exp lore)
e ScalarEntry :: PrimType -> ScalarEntry
ScalarEntry {entryScalarType :: PrimType
entryScalarType = PrimType
bt}
memBoundToVarEntry Maybe (Exp lore)
e (MemMem Space
space) =
  Maybe (Exp lore) -> MemEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> MemEntry -> VarEntry lore
MemVar Maybe (Exp lore)
e (MemEntry -> VarEntry lore) -> MemEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
space
memBoundToVarEntry Maybe (Exp lore)
e (MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
_) =
  Maybe (Exp lore) -> (VName, Shape, [Type]) -> VarEntry lore
forall lore.
Maybe (Exp lore) -> (VName, Shape, [Type]) -> VarEntry lore
AccVar Maybe (Exp lore)
e (VName
acc, Shape
ispace, [Type]
ts)
memBoundToVarEntry Maybe (Exp lore)
e (MemArray PrimType
bt Shape
shape NoUniqueness
_ (ArrayIn VName
mem IxFun
ixfun)) =
  let location :: MemLocation
location = VName -> [SubExp] -> IxFun (TExp Int64) -> MemLocation
MemLocation VName
mem (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape) (IxFun (TExp Int64) -> MemLocation)
-> IxFun (TExp Int64) -> MemLocation
forall a b. (a -> b) -> a -> b
$ (TPrimExp Int64 VName -> TExp Int64) -> IxFun -> IxFun (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> ExpLeaf) -> TPrimExp Int64 VName -> TExp Int64
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> ExpLeaf
Imp.ScalarVar) IxFun
ixfun
   in Maybe (Exp lore) -> ArrayEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> ArrayEntry -> VarEntry lore
ArrayVar
        Maybe (Exp lore)
e
        ArrayEntry :: MemLocation -> PrimType -> ArrayEntry
ArrayEntry
          { entryArrayLocation :: MemLocation
entryArrayLocation = MemLocation
location,
            entryArrayElemType :: PrimType
entryArrayElemType = PrimType
bt
          }

infoDec ::
  Mem lore =>
  NameInfo lore ->
  MemInfo SubExp NoUniqueness MemBind
infoDec :: forall lore. Mem lore => NameInfo lore -> LParamMem
infoDec (LetName LetDec lore
dec) = LetDec lore
LParamMem
dec
infoDec (FParamName FParamInfo lore
dec) = FParamMem -> LParamMem
forall d u r. MemInfo d u r -> MemInfo d NoUniqueness r
noUniquenessReturns FParamInfo lore
FParamMem
dec
infoDec (LParamName LParamInfo lore
dec) = LParamInfo lore
LParamMem
dec
infoDec (IndexName IntType
it) = PrimType -> LParamMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim (PrimType -> LParamMem) -> PrimType -> LParamMem
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
it

dInfo ::
  Mem lore =>
  Maybe (Exp lore) ->
  VName ->
  NameInfo lore ->
  ImpM lore r op ()
dInfo :: forall lore r op.
Mem lore =>
Maybe (Exp lore) -> VName -> NameInfo lore -> ImpM lore r op ()
dInfo Maybe (Exp lore)
e VName
name NameInfo lore
info = do
  let entry :: VarEntry lore
entry = Maybe (Exp lore) -> LParamMem -> VarEntry lore
forall lore. Maybe (Exp lore) -> LParamMem -> VarEntry lore
memBoundToVarEntry Maybe (Exp lore)
e (LParamMem -> VarEntry lore) -> LParamMem -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ NameInfo lore -> LParamMem
forall lore. Mem lore => NameInfo lore -> LParamMem
infoDec NameInfo lore
info
  case VarEntry lore
entry of
    MemVar Maybe (Exp lore)
_ MemEntry
entry' ->
      Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Code op
forall a. VName -> Space -> Code a
Imp.DeclareMem VName
name (Space -> Code op) -> Space -> Code op
forall a b. (a -> b) -> a -> b
$ MemEntry -> Space
entryMemSpace MemEntry
entry'
    ScalarVar Maybe (Exp lore)
_ ScalarEntry
entry' ->
      Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Volatility -> PrimType -> Code op
forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
name Volatility
Imp.Nonvolatile (PrimType -> Code op) -> PrimType -> Code op
forall a b. (a -> b) -> a -> b
$ ScalarEntry -> PrimType
entryScalarType ScalarEntry
entry'
    ArrayVar Maybe (Exp lore)
_ ArrayEntry
_ ->
      () -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    AccVar {} ->
      () -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name VarEntry lore
entry

dScope ::
  Mem lore =>
  Maybe (Exp lore) ->
  Scope lore ->
  ImpM lore r op ()
dScope :: forall lore r op.
Mem lore =>
Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
dScope Maybe (Exp lore)
e = ((VName, NameInfo lore) -> ImpM lore r op ())
-> [(VName, NameInfo lore)] -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((VName -> NameInfo lore -> ImpM lore r op ())
-> (VName, NameInfo lore) -> ImpM lore r op ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((VName -> NameInfo lore -> ImpM lore r op ())
 -> (VName, NameInfo lore) -> ImpM lore r op ())
-> (VName -> NameInfo lore -> ImpM lore r op ())
-> (VName, NameInfo lore)
-> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> VName -> NameInfo lore -> ImpM lore r op ()
forall lore r op.
Mem lore =>
Maybe (Exp lore) -> VName -> NameInfo lore -> ImpM lore r op ()
dInfo Maybe (Exp lore)
e) ([(VName, NameInfo lore)] -> ImpM lore r op ())
-> (Map VName (NameInfo lore) -> [(VName, NameInfo lore)])
-> Map VName (NameInfo lore)
-> ImpM lore r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map VName (NameInfo lore) -> [(VName, NameInfo lore)]
forall k a. Map k a -> [(k, a)]
M.toList

dArray :: VName -> PrimType -> ShapeBase SubExp -> MemBind -> ImpM lore r op ()
dArray :: forall lore r op.
VName -> PrimType -> Shape -> MemBind -> ImpM lore r op ()
dArray VName
name PrimType
bt Shape
shape MemBind
membind =
  VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
    Maybe (Exp lore) -> LParamMem -> VarEntry lore
forall lore. Maybe (Exp lore) -> LParamMem -> VarEntry lore
memBoundToVarEntry Maybe (Exp lore)
forall a. Maybe a
Nothing (LParamMem -> VarEntry lore) -> LParamMem -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ PrimType -> Shape -> NoUniqueness -> MemBind -> LParamMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt Shape
shape NoUniqueness
NoUniqueness MemBind
membind

everythingVolatile :: ImpM lore r op a -> ImpM lore r op a
everythingVolatile :: forall lore r op a. ImpM lore r op a -> ImpM lore r op a
everythingVolatile = (Env lore r op -> Env lore r op)
-> ImpM lore r op a -> ImpM lore r op a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((Env lore r op -> Env lore r op)
 -> ImpM lore r op a -> ImpM lore r op a)
-> (Env lore r op -> Env lore r op)
-> ImpM lore r op a
-> ImpM lore r op a
forall a b. (a -> b) -> a -> b
$ \Env lore r op
env -> Env lore r op
env {envVolatility :: Volatility
envVolatility = Volatility
Imp.Volatile}

-- | Remove the array targets.
funcallTargets :: Destination -> ImpM lore r op [VName]
funcallTargets :: forall lore r op. Destination -> ImpM lore r op [VName]
funcallTargets (Destination Maybe Int
_ [ValueDestination]
dests) =
  [[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[VName]] -> [VName])
-> ImpM lore r op [[VName]] -> ImpM lore r op [VName]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ValueDestination -> ImpM lore r op [VName])
-> [ValueDestination] -> ImpM lore r op [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ValueDestination -> ImpM lore r op [VName]
forall {m :: * -> *}. Monad m => ValueDestination -> m [VName]
funcallTarget [ValueDestination]
dests
  where
    funcallTarget :: ValueDestination -> m [VName]
funcallTarget (ScalarDestination VName
name) =
      [VName] -> m [VName]
forall (m :: * -> *) a. Monad m => a -> m a
return [VName
name]
    funcallTarget (ArrayDestination Maybe MemLocation
_) =
      [VName] -> m [VName]
forall (m :: * -> *) a. Monad m => a -> m a
return []
    funcallTarget (MemoryDestination VName
name) =
      [VName] -> m [VName]
forall (m :: * -> *) a. Monad m => a -> m a
return [VName
name]

-- | A typed variable, which we can turn into a typed expression, or
-- use as the target for an assignment.  This is used to aid in type
-- safety when doing code generation, by keeping the types straight.
-- It is still easy to cheat when you need to.
data TV t = TV VName PrimType

-- | Create a typed variable from a name and a dynamic type.  Note
-- that there is no guarantee that the dynamic type corresponds to the
-- inferred static type, but the latter will at least have to be used
-- consistently.
mkTV :: VName -> PrimType -> TV t
mkTV :: forall t. VName -> PrimType -> TV t
mkTV = VName -> PrimType -> TV t
forall t. VName -> PrimType -> TV t
TV

-- | Convert a typed variable to a size (a SubExp).
tvSize :: TV t -> Imp.DimSize
tvSize :: forall t. TV t -> SubExp
tvSize = VName -> SubExp
Var (VName -> SubExp) -> (TV t -> VName) -> TV t -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TV t -> VName
forall t. TV t -> VName
tvVar

-- | Convert a typed variable to a similarly typed expression.
tvExp :: TV t -> Imp.TExp t
tvExp :: forall t. TV t -> TExp t
tvExp (TV VName
v PrimType
t) = Exp -> TPrimExp t ExpLeaf
forall t v. PrimExp v -> TPrimExp t v
Imp.TPrimExp (Exp -> TPrimExp t ExpLeaf) -> Exp -> TPrimExp t ExpLeaf
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
v PrimType
t

-- | Extract the underlying variable name from a typed variable.
tvVar :: TV t -> VName
tvVar :: forall t. TV t -> VName
tvVar (TV VName
v PrimType
_) = VName
v

-- | Compile things to 'Imp.Exp'.
class ToExp a where
  -- | Compile to an 'Imp.Exp', where the type (must must still be a
  -- primitive) is deduced monadically.
  toExp :: a -> ImpM lore r op Imp.Exp

  -- | Compile where we know the type in advance.
  toExp' :: PrimType -> a -> Imp.Exp

  toInt64Exp :: a -> Imp.TExp Int64
  toInt64Exp = Exp -> TExp Int64
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (Exp -> TExp Int64) -> (a -> Exp) -> a -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> a -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int64

  toBoolExp :: a -> Imp.TExp Bool
  toBoolExp = Exp -> TExp Bool
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (Exp -> TExp Bool) -> (a -> Exp) -> a -> TExp Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> a -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
Bool

instance ToExp SubExp where
  toExp :: forall lore r op. SubExp -> ImpM lore r op Exp
toExp (Constant PrimValue
v) =
    Exp -> ImpM lore r op Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> ImpM lore r op Exp) -> Exp -> ImpM lore r op Exp
forall a b. (a -> b) -> a -> b
$ PrimValue -> Exp
forall v. PrimValue -> PrimExp v
Imp.ValueExp PrimValue
v
  toExp (Var VName
v) =
    VName -> ImpM lore r op (VarEntry lore)
forall lore r op. VName -> ImpM lore r op (VarEntry lore)
lookupVar VName
v ImpM lore r op (VarEntry lore)
-> (VarEntry lore -> ImpM lore r op Exp) -> ImpM lore r op Exp
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      ScalarVar Maybe (Exp lore)
_ (ScalarEntry PrimType
pt) ->
        Exp -> ImpM lore r op Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> ImpM lore r op Exp) -> Exp -> ImpM lore r op Exp
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
v PrimType
pt
      VarEntry lore
_ -> [Char] -> ImpM lore r op Exp
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM lore r op Exp) -> [Char] -> ImpM lore r op Exp
forall a b. (a -> b) -> a -> b
$ [Char]
"toExp SubExp: SubExp is not a primitive type: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
v

  toExp' :: PrimType -> SubExp -> Exp
toExp' PrimType
_ (Constant PrimValue
v) = PrimValue -> Exp
forall v. PrimValue -> PrimExp v
Imp.ValueExp PrimValue
v
  toExp' PrimType
t (Var VName
v) = VName -> PrimType -> Exp
Imp.var VName
v PrimType
t

instance ToExp (PrimExp VName) where
  toExp :: forall lore r op. PrimExp VName -> ImpM lore r op Exp
toExp = Exp -> ImpM lore r op Exp
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp -> ImpM lore r op Exp)
-> (PrimExp VName -> Exp) -> PrimExp VName -> ImpM lore r op Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> ExpLeaf) -> PrimExp VName -> Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> ExpLeaf
Imp.ScalarVar
  toExp' :: PrimType -> PrimExp VName -> Exp
toExp' PrimType
_ = (VName -> ExpLeaf) -> PrimExp VName -> Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> ExpLeaf
Imp.ScalarVar

addVar :: VName -> VarEntry lore -> ImpM lore r op ()
addVar :: forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name VarEntry lore
entry =
  (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ())
-> (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState lore r op
s -> ImpState lore r op
s {stateVTable :: VTable lore
stateVTable = VName -> VarEntry lore -> VTable lore -> VTable lore
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
name VarEntry lore
entry (VTable lore -> VTable lore) -> VTable lore -> VTable lore
forall a b. (a -> b) -> a -> b
$ ImpState lore r op -> VTable lore
forall lore r op. ImpState lore r op -> VTable lore
stateVTable ImpState lore r op
s}

localDefaultSpace :: Imp.Space -> ImpM lore r op a -> ImpM lore r op a
localDefaultSpace :: forall lore r op a. Space -> ImpM lore r op a -> ImpM lore r op a
localDefaultSpace Space
space = (Env lore r op -> Env lore r op)
-> ImpM lore r op a -> ImpM lore r op a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (\Env lore r op
env -> Env lore r op
env {envDefaultSpace :: Space
envDefaultSpace = Space
space})

askFunction :: ImpM lore r op (Maybe Name)
askFunction :: forall lore r op. ImpM lore r op (Maybe Name)
askFunction = (Env lore r op -> Maybe Name) -> ImpM lore r op (Maybe Name)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Maybe Name
forall lore r op. Env lore r op -> Maybe Name
envFunction

-- | Generate a 'VName', prefixed with 'askFunction' if it exists.
newVNameForFun :: String -> ImpM lore r op VName
newVNameForFun :: forall lore r op. [Char] -> ImpM lore r op VName
newVNameForFun [Char]
s = do
  Maybe [Char]
fname <- (Name -> [Char]) -> Maybe Name -> Maybe [Char]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Name -> [Char]
nameToString (Maybe Name -> Maybe [Char])
-> ImpM lore r op (Maybe Name) -> ImpM lore r op (Maybe [Char])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM lore r op (Maybe Name)
forall lore r op. ImpM lore r op (Maybe Name)
askFunction
  [Char] -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName ([Char] -> ImpM lore r op VName) -> [Char] -> ImpM lore r op VName
forall a b. (a -> b) -> a -> b
$ [Char] -> ShowS -> Maybe [Char] -> [Char]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe [Char]
"" ([Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
".") Maybe [Char]
fname [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
s

-- | Generate a 'Name', prefixed with 'askFunction' if it exists.
nameForFun :: String -> ImpM lore r op Name
nameForFun :: forall lore r op. [Char] -> ImpM lore r op Name
nameForFun [Char]
s = do
  Maybe Name
fname <- ImpM lore r op (Maybe Name)
forall lore r op. ImpM lore r op (Maybe Name)
askFunction
  Name -> ImpM lore r op Name
forall (m :: * -> *) a. Monad m => a -> m a
return (Name -> ImpM lore r op Name) -> Name -> ImpM lore r op Name
forall a b. (a -> b) -> a -> b
$ Name -> (Name -> Name) -> Maybe Name -> Name
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Name
"" (Name -> Name -> Name
forall a. Semigroup a => a -> a -> a
<> Name
".") Maybe Name
fname Name -> Name -> Name
forall a. Semigroup a => a -> a -> a
<> [Char] -> Name
nameFromString [Char]
s

askEnv :: ImpM lore r op r
askEnv :: forall lore r op. ImpM lore r op r
askEnv = (Env lore r op -> r) -> ImpM lore r op r
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> r
forall lore r op. Env lore r op -> r
envEnv

localEnv :: (r -> r) -> ImpM lore r op a -> ImpM lore r op a
localEnv :: forall r lore op a.
(r -> r) -> ImpM lore r op a -> ImpM lore r op a
localEnv r -> r
f = (Env lore r op -> Env lore r op)
-> ImpM lore r op a -> ImpM lore r op a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((Env lore r op -> Env lore r op)
 -> ImpM lore r op a -> ImpM lore r op a)
-> (Env lore r op -> Env lore r op)
-> ImpM lore r op a
-> ImpM lore r op a
forall a b. (a -> b) -> a -> b
$ \Env lore r op
env -> Env lore r op
env {envEnv :: r
envEnv = r -> r
f (r -> r) -> r -> r
forall a b. (a -> b) -> a -> b
$ Env lore r op -> r
forall lore r op. Env lore r op -> r
envEnv Env lore r op
env}

-- | The active attributes, including those for the statement
-- currently being compiled.
askAttrs :: ImpM lore r op Attrs
askAttrs :: forall lore r op. ImpM lore r op Attrs
askAttrs = (Env lore r op -> Attrs) -> ImpM lore r op Attrs
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Attrs
forall lore r op. Env lore r op -> Attrs
envAttrs

-- | Add more attributes to what is returning by 'askAttrs'.
localAttrs :: Attrs -> ImpM lore r op a -> ImpM lore r op a
localAttrs :: forall lore r op a. Attrs -> ImpM lore r op a -> ImpM lore r op a
localAttrs Attrs
attrs = (Env lore r op -> Env lore r op)
-> ImpM lore r op a -> ImpM lore r op a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((Env lore r op -> Env lore r op)
 -> ImpM lore r op a -> ImpM lore r op a)
-> (Env lore r op -> Env lore r op)
-> ImpM lore r op a
-> ImpM lore r op a
forall a b. (a -> b) -> a -> b
$ \Env lore r op
env -> Env lore r op
env {envAttrs :: Attrs
envAttrs = Attrs
attrs Attrs -> Attrs -> Attrs
forall a. Semigroup a => a -> a -> a
<> Env lore r op -> Attrs
forall lore r op. Env lore r op -> Attrs
envAttrs Env lore r op
env}

localOps :: Operations lore r op -> ImpM lore r op a -> ImpM lore r op a
localOps :: forall lore r op a.
Operations lore r op -> ImpM lore r op a -> ImpM lore r op a
localOps Operations lore r op
ops = (Env lore r op -> Env lore r op)
-> ImpM lore r op a -> ImpM lore r op a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((Env lore r op -> Env lore r op)
 -> ImpM lore r op a -> ImpM lore r op a)
-> (Env lore r op -> Env lore r op)
-> ImpM lore r op a
-> ImpM lore r op a
forall a b. (a -> b) -> a -> b
$ \Env lore r op
env ->
  Env lore r op
env
    { envExpCompiler :: ExpCompiler lore r op
envExpCompiler = Operations lore r op -> ExpCompiler lore r op
forall lore r op. Operations lore r op -> ExpCompiler lore r op
opsExpCompiler Operations lore r op
ops,
      envStmsCompiler :: StmsCompiler lore r op
envStmsCompiler = Operations lore r op -> StmsCompiler lore r op
forall lore r op. Operations lore r op -> StmsCompiler lore r op
opsStmsCompiler Operations lore r op
ops,
      envCopyCompiler :: CopyCompiler lore r op
envCopyCompiler = Operations lore r op -> CopyCompiler lore r op
forall lore r op. Operations lore r op -> CopyCompiler lore r op
opsCopyCompiler Operations lore r op
ops,
      envOpCompiler :: OpCompiler lore r op
envOpCompiler = Operations lore r op -> OpCompiler lore r op
forall lore r op. Operations lore r op -> OpCompiler lore r op
opsOpCompiler Operations lore r op
ops,
      envAllocCompilers :: Map Space (AllocCompiler lore r op)
envAllocCompilers = Operations lore r op -> Map Space (AllocCompiler lore r op)
forall lore r op.
Operations lore r op -> Map Space (AllocCompiler lore r op)
opsAllocCompilers Operations lore r op
ops
    }

-- | Get the current symbol table.
getVTable :: ImpM lore r op (VTable lore)
getVTable :: forall lore r op. ImpM lore r op (VTable lore)
getVTable = (ImpState lore r op -> VTable lore) -> ImpM lore r op (VTable lore)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ImpState lore r op -> VTable lore
forall lore r op. ImpState lore r op -> VTable lore
stateVTable

putVTable :: VTable lore -> ImpM lore r op ()
putVTable :: forall lore r op. VTable lore -> ImpM lore r op ()
putVTable VTable lore
vtable = (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ())
-> (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState lore r op
s -> ImpState lore r op
s {stateVTable :: VTable lore
stateVTable = VTable lore
vtable}

-- | Run an action with a modified symbol table.  All changes to the
-- symbol table will be reverted once the action is done!
localVTable :: (VTable lore -> VTable lore) -> ImpM lore r op a -> ImpM lore r op a
localVTable :: forall lore r op a.
(VTable lore -> VTable lore)
-> ImpM lore r op a -> ImpM lore r op a
localVTable VTable lore -> VTable lore
f ImpM lore r op a
m = do
  VTable lore
old_vtable <- ImpM lore r op (VTable lore)
forall lore r op. ImpM lore r op (VTable lore)
getVTable
  VTable lore -> ImpM lore r op ()
forall lore r op. VTable lore -> ImpM lore r op ()
putVTable (VTable lore -> ImpM lore r op ())
-> VTable lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VTable lore -> VTable lore
f VTable lore
old_vtable
  a
a <- ImpM lore r op a
m
  VTable lore -> ImpM lore r op ()
forall lore r op. VTable lore -> ImpM lore r op ()
putVTable VTable lore
old_vtable
  a -> ImpM lore r op a
forall (m :: * -> *) a. Monad m => a -> m a
return a
a

lookupVar :: VName -> ImpM lore r op (VarEntry lore)
lookupVar :: forall lore r op. VName -> ImpM lore r op (VarEntry lore)
lookupVar VName
name = do
  Maybe (VarEntry lore)
res <- (ImpState lore r op -> Maybe (VarEntry lore))
-> ImpM lore r op (Maybe (VarEntry lore))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ImpState lore r op -> Maybe (VarEntry lore))
 -> ImpM lore r op (Maybe (VarEntry lore)))
-> (ImpState lore r op -> Maybe (VarEntry lore))
-> ImpM lore r op (Maybe (VarEntry lore))
forall a b. (a -> b) -> a -> b
$ VName -> Map VName (VarEntry lore) -> Maybe (VarEntry lore)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name (Map VName (VarEntry lore) -> Maybe (VarEntry lore))
-> (ImpState lore r op -> Map VName (VarEntry lore))
-> ImpState lore r op
-> Maybe (VarEntry lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ImpState lore r op -> Map VName (VarEntry lore)
forall lore r op. ImpState lore r op -> VTable lore
stateVTable
  case Maybe (VarEntry lore)
res of
    Just VarEntry lore
entry -> VarEntry lore -> ImpM lore r op (VarEntry lore)
forall (m :: * -> *) a. Monad m => a -> m a
return VarEntry lore
entry
    Maybe (VarEntry lore)
_ -> [Char] -> ImpM lore r op (VarEntry lore)
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM lore r op (VarEntry lore))
-> [Char] -> ImpM lore r op (VarEntry lore)
forall a b. (a -> b) -> a -> b
$ [Char]
"Unknown variable: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
name

lookupArray :: VName -> ImpM lore r op ArrayEntry
lookupArray :: forall lore r op. VName -> ImpM lore r op ArrayEntry
lookupArray VName
name = do
  VarEntry lore
res <- VName -> ImpM lore r op (VarEntry lore)
forall lore r op. VName -> ImpM lore r op (VarEntry lore)
lookupVar VName
name
  case VarEntry lore
res of
    ArrayVar Maybe (Exp lore)
_ ArrayEntry
entry -> ArrayEntry -> ImpM lore r op ArrayEntry
forall (m :: * -> *) a. Monad m => a -> m a
return ArrayEntry
entry
    VarEntry lore
_ -> [Char] -> ImpM lore r op ArrayEntry
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM lore r op ArrayEntry)
-> [Char] -> ImpM lore r op ArrayEntry
forall a b. (a -> b) -> a -> b
$ [Char]
"ImpGen.lookupArray: not an array: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
name

lookupMemory :: VName -> ImpM lore r op MemEntry
lookupMemory :: forall lore r op. VName -> ImpM lore r op MemEntry
lookupMemory VName
name = do
  VarEntry lore
res <- VName -> ImpM lore r op (VarEntry lore)
forall lore r op. VName -> ImpM lore r op (VarEntry lore)
lookupVar VName
name
  case VarEntry lore
res of
    MemVar Maybe (Exp lore)
_ MemEntry
entry -> MemEntry -> ImpM lore r op MemEntry
forall (m :: * -> *) a. Monad m => a -> m a
return MemEntry
entry
    VarEntry lore
_ -> [Char] -> ImpM lore r op MemEntry
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM lore r op MemEntry)
-> [Char] -> ImpM lore r op MemEntry
forall a b. (a -> b) -> a -> b
$ [Char]
"Unknown memory block: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
name

lookupArraySpace :: VName -> ImpM lore r op Space
lookupArraySpace :: forall lore r op. VName -> ImpM lore r op Space
lookupArraySpace =
  (MemEntry -> Space)
-> ImpM lore r op MemEntry -> ImpM lore r op Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap MemEntry -> Space
entryMemSpace (ImpM lore r op MemEntry -> ImpM lore r op Space)
-> (VName -> ImpM lore r op MemEntry)
-> VName
-> ImpM lore r op Space
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> ImpM lore r op MemEntry
forall lore r op. VName -> ImpM lore r op MemEntry
lookupMemory
    (VName -> ImpM lore r op Space)
-> (VName -> ImpM lore r op VName) -> VName -> ImpM lore r op Space
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< (ArrayEntry -> VName)
-> ImpM lore r op ArrayEntry -> ImpM lore r op VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (MemLocation -> VName
memLocationName (MemLocation -> VName)
-> (ArrayEntry -> MemLocation) -> ArrayEntry -> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArrayEntry -> MemLocation
entryArrayLocation) (ImpM lore r op ArrayEntry -> ImpM lore r op VName)
-> (VName -> ImpM lore r op ArrayEntry)
-> VName
-> ImpM lore r op VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> ImpM lore r op ArrayEntry
forall lore r op. VName -> ImpM lore r op ArrayEntry
lookupArray

-- | In the case of a histogram-like accumulator, also sets the index
-- parameters.
lookupAcc ::
  VName ->
  [Imp.TExp Int64] ->
  ImpM lore r op (VName, Space, [VName], [Imp.TExp Int64], Maybe (Lambda lore))
lookupAcc :: forall lore r op.
VName
-> Shape (TExp Int64)
-> ImpM
     lore
     r
     op
     (VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda lore))
lookupAcc VName
name Shape (TExp Int64)
is = do
  VarEntry lore
res <- VName -> ImpM lore r op (VarEntry lore)
forall lore r op. VName -> ImpM lore r op (VarEntry lore)
lookupVar VName
name
  case VarEntry lore
res of
    AccVar Maybe (Exp lore)
_ (VName
acc, Shape
ispace, [Type]
_) -> do
      Maybe ([VName], Maybe (Lambda lore, [SubExp]))
acc' <- (ImpState lore r op
 -> Maybe ([VName], Maybe (Lambda lore, [SubExp])))
-> ImpM lore r op (Maybe ([VName], Maybe (Lambda lore, [SubExp])))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ImpState lore r op
  -> Maybe ([VName], Maybe (Lambda lore, [SubExp])))
 -> ImpM lore r op (Maybe ([VName], Maybe (Lambda lore, [SubExp]))))
-> (ImpState lore r op
    -> Maybe ([VName], Maybe (Lambda lore, [SubExp])))
-> ImpM lore r op (Maybe ([VName], Maybe (Lambda lore, [SubExp])))
forall a b. (a -> b) -> a -> b
$ VName
-> Map VName ([VName], Maybe (Lambda lore, [SubExp]))
-> Maybe ([VName], Maybe (Lambda lore, [SubExp]))
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
acc (Map VName ([VName], Maybe (Lambda lore, [SubExp]))
 -> Maybe ([VName], Maybe (Lambda lore, [SubExp])))
-> (ImpState lore r op
    -> Map VName ([VName], Maybe (Lambda lore, [SubExp])))
-> ImpState lore r op
-> Maybe ([VName], Maybe (Lambda lore, [SubExp]))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ImpState lore r op
-> Map VName ([VName], Maybe (Lambda lore, [SubExp]))
forall lore r op.
ImpState lore r op
-> Map VName ([VName], Maybe (Lambda lore, [SubExp]))
stateAccs
      case Maybe ([VName], Maybe (Lambda lore, [SubExp]))
acc' of
        Just ([], Maybe (Lambda lore, [SubExp])
_) ->
          [Char]
-> ImpM
     lore
     r
     op
     (VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda lore))
forall a. HasCallStack => [Char] -> a
error ([Char]
 -> ImpM
      lore
      r
      op
      (VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda lore)))
-> [Char]
-> ImpM
     lore
     r
     op
     (VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda lore))
forall a b. (a -> b) -> a -> b
$ [Char]
"Accumulator with no arrays: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
name
        Just (arrs :: [VName]
arrs@(VName
arr : [VName]
_), Just (Lambda lore
op, [SubExp]
_)) -> do
          Space
space <- VName -> ImpM lore r op Space
forall lore r op. VName -> ImpM lore r op Space
lookupArraySpace VName
arr
          let ([Param (LParamInfo lore)]
i_params, [Param (LParamInfo lore)]
ps) = Int
-> [Param (LParamInfo lore)]
-> ([Param (LParamInfo lore)], [Param (LParamInfo lore)])
forall a. Int -> [a] -> ([a], [a])
splitAt (Shape (TExp Int64) -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape (TExp Int64)
is) ([Param (LParamInfo lore)]
 -> ([Param (LParamInfo lore)], [Param (LParamInfo lore)]))
-> [Param (LParamInfo lore)]
-> ([Param (LParamInfo lore)], [Param (LParamInfo lore)])
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Param (LParamInfo lore)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
op
          (VName -> TExp Int64 -> ImpM lore r op ())
-> [VName] -> Shape (TExp Int64) -> ImpM lore r op ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TExp Int64 -> ImpM lore r op ()
forall t lore r op. VName -> TExp t -> ImpM lore r op ()
dPrimV_ ((Param (LParamInfo lore) -> VName)
-> [Param (LParamInfo lore)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (LParamInfo lore) -> VName
forall dec. Param dec -> VName
paramName [Param (LParamInfo lore)]
i_params) Shape (TExp Int64)
is
          (VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda lore))
-> ImpM
     lore
     r
     op
     (VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda lore))
forall (m :: * -> *) a. Monad m => a -> m a
return
            ( VName
acc,
              Space
space,
              [VName]
arrs,
              (SubExp -> TExp Int64) -> [SubExp] -> Shape (TExp Int64)
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
ispace),
              Lambda lore -> Maybe (Lambda lore)
forall a. a -> Maybe a
Just Lambda lore
op {lambdaParams :: [Param (LParamInfo lore)]
lambdaParams = [Param (LParamInfo lore)]
ps}
            )
        Just (arrs :: [VName]
arrs@(VName
arr : [VName]
_), Maybe (Lambda lore, [SubExp])
Nothing) -> do
          Space
space <- VName -> ImpM lore r op Space
forall lore r op. VName -> ImpM lore r op Space
lookupArraySpace VName
arr
          (VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda lore))
-> ImpM
     lore
     r
     op
     (VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
acc, Space
space, [VName]
arrs, (SubExp -> TExp Int64) -> [SubExp] -> Shape (TExp Int64)
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
ispace), Maybe (Lambda lore)
forall a. Maybe a
Nothing)
        Maybe ([VName], Maybe (Lambda lore, [SubExp]))
Nothing ->
          [Char]
-> ImpM
     lore
     r
     op
     (VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda lore))
forall a. HasCallStack => [Char] -> a
error ([Char]
 -> ImpM
      lore
      r
      op
      (VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda lore)))
-> [Char]
-> ImpM
     lore
     r
     op
     (VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda lore))
forall a b. (a -> b) -> a -> b
$ [Char]
"ImpGen.lookupAcc: unlisted accumulator: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
name
    VarEntry lore
_ -> [Char]
-> ImpM
     lore
     r
     op
     (VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda lore))
forall a. HasCallStack => [Char] -> a
error ([Char]
 -> ImpM
      lore
      r
      op
      (VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda lore)))
-> [Char]
-> ImpM
     lore
     r
     op
     (VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda lore))
forall a b. (a -> b) -> a -> b
$ [Char]
"ImpGen.lookupAcc: not an accumulator: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
name

destinationFromPattern :: Mem lore => Pattern lore -> ImpM lore r op Destination
destinationFromPattern :: forall lore r op.
Mem lore =>
Pattern lore -> ImpM lore r op Destination
destinationFromPattern Pattern lore
pat =
  ([ValueDestination] -> Destination)
-> ImpM lore r op [ValueDestination] -> ImpM lore r op Destination
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Maybe Int -> [ValueDestination] -> Destination
Destination (VName -> Int
baseTag (VName -> Int) -> Maybe VName -> Maybe Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [VName] -> Maybe VName
forall a. [a] -> Maybe a
maybeHead (PatternT LParamMem -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern lore
PatternT LParamMem
pat))) (ImpM lore r op [ValueDestination] -> ImpM lore r op Destination)
-> ([PatElemT LParamMem] -> ImpM lore r op [ValueDestination])
-> [PatElemT LParamMem]
-> ImpM lore r op Destination
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElemT LParamMem -> ImpM lore r op ValueDestination)
-> [PatElemT LParamMem] -> ImpM lore r op [ValueDestination]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM PatElemT LParamMem -> ImpM lore r op ValueDestination
forall {dec} {lore} {r} {op}.
PatElemT dec -> ImpM lore r op ValueDestination
inspect ([PatElemT LParamMem] -> ImpM lore r op Destination)
-> [PatElemT LParamMem] -> ImpM lore r op Destination
forall a b. (a -> b) -> a -> b
$
    PatternT LParamMem -> [PatElemT LParamMem]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern lore
PatternT LParamMem
pat
  where
    inspect :: PatElemT dec -> ImpM lore r op ValueDestination
inspect PatElemT dec
patElem = do
      let name :: VName
name = PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT dec
patElem
      VarEntry lore
entry <- VName -> ImpM lore r op (VarEntry lore)
forall lore r op. VName -> ImpM lore r op (VarEntry lore)
lookupVar VName
name
      case VarEntry lore
entry of
        ArrayVar Maybe (Exp lore)
_ (ArrayEntry MemLocation {} PrimType
_) ->
          ValueDestination -> ImpM lore r op ValueDestination
forall (m :: * -> *) a. Monad m => a -> m a
return (ValueDestination -> ImpM lore r op ValueDestination)
-> ValueDestination -> ImpM lore r op ValueDestination
forall a b. (a -> b) -> a -> b
$ Maybe MemLocation -> ValueDestination
ArrayDestination Maybe MemLocation
forall a. Maybe a
Nothing
        MemVar {} ->
          ValueDestination -> ImpM lore r op ValueDestination
forall (m :: * -> *) a. Monad m => a -> m a
return (ValueDestination -> ImpM lore r op ValueDestination)
-> ValueDestination -> ImpM lore r op ValueDestination
forall a b. (a -> b) -> a -> b
$ VName -> ValueDestination
MemoryDestination VName
name
        ScalarVar {} ->
          ValueDestination -> ImpM lore r op ValueDestination
forall (m :: * -> *) a. Monad m => a -> m a
return (ValueDestination -> ImpM lore r op ValueDestination)
-> ValueDestination -> ImpM lore r op ValueDestination
forall a b. (a -> b) -> a -> b
$ VName -> ValueDestination
ScalarDestination VName
name
        AccVar {} ->
          ValueDestination -> ImpM lore r op ValueDestination
forall (m :: * -> *) a. Monad m => a -> m a
return (ValueDestination -> ImpM lore r op ValueDestination)
-> ValueDestination -> ImpM lore r op ValueDestination
forall a b. (a -> b) -> a -> b
$ Maybe MemLocation -> ValueDestination
ArrayDestination Maybe MemLocation
forall a. Maybe a
Nothing

fullyIndexArray ::
  VName ->
  [Imp.TExp Int64] ->
  ImpM lore r op (VName, Imp.Space, Count Elements (Imp.TExp Int64))
fullyIndexArray :: forall lore r op.
VName
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray VName
name Shape (TExp Int64)
indices = do
  ArrayEntry
arr <- VName -> ImpM lore r op ArrayEntry
forall lore r op. VName -> ImpM lore r op ArrayEntry
lookupArray VName
name
  MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall lore r op.
MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' (ArrayEntry -> MemLocation
entryArrayLocation ArrayEntry
arr) Shape (TExp Int64)
indices

fullyIndexArray' ::
  MemLocation ->
  [Imp.TExp Int64] ->
  ImpM lore r op (VName, Imp.Space, Count Elements (Imp.TExp Int64))
fullyIndexArray' :: forall lore r op.
MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' (MemLocation VName
mem [SubExp]
_ IxFun (TExp Int64)
ixfun) Shape (TExp Int64)
indices = do
  Space
space <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM lore r op MemEntry -> ImpM lore r op Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM lore r op MemEntry
forall lore r op. VName -> ImpM lore r op MemEntry
lookupMemory VName
mem
  let indices' :: Shape (TExp Int64)
indices' = case Space
space of
        ScalarSpace [SubExp]
ds PrimType
_ ->
          let (Shape (TExp Int64)
zero_is, Shape (TExp Int64)
is) = Int
-> Shape (TExp Int64) -> (Shape (TExp Int64), Shape (TExp Int64))
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
ds) Shape (TExp Int64)
indices
           in (TExp Int64 -> TExp Int64)
-> Shape (TExp Int64) -> Shape (TExp Int64)
forall a b. (a -> b) -> [a] -> [b]
map (TExp Int64 -> TExp Int64 -> TExp Int64
forall a b. a -> b -> a
const TExp Int64
0) Shape (TExp Int64)
zero_is Shape (TExp Int64) -> Shape (TExp Int64) -> Shape (TExp Int64)
forall a. [a] -> [a] -> [a]
++ Shape (TExp Int64)
is
        Space
_ -> Shape (TExp Int64)
indices
  (VName, Space, Count Elements (TExp Int64))
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( VName
mem,
      Space
space,
      TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
elements (TExp Int64 -> Count Elements (TExp Int64))
-> TExp Int64 -> Count Elements (TExp Int64)
forall a b. (a -> b) -> a -> b
$ IxFun (TExp Int64) -> Shape (TExp Int64) -> TExp Int64
forall num.
(IntegralExp num, Eq num) =>
IxFun num -> Indices num -> num
IxFun.index IxFun (TExp Int64)
ixfun Shape (TExp Int64)
indices'
    )

-- More complicated read/write operations that use index functions.

copy :: CopyCompiler lore r op
copy :: forall lore r op. CopyCompiler lore r op
copy PrimType
bt MemLocation
dest [DimIndex (TExp Int64)]
destslice MemLocation
src [DimIndex (TExp Int64)]
srcslice = do
  CopyCompiler lore r op
cc <- (Env lore r op -> CopyCompiler lore r op)
-> ImpM lore r op (CopyCompiler lore r op)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> CopyCompiler lore r op
forall lore r op. Env lore r op -> CopyCompiler lore r op
envCopyCompiler
  CopyCompiler lore r op
cc PrimType
bt MemLocation
dest [DimIndex (TExp Int64)]
destslice MemLocation
src [DimIndex (TExp Int64)]
srcslice

-- | Is this copy really a mapping with transpose?
isMapTransposeCopy ::
  PrimType ->
  MemLocation ->
  Slice (Imp.TExp Int64) ->
  MemLocation ->
  Slice (Imp.TExp Int64) ->
  Maybe
    ( Imp.TExp Int64,
      Imp.TExp Int64,
      Imp.TExp Int64,
      Imp.TExp Int64,
      Imp.TExp Int64
    )
isMapTransposeCopy :: PrimType
-> MemLocation
-> [DimIndex (TExp Int64)]
-> MemLocation
-> [DimIndex (TExp Int64)]
-> Maybe
     (TExp Int64, TExp Int64, TExp Int64, TExp Int64, TExp Int64)
isMapTransposeCopy
  PrimType
bt
  (MemLocation VName
_ [SubExp]
_ IxFun (TExp Int64)
destIxFun)
  [DimIndex (TExp Int64)]
destslice
  (MemLocation VName
_ [SubExp]
_ IxFun (TExp Int64)
srcIxFun)
  [DimIndex (TExp Int64)]
srcslice
    | Just (TExp Int64
dest_offset, [(Int, TExp Int64)]
perm_and_destshape) <- IxFun (TExp Int64)
-> TExp Int64 -> Maybe (TExp Int64, [(Int, TExp Int64)])
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe (num, [(Int, num)])
IxFun.rearrangeWithOffset IxFun (TExp Int64)
destIxFun' TExp Int64
bt_size,
      ([Int]
perm, Shape (TExp Int64)
destshape) <- [(Int, TExp Int64)] -> ([Int], Shape (TExp Int64))
forall a b. [(a, b)] -> ([a], [b])
unzip [(Int, TExp Int64)]
perm_and_destshape,
      Just TExp Int64
src_offset <- IxFun (TExp Int64) -> TExp Int64 -> Maybe (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
IxFun.linearWithOffset IxFun (TExp Int64)
srcIxFun' TExp Int64
bt_size,
      Just (Int
r1, Int
r2, Int
_) <- [Int] -> Maybe (Int, Int, Int)
isMapTranspose [Int]
perm =
      Shape (TExp Int64)
-> ((Shape (TExp Int64), Shape (TExp Int64))
    -> (Shape (TExp Int64), Shape (TExp Int64)))
-> Int
-> Int
-> TExp Int64
-> TExp Int64
-> Maybe
     (TExp Int64, TExp Int64, TExp Int64, TExp Int64, TExp Int64)
forall {t :: * -> *} {t :: * -> *} {c} {d} {e} {m :: * -> *} {a}
       {b}.
(Foldable t, Foldable t, Num c, Num d, Num e, Monad m) =>
[c]
-> (([c], [c]) -> (t d, t e))
-> Int
-> Int
-> a
-> b
-> m (a, b, c, d, e)
isOk Shape (TExp Int64)
destshape (Shape (TExp Int64), Shape (TExp Int64))
-> (Shape (TExp Int64), Shape (TExp Int64))
forall {b} {a}. (b, a) -> (a, b)
swap Int
r1 Int
r2 TExp Int64
dest_offset TExp Int64
src_offset
    | Just TExp Int64
dest_offset <- IxFun (TExp Int64) -> TExp Int64 -> Maybe (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
IxFun.linearWithOffset IxFun (TExp Int64)
destIxFun' TExp Int64
bt_size,
      Just (TExp Int64
src_offset, [(Int, TExp Int64)]
perm_and_srcshape) <- IxFun (TExp Int64)
-> TExp Int64 -> Maybe (TExp Int64, [(Int, TExp Int64)])
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe (num, [(Int, num)])
IxFun.rearrangeWithOffset IxFun (TExp Int64)
srcIxFun' TExp Int64
bt_size,
      ([Int]
perm, Shape (TExp Int64)
srcshape) <- [(Int, TExp Int64)] -> ([Int], Shape (TExp Int64))
forall a b. [(a, b)] -> ([a], [b])
unzip [(Int, TExp Int64)]
perm_and_srcshape,
      Just (Int
r1, Int
r2, Int
_) <- [Int] -> Maybe (Int, Int, Int)
isMapTranspose [Int]
perm =
      Shape (TExp Int64)
-> ((Shape (TExp Int64), Shape (TExp Int64))
    -> (Shape (TExp Int64), Shape (TExp Int64)))
-> Int
-> Int
-> TExp Int64
-> TExp Int64
-> Maybe
     (TExp Int64, TExp Int64, TExp Int64, TExp Int64, TExp Int64)
forall {t :: * -> *} {t :: * -> *} {c} {d} {e} {m :: * -> *} {a}
       {b}.
(Foldable t, Foldable t, Num c, Num d, Num e, Monad m) =>
[c]
-> (([c], [c]) -> (t d, t e))
-> Int
-> Int
-> a
-> b
-> m (a, b, c, d, e)
isOk Shape (TExp Int64)
srcshape (Shape (TExp Int64), Shape (TExp Int64))
-> (Shape (TExp Int64), Shape (TExp Int64))
forall a. a -> a
id Int
r1 Int
r2 TExp Int64
dest_offset TExp Int64
src_offset
    | Bool
otherwise =
      Maybe (TExp Int64, TExp Int64, TExp Int64, TExp Int64, TExp Int64)
forall a. Maybe a
Nothing
    where
      bt_size :: TExp Int64
bt_size = PrimType -> TExp Int64
forall a. Num a => PrimType -> a
primByteSize PrimType
bt
      swap :: (b, a) -> (a, b)
swap (b
x, a
y) = (a
y, b
x)

      destIxFun' :: IxFun (TExp Int64)
destIxFun' = IxFun (TExp Int64) -> [DimIndex (TExp Int64)] -> IxFun (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TExp Int64)
destIxFun [DimIndex (TExp Int64)]
destslice
      srcIxFun' :: IxFun (TExp Int64)
srcIxFun' = IxFun (TExp Int64) -> [DimIndex (TExp Int64)] -> IxFun (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TExp Int64)
srcIxFun [DimIndex (TExp Int64)]
srcslice

      isOk :: [c]
-> (([c], [c]) -> (t d, t e))
-> Int
-> Int
-> a
-> b
-> m (a, b, c, d, e)
isOk [c]
shape ([c], [c]) -> (t d, t e)
f Int
r1 Int
r2 a
dest_offset b
src_offset = do
        let (c
num_arrays, d
size_x, e
size_y) = [c] -> (([c], [c]) -> (t d, t e)) -> Int -> Int -> (c, d, e)
forall {t :: * -> *} {t :: * -> *} {a} {b} {c}.
(Foldable t, Foldable t, Num a, Num b, Num c) =>
[a] -> (([a], [a]) -> (t b, t c)) -> Int -> Int -> (a, b, c)
getSizes [c]
shape ([c], [c]) -> (t d, t e)
f Int
r1 Int
r2
        (a, b, c, d, e) -> m (a, b, c, d, e)
forall (m :: * -> *) a. Monad m => a -> m a
return
          ( a
dest_offset,
            b
src_offset,
            c
num_arrays,
            d
size_x,
            e
size_y
          )

      getSizes :: [a] -> (([a], [a]) -> (t b, t c)) -> Int -> Int -> (a, b, c)
getSizes [a]
shape ([a], [a]) -> (t b, t c)
f Int
r1 Int
r2 =
        let ([a]
mapped, [a]
notmapped) = Int -> [a] -> ([a], [a])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
r1 [a]
shape
            (t b
pretrans, t c
posttrans) = ([a], [a]) -> (t b, t c)
f (([a], [a]) -> (t b, t c)) -> ([a], [a]) -> (t b, t c)
forall a b. (a -> b) -> a -> b
$ Int -> [a] -> ([a], [a])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
r2 [a]
notmapped
         in ([a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [a]
mapped, t b -> b
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product t b
pretrans, t c -> c
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product t c
posttrans)

mapTransposeName :: PrimType -> String
mapTransposeName :: PrimType -> [Char]
mapTransposeName PrimType
bt = [Char]
"map_transpose_" [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ PrimType -> [Char]
forall a. Pretty a => a -> [Char]
pretty PrimType
bt

mapTransposeForType :: PrimType -> ImpM lore r op Name
mapTransposeForType :: forall lore r op. PrimType -> ImpM lore r op Name
mapTransposeForType PrimType
bt = do
  let fname :: Name
fname = [Char] -> Name
nameFromString ([Char] -> Name) -> [Char] -> Name
forall a b. (a -> b) -> a -> b
$ [Char]
"builtin#" [Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> PrimType -> [Char]
mapTransposeName PrimType
bt

  Bool
exists <- Name -> ImpM lore r op Bool
forall lore r op. Name -> ImpM lore r op Bool
hasFunction Name
fname
  Bool -> ImpM lore r op () -> ImpM lore r op ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
exists (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Name -> Function op -> ImpM lore r op ()
forall op lore r. Name -> Function op -> ImpM lore r op ()
emitFunction Name
fname (Function op -> ImpM lore r op ())
-> Function op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Name -> PrimType -> Function op
forall op. Name -> PrimType -> Function op
mapTransposeFunction Name
fname PrimType
bt

  Name -> ImpM lore r op Name
forall (m :: * -> *) a. Monad m => a -> m a
return Name
fname

-- | Use an 'Imp.Copy' if possible, otherwise 'copyElementWise'.
defaultCopy :: CopyCompiler lore r op
defaultCopy :: forall lore r op. CopyCompiler lore r op
defaultCopy PrimType
pt MemLocation
dest [DimIndex (TExp Int64)]
destslice MemLocation
src [DimIndex (TExp Int64)]
srcslice
  | Just
      ( TExp Int64
destoffset,
        TExp Int64
srcoffset,
        TExp Int64
num_arrays,
        TExp Int64
size_x,
        TExp Int64
size_y
        ) <-
      PrimType
-> MemLocation
-> [DimIndex (TExp Int64)]
-> MemLocation
-> [DimIndex (TExp Int64)]
-> Maybe
     (TExp Int64, TExp Int64, TExp Int64, TExp Int64, TExp Int64)
isMapTransposeCopy PrimType
pt MemLocation
dest [DimIndex (TExp Int64)]
destslice MemLocation
src [DimIndex (TExp Int64)]
srcslice = do
    Name
fname <- PrimType -> ImpM lore r op Name
forall lore r op. PrimType -> ImpM lore r op Name
mapTransposeForType PrimType
pt
    Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
      [VName] -> Name -> [Arg] -> Code op
forall a. [VName] -> Name -> [Arg] -> Code a
Imp.Call
        []
        Name
fname
        ([Arg] -> Code op) -> [Arg] -> Code op
forall a b. (a -> b) -> a -> b
$ PrimType
-> VName
-> Count Bytes (TExp Int64)
-> VName
-> Count Bytes (TExp Int64)
-> TExp Int64
-> TExp Int64
-> TExp Int64
-> [Arg]
transposeArgs
          PrimType
pt
          VName
destmem
          (TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
bytes TExp Int64
destoffset)
          VName
srcmem
          (TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
bytes TExp Int64
srcoffset)
          TExp Int64
num_arrays
          TExp Int64
size_x
          TExp Int64
size_y
  | Just TExp Int64
destoffset <-
      IxFun (TExp Int64) -> TExp Int64 -> Maybe (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
IxFun.linearWithOffset (IxFun (TExp Int64) -> [DimIndex (TExp Int64)] -> IxFun (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TExp Int64)
dest_ixfun [DimIndex (TExp Int64)]
destslice) TExp Int64
pt_size,
    Just TExp Int64
srcoffset <-
      IxFun (TExp Int64) -> TExp Int64 -> Maybe (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
IxFun.linearWithOffset (IxFun (TExp Int64) -> [DimIndex (TExp Int64)] -> IxFun (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TExp Int64)
src_ixfun [DimIndex (TExp Int64)]
srcslice) TExp Int64
pt_size = do
    Space
srcspace <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM lore r op MemEntry -> ImpM lore r op Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM lore r op MemEntry
forall lore r op. VName -> ImpM lore r op MemEntry
lookupMemory VName
srcmem
    Space
destspace <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM lore r op MemEntry -> ImpM lore r op Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM lore r op MemEntry
forall lore r op. VName -> ImpM lore r op MemEntry
lookupMemory VName
destmem
    if Space -> Bool
isScalarSpace Space
srcspace Bool -> Bool -> Bool
|| Space -> Bool
isScalarSpace Space
destspace
      then CopyCompiler lore r op
forall lore r op. CopyCompiler lore r op
copyElementWise PrimType
pt MemLocation
dest [DimIndex (TExp Int64)]
destslice MemLocation
src [DimIndex (TExp Int64)]
srcslice
      else
        Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
          VName
-> Count Bytes (TExp Int64)
-> Space
-> VName
-> Count Bytes (TExp Int64)
-> Space
-> Count Bytes (TExp Int64)
-> Code op
forall a.
VName
-> Count Bytes (TExp Int64)
-> Space
-> VName
-> Count Bytes (TExp Int64)
-> Space
-> Count Bytes (TExp Int64)
-> Code a
Imp.Copy
            VName
destmem
            (TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
bytes TExp Int64
destoffset)
            Space
destspace
            VName
srcmem
            (TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
bytes TExp Int64
srcoffset)
            Space
srcspace
            (Count Bytes (TExp Int64) -> Code op)
-> Count Bytes (TExp Int64) -> Code op
forall a b. (a -> b) -> a -> b
$ Count Elements (TExp Int64)
num_elems Count Elements (TExp Int64) -> PrimType -> Count Bytes (TExp Int64)
`withElemType` PrimType
pt
  | Bool
otherwise =
    CopyCompiler lore r op
forall lore r op. CopyCompiler lore r op
copyElementWise PrimType
pt MemLocation
dest [DimIndex (TExp Int64)]
destslice MemLocation
src [DimIndex (TExp Int64)]
srcslice
  where
    pt_size :: TExp Int64
pt_size = PrimType -> TExp Int64
forall a. Num a => PrimType -> a
primByteSize PrimType
pt
    num_elems :: Count Elements (TExp Int64)
num_elems = TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
Imp.elements (TExp Int64 -> Count Elements (TExp Int64))
-> TExp Int64 -> Count Elements (TExp Int64)
forall a b. (a -> b) -> a -> b
$ Shape (TExp Int64) -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product (Shape (TExp Int64) -> TExp Int64)
-> Shape (TExp Int64) -> TExp Int64
forall a b. (a -> b) -> a -> b
$ [DimIndex (TExp Int64)] -> Shape (TExp Int64)
forall d. Slice d -> [d]
sliceDims [DimIndex (TExp Int64)]
srcslice

    MemLocation VName
destmem [SubExp]
_ IxFun (TExp Int64)
dest_ixfun = MemLocation
dest
    MemLocation VName
srcmem [SubExp]
_ IxFun (TExp Int64)
src_ixfun = MemLocation
src

    isScalarSpace :: Space -> Bool
isScalarSpace ScalarSpace {} = Bool
True
    isScalarSpace Space
_ = Bool
False

copyElementWise :: CopyCompiler lore r op
copyElementWise :: forall lore r op. CopyCompiler lore r op
copyElementWise PrimType
bt MemLocation
dest [DimIndex (TExp Int64)]
destslice MemLocation
src [DimIndex (TExp Int64)]
srcslice = do
  let bounds :: Shape (TExp Int64)
bounds = [DimIndex (TExp Int64)] -> Shape (TExp Int64)
forall d. Slice d -> [d]
sliceDims [DimIndex (TExp Int64)]
srcslice
  [VName]
is <- Int -> ImpM lore r op VName -> ImpM lore r op [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (Shape (TExp Int64) -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape (TExp Int64)
bounds) ([Char] -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"i")
  let ivars :: Shape (TExp Int64)
ivars = (VName -> TExp Int64) -> [VName] -> Shape (TExp Int64)
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 [VName]
is
  (VName
destmem, Space
destspace, Count Elements (TExp Int64)
destidx) <-
    MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall lore r op.
MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLocation
dest (Shape (TExp Int64)
 -> ImpM lore r op (VName, Space, Count Elements (TExp Int64)))
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall a b. (a -> b) -> a -> b
$ [DimIndex (TExp Int64)] -> Shape (TExp Int64) -> Shape (TExp Int64)
forall d. Num d => Slice d -> [d] -> [d]
fixSlice [DimIndex (TExp Int64)]
destslice Shape (TExp Int64)
ivars
  (VName
srcmem, Space
srcspace, Count Elements (TExp Int64)
srcidx) <-
    MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall lore r op.
MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLocation
src (Shape (TExp Int64)
 -> ImpM lore r op (VName, Space, Count Elements (TExp Int64)))
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall a b. (a -> b) -> a -> b
$ [DimIndex (TExp Int64)] -> Shape (TExp Int64) -> Shape (TExp Int64)
forall d. Num d => Slice d -> [d] -> [d]
fixSlice [DimIndex (TExp Int64)]
srcslice Shape (TExp Int64)
ivars
  Volatility
vol <- (Env lore r op -> Volatility) -> ImpM lore r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Volatility
forall lore r op. Env lore r op -> Volatility
envVolatility
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
    ((Code op -> Code op)
 -> (Code op -> Code op) -> Code op -> Code op)
-> (Code op -> Code op)
-> [Code op -> Code op]
-> Code op
-> Code op
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (Code op -> Code op) -> (Code op -> Code op) -> Code op -> Code op
forall b c a. (b -> c) -> (a -> b) -> a -> c
(.) Code op -> Code op
forall a. a -> a
id ((VName -> Exp -> Code op -> Code op)
-> [VName] -> [Exp] -> [Code op -> Code op]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> Exp -> Code op -> Code op
forall a. VName -> Exp -> Code a -> Code a
Imp.For [VName]
is ([Exp] -> [Code op -> Code op]) -> [Exp] -> [Code op -> Code op]
forall a b. (a -> b) -> a -> b
$ (TExp Int64 -> Exp) -> Shape (TExp Int64) -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped Shape (TExp Int64)
bounds) (Code op -> Code op) -> Code op -> Code op
forall a b. (a -> b) -> a -> b
$
      VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code op
forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
destmem Count Elements (TExp Int64)
destidx PrimType
bt Space
destspace Volatility
vol (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$
        VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
Imp.index VName
srcmem Count Elements (TExp Int64)
srcidx PrimType
bt Space
srcspace Volatility
vol

-- | Copy from here to there; both destination and source may be
-- indexeded.
copyArrayDWIM ::
  PrimType ->
  MemLocation ->
  [DimIndex (Imp.TExp Int64)] ->
  MemLocation ->
  [DimIndex (Imp.TExp Int64)] ->
  ImpM lore r op (Imp.Code op)
copyArrayDWIM :: forall lore r op.
PrimType
-> MemLocation
-> [DimIndex (TExp Int64)]
-> MemLocation
-> [DimIndex (TExp Int64)]
-> ImpM lore r op (Code op)
copyArrayDWIM
  PrimType
bt
  destlocation :: MemLocation
destlocation@(MemLocation VName
_ [SubExp]
destshape IxFun (TExp Int64)
_)
  [DimIndex (TExp Int64)]
destslice
  srclocation :: MemLocation
srclocation@(MemLocation VName
_ [SubExp]
srcshape IxFun (TExp Int64)
_)
  [DimIndex (TExp Int64)]
srcslice
    | Just Shape (TExp Int64)
destis <- (DimIndex (TExp Int64) -> Maybe (TExp Int64))
-> [DimIndex (TExp Int64)] -> Maybe (Shape (TExp Int64))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndex (TExp Int64) -> Maybe (TExp Int64)
forall d. DimIndex d -> Maybe d
dimFix [DimIndex (TExp Int64)]
destslice,
      Just Shape (TExp Int64)
srcis <- (DimIndex (TExp Int64) -> Maybe (TExp Int64))
-> [DimIndex (TExp Int64)] -> Maybe (Shape (TExp Int64))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndex (TExp Int64) -> Maybe (TExp Int64)
forall d. DimIndex d -> Maybe d
dimFix [DimIndex (TExp Int64)]
srcslice,
      Shape (TExp Int64) -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape (TExp Int64)
srcis Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
srcshape,
      Shape (TExp Int64) -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape (TExp Int64)
destis Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
destshape = do
      (VName
targetmem, Space
destspace, Count Elements (TExp Int64)
targetoffset) <-
        MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall lore r op.
MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLocation
destlocation Shape (TExp Int64)
destis
      (VName
srcmem, Space
srcspace, Count Elements (TExp Int64)
srcoffset) <-
        MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall lore r op.
MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLocation
srclocation Shape (TExp Int64)
srcis
      Volatility
vol <- (Env lore r op -> Volatility) -> ImpM lore r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Volatility
forall lore r op. Env lore r op -> Volatility
envVolatility
      Code op -> ImpM lore r op (Code op)
forall (m :: * -> *) a. Monad m => a -> m a
return (Code op -> ImpM lore r op (Code op))
-> Code op -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$
        VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code op
forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
targetmem Count Elements (TExp Int64)
targetoffset PrimType
bt Space
destspace Volatility
vol (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$
          VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
Imp.index VName
srcmem Count Elements (TExp Int64)
srcoffset PrimType
bt Space
srcspace Volatility
vol
    | Bool
otherwise = do
      let destslice' :: [DimIndex (TExp Int64)]
destslice' =
            Shape (TExp Int64)
-> [DimIndex (TExp Int64)] -> [DimIndex (TExp Int64)]
forall d. Num d => [d] -> [DimIndex d] -> [DimIndex d]
fullSliceNum ((SubExp -> TExp Int64) -> [SubExp] -> Shape (TExp Int64)
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
destshape) [DimIndex (TExp Int64)]
destslice
          srcslice' :: [DimIndex (TExp Int64)]
srcslice' =
            Shape (TExp Int64)
-> [DimIndex (TExp Int64)] -> [DimIndex (TExp Int64)]
forall d. Num d => [d] -> [DimIndex d] -> [DimIndex d]
fullSliceNum ((SubExp -> TExp Int64) -> [SubExp] -> Shape (TExp Int64)
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
srcshape) [DimIndex (TExp Int64)]
srcslice
          destrank :: Int
destrank = Shape (TExp Int64) -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Shape (TExp Int64) -> Int) -> Shape (TExp Int64) -> Int
forall a b. (a -> b) -> a -> b
$ [DimIndex (TExp Int64)] -> Shape (TExp Int64)
forall d. Slice d -> [d]
sliceDims [DimIndex (TExp Int64)]
destslice'
          srcrank :: Int
srcrank = Shape (TExp Int64) -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Shape (TExp Int64) -> Int) -> Shape (TExp Int64) -> Int
forall a b. (a -> b) -> a -> b
$ [DimIndex (TExp Int64)] -> Shape (TExp Int64)
forall d. Slice d -> [d]
sliceDims [DimIndex (TExp Int64)]
srcslice'
      if Int
destrank Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
srcrank
        then
          [Char] -> ImpM lore r op (Code op)
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM lore r op (Code op))
-> [Char] -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$
            [Char]
"copyArrayDWIM: cannot copy to "
              [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty (MemLocation -> VName
memLocationName MemLocation
destlocation)
              [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
" from "
              [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty (MemLocation -> VName
memLocationName MemLocation
srclocation)
              [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
" because ranks do not match ("
              [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Pretty a => a -> [Char]
pretty Int
destrank
              [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
" vs "
              [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Pretty a => a -> [Char]
pretty Int
srcrank
              [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
")"
        else
          if MemLocation
destlocation MemLocation -> MemLocation -> Bool
forall a. Eq a => a -> a -> Bool
== MemLocation
srclocation Bool -> Bool -> Bool
&& [DimIndex (TExp Int64)]
destslice' [DimIndex (TExp Int64)] -> [DimIndex (TExp Int64)] -> Bool
forall a. Eq a => a -> a -> Bool
== [DimIndex (TExp Int64)]
srcslice'
            then Code op -> ImpM lore r op (Code op)
forall (m :: * -> *) a. Monad m => a -> m a
return Code op
forall a. Monoid a => a
mempty -- Copy would be no-op.
            else ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (ImpM lore r op () -> ImpM lore r op (Code op))
-> ImpM lore r op () -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$ CopyCompiler lore r op
forall lore r op. CopyCompiler lore r op
copy PrimType
bt MemLocation
destlocation [DimIndex (TExp Int64)]
destslice' MemLocation
srclocation [DimIndex (TExp Int64)]
srcslice'

-- | Like 'copyDWIM', but the target is a 'ValueDestination'
-- instead of a variable name.
copyDWIMDest ::
  ValueDestination ->
  [DimIndex (Imp.TExp Int64)] ->
  SubExp ->
  [DimIndex (Imp.TExp Int64)] ->
  ImpM lore r op ()
copyDWIMDest :: forall lore r op.
ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIMDest ValueDestination
_ [DimIndex (TExp Int64)]
_ (Constant PrimValue
v) (DimIndex (TExp Int64)
_ : [DimIndex (TExp Int64)]
_) =
  [Char] -> ImpM lore r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM lore r op ()) -> [Char] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
    [[Char]] -> [Char]
unwords [[Char]
"copyDWIMDest: constant source", PrimValue -> [Char]
forall a. Pretty a => a -> [Char]
pretty PrimValue
v, [Char]
"cannot be indexed."]
copyDWIMDest ValueDestination
pat [DimIndex (TExp Int64)]
dest_slice (Constant PrimValue
v) [] =
  case (DimIndex (TExp Int64) -> Maybe (TExp Int64))
-> [DimIndex (TExp Int64)] -> Maybe (Shape (TExp Int64))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndex (TExp Int64) -> Maybe (TExp Int64)
forall d. DimIndex d -> Maybe d
dimFix [DimIndex (TExp Int64)]
dest_slice of
    Maybe (Shape (TExp Int64))
Nothing ->
      [Char] -> ImpM lore r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM lore r op ()) -> [Char] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
        [[Char]] -> [Char]
unwords [[Char]
"copyDWIMDest: constant source", PrimValue -> [Char]
forall a. Pretty a => a -> [Char]
pretty PrimValue
v, [Char]
"with slice destination."]
    Just Shape (TExp Int64)
dest_is ->
      case ValueDestination
pat of
        ScalarDestination VName
name ->
          Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
name (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ PrimValue -> Exp
forall v. PrimValue -> PrimExp v
Imp.ValueExp PrimValue
v
        MemoryDestination {} ->
          [Char] -> ImpM lore r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM lore r op ()) -> [Char] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
            [[Char]] -> [Char]
unwords [[Char]
"copyDWIMDest: constant source", PrimValue -> [Char]
forall a. Pretty a => a -> [Char]
pretty PrimValue
v, [Char]
"cannot be written to memory destination."]
        ArrayDestination (Just MemLocation
dest_loc) -> do
          (VName
dest_mem, Space
dest_space, Count Elements (TExp Int64)
dest_i) <-
            MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall lore r op.
MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLocation
dest_loc Shape (TExp Int64)
dest_is
          Volatility
vol <- (Env lore r op -> Volatility) -> ImpM lore r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Volatility
forall lore r op. Env lore r op -> Volatility
envVolatility
          Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code op
forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
dest_mem Count Elements (TExp Int64)
dest_i PrimType
bt Space
dest_space Volatility
vol (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ PrimValue -> Exp
forall v. PrimValue -> PrimExp v
Imp.ValueExp PrimValue
v
        ArrayDestination Maybe MemLocation
Nothing ->
          [Char] -> ImpM lore r op ()
forall a. HasCallStack => [Char] -> a
error [Char]
"copyDWIMDest: ArrayDestination Nothing"
  where
    bt :: PrimType
bt = PrimValue -> PrimType
primValueType PrimValue
v
copyDWIMDest ValueDestination
dest [DimIndex (TExp Int64)]
dest_slice (Var VName
src) [DimIndex (TExp Int64)]
src_slice = do
  VarEntry lore
src_entry <- VName -> ImpM lore r op (VarEntry lore)
forall lore r op. VName -> ImpM lore r op (VarEntry lore)
lookupVar VName
src
  case (ValueDestination
dest, VarEntry lore
src_entry) of
    (MemoryDestination VName
mem, MemVar Maybe (Exp lore)
_ (MemEntry Space
space)) ->
      Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> Space -> Code op
forall a. VName -> VName -> Space -> Code a
Imp.SetMem VName
mem VName
src Space
space
    (MemoryDestination {}, VarEntry lore
_) ->
      [Char] -> ImpM lore r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM lore r op ()) -> [Char] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
        [[Char]] -> [Char]
unwords [[Char]
"copyDWIMDest: cannot write", VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
src, [Char]
"to memory destination."]
    (ValueDestination
_, MemVar {}) ->
      [Char] -> ImpM lore r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM lore r op ()) -> [Char] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
        [[Char]] -> [Char]
unwords [[Char]
"copyDWIMDest: source", VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
src, [Char]
"is a memory block."]
    (ValueDestination
_, ScalarVar Maybe (Exp lore)
_ (ScalarEntry PrimType
_))
      | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [DimIndex (TExp Int64)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [DimIndex (TExp Int64)]
src_slice ->
        [Char] -> ImpM lore r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM lore r op ()) -> [Char] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
          [[Char]] -> [Char]
unwords [[Char]
"copyDWIMDest: prim-typed source", VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
src, [Char]
"with slice", [DimIndex (TExp Int64)] -> [Char]
forall a. Pretty a => a -> [Char]
pretty [DimIndex (TExp Int64)]
src_slice]
    (ScalarDestination VName
name, VarEntry lore
_)
      | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [DimIndex (TExp Int64)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [DimIndex (TExp Int64)]
dest_slice ->
        [Char] -> ImpM lore r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM lore r op ()) -> [Char] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
          [[Char]] -> [Char]
unwords [[Char]
"copyDWIMDest: prim-typed target", VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
name, [Char]
"with slice", [DimIndex (TExp Int64)] -> [Char]
forall a. Pretty a => a -> [Char]
pretty [DimIndex (TExp Int64)]
dest_slice]
    (ScalarDestination VName
name, ScalarVar Maybe (Exp lore)
_ (ScalarEntry PrimType
pt)) ->
      Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
name (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
src PrimType
pt
    (ScalarDestination VName
name, ArrayVar Maybe (Exp lore)
_ ArrayEntry
arr)
      | Just Shape (TExp Int64)
src_is <- (DimIndex (TExp Int64) -> Maybe (TExp Int64))
-> [DimIndex (TExp Int64)] -> Maybe (Shape (TExp Int64))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndex (TExp Int64) -> Maybe (TExp Int64)
forall d. DimIndex d -> Maybe d
dimFix [DimIndex (TExp Int64)]
src_slice,
        [DimIndex (TExp Int64)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex (TExp Int64)]
src_slice Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (ArrayEntry -> [SubExp]
entryArrayShape ArrayEntry
arr) -> do
        let bt :: PrimType
bt = ArrayEntry -> PrimType
entryArrayElemType ArrayEntry
arr
        (VName
mem, Space
space, Count Elements (TExp Int64)
i) <-
          MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall lore r op.
MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' (ArrayEntry -> MemLocation
entryArrayLocation ArrayEntry
arr) Shape (TExp Int64)
src_is
        Volatility
vol <- (Env lore r op -> Volatility) -> ImpM lore r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Volatility
forall lore r op. Env lore r op -> Volatility
envVolatility
        Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
name (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
Imp.index VName
mem Count Elements (TExp Int64)
i PrimType
bt Space
space Volatility
vol
      | Bool
otherwise ->
        [Char] -> ImpM lore r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM lore r op ()) -> [Char] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
          [[Char]] -> [Char]
unwords
            [ [Char]
"copyDWIMDest: prim-typed target",
              VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
name,
              [Char]
"and array-typed source",
              VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
src,
              [Char]
"with slice",
              [DimIndex (TExp Int64)] -> [Char]
forall a. Pretty a => a -> [Char]
pretty [DimIndex (TExp Int64)]
src_slice
            ]
    (ArrayDestination (Just MemLocation
dest_loc), ArrayVar Maybe (Exp lore)
_ ArrayEntry
src_arr) -> do
      let src_loc :: MemLocation
src_loc = ArrayEntry -> MemLocation
entryArrayLocation ArrayEntry
src_arr
          bt :: PrimType
bt = ArrayEntry -> PrimType
entryArrayElemType ArrayEntry
src_arr
      Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ())
-> ImpM lore r op (Code op) -> ImpM lore r op ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< PrimType
-> MemLocation
-> [DimIndex (TExp Int64)]
-> MemLocation
-> [DimIndex (TExp Int64)]
-> ImpM lore r op (Code op)
forall lore r op.
PrimType
-> MemLocation
-> [DimIndex (TExp Int64)]
-> MemLocation
-> [DimIndex (TExp Int64)]
-> ImpM lore r op (Code op)
copyArrayDWIM PrimType
bt MemLocation
dest_loc [DimIndex (TExp Int64)]
dest_slice MemLocation
src_loc [DimIndex (TExp Int64)]
src_slice
    (ArrayDestination (Just MemLocation
dest_loc), ScalarVar Maybe (Exp lore)
_ (ScalarEntry PrimType
bt))
      | Just Shape (TExp Int64)
dest_is <- (DimIndex (TExp Int64) -> Maybe (TExp Int64))
-> [DimIndex (TExp Int64)] -> Maybe (Shape (TExp Int64))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndex (TExp Int64) -> Maybe (TExp Int64)
forall d. DimIndex d -> Maybe d
dimFix [DimIndex (TExp Int64)]
dest_slice -> do
        (VName
dest_mem, Space
dest_space, Count Elements (TExp Int64)
dest_i) <- MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall lore r op.
MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLocation
dest_loc Shape (TExp Int64)
dest_is
        Volatility
vol <- (Env lore r op -> Volatility) -> ImpM lore r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Volatility
forall lore r op. Env lore r op -> Volatility
envVolatility
        Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code op
forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
dest_mem Count Elements (TExp Int64)
dest_i PrimType
bt Space
dest_space Volatility
vol (VName -> PrimType -> Exp
Imp.var VName
src PrimType
bt)
      | Bool
otherwise ->
        [Char] -> ImpM lore r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM lore r op ()) -> [Char] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
          [[Char]] -> [Char]
unwords
            [ [Char]
"copyDWIMDest: array-typed target and prim-typed source",
              VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
src,
              [Char]
"with slice",
              [DimIndex (TExp Int64)] -> [Char]
forall a. Pretty a => a -> [Char]
pretty [DimIndex (TExp Int64)]
dest_slice
            ]
    (ArrayDestination Maybe MemLocation
Nothing, VarEntry lore
_) ->
      () -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return () -- Nothing to do; something else set some memory
      -- somewhere.
    (ValueDestination
_, AccVar {}) ->
      () -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return () -- Nothing to do; accumulators are phantoms.

-- | Copy from here to there; both destination and source be
-- indexeded.  If so, they better be arrays of enough dimensions.
-- This function will generally just Do What I Mean, and Do The Right
-- Thing.  Both destination and source must be in scope.
copyDWIM ::
  VName ->
  [DimIndex (Imp.TExp Int64)] ->
  SubExp ->
  [DimIndex (Imp.TExp Int64)] ->
  ImpM lore r op ()
copyDWIM :: forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM VName
dest [DimIndex (TExp Int64)]
dest_slice SubExp
src [DimIndex (TExp Int64)]
src_slice = do
  VarEntry lore
dest_entry <- VName -> ImpM lore r op (VarEntry lore)
forall lore r op. VName -> ImpM lore r op (VarEntry lore)
lookupVar VName
dest
  let dest_target :: ValueDestination
dest_target =
        case VarEntry lore
dest_entry of
          ScalarVar Maybe (Exp lore)
_ ScalarEntry
_ ->
            VName -> ValueDestination
ScalarDestination VName
dest
          ArrayVar Maybe (Exp lore)
_ (ArrayEntry (MemLocation VName
mem [SubExp]
shape IxFun (TExp Int64)
ixfun) PrimType
_) ->
            Maybe MemLocation -> ValueDestination
ArrayDestination (Maybe MemLocation -> ValueDestination)
-> Maybe MemLocation -> ValueDestination
forall a b. (a -> b) -> a -> b
$ MemLocation -> Maybe MemLocation
forall a. a -> Maybe a
Just (MemLocation -> Maybe MemLocation)
-> MemLocation -> Maybe MemLocation
forall a b. (a -> b) -> a -> b
$ VName -> [SubExp] -> IxFun (TExp Int64) -> MemLocation
MemLocation VName
mem [SubExp]
shape IxFun (TExp Int64)
ixfun
          MemVar Maybe (Exp lore)
_ MemEntry
_ ->
            VName -> ValueDestination
MemoryDestination VName
dest
          AccVar {} ->
            -- Does not matter; accumulators are phantoms.
            Maybe MemLocation -> ValueDestination
ArrayDestination Maybe MemLocation
forall a. Maybe a
Nothing
  ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIMDest ValueDestination
dest_target [DimIndex (TExp Int64)]
dest_slice SubExp
src [DimIndex (TExp Int64)]
src_slice

-- | As 'copyDWIM', but implicitly 'DimFix'es the indexes.
copyDWIMFix ::
  VName ->
  [Imp.TExp Int64] ->
  SubExp ->
  [Imp.TExp Int64] ->
  ImpM lore r op ()
copyDWIMFix :: forall lore r op.
VName
-> Shape (TExp Int64)
-> SubExp
-> Shape (TExp Int64)
-> ImpM lore r op ()
copyDWIMFix VName
dest Shape (TExp Int64)
dest_is SubExp
src Shape (TExp Int64)
src_is =
  VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM VName
dest ((TExp Int64 -> DimIndex (TExp Int64))
-> Shape (TExp Int64) -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix Shape (TExp Int64)
dest_is) SubExp
src ((TExp Int64 -> DimIndex (TExp Int64))
-> Shape (TExp Int64) -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix Shape (TExp Int64)
src_is)

-- | @compileAlloc pat size space@ allocates @n@ bytes of memory in @space@,
-- writing the result to @dest@, which must be a single
-- 'MemoryDestination',
compileAlloc ::
  Mem lore =>
  Pattern lore ->
  SubExp ->
  Space ->
  ImpM lore r op ()
compileAlloc :: forall lore r op.
Mem lore =>
Pattern lore -> SubExp -> Space -> ImpM lore r op ()
compileAlloc (Pattern [] [PatElemT (LetDec lore)
mem]) SubExp
e Space
space = do
  let e' :: Count Bytes (TExp Int64)
e' = TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
Imp.bytes (TExp Int64 -> Count Bytes (TExp Int64))
-> TExp Int64 -> Count Bytes (TExp Int64)
forall a b. (a -> b) -> a -> b
$ SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp SubExp
e
  Maybe (AllocCompiler lore r op)
allocator <- (Env lore r op -> Maybe (AllocCompiler lore r op))
-> ImpM lore r op (Maybe (AllocCompiler lore r op))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((Env lore r op -> Maybe (AllocCompiler lore r op))
 -> ImpM lore r op (Maybe (AllocCompiler lore r op)))
-> (Env lore r op -> Maybe (AllocCompiler lore r op))
-> ImpM lore r op (Maybe (AllocCompiler lore r op))
forall a b. (a -> b) -> a -> b
$ Space
-> Map Space (AllocCompiler lore r op)
-> Maybe (AllocCompiler lore r op)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Space
space (Map Space (AllocCompiler lore r op)
 -> Maybe (AllocCompiler lore r op))
-> (Env lore r op -> Map Space (AllocCompiler lore r op))
-> Env lore r op
-> Maybe (AllocCompiler lore r op)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env lore r op -> Map Space (AllocCompiler lore r op)
forall lore r op.
Env lore r op -> Map Space (AllocCompiler lore r op)
envAllocCompilers
  case Maybe (AllocCompiler lore r op)
allocator of
    Maybe (AllocCompiler lore r op)
Nothing -> Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Count Bytes (TExp Int64) -> Space -> Code op
forall a. VName -> Count Bytes (TExp Int64) -> Space -> Code a
Imp.Allocate (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT LParamMem
mem) Count Bytes (TExp Int64)
e' Space
space
    Just AllocCompiler lore r op
allocator' -> AllocCompiler lore r op
allocator' (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT LParamMem
mem) Count Bytes (TExp Int64)
e'
compileAlloc PatternT (LetDec lore)
pat SubExp
_ Space
_ =
  [Char] -> ImpM lore r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM lore r op ()) -> [Char] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ [Char]
"compileAlloc: Invalid pattern: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ PatternT LParamMem -> [Char]
forall a. Pretty a => a -> [Char]
pretty PatternT (LetDec lore)
PatternT LParamMem
pat

-- | The number of bytes needed to represent the array in a
-- straightforward contiguous format, as an t'Int64' expression.
typeSize :: Type -> Count Bytes (Imp.TExp Int64)
typeSize :: Type -> Count Bytes (TExp Int64)
typeSize Type
t =
  TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
Imp.bytes (TExp Int64 -> Count Bytes (TExp Int64))
-> TExp Int64 -> Count Bytes (TExp Int64)
forall a b. (a -> b) -> a -> b
$ PrimType -> TExp Int64
forall a. Num a => PrimType -> a
primByteSize (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t) TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* Shape (TExp Int64) -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ((SubExp -> TExp Int64) -> [SubExp] -> Shape (TExp Int64)
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t))

-- | Is this indexing in-bounds for an array of the given shape?  This
-- is useful for things like scatter, which ignores out-of-bounds
-- writes.
inBounds :: Slice (Imp.TExp Int64) -> [Imp.TExp Int64] -> Imp.TExp Bool
inBounds :: [DimIndex (TExp Int64)] -> Shape (TExp Int64) -> TExp Bool
inBounds [DimIndex (TExp Int64)]
slice Shape (TExp Int64)
dims =
  let condInBounds :: DimIndex (TPrimExp t v) -> TPrimExp t v -> TPrimExp Bool v
condInBounds (DimFix TPrimExp t v
i) TPrimExp t v
d =
        TPrimExp t v
0 TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TPrimExp t v
i TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TPrimExp t v
i TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp t v
d
      condInBounds (DimSlice TPrimExp t v
i TPrimExp t v
n TPrimExp t v
s) TPrimExp t v
d =
        TPrimExp t v
0 TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TPrimExp t v
i TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TPrimExp t v
i TPrimExp t v -> TPrimExp t v -> TPrimExp t v
forall a. Num a => a -> a -> a
+ TPrimExp t v
n TPrimExp t v -> TPrimExp t v -> TPrimExp t v
forall a. Num a => a -> a -> a
* TPrimExp t v
s TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp t v
d
   in (TExp Bool -> TExp Bool -> TExp Bool) -> [TExp Bool] -> TExp Bool
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
(.&&.) ([TExp Bool] -> TExp Bool) -> [TExp Bool] -> TExp Bool
forall a b. (a -> b) -> a -> b
$ (DimIndex (TExp Int64) -> TExp Int64 -> TExp Bool)
-> [DimIndex (TExp Int64)] -> Shape (TExp Int64) -> [TExp Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith DimIndex (TExp Int64) -> TExp Int64 -> TExp Bool
forall {t} {v}.
(NumExp t, Pretty v) =>
DimIndex (TPrimExp t v) -> TPrimExp t v -> TPrimExp Bool v
condInBounds [DimIndex (TExp Int64)]
slice Shape (TExp Int64)
dims

--- Building blocks for constructing code.

sFor' :: VName -> Imp.Exp -> ImpM lore r op () -> ImpM lore r op ()
sFor' :: forall lore r op.
VName -> Exp -> ImpM lore r op () -> ImpM lore r op ()
sFor' VName
i Exp
bound ImpM lore r op ()
body = do
  let it :: IntType
it = case Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType Exp
bound of
        IntType IntType
bound_t -> IntType
bound_t
        PrimType
t -> [Char] -> IntType
forall a. HasCallStack => [Char] -> a
error ([Char] -> IntType) -> [Char] -> IntType
forall a b. (a -> b) -> a -> b
$ [Char]
"sFor': bound " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ Exp -> [Char]
forall a. Pretty a => a -> [Char]
pretty Exp
bound [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
" is of type " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ PrimType -> [Char]
forall a. Pretty a => a -> [Char]
pretty PrimType
t
  VName -> IntType -> ImpM lore r op ()
forall lore r op. VName -> IntType -> ImpM lore r op ()
addLoopVar VName
i IntType
it
  Code op
body' <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect ImpM lore r op ()
body
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op -> Code op
forall a. VName -> Exp -> Code a -> Code a
Imp.For VName
i Exp
bound Code op
body'

sFor :: String -> Imp.TExp t -> (Imp.TExp t -> ImpM lore r op ()) -> ImpM lore r op ()
sFor :: forall t lore r op.
[Char]
-> TExp t -> (TExp t -> ImpM lore r op ()) -> ImpM lore r op ()
sFor [Char]
i TExp t
bound TExp t -> ImpM lore r op ()
body = do
  VName
i' <- [Char] -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
i
  VName -> Exp -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
VName -> Exp -> ImpM lore r op () -> ImpM lore r op ()
sFor' VName
i' (TExp t -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp t
bound) (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
    TExp t -> ImpM lore r op ()
body (TExp t -> ImpM lore r op ()) -> TExp t -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Exp -> TExp t
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (Exp -> TExp t) -> Exp -> TExp t
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
i' (PrimType -> Exp) -> PrimType -> Exp
forall a b. (a -> b) -> a -> b
$ Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType (Exp -> PrimType) -> Exp -> PrimType
forall a b. (a -> b) -> a -> b
$ TExp t -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp t
bound

sWhile :: Imp.TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhile :: forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhile TExp Bool
cond ImpM lore r op ()
body = do
  Code op
body' <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect ImpM lore r op ()
body
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ TExp Bool -> Code op -> Code op
forall a. TExp Bool -> Code a -> Code a
Imp.While TExp Bool
cond Code op
body'

sComment :: String -> ImpM lore r op () -> ImpM lore r op ()
sComment :: forall lore r op. [Char] -> ImpM lore r op () -> ImpM lore r op ()
sComment [Char]
s ImpM lore r op ()
code = do
  Code op
code' <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect ImpM lore r op ()
code
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Code op -> Code op
forall a. [Char] -> Code a -> Code a
Imp.Comment [Char]
s Code op
code'

sIf :: Imp.TExp Bool -> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf :: forall lore r op.
TExp Bool
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf TExp Bool
cond ImpM lore r op ()
tbranch ImpM lore r op ()
fbranch = do
  Code op
tbranch' <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect ImpM lore r op ()
tbranch
  Code op
fbranch' <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect ImpM lore r op ()
fbranch
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ TExp Bool -> Code op -> Code op -> Code op
forall a. TExp Bool -> Code a -> Code a -> Code a
Imp.If TExp Bool
cond Code op
tbranch' Code op
fbranch'

sWhen :: Imp.TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen :: forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen TExp Bool
cond ImpM lore r op ()
tbranch = TExp Bool
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
TExp Bool
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf TExp Bool
cond ImpM lore r op ()
tbranch (() -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())

sUnless :: Imp.TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sUnless :: forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sUnless TExp Bool
cond = TExp Bool
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
TExp Bool
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf TExp Bool
cond (() -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())

sOp :: op -> ImpM lore r op ()
sOp :: forall op lore r. op -> ImpM lore r op ()
sOp = Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ())
-> (op -> Code op) -> op -> ImpM lore r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. op -> Code op
forall a. a -> Code a
Imp.Op

sDeclareMem :: String -> Space -> ImpM lore r op VName
sDeclareMem :: forall lore r op. [Char] -> Space -> ImpM lore r op VName
sDeclareMem [Char]
name Space
space = do
  VName
name' <- [Char] -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
name
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Code op
forall a. VName -> Space -> Code a
Imp.DeclareMem VName
name' Space
space
  VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name' (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> MemEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> MemEntry -> VarEntry lore
MemVar Maybe (Exp lore)
forall a. Maybe a
Nothing (MemEntry -> VarEntry lore) -> MemEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
space
  VName -> ImpM lore r op VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
name'

sAlloc_ :: VName -> Count Bytes (Imp.TExp Int64) -> Space -> ImpM lore r op ()
sAlloc_ :: forall lore r op.
VName -> Count Bytes (TExp Int64) -> Space -> ImpM lore r op ()
sAlloc_ VName
name' Count Bytes (TExp Int64)
size' Space
space = do
  Maybe (AllocCompiler lore r op)
allocator <- (Env lore r op -> Maybe (AllocCompiler lore r op))
-> ImpM lore r op (Maybe (AllocCompiler lore r op))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((Env lore r op -> Maybe (AllocCompiler lore r op))
 -> ImpM lore r op (Maybe (AllocCompiler lore r op)))
-> (Env lore r op -> Maybe (AllocCompiler lore r op))
-> ImpM lore r op (Maybe (AllocCompiler lore r op))
forall a b. (a -> b) -> a -> b
$ Space
-> Map Space (AllocCompiler lore r op)
-> Maybe (AllocCompiler lore r op)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Space
space (Map Space (AllocCompiler lore r op)
 -> Maybe (AllocCompiler lore r op))
-> (Env lore r op -> Map Space (AllocCompiler lore r op))
-> Env lore r op
-> Maybe (AllocCompiler lore r op)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env lore r op -> Map Space (AllocCompiler lore r op)
forall lore r op.
Env lore r op -> Map Space (AllocCompiler lore r op)
envAllocCompilers
  case Maybe (AllocCompiler lore r op)
allocator of
    Maybe (AllocCompiler lore r op)
Nothing -> Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Count Bytes (TExp Int64) -> Space -> Code op
forall a. VName -> Count Bytes (TExp Int64) -> Space -> Code a
Imp.Allocate VName
name' Count Bytes (TExp Int64)
size' Space
space
    Just AllocCompiler lore r op
allocator' -> AllocCompiler lore r op
allocator' VName
name' Count Bytes (TExp Int64)
size'

sAlloc :: String -> Count Bytes (Imp.TExp Int64) -> Space -> ImpM lore r op VName
sAlloc :: forall lore r op.
[Char] -> Count Bytes (TExp Int64) -> Space -> ImpM lore r op VName
sAlloc [Char]
name Count Bytes (TExp Int64)
size Space
space = do
  VName
name' <- [Char] -> Space -> ImpM lore r op VName
forall lore r op. [Char] -> Space -> ImpM lore r op VName
sDeclareMem [Char]
name Space
space
  VName -> Count Bytes (TExp Int64) -> Space -> ImpM lore r op ()
forall lore r op.
VName -> Count Bytes (TExp Int64) -> Space -> ImpM lore r op ()
sAlloc_ VName
name' Count Bytes (TExp Int64)
size Space
space
  VName -> ImpM lore r op VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
name'

sArray :: String -> PrimType -> ShapeBase SubExp -> MemBind -> ImpM lore r op VName
sArray :: forall lore r op.
[Char] -> PrimType -> Shape -> MemBind -> ImpM lore r op VName
sArray [Char]
name PrimType
bt Shape
shape MemBind
membind = do
  VName
name' <- [Char] -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
name
  VName -> PrimType -> Shape -> MemBind -> ImpM lore r op ()
forall lore r op.
VName -> PrimType -> Shape -> MemBind -> ImpM lore r op ()
dArray VName
name' PrimType
bt Shape
shape MemBind
membind
  VName -> ImpM lore r op VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
name'

-- | Declare an array in row-major order in the given memory block.
sArrayInMem :: String -> PrimType -> ShapeBase SubExp -> VName -> ImpM lore r op VName
sArrayInMem :: forall lore r op.
[Char] -> PrimType -> Shape -> VName -> ImpM lore r op VName
sArrayInMem [Char]
name PrimType
pt Shape
shape VName
mem =
  [Char] -> PrimType -> Shape -> MemBind -> ImpM lore r op VName
forall lore r op.
[Char] -> PrimType -> Shape -> MemBind -> ImpM lore r op VName
sArray [Char]
name PrimType
pt Shape
shape (MemBind -> ImpM lore r op VName)
-> MemBind -> ImpM lore r op VName
forall a b. (a -> b) -> a -> b
$
    VName -> IxFun -> MemBind
ArrayIn VName
mem (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$
      Shape (TPrimExp Int64 VName) -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota (Shape (TPrimExp Int64 VName) -> IxFun)
-> Shape (TPrimExp Int64 VName) -> IxFun
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map (PrimExp VName -> TPrimExp Int64 VName
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (PrimExp VName -> TPrimExp Int64 VName)
-> (SubExp -> PrimExp VName) -> SubExp -> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int64) ([SubExp] -> Shape (TPrimExp Int64 VName))
-> [SubExp] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape

-- | Like 'sAllocArray', but permute the in-memory representation of the indices as specified.
sAllocArrayPerm :: String -> PrimType -> ShapeBase SubExp -> Space -> [Int] -> ImpM lore r op VName
sAllocArrayPerm :: forall lore r op.
[Char]
-> PrimType -> Shape -> Space -> [Int] -> ImpM lore r op VName
sAllocArrayPerm [Char]
name PrimType
pt Shape
shape Space
space [Int]
perm = do
  let permuted_dims :: [SubExp]
permuted_dims = [Int] -> [SubExp] -> [SubExp]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape
  VName
mem <- [Char] -> Count Bytes (TExp Int64) -> Space -> ImpM lore r op VName
forall lore r op.
[Char] -> Count Bytes (TExp Int64) -> Space -> ImpM lore r op VName
sAlloc ([Char]
name [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
"_mem") (Type -> Count Bytes (TExp Int64)
typeSize (PrimType -> Shape -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
pt Shape
shape NoUniqueness
NoUniqueness)) Space
space
  let iota_ixfun :: IxFun
iota_ixfun = Shape (TPrimExp Int64 VName) -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota (Shape (TPrimExp Int64 VName) -> IxFun)
-> Shape (TPrimExp Int64 VName) -> IxFun
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map (PrimExp VName -> TPrimExp Int64 VName
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (PrimExp VName -> TPrimExp Int64 VName)
-> (SubExp -> PrimExp VName) -> SubExp -> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int64) [SubExp]
permuted_dims
  [Char] -> PrimType -> Shape -> MemBind -> ImpM lore r op VName
forall lore r op.
[Char] -> PrimType -> Shape -> MemBind -> ImpM lore r op VName
sArray [Char]
name PrimType
pt Shape
shape (MemBind -> ImpM lore r op VName)
-> MemBind -> ImpM lore r op VName
forall a b. (a -> b) -> a -> b
$
    VName -> IxFun -> MemBind
ArrayIn VName
mem (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$ IxFun -> [Int] -> IxFun
forall num. IntegralExp num => IxFun num -> [Int] -> IxFun num
IxFun.permute IxFun
iota_ixfun ([Int] -> IxFun) -> [Int] -> IxFun
forall a b. (a -> b) -> a -> b
$ [Int] -> [Int]
rearrangeInverse [Int]
perm

-- | Uses linear/iota index function.
sAllocArray :: String -> PrimType -> ShapeBase SubExp -> Space -> ImpM lore r op VName
sAllocArray :: forall lore r op.
[Char] -> PrimType -> Shape -> Space -> ImpM lore r op VName
sAllocArray [Char]
name PrimType
pt Shape
shape Space
space =
  [Char]
-> PrimType -> Shape -> Space -> [Int] -> ImpM lore r op VName
forall lore r op.
[Char]
-> PrimType -> Shape -> Space -> [Int] -> ImpM lore r op VName
sAllocArrayPerm [Char]
name PrimType
pt Shape
shape Space
space [Int
0 .. Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
shape Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]

-- | Uses linear/iota index function.
sStaticArray :: String -> Space -> PrimType -> Imp.ArrayContents -> ImpM lore r op VName
sStaticArray :: forall lore r op.
[Char]
-> Space -> PrimType -> ArrayContents -> ImpM lore r op VName
sStaticArray [Char]
name Space
space PrimType
pt ArrayContents
vs = do
  let num_elems :: Int
num_elems = case ArrayContents
vs of
        Imp.ArrayValues [PrimValue]
vs' -> [PrimValue] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimValue]
vs'
        Imp.ArrayZeros Int
n -> Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n
      shape :: Shape
shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [IntType -> Integer -> SubExp
intConst IntType
Int64 (Integer -> SubExp) -> Integer -> SubExp
forall a b. (a -> b) -> a -> b
$ Int -> Integer
forall a. Integral a => a -> Integer
toInteger Int
num_elems]
  VName
mem <- [Char] -> ImpM lore r op VName
forall lore r op. [Char] -> ImpM lore r op VName
newVNameForFun ([Char] -> ImpM lore r op VName) -> [Char] -> ImpM lore r op VName
forall a b. (a -> b) -> a -> b
$ [Char]
name [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
"_mem"
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Space -> PrimType -> ArrayContents -> Code op
forall a. VName -> Space -> PrimType -> ArrayContents -> Code a
Imp.DeclareArray VName
mem Space
space PrimType
pt ArrayContents
vs
  VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
mem (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> MemEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> MemEntry -> VarEntry lore
MemVar Maybe (Exp lore)
forall a. Maybe a
Nothing (MemEntry -> VarEntry lore) -> MemEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
space
  [Char] -> PrimType -> Shape -> MemBind -> ImpM lore r op VName
forall lore r op.
[Char] -> PrimType -> Shape -> MemBind -> ImpM lore r op VName
sArray [Char]
name PrimType
pt Shape
shape (MemBind -> ImpM lore r op VName)
-> MemBind -> ImpM lore r op VName
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn VName
mem (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$ Shape (TPrimExp Int64 VName) -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota [Int -> TPrimExp Int64 VName
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
num_elems]

sWrite :: VName -> [Imp.TExp Int64] -> Imp.Exp -> ImpM lore r op ()
sWrite :: forall lore r op.
VName -> Shape (TExp Int64) -> Exp -> ImpM lore r op ()
sWrite VName
arr Shape (TExp Int64)
is Exp
v = do
  (VName
mem, Space
space, Count Elements (TExp Int64)
offset) <- VName
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall lore r op.
VName
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray VName
arr Shape (TExp Int64)
is
  Volatility
vol <- (Env lore r op -> Volatility) -> ImpM lore r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Volatility
forall lore r op. Env lore r op -> Volatility
envVolatility
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code op
forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
mem Count Elements (TExp Int64)
offset (Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType Exp
v) Space
space Volatility
vol Exp
v

sUpdate :: VName -> Slice (Imp.TExp Int64) -> SubExp -> ImpM lore r op ()
sUpdate :: forall lore r op.
VName -> [DimIndex (TExp Int64)] -> SubExp -> ImpM lore r op ()
sUpdate VName
arr [DimIndex (TExp Int64)]
slice SubExp
v = VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM VName
arr [DimIndex (TExp Int64)]
slice SubExp
v []

sLoopNest ::
  Shape ->
  ([Imp.TExp Int64] -> ImpM lore r op ()) ->
  ImpM lore r op ()
sLoopNest :: forall lore r op.
Shape
-> (Shape (TExp Int64) -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest = Shape (TExp Int64)
-> [SubExp]
-> (Shape (TExp Int64) -> ImpM lore r op ())
-> ImpM lore r op ()
forall {a} {lore} {r} {op}.
ToExp a =>
Shape (TExp Int64)
-> [a]
-> (Shape (TExp Int64) -> ImpM lore r op ())
-> ImpM lore r op ()
sLoopNest' [] ([SubExp]
 -> (Shape (TExp Int64) -> ImpM lore r op ()) -> ImpM lore r op ())
-> (Shape -> [SubExp])
-> Shape
-> (Shape (TExp Int64) -> ImpM lore r op ())
-> ImpM lore r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims
  where
    sLoopNest' :: Shape (TExp Int64)
-> [a]
-> (Shape (TExp Int64) -> ImpM lore r op ())
-> ImpM lore r op ()
sLoopNest' Shape (TExp Int64)
is [] Shape (TExp Int64) -> ImpM lore r op ()
f = Shape (TExp Int64) -> ImpM lore r op ()
f (Shape (TExp Int64) -> ImpM lore r op ())
-> Shape (TExp Int64) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Shape (TExp Int64) -> Shape (TExp Int64)
forall a. [a] -> [a]
reverse Shape (TExp Int64)
is
    sLoopNest' Shape (TExp Int64)
is (a
d : [a]
ds) Shape (TExp Int64) -> ImpM lore r op ()
f =
      [Char]
-> TExp Int64
-> (TExp Int64 -> ImpM lore r op ())
-> ImpM lore r op ()
forall t lore r op.
[Char]
-> TExp t -> (TExp t -> ImpM lore r op ()) -> ImpM lore r op ()
sFor [Char]
"nest_i" (a -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp a
d) ((TExp Int64 -> ImpM lore r op ()) -> ImpM lore r op ())
-> (TExp Int64 -> ImpM lore r op ()) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> Shape (TExp Int64)
-> [a]
-> (Shape (TExp Int64) -> ImpM lore r op ())
-> ImpM lore r op ()
sLoopNest' (TExp Int64
i TExp Int64 -> Shape (TExp Int64) -> Shape (TExp Int64)
forall a. a -> [a] -> [a]
: Shape (TExp Int64)
is) [a]
ds Shape (TExp Int64) -> ImpM lore r op ()
f

-- | Untyped assignment.
(<~~) :: VName -> Imp.Exp -> ImpM lore r op ()
VName
x <~~ :: forall lore r op. VName -> Exp -> ImpM lore r op ()
<~~ Exp
e = Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
x Exp
e

infixl 3 <~~

-- | Typed assignment.
(<--) :: TV t -> Imp.TExp t -> ImpM lore r op ()
TV VName
x PrimType
_ <-- :: forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- TExp t
e = Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
x (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ TExp t -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp t
e

infixl 3 <--

-- | Constructing an ad-hoc function that does not
-- correspond to any of the IR functions in the input program.
function ::
  Name ->
  [Imp.Param] ->
  [Imp.Param] ->
  ImpM lore r op () ->
  ImpM lore r op ()
function :: forall lore r op.
Name
-> [Param] -> [Param] -> ImpM lore r op () -> ImpM lore r op ()
function Name
fname [Param]
outputs [Param]
inputs ImpM lore r op ()
m = (Env lore r op -> Env lore r op)
-> ImpM lore r op () -> ImpM lore r op ()
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local Env lore r op -> Env lore r op
newFunction (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ do
  Code op
body <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (ImpM lore r op () -> ImpM lore r op (Code op))
-> ImpM lore r op () -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$ do
    (Param -> ImpM lore r op ()) -> [Param] -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Param -> ImpM lore r op ()
forall {lore} {r} {op}. Param -> ImpM lore r op ()
addParam ([Param] -> ImpM lore r op ()) -> [Param] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ [Param]
outputs [Param] -> [Param] -> [Param]
forall a. [a] -> [a] -> [a]
++ [Param]
inputs
    ImpM lore r op ()
m
  Name -> Function op -> ImpM lore r op ()
forall op lore r. Name -> Function op -> ImpM lore r op ()
emitFunction Name
fname (Function op -> ImpM lore r op ())
-> Function op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Maybe Name
-> [Param]
-> [Param]
-> Code op
-> [ExternalValue]
-> [ExternalValue]
-> Function op
forall a.
Maybe Name
-> [Param]
-> [Param]
-> Code a
-> [ExternalValue]
-> [ExternalValue]
-> FunctionT a
Imp.Function Maybe Name
forall a. Maybe a
Nothing [Param]
outputs [Param]
inputs Code op
body [] []
  where
    addParam :: Param -> ImpM lore r op ()
addParam (Imp.MemParam VName
name Space
space) =
      VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> MemEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> MemEntry -> VarEntry lore
MemVar Maybe (Exp lore)
forall a. Maybe a
Nothing (MemEntry -> VarEntry lore) -> MemEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
space
    addParam (Imp.ScalarParam VName
name PrimType
bt) =
      VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
ScalarVar Maybe (Exp lore)
forall a. Maybe a
Nothing (ScalarEntry -> VarEntry lore) -> ScalarEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarEntry
ScalarEntry PrimType
bt
    newFunction :: Env lore r op -> Env lore r op
newFunction Env lore r op
env = Env lore r op
env {envFunction :: Maybe Name
envFunction = Name -> Maybe Name
forall a. a -> Maybe a
Just Name
fname}

dSlices :: [Imp.TExp Int64] -> ImpM lore r op [Imp.TExp Int64]
dSlices :: forall lore r op.
Shape (TExp Int64) -> ImpM lore r op (Shape (TExp Int64))
dSlices = ((TExp Int64, Shape (TExp Int64)) -> Shape (TExp Int64))
-> ImpM lore r op (TExp Int64, Shape (TExp Int64))
-> ImpM lore r op (Shape (TExp Int64))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int -> Shape (TExp Int64) -> Shape (TExp Int64)
forall a. Int -> [a] -> [a]
drop Int
1 (Shape (TExp Int64) -> Shape (TExp Int64))
-> ((TExp Int64, Shape (TExp Int64)) -> Shape (TExp Int64))
-> (TExp Int64, Shape (TExp Int64))
-> Shape (TExp Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TExp Int64, Shape (TExp Int64)) -> Shape (TExp Int64)
forall a b. (a, b) -> b
snd) (ImpM lore r op (TExp Int64, Shape (TExp Int64))
 -> ImpM lore r op (Shape (TExp Int64)))
-> (Shape (TExp Int64)
    -> ImpM lore r op (TExp Int64, Shape (TExp Int64)))
-> Shape (TExp Int64)
-> ImpM lore r op (Shape (TExp Int64))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Shape (TExp Int64)
-> ImpM lore r op (TExp Int64, Shape (TExp Int64))
forall {t} {lore} {r} {op}.
NumExp t =>
[TExp t] -> ImpM lore r op (TExp t, [TExp t])
dSlices'
  where
    dSlices' :: [TExp t] -> ImpM lore r op (TExp t, [TExp t])
dSlices' [] = (TExp t, [TExp t]) -> ImpM lore r op (TExp t, [TExp t])
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TExp t
1, [TExp t
1])
    dSlices' (TExp t
n : [TExp t]
ns) = do
      (TExp t
prod, [TExp t]
ns') <- [TExp t] -> ImpM lore r op (TExp t, [TExp t])
dSlices' [TExp t]
ns
      TExp t
n' <- [Char] -> TExp t -> ImpM lore r op (TExp t)
forall t lore r op. [Char] -> TExp t -> ImpM lore r op (TExp t)
dPrimVE [Char]
"slice" (TExp t -> ImpM lore r op (TExp t))
-> TExp t -> ImpM lore r op (TExp t)
forall a b. (a -> b) -> a -> b
$ TExp t
n TExp t -> TExp t -> TExp t
forall a. Num a => a -> a -> a
* TExp t
prod
      (TExp t, [TExp t]) -> ImpM lore r op (TExp t, [TExp t])
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TExp t
n', TExp t
n' TExp t -> [TExp t] -> [TExp t]
forall a. a -> [a] -> [a]
: [TExp t]
ns')

-- | @dIndexSpace f dims i@ computes a list of indices into an
-- array with dimension @dims@ given the flat index @i@.  The
-- resulting list will have the same size as @dims@.  Intermediate
-- results are passed to @f@.
dIndexSpace ::
  [(VName, Imp.TExp Int64)] ->
  Imp.TExp Int64 ->
  ImpM lore r op ()
dIndexSpace :: forall lore r op.
[(VName, TExp Int64)] -> TExp Int64 -> ImpM lore r op ()
dIndexSpace [(VName, TExp Int64)]
vs_ds TExp Int64
j = do
  Shape (TExp Int64)
slices <- Shape (TExp Int64) -> ImpM lore r op (Shape (TExp Int64))
forall lore r op.
Shape (TExp Int64) -> ImpM lore r op (Shape (TExp Int64))
dSlices (((VName, TExp Int64) -> TExp Int64)
-> [(VName, TExp Int64)] -> Shape (TExp Int64)
forall a b. (a -> b) -> [a] -> [b]
map (VName, TExp Int64) -> TExp Int64
forall a b. (a, b) -> b
snd [(VName, TExp Int64)]
vs_ds)
  [(VName, TExp Int64)] -> TExp Int64 -> ImpM lore r op ()
forall lore r op.
[(VName, TExp Int64)] -> TExp Int64 -> ImpM lore r op ()
loop ([VName] -> Shape (TExp Int64) -> [(VName, TExp Int64)]
forall a b. [a] -> [b] -> [(a, b)]
zip (((VName, TExp Int64) -> VName) -> [(VName, TExp Int64)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, TExp Int64) -> VName
forall a b. (a, b) -> a
fst [(VName, TExp Int64)]
vs_ds) Shape (TExp Int64)
slices) TExp Int64
j
  where
    loop :: [(VName, TExp Int64)] -> TExp Int64 -> ImpM lore r op ()
loop ((VName
v, TExp Int64
size) : [(VName, TExp Int64)]
rest) TExp Int64
i = do
      VName -> TExp Int64 -> ImpM lore r op ()
forall t lore r op. VName -> TExp t -> ImpM lore r op ()
dPrimV_ VName
v (TExp Int64
i TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`quot` TExp Int64
size)
      TExp Int64
i' <- [Char] -> TExp Int64 -> ImpM lore r op (TExp Int64)
forall t lore r op. [Char] -> TExp t -> ImpM lore r op (TExp t)
dPrimVE [Char]
"remnant" (TExp Int64 -> ImpM lore r op (TExp Int64))
-> TExp Int64 -> ImpM lore r op (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int64
i TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- VName -> TExp Int64
Imp.vi64 VName
v TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
size
      [(VName, TExp Int64)] -> TExp Int64 -> ImpM lore r op ()
loop [(VName, TExp Int64)]
rest TExp Int64
i'
    loop [(VName, TExp Int64)]
_ TExp Int64
_ = () -> ImpM lore r op ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()