{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE Strict #-}
{-# LANGUAGE Trustworthy #-}
{-# LANGUAGE TypeFamilies #-}

module Futhark.CodeGen.ImpGen
  ( -- * Entry Points
    compileProg,

    -- * Pluggable Compiler
    OpCompiler,
    ExpCompiler,
    CopyCompiler,
    StmsCompiler,
    AllocCompiler,
    Operations (..),
    defaultOperations,
    MemLocation (..),
    MemEntry (..),
    ScalarEntry (..),

    -- * Monadic Compiler Interface
    ImpM,
    localDefaultSpace,
    askFunction,
    newVNameForFun,
    nameForFun,
    askEnv,
    localEnv,
    localOps,
    VTable,
    getVTable,
    localVTable,
    subImpM,
    subImpM_,
    emit,
    emitFunction,
    hasFunction,
    collect,
    collect',
    comment,
    VarEntry (..),
    ArrayEntry (..),

    -- * Lookups
    lookupVar,
    lookupArray,
    lookupMemory,
    lookupAcc,

    -- * Building Blocks
    TV,
    mkTV,
    tvSize,
    tvExp,
    tvVar,
    ToExp (..),
    compileAlloc,
    everythingVolatile,
    compileBody,
    compileBody',
    compileLoopBody,
    defCompileStms,
    compileStms,
    compileExp,
    defCompileExp,
    fullyIndexArray,
    fullyIndexArray',
    copy,
    copyDWIM,
    copyDWIMFix,
    copyElementWise,
    typeSize,
    inBounds,
    isMapTransposeCopy,

    -- * Constructing code.
    dLParams,
    dFParams,
    dScope,
    dArray,
    dPrim,
    dPrimVol,
    dPrim_,
    dPrimV_,
    dPrimV,
    dPrimVE,
    sFor,
    sWhile,
    sComment,
    sIf,
    sWhen,
    sUnless,
    sOp,
    sDeclareMem,
    sAlloc,
    sAlloc_,
    sArray,
    sArrayInMem,
    sAllocArray,
    sAllocArrayPerm,
    sStaticArray,
    sWrite,
    sUpdate,
    sLoopNest,
    (<--),
    (<~~),
    function,
    warn,
    module Language.Futhark.Warnings,
  )
where

import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Writer
import Control.Parallel.Strategies
import Data.Bifunctor (first)
import qualified Data.DList as DL
import Data.Either
import Data.List (find, genericLength, sortOn)
import qualified Data.Map.Strict as M
import Data.Maybe
import qualified Data.Set as S
import Data.String
import Futhark.CodeGen.ImpCode
  ( Bytes,
    Count,
    Elements,
    bytes,
    elements,
    withElemType,
  )
import qualified Futhark.CodeGen.ImpCode as Imp
import Futhark.CodeGen.ImpGen.Transpose
import Futhark.Construct hiding (ToExp (..))
import Futhark.IR.Mem
import qualified Futhark.IR.Mem.IxFun as IxFun
import Futhark.IR.SOACS (SOACS)
import Futhark.Util
import Futhark.Util.Loc (noLoc)
import Language.Futhark.Warnings

-- | How to compile an t'Op'.
type OpCompiler lore r op = Pattern lore -> Op lore -> ImpM lore r op ()

-- | How to compile some 'Stms'.
type StmsCompiler lore r op = Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()

-- | How to compile an 'Exp'.
type ExpCompiler lore r op = Pattern lore -> Exp lore -> ImpM lore r op ()

type CopyCompiler lore r op =
  PrimType ->
  MemLocation ->
  Slice (Imp.TExp Int64) ->
  MemLocation ->
  Slice (Imp.TExp Int64) ->
  ImpM lore r op ()

-- | An alternate way of compiling an allocation.
type AllocCompiler lore r op = VName -> Count Bytes (Imp.TExp Int64) -> ImpM lore r op ()

data Operations lore r op = Operations
  { forall lore r op. Operations lore r op -> ExpCompiler lore r op
opsExpCompiler :: ExpCompiler lore r op,
    forall lore r op. Operations lore r op -> OpCompiler lore r op
opsOpCompiler :: OpCompiler lore r op,
    forall lore r op. Operations lore r op -> StmsCompiler lore r op
opsStmsCompiler :: StmsCompiler lore r op,
    forall lore r op. Operations lore r op -> CopyCompiler lore r op
opsCopyCompiler :: CopyCompiler lore r op,
    forall lore r op.
Operations lore r op -> Map Space (AllocCompiler lore r op)
opsAllocCompilers :: M.Map Space (AllocCompiler lore r op)
  }

-- | An operations set for which the expression compiler always
-- returns 'defCompileExp'.
defaultOperations ::
  (Mem lore, FreeIn op) =>
  OpCompiler lore r op ->
  Operations lore r op
defaultOperations :: forall lore op r.
(Mem lore, FreeIn op) =>
OpCompiler lore r op -> Operations lore r op
defaultOperations OpCompiler lore r op
opc =
  Operations :: forall lore r op.
ExpCompiler lore r op
-> OpCompiler lore r op
-> StmsCompiler lore r op
-> CopyCompiler lore r op
-> Map Space (AllocCompiler lore r op)
-> Operations lore r op
Operations
    { opsExpCompiler :: ExpCompiler lore r op
opsExpCompiler = ExpCompiler lore r op
forall lore r op.
Mem lore =>
Pattern lore -> Exp lore -> ImpM lore r op ()
defCompileExp,
      opsOpCompiler :: OpCompiler lore r op
opsOpCompiler = OpCompiler lore r op
opc,
      opsStmsCompiler :: StmsCompiler lore r op
opsStmsCompiler = StmsCompiler lore r op
forall lore op r.
(Mem lore, FreeIn op) =>
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
defCompileStms,
      opsCopyCompiler :: CopyCompiler lore r op
opsCopyCompiler = CopyCompiler lore r op
forall lore r op. CopyCompiler lore r op
defaultCopy,
      opsAllocCompilers :: Map Space (AllocCompiler lore r op)
opsAllocCompilers = Map Space (AllocCompiler lore r op)
forall a. Monoid a => a
mempty
    }

-- | When an array is dared, this is where it is stored.
data MemLocation = MemLocation
  { MemLocation -> VName
memLocationName :: VName,
    MemLocation -> [SubExp]
memLocationShape :: [Imp.DimSize],
    MemLocation -> IxFun (TExp Int64)
memLocationIxFun :: IxFun.IxFun (Imp.TExp Int64)
  }
  deriving (MemLocation -> MemLocation -> Bool
(MemLocation -> MemLocation -> Bool)
-> (MemLocation -> MemLocation -> Bool) -> Eq MemLocation
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: MemLocation -> MemLocation -> Bool
$c/= :: MemLocation -> MemLocation -> Bool
== :: MemLocation -> MemLocation -> Bool
$c== :: MemLocation -> MemLocation -> Bool
Eq, Int -> MemLocation -> ShowS
[MemLocation] -> ShowS
MemLocation -> [Char]
(Int -> MemLocation -> ShowS)
-> (MemLocation -> [Char])
-> ([MemLocation] -> ShowS)
-> Show MemLocation
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [MemLocation] -> ShowS
$cshowList :: [MemLocation] -> ShowS
show :: MemLocation -> [Char]
$cshow :: MemLocation -> [Char]
showsPrec :: Int -> MemLocation -> ShowS
$cshowsPrec :: Int -> MemLocation -> ShowS
Show)

data ArrayEntry = ArrayEntry
  { ArrayEntry -> MemLocation
entryArrayLocation :: MemLocation,
    ArrayEntry -> PrimType
entryArrayElemType :: PrimType
  }
  deriving (Int -> ArrayEntry -> ShowS
[ArrayEntry] -> ShowS
ArrayEntry -> [Char]
(Int -> ArrayEntry -> ShowS)
-> (ArrayEntry -> [Char])
-> ([ArrayEntry] -> ShowS)
-> Show ArrayEntry
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [ArrayEntry] -> ShowS
$cshowList :: [ArrayEntry] -> ShowS
show :: ArrayEntry -> [Char]
$cshow :: ArrayEntry -> [Char]
showsPrec :: Int -> ArrayEntry -> ShowS
$cshowsPrec :: Int -> ArrayEntry -> ShowS
Show)

entryArrayShape :: ArrayEntry -> [Imp.DimSize]
entryArrayShape :: ArrayEntry -> [SubExp]
entryArrayShape = MemLocation -> [SubExp]
memLocationShape (MemLocation -> [SubExp])
-> (ArrayEntry -> MemLocation) -> ArrayEntry -> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArrayEntry -> MemLocation
entryArrayLocation

newtype MemEntry = MemEntry {MemEntry -> Space
entryMemSpace :: Imp.Space}
  deriving (Int -> MemEntry -> ShowS
[MemEntry] -> ShowS
MemEntry -> [Char]
(Int -> MemEntry -> ShowS)
-> (MemEntry -> [Char]) -> ([MemEntry] -> ShowS) -> Show MemEntry
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [MemEntry] -> ShowS
$cshowList :: [MemEntry] -> ShowS
show :: MemEntry -> [Char]
$cshow :: MemEntry -> [Char]
showsPrec :: Int -> MemEntry -> ShowS
$cshowsPrec :: Int -> MemEntry -> ShowS
Show)

newtype ScalarEntry = ScalarEntry
  { ScalarEntry -> PrimType
entryScalarType :: PrimType
  }
  deriving (Int -> ScalarEntry -> ShowS
[ScalarEntry] -> ShowS
ScalarEntry -> [Char]
(Int -> ScalarEntry -> ShowS)
-> (ScalarEntry -> [Char])
-> ([ScalarEntry] -> ShowS)
-> Show ScalarEntry
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [ScalarEntry] -> ShowS
$cshowList :: [ScalarEntry] -> ShowS
show :: ScalarEntry -> [Char]
$cshow :: ScalarEntry -> [Char]
showsPrec :: Int -> ScalarEntry -> ShowS
$cshowsPrec :: Int -> ScalarEntry -> ShowS
Show)

-- | Every non-scalar variable must be associated with an entry.
data VarEntry lore
  = ArrayVar (Maybe (Exp lore)) ArrayEntry
  | ScalarVar (Maybe (Exp lore)) ScalarEntry
  | MemVar (Maybe (Exp lore)) MemEntry
  | AccVar (Maybe (Exp lore)) (VName, Shape, [Type])
  deriving (Int -> VarEntry lore -> ShowS
[VarEntry lore] -> ShowS
VarEntry lore -> [Char]
(Int -> VarEntry lore -> ShowS)
-> (VarEntry lore -> [Char])
-> ([VarEntry lore] -> ShowS)
-> Show (VarEntry lore)
forall lore. Decorations lore => Int -> VarEntry lore -> ShowS
forall lore. Decorations lore => [VarEntry lore] -> ShowS
forall lore. Decorations lore => VarEntry lore -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [VarEntry lore] -> ShowS
$cshowList :: forall lore. Decorations lore => [VarEntry lore] -> ShowS
show :: VarEntry lore -> [Char]
$cshow :: forall lore. Decorations lore => VarEntry lore -> [Char]
showsPrec :: Int -> VarEntry lore -> ShowS
$cshowsPrec :: forall lore. Decorations lore => Int -> VarEntry lore -> ShowS
Show)

-- | When compiling an expression, this is a description of where the
-- result should end up.  The integer is a reference to the construct
-- that gave rise to this destination (for patterns, this will be the
-- tag of the first name in the pattern).  This can be used to make
-- the generated code easier to relate to the original code.
data Destination = Destination
  { Destination -> Maybe Int
destinationTag :: Maybe Int,
    Destination -> [ValueDestination]
valueDestinations :: [ValueDestination]
  }
  deriving (Int -> Destination -> ShowS
[Destination] -> ShowS
Destination -> [Char]
(Int -> Destination -> ShowS)
-> (Destination -> [Char])
-> ([Destination] -> ShowS)
-> Show Destination
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [Destination] -> ShowS
$cshowList :: [Destination] -> ShowS
show :: Destination -> [Char]
$cshow :: Destination -> [Char]
showsPrec :: Int -> Destination -> ShowS
$cshowsPrec :: Int -> Destination -> ShowS
Show)

data ValueDestination
  = ScalarDestination VName
  | MemoryDestination VName
  | -- | The 'MemLocation' is 'Just' if a copy if
    -- required.  If it is 'Nothing', then a
    -- copy/assignment of a memory block somewhere
    -- takes care of this array.
    ArrayDestination (Maybe MemLocation)
  deriving (Int -> ValueDestination -> ShowS
[ValueDestination] -> ShowS
ValueDestination -> [Char]
(Int -> ValueDestination -> ShowS)
-> (ValueDestination -> [Char])
-> ([ValueDestination] -> ShowS)
-> Show ValueDestination
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [ValueDestination] -> ShowS
$cshowList :: [ValueDestination] -> ShowS
show :: ValueDestination -> [Char]
$cshow :: ValueDestination -> [Char]
showsPrec :: Int -> ValueDestination -> ShowS
$cshowsPrec :: Int -> ValueDestination -> ShowS
Show)

data Env lore r op = Env
  { forall lore r op. Env lore r op -> ExpCompiler lore r op
envExpCompiler :: ExpCompiler lore r op,
    forall lore r op. Env lore r op -> StmsCompiler lore r op
envStmsCompiler :: StmsCompiler lore r op,
    forall lore r op. Env lore r op -> OpCompiler lore r op
envOpCompiler :: OpCompiler lore r op,
    forall lore r op. Env lore r op -> CopyCompiler lore r op
envCopyCompiler :: CopyCompiler lore r op,
    forall lore r op.
Env lore r op -> Map Space (AllocCompiler lore r op)
envAllocCompilers :: M.Map Space (AllocCompiler lore r op),
    forall lore r op. Env lore r op -> Space
envDefaultSpace :: Imp.Space,
    forall lore r op. Env lore r op -> Volatility
envVolatility :: Imp.Volatility,
    -- | User-extensible environment.
    forall lore r op. Env lore r op -> r
envEnv :: r,
    -- | Name of the function we are compiling, if any.
    forall lore r op. Env lore r op -> Maybe Name
envFunction :: Maybe Name,
    -- | The set of attributes that are active on the enclosing
    -- statements (including the one we are currently compiling).
    forall lore r op. Env lore r op -> Attrs
envAttrs :: Attrs
  }

newEnv :: r -> Operations lore r op -> Imp.Space -> Env lore r op
newEnv :: forall r lore op.
r -> Operations lore r op -> Space -> Env lore r op
newEnv r
r Operations lore r op
ops Space
ds =
  Env :: forall lore r op.
ExpCompiler lore r op
-> StmsCompiler lore r op
-> OpCompiler lore r op
-> CopyCompiler lore r op
-> Map Space (AllocCompiler lore r op)
-> Space
-> Volatility
-> r
-> Maybe Name
-> Attrs
-> Env lore r op
Env
    { envExpCompiler :: ExpCompiler lore r op
envExpCompiler = Operations lore r op -> ExpCompiler lore r op
forall lore r op. Operations lore r op -> ExpCompiler lore r op
opsExpCompiler Operations lore r op
ops,
      envStmsCompiler :: StmsCompiler lore r op
envStmsCompiler = Operations lore r op -> StmsCompiler lore r op
forall lore r op. Operations lore r op -> StmsCompiler lore r op
opsStmsCompiler Operations lore r op
ops,
      envOpCompiler :: OpCompiler lore r op
envOpCompiler = Operations lore r op -> OpCompiler lore r op
forall lore r op. Operations lore r op -> OpCompiler lore r op
opsOpCompiler Operations lore r op
ops,
      envCopyCompiler :: CopyCompiler lore r op
envCopyCompiler = Operations lore r op -> CopyCompiler lore r op
forall lore r op. Operations lore r op -> CopyCompiler lore r op
opsCopyCompiler Operations lore r op
ops,
      envAllocCompilers :: Map Space (AllocCompiler lore r op)
envAllocCompilers = Map Space (AllocCompiler lore r op)
forall a. Monoid a => a
mempty,
      envDefaultSpace :: Space
envDefaultSpace = Space
ds,
      envVolatility :: Volatility
envVolatility = Volatility
Imp.Nonvolatile,
      envEnv :: r
envEnv = r
r,
      envFunction :: Maybe Name
envFunction = Maybe Name
forall a. Maybe a
Nothing,
      envAttrs :: Attrs
envAttrs = Attrs
forall a. Monoid a => a
mempty
    }

-- | The symbol table used during compilation.
type VTable lore = M.Map VName (VarEntry lore)

data ImpState lore r op = ImpState
  { forall lore r op. ImpState lore r op -> VTable lore
stateVTable :: VTable lore,
    forall lore r op. ImpState lore r op -> Functions op
stateFunctions :: Imp.Functions op,
    forall lore r op. ImpState lore r op -> Code op
stateCode :: Imp.Code op,
    forall lore r op. ImpState lore r op -> Warnings
stateWarnings :: Warnings,
    -- | Maps the arrays backing each accumulator to their
    -- update function and neutral elements.  This works
    -- because an array name can only become part of a single
    -- accumulator throughout its lifetime.  If the arrays
    -- backing an accumulator is not in this mapping, the
    -- accumulator is scatter-like.
    forall lore r op.
ImpState lore r op
-> Map VName ([VName], Maybe (Lambda lore, [SubExp]))
stateAccs :: M.Map VName ([VName], Maybe (Lambda lore, [SubExp])),
    forall lore r op. ImpState lore r op -> VNameSource
stateNameSource :: VNameSource
  }

newState :: VNameSource -> ImpState lore r op
newState :: forall lore r op. VNameSource -> ImpState lore r op
newState = VTable lore
-> Functions op
-> Code op
-> Warnings
-> Map VName ([VName], Maybe (Lambda lore, [SubExp]))
-> VNameSource
-> ImpState lore r op
forall lore r op.
VTable lore
-> Functions op
-> Code op
-> Warnings
-> Map VName ([VName], Maybe (Lambda lore, [SubExp]))
-> VNameSource
-> ImpState lore r op
ImpState VTable lore
forall a. Monoid a => a
mempty Functions op
forall a. Monoid a => a
mempty Code op
forall a. Monoid a => a
mempty Warnings
forall a. Monoid a => a
mempty Map VName ([VName], Maybe (Lambda lore, [SubExp]))
forall a. Monoid a => a
mempty

newtype ImpM lore r op a
  = ImpM (ReaderT (Env lore r op) (State (ImpState lore r op)) a)
  deriving
    ( (forall a b. (a -> b) -> ImpM lore r op a -> ImpM lore r op b)
-> (forall a b. a -> ImpM lore r op b -> ImpM lore r op a)
-> Functor (ImpM lore r op)
forall a b. a -> ImpM lore r op b -> ImpM lore r op a
forall a b. (a -> b) -> ImpM lore r op a -> ImpM lore r op b
forall lore r op a b. a -> ImpM lore r op b -> ImpM lore r op a
forall lore r op a b.
(a -> b) -> ImpM lore r op a -> ImpM lore r op b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> ImpM lore r op b -> ImpM lore r op a
$c<$ :: forall lore r op a b. a -> ImpM lore r op b -> ImpM lore r op a
fmap :: forall a b. (a -> b) -> ImpM lore r op a -> ImpM lore r op b
$cfmap :: forall lore r op a b.
(a -> b) -> ImpM lore r op a -> ImpM lore r op b
Functor,
      Functor (ImpM lore r op)
Functor (ImpM lore r op)
-> (forall a. a -> ImpM lore r op a)
-> (forall a b.
    ImpM lore r op (a -> b) -> ImpM lore r op a -> ImpM lore r op b)
-> (forall a b c.
    (a -> b -> c)
    -> ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op c)
-> (forall a b.
    ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b)
-> (forall a b.
    ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op a)
-> Applicative (ImpM lore r op)
forall a. a -> ImpM lore r op a
forall a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op a
forall a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
forall a b.
ImpM lore r op (a -> b) -> ImpM lore r op a -> ImpM lore r op b
forall lore r op. Functor (ImpM lore r op)
forall a b c.
(a -> b -> c)
-> ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op c
forall lore r op a. a -> ImpM lore r op a
forall lore r op a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op a
forall lore r op a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
forall lore r op a b.
ImpM lore r op (a -> b) -> ImpM lore r op a -> ImpM lore r op b
forall lore r op a b c.
(a -> b -> c)
-> ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op a
$c<* :: forall lore r op a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op a
*> :: forall a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
$c*> :: forall lore r op a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
liftA2 :: forall a b c.
(a -> b -> c)
-> ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op c
$cliftA2 :: forall lore r op a b c.
(a -> b -> c)
-> ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op c
<*> :: forall a b.
ImpM lore r op (a -> b) -> ImpM lore r op a -> ImpM lore r op b
$c<*> :: forall lore r op a b.
ImpM lore r op (a -> b) -> ImpM lore r op a -> ImpM lore r op b
pure :: forall a. a -> ImpM lore r op a
$cpure :: forall lore r op a. a -> ImpM lore r op a
Applicative,
      Applicative (ImpM lore r op)
Applicative (ImpM lore r op)
-> (forall a b.
    ImpM lore r op a -> (a -> ImpM lore r op b) -> ImpM lore r op b)
-> (forall a b.
    ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b)
-> (forall a. a -> ImpM lore r op a)
-> Monad (ImpM lore r op)
forall a. a -> ImpM lore r op a
forall a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
forall a b.
ImpM lore r op a -> (a -> ImpM lore r op b) -> ImpM lore r op b
forall lore r op. Applicative (ImpM lore r op)
forall lore r op a. a -> ImpM lore r op a
forall lore r op a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
forall lore r op a b.
ImpM lore r op a -> (a -> ImpM lore r op b) -> ImpM lore r op b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> ImpM lore r op a
$creturn :: forall lore r op a. a -> ImpM lore r op a
>> :: forall a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
$c>> :: forall lore r op a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
>>= :: forall a b.
ImpM lore r op a -> (a -> ImpM lore r op b) -> ImpM lore r op b
$c>>= :: forall lore r op a b.
ImpM lore r op a -> (a -> ImpM lore r op b) -> ImpM lore r op b
Monad,
      MonadState (ImpState lore r op),
      MonadReader (Env lore r op)
    )

instance MonadFreshNames (ImpM lore r op) where
  getNameSource :: ImpM lore r op VNameSource
getNameSource = (ImpState lore r op -> VNameSource) -> ImpM lore r op VNameSource
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ImpState lore r op -> VNameSource
forall lore r op. ImpState lore r op -> VNameSource
stateNameSource
  putNameSource :: VNameSource -> ImpM lore r op ()
putNameSource VNameSource
src = (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ())
-> (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState lore r op
s -> ImpState lore r op
s {stateNameSource :: VNameSource
stateNameSource = VNameSource
src}

-- Cannot be an KernelsMem scope because the index functions have
-- the wrong leaves (VName instead of Imp.Exp).
instance HasScope SOACS (ImpM lore r op) where
  askScope :: ImpM lore r op (Scope SOACS)
askScope = (ImpState lore r op -> Scope SOACS) -> ImpM lore r op (Scope SOACS)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ImpState lore r op -> Scope SOACS)
 -> ImpM lore r op (Scope SOACS))
-> (ImpState lore r op -> Scope SOACS)
-> ImpM lore r op (Scope SOACS)
forall a b. (a -> b) -> a -> b
$ (VarEntry lore -> NameInfo SOACS)
-> Map VName (VarEntry lore) -> Scope SOACS
forall a b k. (a -> b) -> Map k a -> Map k b
M.map (Type -> NameInfo SOACS
forall lore. LetDec lore -> NameInfo lore
LetName (Type -> NameInfo SOACS)
-> (VarEntry lore -> Type) -> VarEntry lore -> NameInfo SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VarEntry lore -> Type
forall {lore}. VarEntry lore -> Type
entryType) (Map VName (VarEntry lore) -> Scope SOACS)
-> (ImpState lore r op -> Map VName (VarEntry lore))
-> ImpState lore r op
-> Scope SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ImpState lore r op -> Map VName (VarEntry lore)
forall lore r op. ImpState lore r op -> VTable lore
stateVTable
    where
      entryType :: VarEntry lore -> Type
entryType (MemVar Maybe (Exp lore)
_ MemEntry
memEntry) =
        Space -> Type
forall shape u. Space -> TypeBase shape u
Mem (MemEntry -> Space
entryMemSpace MemEntry
memEntry)
      entryType (ArrayVar Maybe (Exp lore)
_ ArrayEntry
arrayEntry) =
        PrimType -> Shape -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array
          (ArrayEntry -> PrimType
entryArrayElemType ArrayEntry
arrayEntry)
          ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> [SubExp] -> Shape
forall a b. (a -> b) -> a -> b
$ ArrayEntry -> [SubExp]
entryArrayShape ArrayEntry
arrayEntry)
          NoUniqueness
NoUniqueness
      entryType (ScalarVar Maybe (Exp lore)
_ ScalarEntry
scalarEntry) =
        PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> Type) -> PrimType -> Type
forall a b. (a -> b) -> a -> b
$ ScalarEntry -> PrimType
entryScalarType ScalarEntry
scalarEntry
      entryType (AccVar Maybe (Exp lore)
_ (VName
acc, Shape
ispace, [Type]
ts)) =
        VName -> Shape -> [Type] -> NoUniqueness -> Type
forall shape u. VName -> Shape -> [Type] -> u -> TypeBase shape u
Acc VName
acc Shape
ispace [Type]
ts NoUniqueness
NoUniqueness

runImpM ::
  ImpM lore r op a ->
  r ->
  Operations lore r op ->
  Imp.Space ->
  ImpState lore r op ->
  (a, ImpState lore r op)
runImpM :: forall lore r op a.
ImpM lore r op a
-> r
-> Operations lore r op
-> Space
-> ImpState lore r op
-> (a, ImpState lore r op)
runImpM (ImpM ReaderT (Env lore r op) (State (ImpState lore r op)) a
m) r
r Operations lore r op
ops Space
space = State (ImpState lore r op) a
-> ImpState lore r op -> (a, ImpState lore r op)
forall s a. State s a -> s -> (a, s)
runState (ReaderT (Env lore r op) (State (ImpState lore r op)) a
-> Env lore r op -> State (ImpState lore r op) a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Env lore r op) (State (ImpState lore r op)) a
m (Env lore r op -> State (ImpState lore r op) a)
-> Env lore r op -> State (ImpState lore r op) a
forall a b. (a -> b) -> a -> b
$ r -> Operations lore r op -> Space -> Env lore r op
forall r lore op.
r -> Operations lore r op -> Space -> Env lore r op
newEnv r
r Operations lore r op
ops Space
space)

subImpM_ ::
  r' ->
  Operations lore r' op' ->
  ImpM lore r' op' a ->
  ImpM lore r op (Imp.Code op')
subImpM_ :: forall r' lore op' a r op.
r'
-> Operations lore r' op'
-> ImpM lore r' op' a
-> ImpM lore r op (Code op')
subImpM_ r'
r Operations lore r' op'
ops ImpM lore r' op' a
m = (a, Code op') -> Code op'
forall a b. (a, b) -> b
snd ((a, Code op') -> Code op')
-> ImpM lore r op (a, Code op') -> ImpM lore r op (Code op')
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> r'
-> Operations lore r' op'
-> ImpM lore r' op' a
-> ImpM lore r op (a, Code op')
forall r' lore op' a r op.
r'
-> Operations lore r' op'
-> ImpM lore r' op' a
-> ImpM lore r op (a, Code op')
subImpM r'
r Operations lore r' op'
ops ImpM lore r' op' a
m

subImpM ::
  r' ->
  Operations lore r' op' ->
  ImpM lore r' op' a ->
  ImpM lore r op (a, Imp.Code op')
subImpM :: forall r' lore op' a r op.
r'
-> Operations lore r' op'
-> ImpM lore r' op' a
-> ImpM lore r op (a, Code op')
subImpM r'
r Operations lore r' op'
ops (ImpM ReaderT (Env lore r' op') (State (ImpState lore r' op')) a
m) = do
  Env lore r op
env <- ImpM lore r op (Env lore r op)
forall r (m :: * -> *). MonadReader r m => m r
ask
  ImpState lore r op
s <- ImpM lore r op (ImpState lore r op)
forall s (m :: * -> *). MonadState s m => m s
get

  let env' :: Env lore r' op'
env' =
        Env lore r op
env
          { envExpCompiler :: ExpCompiler lore r' op'
envExpCompiler = Operations lore r' op' -> ExpCompiler lore r' op'
forall lore r op. Operations lore r op -> ExpCompiler lore r op
opsExpCompiler Operations lore r' op'
ops,
            envStmsCompiler :: StmsCompiler lore r' op'
envStmsCompiler = Operations lore r' op' -> StmsCompiler lore r' op'
forall lore r op. Operations lore r op -> StmsCompiler lore r op
opsStmsCompiler Operations lore r' op'
ops,
            envCopyCompiler :: CopyCompiler lore r' op'
envCopyCompiler = Operations lore r' op' -> CopyCompiler lore r' op'
forall lore r op. Operations lore r op -> CopyCompiler lore r op
opsCopyCompiler Operations lore r' op'
ops,
            envOpCompiler :: OpCompiler lore r' op'
envOpCompiler = Operations lore r' op' -> OpCompiler lore r' op'
forall lore r op. Operations lore r op -> OpCompiler lore r op
opsOpCompiler Operations lore r' op'
ops,
            envAllocCompilers :: Map Space (AllocCompiler lore r' op')
envAllocCompilers = Operations lore r' op' -> Map Space (AllocCompiler lore r' op')
forall lore r op.
Operations lore r op -> Map Space (AllocCompiler lore r op)
opsAllocCompilers Operations lore r' op'
ops,
            envEnv :: r'
envEnv = r'
r
          }
      s' :: ImpState lore r' op'
s' =
        ImpState :: forall lore r op.
VTable lore
-> Functions op
-> Code op
-> Warnings
-> Map VName ([VName], Maybe (Lambda lore, [SubExp]))
-> VNameSource
-> ImpState lore r op
ImpState
          { stateVTable :: VTable lore
stateVTable = ImpState lore r op -> VTable lore
forall lore r op. ImpState lore r op -> VTable lore
stateVTable ImpState lore r op
s,
            stateFunctions :: Functions op'
stateFunctions = Functions op'
forall a. Monoid a => a
mempty,
            stateCode :: Code op'
stateCode = Code op'
forall a. Monoid a => a
mempty,
            stateNameSource :: VNameSource
stateNameSource = ImpState lore r op -> VNameSource
forall lore r op. ImpState lore r op -> VNameSource
stateNameSource ImpState lore r op
s,
            stateWarnings :: Warnings
stateWarnings = Warnings
forall a. Monoid a => a
mempty,
            stateAccs :: Map VName ([VName], Maybe (Lambda lore, [SubExp]))
stateAccs = ImpState lore r op
-> Map VName ([VName], Maybe (Lambda lore, [SubExp]))
forall lore r op.
ImpState lore r op
-> Map VName ([VName], Maybe (Lambda lore, [SubExp]))
stateAccs ImpState lore r op
s
          }
      (a
x, ImpState lore r' op'
s'') = State (ImpState lore r' op') a
-> ImpState lore r' op' -> (a, ImpState lore r' op')
forall s a. State s a -> s -> (a, s)
runState (ReaderT (Env lore r' op') (State (ImpState lore r' op')) a
-> Env lore r' op' -> State (ImpState lore r' op') a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Env lore r' op') (State (ImpState lore r' op')) a
m Env lore r' op'
env') ImpState lore r' op'
s'

  VNameSource -> ImpM lore r op ()
forall (m :: * -> *). MonadFreshNames m => VNameSource -> m ()
putNameSource (VNameSource -> ImpM lore r op ())
-> VNameSource -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ ImpState lore r' op' -> VNameSource
forall lore r op. ImpState lore r op -> VNameSource
stateNameSource ImpState lore r' op'
s''
  Warnings -> ImpM lore r op ()
forall lore r op. Warnings -> ImpM lore r op ()
warnings (Warnings -> ImpM lore r op ()) -> Warnings -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ ImpState lore r' op' -> Warnings
forall lore r op. ImpState lore r op -> Warnings
stateWarnings ImpState lore r' op'
s''
  (a, Code op') -> ImpM lore r op (a, Code op')
forall (m :: * -> *) a. Monad m => a -> m a
return (a
x, ImpState lore r' op' -> Code op'
forall lore r op. ImpState lore r op -> Code op
stateCode ImpState lore r' op'
s'')

-- | Execute a code generation action, returning the code that was
-- emitted.
collect :: ImpM lore r op () -> ImpM lore r op (Imp.Code op)
collect :: forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect = (((), Code op) -> Code op)
-> ImpM lore r op ((), Code op) -> ImpM lore r op (Code op)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((), Code op) -> Code op
forall a b. (a, b) -> b
snd (ImpM lore r op ((), Code op) -> ImpM lore r op (Code op))
-> (ImpM lore r op () -> ImpM lore r op ((), Code op))
-> ImpM lore r op ()
-> ImpM lore r op (Code op)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ImpM lore r op () -> ImpM lore r op ((), Code op)
forall lore r op a. ImpM lore r op a -> ImpM lore r op (a, Code op)
collect'

collect' :: ImpM lore r op a -> ImpM lore r op (a, Imp.Code op)
collect' :: forall lore r op a. ImpM lore r op a -> ImpM lore r op (a, Code op)
collect' ImpM lore r op a
m = do
  Code op
prev_code <- (ImpState lore r op -> Code op) -> ImpM lore r op (Code op)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ImpState lore r op -> Code op
forall lore r op. ImpState lore r op -> Code op
stateCode
  (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ())
-> (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState lore r op
s -> ImpState lore r op
s {stateCode :: Code op
stateCode = Code op
forall a. Monoid a => a
mempty}
  a
x <- ImpM lore r op a
m
  Code op
new_code <- (ImpState lore r op -> Code op) -> ImpM lore r op (Code op)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ImpState lore r op -> Code op
forall lore r op. ImpState lore r op -> Code op
stateCode
  (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ())
-> (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState lore r op
s -> ImpState lore r op
s {stateCode :: Code op
stateCode = Code op
prev_code}
  (a, Code op) -> ImpM lore r op (a, Code op)
forall (m :: * -> *) a. Monad m => a -> m a
return (a
x, Code op
new_code)

-- | Execute a code generation action, wrapping the generated code
-- within a 'Imp.Comment' with the given description.
comment :: String -> ImpM lore r op () -> ImpM lore r op ()
comment :: forall lore r op. [Char] -> ImpM lore r op () -> ImpM lore r op ()
comment [Char]
desc ImpM lore r op ()
m = do
  Code op
code <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect ImpM lore r op ()
m
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Code op -> Code op
forall a. [Char] -> Code a -> Code a
Imp.Comment [Char]
desc Code op
code

-- | Emit some generated imperative code.
emit :: Imp.Code op -> ImpM lore r op ()
emit :: forall op lore r. Code op -> ImpM lore r op ()
emit Code op
code = (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ())
-> (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState lore r op
s -> ImpState lore r op
s {stateCode :: Code op
stateCode = ImpState lore r op -> Code op
forall lore r op. ImpState lore r op -> Code op
stateCode ImpState lore r op
s Code op -> Code op -> Code op
forall a. Semigroup a => a -> a -> a
<> Code op
code}

warnings :: Warnings -> ImpM lore r op ()
warnings :: forall lore r op. Warnings -> ImpM lore r op ()
warnings Warnings
ws = (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ())
-> (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState lore r op
s -> ImpState lore r op
s {stateWarnings :: Warnings
stateWarnings = Warnings
ws Warnings -> Warnings -> Warnings
forall a. Semigroup a => a -> a -> a
<> ImpState lore r op -> Warnings
forall lore r op. ImpState lore r op -> Warnings
stateWarnings ImpState lore r op
s}

-- | Emit a warning about something the user should be aware of.
warn :: Located loc => loc -> [loc] -> String -> ImpM lore r op ()
warn :: forall loc lore r op.
Located loc =>
loc -> [loc] -> [Char] -> ImpM lore r op ()
warn loc
loc [loc]
locs [Char]
problem =
  Warnings -> ImpM lore r op ()
forall lore r op. Warnings -> ImpM lore r op ()
warnings (Warnings -> ImpM lore r op ()) -> Warnings -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ SrcLoc -> [SrcLoc] -> Doc -> Warnings
singleWarning' (loc -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf loc
loc) ((loc -> SrcLoc) -> [loc] -> [SrcLoc]
forall a b. (a -> b) -> [a] -> [b]
map loc -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf [loc]
locs) ([Char] -> Doc
forall a. IsString a => [Char] -> a
fromString [Char]
problem)

-- | Emit a function in the generated code.
emitFunction :: Name -> Imp.Function op -> ImpM lore r op ()
emitFunction :: forall op lore r. Name -> Function op -> ImpM lore r op ()
emitFunction Name
fname Function op
fun = do
  Imp.Functions [(Name, Function op)]
fs <- (ImpState lore r op -> Functions op)
-> ImpM lore r op (Functions op)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ImpState lore r op -> Functions op
forall lore r op. ImpState lore r op -> Functions op
stateFunctions
  (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ())
-> (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState lore r op
s -> ImpState lore r op
s {stateFunctions :: Functions op
stateFunctions = [(Name, Function op)] -> Functions op
forall a. [(Name, Function a)] -> Functions a
Imp.Functions ([(Name, Function op)] -> Functions op)
-> [(Name, Function op)] -> Functions op
forall a b. (a -> b) -> a -> b
$ (Name
fname, Function op
fun) (Name, Function op)
-> [(Name, Function op)] -> [(Name, Function op)]
forall a. a -> [a] -> [a]
: [(Name, Function op)]
fs}

-- | Check if a function of a given name exists.
hasFunction :: Name -> ImpM lore r op Bool
hasFunction :: forall lore r op. Name -> ImpM lore r op Bool
hasFunction Name
fname = (ImpState lore r op -> Bool) -> ImpM lore r op Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ImpState lore r op -> Bool) -> ImpM lore r op Bool)
-> (ImpState lore r op -> Bool) -> ImpM lore r op Bool
forall a b. (a -> b) -> a -> b
$ \ImpState lore r op
s ->
  let Imp.Functions [(Name, Function op)]
fs = ImpState lore r op -> Functions op
forall lore r op. ImpState lore r op -> Functions op
stateFunctions ImpState lore r op
s
   in Maybe (Function op) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe (Function op) -> Bool) -> Maybe (Function op) -> Bool
forall a b. (a -> b) -> a -> b
$ Name -> [(Name, Function op)] -> Maybe (Function op)
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Name
fname [(Name, Function op)]
fs

constsVTable :: Mem lore => Stms lore -> VTable lore
constsVTable :: forall lore. Mem lore => Stms lore -> VTable lore
constsVTable = (Stm lore -> Map VName (VarEntry lore))
-> Seq (Stm lore) -> Map VName (VarEntry lore)
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm lore -> Map VName (VarEntry lore)
forall {lore}.
(LetDec lore ~ LParamMem) =>
Stm lore -> Map VName (VarEntry lore)
stmVtable
  where
    stmVtable :: Stm lore -> Map VName (VarEntry lore)
stmVtable (Let Pattern lore
pat StmAux (ExpDec lore)
_ Exp lore
e) =
      (PatElemT LParamMem -> Map VName (VarEntry lore))
-> [PatElemT LParamMem] -> Map VName (VarEntry lore)
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Exp lore -> PatElemT LParamMem -> Map VName (VarEntry lore)
forall {lore}.
Exp lore -> PatElemT LParamMem -> Map VName (VarEntry lore)
peVtable Exp lore
e) ([PatElemT LParamMem] -> Map VName (VarEntry lore))
-> [PatElemT LParamMem] -> Map VName (VarEntry lore)
forall a b. (a -> b) -> a -> b
$ PatternT LParamMem -> [PatElemT LParamMem]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern lore
PatternT LParamMem
pat
    peVtable :: Exp lore -> PatElemT LParamMem -> Map VName (VarEntry lore)
peVtable Exp lore
e (PatElem VName
name LParamMem
dec) =
      VName -> VarEntry lore -> Map VName (VarEntry lore)
forall k a. k -> a -> Map k a
M.singleton VName
name (VarEntry lore -> Map VName (VarEntry lore))
-> VarEntry lore -> Map VName (VarEntry lore)
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> LParamMem -> VarEntry lore
forall lore. Maybe (Exp lore) -> LParamMem -> VarEntry lore
memBoundToVarEntry (Exp lore -> Maybe (Exp lore)
forall a. a -> Maybe a
Just Exp lore
e) LParamMem
dec

compileProg ::
  (Mem lore, FreeIn op, MonadFreshNames m) =>
  r ->
  Operations lore r op ->
  Imp.Space ->
  Prog lore ->
  m (Warnings, Imp.Definitions op)
compileProg :: forall lore op (m :: * -> *) r.
(Mem lore, FreeIn op, MonadFreshNames m) =>
r
-> Operations lore r op
-> Space
-> Prog lore
-> m (Warnings, Definitions op)
compileProg r
r Operations lore r op
ops Space
space (Prog Stms lore
consts [FunDef lore]
funs) =
  (VNameSource -> ((Warnings, Definitions op), VNameSource))
-> m (Warnings, Definitions op)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> ((Warnings, Definitions op), VNameSource))
 -> m (Warnings, Definitions op))
-> (VNameSource -> ((Warnings, Definitions op), VNameSource))
-> m (Warnings, Definitions op)
forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
    let ([()]
_, [ImpState lore r op]
ss) =
          [((), ImpState lore r op)] -> ([()], [ImpState lore r op])
forall a b. [(a, b)] -> ([a], [b])
unzip ([((), ImpState lore r op)] -> ([()], [ImpState lore r op]))
-> [((), ImpState lore r op)] -> ([()], [ImpState lore r op])
forall a b. (a -> b) -> a -> b
$ Strategy ((), ImpState lore r op)
-> (FunDef lore -> ((), ImpState lore r op))
-> [FunDef lore]
-> [((), ImpState lore r op)]
forall b a. Strategy b -> (a -> b) -> [a] -> [b]
parMap Strategy ((), ImpState lore r op)
forall a. Strategy a
rpar (VNameSource -> FunDef lore -> ((), ImpState lore r op)
compileFunDef' VNameSource
src) [FunDef lore]
funs
        free_in_funs :: Names
free_in_funs =
          Functions op -> Names
forall a. FreeIn a => a -> Names
freeIn (Functions op -> Names) -> Functions op -> Names
forall a b. (a -> b) -> a -> b
$ [Functions op] -> Functions op
forall a. Monoid a => [a] -> a
mconcat ([Functions op] -> Functions op) -> [Functions op] -> Functions op
forall a b. (a -> b) -> a -> b
$ (ImpState lore r op -> Functions op)
-> [ImpState lore r op] -> [Functions op]
forall a b. (a -> b) -> [a] -> [b]
map ImpState lore r op -> Functions op
forall lore r op. ImpState lore r op -> Functions op
stateFunctions [ImpState lore r op]
ss
        (Constants op
consts', ImpState lore r op
s') =
          ImpM lore r op (Constants op)
-> r
-> Operations lore r op
-> Space
-> ImpState lore r op
-> (Constants op, ImpState lore r op)
forall lore r op a.
ImpM lore r op a
-> r
-> Operations lore r op
-> Space
-> ImpState lore r op
-> (a, ImpState lore r op)
runImpM (Names -> Stms lore -> ImpM lore r op (Constants op)
forall lore r op.
Names -> Stms lore -> ImpM lore r op (Constants op)
compileConsts Names
free_in_funs Stms lore
consts) r
r Operations lore r op
ops Space
space (ImpState lore r op -> (Constants op, ImpState lore r op))
-> ImpState lore r op -> (Constants op, ImpState lore r op)
forall a b. (a -> b) -> a -> b
$
            [ImpState lore r op] -> ImpState lore r op
forall {lore} {r} {op} {lore} {r}.
[ImpState lore r op] -> ImpState lore r op
combineStates [ImpState lore r op]
ss
     in ( ( ImpState lore r op -> Warnings
forall lore r op. ImpState lore r op -> Warnings
stateWarnings ImpState lore r op
s',
            Constants op -> Functions op -> Definitions op
forall a. Constants a -> Functions a -> Definitions a
Imp.Definitions Constants op
consts' (ImpState lore r op -> Functions op
forall lore r op. ImpState lore r op -> Functions op
stateFunctions ImpState lore r op
s')
          ),
          ImpState lore r op -> VNameSource
forall lore r op. ImpState lore r op -> VNameSource
stateNameSource ImpState lore r op
s'
        )
  where
    compileFunDef' :: VNameSource -> FunDef lore -> ((), ImpState lore r op)
compileFunDef' VNameSource
src FunDef lore
fdef =
      ImpM lore r op ()
-> r
-> Operations lore r op
-> Space
-> ImpState lore r op
-> ((), ImpState lore r op)
forall lore r op a.
ImpM lore r op a
-> r
-> Operations lore r op
-> Space
-> ImpState lore r op
-> (a, ImpState lore r op)
runImpM
        (FunDef lore -> ImpM lore r op ()
forall lore r op. Mem lore => FunDef lore -> ImpM lore r op ()
compileFunDef FunDef lore
fdef)
        r
r
        Operations lore r op
ops
        Space
space
        (VNameSource -> ImpState lore Any op
forall lore r op. VNameSource -> ImpState lore r op
newState VNameSource
src) {stateVTable :: VTable lore
stateVTable = Stms lore -> VTable lore
forall lore. Mem lore => Stms lore -> VTable lore
constsVTable Stms lore
consts}

    combineStates :: [ImpState lore r op] -> ImpState lore r op
combineStates [ImpState lore r op]
ss =
      let Imp.Functions [(Name, Function op)]
funs' = [Functions op] -> Functions op
forall a. Monoid a => [a] -> a
mconcat ([Functions op] -> Functions op) -> [Functions op] -> Functions op
forall a b. (a -> b) -> a -> b
$ (ImpState lore r op -> Functions op)
-> [ImpState lore r op] -> [Functions op]
forall a b. (a -> b) -> [a] -> [b]
map ImpState lore r op -> Functions op
forall lore r op. ImpState lore r op -> Functions op
stateFunctions [ImpState lore r op]
ss
          src :: VNameSource
src = [VNameSource] -> VNameSource
forall a. Monoid a => [a] -> a
mconcat ((ImpState lore r op -> VNameSource)
-> [ImpState lore r op] -> [VNameSource]
forall a b. (a -> b) -> [a] -> [b]
map ImpState lore r op -> VNameSource
forall lore r op. ImpState lore r op -> VNameSource
stateNameSource [ImpState lore r op]
ss)
       in (VNameSource -> ImpState lore Any op
forall lore r op. VNameSource -> ImpState lore r op
newState VNameSource
src)
            { stateFunctions :: Functions op
stateFunctions =
                [(Name, Function op)] -> Functions op
forall a. [(Name, Function a)] -> Functions a
Imp.Functions ([(Name, Function op)] -> Functions op)
-> [(Name, Function op)] -> Functions op
forall a b. (a -> b) -> a -> b
$ Map Name (Function op) -> [(Name, Function op)]
forall k a. Map k a -> [(k, a)]
M.toList (Map Name (Function op) -> [(Name, Function op)])
-> Map Name (Function op) -> [(Name, Function op)]
forall a b. (a -> b) -> a -> b
$ [(Name, Function op)] -> Map Name (Function op)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(Name, Function op)]
funs',
              stateWarnings :: Warnings
stateWarnings =
                [Warnings] -> Warnings
forall a. Monoid a => [a] -> a
mconcat ([Warnings] -> Warnings) -> [Warnings] -> Warnings
forall a b. (a -> b) -> a -> b
$ (ImpState lore r op -> Warnings)
-> [ImpState lore r op] -> [Warnings]
forall a b. (a -> b) -> [a] -> [b]
map ImpState lore r op -> Warnings
forall lore r op. ImpState lore r op -> Warnings
stateWarnings [ImpState lore r op]
ss
            }

compileConsts :: Names -> Stms lore -> ImpM lore r op (Imp.Constants op)
compileConsts :: forall lore r op.
Names -> Stms lore -> ImpM lore r op (Constants op)
compileConsts Names
used_consts Stms lore
stms = do
  Code op
code <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (ImpM lore r op () -> ImpM lore r op (Code op))
-> ImpM lore r op () -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$ Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
used_consts Stms lore
stms (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ () -> ImpM lore r op ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  Constants op -> ImpM lore r op (Constants op)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Constants op -> ImpM lore r op (Constants op))
-> Constants op -> ImpM lore r op (Constants op)
forall a b. (a -> b) -> a -> b
$ ([Param] -> Code op -> Constants op)
-> ([Param], Code op) -> Constants op
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry [Param] -> Code op -> Constants op
forall a. [Param] -> Code a -> Constants a
Imp.Constants (([Param], Code op) -> Constants op)
-> ([Param], Code op) -> Constants op
forall a b. (a -> b) -> a -> b
$ (DList Param -> [Param])
-> (DList Param, Code op) -> ([Param], Code op)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first DList Param -> [Param]
forall a. DList a -> [a]
DL.toList ((DList Param, Code op) -> ([Param], Code op))
-> (DList Param, Code op) -> ([Param], Code op)
forall a b. (a -> b) -> a -> b
$ Code op -> (DList Param, Code op)
extract Code op
code
  where
    -- Fish out those top-level declarations in the constant
    -- initialisation code that are free in the functions.
    extract :: Code op -> (DList Param, Code op)
extract (Code op
x Imp.:>>: Code op
y) =
      Code op -> (DList Param, Code op)
extract Code op
x (DList Param, Code op)
-> (DList Param, Code op) -> (DList Param, Code op)
forall a. Semigroup a => a -> a -> a
<> Code op -> (DList Param, Code op)
extract Code op
y
    extract (Imp.DeclareMem VName
name Space
space)
      | VName
name VName -> Names -> Bool
`nameIn` Names
used_consts =
        ( Param -> DList Param
forall a. a -> DList a
DL.singleton (Param -> DList Param) -> Param -> DList Param
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Param
Imp.MemParam VName
name Space
space,
          Code op
forall a. Monoid a => a
mempty
        )
    extract (Imp.DeclareScalar VName
name Volatility
_ PrimType
t)
      | VName
name VName -> Names -> Bool
`nameIn` Names
used_consts =
        ( Param -> DList Param
forall a. a -> DList a
DL.singleton (Param -> DList Param) -> Param -> DList Param
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Param
Imp.ScalarParam VName
name PrimType
t,
          Code op
forall a. Monoid a => a
mempty
        )
    extract Code op
s =
      (DList Param
forall a. Monoid a => a
mempty, Code op
s)

compileInParam ::
  Mem lore =>
  FParam lore ->
  ImpM lore r op (Either Imp.Param ArrayDecl)
compileInParam :: forall lore r op.
Mem lore =>
FParam lore -> ImpM lore r op (Either Param ArrayDecl)
compileInParam FParam lore
fparam = case Param FParamMem -> FParamMem
forall dec. Param dec -> dec
paramDec FParam lore
Param FParamMem
fparam of
  MemPrim PrimType
bt ->
    Either Param ArrayDecl -> ImpM lore r op (Either Param ArrayDecl)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either Param ArrayDecl -> ImpM lore r op (Either Param ArrayDecl))
-> Either Param ArrayDecl
-> ImpM lore r op (Either Param ArrayDecl)
forall a b. (a -> b) -> a -> b
$ Param -> Either Param ArrayDecl
forall a b. a -> Either a b
Left (Param -> Either Param ArrayDecl)
-> Param -> Either Param ArrayDecl
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Param
Imp.ScalarParam VName
name PrimType
bt
  MemMem Space
space ->
    Either Param ArrayDecl -> ImpM lore r op (Either Param ArrayDecl)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either Param ArrayDecl -> ImpM lore r op (Either Param ArrayDecl))
-> Either Param ArrayDecl
-> ImpM lore r op (Either Param ArrayDecl)
forall a b. (a -> b) -> a -> b
$ Param -> Either Param ArrayDecl
forall a b. a -> Either a b
Left (Param -> Either Param ArrayDecl)
-> Param -> Either Param ArrayDecl
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Param
Imp.MemParam VName
name Space
space
  MemArray PrimType
bt Shape
shape Uniqueness
_ (ArrayIn VName
mem IxFun
ixfun) ->
    Either Param ArrayDecl -> ImpM lore r op (Either Param ArrayDecl)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either Param ArrayDecl -> ImpM lore r op (Either Param ArrayDecl))
-> Either Param ArrayDecl
-> ImpM lore r op (Either Param ArrayDecl)
forall a b. (a -> b) -> a -> b
$
      ArrayDecl -> Either Param ArrayDecl
forall a b. b -> Either a b
Right (ArrayDecl -> Either Param ArrayDecl)
-> ArrayDecl -> Either Param ArrayDecl
forall a b. (a -> b) -> a -> b
$
        VName -> PrimType -> MemLocation -> ArrayDecl
ArrayDecl VName
name PrimType
bt (MemLocation -> ArrayDecl) -> MemLocation -> ArrayDecl
forall a b. (a -> b) -> a -> b
$
          VName -> [SubExp] -> IxFun (TExp Int64) -> MemLocation
MemLocation VName
mem (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape) (IxFun (TExp Int64) -> MemLocation)
-> IxFun (TExp Int64) -> MemLocation
forall a b. (a -> b) -> a -> b
$ (TPrimExp Int64 VName -> TExp Int64) -> IxFun -> IxFun (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> ExpLeaf) -> TPrimExp Int64 VName -> TExp Int64
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> ExpLeaf
Imp.ScalarVar) IxFun
ixfun
  MemAcc {} ->
    [Char] -> ImpM lore r op (Either Param ArrayDecl)
forall a. HasCallStack => [Char] -> a
error [Char]
"Functions may not have accumulator parameters."
  where
    name :: VName
name = Param FParamMem -> VName
forall dec. Param dec -> VName
paramName FParam lore
Param FParamMem
fparam

data ArrayDecl = ArrayDecl VName PrimType MemLocation

compileInParams ::
  Mem lore =>
  [FParam lore] ->
  [EntryPointType] ->
  ImpM lore r op ([Imp.Param], [ArrayDecl], [Imp.ExternalValue])
compileInParams :: forall lore r op.
Mem lore =>
[FParam lore]
-> [EntryPointType]
-> ImpM lore r op ([Param], [ArrayDecl], [ExternalValue])
compileInParams [FParam lore]
params [EntryPointType]
orig_epts = do
  let ([Param FParamMem]
ctx_params, [Param FParamMem]
val_params) =
        Int -> [Param FParamMem] -> ([Param FParamMem], [Param FParamMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([Param FParamMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [FParam lore]
[Param FParamMem]
params Int -> Int -> Int
forall a. Num a => a -> a -> a
- [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((EntryPointType -> Int) -> [EntryPointType] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map EntryPointType -> Int
entryPointSize [EntryPointType]
orig_epts)) [FParam lore]
[Param FParamMem]
params
  ([Param]
inparams, [ArrayDecl]
arrayds) <- [Either Param ArrayDecl] -> ([Param], [ArrayDecl])
forall a b. [Either a b] -> ([a], [b])
partitionEithers ([Either Param ArrayDecl] -> ([Param], [ArrayDecl]))
-> ImpM lore r op [Either Param ArrayDecl]
-> ImpM lore r op ([Param], [ArrayDecl])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Param FParamMem -> ImpM lore r op (Either Param ArrayDecl))
-> [Param FParamMem] -> ImpM lore r op [Either Param ArrayDecl]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Param FParamMem -> ImpM lore r op (Either Param ArrayDecl)
forall lore r op.
Mem lore =>
FParam lore -> ImpM lore r op (Either Param ArrayDecl)
compileInParam ([Param FParamMem]
ctx_params [Param FParamMem] -> [Param FParamMem] -> [Param FParamMem]
forall a. [a] -> [a] -> [a]
++ [Param FParamMem]
val_params)
  let findArray :: VName -> Maybe ArrayDecl
findArray VName
x = (ArrayDecl -> Bool) -> [ArrayDecl] -> Maybe ArrayDecl
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (VName -> ArrayDecl -> Bool
isArrayDecl VName
x) [ArrayDecl]
arrayds

      summaries :: Map VName Space
summaries = [(VName, Space)] -> Map VName Space
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Space)] -> Map VName Space)
-> [(VName, Space)] -> Map VName Space
forall a b. (a -> b) -> a -> b
$ (Param FParamMem -> Maybe (VName, Space))
-> [Param FParamMem] -> [(VName, Space)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Param FParamMem -> Maybe (VName, Space)
forall {d} {u} {ret}.
Param (MemInfo d u ret) -> Maybe (VName, Space)
memSummary [FParam lore]
[Param FParamMem]
params
        where
          memSummary :: Param (MemInfo d u ret) -> Maybe (VName, Space)
memSummary Param (MemInfo d u ret)
param
            | MemMem Space
space <- Param (MemInfo d u ret) -> MemInfo d u ret
forall dec. Param dec -> dec
paramDec Param (MemInfo d u ret)
param =
              (VName, Space) -> Maybe (VName, Space)
forall a. a -> Maybe a
Just (Param (MemInfo d u ret) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo d u ret)
param, Space
space)
            | Bool
otherwise =
              Maybe (VName, Space)
forall a. Maybe a
Nothing

      findMemInfo :: VName -> Maybe Space
      findMemInfo :: VName -> Maybe Space
findMemInfo = (VName -> Map VName Space -> Maybe Space)
-> Map VName Space -> VName -> Maybe Space
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> Map VName Space -> Maybe Space
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Map VName Space
summaries

      mkValueDesc :: Param FParamMem -> Signedness -> Maybe ValueDesc
mkValueDesc Param FParamMem
fparam Signedness
signedness =
        case (VName -> Maybe ArrayDecl
findArray (VName -> Maybe ArrayDecl) -> VName -> Maybe ArrayDecl
forall a b. (a -> b) -> a -> b
$ Param FParamMem -> VName
forall dec. Param dec -> VName
paramName Param FParamMem
fparam, Param FParamMem -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param FParamMem
fparam) of
          (Just (ArrayDecl VName
_ PrimType
bt (MemLocation VName
mem [SubExp]
shape IxFun (TExp Int64)
_)), Type
_) -> do
            Space
memspace <- VName -> Maybe Space
findMemInfo VName
mem
            ValueDesc -> Maybe ValueDesc
forall a. a -> Maybe a
Just (ValueDesc -> Maybe ValueDesc) -> ValueDesc -> Maybe ValueDesc
forall a b. (a -> b) -> a -> b
$ VName -> Space -> PrimType -> Signedness -> [SubExp] -> ValueDesc
Imp.ArrayValue VName
mem Space
memspace PrimType
bt Signedness
signedness [SubExp]
shape
          (Maybe ArrayDecl
_, Prim PrimType
bt) ->
            ValueDesc -> Maybe ValueDesc
forall a. a -> Maybe a
Just (ValueDesc -> Maybe ValueDesc) -> ValueDesc -> Maybe ValueDesc
forall a b. (a -> b) -> a -> b
$ PrimType -> Signedness -> VName -> ValueDesc
Imp.ScalarValue PrimType
bt Signedness
signedness (VName -> ValueDesc) -> VName -> ValueDesc
forall a b. (a -> b) -> a -> b
$ Param FParamMem -> VName
forall dec. Param dec -> VName
paramName Param FParamMem
fparam
          (Maybe ArrayDecl, Type)
_ ->
            Maybe ValueDesc
forall a. Maybe a
Nothing

      mkExts :: [EntryPointType] -> [Param FParamMem] -> [ExternalValue]
mkExts (TypeOpaque [Char]
desc Int
n : [EntryPointType]
epts) [Param FParamMem]
fparams =
        let ([Param FParamMem]
fparams', [Param FParamMem]
rest) = Int -> [Param FParamMem] -> ([Param FParamMem], [Param FParamMem])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
n [Param FParamMem]
fparams
         in [Char] -> [ValueDesc] -> ExternalValue
Imp.OpaqueValue
              [Char]
desc
              ((Param FParamMem -> Maybe ValueDesc)
-> [Param FParamMem] -> [ValueDesc]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (Param FParamMem -> Signedness -> Maybe ValueDesc
`mkValueDesc` Signedness
Imp.TypeDirect) [Param FParamMem]
fparams') ExternalValue -> [ExternalValue] -> [ExternalValue]
forall a. a -> [a] -> [a]
:
            [EntryPointType] -> [Param FParamMem] -> [ExternalValue]
mkExts [EntryPointType]
epts [Param FParamMem]
rest
      mkExts (EntryPointType
TypeUnsigned : [EntryPointType]
epts) (Param FParamMem
fparam : [Param FParamMem]
fparams) =
        Maybe ExternalValue -> [ExternalValue]
forall a. Maybe a -> [a]
maybeToList (ValueDesc -> ExternalValue
Imp.TransparentValue (ValueDesc -> ExternalValue)
-> Maybe ValueDesc -> Maybe ExternalValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Param FParamMem -> Signedness -> Maybe ValueDesc
mkValueDesc Param FParamMem
fparam Signedness
Imp.TypeUnsigned)
          [ExternalValue] -> [ExternalValue] -> [ExternalValue]
forall a. [a] -> [a] -> [a]
++ [EntryPointType] -> [Param FParamMem] -> [ExternalValue]
mkExts [EntryPointType]
epts [Param FParamMem]
fparams
      mkExts (EntryPointType
TypeDirect : [EntryPointType]
epts) (Param FParamMem
fparam : [Param FParamMem]
fparams) =
        Maybe ExternalValue -> [ExternalValue]
forall a. Maybe a -> [a]
maybeToList (ValueDesc -> ExternalValue
Imp.TransparentValue (ValueDesc -> ExternalValue)
-> Maybe ValueDesc -> Maybe ExternalValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Param FParamMem -> Signedness -> Maybe ValueDesc
mkValueDesc Param FParamMem
fparam Signedness
Imp.TypeDirect)
          [ExternalValue] -> [ExternalValue] -> [ExternalValue]
forall a. [a] -> [a] -> [a]
++ [EntryPointType] -> [Param FParamMem] -> [ExternalValue]
mkExts [EntryPointType]
epts [Param FParamMem]
fparams
      mkExts [EntryPointType]
_ [Param FParamMem]
_ = []

  ([Param], [ArrayDecl], [ExternalValue])
-> ImpM lore r op ([Param], [ArrayDecl], [ExternalValue])
forall (m :: * -> *) a. Monad m => a -> m a
return ([Param]
inparams, [ArrayDecl]
arrayds, [EntryPointType] -> [Param FParamMem] -> [ExternalValue]
mkExts [EntryPointType]
orig_epts [Param FParamMem]
val_params)
  where
    isArrayDecl :: VName -> ArrayDecl -> Bool
isArrayDecl VName
x (ArrayDecl VName
y PrimType
_ MemLocation
_) = VName
x VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
y

compileOutParams ::
  Mem lore =>
  [RetType lore] ->
  [EntryPointType] ->
  ImpM lore r op ([Imp.ExternalValue], [Imp.Param], Destination)
compileOutParams :: forall lore r op.
Mem lore =>
[RetType lore]
-> [EntryPointType]
-> ImpM lore r op ([ExternalValue], [Param], Destination)
compileOutParams [RetType lore]
orig_rts [EntryPointType]
orig_epts = do
  (([ExternalValue]
extvs, [ValueDestination]
dests), ([Param]
outparams, Map Int ValueDestination
ctx_dests)) <-
    WriterT
  ([Param], Map Int ValueDestination)
  (ImpM lore r op)
  ([ExternalValue], [ValueDestination])
-> ImpM
     lore
     r
     op
     (([ExternalValue], [ValueDestination]),
      ([Param], Map Int ValueDestination))
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT
   ([Param], Map Int ValueDestination)
   (ImpM lore r op)
   ([ExternalValue], [ValueDestination])
 -> ImpM
      lore
      r
      op
      (([ExternalValue], [ValueDestination]),
       ([Param], Map Int ValueDestination)))
-> WriterT
     ([Param], Map Int ValueDestination)
     (ImpM lore r op)
     ([ExternalValue], [ValueDestination])
-> ImpM
     lore
     r
     op
     (([ExternalValue], [ValueDestination]),
      ([Param], Map Int ValueDestination))
forall a b. (a -> b) -> a -> b
$ StateT
  (Map Any Any, Map Int VName)
  (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
  ([ExternalValue], [ValueDestination])
-> (Map Any Any, Map Int VName)
-> WriterT
     ([Param], Map Int ValueDestination)
     (ImpM lore r op)
     ([ExternalValue], [ValueDestination])
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT ([EntryPointType]
-> [RetTypeMem]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ([ExternalValue], [ValueDestination])
mkExts [EntryPointType]
orig_epts [RetType lore]
[RetTypeMem]
orig_rts) (Map Any Any
forall k a. Map k a
M.empty, Map Int VName
forall k a. Map k a
M.empty)
  let ctx_dests' :: [ValueDestination]
ctx_dests' = ((Int, ValueDestination) -> ValueDestination)
-> [(Int, ValueDestination)] -> [ValueDestination]
forall a b. (a -> b) -> [a] -> [b]
map (Int, ValueDestination) -> ValueDestination
forall a b. (a, b) -> b
snd ([(Int, ValueDestination)] -> [ValueDestination])
-> [(Int, ValueDestination)] -> [ValueDestination]
forall a b. (a -> b) -> a -> b
$ ((Int, ValueDestination) -> Int)
-> [(Int, ValueDestination)] -> [(Int, ValueDestination)]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn (Int, ValueDestination) -> Int
forall a b. (a, b) -> a
fst ([(Int, ValueDestination)] -> [(Int, ValueDestination)])
-> [(Int, ValueDestination)] -> [(Int, ValueDestination)]
forall a b. (a -> b) -> a -> b
$ Map Int ValueDestination -> [(Int, ValueDestination)]
forall k a. Map k a -> [(k, a)]
M.toList Map Int ValueDestination
ctx_dests
  ([ExternalValue], [Param], Destination)
-> ImpM lore r op ([ExternalValue], [Param], Destination)
forall (m :: * -> *) a. Monad m => a -> m a
return ([ExternalValue]
extvs, [Param]
outparams, Maybe Int -> [ValueDestination] -> Destination
Destination Maybe Int
forall a. Maybe a
Nothing ([ValueDestination] -> Destination)
-> [ValueDestination] -> Destination
forall a b. (a -> b) -> a -> b
$ [ValueDestination]
ctx_dests' [ValueDestination] -> [ValueDestination] -> [ValueDestination]
forall a. Semigroup a => a -> a -> a
<> [ValueDestination]
dests)
  where
    imp :: ImpM lore r op a
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     a
imp = WriterT ([Param], Map Int ValueDestination) (ImpM lore r op) a
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op) a
 -> StateT
      (Map Any Any, Map Int VName)
      (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
      a)
-> (ImpM lore r op a
    -> WriterT ([Param], Map Int ValueDestination) (ImpM lore r op) a)
-> ImpM lore r op a
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ImpM lore r op a
-> WriterT ([Param], Map Int ValueDestination) (ImpM lore r op) a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift

    mkExts :: [EntryPointType]
-> [RetTypeMem]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ([ExternalValue], [ValueDestination])
mkExts (TypeOpaque [Char]
desc Int
n : [EntryPointType]
epts) [RetTypeMem]
rts = do
      let ([RetTypeMem]
rts', [RetTypeMem]
rest) = Int -> [RetTypeMem] -> ([RetTypeMem], [RetTypeMem])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
n [RetTypeMem]
rts
      ([ValueDesc]
evs, [ValueDestination]
dests) <- [(ValueDesc, ValueDestination)]
-> ([ValueDesc], [ValueDestination])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(ValueDesc, ValueDestination)]
 -> ([ValueDesc], [ValueDestination]))
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     [(ValueDesc, ValueDestination)]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ([ValueDesc], [ValueDestination])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (RetTypeMem
 -> Signedness
 -> StateT
      (Map Any Any, Map Int VName)
      (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
      (ValueDesc, ValueDestination))
-> [RetTypeMem]
-> [Signedness]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     [(ValueDesc, ValueDestination)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM RetTypeMem
-> Signedness
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     (ValueDesc, ValueDestination)
mkParam [RetTypeMem]
rts' (Signedness -> [Signedness]
forall a. a -> [a]
repeat Signedness
Imp.TypeDirect)
      ([ExternalValue]
more_values, [ValueDestination]
more_dests) <- [EntryPointType]
-> [RetTypeMem]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ([ExternalValue], [ValueDestination])
mkExts [EntryPointType]
epts [RetTypeMem]
rest
      ([ExternalValue], [ValueDestination])
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ([ExternalValue], [ValueDestination])
forall (m :: * -> *) a. Monad m => a -> m a
return
        ( [Char] -> [ValueDesc] -> ExternalValue
Imp.OpaqueValue [Char]
desc [ValueDesc]
evs ExternalValue -> [ExternalValue] -> [ExternalValue]
forall a. a -> [a] -> [a]
: [ExternalValue]
more_values,
          [ValueDestination]
dests [ValueDestination] -> [ValueDestination] -> [ValueDestination]
forall a. [a] -> [a] -> [a]
++ [ValueDestination]
more_dests
        )
    mkExts (EntryPointType
TypeUnsigned : [EntryPointType]
epts) (RetTypeMem
rt : [RetTypeMem]
rts) = do
      (ValueDesc
ev, ValueDestination
dest) <- RetTypeMem
-> Signedness
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     (ValueDesc, ValueDestination)
mkParam RetTypeMem
rt Signedness
Imp.TypeUnsigned
      ([ExternalValue]
more_values, [ValueDestination]
more_dests) <- [EntryPointType]
-> [RetTypeMem]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ([ExternalValue], [ValueDestination])
mkExts [EntryPointType]
epts [RetTypeMem]
rts
      ([ExternalValue], [ValueDestination])
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ([ExternalValue], [ValueDestination])
forall (m :: * -> *) a. Monad m => a -> m a
return
        ( ValueDesc -> ExternalValue
Imp.TransparentValue ValueDesc
ev ExternalValue -> [ExternalValue] -> [ExternalValue]
forall a. a -> [a] -> [a]
: [ExternalValue]
more_values,
          ValueDestination
dest ValueDestination -> [ValueDestination] -> [ValueDestination]
forall a. a -> [a] -> [a]
: [ValueDestination]
more_dests
        )
    mkExts (EntryPointType
TypeDirect : [EntryPointType]
epts) (RetTypeMem
rt : [RetTypeMem]
rts) = do
      (ValueDesc
ev, ValueDestination
dest) <- RetTypeMem
-> Signedness
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     (ValueDesc, ValueDestination)
mkParam RetTypeMem
rt Signedness
Imp.TypeDirect
      ([ExternalValue]
more_values, [ValueDestination]
more_dests) <- [EntryPointType]
-> [RetTypeMem]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ([ExternalValue], [ValueDestination])
mkExts [EntryPointType]
epts [RetTypeMem]
rts
      ([ExternalValue], [ValueDestination])
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ([ExternalValue], [ValueDestination])
forall (m :: * -> *) a. Monad m => a -> m a
return
        ( ValueDesc -> ExternalValue
Imp.TransparentValue ValueDesc
ev ExternalValue -> [ExternalValue] -> [ExternalValue]
forall a. a -> [a] -> [a]
: [ExternalValue]
more_values,
          ValueDestination
dest ValueDestination -> [ValueDestination] -> [ValueDestination]
forall a. a -> [a] -> [a]
: [ValueDestination]
more_dests
        )
    mkExts [EntryPointType]
_ [RetTypeMem]
_ = ([ExternalValue], [ValueDestination])
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ([ExternalValue], [ValueDestination])
forall (m :: * -> *) a. Monad m => a -> m a
return ([], [])

    mkParam :: RetTypeMem
-> Signedness
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     (ValueDesc, ValueDestination)
mkParam MemMem {} Signedness
_ =
      [Char]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     (ValueDesc, ValueDestination)
forall a. HasCallStack => [Char] -> a
error [Char]
"Functions may not explicitly return memory blocks."
    mkParam MemAcc {} Signedness
_ =
      [Char]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     (ValueDesc, ValueDestination)
forall a. HasCallStack => [Char] -> a
error [Char]
"Functions may not return accumulators."
    mkParam (MemPrim PrimType
t) Signedness
ept = do
      VName
out <- ImpM lore r op VName
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     VName
forall {a}.
ImpM lore r op a
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     a
imp (ImpM lore r op VName
 -> StateT
      (Map Any Any, Map Int VName)
      (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
      VName)
-> ImpM lore r op VName
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     VName
forall a b. (a -> b) -> a -> b
$ [Char] -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"scalar_out"
      ([Param], Map Int ValueDestination)
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([VName -> PrimType -> Param
Imp.ScalarParam VName
out PrimType
t], Map Int ValueDestination
forall a. Monoid a => a
mempty)
      (ValueDesc, ValueDestination)
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     (ValueDesc, ValueDestination)
forall (m :: * -> *) a. Monad m => a -> m a
return (PrimType -> Signedness -> VName -> ValueDesc
Imp.ScalarValue PrimType
t Signedness
ept VName
out, VName -> ValueDestination
ScalarDestination VName
out)
    mkParam (MemArray PrimType
t ShapeBase (Ext SubExp)
shape Uniqueness
_ MemReturn
dec) Signedness
ept = do
      Space
space <- (Env lore r op -> Space)
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     Space
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Space
forall lore r op. Env lore r op -> Space
envDefaultSpace
      VName
memout <- case MemReturn
dec of
        ReturnsNewBlock Space
_ Int
x ExtIxFun
_ixfun -> do
          VName
memout <- ImpM lore r op VName
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     VName
forall {a}.
ImpM lore r op a
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     a
imp (ImpM lore r op VName
 -> StateT
      (Map Any Any, Map Int VName)
      (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
      VName)
-> ImpM lore r op VName
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     VName
forall a b. (a -> b) -> a -> b
$ [Char] -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"out_mem"
          ([Param], Map Int ValueDestination)
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell
            ( [VName -> Space -> Param
Imp.MemParam VName
memout Space
space],
              Int -> ValueDestination -> Map Int ValueDestination
forall k a. k -> a -> Map k a
M.singleton Int
x (ValueDestination -> Map Int ValueDestination)
-> ValueDestination -> Map Int ValueDestination
forall a b. (a -> b) -> a -> b
$ VName -> ValueDestination
MemoryDestination VName
memout
            )
          VName
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
memout
        ReturnsInBlock VName
memout ExtIxFun
_ ->
          VName
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
memout
      [SubExp]
resultshape <- (Ext SubExp
 -> StateT
      (Map Any Any, Map Int VName)
      (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
      SubExp)
-> [Ext SubExp]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Ext SubExp
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     SubExp
inspectExtSize ([Ext SubExp]
 -> StateT
      (Map Any Any, Map Int VName)
      (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
      [SubExp])
-> [Ext SubExp]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     [SubExp]
forall a b. (a -> b) -> a -> b
$ ShapeBase (Ext SubExp) -> [Ext SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase (Ext SubExp)
shape
      (ValueDesc, ValueDestination)
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     (ValueDesc, ValueDestination)
forall (m :: * -> *) a. Monad m => a -> m a
return
        ( VName -> Space -> PrimType -> Signedness -> [SubExp] -> ValueDesc
Imp.ArrayValue VName
memout Space
space PrimType
t Signedness
ept [SubExp]
resultshape,
          Maybe MemLocation -> ValueDestination
ArrayDestination Maybe MemLocation
forall a. Maybe a
Nothing
        )

    inspectExtSize :: Ext SubExp
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     SubExp
inspectExtSize (Ext Int
x) = do
      (Map Any Any
memseen, Map Int VName
arrseen) <- StateT
  (Map Any Any, Map Int VName)
  (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
  (Map Any Any, Map Int VName)
forall s (m :: * -> *). MonadState s m => m s
get
      case Int -> Map Int VName -> Maybe VName
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Int
x Map Int VName
arrseen of
        Maybe VName
Nothing -> do
          VName
out <- ImpM lore r op VName
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     VName
forall {a}.
ImpM lore r op a
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     a
imp (ImpM lore r op VName
 -> StateT
      (Map Any Any, Map Int VName)
      (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
      VName)
-> ImpM lore r op VName
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     VName
forall a b. (a -> b) -> a -> b
$ [Char] -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"out_arrsize"
          ([Param], Map Int ValueDestination)
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell
            ( [VName -> PrimType -> Param
Imp.ScalarParam VName
out PrimType
int64],
              Int -> ValueDestination -> Map Int ValueDestination
forall k a. k -> a -> Map k a
M.singleton Int
x (ValueDestination -> Map Int ValueDestination)
-> ValueDestination -> Map Int ValueDestination
forall a b. (a -> b) -> a -> b
$ VName -> ValueDestination
ScalarDestination VName
out
            )
          (Map Any Any, Map Int VName)
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (Map Any Any
memseen, Int -> VName -> Map Int VName -> Map Int VName
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Int
x VName
out Map Int VName
arrseen)
          SubExp
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
 -> StateT
      (Map Any Any, Map Int VName)
      (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
      SubExp)
-> SubExp
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
out
        Just VName
out ->
          SubExp
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
 -> StateT
      (Map Any Any, Map Int VName)
      (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
      SubExp)
-> SubExp
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
out
    inspectExtSize (Free SubExp
se) =
      SubExp
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
se

compileFunDef ::
  Mem lore =>
  FunDef lore ->
  ImpM lore r op ()
compileFunDef :: forall lore r op. Mem lore => FunDef lore -> ImpM lore r op ()
compileFunDef (FunDef Maybe EntryPoint
entry Attrs
_ Name
fname [RetType lore]
rettype [FParam lore]
params BodyT lore
body) =
  (Env lore r op -> Env lore r op)
-> ImpM lore r op () -> ImpM lore r op ()
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (\Env lore r op
env -> Env lore r op
env {envFunction :: Maybe Name
envFunction = Name -> Maybe Name
forall a. a -> Maybe a
Just Name
fname}) (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ do
    (([Param]
outparams, [Param]
inparams, [ExternalValue]
results, [ExternalValue]
args), Code op
body') <- ImpM lore r op ([Param], [Param], [ExternalValue], [ExternalValue])
-> ImpM
     lore
     r
     op
     (([Param], [Param], [ExternalValue], [ExternalValue]), Code op)
forall lore r op a. ImpM lore r op a -> ImpM lore r op (a, Code op)
collect' ImpM lore r op ([Param], [Param], [ExternalValue], [ExternalValue])
compile
    Name -> Function op -> ImpM lore r op ()
forall op lore r. Name -> Function op -> ImpM lore r op ()
emitFunction Name
fname (Function op -> ImpM lore r op ())
-> Function op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Bool
-> [Param]
-> [Param]
-> Code op
-> [ExternalValue]
-> [ExternalValue]
-> Function op
forall a.
Bool
-> [Param]
-> [Param]
-> Code a
-> [ExternalValue]
-> [ExternalValue]
-> FunctionT a
Imp.Function (Maybe EntryPoint -> Bool
forall a. Maybe a -> Bool
isJust Maybe EntryPoint
entry) [Param]
outparams [Param]
inparams Code op
body' [ExternalValue]
results [ExternalValue]
args
  where
    params_entry :: [EntryPointType]
params_entry = [EntryPointType]
-> (EntryPoint -> [EntryPointType])
-> Maybe EntryPoint
-> [EntryPointType]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (Int -> EntryPointType -> [EntryPointType]
forall a. Int -> a -> [a]
replicate ([Param FParamMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [FParam lore]
[Param FParamMem]
params) EntryPointType
TypeDirect) EntryPoint -> [EntryPointType]
forall a b. (a, b) -> a
fst Maybe EntryPoint
entry
    ret_entry :: [EntryPointType]
ret_entry = [EntryPointType]
-> (EntryPoint -> [EntryPointType])
-> Maybe EntryPoint
-> [EntryPointType]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (Int -> EntryPointType -> [EntryPointType]
forall a. Int -> a -> [a]
replicate ([RetTypeMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [RetType lore]
[RetTypeMem]
rettype) EntryPointType
TypeDirect) EntryPoint -> [EntryPointType]
forall a b. (a, b) -> b
snd Maybe EntryPoint
entry
    compile :: ImpM lore r op ([Param], [Param], [ExternalValue], [ExternalValue])
compile = do
      ([Param]
inparams, [ArrayDecl]
arrayds, [ExternalValue]
args) <- [FParam lore]
-> [EntryPointType]
-> ImpM lore r op ([Param], [ArrayDecl], [ExternalValue])
forall lore r op.
Mem lore =>
[FParam lore]
-> [EntryPointType]
-> ImpM lore r op ([Param], [ArrayDecl], [ExternalValue])
compileInParams [FParam lore]
params [EntryPointType]
params_entry
      ([ExternalValue]
results, [Param]
outparams, Destination Maybe Int
_ [ValueDestination]
dests) <- [RetType lore]
-> [EntryPointType]
-> ImpM lore r op ([ExternalValue], [Param], Destination)
forall lore r op.
Mem lore =>
[RetType lore]
-> [EntryPointType]
-> ImpM lore r op ([ExternalValue], [Param], Destination)
compileOutParams [RetType lore]
rettype [EntryPointType]
ret_entry
      [FParam lore] -> ImpM lore r op ()
forall lore r op. Mem lore => [FParam lore] -> ImpM lore r op ()
addFParams [FParam lore]
params
      [ArrayDecl] -> ImpM lore r op ()
forall lore r op. [ArrayDecl] -> ImpM lore r op ()
addArrays [ArrayDecl]
arrayds

      let Body BodyDec lore
_ Stms lore
stms [SubExp]
ses = BodyT lore
body
      Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms ([SubExp] -> Names
forall a. FreeIn a => a -> Names
freeIn [SubExp]
ses) Stms lore
stms (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
        [(ValueDestination, SubExp)]
-> ((ValueDestination, SubExp) -> ImpM lore r op ())
-> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([ValueDestination] -> [SubExp] -> [(ValueDestination, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [ValueDestination]
dests [SubExp]
ses) (((ValueDestination, SubExp) -> ImpM lore r op ())
 -> ImpM lore r op ())
-> ((ValueDestination, SubExp) -> ImpM lore r op ())
-> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(ValueDestination
d, SubExp
se) -> ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIMDest ValueDestination
d [] SubExp
se []

      ([Param], [Param], [ExternalValue], [ExternalValue])
-> ImpM
     lore r op ([Param], [Param], [ExternalValue], [ExternalValue])
forall (m :: * -> *) a. Monad m => a -> m a
return ([Param]
outparams, [Param]
inparams, [ExternalValue]
results, [ExternalValue]
args)

compileBody :: (Mem lore) => Pattern lore -> Body lore -> ImpM lore r op ()
compileBody :: forall lore r op.
Mem lore =>
Pattern lore -> Body lore -> ImpM lore r op ()
compileBody Pattern lore
pat (Body BodyDec lore
_ Stms lore
bnds [SubExp]
ses) = do
  Destination Maybe Int
_ [ValueDestination]
dests <- Pattern lore -> ImpM lore r op Destination
forall lore r op.
Mem lore =>
Pattern lore -> ImpM lore r op Destination
destinationFromPattern Pattern lore
pat
  Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms ([SubExp] -> Names
forall a. FreeIn a => a -> Names
freeIn [SubExp]
ses) Stms lore
bnds (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
    [(ValueDestination, SubExp)]
-> ((ValueDestination, SubExp) -> ImpM lore r op ())
-> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([ValueDestination] -> [SubExp] -> [(ValueDestination, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [ValueDestination]
dests [SubExp]
ses) (((ValueDestination, SubExp) -> ImpM lore r op ())
 -> ImpM lore r op ())
-> ((ValueDestination, SubExp) -> ImpM lore r op ())
-> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(ValueDestination
d, SubExp
se) -> ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIMDest ValueDestination
d [] SubExp
se []

compileBody' :: [Param dec] -> Body lore -> ImpM lore r op ()
compileBody' :: forall dec lore r op. [Param dec] -> Body lore -> ImpM lore r op ()
compileBody' [Param dec]
params (Body BodyDec lore
_ Stms lore
bnds [SubExp]
ses) =
  Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms ([SubExp] -> Names
forall a. FreeIn a => a -> Names
freeIn [SubExp]
ses) Stms lore
bnds (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
    [(Param dec, SubExp)]
-> ((Param dec, SubExp) -> ImpM lore r op ()) -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param dec] -> [SubExp] -> [(Param dec, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param dec]
params [SubExp]
ses) (((Param dec, SubExp) -> ImpM lore r op ()) -> ImpM lore r op ())
-> ((Param dec, SubExp) -> ImpM lore r op ()) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(Param dec
param, SubExp
se) -> VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
param) [] SubExp
se []

compileLoopBody :: Typed dec => [Param dec] -> Body lore -> ImpM lore r op ()
compileLoopBody :: forall dec lore r op.
Typed dec =>
[Param dec] -> Body lore -> ImpM lore r op ()
compileLoopBody [Param dec]
mergeparams (Body BodyDec lore
_ Stms lore
bnds [SubExp]
ses) = do
  -- We cannot write the results to the merge parameters immediately,
  -- as some of the results may actually *be* merge parameters, and
  -- would thus be clobbered.  Therefore, we first copy to new
  -- variables mirroring the merge parameters, and then copy this
  -- buffer to the merge parameters.  This is efficient, because the
  -- operations are all scalar operations.
  [VName]
tmpnames <- (Param dec -> ImpM lore r op VName)
-> [Param dec] -> ImpM lore r op [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([Char] -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName ([Char] -> ImpM lore r op VName)
-> (Param dec -> [Char]) -> Param dec -> ImpM lore r op VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
"_tmp") ShowS -> (Param dec -> [Char]) -> Param dec -> [Char]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> [Char]
baseString (VName -> [Char]) -> (Param dec -> VName) -> Param dec -> [Char]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param dec -> VName
forall dec. Param dec -> VName
paramName) [Param dec]
mergeparams
  Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms ([SubExp] -> Names
forall a. FreeIn a => a -> Names
freeIn [SubExp]
ses) Stms lore
bnds (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ do
    [ImpM lore r op ()]
copy_to_merge_params <- [(Param dec, VName, SubExp)]
-> ((Param dec, VName, SubExp)
    -> ImpM lore r op (ImpM lore r op ()))
-> ImpM lore r op [ImpM lore r op ()]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Param dec] -> [VName] -> [SubExp] -> [(Param dec, VName, SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Param dec]
mergeparams [VName]
tmpnames [SubExp]
ses) (((Param dec, VName, SubExp) -> ImpM lore r op (ImpM lore r op ()))
 -> ImpM lore r op [ImpM lore r op ()])
-> ((Param dec, VName, SubExp)
    -> ImpM lore r op (ImpM lore r op ()))
-> ImpM lore r op [ImpM lore r op ()]
forall a b. (a -> b) -> a -> b
$ \(Param dec
p, VName
tmp, SubExp
se) ->
      case Param dec -> Type
forall t. Typed t => t -> Type
typeOf Param dec
p of
        Prim PrimType
pt -> do
          Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Volatility -> PrimType -> Code op
forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
tmp Volatility
Imp.Nonvolatile PrimType
pt
          Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
tmp (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
pt SubExp
se
          ImpM lore r op () -> ImpM lore r op (ImpM lore r op ())
forall (m :: * -> *) a. Monad m => a -> m a
return (ImpM lore r op () -> ImpM lore r op (ImpM lore r op ()))
-> ImpM lore r op () -> ImpM lore r op (ImpM lore r op ())
forall a b. (a -> b) -> a -> b
$ Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar (Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p) (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
tmp PrimType
pt
        Mem Space
space | Var VName
v <- SubExp
se -> do
          Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Code op
forall a. VName -> Space -> Code a
Imp.DeclareMem VName
tmp Space
space
          Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> Space -> Code op
forall a. VName -> VName -> Space -> Code a
Imp.SetMem VName
tmp VName
v Space
space
          ImpM lore r op () -> ImpM lore r op (ImpM lore r op ())
forall (m :: * -> *) a. Monad m => a -> m a
return (ImpM lore r op () -> ImpM lore r op (ImpM lore r op ()))
-> ImpM lore r op () -> ImpM lore r op (ImpM lore r op ())
forall a b. (a -> b) -> a -> b
$ Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> Space -> Code op
forall a. VName -> VName -> Space -> Code a
Imp.SetMem (Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p) VName
tmp Space
space
        Type
_ -> ImpM lore r op () -> ImpM lore r op (ImpM lore r op ())
forall (m :: * -> *) a. Monad m => a -> m a
return (ImpM lore r op () -> ImpM lore r op (ImpM lore r op ()))
-> ImpM lore r op () -> ImpM lore r op (ImpM lore r op ())
forall a b. (a -> b) -> a -> b
$ () -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    [ImpM lore r op ()] -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ [ImpM lore r op ()]
copy_to_merge_params

compileStms :: Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms :: forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
alive_after_stms Stms lore
all_stms ImpM lore r op ()
m = do
  StmsCompiler lore r op
cb <- (Env lore r op -> StmsCompiler lore r op)
-> ImpM lore r op (StmsCompiler lore r op)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> StmsCompiler lore r op
forall lore r op. Env lore r op -> StmsCompiler lore r op
envStmsCompiler
  StmsCompiler lore r op
cb Names
alive_after_stms Stms lore
all_stms ImpM lore r op ()
m

defCompileStms ::
  (Mem lore, FreeIn op) =>
  Names ->
  Stms lore ->
  ImpM lore r op () ->
  ImpM lore r op ()
defCompileStms :: forall lore op r.
(Mem lore, FreeIn op) =>
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
defCompileStms Names
alive_after_stms Stms lore
all_stms ImpM lore r op ()
m =
  -- We keep track of any memory blocks produced by the statements,
  -- and after the last time that memory block is used, we insert a
  -- Free.  This is very conservative, but can cut down on lifetimes
  -- in some cases.
  ImpM lore r op Names -> ImpM lore r op ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ImpM lore r op Names -> ImpM lore r op ())
-> ImpM lore r op Names -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Set (VName, Space) -> [Stm lore] -> ImpM lore r op Names
compileStms' Set (VName, Space)
forall a. Monoid a => a
mempty ([Stm lore] -> ImpM lore r op Names)
-> [Stm lore] -> ImpM lore r op Names
forall a b. (a -> b) -> a -> b
$ Stms lore -> [Stm lore]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms lore
all_stms
  where
    compileStms' :: Set (VName, Space) -> [Stm lore] -> ImpM lore r op Names
compileStms' Set (VName, Space)
allocs (Let Pattern lore
pat StmAux (ExpDec lore)
aux Exp lore
e : [Stm lore]
bs) = do
      Maybe (Exp lore) -> [PatElem lore] -> ImpM lore r op ()
forall lore r op.
Mem lore =>
Maybe (Exp lore) -> [PatElem lore] -> ImpM lore r op ()
dVars (Exp lore -> Maybe (Exp lore)
forall a. a -> Maybe a
Just Exp lore
e) (PatternT LParamMem -> [PatElemT LParamMem]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern lore
PatternT LParamMem
pat)

      Code op
e_code <-
        Attrs -> ImpM lore r op (Code op) -> ImpM lore r op (Code op)
forall lore r op a. Attrs -> ImpM lore r op a -> ImpM lore r op a
localAttrs (StmAux (ExpDec lore) -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux (ExpDec lore)
aux) (ImpM lore r op (Code op) -> ImpM lore r op (Code op))
-> ImpM lore r op (Code op) -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$
          ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (ImpM lore r op () -> ImpM lore r op (Code op))
-> ImpM lore r op () -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$ Pattern lore -> Exp lore -> ImpM lore r op ()
forall lore r op. Pattern lore -> Exp lore -> ImpM lore r op ()
compileExp Pattern lore
pat Exp lore
e
      (Names
live_after, Code op
bs_code) <- ImpM lore r op Names -> ImpM lore r op (Names, Code op)
forall lore r op a. ImpM lore r op a -> ImpM lore r op (a, Code op)
collect' (ImpM lore r op Names -> ImpM lore r op (Names, Code op))
-> ImpM lore r op Names -> ImpM lore r op (Names, Code op)
forall a b. (a -> b) -> a -> b
$ Set (VName, Space) -> [Stm lore] -> ImpM lore r op Names
compileStms' (PatternT LParamMem -> Set (VName, Space)
patternAllocs Pattern lore
PatternT LParamMem
pat Set (VName, Space) -> Set (VName, Space) -> Set (VName, Space)
forall a. Semigroup a => a -> a -> a
<> Set (VName, Space)
allocs) [Stm lore]
bs
      let dies_here :: VName -> Bool
dies_here VName
v =
            Bool -> Bool
not (VName
v VName -> Names -> Bool
`nameIn` Names
live_after)
              Bool -> Bool -> Bool
&& VName
v VName -> Names -> Bool
`nameIn` Code op -> Names
forall a. FreeIn a => a -> Names
freeIn Code op
e_code
          to_free :: Set (VName, Space)
to_free = ((VName, Space) -> Bool)
-> Set (VName, Space) -> Set (VName, Space)
forall a. (a -> Bool) -> Set a -> Set a
S.filter (VName -> Bool
dies_here (VName -> Bool)
-> ((VName, Space) -> VName) -> (VName, Space) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, Space) -> VName
forall a b. (a, b) -> a
fst) Set (VName, Space)
allocs

      Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit Code op
e_code
      ((VName, Space) -> ImpM lore r op ())
-> Set (VName, Space) -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ())
-> ((VName, Space) -> Code op)
-> (VName, Space)
-> ImpM lore r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> Space -> Code op) -> (VName, Space) -> Code op
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry VName -> Space -> Code op
forall a. VName -> Space -> Code a
Imp.Free) Set (VName, Space)
to_free
      Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit Code op
bs_code

      Names -> ImpM lore r op Names
forall (m :: * -> *) a. Monad m => a -> m a
return (Names -> ImpM lore r op Names) -> Names -> ImpM lore r op Names
forall a b. (a -> b) -> a -> b
$ Code op -> Names
forall a. FreeIn a => a -> Names
freeIn Code op
e_code Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
live_after
    compileStms' Set (VName, Space)
_ [] = do
      Code op
code <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect ImpM lore r op ()
m
      Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit Code op
code
      Names -> ImpM lore r op Names
forall (m :: * -> *) a. Monad m => a -> m a
return (Names -> ImpM lore r op Names) -> Names -> ImpM lore r op Names
forall a b. (a -> b) -> a -> b
$ Code op -> Names
forall a. FreeIn a => a -> Names
freeIn Code op
code Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
alive_after_stms

    patternAllocs :: PatternT LParamMem -> Set (VName, Space)
patternAllocs = [(VName, Space)] -> Set (VName, Space)
forall a. Ord a => [a] -> Set a
S.fromList ([(VName, Space)] -> Set (VName, Space))
-> (PatternT LParamMem -> [(VName, Space)])
-> PatternT LParamMem
-> Set (VName, Space)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElemT LParamMem -> Maybe (VName, Space))
-> [PatElemT LParamMem] -> [(VName, Space)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe PatElemT LParamMem -> Maybe (VName, Space)
forall {dec}. Typed dec => PatElemT dec -> Maybe (VName, Space)
isMemPatElem ([PatElemT LParamMem] -> [(VName, Space)])
-> (PatternT LParamMem -> [PatElemT LParamMem])
-> PatternT LParamMem
-> [(VName, Space)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatternT LParamMem -> [PatElemT LParamMem]
forall dec. PatternT dec -> [PatElemT dec]
patternElements
    isMemPatElem :: PatElemT dec -> Maybe (VName, Space)
isMemPatElem PatElemT dec
pe = case PatElemT dec -> Type
forall dec. Typed dec => PatElemT dec -> Type
patElemType PatElemT dec
pe of
      Mem Space
space -> (VName, Space) -> Maybe (VName, Space)
forall a. a -> Maybe a
Just (PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT dec
pe, Space
space)
      Type
_ -> Maybe (VName, Space)
forall a. Maybe a
Nothing

compileExp :: Pattern lore -> Exp lore -> ImpM lore r op ()
compileExp :: forall lore r op. Pattern lore -> Exp lore -> ImpM lore r op ()
compileExp Pattern lore
pat Exp lore
e = do
  Pattern lore -> Exp lore -> ImpM lore r op ()
ec <- (Env lore r op -> Pattern lore -> Exp lore -> ImpM lore r op ())
-> ImpM lore r op (Pattern lore -> Exp lore -> ImpM lore r op ())
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Pattern lore -> Exp lore -> ImpM lore r op ()
forall lore r op. Env lore r op -> ExpCompiler lore r op
envExpCompiler
  Pattern lore -> Exp lore -> ImpM lore r op ()
ec Pattern lore
pat Exp lore
e

defCompileExp ::
  (Mem lore) =>
  Pattern lore ->
  Exp lore ->
  ImpM lore r op ()
defCompileExp :: forall lore r op.
Mem lore =>
Pattern lore -> Exp lore -> ImpM lore r op ()
defCompileExp Pattern lore
pat (If SubExp
cond BodyT lore
tbranch BodyT lore
fbranch IfDec (BranchType lore)
_) =
  TExp Bool
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
TExp Bool
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf (SubExp -> TExp Bool
forall a. ToExp a => a -> TExp Bool
toBoolExp SubExp
cond) (Pattern lore -> BodyT lore -> ImpM lore r op ()
forall lore r op.
Mem lore =>
Pattern lore -> Body lore -> ImpM lore r op ()
compileBody Pattern lore
pat BodyT lore
tbranch) (Pattern lore -> BodyT lore -> ImpM lore r op ()
forall lore r op.
Mem lore =>
Pattern lore -> Body lore -> ImpM lore r op ()
compileBody Pattern lore
pat BodyT lore
fbranch)
defCompileExp Pattern lore
pat (Apply Name
fname [(SubExp, Diet)]
args [RetType lore]
_ (Safety, SrcLoc, [SrcLoc])
_) = do
  Destination
dest <- Pattern lore -> ImpM lore r op Destination
forall lore r op.
Mem lore =>
Pattern lore -> ImpM lore r op Destination
destinationFromPattern Pattern lore
pat
  [VName]
targets <- Destination -> ImpM lore r op [VName]
forall lore r op. Destination -> ImpM lore r op [VName]
funcallTargets Destination
dest
  [Arg]
args' <- [Maybe Arg] -> [Arg]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe Arg] -> [Arg])
-> ImpM lore r op [Maybe Arg] -> ImpM lore r op [Arg]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((SubExp, Diet) -> ImpM lore r op (Maybe Arg))
-> [(SubExp, Diet)] -> ImpM lore r op [Maybe Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SubExp, Diet) -> ImpM lore r op (Maybe Arg)
forall {m :: * -> *} {t} {b}.
(Monad m, HasScope t m) =>
(SubExp, b) -> m (Maybe Arg)
compileArg [(SubExp, Diet)]
args
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Name -> [Arg] -> Code op
forall a. [VName] -> Name -> [Arg] -> Code a
Imp.Call [VName]
targets Name
fname [Arg]
args'
  where
    compileArg :: (SubExp, b) -> m (Maybe Arg)
compileArg (SubExp
se, b
_) = do
      Type
t <- SubExp -> m Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
se
      case (SubExp
se, Type
t) of
        (SubExp
_, Prim PrimType
pt) -> Maybe Arg -> m (Maybe Arg)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Arg -> m (Maybe Arg)) -> Maybe Arg -> m (Maybe Arg)
forall a b. (a -> b) -> a -> b
$ Arg -> Maybe Arg
forall a. a -> Maybe a
Just (Arg -> Maybe Arg) -> Arg -> Maybe Arg
forall a b. (a -> b) -> a -> b
$ Exp -> Arg
Imp.ExpArg (Exp -> Arg) -> Exp -> Arg
forall a b. (a -> b) -> a -> b
$ PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
pt SubExp
se
        (Var VName
v, Mem {}) -> Maybe Arg -> m (Maybe Arg)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Arg -> m (Maybe Arg)) -> Maybe Arg -> m (Maybe Arg)
forall a b. (a -> b) -> a -> b
$ Arg -> Maybe Arg
forall a. a -> Maybe a
Just (Arg -> Maybe Arg) -> Arg -> Maybe Arg
forall a b. (a -> b) -> a -> b
$ VName -> Arg
Imp.MemArg VName
v
        (SubExp, Type)
_ -> Maybe Arg -> m (Maybe Arg)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Arg
forall a. Maybe a
Nothing
defCompileExp Pattern lore
pat (BasicOp BasicOp
op) = Pattern lore -> BasicOp -> ImpM lore r op ()
forall lore r op.
Mem lore =>
Pattern lore -> BasicOp -> ImpM lore r op ()
defCompileBasicOp Pattern lore
pat BasicOp
op
defCompileExp Pattern lore
pat (DoLoop [(FParam lore, SubExp)]
ctx [(FParam lore, SubExp)]
val LoopForm lore
form BodyT lore
body) = do
  Attrs
attrs <- ImpM lore r op Attrs
forall lore r op. ImpM lore r op Attrs
askAttrs
  Bool -> ImpM lore r op () -> ImpM lore r op ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Attr
"unroll" Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs) (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
    SrcLoc -> [SrcLoc] -> [Char] -> ImpM lore r op ()
forall loc lore r op.
Located loc =>
loc -> [loc] -> [Char] -> ImpM lore r op ()
warn (SrcLoc
forall a. IsLocation a => a
noLoc :: SrcLoc) [] [Char]
"#[unroll] on loop with unknown number of iterations." -- FIXME: no location.
  [FParam lore] -> ImpM lore r op ()
forall lore r op. Mem lore => [FParam lore] -> ImpM lore r op ()
dFParams [FParam lore]
[Param FParamMem]
mergepat
  [(Param FParamMem, SubExp)]
-> ((Param FParamMem, SubExp) -> ImpM lore r op ())
-> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(Param FParamMem, SubExp)]
merge (((Param FParamMem, SubExp) -> ImpM lore r op ())
 -> ImpM lore r op ())
-> ((Param FParamMem, SubExp) -> ImpM lore r op ())
-> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(Param FParamMem
p, SubExp
se) ->
    Bool -> ImpM lore r op () -> ImpM lore r op ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ((Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0) (Int -> Bool) -> Int -> Bool
forall a b. (a -> b) -> a -> b
$ Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (Type -> Int) -> Type -> Int
forall a b. (a -> b) -> a -> b
$ Param FParamMem -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param FParamMem
p) (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
      VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (Param FParamMem -> VName
forall dec. Param dec -> VName
paramName Param FParamMem
p) [] SubExp
se []

  let doBody :: ImpM lore r op ()
doBody = [Param FParamMem] -> BodyT lore -> ImpM lore r op ()
forall dec lore r op.
Typed dec =>
[Param dec] -> Body lore -> ImpM lore r op ()
compileLoopBody [Param FParamMem]
mergepat BodyT lore
body

  case LoopForm lore
form of
    ForLoop VName
i IntType
_ SubExp
bound [(LParam lore, VName)]
loopvars -> do
      let setLoopParam :: (Param LParamMem, VName) -> ImpM lore r op ()
setLoopParam (Param LParamMem
p, VName
a)
            | Prim PrimType
_ <- Param LParamMem -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param LParamMem
p =
              VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var VName
a) [TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix (TExp Int64 -> DimIndex (TExp Int64))
-> TExp Int64 -> DimIndex (TExp Int64)
forall a b. (a -> b) -> a -> b
$ VName -> TExp Int64
Imp.vi64 VName
i]
            | Bool
otherwise =
              () -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

      Exp
bound' <- SubExp -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
bound

      [LParam lore] -> ImpM lore r op ()
forall lore r op. Mem lore => [LParam lore] -> ImpM lore r op ()
dLParams ([LParam lore] -> ImpM lore r op ())
-> [LParam lore] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ ((Param LParamMem, VName) -> Param LParamMem)
-> [(Param LParamMem, VName)] -> [Param LParamMem]
forall a b. (a -> b) -> [a] -> [b]
map (Param LParamMem, VName) -> Param LParamMem
forall a b. (a, b) -> a
fst [(LParam lore, VName)]
[(Param LParamMem, VName)]
loopvars
      VName -> Exp -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
VName -> Exp -> ImpM lore r op () -> ImpM lore r op ()
sFor' VName
i Exp
bound' (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
        ((Param LParamMem, VName) -> ImpM lore r op ())
-> [(Param LParamMem, VName)] -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Param LParamMem, VName) -> ImpM lore r op ()
setLoopParam [(LParam lore, VName)]
[(Param LParamMem, VName)]
loopvars ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ImpM lore r op ()
doBody
    WhileLoop VName
cond ->
      TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhile (Exp -> TExp Bool
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (Exp -> TExp Bool) -> Exp -> TExp Bool
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
cond PrimType
Bool) ImpM lore r op ()
doBody

  Destination Maybe Int
_ [ValueDestination]
pat_dests <- Pattern lore -> ImpM lore r op Destination
forall lore r op.
Mem lore =>
Pattern lore -> ImpM lore r op Destination
destinationFromPattern Pattern lore
pat
  [(ValueDestination, SubExp)]
-> ((ValueDestination, SubExp) -> ImpM lore r op ())
-> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([ValueDestination] -> [SubExp] -> [(ValueDestination, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [ValueDestination]
pat_dests ([SubExp] -> [(ValueDestination, SubExp)])
-> [SubExp] -> [(ValueDestination, SubExp)]
forall a b. (a -> b) -> a -> b
$ ((Param FParamMem, SubExp) -> SubExp)
-> [(Param FParamMem, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> ((Param FParamMem, SubExp) -> VName)
-> (Param FParamMem, SubExp)
-> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param FParamMem -> VName
forall dec. Param dec -> VName
paramName (Param FParamMem -> VName)
-> ((Param FParamMem, SubExp) -> Param FParamMem)
-> (Param FParamMem, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param FParamMem, SubExp) -> Param FParamMem
forall a b. (a, b) -> a
fst) [(Param FParamMem, SubExp)]
merge) (((ValueDestination, SubExp) -> ImpM lore r op ())
 -> ImpM lore r op ())
-> ((ValueDestination, SubExp) -> ImpM lore r op ())
-> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(ValueDestination
d, SubExp
r) ->
    ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIMDest ValueDestination
d [] SubExp
r []
  where
    merge :: [(Param FParamMem, SubExp)]
merge = [(FParam lore, SubExp)]
[(Param FParamMem, SubExp)]
ctx [(Param FParamMem, SubExp)]
-> [(Param FParamMem, SubExp)] -> [(Param FParamMem, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(FParam lore, SubExp)]
[(Param FParamMem, SubExp)]
val
    mergepat :: [Param FParamMem]
mergepat = ((Param FParamMem, SubExp) -> Param FParamMem)
-> [(Param FParamMem, SubExp)] -> [Param FParamMem]
forall a b. (a -> b) -> [a] -> [b]
map (Param FParamMem, SubExp) -> Param FParamMem
forall a b. (a, b) -> a
fst [(Param FParamMem, SubExp)]
merge
defCompileExp Pattern lore
pat (WithAcc [(Shape, [VName], Maybe (Lambda lore, [SubExp]))]
inputs Lambda lore
lam) = do
  [LParam lore] -> ImpM lore r op ()
forall lore r op. Mem lore => [LParam lore] -> ImpM lore r op ()
dLParams ([LParam lore] -> ImpM lore r op ())
-> [LParam lore] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [LParam lore]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
lam
  [((Shape, [VName], Maybe (Lambda lore, [SubExp])),
  Param LParamMem)]
-> (((Shape, [VName], Maybe (Lambda lore, [SubExp])),
     Param LParamMem)
    -> ImpM lore r op ())
-> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(Shape, [VName], Maybe (Lambda lore, [SubExp]))]
-> [Param LParamMem]
-> [((Shape, [VName], Maybe (Lambda lore, [SubExp])),
     Param LParamMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(Shape, [VName], Maybe (Lambda lore, [SubExp]))]
inputs ([Param LParamMem]
 -> [((Shape, [VName], Maybe (Lambda lore, [SubExp])),
      Param LParamMem)])
-> [Param LParamMem]
-> [((Shape, [VName], Maybe (Lambda lore, [SubExp])),
     Param LParamMem)]
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [LParam lore]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
lam) ((((Shape, [VName], Maybe (Lambda lore, [SubExp])),
   Param LParamMem)
  -> ImpM lore r op ())
 -> ImpM lore r op ())
-> (((Shape, [VName], Maybe (Lambda lore, [SubExp])),
     Param LParamMem)
    -> ImpM lore r op ())
-> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \((Shape
_, [VName]
arrs, Maybe (Lambda lore, [SubExp])
op), Param LParamMem
p) ->
    (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ())
-> (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState lore r op
s ->
      ImpState lore r op
s {stateAccs :: Map VName ([VName], Maybe (Lambda lore, [SubExp]))
stateAccs = VName
-> ([VName], Maybe (Lambda lore, [SubExp]))
-> Map VName ([VName], Maybe (Lambda lore, [SubExp]))
-> Map VName ([VName], Maybe (Lambda lore, [SubExp]))
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) ([VName]
arrs, Maybe (Lambda lore, [SubExp])
op) (Map VName ([VName], Maybe (Lambda lore, [SubExp]))
 -> Map VName ([VName], Maybe (Lambda lore, [SubExp])))
-> Map VName ([VName], Maybe (Lambda lore, [SubExp]))
-> Map VName ([VName], Maybe (Lambda lore, [SubExp]))
forall a b. (a -> b) -> a -> b
$ ImpState lore r op
-> Map VName ([VName], Maybe (Lambda lore, [SubExp]))
forall lore r op.
ImpState lore r op
-> Map VName ([VName], Maybe (Lambda lore, [SubExp]))
stateAccs ImpState lore r op
s}
  Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (BodyT lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms (BodyT lore -> Stms lore) -> BodyT lore -> Stms lore
forall a b. (a -> b) -> a -> b
$ Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam) (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ do
    let nonacc_res :: [SubExp]
nonacc_res = Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
drop Int
num_accs (BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam))
        nonacc_pat_names :: [VName]
nonacc_pat_names = Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
takeLast ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nonacc_res) (PatternT LParamMem -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern lore
PatternT LParamMem
pat)
    [(VName, SubExp)]
-> ((VName, SubExp) -> ImpM lore r op ()) -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
nonacc_pat_names [SubExp]
nonacc_res) (((VName, SubExp) -> ImpM lore r op ()) -> ImpM lore r op ())
-> ((VName, SubExp) -> ImpM lore r op ()) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(VName
v, SubExp
se) ->
      VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM VName
v [] SubExp
se []
  where
    num_accs :: Int
num_accs = [(Shape, [VName], Maybe (Lambda lore, [SubExp]))] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Shape, [VName], Maybe (Lambda lore, [SubExp]))]
inputs
defCompileExp Pattern lore
pat (Op Op lore
op) = do
  PatternT LParamMem -> Op lore -> ImpM lore r op ()
opc <- (Env lore r op
 -> PatternT LParamMem -> Op lore -> ImpM lore r op ())
-> ImpM
     lore r op (PatternT LParamMem -> Op lore -> ImpM lore r op ())
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> PatternT LParamMem -> Op lore -> ImpM lore r op ()
forall lore r op. Env lore r op -> OpCompiler lore r op
envOpCompiler
  PatternT LParamMem -> Op lore -> ImpM lore r op ()
opc Pattern lore
PatternT LParamMem
pat Op lore
op

defCompileBasicOp ::
  Mem lore =>
  Pattern lore ->
  BasicOp ->
  ImpM lore r op ()
defCompileBasicOp :: forall lore r op.
Mem lore =>
Pattern lore -> BasicOp -> ImpM lore r op ()
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (SubExp SubExp
se) =
  VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT LParamMem
pe) [] SubExp
se []
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (Opaque SubExp
se) =
  VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT LParamMem
pe) [] SubExp
se []
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (UnOp UnOp
op SubExp
e) = do
  Exp
e' <- SubExp -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
e
  PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT LParamMem
pe VName -> Exp -> ImpM lore r op ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<~~ UnOp -> Exp -> Exp
forall v. UnOp -> PrimExp v -> PrimExp v
Imp.UnOpExp UnOp
op Exp
e'
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (ConvOp ConvOp
conv SubExp
e) = do
  Exp
e' <- SubExp -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
e
  PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT LParamMem
pe VName -> Exp -> ImpM lore r op ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<~~ ConvOp -> Exp -> Exp
forall v. ConvOp -> PrimExp v -> PrimExp v
Imp.ConvOpExp ConvOp
conv Exp
e'
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (BinOp BinOp
bop SubExp
x SubExp
y) = do
  Exp
x' <- SubExp -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
x
  Exp
y' <- SubExp -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
y
  PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT LParamMem
pe VName -> Exp -> ImpM lore r op ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<~~ BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp BinOp
bop Exp
x' Exp
y'
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (CmpOp CmpOp
bop SubExp
x SubExp
y) = do
  Exp
x' <- SubExp -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
x
  Exp
y' <- SubExp -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
y
  PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT LParamMem
pe VName -> Exp -> ImpM lore r op ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<~~ CmpOp -> Exp -> Exp -> Exp
forall v. CmpOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.CmpOpExp CmpOp
bop Exp
x' Exp
y'
defCompileBasicOp PatternT (LetDec lore)
_ (Assert SubExp
e ErrorMsg SubExp
msg (SrcLoc, [SrcLoc])
loc) = do
  Exp
e' <- SubExp -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
e
  ErrorMsg Exp
msg' <- (SubExp -> ImpM lore r op Exp)
-> ErrorMsg SubExp -> ImpM lore r op (ErrorMsg Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp ErrorMsg SubExp
msg
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Exp -> ErrorMsg Exp -> (SrcLoc, [SrcLoc]) -> Code op
forall a. Exp -> ErrorMsg Exp -> (SrcLoc, [SrcLoc]) -> Code a
Imp.Assert Exp
e' ErrorMsg Exp
msg' (SrcLoc, [SrcLoc])
loc

  Attrs
attrs <- ImpM lore r op Attrs
forall lore r op. ImpM lore r op Attrs
askAttrs
  Bool -> ImpM lore r op () -> ImpM lore r op ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Name -> [Attr] -> Attr
AttrComp Name
"warn" [Attr
"safety_checks"] Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs) (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
    (SrcLoc -> [SrcLoc] -> [Char] -> ImpM lore r op ())
-> (SrcLoc, [SrcLoc]) -> [Char] -> ImpM lore r op ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry SrcLoc -> [SrcLoc] -> [Char] -> ImpM lore r op ()
forall loc lore r op.
Located loc =>
loc -> [loc] -> [Char] -> ImpM lore r op ()
warn (SrcLoc, [SrcLoc])
loc [Char]
"Safety check required at run-time."
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (Index VName
src Slice SubExp
slice)
  | Just [SubExp]
idxs <- Slice SubExp -> Maybe [SubExp]
forall d. Slice d -> Maybe [d]
sliceIndices Slice SubExp
slice =
    VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT LParamMem
pe) [] (VName -> SubExp
Var VName
src) ([DimIndex (TExp Int64)] -> ImpM lore r op ())
-> [DimIndex (TExp Int64)] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ (SubExp -> DimIndex (TExp Int64))
-> [SubExp] -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map (TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix (TExp Int64 -> DimIndex (TExp Int64))
-> (SubExp -> TExp Int64) -> SubExp -> DimIndex (TExp Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp) [SubExp]
idxs
defCompileBasicOp PatternT (LetDec lore)
_ Index {} =
  () -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (Update VName
_ Slice SubExp
slice SubExp
se) =
  VName -> [DimIndex (TExp Int64)] -> SubExp -> ImpM lore r op ()
forall lore r op.
VName -> [DimIndex (TExp Int64)] -> SubExp -> ImpM lore r op ()
sUpdate (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT LParamMem
pe) ((DimIndex SubExp -> DimIndex (TExp Int64))
-> Slice SubExp -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map ((SubExp -> TExp Int64) -> DimIndex SubExp -> DimIndex (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp) Slice SubExp
slice) SubExp
se
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (Replicate (Shape [SubExp]
ds) SubExp
se) = do
  [Exp]
ds' <- (SubExp -> ImpM lore r op Exp) -> [SubExp] -> ImpM lore r op [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp [SubExp]
ds
  [VName]
is <- Int -> ImpM lore r op VName -> ImpM lore r op [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
ds) ([Char] -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"i")
  Code op
copy_elem <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (ImpM lore r op () -> ImpM lore r op (Code op))
-> ImpM lore r op () -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$ VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT LParamMem
pe) ((VName -> DimIndex (TExp Int64))
-> [VName] -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map (TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix (TExp Int64 -> DimIndex (TExp Int64))
-> (VName -> TExp Int64) -> VName -> DimIndex (TExp Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> TExp Int64
Imp.vi64) [VName]
is) SubExp
se []
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ ((Code op -> Code op)
 -> (Code op -> Code op) -> Code op -> Code op)
-> (Code op -> Code op)
-> [Code op -> Code op]
-> Code op
-> Code op
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (Code op -> Code op) -> (Code op -> Code op) -> Code op -> Code op
forall b c a. (b -> c) -> (a -> b) -> a -> c
(.) Code op -> Code op
forall a. a -> a
id ((VName -> Exp -> Code op -> Code op)
-> [VName] -> [Exp] -> [Code op -> Code op]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> Exp -> Code op -> Code op
forall a. VName -> Exp -> Code a -> Code a
Imp.For [VName]
is [Exp]
ds') Code op
copy_elem
defCompileBasicOp PatternT (LetDec lore)
_ Scratch {} =
  () -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
defCompileBasicOp (Pattern [] [PatElemT (LetDec lore)
pe]) (Iota SubExp
n SubExp
e SubExp
s IntType
it) = do
  Exp
e' <- SubExp -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
e
  Exp
s' <- SubExp -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
s
  [Char]
-> TExp Int64
-> (TExp Int64 -> ImpM lore r op ())
-> ImpM lore r op ()
forall t lore r op.
[Char]
-> TExp t -> (TExp t -> ImpM lore r op ()) -> ImpM lore r op ()
sFor [Char]
"i" (SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp SubExp
n) ((TExp Int64 -> ImpM lore r op ()) -> ImpM lore r op ())
-> (TExp Int64 -> ImpM lore r op ()) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
    let i' :: Exp
i' = IntType -> Exp -> Exp
forall v. IntType -> PrimExp v -> PrimExp v
sExt IntType
it (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
i
    TV Any
x <-
      [Char] -> TExp Any -> ImpM lore r op (TV Any)
forall t lore r op. [Char] -> TExp t -> ImpM lore r op (TV t)
dPrimV [Char]
"x" (TExp Any -> ImpM lore r op (TV Any))
-> TExp Any -> ImpM lore r op (TV Any)
forall a b. (a -> b) -> a -> b
$
        Exp -> TExp Any
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (Exp -> TExp Any) -> Exp -> TExp Any
forall a b. (a -> b) -> a -> b
$
          BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> Overflow -> BinOp
Add IntType
it Overflow
OverflowUndef) Exp
e' (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$
            BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> Overflow -> BinOp
Mul IntType
it Overflow
OverflowUndef) Exp
i' Exp
s'
    VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT LParamMem
pe) [TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix TExp Int64
i] (VName -> SubExp
Var (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
x)) []
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (Copy VName
src) =
  VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT LParamMem
pe) [] (VName -> SubExp
Var VName
src) []
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (Manifest [Int]
_ VName
src) =
  VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT LParamMem
pe) [] (VName -> SubExp
Var VName
src) []
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (Concat Int
i VName
x [VName]
ys SubExp
_) = do
  TV Int64
offs_glb <- [Char] -> TExp Int64 -> ImpM lore r op (TV Int64)
forall t lore r op. [Char] -> TExp t -> ImpM lore r op (TV t)
dPrimV [Char]
"tmp_offs" TExp Int64
0

  [VName] -> (VName -> ImpM lore r op ()) -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (VName
x VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
ys) ((VName -> ImpM lore r op ()) -> ImpM lore r op ())
-> (VName -> ImpM lore r op ()) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \VName
y -> do
    [SubExp]
y_dims <- Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [SubExp])
-> ImpM lore r op Type -> ImpM lore r op [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM lore r op Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
y
    let rows :: TExp Int64
rows = case Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
drop Int
i [SubExp]
y_dims of
          [] -> [Char] -> TExp Int64
forall a. HasCallStack => [Char] -> a
error ([Char] -> TExp Int64) -> [Char] -> TExp Int64
forall a b. (a -> b) -> a -> b
$ [Char]
"defCompileBasicOp Concat: empty array shape for " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
y
          SubExp
r : [SubExp]
_ -> SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp SubExp
r
        skip_dims :: [SubExp]
skip_dims = Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
take Int
i [SubExp]
y_dims
        sliceAllDim :: d -> DimIndex d
sliceAllDim d
d = d -> d -> d -> DimIndex d
forall d. d -> d -> d -> DimIndex d
DimSlice d
0 d
d d
1
        skip_slices :: [DimIndex (TExp Int64)]
skip_slices = (SubExp -> DimIndex (TExp Int64))
-> [SubExp] -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map (TExp Int64 -> DimIndex (TExp Int64)
forall {d}. Num d => d -> DimIndex d
sliceAllDim (TExp Int64 -> DimIndex (TExp Int64))
-> (SubExp -> TExp Int64) -> SubExp -> DimIndex (TExp Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp) [SubExp]
skip_dims
        destslice :: [DimIndex (TExp Int64)]
destslice = [DimIndex (TExp Int64)]
skip_slices [DimIndex (TExp Int64)]
-> [DimIndex (TExp Int64)] -> [DimIndex (TExp Int64)]
forall a. [a] -> [a] -> [a]
++ [TExp Int64 -> TExp Int64 -> TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> d -> d -> DimIndex d
DimSlice (TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
offs_glb) TExp Int64
rows TExp Int64
1]
    VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT LParamMem
pe) [DimIndex (TExp Int64)]
destslice (VName -> SubExp
Var VName
y) []
    TV Int64
offs_glb TV Int64 -> TExp Int64 -> ImpM lore r op ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
offs_glb TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
rows
defCompileBasicOp (Pattern [] [PatElemT (LetDec lore)
pe]) (ArrayLit [SubExp]
es Type
_)
  | Just vs :: [PrimValue]
vs@(PrimValue
v : [PrimValue]
_) <- (SubExp -> Maybe PrimValue) -> [SubExp] -> Maybe [PrimValue]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> Maybe PrimValue
isLiteral [SubExp]
es = do
    MemLocation
dest_mem <- ArrayEntry -> MemLocation
entryArrayLocation (ArrayEntry -> MemLocation)
-> ImpM lore r op ArrayEntry -> ImpM lore r op MemLocation
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM lore r op ArrayEntry
forall lore r op. VName -> ImpM lore r op ArrayEntry
lookupArray (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT LParamMem
pe)
    Space
dest_space <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM lore r op MemEntry -> ImpM lore r op Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM lore r op MemEntry
forall lore r op. VName -> ImpM lore r op MemEntry
lookupMemory (MemLocation -> VName
memLocationName MemLocation
dest_mem)
    let t :: PrimType
t = PrimValue -> PrimType
primValueType PrimValue
v
    VName
static_array <- [Char] -> ImpM lore r op VName
forall lore r op. [Char] -> ImpM lore r op VName
newVNameForFun [Char]
"static_array"
    Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Space -> PrimType -> ArrayContents -> Code op
forall a. VName -> Space -> PrimType -> ArrayContents -> Code a
Imp.DeclareArray VName
static_array Space
dest_space PrimType
t (ArrayContents -> Code op) -> ArrayContents -> Code op
forall a b. (a -> b) -> a -> b
$ [PrimValue] -> ArrayContents
Imp.ArrayValues [PrimValue]
vs
    let static_src :: MemLocation
static_src =
          VName -> [SubExp] -> IxFun (TExp Int64) -> MemLocation
MemLocation VName
static_array [IntType -> Integer -> SubExp
intConst IntType
Int64 (Integer -> SubExp) -> Integer -> SubExp
forall a b. (a -> b) -> a -> b
$ Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Integer) -> Int -> Integer
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
es] (IxFun (TExp Int64) -> MemLocation)
-> IxFun (TExp Int64) -> MemLocation
forall a b. (a -> b) -> a -> b
$
            Shape (TExp Int64) -> IxFun (TExp Int64)
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota [Int -> TExp Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> TExp Int64) -> Int -> TExp Int64
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
es]
        entry :: VarEntry lore
entry = Maybe (Exp lore) -> MemEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> MemEntry -> VarEntry lore
MemVar Maybe (Exp lore)
forall a. Maybe a
Nothing (MemEntry -> VarEntry lore) -> MemEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
dest_space
    VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
static_array VarEntry lore
entry
    let slice :: [DimIndex (TExp Int64)]
slice = [TExp Int64 -> TExp Int64 -> TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> d -> d -> DimIndex d
DimSlice TExp Int64
0 ([SubExp] -> TExp Int64
forall i a. Num i => [a] -> i
genericLength [SubExp]
es) TExp Int64
1]
    CopyCompiler lore r op
forall lore r op. CopyCompiler lore r op
copy PrimType
t MemLocation
dest_mem [DimIndex (TExp Int64)]
slice MemLocation
static_src [DimIndex (TExp Int64)]
slice
  | Bool
otherwise =
    [(Integer, SubExp)]
-> ((Integer, SubExp) -> ImpM lore r op ()) -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Integer] -> [SubExp] -> [(Integer, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Integer
0 ..] [SubExp]
es) (((Integer, SubExp) -> ImpM lore r op ()) -> ImpM lore r op ())
-> ((Integer, SubExp) -> ImpM lore r op ()) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(Integer
i, SubExp
e) ->
      VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT LParamMem
pe) [TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix (TExp Int64 -> DimIndex (TExp Int64))
-> TExp Int64 -> DimIndex (TExp Int64)
forall a b. (a -> b) -> a -> b
$ Integer -> TExp Int64
forall a. Num a => Integer -> a
fromInteger Integer
i] SubExp
e []
  where
    isLiteral :: SubExp -> Maybe PrimValue
isLiteral (Constant PrimValue
v) = PrimValue -> Maybe PrimValue
forall a. a -> Maybe a
Just PrimValue
v
    isLiteral SubExp
_ = Maybe PrimValue
forall a. Maybe a
Nothing
defCompileBasicOp PatternT (LetDec lore)
_ Rearrange {} =
  () -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
defCompileBasicOp PatternT (LetDec lore)
_ Rotate {} =
  () -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
defCompileBasicOp PatternT (LetDec lore)
_ Reshape {} =
  () -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
defCompileBasicOp PatternT (LetDec lore)
_ (UpdateAcc VName
acc [SubExp]
is [SubExp]
vs) = [Char] -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op. [Char] -> ImpM lore r op () -> ImpM lore r op ()
sComment [Char]
"UpdateAcc" (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ do
  -- We are abusing the comment mechanism to wrap the operator in
  -- braces when we end up generating code.  This is necessary because
  -- we might otherwise end up declaring lambda parameters (if any)
  -- multiple times, as they are duplicated every time we do an
  -- UpdateAcc for the same accumulator.
  let is' :: Shape (TExp Int64)
is' = (SubExp -> TExp Int64) -> [SubExp] -> Shape (TExp Int64)
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
is

  -- We need to figure out whether we are updating a scatter-like
  -- accumulator or a generalised reduction.  This also binds the
  -- index parameters.
  (VName
_, Space
_, [VName]
arrs, Shape (TExp Int64)
dims, Maybe (Lambda lore)
op) <- VName
-> Shape (TExp Int64)
-> ImpM
     lore
     r
     op
     (VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda lore))
forall lore r op.
VName
-> Shape (TExp Int64)
-> ImpM
     lore
     r
     op
     (VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda lore))
lookupAcc VName
acc Shape (TExp Int64)
is'

  TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen ([DimIndex (TExp Int64)] -> Shape (TExp Int64) -> TExp Bool
inBounds ((TExp Int64 -> DimIndex (TExp Int64))
-> Shape (TExp Int64) -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix Shape (TExp Int64)
is') Shape (TExp Int64)
dims) (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
    case Maybe (Lambda lore)
op of
      Maybe (Lambda lore)
Nothing ->
        -- Scatter-like.
        [(VName, SubExp)]
-> ((VName, SubExp) -> ImpM lore r op ()) -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
arrs [SubExp]
vs) (((VName, SubExp) -> ImpM lore r op ()) -> ImpM lore r op ())
-> ((VName, SubExp) -> ImpM lore r op ()) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(VName
arr, SubExp
v) -> VName
-> Shape (TExp Int64)
-> SubExp
-> Shape (TExp Int64)
-> ImpM lore r op ()
forall lore r op.
VName
-> Shape (TExp Int64)
-> SubExp
-> Shape (TExp Int64)
-> ImpM lore r op ()
copyDWIMFix VName
arr Shape (TExp Int64)
is' SubExp
v []
      Just Lambda lore
lam -> do
        -- Generalised reduction.
        [LParam lore] -> ImpM lore r op ()
forall lore r op. Mem lore => [LParam lore] -> ImpM lore r op ()
dLParams ([LParam lore] -> ImpM lore r op ())
-> [LParam lore] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [LParam lore]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
lam
        let ([VName]
x_params, [VName]
y_params) =
              Int -> [VName] -> ([VName], [VName])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
vs) ([VName] -> ([VName], [VName])) -> [VName] -> ([VName], [VName])
forall a b. (a -> b) -> a -> b
$ (Param LParamMem -> VName) -> [Param LParamMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param LParamMem -> VName
forall dec. Param dec -> VName
paramName ([Param LParamMem] -> [VName]) -> [Param LParamMem] -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [LParam lore]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
lam

        [(VName, VName)]
-> ((VName, VName) -> ImpM lore r op ()) -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
x_params [VName]
arrs) (((VName, VName) -> ImpM lore r op ()) -> ImpM lore r op ())
-> ((VName, VName) -> ImpM lore r op ()) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(VName
xp, VName
arr) ->
          VName
-> Shape (TExp Int64)
-> SubExp
-> Shape (TExp Int64)
-> ImpM lore r op ()
forall lore r op.
VName
-> Shape (TExp Int64)
-> SubExp
-> Shape (TExp Int64)
-> ImpM lore r op ()
copyDWIMFix VName
xp [] (VName -> SubExp
Var VName
arr) Shape (TExp Int64)
is'

        [(VName, SubExp)]
-> ((VName, SubExp) -> ImpM lore r op ()) -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
y_params [SubExp]
vs) (((VName, SubExp) -> ImpM lore r op ()) -> ImpM lore r op ())
-> ((VName, SubExp) -> ImpM lore r op ()) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(VName
yp, SubExp
v) ->
          VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM VName
yp [] SubExp
v []

        Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (BodyT lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms (BodyT lore -> Stms lore) -> BodyT lore -> Stms lore
forall a b. (a -> b) -> a -> b
$ Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam) (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
          [(VName, SubExp)]
-> ((VName, SubExp) -> ImpM lore r op ()) -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
arrs (BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam))) (((VName, SubExp) -> ImpM lore r op ()) -> ImpM lore r op ())
-> ((VName, SubExp) -> ImpM lore r op ()) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(VName
arr, SubExp
se) ->
            VName
-> Shape (TExp Int64)
-> SubExp
-> Shape (TExp Int64)
-> ImpM lore r op ()
forall lore r op.
VName
-> Shape (TExp Int64)
-> SubExp
-> Shape (TExp Int64)
-> ImpM lore r op ()
copyDWIMFix VName
arr Shape (TExp Int64)
is' SubExp
se []
defCompileBasicOp PatternT (LetDec lore)
pat BasicOp
e =
  [Char] -> ImpM lore r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM lore r op ()) -> [Char] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
    [Char]
"ImpGen.defCompileBasicOp: Invalid pattern\n  "
      [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ PatternT LParamMem -> [Char]
forall a. Pretty a => a -> [Char]
pretty PatternT (LetDec lore)
PatternT LParamMem
pat
      [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
"\nfor expression\n  "
      [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ BasicOp -> [Char]
forall a. Pretty a => a -> [Char]
pretty BasicOp
e

-- | Note: a hack to be used only for functions.
addArrays :: [ArrayDecl] -> ImpM lore r op ()
addArrays :: forall lore r op. [ArrayDecl] -> ImpM lore r op ()
addArrays = (ArrayDecl -> ImpM lore r op ())
-> [ArrayDecl] -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ArrayDecl -> ImpM lore r op ()
forall {lore} {r} {op}. ArrayDecl -> ImpM lore r op ()
addArray
  where
    addArray :: ArrayDecl -> ImpM lore r op ()
addArray (ArrayDecl VName
name PrimType
bt MemLocation
location) =
      VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
        Maybe (Exp lore) -> ArrayEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> ArrayEntry -> VarEntry lore
ArrayVar
          Maybe (Exp lore)
forall a. Maybe a
Nothing
          ArrayEntry :: MemLocation -> PrimType -> ArrayEntry
ArrayEntry
            { entryArrayLocation :: MemLocation
entryArrayLocation = MemLocation
location,
              entryArrayElemType :: PrimType
entryArrayElemType = PrimType
bt
            }

-- | Like 'dFParams', but does not create new declarations.
-- Note: a hack to be used only for functions.
addFParams :: Mem lore => [FParam lore] -> ImpM lore r op ()
addFParams :: forall lore r op. Mem lore => [FParam lore] -> ImpM lore r op ()
addFParams = (Param FParamMem -> ImpM lore r op ())
-> [Param FParamMem] -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Param FParamMem -> ImpM lore r op ()
forall {u} {lore} {r} {op}.
Param (MemInfo SubExp u MemBind) -> ImpM lore r op ()
addFParam
  where
    addFParam :: Param (MemInfo SubExp u MemBind) -> ImpM lore r op ()
addFParam Param (MemInfo SubExp u MemBind)
fparam =
      VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar (Param (MemInfo SubExp u MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo SubExp u MemBind)
fparam) (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
        Maybe (Exp lore) -> LParamMem -> VarEntry lore
forall lore. Maybe (Exp lore) -> LParamMem -> VarEntry lore
memBoundToVarEntry Maybe (Exp lore)
forall a. Maybe a
Nothing (LParamMem -> VarEntry lore) -> LParamMem -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ MemInfo SubExp u MemBind -> LParamMem
forall d u r. MemInfo d u r -> MemInfo d NoUniqueness r
noUniquenessReturns (MemInfo SubExp u MemBind -> LParamMem)
-> MemInfo SubExp u MemBind -> LParamMem
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp u MemBind) -> MemInfo SubExp u MemBind
forall dec. Param dec -> dec
paramDec Param (MemInfo SubExp u MemBind)
fparam

-- | Another hack.
addLoopVar :: VName -> IntType -> ImpM lore r op ()
addLoopVar :: forall lore r op. VName -> IntType -> ImpM lore r op ()
addLoopVar VName
i IntType
it = VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
i (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
ScalarVar Maybe (Exp lore)
forall a. Maybe a
Nothing (ScalarEntry -> VarEntry lore) -> ScalarEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarEntry
ScalarEntry (PrimType -> ScalarEntry) -> PrimType -> ScalarEntry
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
it

dVars ::
  Mem lore =>
  Maybe (Exp lore) ->
  [PatElem lore] ->
  ImpM lore r op ()
dVars :: forall lore r op.
Mem lore =>
Maybe (Exp lore) -> [PatElem lore] -> ImpM lore r op ()
dVars Maybe (Exp lore)
e = (PatElemT LParamMem -> ImpM lore r op ())
-> [PatElemT LParamMem] -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ PatElemT LParamMem -> ImpM lore r op ()
dVar
  where
    dVar :: PatElemT LParamMem -> ImpM lore r op ()
dVar = Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
forall lore r op.
Mem lore =>
Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
dScope Maybe (Exp lore)
e (Scope lore -> ImpM lore r op ())
-> (PatElemT LParamMem -> Scope lore)
-> PatElemT LParamMem
-> ImpM lore r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT LParamMem -> Scope lore
forall lore dec. (LetDec lore ~ dec) => PatElemT dec -> Scope lore
scopeOfPatElem

dFParams :: Mem lore => [FParam lore] -> ImpM lore r op ()
dFParams :: forall lore r op. Mem lore => [FParam lore] -> ImpM lore r op ()
dFParams = Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
forall lore r op.
Mem lore =>
Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
dScope Maybe (Exp lore)
forall a. Maybe a
Nothing (Scope lore -> ImpM lore r op ())
-> ([Param FParamMem] -> Scope lore)
-> [Param FParamMem]
-> ImpM lore r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Param FParamMem] -> Scope lore
forall lore dec.
(FParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfFParams

dLParams :: Mem lore => [LParam lore] -> ImpM lore r op ()
dLParams :: forall lore r op. Mem lore => [LParam lore] -> ImpM lore r op ()
dLParams = Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
forall lore r op.
Mem lore =>
Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
dScope Maybe (Exp lore)
forall a. Maybe a
Nothing (Scope lore -> ImpM lore r op ())
-> ([Param LParamMem] -> Scope lore)
-> [Param LParamMem]
-> ImpM lore r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Param LParamMem] -> Scope lore
forall lore dec.
(LParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfLParams

dPrimVol :: String -> PrimType -> Imp.TExp t -> ImpM lore r op (TV t)
dPrimVol :: forall t lore r op.
[Char] -> PrimType -> TExp t -> ImpM lore r op (TV t)
dPrimVol [Char]
name PrimType
t TExp t
e = do
  VName
name' <- [Char] -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
name
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Volatility -> PrimType -> Code op
forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
name' Volatility
Imp.Volatile PrimType
t
  VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name' (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
ScalarVar Maybe (Exp lore)
forall a. Maybe a
Nothing (ScalarEntry -> VarEntry lore) -> ScalarEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarEntry
ScalarEntry PrimType
t
  VName
name' VName -> Exp -> ImpM lore r op ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<~~ TExp t -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp t
e
  TV t -> ImpM lore r op (TV t)
forall (m :: * -> *) a. Monad m => a -> m a
return (TV t -> ImpM lore r op (TV t)) -> TV t -> ImpM lore r op (TV t)
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> TV t
forall t. VName -> PrimType -> TV t
TV VName
name' PrimType
t

dPrim_ :: VName -> PrimType -> ImpM lore r op ()
dPrim_ :: forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
name PrimType
t = do
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Volatility -> PrimType -> Code op
forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
name Volatility
Imp.Nonvolatile PrimType
t
  VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
ScalarVar Maybe (Exp lore)
forall a. Maybe a
Nothing (ScalarEntry -> VarEntry lore) -> ScalarEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarEntry
ScalarEntry PrimType
t

-- | The return type is polymorphic, so there is no guarantee it
-- actually matches the 'PrimType', but at least we have to use it
-- consistently.
dPrim :: String -> PrimType -> ImpM lore r op (TV t)
dPrim :: forall lore r op t. [Char] -> PrimType -> ImpM lore r op (TV t)
dPrim [Char]
name PrimType
t = do
  VName
name' <- [Char] -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
name
  VName -> PrimType -> ImpM lore r op ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
name' PrimType
t
  TV t -> ImpM lore r op (TV t)
forall (m :: * -> *) a. Monad m => a -> m a
return (TV t -> ImpM lore r op (TV t)) -> TV t -> ImpM lore r op (TV t)
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> TV t
forall t. VName -> PrimType -> TV t
TV VName
name' PrimType
t

dPrimV_ :: VName -> Imp.TExp t -> ImpM lore r op ()
dPrimV_ :: forall t lore r op. VName -> TExp t -> ImpM lore r op ()
dPrimV_ VName
name TExp t
e = do
  VName -> PrimType -> ImpM lore r op ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
name PrimType
t
  VName -> PrimType -> TV t
forall t. VName -> PrimType -> TV t
TV VName
name PrimType
t TV t -> TExp t -> ImpM lore r op ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- TExp t
e
  where
    t :: PrimType
t = Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType (Exp -> PrimType) -> Exp -> PrimType
forall a b. (a -> b) -> a -> b
$ TExp t -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp t
e

dPrimV :: String -> Imp.TExp t -> ImpM lore r op (TV t)
dPrimV :: forall t lore r op. [Char] -> TExp t -> ImpM lore r op (TV t)
dPrimV [Char]
name TExp t
e = do
  TV t
name' <- [Char] -> PrimType -> ImpM lore r op (TV t)
forall lore r op t. [Char] -> PrimType -> ImpM lore r op (TV t)
dPrim [Char]
name (PrimType -> ImpM lore r op (TV t))
-> PrimType -> ImpM lore r op (TV t)
forall a b. (a -> b) -> a -> b
$ Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType (Exp -> PrimType) -> Exp -> PrimType
forall a b. (a -> b) -> a -> b
$ TExp t -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp t
e
  TV t
name' TV t -> TExp t -> ImpM lore r op ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- TExp t
e
  TV t -> ImpM lore r op (TV t)
forall (m :: * -> *) a. Monad m => a -> m a
return TV t
name'

dPrimVE :: String -> Imp.TExp t -> ImpM lore r op (Imp.TExp t)
dPrimVE :: forall t lore r op. [Char] -> TExp t -> ImpM lore r op (TExp t)
dPrimVE [Char]
name TExp t
e = do
  TV t
name' <- [Char] -> PrimType -> ImpM lore r op (TV t)
forall lore r op t. [Char] -> PrimType -> ImpM lore r op (TV t)
dPrim [Char]
name (PrimType -> ImpM lore r op (TV t))
-> PrimType -> ImpM lore r op (TV t)
forall a b. (a -> b) -> a -> b
$ Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType (Exp -> PrimType) -> Exp -> PrimType
forall a b. (a -> b) -> a -> b
$ TExp t -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp t
e
  TV t
name' TV t -> TExp t -> ImpM lore r op ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- TExp t
e
  TExp t -> ImpM lore r op (TExp t)
forall (m :: * -> *) a. Monad m => a -> m a
return (TExp t -> ImpM lore r op (TExp t))
-> TExp t -> ImpM lore r op (TExp t)
forall a b. (a -> b) -> a -> b
$ TV t -> TExp t
forall t. TV t -> TExp t
tvExp TV t
name'

memBoundToVarEntry ::
  Maybe (Exp lore) ->
  MemBound NoUniqueness ->
  VarEntry lore
memBoundToVarEntry :: forall lore. Maybe (Exp lore) -> LParamMem -> VarEntry lore
memBoundToVarEntry Maybe (Exp lore)
e (MemPrim PrimType
bt) =
  Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
ScalarVar Maybe (Exp lore)
e ScalarEntry :: PrimType -> ScalarEntry
ScalarEntry {entryScalarType :: PrimType
entryScalarType = PrimType
bt}
memBoundToVarEntry Maybe (Exp lore)
e (MemMem Space
space) =
  Maybe (Exp lore) -> MemEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> MemEntry -> VarEntry lore
MemVar Maybe (Exp lore)
e (MemEntry -> VarEntry lore) -> MemEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
space
memBoundToVarEntry Maybe (Exp lore)
e (MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
_) =
  Maybe (Exp lore) -> (VName, Shape, [Type]) -> VarEntry lore
forall lore.
Maybe (Exp lore) -> (VName, Shape, [Type]) -> VarEntry lore
AccVar Maybe (Exp lore)
e (VName
acc, Shape
ispace, [Type]
ts)
memBoundToVarEntry Maybe (Exp lore)
e (MemArray PrimType
bt Shape
shape NoUniqueness
_ (ArrayIn VName
mem IxFun
ixfun)) =
  let location :: MemLocation
location = VName -> [SubExp] -> IxFun (TExp Int64) -> MemLocation
MemLocation VName
mem (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape) (IxFun (TExp Int64) -> MemLocation)
-> IxFun (TExp Int64) -> MemLocation
forall a b. (a -> b) -> a -> b
$ (TPrimExp Int64 VName -> TExp Int64) -> IxFun -> IxFun (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> ExpLeaf) -> TPrimExp Int64 VName -> TExp Int64
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> ExpLeaf
Imp.ScalarVar) IxFun
ixfun
   in Maybe (Exp lore) -> ArrayEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> ArrayEntry -> VarEntry lore
ArrayVar
        Maybe (Exp lore)
e
        ArrayEntry :: MemLocation -> PrimType -> ArrayEntry
ArrayEntry
          { entryArrayLocation :: MemLocation
entryArrayLocation = MemLocation
location,
            entryArrayElemType :: PrimType
entryArrayElemType = PrimType
bt
          }

infoDec ::
  Mem lore =>
  NameInfo lore ->
  MemInfo SubExp NoUniqueness MemBind
infoDec :: forall lore. Mem lore => NameInfo lore -> LParamMem
infoDec (LetName LetDec lore
dec) = LetDec lore
LParamMem
dec
infoDec (FParamName FParamInfo lore
dec) = FParamMem -> LParamMem
forall d u r. MemInfo d u r -> MemInfo d NoUniqueness r
noUniquenessReturns FParamInfo lore
FParamMem
dec
infoDec (LParamName LParamInfo lore
dec) = LParamInfo lore
LParamMem
dec
infoDec (IndexName IntType
it) = PrimType -> LParamMem
forall d u ret. PrimType -> MemInfo d u ret
MemPrim (PrimType -> LParamMem) -> PrimType -> LParamMem
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
it

dInfo ::
  Mem lore =>
  Maybe (Exp lore) ->
  VName ->
  NameInfo lore ->
  ImpM lore r op ()
dInfo :: forall lore r op.
Mem lore =>
Maybe (Exp lore) -> VName -> NameInfo lore -> ImpM lore r op ()
dInfo Maybe (Exp lore)
e VName
name NameInfo lore
info = do
  let entry :: VarEntry lore
entry = Maybe (Exp lore) -> LParamMem -> VarEntry lore
forall lore. Maybe (Exp lore) -> LParamMem -> VarEntry lore
memBoundToVarEntry Maybe (Exp lore)
e (LParamMem -> VarEntry lore) -> LParamMem -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ NameInfo lore -> LParamMem
forall lore. Mem lore => NameInfo lore -> LParamMem
infoDec NameInfo lore
info
  case VarEntry lore
entry of
    MemVar Maybe (Exp lore)
_ MemEntry
entry' ->
      Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Code op
forall a. VName -> Space -> Code a
Imp.DeclareMem VName
name (Space -> Code op) -> Space -> Code op
forall a b. (a -> b) -> a -> b
$ MemEntry -> Space
entryMemSpace MemEntry
entry'
    ScalarVar Maybe (Exp lore)
_ ScalarEntry
entry' ->
      Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Volatility -> PrimType -> Code op
forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
name Volatility
Imp.Nonvolatile (PrimType -> Code op) -> PrimType -> Code op
forall a b. (a -> b) -> a -> b
$ ScalarEntry -> PrimType
entryScalarType ScalarEntry
entry'
    ArrayVar Maybe (Exp lore)
_ ArrayEntry
_ ->
      () -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    AccVar {} ->
      () -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name VarEntry lore
entry

dScope ::
  Mem lore =>
  Maybe (Exp lore) ->
  Scope lore ->
  ImpM lore r op ()
dScope :: forall lore r op.
Mem lore =>
Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
dScope Maybe (Exp lore)
e = ((VName, NameInfo lore) -> ImpM lore r op ())
-> [(VName, NameInfo lore)] -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((VName -> NameInfo lore -> ImpM lore r op ())
-> (VName, NameInfo lore) -> ImpM lore r op ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((VName -> NameInfo lore -> ImpM lore r op ())
 -> (VName, NameInfo lore) -> ImpM lore r op ())
-> (VName -> NameInfo lore -> ImpM lore r op ())
-> (VName, NameInfo lore)
-> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> VName -> NameInfo lore -> ImpM lore r op ()
forall lore r op.
Mem lore =>
Maybe (Exp lore) -> VName -> NameInfo lore -> ImpM lore r op ()
dInfo Maybe (Exp lore)
e) ([(VName, NameInfo lore)] -> ImpM lore r op ())
-> (Map VName (NameInfo lore) -> [(VName, NameInfo lore)])
-> Map VName (NameInfo lore)
-> ImpM lore r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map VName (NameInfo lore) -> [(VName, NameInfo lore)]
forall k a. Map k a -> [(k, a)]
M.toList

dArray :: VName -> PrimType -> ShapeBase SubExp -> MemBind -> ImpM lore r op ()
dArray :: forall lore r op.
VName -> PrimType -> Shape -> MemBind -> ImpM lore r op ()
dArray VName
name PrimType
bt Shape
shape MemBind
membind =
  VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
    Maybe (Exp lore) -> LParamMem -> VarEntry lore
forall lore. Maybe (Exp lore) -> LParamMem -> VarEntry lore
memBoundToVarEntry Maybe (Exp lore)
forall a. Maybe a
Nothing (LParamMem -> VarEntry lore) -> LParamMem -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ PrimType -> Shape -> NoUniqueness -> MemBind -> LParamMem
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt Shape
shape NoUniqueness
NoUniqueness MemBind
membind

everythingVolatile :: ImpM lore r op a -> ImpM lore r op a
everythingVolatile :: forall lore r op a. ImpM lore r op a -> ImpM lore r op a
everythingVolatile = (Env lore r op -> Env lore r op)
-> ImpM lore r op a -> ImpM lore r op a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((Env lore r op -> Env lore r op)
 -> ImpM lore r op a -> ImpM lore r op a)
-> (Env lore r op -> Env lore r op)
-> ImpM lore r op a
-> ImpM lore r op a
forall a b. (a -> b) -> a -> b
$ \Env lore r op
env -> Env lore r op
env {envVolatility :: Volatility
envVolatility = Volatility
Imp.Volatile}

-- | Remove the array targets.
funcallTargets :: Destination -> ImpM lore r op [VName]
funcallTargets :: forall lore r op. Destination -> ImpM lore r op [VName]
funcallTargets (Destination Maybe Int
_ [ValueDestination]
dests) =
  [[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[VName]] -> [VName])
-> ImpM lore r op [[VName]] -> ImpM lore r op [VName]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ValueDestination -> ImpM lore r op [VName])
-> [ValueDestination] -> ImpM lore r op [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ValueDestination -> ImpM lore r op [VName]
forall {m :: * -> *}. Monad m => ValueDestination -> m [VName]
funcallTarget [ValueDestination]
dests
  where
    funcallTarget :: ValueDestination -> m [VName]
funcallTarget (ScalarDestination VName
name) =
      [VName] -> m [VName]
forall (m :: * -> *) a. Monad m => a -> m a
return [VName
name]
    funcallTarget (ArrayDestination Maybe MemLocation
_) =
      [VName] -> m [VName]
forall (m :: * -> *) a. Monad m => a -> m a
return []
    funcallTarget (MemoryDestination VName
name) =
      [VName] -> m [VName]
forall (m :: * -> *) a. Monad m => a -> m a
return [VName
name]

-- | A typed variable, which we can turn into a typed expression, or
-- use as the target for an assignment.  This is used to aid in type
-- safety when doing code generation, by keeping the types straight.
-- It is still easy to cheat when you need to.
data TV t = TV VName PrimType

-- | Create a typed variable from a name and a dynamic type.  Note
-- that there is no guarantee that the dynamic type corresponds to the
-- inferred static type, but the latter will at least have to be used
-- consistently.
mkTV :: VName -> PrimType -> TV t
mkTV :: forall t. VName -> PrimType -> TV t
mkTV = VName -> PrimType -> TV t
forall t. VName -> PrimType -> TV t
TV

-- | Convert a typed variable to a size (a SubExp).
tvSize :: TV t -> Imp.DimSize
tvSize :: forall t. TV t -> SubExp
tvSize = VName -> SubExp
Var (VName -> SubExp) -> (TV t -> VName) -> TV t -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TV t -> VName
forall t. TV t -> VName
tvVar

-- | Convert a typed variable to a similarly typed expression.
tvExp :: TV t -> Imp.TExp t
tvExp :: forall t. TV t -> TExp t
tvExp (TV VName
v PrimType
t) = Exp -> TPrimExp t ExpLeaf
forall t v. PrimExp v -> TPrimExp t v
Imp.TPrimExp (Exp -> TPrimExp t ExpLeaf) -> Exp -> TPrimExp t ExpLeaf
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
v PrimType
t

-- | Extract the underlying variable name from a typed variable.
tvVar :: TV t -> VName
tvVar :: forall t. TV t -> VName
tvVar (TV VName
v PrimType
_) = VName
v

-- | Compile things to 'Imp.Exp'.
class ToExp a where
  -- | Compile to an 'Imp.Exp', where the type (must must still be a
  -- primitive) is deduced monadically.
  toExp :: a -> ImpM lore r op Imp.Exp

  -- | Compile where we know the type in advance.
  toExp' :: PrimType -> a -> Imp.Exp

  toInt64Exp :: a -> Imp.TExp Int64
  toInt64Exp = Exp -> TExp Int64
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (Exp -> TExp Int64) -> (a -> Exp) -> a -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> a -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int64

  toBoolExp :: a -> Imp.TExp Bool
  toBoolExp = Exp -> TExp Bool
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (Exp -> TExp Bool) -> (a -> Exp) -> a -> TExp Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> a -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
Bool

instance ToExp SubExp where
  toExp :: forall lore r op. SubExp -> ImpM lore r op Exp
toExp (Constant PrimValue
v) =
    Exp -> ImpM lore r op Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> ImpM lore r op Exp) -> Exp -> ImpM lore r op Exp
forall a b. (a -> b) -> a -> b
$ PrimValue -> Exp
forall v. PrimValue -> PrimExp v
Imp.ValueExp PrimValue
v
  toExp (Var VName
v) =
    VName -> ImpM lore r op (VarEntry lore)
forall lore r op. VName -> ImpM lore r op (VarEntry lore)
lookupVar VName
v ImpM lore r op (VarEntry lore)
-> (VarEntry lore -> ImpM lore r op Exp) -> ImpM lore r op Exp
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      ScalarVar Maybe (Exp lore)
_ (ScalarEntry PrimType
pt) ->
        Exp -> ImpM lore r op Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> ImpM lore r op Exp) -> Exp -> ImpM lore r op Exp
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
v PrimType
pt
      VarEntry lore
_ -> [Char] -> ImpM lore r op Exp
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM lore r op Exp) -> [Char] -> ImpM lore r op Exp
forall a b. (a -> b) -> a -> b
$ [Char]
"toExp SubExp: SubExp is not a primitive type: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
v

  toExp' :: PrimType -> SubExp -> Exp
toExp' PrimType
_ (Constant PrimValue
v) = PrimValue -> Exp
forall v. PrimValue -> PrimExp v
Imp.ValueExp PrimValue
v
  toExp' PrimType
t (Var VName
v) = VName -> PrimType -> Exp
Imp.var VName
v PrimType
t

instance ToExp (PrimExp VName) where
  toExp :: forall lore r op. PrimExp VName -> ImpM lore r op Exp
toExp = Exp -> ImpM lore r op Exp
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp -> ImpM lore r op Exp)
-> (PrimExp VName -> Exp) -> PrimExp VName -> ImpM lore r op Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> ExpLeaf) -> PrimExp VName -> Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> ExpLeaf
Imp.ScalarVar
  toExp' :: PrimType -> PrimExp VName -> Exp
toExp' PrimType
_ = (VName -> ExpLeaf) -> PrimExp VName -> Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> ExpLeaf
Imp.ScalarVar

addVar :: VName -> VarEntry lore -> ImpM lore r op ()
addVar :: forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name VarEntry lore
entry =
  (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ())
-> (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState lore r op
s -> ImpState lore r op
s {stateVTable :: VTable lore
stateVTable = VName -> VarEntry lore -> VTable lore -> VTable lore
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
name VarEntry lore
entry (VTable lore -> VTable lore) -> VTable lore -> VTable lore
forall a b. (a -> b) -> a -> b
$ ImpState lore r op -> VTable lore
forall lore r op. ImpState lore r op -> VTable lore
stateVTable ImpState lore r op
s}

localDefaultSpace :: Imp.Space -> ImpM lore r op a -> ImpM lore r op a
localDefaultSpace :: forall lore r op a. Space -> ImpM lore r op a -> ImpM lore r op a
localDefaultSpace Space
space = (Env lore r op -> Env lore r op)
-> ImpM lore r op a -> ImpM lore r op a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (\Env lore r op
env -> Env lore r op
env {envDefaultSpace :: Space
envDefaultSpace = Space
space})

askFunction :: ImpM lore r op (Maybe Name)
askFunction :: forall lore r op. ImpM lore r op (Maybe Name)
askFunction = (Env lore r op -> Maybe Name) -> ImpM lore r op (Maybe Name)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Maybe Name
forall lore r op. Env lore r op -> Maybe Name
envFunction

-- | Generate a 'VName', prefixed with 'askFunction' if it exists.
newVNameForFun :: String -> ImpM lore r op VName
newVNameForFun :: forall lore r op. [Char] -> ImpM lore r op VName
newVNameForFun [Char]
s = do
  Maybe [Char]
fname <- (Name -> [Char]) -> Maybe Name -> Maybe [Char]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Name -> [Char]
nameToString (Maybe Name -> Maybe [Char])
-> ImpM lore r op (Maybe Name) -> ImpM lore r op (Maybe [Char])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM lore r op (Maybe Name)
forall lore r op. ImpM lore r op (Maybe Name)
askFunction
  [Char] -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName ([Char] -> ImpM lore r op VName) -> [Char] -> ImpM lore r op VName
forall a b. (a -> b) -> a -> b
$ [Char] -> ShowS -> Maybe [Char] -> [Char]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe [Char]
"" ([Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
".") Maybe [Char]
fname [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
s

-- | Generate a 'Name', prefixed with 'askFunction' if it exists.
nameForFun :: String -> ImpM lore r op Name
nameForFun :: forall lore r op. [Char] -> ImpM lore r op Name
nameForFun [Char]
s = do
  Maybe Name
fname <- ImpM lore r op (Maybe Name)
forall lore r op. ImpM lore r op (Maybe Name)
askFunction
  Name -> ImpM lore r op Name
forall (m :: * -> *) a. Monad m => a -> m a
return (Name -> ImpM lore r op Name) -> Name -> ImpM lore r op Name
forall a b. (a -> b) -> a -> b
$ Name -> (Name -> Name) -> Maybe Name -> Name
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Name
"" (Name -> Name -> Name
forall a. Semigroup a => a -> a -> a
<> Name
".") Maybe Name
fname Name -> Name -> Name
forall a. Semigroup a => a -> a -> a
<> [Char] -> Name
nameFromString [Char]
s

askEnv :: ImpM lore r op r
askEnv :: forall lore r op. ImpM lore r op r
askEnv = (Env lore r op -> r) -> ImpM lore r op r
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> r
forall lore r op. Env lore r op -> r
envEnv

localEnv :: (r -> r) -> ImpM lore r op a -> ImpM lore r op a
localEnv :: forall r lore op a.
(r -> r) -> ImpM lore r op a -> ImpM lore r op a
localEnv r -> r
f = (Env lore r op -> Env lore r op)
-> ImpM lore r op a -> ImpM lore r op a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((Env lore r op -> Env lore r op)
 -> ImpM lore r op a -> ImpM lore r op a)
-> (Env lore r op -> Env lore r op)
-> ImpM lore r op a
-> ImpM lore r op a
forall a b. (a -> b) -> a -> b
$ \Env lore r op
env -> Env lore r op
env {envEnv :: r
envEnv = r -> r
f (r -> r) -> r -> r
forall a b. (a -> b) -> a -> b
$ Env lore r op -> r
forall lore r op. Env lore r op -> r
envEnv Env lore r op
env}

-- | The active attributes, including those for the statement
-- currently being compiled.
askAttrs :: ImpM lore r op Attrs
askAttrs :: forall lore r op. ImpM lore r op Attrs
askAttrs = (Env lore r op -> Attrs) -> ImpM lore r op Attrs
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Attrs
forall lore r op. Env lore r op -> Attrs
envAttrs

-- | Add more attributes to what is returning by 'askAttrs'.
localAttrs :: Attrs -> ImpM lore r op a -> ImpM lore r op a
localAttrs :: forall lore r op a. Attrs -> ImpM lore r op a -> ImpM lore r op a
localAttrs Attrs
attrs = (Env lore r op -> Env lore r op)
-> ImpM lore r op a -> ImpM lore r op a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((Env lore r op -> Env lore r op)
 -> ImpM lore r op a -> ImpM lore r op a)
-> (Env lore r op -> Env lore r op)
-> ImpM lore r op a
-> ImpM lore r op a
forall a b. (a -> b) -> a -> b
$ \Env lore r op
env -> Env lore r op
env {envAttrs :: Attrs
envAttrs = Attrs
attrs Attrs -> Attrs -> Attrs
forall a. Semigroup a => a -> a -> a
<> Env lore r op -> Attrs
forall lore r op. Env lore r op -> Attrs
envAttrs Env lore r op
env}

localOps :: Operations lore r op -> ImpM lore r op a -> ImpM lore r op a
localOps :: forall lore r op a.
Operations lore r op -> ImpM lore r op a -> ImpM lore r op a
localOps Operations lore r op
ops = (Env lore r op -> Env lore r op)
-> ImpM lore r op a -> ImpM lore r op a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((Env lore r op -> Env lore r op)
 -> ImpM lore r op a -> ImpM lore r op a)
-> (Env lore r op -> Env lore r op)
-> ImpM lore r op a
-> ImpM lore r op a
forall a b. (a -> b) -> a -> b
$ \Env lore r op
env ->
  Env lore r op
env
    { envExpCompiler :: ExpCompiler lore r op
envExpCompiler = Operations lore r op -> ExpCompiler lore r op
forall lore r op. Operations lore r op -> ExpCompiler lore r op
opsExpCompiler Operations lore r op
ops,
      envStmsCompiler :: StmsCompiler lore r op
envStmsCompiler = Operations lore r op -> StmsCompiler lore r op
forall lore r op. Operations lore r op -> StmsCompiler lore r op
opsStmsCompiler Operations lore r op
ops,
      envCopyCompiler :: CopyCompiler lore r op
envCopyCompiler = Operations lore r op -> CopyCompiler lore r op
forall lore r op. Operations lore r op -> CopyCompiler lore r op
opsCopyCompiler Operations lore r op
ops,
      envOpCompiler :: OpCompiler lore r op
envOpCompiler = Operations lore r op -> OpCompiler lore r op
forall lore r op. Operations lore r op -> OpCompiler lore r op
opsOpCompiler Operations lore r op
ops,
      envAllocCompilers :: Map Space (AllocCompiler lore r op)
envAllocCompilers = Operations lore r op -> Map Space (AllocCompiler lore r op)
forall lore r op.
Operations lore r op -> Map Space (AllocCompiler lore r op)
opsAllocCompilers Operations lore r op
ops
    }

-- | Get the current symbol table.
getVTable :: ImpM lore r op (VTable lore)
getVTable :: forall lore r op. ImpM lore r op (VTable lore)
getVTable = (ImpState lore r op -> VTable lore) -> ImpM lore r op (VTable lore)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ImpState lore r op -> VTable lore
forall lore r op. ImpState lore r op -> VTable lore
stateVTable

putVTable :: VTable lore -> ImpM lore r op ()
putVTable :: forall lore r op. VTable lore -> ImpM lore r op ()
putVTable VTable lore
vtable = (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ())
-> (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState lore r op
s -> ImpState lore r op
s {stateVTable :: VTable lore
stateVTable = VTable lore
vtable}

-- | Run an action with a modified symbol table.  All changes to the
-- symbol table will be reverted once the action is done!
localVTable :: (VTable lore -> VTable lore) -> ImpM lore r op a -> ImpM lore r op a
localVTable :: forall lore r op a.
(VTable lore -> VTable lore)
-> ImpM lore r op a -> ImpM lore r op a
localVTable VTable lore -> VTable lore
f ImpM lore r op a
m = do
  VTable lore
old_vtable <- ImpM lore r op (VTable lore)
forall lore r op. ImpM lore r op (VTable lore)
getVTable
  VTable lore -> ImpM lore r op ()
forall lore r op. VTable lore -> ImpM lore r op ()
putVTable (VTable lore -> ImpM lore r op ())
-> VTable lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VTable lore -> VTable lore
f VTable lore
old_vtable
  a
a <- ImpM lore r op a
m
  VTable lore -> ImpM lore r op ()
forall lore r op. VTable lore -> ImpM lore r op ()
putVTable VTable lore
old_vtable
  a -> ImpM lore r op a
forall (m :: * -> *) a. Monad m => a -> m a
return a
a

lookupVar :: VName -> ImpM lore r op (VarEntry lore)
lookupVar :: forall lore r op. VName -> ImpM lore r op (VarEntry lore)
lookupVar VName
name = do
  Maybe (VarEntry lore)
res <- (ImpState lore r op -> Maybe (VarEntry lore))
-> ImpM lore r op (Maybe (VarEntry lore))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ImpState lore r op -> Maybe (VarEntry lore))
 -> ImpM lore r op (Maybe (VarEntry lore)))
-> (ImpState lore r op -> Maybe (VarEntry lore))
-> ImpM lore r op (Maybe (VarEntry lore))
forall a b. (a -> b) -> a -> b
$ VName -> Map VName (VarEntry lore) -> Maybe (VarEntry lore)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name (Map VName (VarEntry lore) -> Maybe (VarEntry lore))
-> (ImpState lore r op -> Map VName (VarEntry lore))
-> ImpState lore r op
-> Maybe (VarEntry lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ImpState lore r op -> Map VName (VarEntry lore)
forall lore r op. ImpState lore r op -> VTable lore
stateVTable
  case Maybe (VarEntry lore)
res of
    Just VarEntry lore
entry -> VarEntry lore -> ImpM lore r op (VarEntry lore)
forall (m :: * -> *) a. Monad m => a -> m a
return VarEntry lore
entry
    Maybe (VarEntry lore)
_ -> [Char] -> ImpM lore r op (VarEntry lore)
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM lore r op (VarEntry lore))
-> [Char] -> ImpM lore r op (VarEntry lore)
forall a b. (a -> b) -> a -> b
$ [Char]
"Unknown variable: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
name

lookupArray :: VName -> ImpM lore r op ArrayEntry
lookupArray :: forall lore r op. VName -> ImpM lore r op ArrayEntry
lookupArray VName
name = do
  VarEntry lore
res <- VName -> ImpM lore r op (VarEntry lore)
forall lore r op. VName -> ImpM lore r op (VarEntry lore)
lookupVar VName
name
  case VarEntry lore
res of
    ArrayVar Maybe (Exp lore)
_ ArrayEntry
entry -> ArrayEntry -> ImpM lore r op ArrayEntry
forall (m :: * -> *) a. Monad m => a -> m a
return ArrayEntry
entry
    VarEntry lore
_ -> [Char] -> ImpM lore r op ArrayEntry
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM lore r op ArrayEntry)
-> [Char] -> ImpM lore r op ArrayEntry
forall a b. (a -> b) -> a -> b
$ [Char]
"ImpGen.lookupArray: not an array: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
name

lookupMemory :: VName -> ImpM lore r op MemEntry
lookupMemory :: forall lore r op. VName -> ImpM lore r op MemEntry
lookupMemory VName
name = do
  VarEntry lore
res <- VName -> ImpM lore r op (VarEntry lore)
forall lore r op. VName -> ImpM lore r op (VarEntry lore)
lookupVar VName
name
  case VarEntry lore
res of
    MemVar Maybe (Exp lore)
_ MemEntry
entry -> MemEntry -> ImpM lore r op MemEntry
forall (m :: * -> *) a. Monad m => a -> m a
return MemEntry
entry
    VarEntry lore
_ -> [Char] -> ImpM lore r op MemEntry
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM lore r op MemEntry)
-> [Char] -> ImpM lore r op MemEntry
forall a b. (a -> b) -> a -> b
$ [Char]
"Unknown memory block: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
name

lookupArraySpace :: VName -> ImpM lore r op Space
lookupArraySpace :: forall lore r op. VName -> ImpM lore r op Space
lookupArraySpace =
  (MemEntry -> Space)
-> ImpM lore r op MemEntry -> ImpM lore r op Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap MemEntry -> Space
entryMemSpace (ImpM lore r op MemEntry -> ImpM lore r op Space)
-> (VName -> ImpM lore r op MemEntry)
-> VName
-> ImpM lore r op Space
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> ImpM lore r op MemEntry
forall lore r op. VName -> ImpM lore r op MemEntry
lookupMemory
    (VName -> ImpM lore r op Space)
-> (VName -> ImpM lore r op VName) -> VName -> ImpM lore r op Space
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< (ArrayEntry -> VName)
-> ImpM lore r op ArrayEntry -> ImpM lore r op VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (MemLocation -> VName
memLocationName (MemLocation -> VName)
-> (ArrayEntry -> MemLocation) -> ArrayEntry -> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArrayEntry -> MemLocation
entryArrayLocation) (ImpM lore r op ArrayEntry -> ImpM lore r op VName)
-> (VName -> ImpM lore r op ArrayEntry)
-> VName
-> ImpM lore r op VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> ImpM lore r op ArrayEntry
forall lore r op. VName -> ImpM lore r op ArrayEntry
lookupArray

-- | In the case of a histogram-like accumulator, also sets the index
-- parameters.
lookupAcc ::
  VName ->
  [Imp.TExp Int64] ->
  ImpM lore r op (VName, Space, [VName], [Imp.TExp Int64], Maybe (Lambda lore))
lookupAcc :: forall lore r op.
VName
-> Shape (TExp Int64)
-> ImpM
     lore
     r
     op
     (VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda lore))
lookupAcc VName
name Shape (TExp Int64)
is = do
  VarEntry lore
res <- VName -> ImpM lore r op (VarEntry lore)
forall lore r op. VName -> ImpM lore r op (VarEntry lore)
lookupVar VName
name
  case VarEntry lore
res of
    AccVar Maybe (Exp lore)
_ (VName
acc, Shape
ispace, [Type]
_) -> do
      Maybe ([VName], Maybe (Lambda lore, [SubExp]))
acc' <- (ImpState lore r op
 -> Maybe ([VName], Maybe (Lambda lore, [SubExp])))
-> ImpM lore r op (Maybe ([VName], Maybe (Lambda lore, [SubExp])))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ImpState lore r op
  -> Maybe ([VName], Maybe (Lambda lore, [SubExp])))
 -> ImpM lore r op (Maybe ([VName], Maybe (Lambda lore, [SubExp]))))
-> (ImpState lore r op
    -> Maybe ([VName], Maybe (Lambda lore, [SubExp])))
-> ImpM lore r op (Maybe ([VName], Maybe (Lambda lore, [SubExp])))
forall a b. (a -> b) -> a -> b
$ VName
-> Map VName ([VName], Maybe (Lambda lore, [SubExp]))
-> Maybe ([VName], Maybe (Lambda lore, [SubExp]))
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
acc (Map VName ([VName], Maybe (Lambda lore, [SubExp]))
 -> Maybe ([VName], Maybe (Lambda lore, [SubExp])))
-> (ImpState lore r op
    -> Map VName ([VName], Maybe (Lambda lore, [SubExp])))
-> ImpState lore r op
-> Maybe ([VName], Maybe (Lambda lore, [SubExp]))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ImpState lore r op
-> Map VName ([VName], Maybe (Lambda lore, [SubExp]))
forall lore r op.
ImpState lore r op
-> Map VName ([VName], Maybe (Lambda lore, [SubExp]))
stateAccs
      case Maybe ([VName], Maybe (Lambda lore, [SubExp]))
acc' of
        Just ([], Maybe (Lambda lore, [SubExp])
_) ->
          [Char]
-> ImpM
     lore
     r
     op
     (VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda lore))
forall a. HasCallStack => [Char] -> a
error ([Char]
 -> ImpM
      lore
      r
      op
      (VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda lore)))
-> [Char]
-> ImpM
     lore
     r
     op
     (VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda lore))
forall a b. (a -> b) -> a -> b
$ [Char]
"Accumulator with no arrays: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
name
        Just (arrs :: [VName]
arrs@(VName
arr : [VName]
_), Just (Lambda lore
op, [SubExp]
_)) -> do
          Space
space <- VName -> ImpM lore r op Space
forall lore r op. VName -> ImpM lore r op Space
lookupArraySpace VName
arr
          let ([Param (LParamInfo lore)]
i_params, [Param (LParamInfo lore)]
ps) = Int
-> [Param (LParamInfo lore)]
-> ([Param (LParamInfo lore)], [Param (LParamInfo lore)])
forall a. Int -> [a] -> ([a], [a])
splitAt (Shape (TExp Int64) -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape (TExp Int64)
is) ([Param (LParamInfo lore)]
 -> ([Param (LParamInfo lore)], [Param (LParamInfo lore)]))
-> [Param (LParamInfo lore)]
-> ([Param (LParamInfo lore)], [Param (LParamInfo lore)])
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Param (LParamInfo lore)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
op
          (VName -> TExp Int64 -> ImpM lore r op ())
-> [VName] -> Shape (TExp Int64) -> ImpM lore r op ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TExp Int64 -> ImpM lore r op ()
forall t lore r op. VName -> TExp t -> ImpM lore r op ()
dPrimV_ ((Param (LParamInfo lore) -> VName)
-> [Param (LParamInfo lore)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (LParamInfo lore) -> VName
forall dec. Param dec -> VName
paramName [Param (LParamInfo lore)]
i_params) Shape (TExp Int64)
is
          (VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda lore))
-> ImpM
     lore
     r
     op
     (VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda lore))
forall (m :: * -> *) a. Monad m => a -> m a
return
            ( VName
acc,
              Space
space,
              [VName]
arrs,
              (SubExp -> TExp Int64) -> [SubExp] -> Shape (TExp Int64)
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
ispace),
              Lambda lore -> Maybe (Lambda lore)
forall a. a -> Maybe a
Just Lambda lore
op {lambdaParams :: [Param (LParamInfo lore)]
lambdaParams = [Param (LParamInfo lore)]
ps}
            )
        Just (arrs :: [VName]
arrs@(VName
arr : [VName]
_), Maybe (Lambda lore, [SubExp])
Nothing) -> do
          Space
space <- VName -> ImpM lore r op Space
forall lore r op. VName -> ImpM lore r op Space
lookupArraySpace VName
arr
          (VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda lore))
-> ImpM
     lore
     r
     op
     (VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
acc, Space
space, [VName]
arrs, (SubExp -> TExp Int64) -> [SubExp] -> Shape (TExp Int64)
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
ispace), Maybe (Lambda lore)
forall a. Maybe a
Nothing)
        Maybe ([VName], Maybe (Lambda lore, [SubExp]))
Nothing ->
          [Char]
-> ImpM
     lore
     r
     op
     (VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda lore))
forall a. HasCallStack => [Char] -> a
error ([Char]
 -> ImpM
      lore
      r
      op
      (VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda lore)))
-> [Char]
-> ImpM
     lore
     r
     op
     (VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda lore))
forall a b. (a -> b) -> a -> b
$ [Char]
"ImpGen.lookupAcc: unlisted accumulator: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
name
    VarEntry lore
_ -> [Char]
-> ImpM
     lore
     r
     op
     (VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda lore))
forall a. HasCallStack => [Char] -> a
error ([Char]
 -> ImpM
      lore
      r
      op
      (VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda lore)))
-> [Char]
-> ImpM
     lore
     r
     op
     (VName, Space, [VName], Shape (TExp Int64), Maybe (Lambda lore))
forall a b. (a -> b) -> a -> b
$ [Char]
"ImpGen.lookupAcc: not an accumulator: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
name

destinationFromPattern :: Mem lore => Pattern lore -> ImpM lore r op Destination
destinationFromPattern :: forall lore r op.
Mem lore =>
Pattern lore -> ImpM lore r op Destination
destinationFromPattern Pattern lore
pat =
  ([ValueDestination] -> Destination)
-> ImpM lore r op [ValueDestination] -> ImpM lore r op Destination
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Maybe Int -> [ValueDestination] -> Destination
Destination (VName -> Int
baseTag (VName -> Int) -> Maybe VName -> Maybe Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [VName] -> Maybe VName
forall a. [a] -> Maybe a
maybeHead (PatternT LParamMem -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern lore
PatternT LParamMem
pat))) (ImpM lore r op [ValueDestination] -> ImpM lore r op Destination)
-> ([PatElemT LParamMem] -> ImpM lore r op [ValueDestination])
-> [PatElemT LParamMem]
-> ImpM lore r op Destination
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElemT LParamMem -> ImpM lore r op ValueDestination)
-> [PatElemT LParamMem] -> ImpM lore r op [ValueDestination]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM PatElemT LParamMem -> ImpM lore r op ValueDestination
forall {dec} {lore} {r} {op}.
PatElemT dec -> ImpM lore r op ValueDestination
inspect ([PatElemT LParamMem] -> ImpM lore r op Destination)
-> [PatElemT LParamMem] -> ImpM lore r op Destination
forall a b. (a -> b) -> a -> b
$
    PatternT LParamMem -> [PatElemT LParamMem]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern lore
PatternT LParamMem
pat
  where
    inspect :: PatElemT dec -> ImpM lore r op ValueDestination
inspect PatElemT dec
patElem = do
      let name :: VName
name = PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT dec
patElem
      VarEntry lore
entry <- VName -> ImpM lore r op (VarEntry lore)
forall lore r op. VName -> ImpM lore r op (VarEntry lore)
lookupVar VName
name
      case VarEntry lore
entry of
        ArrayVar Maybe (Exp lore)
_ (ArrayEntry MemLocation {} PrimType
_) ->
          ValueDestination -> ImpM lore r op ValueDestination
forall (m :: * -> *) a. Monad m => a -> m a
return (ValueDestination -> ImpM lore r op ValueDestination)
-> ValueDestination -> ImpM lore r op ValueDestination
forall a b. (a -> b) -> a -> b
$ Maybe MemLocation -> ValueDestination
ArrayDestination Maybe MemLocation
forall a. Maybe a
Nothing
        MemVar {} ->
          ValueDestination -> ImpM lore r op ValueDestination
forall (m :: * -> *) a. Monad m => a -> m a
return (ValueDestination -> ImpM lore r op ValueDestination)
-> ValueDestination -> ImpM lore r op ValueDestination
forall a b. (a -> b) -> a -> b
$ VName -> ValueDestination
MemoryDestination VName
name
        ScalarVar {} ->
          ValueDestination -> ImpM lore r op ValueDestination
forall (m :: * -> *) a. Monad m => a -> m a
return (ValueDestination -> ImpM lore r op ValueDestination)
-> ValueDestination -> ImpM lore r op ValueDestination
forall a b. (a -> b) -> a -> b
$ VName -> ValueDestination
ScalarDestination VName
name
        AccVar {} ->
          ValueDestination -> ImpM lore r op ValueDestination
forall (m :: * -> *) a. Monad m => a -> m a
return (ValueDestination -> ImpM lore r op ValueDestination)
-> ValueDestination -> ImpM lore r op ValueDestination
forall a b. (a -> b) -> a -> b
$ Maybe MemLocation -> ValueDestination
ArrayDestination Maybe MemLocation
forall a. Maybe a
Nothing

fullyIndexArray ::
  VName ->
  [Imp.TExp Int64] ->
  ImpM lore r op (VName, Imp.Space, Count Elements (Imp.TExp Int64))
fullyIndexArray :: forall lore r op.
VName
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray VName
name Shape (TExp Int64)
indices = do
  ArrayEntry
arr <- VName -> ImpM lore r op ArrayEntry
forall lore r op. VName -> ImpM lore r op ArrayEntry
lookupArray VName
name
  MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall lore r op.
MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' (ArrayEntry -> MemLocation
entryArrayLocation ArrayEntry
arr) Shape (TExp Int64)
indices

fullyIndexArray' ::
  MemLocation ->
  [Imp.TExp Int64] ->
  ImpM lore r op (VName, Imp.Space, Count Elements (Imp.TExp Int64))
fullyIndexArray' :: forall lore r op.
MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' (MemLocation VName
mem [SubExp]
_ IxFun (TExp Int64)
ixfun) Shape (TExp Int64)
indices = do
  Space
space <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM lore r op MemEntry -> ImpM lore r op Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM lore r op MemEntry
forall lore r op. VName -> ImpM lore r op MemEntry
lookupMemory VName
mem
  let indices' :: Shape (TExp Int64)
indices' = case Space
space of
        ScalarSpace [SubExp]
ds PrimType
_ ->
          let (Shape (TExp Int64)
zero_is, Shape (TExp Int64)
is) = Int
-> Shape (TExp Int64) -> (Shape (TExp Int64), Shape (TExp Int64))
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
ds) Shape (TExp Int64)
indices
           in (TExp Int64 -> TExp Int64)
-> Shape (TExp Int64) -> Shape (TExp Int64)
forall a b. (a -> b) -> [a] -> [b]
map (TExp Int64 -> TExp Int64 -> TExp Int64
forall a b. a -> b -> a
const TExp Int64
0) Shape (TExp Int64)
zero_is Shape (TExp Int64) -> Shape (TExp Int64) -> Shape (TExp Int64)
forall a. [a] -> [a] -> [a]
++ Shape (TExp Int64)
is
        Space
_ -> Shape (TExp Int64)
indices
  (VName, Space, Count Elements (TExp Int64))
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( VName
mem,
      Space
space,
      TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
elements (TExp Int64 -> Count Elements (TExp Int64))
-> TExp Int64 -> Count Elements (TExp Int64)
forall a b. (a -> b) -> a -> b
$ IxFun (TExp Int64) -> Shape (TExp Int64) -> TExp Int64
forall num.
(IntegralExp num, Eq num) =>
IxFun num -> Indices num -> num
IxFun.index IxFun (TExp Int64)
ixfun Shape (TExp Int64)
indices'
    )

-- More complicated read/write operations that use index functions.

copy :: CopyCompiler lore r op
copy :: forall lore r op. CopyCompiler lore r op
copy PrimType
bt MemLocation
dest [DimIndex (TExp Int64)]
destslice MemLocation
src [DimIndex (TExp Int64)]
srcslice = do
  CopyCompiler lore r op
cc <- (Env lore r op -> CopyCompiler lore r op)
-> ImpM lore r op (CopyCompiler lore r op)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> CopyCompiler lore r op
forall lore r op. Env lore r op -> CopyCompiler lore r op
envCopyCompiler
  CopyCompiler lore r op
cc PrimType
bt MemLocation
dest [DimIndex (TExp Int64)]
destslice MemLocation
src [DimIndex (TExp Int64)]
srcslice

-- | Is this copy really a mapping with transpose?
isMapTransposeCopy ::
  PrimType ->
  MemLocation ->
  Slice (Imp.TExp Int64) ->
  MemLocation ->
  Slice (Imp.TExp Int64) ->
  Maybe
    ( Imp.TExp Int64,
      Imp.TExp Int64,
      Imp.TExp Int64,
      Imp.TExp Int64,
      Imp.TExp Int64
    )
isMapTransposeCopy :: PrimType
-> MemLocation
-> [DimIndex (TExp Int64)]
-> MemLocation
-> [DimIndex (TExp Int64)]
-> Maybe
     (TExp Int64, TExp Int64, TExp Int64, TExp Int64, TExp Int64)
isMapTransposeCopy
  PrimType
bt
  (MemLocation VName
_ [SubExp]
_ IxFun (TExp Int64)
destIxFun)
  [DimIndex (TExp Int64)]
destslice
  (MemLocation VName
_ [SubExp]
_ IxFun (TExp Int64)
srcIxFun)
  [DimIndex (TExp Int64)]
srcslice
    | Just (TExp Int64
dest_offset, [(Int, TExp Int64)]
perm_and_destshape) <- IxFun (TExp Int64)
-> TExp Int64 -> Maybe (TExp Int64, [(Int, TExp Int64)])
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe (num, [(Int, num)])
IxFun.rearrangeWithOffset IxFun (TExp Int64)
destIxFun' TExp Int64
bt_size,
      ([Int]
perm, Shape (TExp Int64)
destshape) <- [(Int, TExp Int64)] -> ([Int], Shape (TExp Int64))
forall a b. [(a, b)] -> ([a], [b])
unzip [(Int, TExp Int64)]
perm_and_destshape,
      Just TExp Int64
src_offset <- IxFun (TExp Int64) -> TExp Int64 -> Maybe (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
IxFun.linearWithOffset IxFun (TExp Int64)
srcIxFun' TExp Int64
bt_size,
      Just (Int
r1, Int
r2, Int
_) <- [Int] -> Maybe (Int, Int, Int)
isMapTranspose [Int]
perm =
      Shape (TExp Int64)
-> ((Shape (TExp Int64), Shape (TExp Int64))
    -> (Shape (TExp Int64), Shape (TExp Int64)))
-> Int
-> Int
-> TExp Int64
-> TExp Int64
-> Maybe
     (TExp Int64, TExp Int64, TExp Int64, TExp Int64, TExp Int64)
forall {t :: * -> *} {t :: * -> *} {c} {d} {e} {m :: * -> *} {a}
       {b}.
(Foldable t, Foldable t, Num c, Num d, Num e, Monad m) =>
[c]
-> (([c], [c]) -> (t d, t e))
-> Int
-> Int
-> a
-> b
-> m (a, b, c, d, e)
isOk Shape (TExp Int64)
destshape (Shape (TExp Int64), Shape (TExp Int64))
-> (Shape (TExp Int64), Shape (TExp Int64))
forall {b} {a}. (b, a) -> (a, b)
swap Int
r1 Int
r2 TExp Int64
dest_offset TExp Int64
src_offset
    | Just TExp Int64
dest_offset <- IxFun (TExp Int64) -> TExp Int64 -> Maybe (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
IxFun.linearWithOffset IxFun (TExp Int64)
destIxFun' TExp Int64
bt_size,
      Just (TExp Int64
src_offset, [(Int, TExp Int64)]
perm_and_srcshape) <- IxFun (TExp Int64)
-> TExp Int64 -> Maybe (TExp Int64, [(Int, TExp Int64)])
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe (num, [(Int, num)])
IxFun.rearrangeWithOffset IxFun (TExp Int64)
srcIxFun' TExp Int64
bt_size,
      ([Int]
perm, Shape (TExp Int64)
srcshape) <- [(Int, TExp Int64)] -> ([Int], Shape (TExp Int64))
forall a b. [(a, b)] -> ([a], [b])
unzip [(Int, TExp Int64)]
perm_and_srcshape,
      Just (Int
r1, Int
r2, Int
_) <- [Int] -> Maybe (Int, Int, Int)
isMapTranspose [Int]
perm =
      Shape (TExp Int64)
-> ((Shape (TExp Int64), Shape (TExp Int64))
    -> (Shape (TExp Int64), Shape (TExp Int64)))
-> Int
-> Int
-> TExp Int64
-> TExp Int64
-> Maybe
     (TExp Int64, TExp Int64, TExp Int64, TExp Int64, TExp Int64)
forall {t :: * -> *} {t :: * -> *} {c} {d} {e} {m :: * -> *} {a}
       {b}.
(Foldable t, Foldable t, Num c, Num d, Num e, Monad m) =>
[c]
-> (([c], [c]) -> (t d, t e))
-> Int
-> Int
-> a
-> b
-> m (a, b, c, d, e)
isOk Shape (TExp Int64)
srcshape (Shape (TExp Int64), Shape (TExp Int64))
-> (Shape (TExp Int64), Shape (TExp Int64))
forall a. a -> a
id Int
r1 Int
r2 TExp Int64
dest_offset TExp Int64
src_offset
    | Bool
otherwise =
      Maybe (TExp Int64, TExp Int64, TExp Int64, TExp Int64, TExp Int64)
forall a. Maybe a
Nothing
    where
      bt_size :: TExp Int64
bt_size = PrimType -> TExp Int64
forall a. Num a => PrimType -> a
primByteSize PrimType
bt
      swap :: (b, a) -> (a, b)
swap (b
x, a
y) = (a
y, b
x)

      destIxFun' :: IxFun (TExp Int64)
destIxFun' = IxFun (TExp Int64) -> [DimIndex (TExp Int64)] -> IxFun (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TExp Int64)
destIxFun [DimIndex (TExp Int64)]
destslice
      srcIxFun' :: IxFun (TExp Int64)
srcIxFun' = IxFun (TExp Int64) -> [DimIndex (TExp Int64)] -> IxFun (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TExp Int64)
srcIxFun [DimIndex (TExp Int64)]
srcslice

      isOk :: [c]
-> (([c], [c]) -> (t d, t e))
-> Int
-> Int
-> a
-> b
-> m (a, b, c, d, e)
isOk [c]
shape ([c], [c]) -> (t d, t e)
f Int
r1 Int
r2 a
dest_offset b
src_offset = do
        let (c
num_arrays, d
size_x, e
size_y) = [c] -> (([c], [c]) -> (t d, t e)) -> Int -> Int -> (c, d, e)
forall {t :: * -> *} {t :: * -> *} {a} {b} {c}.
(Foldable t, Foldable t, Num a, Num b, Num c) =>
[a] -> (([a], [a]) -> (t b, t c)) -> Int -> Int -> (a, b, c)
getSizes [c]
shape ([c], [c]) -> (t d, t e)
f Int
r1 Int
r2
        (a, b, c, d, e) -> m (a, b, c, d, e)
forall (m :: * -> *) a. Monad m => a -> m a
return
          ( a
dest_offset,
            b
src_offset,
            c
num_arrays,
            d
size_x,
            e
size_y
          )

      getSizes :: [a] -> (([a], [a]) -> (t b, t c)) -> Int -> Int -> (a, b, c)
getSizes [a]
shape ([a], [a]) -> (t b, t c)
f Int
r1 Int
r2 =
        let ([a]
mapped, [a]
notmapped) = Int -> [a] -> ([a], [a])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
r1 [a]
shape
            (t b
pretrans, t c
posttrans) = ([a], [a]) -> (t b, t c)
f (([a], [a]) -> (t b, t c)) -> ([a], [a]) -> (t b, t c)
forall a b. (a -> b) -> a -> b
$ Int -> [a] -> ([a], [a])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
r2 [a]
notmapped
         in ([a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [a]
mapped, t b -> b
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product t b
pretrans, t c -> c
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product t c
posttrans)

mapTransposeName :: PrimType -> String
mapTransposeName :: PrimType -> [Char]
mapTransposeName PrimType
bt = [Char]
"map_transpose_" [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ PrimType -> [Char]
forall a. Pretty a => a -> [Char]
pretty PrimType
bt

mapTransposeForType :: PrimType -> ImpM lore r op Name
mapTransposeForType :: forall lore r op. PrimType -> ImpM lore r op Name
mapTransposeForType PrimType
bt = do
  let fname :: Name
fname = [Char] -> Name
nameFromString ([Char] -> Name) -> [Char] -> Name
forall a b. (a -> b) -> a -> b
$ [Char]
"builtin#" [Char] -> ShowS
forall a. Semigroup a => a -> a -> a
<> PrimType -> [Char]
mapTransposeName PrimType
bt

  Bool
exists <- Name -> ImpM lore r op Bool
forall lore r op. Name -> ImpM lore r op Bool
hasFunction Name
fname
  Bool -> ImpM lore r op () -> ImpM lore r op ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
exists (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Name -> Function op -> ImpM lore r op ()
forall op lore r. Name -> Function op -> ImpM lore r op ()
emitFunction Name
fname (Function op -> ImpM lore r op ())
-> Function op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Name -> PrimType -> Function op
forall op. Name -> PrimType -> Function op
mapTransposeFunction Name
fname PrimType
bt

  Name -> ImpM lore r op Name
forall (m :: * -> *) a. Monad m => a -> m a
return Name
fname

-- | Use an 'Imp.Copy' if possible, otherwise 'copyElementWise'.
defaultCopy :: CopyCompiler lore r op
defaultCopy :: forall lore r op. CopyCompiler lore r op
defaultCopy PrimType
pt MemLocation
dest [DimIndex (TExp Int64)]
destslice MemLocation
src [DimIndex (TExp Int64)]
srcslice
  | Just
      ( TExp Int64
destoffset,
        TExp Int64
srcoffset,
        TExp Int64
num_arrays,
        TExp Int64
size_x,
        TExp Int64
size_y
        ) <-
      PrimType
-> MemLocation
-> [DimIndex (TExp Int64)]
-> MemLocation
-> [DimIndex (TExp Int64)]
-> Maybe
     (TExp Int64, TExp Int64, TExp Int64, TExp Int64, TExp Int64)
isMapTransposeCopy PrimType
pt MemLocation
dest [DimIndex (TExp Int64)]
destslice MemLocation
src [DimIndex (TExp Int64)]
srcslice = do
    Name
fname <- PrimType -> ImpM lore r op Name
forall lore r op. PrimType -> ImpM lore r op Name
mapTransposeForType PrimType
pt
    Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
      [VName] -> Name -> [Arg] -> Code op
forall a. [VName] -> Name -> [Arg] -> Code a
Imp.Call
        []
        Name
fname
        ([Arg] -> Code op) -> [Arg] -> Code op
forall a b. (a -> b) -> a -> b
$ PrimType
-> VName
-> Count Bytes (TExp Int64)
-> VName
-> Count Bytes (TExp Int64)
-> TExp Int64
-> TExp Int64
-> TExp Int64
-> [Arg]
transposeArgs
          PrimType
pt
          VName
destmem
          (TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
bytes TExp Int64
destoffset)
          VName
srcmem
          (TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
bytes TExp Int64
srcoffset)
          TExp Int64
num_arrays
          TExp Int64
size_x
          TExp Int64
size_y
  | Just TExp Int64
destoffset <-
      IxFun (TExp Int64) -> TExp Int64 -> Maybe (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
IxFun.linearWithOffset (IxFun (TExp Int64) -> [DimIndex (TExp Int64)] -> IxFun (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TExp Int64)
dest_ixfun [DimIndex (TExp Int64)]
destslice) TExp Int64
pt_size,
    Just TExp Int64
srcoffset <-
      IxFun (TExp Int64) -> TExp Int64 -> Maybe (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
IxFun.linearWithOffset (IxFun (TExp Int64) -> [DimIndex (TExp Int64)] -> IxFun (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TExp Int64)
src_ixfun [DimIndex (TExp Int64)]
srcslice) TExp Int64
pt_size = do
    Space
srcspace <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM lore r op MemEntry -> ImpM lore r op Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM lore r op MemEntry
forall lore r op. VName -> ImpM lore r op MemEntry
lookupMemory VName
srcmem
    Space
destspace <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM lore r op MemEntry -> ImpM lore r op Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM lore r op MemEntry
forall lore r op. VName -> ImpM lore r op MemEntry
lookupMemory VName
destmem
    if Space -> Bool
isScalarSpace Space
srcspace Bool -> Bool -> Bool
|| Space -> Bool
isScalarSpace Space
destspace
      then CopyCompiler lore r op
forall lore r op. CopyCompiler lore r op
copyElementWise PrimType
pt MemLocation
dest [DimIndex (TExp Int64)]
destslice MemLocation
src [DimIndex (TExp Int64)]
srcslice
      else
        Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
          VName
-> Count Bytes (TExp Int64)
-> Space
-> VName
-> Count Bytes (TExp Int64)
-> Space
-> Count Bytes (TExp Int64)
-> Code op
forall a.
VName
-> Count Bytes (TExp Int64)
-> Space
-> VName
-> Count Bytes (TExp Int64)
-> Space
-> Count Bytes (TExp Int64)
-> Code a
Imp.Copy
            VName
destmem
            (TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
bytes TExp Int64
destoffset)
            Space
destspace
            VName
srcmem
            (TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
bytes TExp Int64
srcoffset)
            Space
srcspace
            (Count Bytes (TExp Int64) -> Code op)
-> Count Bytes (TExp Int64) -> Code op
forall a b. (a -> b) -> a -> b
$ Count Elements (TExp Int64)
num_elems Count Elements (TExp Int64) -> PrimType -> Count Bytes (TExp Int64)
`withElemType` PrimType
pt
  | Bool
otherwise =
    CopyCompiler lore r op
forall lore r op. CopyCompiler lore r op
copyElementWise PrimType
pt MemLocation
dest [DimIndex (TExp Int64)]
destslice MemLocation
src [DimIndex (TExp Int64)]
srcslice
  where
    pt_size :: TExp Int64
pt_size = PrimType -> TExp Int64
forall a. Num a => PrimType -> a
primByteSize PrimType
pt
    num_elems :: Count Elements (TExp Int64)
num_elems = TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
Imp.elements (TExp Int64 -> Count Elements (TExp Int64))
-> TExp Int64 -> Count Elements (TExp Int64)
forall a b. (a -> b) -> a -> b
$ Shape (TExp Int64) -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product (Shape (TExp Int64) -> TExp Int64)
-> Shape (TExp Int64) -> TExp Int64
forall a b. (a -> b) -> a -> b
$ [DimIndex (TExp Int64)] -> Shape (TExp Int64)
forall d. Slice d -> [d]
sliceDims [DimIndex (TExp Int64)]
srcslice

    MemLocation VName
destmem [SubExp]
_ IxFun (TExp Int64)
dest_ixfun = MemLocation
dest
    MemLocation VName
srcmem [SubExp]
_ IxFun (TExp Int64)
src_ixfun = MemLocation
src

    isScalarSpace :: Space -> Bool
isScalarSpace ScalarSpace {} = Bool
True
    isScalarSpace Space
_ = Bool
False

copyElementWise :: CopyCompiler lore r op
copyElementWise :: forall lore r op. CopyCompiler lore r op
copyElementWise PrimType
bt MemLocation
dest [DimIndex (TExp Int64)]
destslice MemLocation
src [DimIndex (TExp Int64)]
srcslice = do
  let bounds :: Shape (TExp Int64)
bounds = [DimIndex (TExp Int64)] -> Shape (TExp Int64)
forall d. Slice d -> [d]
sliceDims [DimIndex (TExp Int64)]
srcslice
  [VName]
is <- Int -> ImpM lore r op VName -> ImpM lore r op [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (Shape (TExp Int64) -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape (TExp Int64)
bounds) ([Char] -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"i")
  let ivars :: Shape (TExp Int64)
ivars = (VName -> TExp Int64) -> [VName] -> Shape (TExp Int64)
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 [VName]
is
  (VName
destmem, Space
destspace, Count Elements (TExp Int64)
destidx) <-
    MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall lore r op.
MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLocation
dest (Shape (TExp Int64)
 -> ImpM lore r op (VName, Space, Count Elements (TExp Int64)))
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall a b. (a -> b) -> a -> b
$ [DimIndex (TExp Int64)] -> Shape (TExp Int64) -> Shape (TExp Int64)
forall d. Num d => Slice d -> [d] -> [d]
fixSlice [DimIndex (TExp Int64)]
destslice Shape (TExp Int64)
ivars
  (VName
srcmem, Space
srcspace, Count Elements (TExp Int64)
srcidx) <-
    MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall lore r op.
MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLocation
src (Shape (TExp Int64)
 -> ImpM lore r op (VName, Space, Count Elements (TExp Int64)))
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall a b. (a -> b) -> a -> b
$ [DimIndex (TExp Int64)] -> Shape (TExp Int64) -> Shape (TExp Int64)
forall d. Num d => Slice d -> [d] -> [d]
fixSlice [DimIndex (TExp Int64)]
srcslice Shape (TExp Int64)
ivars
  Volatility
vol <- (Env lore r op -> Volatility) -> ImpM lore r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Volatility
forall lore r op. Env lore r op -> Volatility
envVolatility
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
    ((Code op -> Code op)
 -> (Code op -> Code op) -> Code op -> Code op)
-> (Code op -> Code op)
-> [Code op -> Code op]
-> Code op
-> Code op
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (Code op -> Code op) -> (Code op -> Code op) -> Code op -> Code op
forall b c a. (b -> c) -> (a -> b) -> a -> c
(.) Code op -> Code op
forall a. a -> a
id ((VName -> Exp -> Code op -> Code op)
-> [VName] -> [Exp] -> [Code op -> Code op]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> Exp -> Code op -> Code op
forall a. VName -> Exp -> Code a -> Code a
Imp.For [VName]
is ([Exp] -> [Code op -> Code op]) -> [Exp] -> [Code op -> Code op]
forall a b. (a -> b) -> a -> b
$ (TExp Int64 -> Exp) -> Shape (TExp Int64) -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped Shape (TExp Int64)
bounds) (Code op -> Code op) -> Code op -> Code op
forall a b. (a -> b) -> a -> b
$
      VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code op
forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
destmem Count Elements (TExp Int64)
destidx PrimType
bt Space
destspace Volatility
vol (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$
        VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
Imp.index VName
srcmem Count Elements (TExp Int64)
srcidx PrimType
bt Space
srcspace Volatility
vol

-- | Copy from here to there; both destination and source may be
-- indexeded.
copyArrayDWIM ::
  PrimType ->
  MemLocation ->
  [DimIndex (Imp.TExp Int64)] ->
  MemLocation ->
  [DimIndex (Imp.TExp Int64)] ->
  ImpM lore r op (Imp.Code op)
copyArrayDWIM :: forall lore r op.
PrimType
-> MemLocation
-> [DimIndex (TExp Int64)]
-> MemLocation
-> [DimIndex (TExp Int64)]
-> ImpM lore r op (Code op)
copyArrayDWIM
  PrimType
bt
  destlocation :: MemLocation
destlocation@(MemLocation VName
_ [SubExp]
destshape IxFun (TExp Int64)
_)
  [DimIndex (TExp Int64)]
destslice
  srclocation :: MemLocation
srclocation@(MemLocation VName
_ [SubExp]
srcshape IxFun (TExp Int64)
_)
  [DimIndex (TExp Int64)]
srcslice
    | Just Shape (TExp Int64)
destis <- (DimIndex (TExp Int64) -> Maybe (TExp Int64))
-> [DimIndex (TExp Int64)] -> Maybe (Shape (TExp Int64))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndex (TExp Int64) -> Maybe (TExp Int64)
forall d. DimIndex d -> Maybe d
dimFix [DimIndex (TExp Int64)]
destslice,
      Just Shape (TExp Int64)
srcis <- (DimIndex (TExp Int64) -> Maybe (TExp Int64))
-> [DimIndex (TExp Int64)] -> Maybe (Shape (TExp Int64))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndex (TExp Int64) -> Maybe (TExp Int64)
forall d. DimIndex d -> Maybe d
dimFix [DimIndex (TExp Int64)]
srcslice,
      Shape (TExp Int64) -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape (TExp Int64)
srcis Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
srcshape,
      Shape (TExp Int64) -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape (TExp Int64)
destis Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
destshape = do
      (VName
targetmem, Space
destspace, Count Elements (TExp Int64)
targetoffset) <-
        MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall lore r op.
MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLocation
destlocation Shape (TExp Int64)
destis
      (VName
srcmem, Space
srcspace, Count Elements (TExp Int64)
srcoffset) <-
        MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall lore r op.
MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLocation
srclocation Shape (TExp Int64)
srcis
      Volatility
vol <- (Env lore r op -> Volatility) -> ImpM lore r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Volatility
forall lore r op. Env lore r op -> Volatility
envVolatility
      Code op -> ImpM lore r op (Code op)
forall (m :: * -> *) a. Monad m => a -> m a
return (Code op -> ImpM lore r op (Code op))
-> Code op -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$
        VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code op
forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
targetmem Count Elements (TExp Int64)
targetoffset PrimType
bt Space
destspace Volatility
vol (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$
          VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
Imp.index VName
srcmem Count Elements (TExp Int64)
srcoffset PrimType
bt Space
srcspace Volatility
vol
    | Bool
otherwise = do
      let destslice' :: [DimIndex (TExp Int64)]
destslice' =
            Shape (TExp Int64)
-> [DimIndex (TExp Int64)] -> [DimIndex (TExp Int64)]
forall d. Num d => [d] -> [DimIndex d] -> [DimIndex d]
fullSliceNum ((SubExp -> TExp Int64) -> [SubExp] -> Shape (TExp Int64)
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
destshape) [DimIndex (TExp Int64)]
destslice
          srcslice' :: [DimIndex (TExp Int64)]
srcslice' =
            Shape (TExp Int64)
-> [DimIndex (TExp Int64)] -> [DimIndex (TExp Int64)]
forall d. Num d => [d] -> [DimIndex d] -> [DimIndex d]
fullSliceNum ((SubExp -> TExp Int64) -> [SubExp] -> Shape (TExp Int64)
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
srcshape) [DimIndex (TExp Int64)]
srcslice
          destrank :: Int
destrank = Shape (TExp Int64) -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Shape (TExp Int64) -> Int) -> Shape (TExp Int64) -> Int
forall a b. (a -> b) -> a -> b
$ [DimIndex (TExp Int64)] -> Shape (TExp Int64)
forall d. Slice d -> [d]
sliceDims [DimIndex (TExp Int64)]
destslice'
          srcrank :: Int
srcrank = Shape (TExp Int64) -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Shape (TExp Int64) -> Int) -> Shape (TExp Int64) -> Int
forall a b. (a -> b) -> a -> b
$ [DimIndex (TExp Int64)] -> Shape (TExp Int64)
forall d. Slice d -> [d]
sliceDims [DimIndex (TExp Int64)]
srcslice'
      if Int
destrank Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
srcrank
        then
          [Char] -> ImpM lore r op (Code op)
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM lore r op (Code op))
-> [Char] -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$
            [Char]
"copyArrayDWIM: cannot copy to "
              [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty (MemLocation -> VName
memLocationName MemLocation
destlocation)
              [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
" from "
              [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty (MemLocation -> VName
memLocationName MemLocation
srclocation)
              [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
" because ranks do not match ("
              [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Pretty a => a -> [Char]
pretty Int
destrank
              [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
" vs "
              [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Pretty a => a -> [Char]
pretty Int
srcrank
              [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
")"
        else
          if MemLocation
destlocation MemLocation -> MemLocation -> Bool
forall a. Eq a => a -> a -> Bool
== MemLocation
srclocation Bool -> Bool -> Bool
&& [DimIndex (TExp Int64)]
destslice' [DimIndex (TExp Int64)] -> [DimIndex (TExp Int64)] -> Bool
forall a. Eq a => a -> a -> Bool
== [DimIndex (TExp Int64)]
srcslice'
            then Code op -> ImpM lore r op (Code op)
forall (m :: * -> *) a. Monad m => a -> m a
return Code op
forall a. Monoid a => a
mempty -- Copy would be no-op.
            else ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (ImpM lore r op () -> ImpM lore r op (Code op))
-> ImpM lore r op () -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$ CopyCompiler lore r op
forall lore r op. CopyCompiler lore r op
copy PrimType
bt MemLocation
destlocation [DimIndex (TExp Int64)]
destslice' MemLocation
srclocation [DimIndex (TExp Int64)]
srcslice'

-- | Like 'copyDWIM', but the target is a 'ValueDestination'
-- instead of a variable name.
copyDWIMDest ::
  ValueDestination ->
  [DimIndex (Imp.TExp Int64)] ->
  SubExp ->
  [DimIndex (Imp.TExp Int64)] ->
  ImpM lore r op ()
copyDWIMDest :: forall lore r op.
ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIMDest ValueDestination
_ [DimIndex (TExp Int64)]
_ (Constant PrimValue
v) (DimIndex (TExp Int64)
_ : [DimIndex (TExp Int64)]
_) =
  [Char] -> ImpM lore r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM lore r op ()) -> [Char] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
    [[Char]] -> [Char]
unwords [[Char]
"copyDWIMDest: constant source", PrimValue -> [Char]
forall a. Pretty a => a -> [Char]
pretty PrimValue
v, [Char]
"cannot be indexed."]
copyDWIMDest ValueDestination
pat [DimIndex (TExp Int64)]
dest_slice (Constant PrimValue
v) [] =
  case (DimIndex (TExp Int64) -> Maybe (TExp Int64))
-> [DimIndex (TExp Int64)] -> Maybe (Shape (TExp Int64))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndex (TExp Int64) -> Maybe (TExp Int64)
forall d. DimIndex d -> Maybe d
dimFix [DimIndex (TExp Int64)]
dest_slice of
    Maybe (Shape (TExp Int64))
Nothing ->
      [Char] -> ImpM lore r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM lore r op ()) -> [Char] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
        [[Char]] -> [Char]
unwords [[Char]
"copyDWIMDest: constant source", PrimValue -> [Char]
forall a. Pretty a => a -> [Char]
pretty PrimValue
v, [Char]
"with slice destination."]
    Just Shape (TExp Int64)
dest_is ->
      case ValueDestination
pat of
        ScalarDestination VName
name ->
          Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
name (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ PrimValue -> Exp
forall v. PrimValue -> PrimExp v
Imp.ValueExp PrimValue
v
        MemoryDestination {} ->
          [Char] -> ImpM lore r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM lore r op ()) -> [Char] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
            [[Char]] -> [Char]
unwords [[Char]
"copyDWIMDest: constant source", PrimValue -> [Char]
forall a. Pretty a => a -> [Char]
pretty PrimValue
v, [Char]
"cannot be written to memory destination."]
        ArrayDestination (Just MemLocation
dest_loc) -> do
          (VName
dest_mem, Space
dest_space, Count Elements (TExp Int64)
dest_i) <-
            MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall lore r op.
MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLocation
dest_loc Shape (TExp Int64)
dest_is
          Volatility
vol <- (Env lore r op -> Volatility) -> ImpM lore r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Volatility
forall lore r op. Env lore r op -> Volatility
envVolatility
          Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code op
forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
dest_mem Count Elements (TExp Int64)
dest_i PrimType
bt Space
dest_space Volatility
vol (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ PrimValue -> Exp
forall v. PrimValue -> PrimExp v
Imp.ValueExp PrimValue
v
        ArrayDestination Maybe MemLocation
Nothing ->
          [Char] -> ImpM lore r op ()
forall a. HasCallStack => [Char] -> a
error [Char]
"copyDWIMDest: ArrayDestination Nothing"
  where
    bt :: PrimType
bt = PrimValue -> PrimType
primValueType PrimValue
v
copyDWIMDest ValueDestination
dest [DimIndex (TExp Int64)]
dest_slice (Var VName
src) [DimIndex (TExp Int64)]
src_slice = do
  VarEntry lore
src_entry <- VName -> ImpM lore r op (VarEntry lore)
forall lore r op. VName -> ImpM lore r op (VarEntry lore)
lookupVar VName
src
  case (ValueDestination
dest, VarEntry lore
src_entry) of
    (MemoryDestination VName
mem, MemVar Maybe (Exp lore)
_ (MemEntry Space
space)) ->
      Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> Space -> Code op
forall a. VName -> VName -> Space -> Code a
Imp.SetMem VName
mem VName
src Space
space
    (MemoryDestination {}, VarEntry lore
_) ->
      [Char] -> ImpM lore r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM lore r op ()) -> [Char] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
        [[Char]] -> [Char]
unwords [[Char]
"copyDWIMDest: cannot write", VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
src, [Char]
"to memory destination."]
    (ValueDestination
_, MemVar {}) ->
      [Char] -> ImpM lore r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM lore r op ()) -> [Char] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
        [[Char]] -> [Char]
unwords [[Char]
"copyDWIMDest: source", VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
src, [Char]
"is a memory block."]
    (ValueDestination
_, ScalarVar Maybe (Exp lore)
_ (ScalarEntry PrimType
_))
      | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [DimIndex (TExp Int64)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [DimIndex (TExp Int64)]
src_slice ->
        [Char] -> ImpM lore r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM lore r op ()) -> [Char] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
          [[Char]] -> [Char]
unwords [[Char]
"copyDWIMDest: prim-typed source", VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
src, [Char]
"with slice", [DimIndex (TExp Int64)] -> [Char]
forall a. Pretty a => a -> [Char]
pretty [DimIndex (TExp Int64)]
src_slice]
    (ScalarDestination VName
name, VarEntry lore
_)
      | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [DimIndex (TExp Int64)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [DimIndex (TExp Int64)]
dest_slice ->
        [Char] -> ImpM lore r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM lore r op ()) -> [Char] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
          [[Char]] -> [Char]
unwords [[Char]
"copyDWIMDest: prim-typed target", VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
name, [Char]
"with slice", [DimIndex (TExp Int64)] -> [Char]
forall a. Pretty a => a -> [Char]
pretty [DimIndex (TExp Int64)]
dest_slice]
    (ScalarDestination VName
name, ScalarVar Maybe (Exp lore)
_ (ScalarEntry PrimType
pt)) ->
      Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
name (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
src PrimType
pt
    (ScalarDestination VName
name, ArrayVar Maybe (Exp lore)
_ ArrayEntry
arr)
      | Just Shape (TExp Int64)
src_is <- (DimIndex (TExp Int64) -> Maybe (TExp Int64))
-> [DimIndex (TExp Int64)] -> Maybe (Shape (TExp Int64))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndex (TExp Int64) -> Maybe (TExp Int64)
forall d. DimIndex d -> Maybe d
dimFix [DimIndex (TExp Int64)]
src_slice,
        [DimIndex (TExp Int64)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex (TExp Int64)]
src_slice Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (ArrayEntry -> [SubExp]
entryArrayShape ArrayEntry
arr) -> do
        let bt :: PrimType
bt = ArrayEntry -> PrimType
entryArrayElemType ArrayEntry
arr
        (VName
mem, Space
space, Count Elements (TExp Int64)
i) <-
          MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall lore r op.
MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' (ArrayEntry -> MemLocation
entryArrayLocation ArrayEntry
arr) Shape (TExp Int64)
src_is
        Volatility
vol <- (Env lore r op -> Volatility) -> ImpM lore r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Volatility
forall lore r op. Env lore r op -> Volatility
envVolatility
        Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
name (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
Imp.index VName
mem Count Elements (TExp Int64)
i PrimType
bt Space
space Volatility
vol
      | Bool
otherwise ->
        [Char] -> ImpM lore r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM lore r op ()) -> [Char] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
          [[Char]] -> [Char]
unwords
            [ [Char]
"copyDWIMDest: prim-typed target",
              VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
name,
              [Char]
"and array-typed source",
              VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
src,
              [Char]
"with slice",
              [DimIndex (TExp Int64)] -> [Char]
forall a. Pretty a => a -> [Char]
pretty [DimIndex (TExp Int64)]
src_slice
            ]
    (ArrayDestination (Just MemLocation
dest_loc), ArrayVar Maybe (Exp lore)
_ ArrayEntry
src_arr) -> do
      let src_loc :: MemLocation
src_loc = ArrayEntry -> MemLocation
entryArrayLocation ArrayEntry
src_arr
          bt :: PrimType
bt = ArrayEntry -> PrimType
entryArrayElemType ArrayEntry
src_arr
      Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ())
-> ImpM lore r op (Code op) -> ImpM lore r op ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< PrimType
-> MemLocation
-> [DimIndex (TExp Int64)]
-> MemLocation
-> [DimIndex (TExp Int64)]
-> ImpM lore r op (Code op)
forall lore r op.
PrimType
-> MemLocation
-> [DimIndex (TExp Int64)]
-> MemLocation
-> [DimIndex (TExp Int64)]
-> ImpM lore r op (Code op)
copyArrayDWIM PrimType
bt MemLocation
dest_loc [DimIndex (TExp Int64)]
dest_slice MemLocation
src_loc [DimIndex (TExp Int64)]
src_slice
    (ArrayDestination (Just MemLocation
dest_loc), ScalarVar Maybe (Exp lore)
_ (ScalarEntry PrimType
bt))
      | Just Shape (TExp Int64)
dest_is <- (DimIndex (TExp Int64) -> Maybe (TExp Int64))
-> [DimIndex (TExp Int64)] -> Maybe (Shape (TExp Int64))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndex (TExp Int64) -> Maybe (TExp Int64)
forall d. DimIndex d -> Maybe d
dimFix [DimIndex (TExp Int64)]
dest_slice -> do
        (VName
dest_mem, Space
dest_space, Count Elements (TExp Int64)
dest_i) <- MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall lore r op.
MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLocation
dest_loc Shape (TExp Int64)
dest_is
        Volatility
vol <- (Env lore r op -> Volatility) -> ImpM lore r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Volatility
forall lore r op. Env lore r op -> Volatility
envVolatility
        Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code op
forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
dest_mem Count Elements (TExp Int64)
dest_i PrimType
bt Space
dest_space Volatility
vol (VName -> PrimType -> Exp
Imp.var VName
src PrimType
bt)
      | Bool
otherwise ->
        [Char] -> ImpM lore r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM lore r op ()) -> [Char] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
          [[Char]] -> [Char]
unwords
            [ [Char]
"copyDWIMDest: array-typed target and prim-typed source",
              VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
src,
              [Char]
"with slice",
              [DimIndex (TExp Int64)] -> [Char]
forall a. Pretty a => a -> [Char]
pretty [DimIndex (TExp Int64)]
dest_slice
            ]
    (ArrayDestination Maybe MemLocation
Nothing, VarEntry lore
_) ->
      () -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return () -- Nothing to do; something else set some memory
      -- somewhere.
    (ValueDestination
_, AccVar {}) ->
      () -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return () -- Nothing to do; accumulators are phantoms.

-- | Copy from here to there; both destination and source be
-- indexeded.  If so, they better be arrays of enough dimensions.
-- This function will generally just Do What I Mean, and Do The Right
-- Thing.  Both destination and source must be in scope.
copyDWIM ::
  VName ->
  [DimIndex (Imp.TExp Int64)] ->
  SubExp ->
  [DimIndex (Imp.TExp Int64)] ->
  ImpM lore r op ()
copyDWIM :: forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM VName
dest [DimIndex (TExp Int64)]
dest_slice SubExp
src [DimIndex (TExp Int64)]
src_slice = do
  VarEntry lore
dest_entry <- VName -> ImpM lore r op (VarEntry lore)
forall lore r op. VName -> ImpM lore r op (VarEntry lore)
lookupVar VName
dest
  let dest_target :: ValueDestination
dest_target =
        case VarEntry lore
dest_entry of
          ScalarVar Maybe (Exp lore)
_ ScalarEntry
_ ->
            VName -> ValueDestination
ScalarDestination VName
dest
          ArrayVar Maybe (Exp lore)
_ (ArrayEntry (MemLocation VName
mem [SubExp]
shape IxFun (TExp Int64)
ixfun) PrimType
_) ->
            Maybe MemLocation -> ValueDestination
ArrayDestination (Maybe MemLocation -> ValueDestination)
-> Maybe MemLocation -> ValueDestination
forall a b. (a -> b) -> a -> b
$ MemLocation -> Maybe MemLocation
forall a. a -> Maybe a
Just (MemLocation -> Maybe MemLocation)
-> MemLocation -> Maybe MemLocation
forall a b. (a -> b) -> a -> b
$ VName -> [SubExp] -> IxFun (TExp Int64) -> MemLocation
MemLocation VName
mem [SubExp]
shape IxFun (TExp Int64)
ixfun
          MemVar Maybe (Exp lore)
_ MemEntry
_ ->
            VName -> ValueDestination
MemoryDestination VName
dest
          AccVar {} ->
            -- Does not matter; accumulators are phantoms.
            Maybe MemLocation -> ValueDestination
ArrayDestination Maybe MemLocation
forall a. Maybe a
Nothing
  ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIMDest ValueDestination
dest_target [DimIndex (TExp Int64)]
dest_slice SubExp
src [DimIndex (TExp Int64)]
src_slice

-- | As 'copyDWIM', but implicitly 'DimFix'es the indexes.
copyDWIMFix ::
  VName ->
  [Imp.TExp Int64] ->
  SubExp ->
  [Imp.TExp Int64] ->
  ImpM lore r op ()
copyDWIMFix :: forall lore r op.
VName
-> Shape (TExp Int64)
-> SubExp
-> Shape (TExp Int64)
-> ImpM lore r op ()
copyDWIMFix VName
dest Shape (TExp Int64)
dest_is SubExp
src Shape (TExp Int64)
src_is =
  VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM VName
dest ((TExp Int64 -> DimIndex (TExp Int64))
-> Shape (TExp Int64) -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix Shape (TExp Int64)
dest_is) SubExp
src ((TExp Int64 -> DimIndex (TExp Int64))
-> Shape (TExp Int64) -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix Shape (TExp Int64)
src_is)

-- | @compileAlloc pat size space@ allocates @n@ bytes of memory in @space@,
-- writing the result to @dest@, which must be a single
-- 'MemoryDestination',
compileAlloc ::
  Mem lore =>
  Pattern lore ->
  SubExp ->
  Space ->
  ImpM lore r op ()
compileAlloc :: forall lore r op.
Mem lore =>
Pattern lore -> SubExp -> Space -> ImpM lore r op ()
compileAlloc (Pattern [] [PatElemT (LetDec lore)
mem]) SubExp
e Space
space = do
  let e' :: Count Bytes (TExp Int64)
e' = TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
Imp.bytes (TExp Int64 -> Count Bytes (TExp Int64))
-> TExp Int64 -> Count Bytes (TExp Int64)
forall a b. (a -> b) -> a -> b
$ SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp SubExp
e
  Maybe (AllocCompiler lore r op)
allocator <- (Env lore r op -> Maybe (AllocCompiler lore r op))
-> ImpM lore r op (Maybe (AllocCompiler lore r op))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((Env lore r op -> Maybe (AllocCompiler lore r op))
 -> ImpM lore r op (Maybe (AllocCompiler lore r op)))
-> (Env lore r op -> Maybe (AllocCompiler lore r op))
-> ImpM lore r op (Maybe (AllocCompiler lore r op))
forall a b. (a -> b) -> a -> b
$ Space
-> Map Space (AllocCompiler lore r op)
-> Maybe (AllocCompiler lore r op)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Space
space (Map Space (AllocCompiler lore r op)
 -> Maybe (AllocCompiler lore r op))
-> (Env lore r op -> Map Space (AllocCompiler lore r op))
-> Env lore r op
-> Maybe (AllocCompiler lore r op)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env lore r op -> Map Space (AllocCompiler lore r op)
forall lore r op.
Env lore r op -> Map Space (AllocCompiler lore r op)
envAllocCompilers
  case Maybe (AllocCompiler lore r op)
allocator of
    Maybe (AllocCompiler lore r op)
Nothing -> Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Count Bytes (TExp Int64) -> Space -> Code op
forall a. VName -> Count Bytes (TExp Int64) -> Space -> Code a
Imp.Allocate (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT LParamMem
mem) Count Bytes (TExp Int64)
e' Space
space
    Just AllocCompiler lore r op
allocator' -> AllocCompiler lore r op
allocator' (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT LParamMem
mem) Count Bytes (TExp Int64)
e'
compileAlloc PatternT (LetDec lore)
pat SubExp
_ Space
_ =
  [Char] -> ImpM lore r op ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> ImpM lore r op ()) -> [Char] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ [Char]
"compileAlloc: Invalid pattern: " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ PatternT LParamMem -> [Char]
forall a. Pretty a => a -> [Char]
pretty PatternT (LetDec lore)
PatternT LParamMem
pat

-- | The number of bytes needed to represent the array in a
-- straightforward contiguous format, as an t'Int64' expression.
typeSize :: Type -> Count Bytes (Imp.TExp Int64)
typeSize :: Type -> Count Bytes (TExp Int64)
typeSize Type
t =
  TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
Imp.bytes (TExp Int64 -> Count Bytes (TExp Int64))
-> TExp Int64 -> Count Bytes (TExp Int64)
forall a b. (a -> b) -> a -> b
$ PrimType -> TExp Int64
forall a. Num a => PrimType -> a
primByteSize (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t) TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* Shape (TExp Int64) -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ((SubExp -> TExp Int64) -> [SubExp] -> Shape (TExp Int64)
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t))

-- | Is this indexing in-bounds for an array of the given shape?  This
-- is useful for things like scatter, which ignores out-of-bounds
-- writes.
inBounds :: Slice (Imp.TExp Int64) -> [Imp.TExp Int64] -> Imp.TExp Bool
inBounds :: [DimIndex (TExp Int64)] -> Shape (TExp Int64) -> TExp Bool
inBounds [DimIndex (TExp Int64)]
slice Shape (TExp Int64)
dims =
  let condInBounds :: DimIndex (TPrimExp t v) -> TPrimExp t v -> TPrimExp Bool v
condInBounds (DimFix TPrimExp t v
i) TPrimExp t v
d =
        TPrimExp t v
0 TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TPrimExp t v
i TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TPrimExp t v
i TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp t v
d
      condInBounds (DimSlice TPrimExp t v
i TPrimExp t v
n TPrimExp t v
s) TPrimExp t v
d =
        TPrimExp t v
0 TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TPrimExp t v
i TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TPrimExp t v
i TPrimExp t v -> TPrimExp t v -> TPrimExp t v
forall a. Num a => a -> a -> a
+ TPrimExp t v
n TPrimExp t v -> TPrimExp t v -> TPrimExp t v
forall a. Num a => a -> a -> a
* TPrimExp t v
s TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp t v
d
   in (TExp Bool -> TExp Bool -> TExp Bool) -> [TExp Bool] -> TExp Bool
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
(.&&.) ([TExp Bool] -> TExp Bool) -> [TExp Bool] -> TExp Bool
forall a b. (a -> b) -> a -> b
$ (DimIndex (TExp Int64) -> TExp Int64 -> TExp Bool)
-> [DimIndex (TExp Int64)] -> Shape (TExp Int64) -> [TExp Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith DimIndex (TExp Int64) -> TExp Int64 -> TExp Bool
forall {t} {v}.
(NumExp t, Pretty v) =>
DimIndex (TPrimExp t v) -> TPrimExp t v -> TPrimExp Bool v
condInBounds [DimIndex (TExp Int64)]
slice Shape (TExp Int64)
dims

--- Building blocks for constructing code.

sFor' :: VName -> Imp.Exp -> ImpM lore r op () -> ImpM lore r op ()
sFor' :: forall lore r op.
VName -> Exp -> ImpM lore r op () -> ImpM lore r op ()
sFor' VName
i Exp
bound ImpM lore r op ()
body = do
  let it :: IntType
it = case Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType Exp
bound of
        IntType IntType
bound_t -> IntType
bound_t
        PrimType
t -> [Char] -> IntType
forall a. HasCallStack => [Char] -> a
error ([Char] -> IntType) -> [Char] -> IntType
forall a b. (a -> b) -> a -> b
$ [Char]
"sFor': bound " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ Exp -> [Char]
forall a. Pretty a => a -> [Char]
pretty Exp
bound [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
" is of type " [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ PrimType -> [Char]
forall a. Pretty a => a -> [Char]
pretty PrimType
t
  VName -> IntType -> ImpM lore r op ()
forall lore r op. VName -> IntType -> ImpM lore r op ()
addLoopVar VName
i IntType
it
  Code op
body' <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect ImpM lore r op ()
body
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op -> Code op
forall a. VName -> Exp -> Code a -> Code a
Imp.For VName
i Exp
bound Code op
body'

sFor :: String -> Imp.TExp t -> (Imp.TExp t -> ImpM lore r op ()) -> ImpM lore r op ()
sFor :: forall t lore r op.
[Char]
-> TExp t -> (TExp t -> ImpM lore r op ()) -> ImpM lore r op ()
sFor [Char]
i TExp t
bound TExp t -> ImpM lore r op ()
body = do
  VName
i' <- [Char] -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
i
  VName -> Exp -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
VName -> Exp -> ImpM lore r op () -> ImpM lore r op ()
sFor' VName
i' (TExp t -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp t
bound) (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
    TExp t -> ImpM lore r op ()
body (TExp t -> ImpM lore r op ()) -> TExp t -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Exp -> TExp t
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (Exp -> TExp t) -> Exp -> TExp t
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
i' (PrimType -> Exp) -> PrimType -> Exp
forall a b. (a -> b) -> a -> b
$ Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType (Exp -> PrimType) -> Exp -> PrimType
forall a b. (a -> b) -> a -> b
$ TExp t -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp t
bound

sWhile :: Imp.TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhile :: forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhile TExp Bool
cond ImpM lore r op ()
body = do
  Code op
body' <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect ImpM lore r op ()
body
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ TExp Bool -> Code op -> Code op
forall a. TExp Bool -> Code a -> Code a
Imp.While TExp Bool
cond Code op
body'

sComment :: String -> ImpM lore r op () -> ImpM lore r op ()
sComment :: forall lore r op. [Char] -> ImpM lore r op () -> ImpM lore r op ()
sComment [Char]
s ImpM lore r op ()
code = do
  Code op
code' <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect ImpM lore r op ()
code
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ [Char] -> Code op -> Code op
forall a. [Char] -> Code a -> Code a
Imp.Comment [Char]
s Code op
code'

sIf :: Imp.TExp Bool -> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf :: forall lore r op.
TExp Bool
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf TExp Bool
cond ImpM lore r op ()
tbranch ImpM lore r op ()
fbranch = do
  Code op
tbranch' <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect ImpM lore r op ()
tbranch
  Code op
fbranch' <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect ImpM lore r op ()
fbranch
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ TExp Bool -> Code op -> Code op -> Code op
forall a. TExp Bool -> Code a -> Code a -> Code a
Imp.If TExp Bool
cond Code op
tbranch' Code op
fbranch'

sWhen :: Imp.TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen :: forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen TExp Bool
cond ImpM lore r op ()
tbranch = TExp Bool
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
TExp Bool
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf TExp Bool
cond ImpM lore r op ()
tbranch (() -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())

sUnless :: Imp.TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sUnless :: forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sUnless TExp Bool
cond = TExp Bool
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
TExp Bool
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf TExp Bool
cond (() -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())

sOp :: op -> ImpM lore r op ()
sOp :: forall op lore r. op -> ImpM lore r op ()
sOp = Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ())
-> (op -> Code op) -> op -> ImpM lore r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. op -> Code op
forall a. a -> Code a
Imp.Op

sDeclareMem :: String -> Space -> ImpM lore r op VName
sDeclareMem :: forall lore r op. [Char] -> Space -> ImpM lore r op VName
sDeclareMem [Char]
name Space
space = do
  VName
name' <- [Char] -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
name
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Code op
forall a. VName -> Space -> Code a
Imp.DeclareMem VName
name' Space
space
  VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name' (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> MemEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> MemEntry -> VarEntry lore
MemVar Maybe (Exp lore)
forall a. Maybe a
Nothing (MemEntry -> VarEntry lore) -> MemEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
space
  VName -> ImpM lore r op VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
name'

sAlloc_ :: VName -> Count Bytes (Imp.TExp Int64) -> Space -> ImpM lore r op ()
sAlloc_ :: forall lore r op.
VName -> Count Bytes (TExp Int64) -> Space -> ImpM lore r op ()
sAlloc_ VName
name' Count Bytes (TExp Int64)
size' Space
space = do
  Maybe (AllocCompiler lore r op)
allocator <- (Env lore r op -> Maybe (AllocCompiler lore r op))
-> ImpM lore r op (Maybe (AllocCompiler lore r op))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((Env lore r op -> Maybe (AllocCompiler lore r op))
 -> ImpM lore r op (Maybe (AllocCompiler lore r op)))
-> (Env lore r op -> Maybe (AllocCompiler lore r op))
-> ImpM lore r op (Maybe (AllocCompiler lore r op))
forall a b. (a -> b) -> a -> b
$ Space
-> Map Space (AllocCompiler lore r op)
-> Maybe (AllocCompiler lore r op)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Space
space (Map Space (AllocCompiler lore r op)
 -> Maybe (AllocCompiler lore r op))
-> (Env lore r op -> Map Space (AllocCompiler lore r op))
-> Env lore r op
-> Maybe (AllocCompiler lore r op)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env lore r op -> Map Space (AllocCompiler lore r op)
forall lore r op.
Env lore r op -> Map Space (AllocCompiler lore r op)
envAllocCompilers
  case Maybe (AllocCompiler lore r op)
allocator of
    Maybe (AllocCompiler lore r op)
Nothing -> Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Count Bytes (TExp Int64) -> Space -> Code op
forall a. VName -> Count Bytes (TExp Int64) -> Space -> Code a
Imp.Allocate VName
name' Count Bytes (TExp Int64)
size' Space
space
    Just AllocCompiler lore r op
allocator' -> AllocCompiler lore r op
allocator' VName
name' Count Bytes (TExp Int64)
size'

sAlloc :: String -> Count Bytes (Imp.TExp Int64) -> Space -> ImpM lore r op VName
sAlloc :: forall lore r op.
[Char] -> Count Bytes (TExp Int64) -> Space -> ImpM lore r op VName
sAlloc [Char]
name Count Bytes (TExp Int64)
size Space
space = do
  VName
name' <- [Char] -> Space -> ImpM lore r op VName
forall lore r op. [Char] -> Space -> ImpM lore r op VName
sDeclareMem [Char]
name Space
space
  VName -> Count Bytes (TExp Int64) -> Space -> ImpM lore r op ()
forall lore r op.
VName -> Count Bytes (TExp Int64) -> Space -> ImpM lore r op ()
sAlloc_ VName
name' Count Bytes (TExp Int64)
size Space
space
  VName -> ImpM lore r op VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
name'

sArray :: String -> PrimType -> ShapeBase SubExp -> MemBind -> ImpM lore r op VName
sArray :: forall lore r op.
[Char] -> PrimType -> Shape -> MemBind -> ImpM lore r op VName
sArray [Char]
name PrimType
bt Shape
shape MemBind
membind = do
  VName
name' <- [Char] -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
name
  VName -> PrimType -> Shape -> MemBind -> ImpM lore r op ()
forall lore r op.
VName -> PrimType -> Shape -> MemBind -> ImpM lore r op ()
dArray VName
name' PrimType
bt Shape
shape MemBind
membind
  VName -> ImpM lore r op VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
name'

-- | Declare an array in row-major order in the given memory block.
sArrayInMem :: String -> PrimType -> ShapeBase SubExp -> VName -> ImpM lore r op VName
sArrayInMem :: forall lore r op.
[Char] -> PrimType -> Shape -> VName -> ImpM lore r op VName
sArrayInMem [Char]
name PrimType
pt Shape
shape VName
mem =
  [Char] -> PrimType -> Shape -> MemBind -> ImpM lore r op VName
forall lore r op.
[Char] -> PrimType -> Shape -> MemBind -> ImpM lore r op VName
sArray [Char]
name PrimType
pt Shape
shape (MemBind -> ImpM lore r op VName)
-> MemBind -> ImpM lore r op VName
forall a b. (a -> b) -> a -> b
$
    VName -> IxFun -> MemBind
ArrayIn VName
mem (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$
      Shape (TPrimExp Int64 VName) -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota (Shape (TPrimExp Int64 VName) -> IxFun)
-> Shape (TPrimExp Int64 VName) -> IxFun
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map (PrimExp VName -> TPrimExp Int64 VName
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (PrimExp VName -> TPrimExp Int64 VName)
-> (SubExp -> PrimExp VName) -> SubExp -> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int64) ([SubExp] -> Shape (TPrimExp Int64 VName))
-> [SubExp] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape

-- | Like 'sAllocArray', but permute the in-memory representation of the indices as specified.
sAllocArrayPerm :: String -> PrimType -> ShapeBase SubExp -> Space -> [Int] -> ImpM lore r op VName
sAllocArrayPerm :: forall lore r op.
[Char]
-> PrimType -> Shape -> Space -> [Int] -> ImpM lore r op VName
sAllocArrayPerm [Char]
name PrimType
pt Shape
shape Space
space [Int]
perm = do
  let permuted_dims :: [SubExp]
permuted_dims = [Int] -> [SubExp] -> [SubExp]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape
  VName
mem <- [Char] -> Count Bytes (TExp Int64) -> Space -> ImpM lore r op VName
forall lore r op.
[Char] -> Count Bytes (TExp Int64) -> Space -> ImpM lore r op VName
sAlloc ([Char]
name [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
"_mem") (Type -> Count Bytes (TExp Int64)
typeSize (PrimType -> Shape -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
pt Shape
shape NoUniqueness
NoUniqueness)) Space
space
  let iota_ixfun :: IxFun
iota_ixfun = Shape (TPrimExp Int64 VName) -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota (Shape (TPrimExp Int64 VName) -> IxFun)
-> Shape (TPrimExp Int64 VName) -> IxFun
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map (PrimExp VName -> TPrimExp Int64 VName
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (PrimExp VName -> TPrimExp Int64 VName)
-> (SubExp -> PrimExp VName) -> SubExp -> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int64) [SubExp]
permuted_dims
  [Char] -> PrimType -> Shape -> MemBind -> ImpM lore r op VName
forall lore r op.
[Char] -> PrimType -> Shape -> MemBind -> ImpM lore r op VName
sArray [Char]
name PrimType
pt Shape
shape (MemBind -> ImpM lore r op VName)
-> MemBind -> ImpM lore r op VName
forall a b. (a -> b) -> a -> b
$
    VName -> IxFun -> MemBind
ArrayIn VName
mem (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$ IxFun -> [Int] -> IxFun
forall num. IntegralExp num => IxFun num -> [Int] -> IxFun num
IxFun.permute IxFun
iota_ixfun ([Int] -> IxFun) -> [Int] -> IxFun
forall a b. (a -> b) -> a -> b
$ [Int] -> [Int]
rearrangeInverse [Int]
perm

-- | Uses linear/iota index function.
sAllocArray :: String -> PrimType -> ShapeBase SubExp -> Space -> ImpM lore r op VName
sAllocArray :: forall lore r op.
[Char] -> PrimType -> Shape -> Space -> ImpM lore r op VName
sAllocArray [Char]
name PrimType
pt Shape
shape Space
space =
  [Char]
-> PrimType -> Shape -> Space -> [Int] -> ImpM lore r op VName
forall lore r op.
[Char]
-> PrimType -> Shape -> Space -> [Int] -> ImpM lore r op VName
sAllocArrayPerm [Char]
name PrimType
pt Shape
shape Space
space [Int
0 .. Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
shape Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]

-- | Uses linear/iota index function.
sStaticArray :: String -> Space -> PrimType -> Imp.ArrayContents -> ImpM lore r op VName
sStaticArray :: forall lore r op.
[Char]
-> Space -> PrimType -> ArrayContents -> ImpM lore r op VName
sStaticArray [Char]
name Space
space PrimType
pt ArrayContents
vs = do
  let num_elems :: Int
num_elems = case ArrayContents
vs of
        Imp.ArrayValues [PrimValue]
vs' -> [PrimValue] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimValue]
vs'
        Imp.ArrayZeros Int
n -> Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n
      shape :: Shape
shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [IntType -> Integer -> SubExp
intConst IntType
Int64 (Integer -> SubExp) -> Integer -> SubExp
forall a b. (a -> b) -> a -> b
$ Int -> Integer
forall a. Integral a => a -> Integer
toInteger Int
num_elems]
  VName
mem <- [Char] -> ImpM lore r op VName
forall lore r op. [Char] -> ImpM lore r op VName
newVNameForFun ([Char] -> ImpM lore r op VName) -> [Char] -> ImpM lore r op VName
forall a b. (a -> b) -> a -> b
$ [Char]
name [Char] -> ShowS
forall a. [a] -> [a] -> [a]
++ [Char]
"_mem"
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Space -> PrimType -> ArrayContents -> Code op
forall a. VName -> Space -> PrimType -> ArrayContents -> Code a
Imp.DeclareArray VName
mem Space
space PrimType
pt ArrayContents
vs
  VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
mem (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> MemEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> MemEntry -> VarEntry lore
MemVar Maybe (Exp lore)
forall a. Maybe a
Nothing (MemEntry -> VarEntry lore) -> MemEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
space
  [Char] -> PrimType -> Shape -> MemBind -> ImpM lore r op VName
forall lore r op.
[Char] -> PrimType -> Shape -> MemBind -> ImpM lore r op VName
sArray [Char]
name PrimType
pt Shape
shape (MemBind -> ImpM lore r op VName)
-> MemBind -> ImpM lore r op VName
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn VName
mem (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$ Shape (TPrimExp Int64 VName) -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota [Int -> TPrimExp Int64 VName
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
num_elems]

sWrite :: VName -> [Imp.TExp Int64] -> Imp.Exp -> ImpM lore r op ()
sWrite :: forall lore r op.
VName -> Shape (TExp Int64) -> Exp -> ImpM lore r op ()
sWrite VName
arr Shape (TExp Int64)
is Exp
v = do
  (VName
mem, Space
space, Count Elements (TExp Int64)
offset) <- VName
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall lore r op.
VName
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray VName
arr Shape (TExp Int64)
is
  Volatility
vol <- (Env lore r op -> Volatility) -> ImpM lore r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Volatility
forall lore r op. Env lore r op -> Volatility
envVolatility
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code op
forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
mem Count Elements (TExp Int64)
offset (Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType Exp
v) Space
space Volatility
vol Exp
v

sUpdate :: VName -> Slice (Imp.TExp Int64) -> SubExp -> ImpM lore r op ()
sUpdate :: forall lore r op.
VName -> [DimIndex (TExp Int64)] -> SubExp -> ImpM lore r op ()
sUpdate VName
arr [DimIndex (TExp Int64)]
slice SubExp
v = VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM VName
arr [DimIndex (TExp Int64)]
slice SubExp
v []

sLoopNest ::
  Shape ->
  ([Imp.TExp Int64] -> ImpM lore r op ()) ->
  ImpM lore r op ()
sLoopNest :: forall lore r op.
Shape
-> (Shape (TExp Int64) -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest = Shape (TExp Int64)
-> [SubExp]
-> (Shape (TExp Int64) -> ImpM lore r op ())
-> ImpM lore r op ()
forall {a} {lore} {r} {op}.
ToExp a =>
Shape (TExp Int64)
-> [a]
-> (Shape (TExp Int64) -> ImpM lore r op ())
-> ImpM lore r op ()
sLoopNest' [] ([SubExp]
 -> (Shape (TExp Int64) -> ImpM lore r op ()) -> ImpM lore r op ())
-> (Shape -> [SubExp])
-> Shape
-> (Shape (TExp Int64) -> ImpM lore r op ())
-> ImpM lore r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims
  where
    sLoopNest' :: Shape (TExp Int64)
-> [a]
-> (Shape (TExp Int64) -> ImpM lore r op ())
-> ImpM lore r op ()
sLoopNest' Shape (TExp Int64)
is [] Shape (TExp Int64) -> ImpM lore r op ()
f = Shape (TExp Int64) -> ImpM lore r op ()
f (Shape (TExp Int64) -> ImpM lore r op ())
-> Shape (TExp Int64) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Shape (TExp Int64) -> Shape (TExp Int64)
forall a. [a] -> [a]
reverse Shape (TExp Int64)
is
    sLoopNest' Shape (TExp Int64)
is (a
d : [a]
ds) Shape (TExp Int64) -> ImpM lore r op ()
f =
      [Char]
-> TExp Int64
-> (TExp Int64 -> ImpM lore r op ())
-> ImpM lore r op ()
forall t lore r op.
[Char]
-> TExp t -> (TExp t -> ImpM lore r op ()) -> ImpM lore r op ()
sFor [Char]
"nest_i" (a -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp a
d) ((TExp Int64 -> ImpM lore r op ()) -> ImpM lore r op ())
-> (TExp Int64 -> ImpM lore r op ()) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> Shape (TExp Int64)
-> [a]
-> (Shape (TExp Int64) -> ImpM lore r op ())
-> ImpM lore r op ()
sLoopNest' (TExp Int64
i TExp Int64 -> Shape (TExp Int64) -> Shape (TExp Int64)
forall a. a -> [a] -> [a]
: Shape (TExp Int64)
is) [a]
ds Shape (TExp Int64) -> ImpM lore r op ()
f

-- | Untyped assignment.
(<~~) :: VName -> Imp.Exp -> ImpM lore r op ()
VName
x <~~ :: forall lore r op. VName -> Exp -> ImpM lore r op ()
<~~ Exp
e = Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
x Exp
e

infixl 3 <~~

-- | Typed assignment.
(<--) :: TV t -> Imp.TExp t -> ImpM lore r op ()
TV VName
x PrimType
_ <-- :: forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- TExp t
e = Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
x (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ TExp t -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp t
e

infixl 3 <--

-- | Constructing an ad-hoc function that does not
-- correspond to any of the IR functions in the input program.
function ::
  Name ->
  [Imp.Param] ->
  [Imp.Param] ->
  ImpM lore r op () ->
  ImpM lore r op ()
function :: forall lore r op.
Name
-> [Param] -> [Param] -> ImpM lore r op () -> ImpM lore r op ()
function Name
fname [Param]
outputs [Param]
inputs ImpM lore r op ()
m = (Env lore r op -> Env lore r op)
-> ImpM lore r op () -> ImpM lore r op ()
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local Env lore r op -> Env lore r op
newFunction (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ do
  Code op
body <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (ImpM lore r op () -> ImpM lore r op (Code op))
-> ImpM lore r op () -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$ do
    (Param -> ImpM lore r op ()) -> [Param] -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Param -> ImpM lore r op ()
forall {lore} {r} {op}. Param -> ImpM lore r op ()
addParam ([Param] -> ImpM lore r op ()) -> [Param] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ [Param]
outputs [Param] -> [Param] -> [Param]
forall a. [a] -> [a] -> [a]
++ [Param]
inputs
    ImpM lore r op ()
m
  Name -> Function op -> ImpM lore r op ()
forall op lore r. Name -> Function op -> ImpM lore r op ()
emitFunction Name
fname (Function op -> ImpM lore r op ())
-> Function op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Bool
-> [Param]
-> [Param]
-> Code op
-> [ExternalValue]
-> [ExternalValue]
-> Function op
forall a.
Bool
-> [Param]
-> [Param]
-> Code a
-> [ExternalValue]
-> [ExternalValue]
-> FunctionT a
Imp.Function Bool
False [Param]
outputs [Param]
inputs Code op
body [] []
  where
    addParam :: Param -> ImpM lore r op ()
addParam (Imp.MemParam VName
name Space
space) =
      VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> MemEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> MemEntry -> VarEntry lore
MemVar Maybe (Exp lore)
forall a. Maybe a
Nothing (MemEntry -> VarEntry lore) -> MemEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
space
    addParam (Imp.ScalarParam VName
name PrimType
bt) =
      VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
ScalarVar Maybe (Exp lore)
forall a. Maybe a
Nothing (ScalarEntry -> VarEntry lore) -> ScalarEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarEntry
ScalarEntry PrimType
bt
    newFunction :: Env lore r op -> Env lore r op
newFunction Env lore r op
env = Env lore r op
env {envFunction :: Maybe Name
envFunction = Name -> Maybe Name
forall a. a -> Maybe a
Just Name
fname}