{-# LANGUAGE GeneralizedNewtypeDeriving, FlexibleContexts, LambdaCase, FlexibleInstances, MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE Strict #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE Trustworthy #-}
module Futhark.CodeGen.ImpGen
  ( -- * Entry Points
    compileProg

    -- * Pluggable Compiler
  , OpCompiler
  , ExpCompiler
  , CopyCompiler
  , StmsCompiler
  , AllocCompiler
  , Operations (..)
  , defaultOperations
  , MemLocation (..)
  , MemEntry (..)
  , ScalarEntry (..)

    -- * Monadic Compiler Interface
  , ImpM
  , localDefaultSpace, askFunction
  , newVNameForFun, nameForFun
  , askEnv, localEnv
  , localOps
  , VTable
  , getVTable
  , localVTable
  , subImpM
  , subImpM_
  , emit
  , emitFunction
  , hasFunction
  , collect
  , collect'
  , comment
  , VarEntry (..)
  , ArrayEntry (..)

    -- * Lookups
  , lookupVar
  , lookupArray
  , lookupMemory

    -- * Building Blocks
  , ToExp(..)
  , compileAlloc
  , everythingVolatile
  , compileBody
  , compileBody'
  , compileLoopBody
  , defCompileStms
  , compileStms
  , compileExp
  , defCompileExp
  , fullyIndexArray
  , fullyIndexArray'
  , copy
  , copyDWIM
  , copyDWIMFix
  , copyElementWise
  , typeSize

  -- * Constructing code.
  , dLParams
  , dFParams
  , dScope
  , dArray
  , dPrim, dPrimVol_, dPrim_, dPrimV_, dPrimV, dPrimVE

  , sFor, sWhile
  , sComment
  , sIf, sWhen, sUnless
  , sOp
  , sDeclareMem, sAlloc, sAlloc_
  , sArray, sArrayInMem, sAllocArray, sAllocArrayPerm, sStaticArray
  , sWrite, sUpdate
  , sLoopNest
  , (<--)

  , function

  , warn
  , module Language.Futhark.Warnings
  )
  where

import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Writer
import Control.Parallel.Strategies
import Data.Bifunctor (first)
import qualified Data.DList as DL
import Data.Either
import qualified Data.Map.Strict as M
import qualified Data.Set as S
import Data.Maybe
import Data.List (find, sortOn, genericLength)

import qualified Futhark.CodeGen.ImpCode as Imp
import Futhark.CodeGen.ImpCode
  (Count, Bytes, Elements,
   bytes, elements, withElemType)
import Futhark.IR.Mem
import Futhark.IR.SOACS (SOACS)
import qualified Futhark.IR.Mem.IxFun as IxFun
import Futhark.Construct (fullSliceNum)
import Futhark.MonadFreshNames
import Futhark.Util
import Futhark.Util.Loc (noLoc)
import Language.Futhark.Warnings

-- | How to compile an t'Op'.
type OpCompiler lore r op = Pattern lore -> Op lore -> ImpM lore r op ()

-- | How to compile some 'Stms'.
type StmsCompiler lore r op = Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()

-- | How to compile an 'Exp'.
type ExpCompiler lore r op = Pattern lore -> Exp lore -> ImpM lore r op ()

type CopyCompiler lore r op = PrimType
                           -> MemLocation -> Slice Imp.Exp
                           -> MemLocation -> Slice Imp.Exp
                           -> ImpM lore r op ()

-- | An alternate way of compiling an allocation.
type AllocCompiler lore r op = VName -> Count Bytes Imp.Exp -> ImpM lore r op ()

data Operations lore r op = Operations { Operations lore r op -> ExpCompiler lore r op
opsExpCompiler :: ExpCompiler lore r op
                                     , Operations lore r op -> OpCompiler lore r op
opsOpCompiler :: OpCompiler lore r op
                                     , Operations lore r op -> StmsCompiler lore r op
opsStmsCompiler :: StmsCompiler lore r op
                                     , Operations lore r op -> CopyCompiler lore r op
opsCopyCompiler :: CopyCompiler lore r op
                                     , Operations lore r op -> Map Space (AllocCompiler lore r op)
opsAllocCompilers :: M.Map Space (AllocCompiler lore r op)
                                     }

-- | An operations set for which the expression compiler always
-- returns 'defCompileExp'.
defaultOperations :: (Mem lore, FreeIn op) =>
                     OpCompiler lore r op -> Operations lore r op
defaultOperations :: OpCompiler lore r op -> Operations lore r op
defaultOperations OpCompiler lore r op
opc = Operations :: forall lore r op.
ExpCompiler lore r op
-> OpCompiler lore r op
-> StmsCompiler lore r op
-> CopyCompiler lore r op
-> Map Space (AllocCompiler lore r op)
-> Operations lore r op
Operations { opsExpCompiler :: ExpCompiler lore r op
opsExpCompiler = ExpCompiler lore r op
forall lore r op.
Mem lore =>
Pattern lore -> Exp lore -> ImpM lore r op ()
defCompileExp
                                   , opsOpCompiler :: OpCompiler lore r op
opsOpCompiler = OpCompiler lore r op
opc
                                   , opsStmsCompiler :: StmsCompiler lore r op
opsStmsCompiler = StmsCompiler lore r op
forall lore op r.
(Mem lore, FreeIn op) =>
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
defCompileStms
                                   , opsCopyCompiler :: CopyCompiler lore r op
opsCopyCompiler = CopyCompiler lore r op
forall lore r op. CopyCompiler lore r op
defaultCopy
                                   , opsAllocCompilers :: Map Space (AllocCompiler lore r op)
opsAllocCompilers = Map Space (AllocCompiler lore r op)
forall a. Monoid a => a
mempty
                                   }

-- | When an array is dared, this is where it is stored.
data MemLocation = MemLocation { MemLocation -> VName
memLocationName :: VName
                               , MemLocation -> [DimSize]
memLocationShape :: [Imp.DimSize]
                               , MemLocation -> IxFun Exp
memLocationIxFun :: IxFun.IxFun Imp.Exp
                               }
                   deriving (MemLocation -> MemLocation -> Bool
(MemLocation -> MemLocation -> Bool)
-> (MemLocation -> MemLocation -> Bool) -> Eq MemLocation
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: MemLocation -> MemLocation -> Bool
$c/= :: MemLocation -> MemLocation -> Bool
== :: MemLocation -> MemLocation -> Bool
$c== :: MemLocation -> MemLocation -> Bool
Eq, Int -> MemLocation -> ShowS
[MemLocation] -> ShowS
MemLocation -> String
(Int -> MemLocation -> ShowS)
-> (MemLocation -> String)
-> ([MemLocation] -> ShowS)
-> Show MemLocation
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MemLocation] -> ShowS
$cshowList :: [MemLocation] -> ShowS
show :: MemLocation -> String
$cshow :: MemLocation -> String
showsPrec :: Int -> MemLocation -> ShowS
$cshowsPrec :: Int -> MemLocation -> ShowS
Show)

data ArrayEntry = ArrayEntry {
    ArrayEntry -> MemLocation
entryArrayLocation :: MemLocation
  , ArrayEntry -> PrimType
entryArrayElemType :: PrimType
  }
  deriving (Int -> ArrayEntry -> ShowS
[ArrayEntry] -> ShowS
ArrayEntry -> String
(Int -> ArrayEntry -> ShowS)
-> (ArrayEntry -> String)
-> ([ArrayEntry] -> ShowS)
-> Show ArrayEntry
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ArrayEntry] -> ShowS
$cshowList :: [ArrayEntry] -> ShowS
show :: ArrayEntry -> String
$cshow :: ArrayEntry -> String
showsPrec :: Int -> ArrayEntry -> ShowS
$cshowsPrec :: Int -> ArrayEntry -> ShowS
Show)

entryArrayShape :: ArrayEntry -> [Imp.DimSize]
entryArrayShape :: ArrayEntry -> [DimSize]
entryArrayShape = MemLocation -> [DimSize]
memLocationShape (MemLocation -> [DimSize])
-> (ArrayEntry -> MemLocation) -> ArrayEntry -> [DimSize]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArrayEntry -> MemLocation
entryArrayLocation

newtype MemEntry = MemEntry { MemEntry -> Space
entryMemSpace :: Imp.Space }
  deriving (Int -> MemEntry -> ShowS
[MemEntry] -> ShowS
MemEntry -> String
(Int -> MemEntry -> ShowS)
-> (MemEntry -> String) -> ([MemEntry] -> ShowS) -> Show MemEntry
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MemEntry] -> ShowS
$cshowList :: [MemEntry] -> ShowS
show :: MemEntry -> String
$cshow :: MemEntry -> String
showsPrec :: Int -> MemEntry -> ShowS
$cshowsPrec :: Int -> MemEntry -> ShowS
Show)

newtype ScalarEntry = ScalarEntry {
    ScalarEntry -> PrimType
entryScalarType    :: PrimType
  }
  deriving (Int -> ScalarEntry -> ShowS
[ScalarEntry] -> ShowS
ScalarEntry -> String
(Int -> ScalarEntry -> ShowS)
-> (ScalarEntry -> String)
-> ([ScalarEntry] -> ShowS)
-> Show ScalarEntry
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ScalarEntry] -> ShowS
$cshowList :: [ScalarEntry] -> ShowS
show :: ScalarEntry -> String
$cshow :: ScalarEntry -> String
showsPrec :: Int -> ScalarEntry -> ShowS
$cshowsPrec :: Int -> ScalarEntry -> ShowS
Show)

-- | Every non-scalar variable must be associated with an entry.
data VarEntry lore = ArrayVar (Maybe (Exp lore)) ArrayEntry
                   | ScalarVar (Maybe (Exp lore)) ScalarEntry
                   | MemVar (Maybe (Exp lore)) MemEntry
                   deriving (Int -> VarEntry lore -> ShowS
[VarEntry lore] -> ShowS
VarEntry lore -> String
(Int -> VarEntry lore -> ShowS)
-> (VarEntry lore -> String)
-> ([VarEntry lore] -> ShowS)
-> Show (VarEntry lore)
forall lore. Decorations lore => Int -> VarEntry lore -> ShowS
forall lore. Decorations lore => [VarEntry lore] -> ShowS
forall lore. Decorations lore => VarEntry lore -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [VarEntry lore] -> ShowS
$cshowList :: forall lore. Decorations lore => [VarEntry lore] -> ShowS
show :: VarEntry lore -> String
$cshow :: forall lore. Decorations lore => VarEntry lore -> String
showsPrec :: Int -> VarEntry lore -> ShowS
$cshowsPrec :: forall lore. Decorations lore => Int -> VarEntry lore -> ShowS
Show)

-- | When compiling an expression, this is a description of where the
-- result should end up.  The integer is a reference to the construct
-- that gave rise to this destination (for patterns, this will be the
-- tag of the first name in the pattern).  This can be used to make
-- the generated code easier to relate to the original code.
data Destination = Destination { Destination -> Maybe Int
destinationTag :: Maybe Int
                               , Destination -> [ValueDestination]
valueDestinations :: [ValueDestination] }
                    deriving (Int -> Destination -> ShowS
[Destination] -> ShowS
Destination -> String
(Int -> Destination -> ShowS)
-> (Destination -> String)
-> ([Destination] -> ShowS)
-> Show Destination
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Destination] -> ShowS
$cshowList :: [Destination] -> ShowS
show :: Destination -> String
$cshow :: Destination -> String
showsPrec :: Int -> Destination -> ShowS
$cshowsPrec :: Int -> Destination -> ShowS
Show)

data ValueDestination = ScalarDestination VName
                      | MemoryDestination VName
                      | ArrayDestination (Maybe MemLocation)
                        -- ^ The 'MemLocation' is 'Just' if a copy if
                        -- required.  If it is 'Nothing', then a
                        -- copy/assignment of a memory block somewhere
                        -- takes care of this array.
                      deriving (Int -> ValueDestination -> ShowS
[ValueDestination] -> ShowS
ValueDestination -> String
(Int -> ValueDestination -> ShowS)
-> (ValueDestination -> String)
-> ([ValueDestination] -> ShowS)
-> Show ValueDestination
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ValueDestination] -> ShowS
$cshowList :: [ValueDestination] -> ShowS
show :: ValueDestination -> String
$cshow :: ValueDestination -> String
showsPrec :: Int -> ValueDestination -> ShowS
$cshowsPrec :: Int -> ValueDestination -> ShowS
Show)

data Env lore r op = Env {
    Env lore r op -> ExpCompiler lore r op
envExpCompiler :: ExpCompiler lore r op
  , Env lore r op -> StmsCompiler lore r op
envStmsCompiler :: StmsCompiler lore r op
  , Env lore r op -> OpCompiler lore r op
envOpCompiler :: OpCompiler lore r op
  , Env lore r op -> CopyCompiler lore r op
envCopyCompiler :: CopyCompiler lore r op
  , Env lore r op -> Map Space (AllocCompiler lore r op)
envAllocCompilers :: M.Map Space (AllocCompiler lore r op)
  , Env lore r op -> Space
envDefaultSpace :: Imp.Space
  , Env lore r op -> Volatility
envVolatility :: Imp.Volatility
  , Env lore r op -> r
envEnv :: r
    -- ^ User-extensible environment.
  , Env lore r op -> Maybe Name
envFunction :: Maybe Name
    -- ^ Name of the function we are compiling, if any.
  , Env lore r op -> Attrs
envAttrs :: Attrs
    -- ^ The set of attributes that are active on the enclosing
    -- statements (including the one we are currently compiling).
  }

newEnv :: r -> Operations lore r op -> Imp.Space -> Env lore r op
newEnv :: r -> Operations lore r op -> Space -> Env lore r op
newEnv r
r Operations lore r op
ops Space
ds =
  Env :: forall lore r op.
ExpCompiler lore r op
-> StmsCompiler lore r op
-> OpCompiler lore r op
-> CopyCompiler lore r op
-> Map Space (AllocCompiler lore r op)
-> Space
-> Volatility
-> r
-> Maybe Name
-> Attrs
-> Env lore r op
Env { envExpCompiler :: ExpCompiler lore r op
envExpCompiler = Operations lore r op -> ExpCompiler lore r op
forall lore r op. Operations lore r op -> ExpCompiler lore r op
opsExpCompiler Operations lore r op
ops
      , envStmsCompiler :: StmsCompiler lore r op
envStmsCompiler = Operations lore r op -> StmsCompiler lore r op
forall lore r op. Operations lore r op -> StmsCompiler lore r op
opsStmsCompiler Operations lore r op
ops
      , envOpCompiler :: OpCompiler lore r op
envOpCompiler = Operations lore r op -> OpCompiler lore r op
forall lore r op. Operations lore r op -> OpCompiler lore r op
opsOpCompiler Operations lore r op
ops
      , envCopyCompiler :: CopyCompiler lore r op
envCopyCompiler = Operations lore r op -> CopyCompiler lore r op
forall lore r op. Operations lore r op -> CopyCompiler lore r op
opsCopyCompiler Operations lore r op
ops
      , envAllocCompilers :: Map Space (AllocCompiler lore r op)
envAllocCompilers = Map Space (AllocCompiler lore r op)
forall a. Monoid a => a
mempty
      , envDefaultSpace :: Space
envDefaultSpace = Space
ds
      , envVolatility :: Volatility
envVolatility = Volatility
Imp.Nonvolatile
      , envEnv :: r
envEnv = r
r
      , envFunction :: Maybe Name
envFunction = Maybe Name
forall a. Maybe a
Nothing
      , envAttrs :: Attrs
envAttrs = Attrs
forall a. Monoid a => a
mempty
      }

-- | The symbol table used during compilation.
type VTable lore = M.Map VName (VarEntry lore)

data ImpState lore r op =
  ImpState { ImpState lore r op -> VTable lore
stateVTable :: VTable lore
           , ImpState lore r op -> Functions op
stateFunctions :: Imp.Functions op
           , ImpState lore r op -> Code op
stateCode :: Imp.Code op
           , ImpState lore r op -> Warnings
stateWarnings :: Warnings
           , ImpState lore r op -> VNameSource
stateNameSource :: VNameSource
           }

newState :: VNameSource -> ImpState lore r op
newState :: VNameSource -> ImpState lore r op
newState = VTable lore
-> Functions op
-> Code op
-> Warnings
-> VNameSource
-> ImpState lore r op
forall lore r op.
VTable lore
-> Functions op
-> Code op
-> Warnings
-> VNameSource
-> ImpState lore r op
ImpState VTable lore
forall a. Monoid a => a
mempty Functions op
forall a. Monoid a => a
mempty Code op
forall a. Monoid a => a
mempty Warnings
forall a. Monoid a => a
mempty

newtype ImpM lore r op a =
  ImpM (ReaderT (Env lore r op) (State (ImpState lore r op)) a)
  deriving (a -> ImpM lore r op b -> ImpM lore r op a
(a -> b) -> ImpM lore r op a -> ImpM lore r op b
(forall a b. (a -> b) -> ImpM lore r op a -> ImpM lore r op b)
-> (forall a b. a -> ImpM lore r op b -> ImpM lore r op a)
-> Functor (ImpM lore r op)
forall a b. a -> ImpM lore r op b -> ImpM lore r op a
forall a b. (a -> b) -> ImpM lore r op a -> ImpM lore r op b
forall lore r op a b. a -> ImpM lore r op b -> ImpM lore r op a
forall lore r op a b.
(a -> b) -> ImpM lore r op a -> ImpM lore r op b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> ImpM lore r op b -> ImpM lore r op a
$c<$ :: forall lore r op a b. a -> ImpM lore r op b -> ImpM lore r op a
fmap :: (a -> b) -> ImpM lore r op a -> ImpM lore r op b
$cfmap :: forall lore r op a b.
(a -> b) -> ImpM lore r op a -> ImpM lore r op b
Functor, Functor (ImpM lore r op)
a -> ImpM lore r op a
Functor (ImpM lore r op)
-> (forall a. a -> ImpM lore r op a)
-> (forall a b.
    ImpM lore r op (a -> b) -> ImpM lore r op a -> ImpM lore r op b)
-> (forall a b c.
    (a -> b -> c)
    -> ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op c)
-> (forall a b.
    ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b)
-> (forall a b.
    ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op a)
-> Applicative (ImpM lore r op)
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op a
ImpM lore r op (a -> b) -> ImpM lore r op a -> ImpM lore r op b
(a -> b -> c)
-> ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op c
forall a. a -> ImpM lore r op a
forall a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op a
forall a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
forall a b.
ImpM lore r op (a -> b) -> ImpM lore r op a -> ImpM lore r op b
forall lore r op. Functor (ImpM lore r op)
forall a b c.
(a -> b -> c)
-> ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op c
forall lore r op a. a -> ImpM lore r op a
forall lore r op a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op a
forall lore r op a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
forall lore r op a b.
ImpM lore r op (a -> b) -> ImpM lore r op a -> ImpM lore r op b
forall lore r op a b c.
(a -> b -> c)
-> ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op a
$c<* :: forall lore r op a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op a
*> :: ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
$c*> :: forall lore r op a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
liftA2 :: (a -> b -> c)
-> ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op c
$cliftA2 :: forall lore r op a b c.
(a -> b -> c)
-> ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op c
<*> :: ImpM lore r op (a -> b) -> ImpM lore r op a -> ImpM lore r op b
$c<*> :: forall lore r op a b.
ImpM lore r op (a -> b) -> ImpM lore r op a -> ImpM lore r op b
pure :: a -> ImpM lore r op a
$cpure :: forall lore r op a. a -> ImpM lore r op a
$cp1Applicative :: forall lore r op. Functor (ImpM lore r op)
Applicative, Applicative (ImpM lore r op)
a -> ImpM lore r op a
Applicative (ImpM lore r op)
-> (forall a b.
    ImpM lore r op a -> (a -> ImpM lore r op b) -> ImpM lore r op b)
-> (forall a b.
    ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b)
-> (forall a. a -> ImpM lore r op a)
-> Monad (ImpM lore r op)
ImpM lore r op a -> (a -> ImpM lore r op b) -> ImpM lore r op b
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
forall a. a -> ImpM lore r op a
forall a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
forall a b.
ImpM lore r op a -> (a -> ImpM lore r op b) -> ImpM lore r op b
forall lore r op. Applicative (ImpM lore r op)
forall lore r op a. a -> ImpM lore r op a
forall lore r op a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
forall lore r op a b.
ImpM lore r op a -> (a -> ImpM lore r op b) -> ImpM lore r op b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> ImpM lore r op a
$creturn :: forall lore r op a. a -> ImpM lore r op a
>> :: ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
$c>> :: forall lore r op a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
>>= :: ImpM lore r op a -> (a -> ImpM lore r op b) -> ImpM lore r op b
$c>>= :: forall lore r op a b.
ImpM lore r op a -> (a -> ImpM lore r op b) -> ImpM lore r op b
$cp1Monad :: forall lore r op. Applicative (ImpM lore r op)
Monad,
            MonadState (ImpState lore r op),
            MonadReader (Env lore r op))

instance MonadFreshNames (ImpM lore r op) where
  getNameSource :: ImpM lore r op VNameSource
getNameSource = (ImpState lore r op -> VNameSource) -> ImpM lore r op VNameSource
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ImpState lore r op -> VNameSource
forall lore r op. ImpState lore r op -> VNameSource
stateNameSource
  putNameSource :: VNameSource -> ImpM lore r op ()
putNameSource VNameSource
src = (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ())
-> (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState lore r op
s -> ImpState lore r op
s { stateNameSource :: VNameSource
stateNameSource = VNameSource
src }

-- Cannot be an KernelsMem scope because the index functions have
-- the wrong leaves (VName instead of Imp.Exp).
instance HasScope SOACS (ImpM lore r op) where
  askScope :: ImpM lore r op (Scope SOACS)
askScope = (ImpState lore r op -> Scope SOACS) -> ImpM lore r op (Scope SOACS)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ImpState lore r op -> Scope SOACS)
 -> ImpM lore r op (Scope SOACS))
-> (ImpState lore r op -> Scope SOACS)
-> ImpM lore r op (Scope SOACS)
forall a b. (a -> b) -> a -> b
$ (VarEntry lore -> NameInfo SOACS)
-> Map VName (VarEntry lore) -> Scope SOACS
forall a b k. (a -> b) -> Map k a -> Map k b
M.map (Type -> NameInfo SOACS
forall lore. LetDec lore -> NameInfo lore
LetName (Type -> NameInfo SOACS)
-> (VarEntry lore -> Type) -> VarEntry lore -> NameInfo SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VarEntry lore -> Type
forall lore. VarEntry lore -> Type
entryType) (Map VName (VarEntry lore) -> Scope SOACS)
-> (ImpState lore r op -> Map VName (VarEntry lore))
-> ImpState lore r op
-> Scope SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ImpState lore r op -> Map VName (VarEntry lore)
forall lore r op. ImpState lore r op -> VTable lore
stateVTable
    where entryType :: VarEntry lore -> Type
entryType (MemVar Maybe (Exp lore)
_ MemEntry
memEntry) =
            Space -> Type
forall shape u. Space -> TypeBase shape u
Mem (MemEntry -> Space
entryMemSpace MemEntry
memEntry)
          entryType (ArrayVar Maybe (Exp lore)
_ ArrayEntry
arrayEntry) =
            PrimType -> ShapeBase DimSize -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array
            (ArrayEntry -> PrimType
entryArrayElemType ArrayEntry
arrayEntry)
            ([DimSize] -> ShapeBase DimSize
forall d. [d] -> ShapeBase d
Shape ([DimSize] -> ShapeBase DimSize) -> [DimSize] -> ShapeBase DimSize
forall a b. (a -> b) -> a -> b
$ ArrayEntry -> [DimSize]
entryArrayShape ArrayEntry
arrayEntry)
            NoUniqueness
NoUniqueness
          entryType (ScalarVar Maybe (Exp lore)
_ ScalarEntry
scalarEntry) =
            PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> Type) -> PrimType -> Type
forall a b. (a -> b) -> a -> b
$ ScalarEntry -> PrimType
entryScalarType ScalarEntry
scalarEntry

runImpM :: ImpM lore r op a
        -> r -> Operations lore r op -> Imp.Space -> ImpState lore r op
        -> (a, ImpState lore r op)
runImpM :: ImpM lore r op a
-> r
-> Operations lore r op
-> Space
-> ImpState lore r op
-> (a, ImpState lore r op)
runImpM (ImpM ReaderT (Env lore r op) (State (ImpState lore r op)) a
m) r
r Operations lore r op
ops Space
space = State (ImpState lore r op) a
-> ImpState lore r op -> (a, ImpState lore r op)
forall s a. State s a -> s -> (a, s)
runState (ReaderT (Env lore r op) (State (ImpState lore r op)) a
-> Env lore r op -> State (ImpState lore r op) a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Env lore r op) (State (ImpState lore r op)) a
m (Env lore r op -> State (ImpState lore r op) a)
-> Env lore r op -> State (ImpState lore r op) a
forall a b. (a -> b) -> a -> b
$ r -> Operations lore r op -> Space -> Env lore r op
forall r lore op.
r -> Operations lore r op -> Space -> Env lore r op
newEnv r
r Operations lore r op
ops Space
space)

subImpM_ :: r' -> Operations lore r' op' -> ImpM lore r' op' a
         -> ImpM lore r op (Imp.Code op')
subImpM_ :: r'
-> Operations lore r' op'
-> ImpM lore r' op' a
-> ImpM lore r op (Code op')
subImpM_ r'
r Operations lore r' op'
ops ImpM lore r' op' a
m = (a, Code op') -> Code op'
forall a b. (a, b) -> b
snd ((a, Code op') -> Code op')
-> ImpM lore r op (a, Code op') -> ImpM lore r op (Code op')
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> r'
-> Operations lore r' op'
-> ImpM lore r' op' a
-> ImpM lore r op (a, Code op')
forall r' lore op' a r op.
r'
-> Operations lore r' op'
-> ImpM lore r' op' a
-> ImpM lore r op (a, Code op')
subImpM r'
r Operations lore r' op'
ops ImpM lore r' op' a
m

subImpM :: r' -> Operations lore r' op' -> ImpM lore r' op' a
        -> ImpM lore r op (a, Imp.Code op')
subImpM :: r'
-> Operations lore r' op'
-> ImpM lore r' op' a
-> ImpM lore r op (a, Code op')
subImpM r'
r Operations lore r' op'
ops (ImpM ReaderT (Env lore r' op') (State (ImpState lore r' op')) a
m) = do
  Env lore r op
env <- ImpM lore r op (Env lore r op)
forall r (m :: * -> *). MonadReader r m => m r
ask
  ImpState lore r op
s <- ImpM lore r op (ImpState lore r op)
forall s (m :: * -> *). MonadState s m => m s
get

  let env' :: Env lore r' op'
env' = Env lore r op
env { envExpCompiler :: ExpCompiler lore r' op'
envExpCompiler = Operations lore r' op' -> ExpCompiler lore r' op'
forall lore r op. Operations lore r op -> ExpCompiler lore r op
opsExpCompiler Operations lore r' op'
ops
                 , envStmsCompiler :: StmsCompiler lore r' op'
envStmsCompiler = Operations lore r' op' -> StmsCompiler lore r' op'
forall lore r op. Operations lore r op -> StmsCompiler lore r op
opsStmsCompiler Operations lore r' op'
ops
                 , envCopyCompiler :: CopyCompiler lore r' op'
envCopyCompiler = Operations lore r' op' -> CopyCompiler lore r' op'
forall lore r op. Operations lore r op -> CopyCompiler lore r op
opsCopyCompiler Operations lore r' op'
ops
                 , envOpCompiler :: OpCompiler lore r' op'
envOpCompiler = Operations lore r' op' -> OpCompiler lore r' op'
forall lore r op. Operations lore r op -> OpCompiler lore r op
opsOpCompiler Operations lore r' op'
ops
                 , envAllocCompilers :: Map Space (AllocCompiler lore r' op')
envAllocCompilers = Operations lore r' op' -> Map Space (AllocCompiler lore r' op')
forall lore r op.
Operations lore r op -> Map Space (AllocCompiler lore r op)
opsAllocCompilers Operations lore r' op'
ops
                 , envEnv :: r'
envEnv = r'
r
                 }
      s' :: ImpState lore r' op'
s' = ImpState :: forall lore r op.
VTable lore
-> Functions op
-> Code op
-> Warnings
-> VNameSource
-> ImpState lore r op
ImpState { stateVTable :: VTable lore
stateVTable = ImpState lore r op -> VTable lore
forall lore r op. ImpState lore r op -> VTable lore
stateVTable ImpState lore r op
s
                    , stateFunctions :: Functions op'
stateFunctions = Functions op'
forall a. Monoid a => a
mempty
                    , stateCode :: Code op'
stateCode = Code op'
forall a. Monoid a => a
mempty
                    , stateNameSource :: VNameSource
stateNameSource = ImpState lore r op -> VNameSource
forall lore r op. ImpState lore r op -> VNameSource
stateNameSource ImpState lore r op
s
                    , stateWarnings :: Warnings
stateWarnings = Warnings
forall a. Monoid a => a
mempty
                    }
      (a
x, ImpState lore r' op'
s'') = State (ImpState lore r' op') a
-> ImpState lore r' op' -> (a, ImpState lore r' op')
forall s a. State s a -> s -> (a, s)
runState (ReaderT (Env lore r' op') (State (ImpState lore r' op')) a
-> Env lore r' op' -> State (ImpState lore r' op') a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Env lore r' op') (State (ImpState lore r' op')) a
m Env lore r' op'
env') ImpState lore r' op'
s'

  VNameSource -> ImpM lore r op ()
forall (m :: * -> *). MonadFreshNames m => VNameSource -> m ()
putNameSource (VNameSource -> ImpM lore r op ())
-> VNameSource -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ ImpState lore r' op' -> VNameSource
forall lore r op. ImpState lore r op -> VNameSource
stateNameSource ImpState lore r' op'
s''
  Warnings -> ImpM lore r op ()
forall lore r op. Warnings -> ImpM lore r op ()
warnings (Warnings -> ImpM lore r op ()) -> Warnings -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ ImpState lore r' op' -> Warnings
forall lore r op. ImpState lore r op -> Warnings
stateWarnings ImpState lore r' op'
s''
  (a, Code op') -> ImpM lore r op (a, Code op')
forall (m :: * -> *) a. Monad m => a -> m a
return (a
x, ImpState lore r' op' -> Code op'
forall lore r op. ImpState lore r op -> Code op
stateCode ImpState lore r' op'
s'')

-- | Execute a code generation action, returning the code that was
-- emitted.
collect :: ImpM lore r op () -> ImpM lore r op (Imp.Code op)
collect :: ImpM lore r op () -> ImpM lore r op (Code op)
collect = (((), Code op) -> Code op)
-> ImpM lore r op ((), Code op) -> ImpM lore r op (Code op)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((), Code op) -> Code op
forall a b. (a, b) -> b
snd (ImpM lore r op ((), Code op) -> ImpM lore r op (Code op))
-> (ImpM lore r op () -> ImpM lore r op ((), Code op))
-> ImpM lore r op ()
-> ImpM lore r op (Code op)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ImpM lore r op () -> ImpM lore r op ((), Code op)
forall lore r op a. ImpM lore r op a -> ImpM lore r op (a, Code op)
collect'

collect' :: ImpM lore r op a -> ImpM lore r op (a, Imp.Code op)
collect' :: ImpM lore r op a -> ImpM lore r op (a, Code op)
collect' ImpM lore r op a
m = do
  Code op
prev_code <- (ImpState lore r op -> Code op) -> ImpM lore r op (Code op)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ImpState lore r op -> Code op
forall lore r op. ImpState lore r op -> Code op
stateCode
  (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ())
-> (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState lore r op
s -> ImpState lore r op
s { stateCode :: Code op
stateCode = Code op
forall a. Monoid a => a
mempty }
  a
x <- ImpM lore r op a
m
  Code op
new_code <- (ImpState lore r op -> Code op) -> ImpM lore r op (Code op)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ImpState lore r op -> Code op
forall lore r op. ImpState lore r op -> Code op
stateCode
  (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ())
-> (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState lore r op
s -> ImpState lore r op
s { stateCode :: Code op
stateCode = Code op
prev_code }
  (a, Code op) -> ImpM lore r op (a, Code op)
forall (m :: * -> *) a. Monad m => a -> m a
return (a
x, Code op
new_code)

-- | Execute a code generation action, wrapping the generated code
-- within a 'Imp.Comment' with the given description.
comment :: String -> ImpM lore r op () -> ImpM lore r op ()
comment :: String -> ImpM lore r op () -> ImpM lore r op ()
comment String
desc ImpM lore r op ()
m = do Code op
code <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect ImpM lore r op ()
m
                    Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ String -> Code op -> Code op
forall a. String -> Code a -> Code a
Imp.Comment String
desc Code op
code

-- | Emit some generated imperative code.
emit :: Imp.Code op -> ImpM lore r op ()
emit :: Code op -> ImpM lore r op ()
emit Code op
code = (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ())
-> (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState lore r op
s -> ImpState lore r op
s { stateCode :: Code op
stateCode = ImpState lore r op -> Code op
forall lore r op. ImpState lore r op -> Code op
stateCode ImpState lore r op
s Code op -> Code op -> Code op
forall a. Semigroup a => a -> a -> a
<> Code op
code }

warnings :: Warnings -> ImpM lore r op ()
warnings :: Warnings -> ImpM lore r op ()
warnings Warnings
ws = (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ())
-> (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState lore r op
s -> ImpState lore r op
s { stateWarnings :: Warnings
stateWarnings = Warnings
ws Warnings -> Warnings -> Warnings
forall a. Semigroup a => a -> a -> a
<> ImpState lore r op -> Warnings
forall lore r op. ImpState lore r op -> Warnings
stateWarnings ImpState lore r op
s}

-- | Emit a warning about something the user should be aware of.
warn :: Located loc => loc -> [loc] -> String -> ImpM lore r op ()
warn :: loc -> [loc] -> String -> ImpM lore r op ()
warn loc
loc [loc]
locs String
problem =
  Warnings -> ImpM lore r op ()
forall lore r op. Warnings -> ImpM lore r op ()
warnings (Warnings -> ImpM lore r op ()) -> Warnings -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ SrcLoc -> [SrcLoc] -> String -> Warnings
singleWarning' (loc -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf loc
loc) ((loc -> SrcLoc) -> [loc] -> [SrcLoc]
forall a b. (a -> b) -> [a] -> [b]
map loc -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf [loc]
locs) String
problem

-- | Emit a function in the generated code.
emitFunction :: Name -> Imp.Function op -> ImpM lore r op ()
emitFunction :: Name -> Function op -> ImpM lore r op ()
emitFunction Name
fname Function op
fun = do
  Imp.Functions [(Name, Function op)]
fs <- (ImpState lore r op -> Functions op)
-> ImpM lore r op (Functions op)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ImpState lore r op -> Functions op
forall lore r op. ImpState lore r op -> Functions op
stateFunctions
  (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ())
-> (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState lore r op
s -> ImpState lore r op
s { stateFunctions :: Functions op
stateFunctions = [(Name, Function op)] -> Functions op
forall a. [(Name, Function a)] -> Functions a
Imp.Functions ([(Name, Function op)] -> Functions op)
-> [(Name, Function op)] -> Functions op
forall a b. (a -> b) -> a -> b
$ (Name
fname,Function op
fun) (Name, Function op)
-> [(Name, Function op)] -> [(Name, Function op)]
forall a. a -> [a] -> [a]
: [(Name, Function op)]
fs }

-- | Check if a function of a given name exists.
hasFunction :: Name -> ImpM lore r op Bool
hasFunction :: Name -> ImpM lore r op Bool
hasFunction Name
fname = (ImpState lore r op -> Bool) -> ImpM lore r op Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ImpState lore r op -> Bool) -> ImpM lore r op Bool)
-> (ImpState lore r op -> Bool) -> ImpM lore r op Bool
forall a b. (a -> b) -> a -> b
$ \ImpState lore r op
s -> let Imp.Functions [(Name, Function op)]
fs = ImpState lore r op -> Functions op
forall lore r op. ImpState lore r op -> Functions op
stateFunctions ImpState lore r op
s
                                 in Maybe (Function op) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe (Function op) -> Bool) -> Maybe (Function op) -> Bool
forall a b. (a -> b) -> a -> b
$ Name -> [(Name, Function op)] -> Maybe (Function op)
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Name
fname [(Name, Function op)]
fs

constsVTable :: Mem lore => Stms lore -> VTable lore
constsVTable :: Stms lore -> VTable lore
constsVTable = (Stm lore -> VTable lore) -> Stms lore -> VTable lore
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm lore -> VTable lore
forall lore.
(LetDec lore ~ MemBound NoUniqueness) =>
Stm lore -> Map VName (VarEntry lore)
stmVtable
  where stmVtable :: Stm lore -> Map VName (VarEntry lore)
stmVtable (Let Pattern lore
pat StmAux (ExpDec lore)
_ Exp lore
e) =
          (PatElemT (MemBound NoUniqueness) -> Map VName (VarEntry lore))
-> [PatElemT (MemBound NoUniqueness)] -> Map VName (VarEntry lore)
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Exp lore
-> PatElemT (MemBound NoUniqueness) -> Map VName (VarEntry lore)
forall lore.
Exp lore
-> PatElemT (MemBound NoUniqueness) -> Map VName (VarEntry lore)
peVtable Exp lore
e) ([PatElemT (MemBound NoUniqueness)] -> Map VName (VarEntry lore))
-> [PatElemT (MemBound NoUniqueness)] -> Map VName (VarEntry lore)
forall a b. (a -> b) -> a -> b
$ PatternT (MemBound NoUniqueness)
-> [PatElemT (MemBound NoUniqueness)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern lore
PatternT (MemBound NoUniqueness)
pat
        peVtable :: Exp lore
-> PatElemT (MemBound NoUniqueness) -> Map VName (VarEntry lore)
peVtable Exp lore
e (PatElem VName
name MemBound NoUniqueness
dec) =
          VName -> VarEntry lore -> Map VName (VarEntry lore)
forall k a. k -> a -> Map k a
M.singleton VName
name (VarEntry lore -> Map VName (VarEntry lore))
-> VarEntry lore -> Map VName (VarEntry lore)
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> MemBound NoUniqueness -> VarEntry lore
forall lore.
Maybe (Exp lore) -> MemBound NoUniqueness -> VarEntry lore
memBoundToVarEntry (Exp lore -> Maybe (Exp lore)
forall a. a -> Maybe a
Just Exp lore
e) MemBound NoUniqueness
dec

compileProg :: (Mem lore, FreeIn op, MonadFreshNames m) =>
               r -> Operations lore r op -> Imp.Space
            -> Prog lore -> m (Warnings, Imp.Definitions op)
compileProg :: r
-> Operations lore r op
-> Space
-> Prog lore
-> m (Warnings, Definitions op)
compileProg r
r Operations lore r op
ops Space
space (Prog Stms lore
consts [FunDef lore]
funs) =
  (VNameSource -> ((Warnings, Definitions op), VNameSource))
-> m (Warnings, Definitions op)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> ((Warnings, Definitions op), VNameSource))
 -> m (Warnings, Definitions op))
-> (VNameSource -> ((Warnings, Definitions op), VNameSource))
-> m (Warnings, Definitions op)
forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
    let ([()]
_, [ImpState lore r op]
ss) =
          [((), ImpState lore r op)] -> ([()], [ImpState lore r op])
forall a b. [(a, b)] -> ([a], [b])
unzip ([((), ImpState lore r op)] -> ([()], [ImpState lore r op]))
-> [((), ImpState lore r op)] -> ([()], [ImpState lore r op])
forall a b. (a -> b) -> a -> b
$ Strategy ((), ImpState lore r op)
-> (FunDef lore -> ((), ImpState lore r op))
-> [FunDef lore]
-> [((), ImpState lore r op)]
forall b a. Strategy b -> (a -> b) -> [a] -> [b]
parMap Strategy ((), ImpState lore r op)
forall a. Strategy a
rpar (VNameSource -> FunDef lore -> ((), ImpState lore r op)
compileFunDef' VNameSource
src) [FunDef lore]
funs
        free_in_funs :: Names
free_in_funs =
          Functions op -> Names
forall a. FreeIn a => a -> Names
freeIn (Functions op -> Names) -> Functions op -> Names
forall a b. (a -> b) -> a -> b
$ [Functions op] -> Functions op
forall a. Monoid a => [a] -> a
mconcat ([Functions op] -> Functions op) -> [Functions op] -> Functions op
forall a b. (a -> b) -> a -> b
$ (ImpState lore r op -> Functions op)
-> [ImpState lore r op] -> [Functions op]
forall a b. (a -> b) -> [a] -> [b]
map ImpState lore r op -> Functions op
forall lore r op. ImpState lore r op -> Functions op
stateFunctions [ImpState lore r op]
ss
        (Constants op
consts', ImpState lore r op
s') =
          ImpM lore r op (Constants op)
-> r
-> Operations lore r op
-> Space
-> ImpState lore r op
-> (Constants op, ImpState lore r op)
forall lore r op a.
ImpM lore r op a
-> r
-> Operations lore r op
-> Space
-> ImpState lore r op
-> (a, ImpState lore r op)
runImpM (Names -> Stms lore -> ImpM lore r op (Constants op)
forall lore r op.
Names -> Stms lore -> ImpM lore r op (Constants op)
compileConsts Names
free_in_funs Stms lore
consts) r
r Operations lore r op
ops Space
space (ImpState lore r op -> (Constants op, ImpState lore r op))
-> ImpState lore r op -> (Constants op, ImpState lore r op)
forall a b. (a -> b) -> a -> b
$
          [ImpState lore r op] -> ImpState lore r op
forall lore r op lore r. [ImpState lore r op] -> ImpState lore r op
combineStates [ImpState lore r op]
ss
  in ((ImpState lore r op -> Warnings
forall lore r op. ImpState lore r op -> Warnings
stateWarnings ImpState lore r op
s',
       Constants op -> Functions op -> Definitions op
forall a. Constants a -> Functions a -> Definitions a
Imp.Definitions Constants op
consts' (ImpState lore r op -> Functions op
forall lore r op. ImpState lore r op -> Functions op
stateFunctions ImpState lore r op
s')),
      ImpState lore r op -> VNameSource
forall lore r op. ImpState lore r op -> VNameSource
stateNameSource ImpState lore r op
s')
  where compileFunDef' :: VNameSource -> FunDef lore -> ((), ImpState lore r op)
compileFunDef' VNameSource
src FunDef lore
fdef =
          ImpM lore r op ()
-> r
-> Operations lore r op
-> Space
-> ImpState lore r op
-> ((), ImpState lore r op)
forall lore r op a.
ImpM lore r op a
-> r
-> Operations lore r op
-> Space
-> ImpState lore r op
-> (a, ImpState lore r op)
runImpM (FunDef lore -> ImpM lore r op ()
forall lore r op. Mem lore => FunDef lore -> ImpM lore r op ()
compileFunDef FunDef lore
fdef) r
r Operations lore r op
ops Space
space
          (VNameSource -> ImpState Any Any op
forall lore r op. VNameSource -> ImpState lore r op
newState VNameSource
src) { stateVTable :: VTable lore
stateVTable = Stms lore -> VTable lore
forall lore. Mem lore => Stms lore -> VTable lore
constsVTable Stms lore
consts }

        combineStates :: [ImpState lore r op] -> ImpState lore r op
combineStates [ImpState lore r op]
ss =
          let Imp.Functions [(Name, Function op)]
funs' = [Functions op] -> Functions op
forall a. Monoid a => [a] -> a
mconcat ([Functions op] -> Functions op) -> [Functions op] -> Functions op
forall a b. (a -> b) -> a -> b
$ (ImpState lore r op -> Functions op)
-> [ImpState lore r op] -> [Functions op]
forall a b. (a -> b) -> [a] -> [b]
map ImpState lore r op -> Functions op
forall lore r op. ImpState lore r op -> Functions op
stateFunctions [ImpState lore r op]
ss
              src :: VNameSource
src = [VNameSource] -> VNameSource
forall a. Monoid a => [a] -> a
mconcat ((ImpState lore r op -> VNameSource)
-> [ImpState lore r op] -> [VNameSource]
forall a b. (a -> b) -> [a] -> [b]
map ImpState lore r op -> VNameSource
forall lore r op. ImpState lore r op -> VNameSource
stateNameSource [ImpState lore r op]
ss)
          in (VNameSource -> ImpState lore Any op
forall lore r op. VNameSource -> ImpState lore r op
newState VNameSource
src) { stateFunctions :: Functions op
stateFunctions =
                                [(Name, Function op)] -> Functions op
forall a. [(Name, Function a)] -> Functions a
Imp.Functions ([(Name, Function op)] -> Functions op)
-> [(Name, Function op)] -> Functions op
forall a b. (a -> b) -> a -> b
$ Map Name (Function op) -> [(Name, Function op)]
forall k a. Map k a -> [(k, a)]
M.toList (Map Name (Function op) -> [(Name, Function op)])
-> Map Name (Function op) -> [(Name, Function op)]
forall a b. (a -> b) -> a -> b
$ [(Name, Function op)] -> Map Name (Function op)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(Name, Function op)]
funs'
                            , stateWarnings :: Warnings
stateWarnings =
                                [Warnings] -> Warnings
forall a. Monoid a => [a] -> a
mconcat ([Warnings] -> Warnings) -> [Warnings] -> Warnings
forall a b. (a -> b) -> a -> b
$ (ImpState lore r op -> Warnings)
-> [ImpState lore r op] -> [Warnings]
forall a b. (a -> b) -> [a] -> [b]
map ImpState lore r op -> Warnings
forall lore r op. ImpState lore r op -> Warnings
stateWarnings [ImpState lore r op]
ss
                            }

compileConsts :: Names -> Stms lore -> ImpM lore r op (Imp.Constants op)
compileConsts :: Names -> Stms lore -> ImpM lore r op (Constants op)
compileConsts Names
used_consts Stms lore
stms = do
  Code op
code <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (ImpM lore r op () -> ImpM lore r op (Code op))
-> ImpM lore r op () -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$ Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
used_consts Stms lore
stms (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ () -> ImpM lore r op ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  Constants op -> ImpM lore r op (Constants op)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Constants op -> ImpM lore r op (Constants op))
-> Constants op -> ImpM lore r op (Constants op)
forall a b. (a -> b) -> a -> b
$ ([Param] -> Code op -> Constants op)
-> ([Param], Code op) -> Constants op
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry [Param] -> Code op -> Constants op
forall a. [Param] -> Code a -> Constants a
Imp.Constants (([Param], Code op) -> Constants op)
-> ([Param], Code op) -> Constants op
forall a b. (a -> b) -> a -> b
$ (DList Param -> [Param])
-> (DList Param, Code op) -> ([Param], Code op)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first DList Param -> [Param]
forall a. DList a -> [a]
DL.toList ((DList Param, Code op) -> ([Param], Code op))
-> (DList Param, Code op) -> ([Param], Code op)
forall a b. (a -> b) -> a -> b
$ Code op -> (DList Param, Code op)
extract Code op
code
  where -- Fish out those top-level declarations in the constant
        -- initialisation code that are free in the functions.
        extract :: Code op -> (DList Param, Code op)
extract (Code op
x Imp.:>>: Code op
y) =
          Code op -> (DList Param, Code op)
extract Code op
x (DList Param, Code op)
-> (DList Param, Code op) -> (DList Param, Code op)
forall a. Semigroup a => a -> a -> a
<> Code op -> (DList Param, Code op)
extract Code op
y
        extract (Imp.DeclareMem VName
name Space
space)
          | VName
name VName -> Names -> Bool
`nameIn` Names
used_consts =
              (Param -> DList Param
forall a. a -> DList a
DL.singleton (Param -> DList Param) -> Param -> DList Param
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Param
Imp.MemParam VName
name Space
space,
               Code op
forall a. Monoid a => a
mempty)
        extract (Imp.DeclareScalar VName
name Volatility
_ PrimType
t)
          | VName
name VName -> Names -> Bool
`nameIn` Names
used_consts =
              (Param -> DList Param
forall a. a -> DList a
DL.singleton (Param -> DList Param) -> Param -> DList Param
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Param
Imp.ScalarParam VName
name PrimType
t,
               Code op
forall a. Monoid a => a
mempty)
        extract Code op
s =
          (DList Param
forall a. Monoid a => a
mempty, Code op
s)

compileInParam :: Mem lore =>
                  FParam lore -> ImpM lore r op (Either Imp.Param ArrayDecl)
compileInParam :: FParam lore -> ImpM lore r op (Either Param ArrayDecl)
compileInParam FParam lore
fparam = case Param (MemInfo DimSize Uniqueness MemBind)
-> MemInfo DimSize Uniqueness MemBind
forall dec. Param dec -> dec
paramDec FParam lore
Param (MemInfo DimSize Uniqueness MemBind)
fparam of
  MemPrim PrimType
bt ->
    Either Param ArrayDecl -> ImpM lore r op (Either Param ArrayDecl)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either Param ArrayDecl -> ImpM lore r op (Either Param ArrayDecl))
-> Either Param ArrayDecl
-> ImpM lore r op (Either Param ArrayDecl)
forall a b. (a -> b) -> a -> b
$ Param -> Either Param ArrayDecl
forall a b. a -> Either a b
Left (Param -> Either Param ArrayDecl)
-> Param -> Either Param ArrayDecl
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Param
Imp.ScalarParam VName
name PrimType
bt
  MemMem Space
space ->
    Either Param ArrayDecl -> ImpM lore r op (Either Param ArrayDecl)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either Param ArrayDecl -> ImpM lore r op (Either Param ArrayDecl))
-> Either Param ArrayDecl
-> ImpM lore r op (Either Param ArrayDecl)
forall a b. (a -> b) -> a -> b
$ Param -> Either Param ArrayDecl
forall a b. a -> Either a b
Left (Param -> Either Param ArrayDecl)
-> Param -> Either Param ArrayDecl
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Param
Imp.MemParam VName
name Space
space
  MemArray PrimType
bt ShapeBase DimSize
shape Uniqueness
_ (ArrayIn VName
mem IxFun
ixfun) ->
    Either Param ArrayDecl -> ImpM lore r op (Either Param ArrayDecl)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either Param ArrayDecl -> ImpM lore r op (Either Param ArrayDecl))
-> Either Param ArrayDecl
-> ImpM lore r op (Either Param ArrayDecl)
forall a b. (a -> b) -> a -> b
$ ArrayDecl -> Either Param ArrayDecl
forall a b. b -> Either a b
Right (ArrayDecl -> Either Param ArrayDecl)
-> ArrayDecl -> Either Param ArrayDecl
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> MemLocation -> ArrayDecl
ArrayDecl VName
name PrimType
bt (MemLocation -> ArrayDecl) -> MemLocation -> ArrayDecl
forall a b. (a -> b) -> a -> b
$
    VName -> [DimSize] -> IxFun Exp -> MemLocation
MemLocation VName
mem (ShapeBase DimSize -> [DimSize]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase DimSize
shape) (IxFun Exp -> MemLocation) -> IxFun Exp -> MemLocation
forall a b. (a -> b) -> a -> b
$ (PrimExp VName -> Exp) -> IxFun -> IxFun Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (PrimType -> PrimExp VName -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32) IxFun
ixfun
  where name :: VName
name = Param (MemInfo DimSize Uniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName FParam lore
Param (MemInfo DimSize Uniqueness MemBind)
fparam

data ArrayDecl = ArrayDecl VName PrimType MemLocation

fparamSizes :: Typed dec => Param dec -> S.Set VName
fparamSizes :: Param dec -> Set VName
fparamSizes = [VName] -> Set VName
forall a. Ord a => [a] -> Set a
S.fromList ([VName] -> Set VName)
-> (Param dec -> [VName]) -> Param dec -> Set VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [DimSize] -> [VName]
subExpVars ([DimSize] -> [VName])
-> (Param dec -> [DimSize]) -> Param dec -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> [DimSize]
forall u. TypeBase (ShapeBase DimSize) u -> [DimSize]
arrayDims (Type -> [DimSize])
-> (Param dec -> Type) -> Param dec -> [DimSize]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param dec -> Type
forall dec. Typed dec => Param dec -> Type
paramType

compileInParams :: Mem lore =>
                   [FParam lore] -> [EntryPointType]
                -> ImpM lore r op ([Imp.Param], [ArrayDecl], [Imp.ExternalValue])
compileInParams :: [FParam lore]
-> [EntryPointType]
-> ImpM lore r op ([Param], [ArrayDecl], [ExternalValue])
compileInParams [FParam lore]
params [EntryPointType]
orig_epts = do
  let ([Param (MemInfo DimSize Uniqueness MemBind)]
ctx_params, [Param (MemInfo DimSize Uniqueness MemBind)]
val_params) =
        Int
-> [Param (MemInfo DimSize Uniqueness MemBind)]
-> ([Param (MemInfo DimSize Uniqueness MemBind)],
    [Param (MemInfo DimSize Uniqueness MemBind)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([Param (MemInfo DimSize Uniqueness MemBind)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [FParam lore]
[Param (MemInfo DimSize Uniqueness MemBind)]
params Int -> Int -> Int
forall a. Num a => a -> a -> a
- [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((EntryPointType -> Int) -> [EntryPointType] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map EntryPointType -> Int
entryPointSize [EntryPointType]
orig_epts)) [FParam lore]
[Param (MemInfo DimSize Uniqueness MemBind)]
params
  ([Param]
inparams, [ArrayDecl]
arrayds) <- [Either Param ArrayDecl] -> ([Param], [ArrayDecl])
forall a b. [Either a b] -> ([a], [b])
partitionEithers ([Either Param ArrayDecl] -> ([Param], [ArrayDecl]))
-> ImpM lore r op [Either Param ArrayDecl]
-> ImpM lore r op ([Param], [ArrayDecl])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Param (MemInfo DimSize Uniqueness MemBind)
 -> ImpM lore r op (Either Param ArrayDecl))
-> [Param (MemInfo DimSize Uniqueness MemBind)]
-> ImpM lore r op [Either Param ArrayDecl]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Param (MemInfo DimSize Uniqueness MemBind)
-> ImpM lore r op (Either Param ArrayDecl)
forall lore r op.
Mem lore =>
FParam lore -> ImpM lore r op (Either Param ArrayDecl)
compileInParam ([Param (MemInfo DimSize Uniqueness MemBind)]
ctx_params[Param (MemInfo DimSize Uniqueness MemBind)]
-> [Param (MemInfo DimSize Uniqueness MemBind)]
-> [Param (MemInfo DimSize Uniqueness MemBind)]
forall a. [a] -> [a] -> [a]
++[Param (MemInfo DimSize Uniqueness MemBind)]
val_params)
  let findArray :: VName -> Maybe ArrayDecl
findArray VName
x = (ArrayDecl -> Bool) -> [ArrayDecl] -> Maybe ArrayDecl
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (VName -> ArrayDecl -> Bool
isArrayDecl VName
x) [ArrayDecl]
arrayds
      sizes :: Set VName
sizes = [Set VName] -> Set VName
forall a. Monoid a => [a] -> a
mconcat ([Set VName] -> Set VName) -> [Set VName] -> Set VName
forall a b. (a -> b) -> a -> b
$ (Param (MemInfo DimSize Uniqueness MemBind) -> Set VName)
-> [Param (MemInfo DimSize Uniqueness MemBind)] -> [Set VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (MemInfo DimSize Uniqueness MemBind) -> Set VName
forall dec. Typed dec => Param dec -> Set VName
fparamSizes ([Param (MemInfo DimSize Uniqueness MemBind)] -> [Set VName])
-> [Param (MemInfo DimSize Uniqueness MemBind)] -> [Set VName]
forall a b. (a -> b) -> a -> b
$ [Param (MemInfo DimSize Uniqueness MemBind)]
ctx_params[Param (MemInfo DimSize Uniqueness MemBind)]
-> [Param (MemInfo DimSize Uniqueness MemBind)]
-> [Param (MemInfo DimSize Uniqueness MemBind)]
forall a. [a] -> [a] -> [a]
++[Param (MemInfo DimSize Uniqueness MemBind)]
val_params

      summaries :: Map VName Space
summaries = [(VName, Space)] -> Map VName Space
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Space)] -> Map VName Space)
-> [(VName, Space)] -> Map VName Space
forall a b. (a -> b) -> a -> b
$ (Param (MemInfo DimSize Uniqueness MemBind)
 -> Maybe (VName, Space))
-> [Param (MemInfo DimSize Uniqueness MemBind)] -> [(VName, Space)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Param (MemInfo DimSize Uniqueness MemBind) -> Maybe (VName, Space)
forall d u ret. Param (MemInfo d u ret) -> Maybe (VName, Space)
memSummary [FParam lore]
[Param (MemInfo DimSize Uniqueness MemBind)]
params
        where memSummary :: Param (MemInfo d u ret) -> Maybe (VName, Space)
memSummary Param (MemInfo d u ret)
param
                | MemMem Space
space <- Param (MemInfo d u ret) -> MemInfo d u ret
forall dec. Param dec -> dec
paramDec Param (MemInfo d u ret)
param =
                    (VName, Space) -> Maybe (VName, Space)
forall a. a -> Maybe a
Just (Param (MemInfo d u ret) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo d u ret)
param, Space
space)
                | Bool
otherwise =
                    Maybe (VName, Space)
forall a. Maybe a
Nothing

      findMemInfo :: VName -> Maybe Space
      findMemInfo :: VName -> Maybe Space
findMemInfo = (VName -> Map VName Space -> Maybe Space)
-> Map VName Space -> VName -> Maybe Space
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> Map VName Space -> Maybe Space
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Map VName Space
summaries

      mkValueDesc :: Param (MemInfo DimSize Uniqueness MemBind)
-> Signedness -> Maybe ValueDesc
mkValueDesc Param (MemInfo DimSize Uniqueness MemBind)
fparam Signedness
signedness =
        case (VName -> Maybe ArrayDecl
findArray (VName -> Maybe ArrayDecl) -> VName -> Maybe ArrayDecl
forall a b. (a -> b) -> a -> b
$ Param (MemInfo DimSize Uniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo DimSize Uniqueness MemBind)
fparam, Param (MemInfo DimSize Uniqueness MemBind) -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param (MemInfo DimSize Uniqueness MemBind)
fparam) of
          (Just (ArrayDecl VName
_ PrimType
bt (MemLocation VName
mem [DimSize]
shape IxFun Exp
_)), Type
_) -> do
            Space
memspace <- VName -> Maybe Space
findMemInfo VName
mem
            ValueDesc -> Maybe ValueDesc
forall a. a -> Maybe a
Just (ValueDesc -> Maybe ValueDesc) -> ValueDesc -> Maybe ValueDesc
forall a b. (a -> b) -> a -> b
$ VName -> Space -> PrimType -> Signedness -> [DimSize] -> ValueDesc
Imp.ArrayValue VName
mem Space
memspace PrimType
bt Signedness
signedness [DimSize]
shape
          (Maybe ArrayDecl
_, Prim PrimType
bt)
            | Param (MemInfo DimSize Uniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo DimSize Uniqueness MemBind)
fparam VName -> Set VName -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Set VName
sizes ->
              Maybe ValueDesc
forall a. Maybe a
Nothing
            | Bool
otherwise ->
              ValueDesc -> Maybe ValueDesc
forall a. a -> Maybe a
Just (ValueDesc -> Maybe ValueDesc) -> ValueDesc -> Maybe ValueDesc
forall a b. (a -> b) -> a -> b
$ PrimType -> Signedness -> VName -> ValueDesc
Imp.ScalarValue PrimType
bt Signedness
signedness (VName -> ValueDesc) -> VName -> ValueDesc
forall a b. (a -> b) -> a -> b
$ Param (MemInfo DimSize Uniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo DimSize Uniqueness MemBind)
fparam
          (Maybe ArrayDecl, Type)
_ ->
            Maybe ValueDesc
forall a. Maybe a
Nothing

      mkExts :: [EntryPointType]
-> [Param (MemInfo DimSize Uniqueness MemBind)] -> [ExternalValue]
mkExts (TypeOpaque String
desc Int
n:[EntryPointType]
epts) [Param (MemInfo DimSize Uniqueness MemBind)]
fparams =
        let ([Param (MemInfo DimSize Uniqueness MemBind)]
fparams',[Param (MemInfo DimSize Uniqueness MemBind)]
rest) = Int
-> [Param (MemInfo DimSize Uniqueness MemBind)]
-> ([Param (MemInfo DimSize Uniqueness MemBind)],
    [Param (MemInfo DimSize Uniqueness MemBind)])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
n [Param (MemInfo DimSize Uniqueness MemBind)]
fparams
        in String -> [ValueDesc] -> ExternalValue
Imp.OpaqueValue String
desc
           ((Param (MemInfo DimSize Uniqueness MemBind) -> Maybe ValueDesc)
-> [Param (MemInfo DimSize Uniqueness MemBind)] -> [ValueDesc]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (Param (MemInfo DimSize Uniqueness MemBind)
-> Signedness -> Maybe ValueDesc
`mkValueDesc` Signedness
Imp.TypeDirect) [Param (MemInfo DimSize Uniqueness MemBind)]
fparams') ExternalValue -> [ExternalValue] -> [ExternalValue]
forall a. a -> [a] -> [a]
:
           [EntryPointType]
-> [Param (MemInfo DimSize Uniqueness MemBind)] -> [ExternalValue]
mkExts [EntryPointType]
epts [Param (MemInfo DimSize Uniqueness MemBind)]
rest
      mkExts (EntryPointType
TypeUnsigned:[EntryPointType]
epts) (Param (MemInfo DimSize Uniqueness MemBind)
fparam:[Param (MemInfo DimSize Uniqueness MemBind)]
fparams) =
        Maybe ExternalValue -> [ExternalValue]
forall a. Maybe a -> [a]
maybeToList (ValueDesc -> ExternalValue
Imp.TransparentValue (ValueDesc -> ExternalValue)
-> Maybe ValueDesc -> Maybe ExternalValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Param (MemInfo DimSize Uniqueness MemBind)
-> Signedness -> Maybe ValueDesc
mkValueDesc Param (MemInfo DimSize Uniqueness MemBind)
fparam Signedness
Imp.TypeUnsigned) [ExternalValue] -> [ExternalValue] -> [ExternalValue]
forall a. [a] -> [a] -> [a]
++
        [EntryPointType]
-> [Param (MemInfo DimSize Uniqueness MemBind)] -> [ExternalValue]
mkExts [EntryPointType]
epts [Param (MemInfo DimSize Uniqueness MemBind)]
fparams
      mkExts (EntryPointType
TypeDirect:[EntryPointType]
epts) (Param (MemInfo DimSize Uniqueness MemBind)
fparam:[Param (MemInfo DimSize Uniqueness MemBind)]
fparams) =
        Maybe ExternalValue -> [ExternalValue]
forall a. Maybe a -> [a]
maybeToList (ValueDesc -> ExternalValue
Imp.TransparentValue (ValueDesc -> ExternalValue)
-> Maybe ValueDesc -> Maybe ExternalValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Param (MemInfo DimSize Uniqueness MemBind)
-> Signedness -> Maybe ValueDesc
mkValueDesc Param (MemInfo DimSize Uniqueness MemBind)
fparam Signedness
Imp.TypeDirect) [ExternalValue] -> [ExternalValue] -> [ExternalValue]
forall a. [a] -> [a] -> [a]
++
        [EntryPointType]
-> [Param (MemInfo DimSize Uniqueness MemBind)] -> [ExternalValue]
mkExts [EntryPointType]
epts [Param (MemInfo DimSize Uniqueness MemBind)]
fparams
      mkExts [EntryPointType]
_ [Param (MemInfo DimSize Uniqueness MemBind)]
_ = []

  ([Param], [ArrayDecl], [ExternalValue])
-> ImpM lore r op ([Param], [ArrayDecl], [ExternalValue])
forall (m :: * -> *) a. Monad m => a -> m a
return ([Param]
inparams, [ArrayDecl]
arrayds, [EntryPointType]
-> [Param (MemInfo DimSize Uniqueness MemBind)] -> [ExternalValue]
mkExts [EntryPointType]
orig_epts [Param (MemInfo DimSize Uniqueness MemBind)]
val_params)
  where isArrayDecl :: VName -> ArrayDecl -> Bool
isArrayDecl VName
x (ArrayDecl VName
y PrimType
_ MemLocation
_) = VName
x VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
y

compileOutParams :: Mem lore =>
                    [RetType lore] -> [EntryPointType]
                 -> ImpM lore r op ([Imp.ExternalValue], [Imp.Param], Destination)
compileOutParams :: [RetType lore]
-> [EntryPointType]
-> ImpM lore r op ([ExternalValue], [Param], Destination)
compileOutParams [RetType lore]
orig_rts [EntryPointType]
orig_epts = do
  (([ExternalValue]
extvs, [ValueDestination]
dests), ([Param]
outparams,Map Int ValueDestination
ctx_dests)) <-
    WriterT
  ([Param], Map Int ValueDestination)
  (ImpM lore r op)
  ([ExternalValue], [ValueDestination])
-> ImpM
     lore
     r
     op
     (([ExternalValue], [ValueDestination]),
      ([Param], Map Int ValueDestination))
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT
   ([Param], Map Int ValueDestination)
   (ImpM lore r op)
   ([ExternalValue], [ValueDestination])
 -> ImpM
      lore
      r
      op
      (([ExternalValue], [ValueDestination]),
       ([Param], Map Int ValueDestination)))
-> WriterT
     ([Param], Map Int ValueDestination)
     (ImpM lore r op)
     ([ExternalValue], [ValueDestination])
-> ImpM
     lore
     r
     op
     (([ExternalValue], [ValueDestination]),
      ([Param], Map Int ValueDestination))
forall a b. (a -> b) -> a -> b
$ StateT
  (Map Any Any, Map Int VName)
  (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
  ([ExternalValue], [ValueDestination])
-> (Map Any Any, Map Int VName)
-> WriterT
     ([Param], Map Int ValueDestination)
     (ImpM lore r op)
     ([ExternalValue], [ValueDestination])
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT ([EntryPointType]
-> [MemInfo (Ext DimSize) Uniqueness MemReturn]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ([ExternalValue], [ValueDestination])
mkExts [EntryPointType]
orig_epts [RetType lore]
[MemInfo (Ext DimSize) Uniqueness MemReturn]
orig_rts) (Map Any Any
forall k a. Map k a
M.empty, Map Int VName
forall k a. Map k a
M.empty)
  let ctx_dests' :: [ValueDestination]
ctx_dests' = ((Int, ValueDestination) -> ValueDestination)
-> [(Int, ValueDestination)] -> [ValueDestination]
forall a b. (a -> b) -> [a] -> [b]
map (Int, ValueDestination) -> ValueDestination
forall a b. (a, b) -> b
snd ([(Int, ValueDestination)] -> [ValueDestination])
-> [(Int, ValueDestination)] -> [ValueDestination]
forall a b. (a -> b) -> a -> b
$ ((Int, ValueDestination) -> Int)
-> [(Int, ValueDestination)] -> [(Int, ValueDestination)]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn (Int, ValueDestination) -> Int
forall a b. (a, b) -> a
fst ([(Int, ValueDestination)] -> [(Int, ValueDestination)])
-> [(Int, ValueDestination)] -> [(Int, ValueDestination)]
forall a b. (a -> b) -> a -> b
$ Map Int ValueDestination -> [(Int, ValueDestination)]
forall k a. Map k a -> [(k, a)]
M.toList Map Int ValueDestination
ctx_dests
  ([ExternalValue], [Param], Destination)
-> ImpM lore r op ([ExternalValue], [Param], Destination)
forall (m :: * -> *) a. Monad m => a -> m a
return ([ExternalValue]
extvs, [Param]
outparams, Maybe Int -> [ValueDestination] -> Destination
Destination Maybe Int
forall a. Maybe a
Nothing ([ValueDestination] -> Destination)
-> [ValueDestination] -> Destination
forall a b. (a -> b) -> a -> b
$ [ValueDestination]
ctx_dests' [ValueDestination] -> [ValueDestination] -> [ValueDestination]
forall a. Semigroup a => a -> a -> a
<> [ValueDestination]
dests)
  where imp :: ImpM lore r op a
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     a
imp = WriterT ([Param], Map Int ValueDestination) (ImpM lore r op) a
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op) a
 -> StateT
      (Map Any Any, Map Int VName)
      (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
      a)
-> (ImpM lore r op a
    -> WriterT ([Param], Map Int ValueDestination) (ImpM lore r op) a)
-> ImpM lore r op a
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ImpM lore r op a
-> WriterT ([Param], Map Int ValueDestination) (ImpM lore r op) a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift

        mkExts :: [EntryPointType]
-> [MemInfo (Ext DimSize) Uniqueness MemReturn]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ([ExternalValue], [ValueDestination])
mkExts (TypeOpaque String
desc Int
n:[EntryPointType]
epts) [MemInfo (Ext DimSize) Uniqueness MemReturn]
rts = do
          let ([MemInfo (Ext DimSize) Uniqueness MemReturn]
rts',[MemInfo (Ext DimSize) Uniqueness MemReturn]
rest) = Int
-> [MemInfo (Ext DimSize) Uniqueness MemReturn]
-> ([MemInfo (Ext DimSize) Uniqueness MemReturn],
    [MemInfo (Ext DimSize) Uniqueness MemReturn])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
n [MemInfo (Ext DimSize) Uniqueness MemReturn]
rts
          ([ValueDesc]
evs, [ValueDestination]
dests) <- [(ValueDesc, ValueDestination)]
-> ([ValueDesc], [ValueDestination])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(ValueDesc, ValueDestination)]
 -> ([ValueDesc], [ValueDestination]))
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     [(ValueDesc, ValueDestination)]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ([ValueDesc], [ValueDestination])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (MemInfo (Ext DimSize) Uniqueness MemReturn
 -> Signedness
 -> StateT
      (Map Any Any, Map Int VName)
      (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
      (ValueDesc, ValueDestination))
-> [MemInfo (Ext DimSize) Uniqueness MemReturn]
-> [Signedness]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     [(ValueDesc, ValueDestination)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM MemInfo (Ext DimSize) Uniqueness MemReturn
-> Signedness
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     (ValueDesc, ValueDestination)
mkParam [MemInfo (Ext DimSize) Uniqueness MemReturn]
rts' (Signedness -> [Signedness]
forall a. a -> [a]
repeat Signedness
Imp.TypeDirect)
          ([ExternalValue]
more_values, [ValueDestination]
more_dests) <- [EntryPointType]
-> [MemInfo (Ext DimSize) Uniqueness MemReturn]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ([ExternalValue], [ValueDestination])
mkExts [EntryPointType]
epts [MemInfo (Ext DimSize) Uniqueness MemReturn]
rest
          ([ExternalValue], [ValueDestination])
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ([ExternalValue], [ValueDestination])
forall (m :: * -> *) a. Monad m => a -> m a
return (String -> [ValueDesc] -> ExternalValue
Imp.OpaqueValue String
desc [ValueDesc]
evs ExternalValue -> [ExternalValue] -> [ExternalValue]
forall a. a -> [a] -> [a]
: [ExternalValue]
more_values,
                  [ValueDestination]
dests [ValueDestination] -> [ValueDestination] -> [ValueDestination]
forall a. [a] -> [a] -> [a]
++ [ValueDestination]
more_dests)
        mkExts (EntryPointType
TypeUnsigned:[EntryPointType]
epts) (MemInfo (Ext DimSize) Uniqueness MemReturn
rt:[MemInfo (Ext DimSize) Uniqueness MemReturn]
rts) = do
          (ValueDesc
ev,ValueDestination
dest) <- MemInfo (Ext DimSize) Uniqueness MemReturn
-> Signedness
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     (ValueDesc, ValueDestination)
mkParam MemInfo (Ext DimSize) Uniqueness MemReturn
rt Signedness
Imp.TypeUnsigned
          ([ExternalValue]
more_values, [ValueDestination]
more_dests) <- [EntryPointType]
-> [MemInfo (Ext DimSize) Uniqueness MemReturn]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ([ExternalValue], [ValueDestination])
mkExts [EntryPointType]
epts [MemInfo (Ext DimSize) Uniqueness MemReturn]
rts
          ([ExternalValue], [ValueDestination])
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ([ExternalValue], [ValueDestination])
forall (m :: * -> *) a. Monad m => a -> m a
return (ValueDesc -> ExternalValue
Imp.TransparentValue ValueDesc
ev ExternalValue -> [ExternalValue] -> [ExternalValue]
forall a. a -> [a] -> [a]
: [ExternalValue]
more_values,
                  ValueDestination
dest ValueDestination -> [ValueDestination] -> [ValueDestination]
forall a. a -> [a] -> [a]
: [ValueDestination]
more_dests)
        mkExts (EntryPointType
TypeDirect:[EntryPointType]
epts) (MemInfo (Ext DimSize) Uniqueness MemReturn
rt:[MemInfo (Ext DimSize) Uniqueness MemReturn]
rts) = do
          (ValueDesc
ev,ValueDestination
dest) <- MemInfo (Ext DimSize) Uniqueness MemReturn
-> Signedness
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     (ValueDesc, ValueDestination)
mkParam MemInfo (Ext DimSize) Uniqueness MemReturn
rt Signedness
Imp.TypeDirect
          ([ExternalValue]
more_values, [ValueDestination]
more_dests) <- [EntryPointType]
-> [MemInfo (Ext DimSize) Uniqueness MemReturn]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ([ExternalValue], [ValueDestination])
mkExts [EntryPointType]
epts [MemInfo (Ext DimSize) Uniqueness MemReturn]
rts
          ([ExternalValue], [ValueDestination])
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ([ExternalValue], [ValueDestination])
forall (m :: * -> *) a. Monad m => a -> m a
return (ValueDesc -> ExternalValue
Imp.TransparentValue ValueDesc
ev ExternalValue -> [ExternalValue] -> [ExternalValue]
forall a. a -> [a] -> [a]
: [ExternalValue]
more_values,
                  ValueDestination
dest ValueDestination -> [ValueDestination] -> [ValueDestination]
forall a. a -> [a] -> [a]
: [ValueDestination]
more_dests)
        mkExts [EntryPointType]
_ [MemInfo (Ext DimSize) Uniqueness MemReturn]
_ = ([ExternalValue], [ValueDestination])
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ([ExternalValue], [ValueDestination])
forall (m :: * -> *) a. Monad m => a -> m a
return ([], [])

        mkParam :: MemInfo (Ext DimSize) Uniqueness MemReturn
-> Signedness
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     (ValueDesc, ValueDestination)
mkParam MemMem{} Signedness
_ =
          String
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     (ValueDesc, ValueDestination)
forall a. HasCallStack => String -> a
error String
"Functions may not explicitly return memory blocks."
        mkParam (MemPrim PrimType
t) Signedness
ept = do
          VName
out <- ImpM lore r op VName
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     VName
forall a.
ImpM lore r op a
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     a
imp (ImpM lore r op VName
 -> StateT
      (Map Any Any, Map Int VName)
      (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
      VName)
-> ImpM lore r op VName
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     VName
forall a b. (a -> b) -> a -> b
$ String -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"scalar_out"
          ([Param], Map Int ValueDestination)
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([VName -> PrimType -> Param
Imp.ScalarParam VName
out PrimType
t], Map Int ValueDestination
forall a. Monoid a => a
mempty)
          (ValueDesc, ValueDestination)
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     (ValueDesc, ValueDestination)
forall (m :: * -> *) a. Monad m => a -> m a
return (PrimType -> Signedness -> VName -> ValueDesc
Imp.ScalarValue PrimType
t Signedness
ept VName
out, VName -> ValueDestination
ScalarDestination VName
out)
        mkParam (MemArray PrimType
t ShapeBase (Ext DimSize)
shape Uniqueness
_ MemReturn
dec) Signedness
ept = do
          Space
space <- (Env lore r op -> Space)
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     Space
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Space
forall lore r op. Env lore r op -> Space
envDefaultSpace
          VName
memout <- case MemReturn
dec of
            ReturnsNewBlock Space
_ Int
x ExtIxFun
_ixfun -> do
              VName
memout <- ImpM lore r op VName
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     VName
forall a.
ImpM lore r op a
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     a
imp (ImpM lore r op VName
 -> StateT
      (Map Any Any, Map Int VName)
      (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
      VName)
-> ImpM lore r op VName
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     VName
forall a b. (a -> b) -> a -> b
$ String -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"out_mem"
              ([Param], Map Int ValueDestination)
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([VName -> Space -> Param
Imp.MemParam VName
memout Space
space],
                    Int -> ValueDestination -> Map Int ValueDestination
forall k a. k -> a -> Map k a
M.singleton Int
x (ValueDestination -> Map Int ValueDestination)
-> ValueDestination -> Map Int ValueDestination
forall a b. (a -> b) -> a -> b
$ VName -> ValueDestination
MemoryDestination VName
memout)
              VName
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
memout
            ReturnsInBlock VName
memout ExtIxFun
_ ->
              VName
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
memout
          [DimSize]
resultshape <- (Ext DimSize
 -> StateT
      (Map Any Any, Map Int VName)
      (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
      DimSize)
-> [Ext DimSize]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     [DimSize]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Ext DimSize
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     DimSize
inspectExtSize ([Ext DimSize]
 -> StateT
      (Map Any Any, Map Int VName)
      (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
      [DimSize])
-> [Ext DimSize]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     [DimSize]
forall a b. (a -> b) -> a -> b
$ ShapeBase (Ext DimSize) -> [Ext DimSize]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase (Ext DimSize)
shape
          (ValueDesc, ValueDestination)
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     (ValueDesc, ValueDestination)
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> Space -> PrimType -> Signedness -> [DimSize] -> ValueDesc
Imp.ArrayValue VName
memout Space
space PrimType
t Signedness
ept [DimSize]
resultshape,
                  Maybe MemLocation -> ValueDestination
ArrayDestination Maybe MemLocation
forall a. Maybe a
Nothing)

        inspectExtSize :: Ext DimSize
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     DimSize
inspectExtSize (Ext Int
x) = do
          (Map Any Any
memseen,Map Int VName
arrseen) <- StateT
  (Map Any Any, Map Int VName)
  (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
  (Map Any Any, Map Int VName)
forall s (m :: * -> *). MonadState s m => m s
get
          case Int -> Map Int VName -> Maybe VName
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Int
x Map Int VName
arrseen of
            Maybe VName
Nothing -> do
              VName
out <- ImpM lore r op VName
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     VName
forall a.
ImpM lore r op a
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     a
imp (ImpM lore r op VName
 -> StateT
      (Map Any Any, Map Int VName)
      (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
      VName)
-> ImpM lore r op VName
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     VName
forall a b. (a -> b) -> a -> b
$ String -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"out_arrsize"
              ([Param], Map Int ValueDestination)
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([VName -> PrimType -> Param
Imp.ScalarParam VName
out PrimType
int32],
                    Int -> ValueDestination -> Map Int ValueDestination
forall k a. k -> a -> Map k a
M.singleton Int
x (ValueDestination -> Map Int ValueDestination)
-> ValueDestination -> Map Int ValueDestination
forall a b. (a -> b) -> a -> b
$ VName -> ValueDestination
ScalarDestination VName
out)
              (Map Any Any, Map Int VName)
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (Map Any Any
memseen, Int -> VName -> Map Int VName -> Map Int VName
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Int
x VName
out Map Int VName
arrseen)
              DimSize
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     DimSize
forall (m :: * -> *) a. Monad m => a -> m a
return (DimSize
 -> StateT
      (Map Any Any, Map Int VName)
      (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
      DimSize)
-> DimSize
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     DimSize
forall a b. (a -> b) -> a -> b
$ VName -> DimSize
Var VName
out
            Just VName
out ->
              DimSize
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     DimSize
forall (m :: * -> *) a. Monad m => a -> m a
return (DimSize
 -> StateT
      (Map Any Any, Map Int VName)
      (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
      DimSize)
-> DimSize
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     DimSize
forall a b. (a -> b) -> a -> b
$ VName -> DimSize
Var VName
out
        inspectExtSize (Free DimSize
se) =
          DimSize
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     DimSize
forall (m :: * -> *) a. Monad m => a -> m a
return DimSize
se

compileFunDef :: Mem lore =>
                 FunDef lore
              -> ImpM lore r op ()
compileFunDef :: FunDef lore -> ImpM lore r op ()
compileFunDef (FunDef Maybe EntryPoint
entry Attrs
_ Name
fname [RetType lore]
rettype [FParam lore]
params BodyT lore
body) =
  (Env lore r op -> Env lore r op)
-> ImpM lore r op () -> ImpM lore r op ()
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (\Env lore r op
env -> Env lore r op
env { envFunction :: Maybe Name
envFunction = Name -> Maybe Name
forall a. a -> Maybe a
Just Name
fname }) (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ do
  (([Param]
outparams, [Param]
inparams, [ExternalValue]
results, [ExternalValue]
args), Code op
body') <- ImpM lore r op ([Param], [Param], [ExternalValue], [ExternalValue])
-> ImpM
     lore
     r
     op
     (([Param], [Param], [ExternalValue], [ExternalValue]), Code op)
forall lore r op a. ImpM lore r op a -> ImpM lore r op (a, Code op)
collect' ImpM lore r op ([Param], [Param], [ExternalValue], [ExternalValue])
compile
  Name -> Function op -> ImpM lore r op ()
forall op lore r. Name -> Function op -> ImpM lore r op ()
emitFunction Name
fname (Function op -> ImpM lore r op ())
-> Function op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Bool
-> [Param]
-> [Param]
-> Code op
-> [ExternalValue]
-> [ExternalValue]
-> Function op
forall a.
Bool
-> [Param]
-> [Param]
-> Code a
-> [ExternalValue]
-> [ExternalValue]
-> FunctionT a
Imp.Function (Maybe EntryPoint -> Bool
forall a. Maybe a -> Bool
isJust Maybe EntryPoint
entry) [Param]
outparams [Param]
inparams Code op
body' [ExternalValue]
results [ExternalValue]
args
  where params_entry :: [EntryPointType]
params_entry = [EntryPointType]
-> (EntryPoint -> [EntryPointType])
-> Maybe EntryPoint
-> [EntryPointType]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (Int -> EntryPointType -> [EntryPointType]
forall a. Int -> a -> [a]
replicate ([Param (MemInfo DimSize Uniqueness MemBind)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [FParam lore]
[Param (MemInfo DimSize Uniqueness MemBind)]
params) EntryPointType
TypeDirect) EntryPoint -> [EntryPointType]
forall a b. (a, b) -> a
fst Maybe EntryPoint
entry
        ret_entry :: [EntryPointType]
ret_entry = [EntryPointType]
-> (EntryPoint -> [EntryPointType])
-> Maybe EntryPoint
-> [EntryPointType]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (Int -> EntryPointType -> [EntryPointType]
forall a. Int -> a -> [a]
replicate ([MemInfo (Ext DimSize) Uniqueness MemReturn] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [RetType lore]
[MemInfo (Ext DimSize) Uniqueness MemReturn]
rettype) EntryPointType
TypeDirect) EntryPoint -> [EntryPointType]
forall a b. (a, b) -> b
snd Maybe EntryPoint
entry
        compile :: ImpM lore r op ([Param], [Param], [ExternalValue], [ExternalValue])
compile = do
          ([Param]
inparams, [ArrayDecl]
arrayds, [ExternalValue]
args) <- [FParam lore]
-> [EntryPointType]
-> ImpM lore r op ([Param], [ArrayDecl], [ExternalValue])
forall lore r op.
Mem lore =>
[FParam lore]
-> [EntryPointType]
-> ImpM lore r op ([Param], [ArrayDecl], [ExternalValue])
compileInParams [FParam lore]
params [EntryPointType]
params_entry
          ([ExternalValue]
results, [Param]
outparams, Destination Maybe Int
_ [ValueDestination]
dests) <- [RetType lore]
-> [EntryPointType]
-> ImpM lore r op ([ExternalValue], [Param], Destination)
forall lore r op.
Mem lore =>
[RetType lore]
-> [EntryPointType]
-> ImpM lore r op ([ExternalValue], [Param], Destination)
compileOutParams [RetType lore]
rettype [EntryPointType]
ret_entry
          [FParam lore] -> ImpM lore r op ()
forall lore r op. Mem lore => [FParam lore] -> ImpM lore r op ()
addFParams [FParam lore]
params
          [ArrayDecl] -> ImpM lore r op ()
forall lore r op. [ArrayDecl] -> ImpM lore r op ()
addArrays [ArrayDecl]
arrayds

          let Body BodyDec lore
_ Stms lore
stms [DimSize]
ses = BodyT lore
body
          Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms ([DimSize] -> Names
forall a. FreeIn a => a -> Names
freeIn [DimSize]
ses) Stms lore
stms (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
            [(ValueDestination, DimSize)]
-> ((ValueDestination, DimSize) -> ImpM lore r op ())
-> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([ValueDestination] -> [DimSize] -> [(ValueDestination, DimSize)]
forall a b. [a] -> [b] -> [(a, b)]
zip [ValueDestination]
dests [DimSize]
ses) (((ValueDestination, DimSize) -> ImpM lore r op ())
 -> ImpM lore r op ())
-> ((ValueDestination, DimSize) -> ImpM lore r op ())
-> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(ValueDestination
d, DimSize
se) -> ValueDestination
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
forall lore r op.
ValueDestination
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIMDest ValueDestination
d [] DimSize
se []

          ([Param], [Param], [ExternalValue], [ExternalValue])
-> ImpM
     lore r op ([Param], [Param], [ExternalValue], [ExternalValue])
forall (m :: * -> *) a. Monad m => a -> m a
return ([Param]
outparams, [Param]
inparams, [ExternalValue]
results, [ExternalValue]
args)

compileBody :: (Mem lore) => Pattern lore -> Body lore -> ImpM lore r op ()
compileBody :: Pattern lore -> Body lore -> ImpM lore r op ()
compileBody Pattern lore
pat (Body BodyDec lore
_ Stms lore
bnds [DimSize]
ses) = do
  Destination Maybe Int
_ [ValueDestination]
dests <- Pattern lore -> ImpM lore r op Destination
forall lore r op.
Mem lore =>
Pattern lore -> ImpM lore r op Destination
destinationFromPattern Pattern lore
pat
  Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms ([DimSize] -> Names
forall a. FreeIn a => a -> Names
freeIn [DimSize]
ses) Stms lore
bnds (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
    [(ValueDestination, DimSize)]
-> ((ValueDestination, DimSize) -> ImpM lore r op ())
-> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([ValueDestination] -> [DimSize] -> [(ValueDestination, DimSize)]
forall a b. [a] -> [b] -> [(a, b)]
zip [ValueDestination]
dests [DimSize]
ses) (((ValueDestination, DimSize) -> ImpM lore r op ())
 -> ImpM lore r op ())
-> ((ValueDestination, DimSize) -> ImpM lore r op ())
-> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(ValueDestination
d, DimSize
se) -> ValueDestination
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
forall lore r op.
ValueDestination
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIMDest ValueDestination
d [] DimSize
se []

compileBody' :: [Param dec] -> Body lore -> ImpM lore r op ()
compileBody' :: [Param dec] -> Body lore -> ImpM lore r op ()
compileBody' [Param dec]
params (Body BodyDec lore
_ Stms lore
bnds [DimSize]
ses) =
  Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms ([DimSize] -> Names
forall a. FreeIn a => a -> Names
freeIn [DimSize]
ses) Stms lore
bnds (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
    [(Param dec, DimSize)]
-> ((Param dec, DimSize) -> ImpM lore r op ()) -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param dec] -> [DimSize] -> [(Param dec, DimSize)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param dec]
params [DimSize]
ses) (((Param dec, DimSize) -> ImpM lore r op ()) -> ImpM lore r op ())
-> ((Param dec, DimSize) -> ImpM lore r op ()) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(Param dec
param, DimSize
se) -> VName
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
param) [] DimSize
se []

compileLoopBody :: Typed dec => [Param dec] -> Body lore -> ImpM lore r op ()
compileLoopBody :: [Param dec] -> Body lore -> ImpM lore r op ()
compileLoopBody [Param dec]
mergeparams (Body BodyDec lore
_ Stms lore
bnds [DimSize]
ses) = do
  -- We cannot write the results to the merge parameters immediately,
  -- as some of the results may actually *be* merge parameters, and
  -- would thus be clobbered.  Therefore, we first copy to new
  -- variables mirroring the merge parameters, and then copy this
  -- buffer to the merge parameters.  This is efficient, because the
  -- operations are all scalar operations.
  [VName]
tmpnames <- (Param dec -> ImpM lore r op VName)
-> [Param dec] -> ImpM lore r op [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> ImpM lore r op VName)
-> (Param dec -> String) -> Param dec -> ImpM lore r op VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (String -> ShowS
forall a. [a] -> [a] -> [a]
++String
"_tmp") ShowS -> (Param dec -> String) -> Param dec -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> String
baseString (VName -> String) -> (Param dec -> VName) -> Param dec -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param dec -> VName
forall dec. Param dec -> VName
paramName) [Param dec]
mergeparams
  Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms ([DimSize] -> Names
forall a. FreeIn a => a -> Names
freeIn [DimSize]
ses) Stms lore
bnds (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ do
    [ImpM lore r op ()]
copy_to_merge_params <- [(Param dec, VName, DimSize)]
-> ((Param dec, VName, DimSize)
    -> ImpM lore r op (ImpM lore r op ()))
-> ImpM lore r op [ImpM lore r op ()]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Param dec]
-> [VName] -> [DimSize] -> [(Param dec, VName, DimSize)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Param dec]
mergeparams [VName]
tmpnames [DimSize]
ses) (((Param dec, VName, DimSize)
  -> ImpM lore r op (ImpM lore r op ()))
 -> ImpM lore r op [ImpM lore r op ()])
-> ((Param dec, VName, DimSize)
    -> ImpM lore r op (ImpM lore r op ()))
-> ImpM lore r op [ImpM lore r op ()]
forall a b. (a -> b) -> a -> b
$ \(Param dec
p,VName
tmp,DimSize
se) ->
      case Param dec -> Type
forall t. Typed t => t -> Type
typeOf Param dec
p of
        Prim PrimType
pt  -> do
          Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Volatility -> PrimType -> Code op
forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
tmp Volatility
Imp.Nonvolatile PrimType
pt
          Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
tmp (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ PrimType -> DimSize -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
pt DimSize
se
          ImpM lore r op () -> ImpM lore r op (ImpM lore r op ())
forall (m :: * -> *) a. Monad m => a -> m a
return (ImpM lore r op () -> ImpM lore r op (ImpM lore r op ()))
-> ImpM lore r op () -> ImpM lore r op (ImpM lore r op ())
forall a b. (a -> b) -> a -> b
$ Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar (Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p) (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
tmp PrimType
pt
        Mem Space
space | Var VName
v <- DimSize
se -> do
          Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Code op
forall a. VName -> Space -> Code a
Imp.DeclareMem VName
tmp Space
space
          Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> Space -> Code op
forall a. VName -> VName -> Space -> Code a
Imp.SetMem VName
tmp VName
v Space
space
          ImpM lore r op () -> ImpM lore r op (ImpM lore r op ())
forall (m :: * -> *) a. Monad m => a -> m a
return (ImpM lore r op () -> ImpM lore r op (ImpM lore r op ()))
-> ImpM lore r op () -> ImpM lore r op (ImpM lore r op ())
forall a b. (a -> b) -> a -> b
$ Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> Space -> Code op
forall a. VName -> VName -> Space -> Code a
Imp.SetMem (Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p) VName
tmp Space
space
        Type
_ -> ImpM lore r op () -> ImpM lore r op (ImpM lore r op ())
forall (m :: * -> *) a. Monad m => a -> m a
return (ImpM lore r op () -> ImpM lore r op (ImpM lore r op ()))
-> ImpM lore r op () -> ImpM lore r op (ImpM lore r op ())
forall a b. (a -> b) -> a -> b
$ () -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    [ImpM lore r op ()] -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ [ImpM lore r op ()]
copy_to_merge_params

compileStms :: Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms :: Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
alive_after_stms Stms lore
all_stms ImpM lore r op ()
m = do
  Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
cb <- (Env lore r op
 -> Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ())
-> ImpM
     lore
     r
     op
     (Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ())
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op
-> Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op. Env lore r op -> StmsCompiler lore r op
envStmsCompiler
  Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
cb Names
alive_after_stms Stms lore
all_stms ImpM lore r op ()
m

defCompileStms :: (Mem lore, FreeIn op) =>
                  Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
defCompileStms :: Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
defCompileStms Names
alive_after_stms Stms lore
all_stms ImpM lore r op ()
m =
  -- We keep track of any memory blocks produced by the statements,
  -- and after the last time that memory block is used, we insert a
  -- Free.  This is very conservative, but can cut down on lifetimes
  -- in some cases.
  ImpM lore r op Names -> ImpM lore r op ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ImpM lore r op Names -> ImpM lore r op ())
-> ImpM lore r op Names -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Set (VName, Space) -> [Stm lore] -> ImpM lore r op Names
compileStms' Set (VName, Space)
forall a. Monoid a => a
mempty ([Stm lore] -> ImpM lore r op Names)
-> [Stm lore] -> ImpM lore r op Names
forall a b. (a -> b) -> a -> b
$ Stms lore -> [Stm lore]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms lore
all_stms
  where compileStms' :: Set (VName, Space) -> [Stm lore] -> ImpM lore r op Names
compileStms' Set (VName, Space)
allocs (Let Pattern lore
pat StmAux (ExpDec lore)
aux Exp lore
e:[Stm lore]
bs) = do
          Maybe (Exp lore) -> [PatElem lore] -> ImpM lore r op ()
forall lore r op.
Mem lore =>
Maybe (Exp lore) -> [PatElem lore] -> ImpM lore r op ()
dVars (Exp lore -> Maybe (Exp lore)
forall a. a -> Maybe a
Just Exp lore
e) (PatternT (MemBound NoUniqueness)
-> [PatElemT (MemBound NoUniqueness)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern lore
PatternT (MemBound NoUniqueness)
pat)

          Code op
e_code <- Attrs -> ImpM lore r op (Code op) -> ImpM lore r op (Code op)
forall lore r op a. Attrs -> ImpM lore r op a -> ImpM lore r op a
localAttrs (StmAux (ExpDec lore) -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux (ExpDec lore)
aux) (ImpM lore r op (Code op) -> ImpM lore r op (Code op))
-> ImpM lore r op (Code op) -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$
                    ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (ImpM lore r op () -> ImpM lore r op (Code op))
-> ImpM lore r op () -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$ Pattern lore -> Exp lore -> ImpM lore r op ()
forall lore r op. Pattern lore -> Exp lore -> ImpM lore r op ()
compileExp Pattern lore
pat Exp lore
e
          (Names
live_after, Code op
bs_code) <- ImpM lore r op Names -> ImpM lore r op (Names, Code op)
forall lore r op a. ImpM lore r op a -> ImpM lore r op (a, Code op)
collect' (ImpM lore r op Names -> ImpM lore r op (Names, Code op))
-> ImpM lore r op Names -> ImpM lore r op (Names, Code op)
forall a b. (a -> b) -> a -> b
$ Set (VName, Space) -> [Stm lore] -> ImpM lore r op Names
compileStms' (PatternT (MemBound NoUniqueness) -> Set (VName, Space)
patternAllocs Pattern lore
PatternT (MemBound NoUniqueness)
pat Set (VName, Space) -> Set (VName, Space) -> Set (VName, Space)
forall a. Semigroup a => a -> a -> a
<> Set (VName, Space)
allocs) [Stm lore]
bs
          let dies_here :: VName -> Bool
dies_here VName
v = Bool -> Bool
not (VName
v VName -> Names -> Bool
`nameIn` Names
live_after) Bool -> Bool -> Bool
&&
                            VName
v VName -> Names -> Bool
`nameIn` Code op -> Names
forall a. FreeIn a => a -> Names
freeIn Code op
e_code
              to_free :: Set (VName, Space)
to_free = ((VName, Space) -> Bool)
-> Set (VName, Space) -> Set (VName, Space)
forall a. (a -> Bool) -> Set a -> Set a
S.filter (VName -> Bool
dies_here (VName -> Bool)
-> ((VName, Space) -> VName) -> (VName, Space) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, Space) -> VName
forall a b. (a, b) -> a
fst) Set (VName, Space)
allocs

          Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit Code op
e_code
          ((VName, Space) -> ImpM lore r op ())
-> Set (VName, Space) -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ())
-> ((VName, Space) -> Code op)
-> (VName, Space)
-> ImpM lore r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> Space -> Code op) -> (VName, Space) -> Code op
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry VName -> Space -> Code op
forall a. VName -> Space -> Code a
Imp.Free) Set (VName, Space)
to_free
          Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit Code op
bs_code

          Names -> ImpM lore r op Names
forall (m :: * -> *) a. Monad m => a -> m a
return (Names -> ImpM lore r op Names) -> Names -> ImpM lore r op Names
forall a b. (a -> b) -> a -> b
$ Code op -> Names
forall a. FreeIn a => a -> Names
freeIn Code op
e_code Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
live_after
        compileStms' Set (VName, Space)
_ [] = do
          Code op
code <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect ImpM lore r op ()
m
          Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit Code op
code
          Names -> ImpM lore r op Names
forall (m :: * -> *) a. Monad m => a -> m a
return (Names -> ImpM lore r op Names) -> Names -> ImpM lore r op Names
forall a b. (a -> b) -> a -> b
$ Code op -> Names
forall a. FreeIn a => a -> Names
freeIn Code op
code Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
alive_after_stms

        patternAllocs :: PatternT (MemBound NoUniqueness) -> Set (VName, Space)
patternAllocs = [(VName, Space)] -> Set (VName, Space)
forall a. Ord a => [a] -> Set a
S.fromList ([(VName, Space)] -> Set (VName, Space))
-> (PatternT (MemBound NoUniqueness) -> [(VName, Space)])
-> PatternT (MemBound NoUniqueness)
-> Set (VName, Space)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElemT (MemBound NoUniqueness) -> Maybe (VName, Space))
-> [PatElemT (MemBound NoUniqueness)] -> [(VName, Space)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe PatElemT (MemBound NoUniqueness) -> Maybe (VName, Space)
forall dec. Typed dec => PatElemT dec -> Maybe (VName, Space)
isMemPatElem ([PatElemT (MemBound NoUniqueness)] -> [(VName, Space)])
-> (PatternT (MemBound NoUniqueness)
    -> [PatElemT (MemBound NoUniqueness)])
-> PatternT (MemBound NoUniqueness)
-> [(VName, Space)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatternT (MemBound NoUniqueness)
-> [PatElemT (MemBound NoUniqueness)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements
        isMemPatElem :: PatElemT dec -> Maybe (VName, Space)
isMemPatElem PatElemT dec
pe = case PatElemT dec -> Type
forall dec. Typed dec => PatElemT dec -> Type
patElemType PatElemT dec
pe of
                            Mem Space
space -> (VName, Space) -> Maybe (VName, Space)
forall a. a -> Maybe a
Just (PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT dec
pe, Space
space)
                            Type
_         -> Maybe (VName, Space)
forall a. Maybe a
Nothing

compileExp :: Pattern lore -> Exp lore -> ImpM lore r op ()
compileExp :: Pattern lore -> Exp lore -> ImpM lore r op ()
compileExp Pattern lore
pat Exp lore
e = do
  Pattern lore -> Exp lore -> ImpM lore r op ()
ec <- (Env lore r op -> Pattern lore -> Exp lore -> ImpM lore r op ())
-> ImpM lore r op (Pattern lore -> Exp lore -> ImpM lore r op ())
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Pattern lore -> Exp lore -> ImpM lore r op ()
forall lore r op. Env lore r op -> ExpCompiler lore r op
envExpCompiler
  Pattern lore -> Exp lore -> ImpM lore r op ()
ec Pattern lore
pat Exp lore
e

defCompileExp :: (Mem lore) =>
                 Pattern lore -> Exp lore -> ImpM lore r op ()

defCompileExp :: Pattern lore -> Exp lore -> ImpM lore r op ()
defCompileExp Pattern lore
pat (If DimSize
cond BodyT lore
tbranch BodyT lore
fbranch IfDec (BranchType lore)
_) = do
  Code op
tcode <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (ImpM lore r op () -> ImpM lore r op (Code op))
-> ImpM lore r op () -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$ Pattern lore -> BodyT lore -> ImpM lore r op ()
forall lore r op.
Mem lore =>
Pattern lore -> Body lore -> ImpM lore r op ()
compileBody Pattern lore
pat BodyT lore
tbranch
  Code op
fcode <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (ImpM lore r op () -> ImpM lore r op (Code op))
-> ImpM lore r op () -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$ Pattern lore -> BodyT lore -> ImpM lore r op ()
forall lore r op.
Mem lore =>
Pattern lore -> Body lore -> ImpM lore r op ()
compileBody Pattern lore
pat BodyT lore
fbranch
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Exp -> Code op -> Code op -> Code op
forall a. Exp -> Code a -> Code a -> Code a
Imp.If (PrimType -> DimSize -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
Bool DimSize
cond) Code op
tcode Code op
fcode

defCompileExp Pattern lore
pat (Apply Name
fname [(DimSize, Diet)]
args [RetType lore]
_ (Safety, SrcLoc, [SrcLoc])
_) = do
  Destination
dest <- Pattern lore -> ImpM lore r op Destination
forall lore r op.
Mem lore =>
Pattern lore -> ImpM lore r op Destination
destinationFromPattern Pattern lore
pat
  [VName]
targets <- Destination -> ImpM lore r op [VName]
forall lore r op. Destination -> ImpM lore r op [VName]
funcallTargets Destination
dest
  [Arg]
args' <- [Maybe Arg] -> [Arg]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe Arg] -> [Arg])
-> ImpM lore r op [Maybe Arg] -> ImpM lore r op [Arg]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((DimSize, Diet) -> ImpM lore r op (Maybe Arg))
-> [(DimSize, Diet)] -> ImpM lore r op [Maybe Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (DimSize, Diet) -> ImpM lore r op (Maybe Arg)
forall (m :: * -> *) t b.
(Monad m, HasScope t m) =>
(DimSize, b) -> m (Maybe Arg)
compileArg [(DimSize, Diet)]
args
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Name -> [Arg] -> Code op
forall a. [VName] -> Name -> [Arg] -> Code a
Imp.Call [VName]
targets Name
fname [Arg]
args'
  where compileArg :: (DimSize, b) -> m (Maybe Arg)
compileArg (DimSize
se, b
_) = do
          Type
t <- DimSize -> m Type
forall t (m :: * -> *). HasScope t m => DimSize -> m Type
subExpType DimSize
se
          case (DimSize
se, Type
t) of
            (DimSize
_, Prim PrimType
pt)   -> Maybe Arg -> m (Maybe Arg)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Arg -> m (Maybe Arg)) -> Maybe Arg -> m (Maybe Arg)
forall a b. (a -> b) -> a -> b
$ Arg -> Maybe Arg
forall a. a -> Maybe a
Just (Arg -> Maybe Arg) -> Arg -> Maybe Arg
forall a b. (a -> b) -> a -> b
$ Exp -> Arg
Imp.ExpArg (Exp -> Arg) -> Exp -> Arg
forall a b. (a -> b) -> a -> b
$ PrimType -> DimSize -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
pt DimSize
se
            (Var VName
v, Mem{}) -> Maybe Arg -> m (Maybe Arg)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Arg -> m (Maybe Arg)) -> Maybe Arg -> m (Maybe Arg)
forall a b. (a -> b) -> a -> b
$ Arg -> Maybe Arg
forall a. a -> Maybe a
Just (Arg -> Maybe Arg) -> Arg -> Maybe Arg
forall a b. (a -> b) -> a -> b
$ VName -> Arg
Imp.MemArg VName
v
            (DimSize, Type)
_              -> Maybe Arg -> m (Maybe Arg)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Arg
forall a. Maybe a
Nothing

defCompileExp Pattern lore
pat (BasicOp BasicOp
op) = Pattern lore -> BasicOp -> ImpM lore r op ()
forall lore r op.
Mem lore =>
Pattern lore -> BasicOp -> ImpM lore r op ()
defCompileBasicOp Pattern lore
pat BasicOp
op

defCompileExp Pattern lore
pat (DoLoop [(FParam lore, DimSize)]
ctx [(FParam lore, DimSize)]
val LoopForm lore
form BodyT lore
body) = do
  Attrs
attrs <- ImpM lore r op Attrs
forall lore r op. ImpM lore r op Attrs
askAttrs
  Bool -> ImpM lore r op () -> ImpM lore r op ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Attr
"unroll" Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs) (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
    SrcLoc -> [SrcLoc] -> String -> ImpM lore r op ()
forall loc lore r op.
Located loc =>
loc -> [loc] -> String -> ImpM lore r op ()
warn (SrcLoc
forall a. IsLocation a => a
noLoc::SrcLoc) [] String
"#[unroll] on loop with unknown number of iterations." -- FIXME: no location.

  [FParam lore] -> ImpM lore r op ()
forall lore r op. Mem lore => [FParam lore] -> ImpM lore r op ()
dFParams [FParam lore]
[Param (MemInfo DimSize Uniqueness MemBind)]
mergepat
  [(Param (MemInfo DimSize Uniqueness MemBind), DimSize)]
-> ((Param (MemInfo DimSize Uniqueness MemBind), DimSize)
    -> ImpM lore r op ())
-> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(Param (MemInfo DimSize Uniqueness MemBind), DimSize)]
merge (((Param (MemInfo DimSize Uniqueness MemBind), DimSize)
  -> ImpM lore r op ())
 -> ImpM lore r op ())
-> ((Param (MemInfo DimSize Uniqueness MemBind), DimSize)
    -> ImpM lore r op ())
-> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo DimSize Uniqueness MemBind)
p, DimSize
se) ->
    Bool -> ImpM lore r op () -> ImpM lore r op ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ((Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
==Int
0) (Int -> Bool) -> Int -> Bool
forall a b. (a -> b) -> a -> b
$ Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (Type -> Int) -> Type -> Int
forall a b. (a -> b) -> a -> b
$ Param (MemInfo DimSize Uniqueness MemBind) -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param (MemInfo DimSize Uniqueness MemBind)
p) (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
    VName
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (Param (MemInfo DimSize Uniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo DimSize Uniqueness MemBind)
p) [] DimSize
se []

  let doBody :: ImpM lore r op ()
doBody = [Param (MemInfo DimSize Uniqueness MemBind)]
-> BodyT lore -> ImpM lore r op ()
forall dec lore r op.
Typed dec =>
[Param dec] -> Body lore -> ImpM lore r op ()
compileLoopBody [Param (MemInfo DimSize Uniqueness MemBind)]
mergepat BodyT lore
body

  case LoopForm lore
form of
    ForLoop VName
i IntType
it DimSize
bound [(LParam lore, VName)]
loopvars -> do
      let setLoopParam :: (Param (MemBound NoUniqueness), VName) -> ImpM lore r op ()
setLoopParam (Param (MemBound NoUniqueness)
p,VName
a)
            | Prim PrimType
_ <- Param (MemBound NoUniqueness) -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param (MemBound NoUniqueness)
p =
                VName
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (Param (MemBound NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (MemBound NoUniqueness)
p) [] (VName -> DimSize
Var VName
a) [Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix (Exp -> DimIndex Exp) -> Exp -> DimIndex Exp
forall a b. (a -> b) -> a -> b
$ VName -> Exp
Imp.vi32 VName
i]
            | Bool
otherwise =
                () -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

      [LParam lore] -> ImpM lore r op ()
forall lore r op. Mem lore => [LParam lore] -> ImpM lore r op ()
dLParams ([LParam lore] -> ImpM lore r op ())
-> [LParam lore] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ ((Param (MemBound NoUniqueness), VName)
 -> Param (MemBound NoUniqueness))
-> [(Param (MemBound NoUniqueness), VName)]
-> [Param (MemBound NoUniqueness)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (MemBound NoUniqueness), VName)
-> Param (MemBound NoUniqueness)
forall a b. (a, b) -> a
fst [(LParam lore, VName)]
[(Param (MemBound NoUniqueness), VName)]
loopvars
      VName -> IntType -> Exp -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
VName -> IntType -> Exp -> ImpM lore r op () -> ImpM lore r op ()
sFor' VName
i IntType
it (PrimType -> DimSize -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' (IntType -> PrimType
IntType IntType
it) DimSize
bound) (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
        ((Param (MemBound NoUniqueness), VName) -> ImpM lore r op ())
-> [(Param (MemBound NoUniqueness), VName)] -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Param (MemBound NoUniqueness), VName) -> ImpM lore r op ()
setLoopParam [(LParam lore, VName)]
[(Param (MemBound NoUniqueness), VName)]
loopvars ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ImpM lore r op ()
doBody
    WhileLoop VName
cond ->
      Exp -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhile (VName -> PrimType -> Exp
Imp.var VName
cond PrimType
Bool) ImpM lore r op ()
doBody

  Destination Maybe Int
_ [ValueDestination]
pat_dests <- Pattern lore -> ImpM lore r op Destination
forall lore r op.
Mem lore =>
Pattern lore -> ImpM lore r op Destination
destinationFromPattern Pattern lore
pat
  [(ValueDestination, DimSize)]
-> ((ValueDestination, DimSize) -> ImpM lore r op ())
-> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([ValueDestination] -> [DimSize] -> [(ValueDestination, DimSize)]
forall a b. [a] -> [b] -> [(a, b)]
zip [ValueDestination]
pat_dests ([DimSize] -> [(ValueDestination, DimSize)])
-> [DimSize] -> [(ValueDestination, DimSize)]
forall a b. (a -> b) -> a -> b
$ ((Param (MemInfo DimSize Uniqueness MemBind), DimSize) -> DimSize)
-> [(Param (MemInfo DimSize Uniqueness MemBind), DimSize)]
-> [DimSize]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> DimSize
Var (VName -> DimSize)
-> ((Param (MemInfo DimSize Uniqueness MemBind), DimSize) -> VName)
-> (Param (MemInfo DimSize Uniqueness MemBind), DimSize)
-> DimSize
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param (MemInfo DimSize Uniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName (Param (MemInfo DimSize Uniqueness MemBind) -> VName)
-> ((Param (MemInfo DimSize Uniqueness MemBind), DimSize)
    -> Param (MemInfo DimSize Uniqueness MemBind))
-> (Param (MemInfo DimSize Uniqueness MemBind), DimSize)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param (MemInfo DimSize Uniqueness MemBind), DimSize)
-> Param (MemInfo DimSize Uniqueness MemBind)
forall a b. (a, b) -> a
fst) [(Param (MemInfo DimSize Uniqueness MemBind), DimSize)]
merge) (((ValueDestination, DimSize) -> ImpM lore r op ())
 -> ImpM lore r op ())
-> ((ValueDestination, DimSize) -> ImpM lore r op ())
-> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(ValueDestination
d, DimSize
r) ->
    ValueDestination
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
forall lore r op.
ValueDestination
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIMDest ValueDestination
d [] DimSize
r []

  where merge :: [(Param (MemInfo DimSize Uniqueness MemBind), DimSize)]
merge = [(FParam lore, DimSize)]
[(Param (MemInfo DimSize Uniqueness MemBind), DimSize)]
ctx [(Param (MemInfo DimSize Uniqueness MemBind), DimSize)]
-> [(Param (MemInfo DimSize Uniqueness MemBind), DimSize)]
-> [(Param (MemInfo DimSize Uniqueness MemBind), DimSize)]
forall a. [a] -> [a] -> [a]
++ [(FParam lore, DimSize)]
[(Param (MemInfo DimSize Uniqueness MemBind), DimSize)]
val
        mergepat :: [Param (MemInfo DimSize Uniqueness MemBind)]
mergepat = ((Param (MemInfo DimSize Uniqueness MemBind), DimSize)
 -> Param (MemInfo DimSize Uniqueness MemBind))
-> [(Param (MemInfo DimSize Uniqueness MemBind), DimSize)]
-> [Param (MemInfo DimSize Uniqueness MemBind)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (MemInfo DimSize Uniqueness MemBind), DimSize)
-> Param (MemInfo DimSize Uniqueness MemBind)
forall a b. (a, b) -> a
fst [(Param (MemInfo DimSize Uniqueness MemBind), DimSize)]
merge

defCompileExp Pattern lore
pat (Op Op lore
op) = do
  PatternT (MemBound NoUniqueness) -> Op lore -> ImpM lore r op ()
opc <- (Env lore r op
 -> PatternT (MemBound NoUniqueness)
 -> Op lore
 -> ImpM lore r op ())
-> ImpM
     lore
     r
     op
     (PatternT (MemBound NoUniqueness) -> Op lore -> ImpM lore r op ())
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op
-> PatternT (MemBound NoUniqueness) -> Op lore -> ImpM lore r op ()
forall lore r op. Env lore r op -> OpCompiler lore r op
envOpCompiler
  PatternT (MemBound NoUniqueness) -> Op lore -> ImpM lore r op ()
opc Pattern lore
PatternT (MemBound NoUniqueness)
pat Op lore
op

defCompileBasicOp :: Mem lore =>
                     Pattern lore -> BasicOp -> ImpM lore r op ()

defCompileBasicOp :: Pattern lore -> BasicOp -> ImpM lore r op ()
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (SubExp DimSize
se) =
  VName
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (PatElemT (MemBound NoUniqueness) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT (MemBound NoUniqueness)
pe) [] DimSize
se []

defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (Opaque DimSize
se) =
  VName
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (PatElemT (MemBound NoUniqueness) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT (MemBound NoUniqueness)
pe) [] DimSize
se []

defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (UnOp UnOp
op DimSize
e) = do
  Exp
e' <- DimSize -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp DimSize
e
  PatElemT (MemBound NoUniqueness) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT (MemBound NoUniqueness)
pe VName -> Exp -> ImpM lore r op ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- UnOp -> Exp -> Exp
forall v. UnOp -> PrimExp v -> PrimExp v
Imp.UnOpExp UnOp
op Exp
e'

defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (ConvOp ConvOp
conv DimSize
e) = do
  Exp
e' <- DimSize -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp DimSize
e
  PatElemT (MemBound NoUniqueness) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT (MemBound NoUniqueness)
pe VName -> Exp -> ImpM lore r op ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- ConvOp -> Exp -> Exp
forall v. ConvOp -> PrimExp v -> PrimExp v
Imp.ConvOpExp ConvOp
conv Exp
e'

defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (BinOp BinOp
bop DimSize
x DimSize
y) = do
  Exp
x' <- DimSize -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp DimSize
x
  Exp
y' <- DimSize -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp DimSize
y
  PatElemT (MemBound NoUniqueness) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT (MemBound NoUniqueness)
pe VName -> Exp -> ImpM lore r op ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp BinOp
bop Exp
x' Exp
y'

defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (CmpOp CmpOp
bop DimSize
x DimSize
y) = do
  Exp
x' <- DimSize -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp DimSize
x
  Exp
y' <- DimSize -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp DimSize
y
  PatElemT (MemBound NoUniqueness) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT (MemBound NoUniqueness)
pe VName -> Exp -> ImpM lore r op ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- CmpOp -> Exp -> Exp -> Exp
forall v. CmpOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.CmpOpExp CmpOp
bop Exp
x' Exp
y'

defCompileBasicOp Pattern lore
_ (Assert DimSize
e ErrorMsg DimSize
msg (SrcLoc, [SrcLoc])
loc) = do
  Exp
e' <- DimSize -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp DimSize
e
  ErrorMsg Exp
msg' <- (DimSize -> ImpM lore r op Exp)
-> ErrorMsg DimSize -> ImpM lore r op (ErrorMsg Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse DimSize -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp ErrorMsg DimSize
msg
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Exp -> ErrorMsg Exp -> (SrcLoc, [SrcLoc]) -> Code op
forall a. Exp -> ErrorMsg Exp -> (SrcLoc, [SrcLoc]) -> Code a
Imp.Assert Exp
e' ErrorMsg Exp
msg' (SrcLoc, [SrcLoc])
loc

  Attrs
attrs <- ImpM lore r op Attrs
forall lore r op. ImpM lore r op Attrs
askAttrs
  Bool -> ImpM lore r op () -> ImpM lore r op ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Name -> [Attr] -> Attr
AttrComp Name
"warn" [Attr
"safety_checks"] Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs) (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
    (SrcLoc -> [SrcLoc] -> String -> ImpM lore r op ())
-> (SrcLoc, [SrcLoc]) -> String -> ImpM lore r op ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry SrcLoc -> [SrcLoc] -> String -> ImpM lore r op ()
forall loc lore r op.
Located loc =>
loc -> [loc] -> String -> ImpM lore r op ()
warn (SrcLoc, [SrcLoc])
loc String
"Safety check required at run-time."

defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (Index VName
src Slice DimSize
slice)
  | Just [DimSize]
idxs <- Slice DimSize -> Maybe [DimSize]
forall d. Slice d -> Maybe [d]
sliceIndices Slice DimSize
slice =
      VName
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (PatElemT (MemBound NoUniqueness) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT (MemBound NoUniqueness)
pe) [] (VName -> DimSize
Var VName
src) ([DimIndex Exp] -> ImpM lore r op ())
-> [DimIndex Exp] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ (DimSize -> DimIndex Exp) -> [DimSize] -> [DimIndex Exp]
forall a b. (a -> b) -> [a] -> [b]
map (Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix (Exp -> DimIndex Exp)
-> (DimSize -> Exp) -> DimSize -> DimIndex Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> DimSize -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32) [DimSize]
idxs

defCompileBasicOp Pattern lore
_ Index{} =
  () -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (Update VName
_ Slice DimSize
slice DimSize
se) =
  VName -> [DimIndex Exp] -> DimSize -> ImpM lore r op ()
forall lore r op.
VName -> [DimIndex Exp] -> DimSize -> ImpM lore r op ()
sUpdate (PatElemT (MemBound NoUniqueness) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT (MemBound NoUniqueness)
pe) ((DimIndex DimSize -> DimIndex Exp)
-> Slice DimSize -> [DimIndex Exp]
forall a b. (a -> b) -> [a] -> [b]
map ((DimSize -> Exp) -> DimIndex DimSize -> DimIndex Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (PrimType -> DimSize -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32)) Slice DimSize
slice) DimSize
se

defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (Replicate (Shape [DimSize]
ds) DimSize
se) = do
  [Exp]
ds' <- (DimSize -> ImpM lore r op Exp)
-> [DimSize] -> ImpM lore r op [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimSize -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp [DimSize]
ds
  [VName]
is <- Int -> ImpM lore r op VName -> ImpM lore r op [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM ([DimSize] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimSize]
ds) (String -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"i")
  Code op
copy_elem <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (ImpM lore r op () -> ImpM lore r op (Code op))
-> ImpM lore r op () -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$ VName
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (PatElemT (MemBound NoUniqueness) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT (MemBound NoUniqueness)
pe) ((VName -> DimIndex Exp) -> [VName] -> [DimIndex Exp]
forall a b. (a -> b) -> [a] -> [b]
map (Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix (Exp -> DimIndex Exp) -> (VName -> Exp) -> VName -> DimIndex Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Exp
Imp.vi32) [VName]
is) DimSize
se []
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ ((Code op -> Code op)
 -> (Code op -> Code op) -> Code op -> Code op)
-> (Code op -> Code op)
-> [Code op -> Code op]
-> Code op
-> Code op
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (Code op -> Code op) -> (Code op -> Code op) -> Code op -> Code op
forall b c a. (b -> c) -> (a -> b) -> a -> c
(.) Code op -> Code op
forall a. a -> a
id ((VName -> Exp -> Code op -> Code op)
-> [VName] -> [Exp] -> [Code op -> Code op]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (VName -> IntType -> Exp -> Code op -> Code op
forall a. VName -> IntType -> Exp -> Code a -> Code a
`Imp.For` IntType
Int32) [VName]
is [Exp]
ds') Code op
copy_elem

defCompileBasicOp Pattern lore
_ Scratch{} =
  () -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

defCompileBasicOp (Pattern [] [PatElemT (LetDec lore)
pe]) (Iota DimSize
n DimSize
e DimSize
s IntType
it) = do
  Exp
n' <- DimSize -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp DimSize
n
  Exp
e' <- DimSize -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp DimSize
e
  Exp
s' <- DimSize -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp DimSize
s
  String -> Exp -> (Exp -> ImpM lore r op ()) -> ImpM lore r op ()
forall lore r op.
String -> Exp -> (Exp -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"i" Exp
n' ((Exp -> ImpM lore r op ()) -> ImpM lore r op ())
-> (Exp -> ImpM lore r op ()) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \Exp
i -> do
    let i' :: Exp
i' = IntType -> Exp -> Exp
forall v. IntType -> PrimExp v -> PrimExp v
sExt IntType
it Exp
i
    VName
x <- String -> Exp -> ImpM lore r op VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"x" (Exp -> ImpM lore r op VName) -> Exp -> ImpM lore r op VName
forall a b. (a -> b) -> a -> b
$ Exp
e' Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
i' Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
s'
    VName
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (PatElemT (MemBound NoUniqueness) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT (MemBound NoUniqueness)
pe) [Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix Exp
i] (VName -> DimSize
Var VName
x) []

defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (Copy VName
src) =
  VName
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (PatElemT (MemBound NoUniqueness) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT (MemBound NoUniqueness)
pe) [] (VName -> DimSize
Var VName
src) []

defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (Manifest [Int]
_ VName
src) =
  VName
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (PatElemT (MemBound NoUniqueness) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT (MemBound NoUniqueness)
pe) [] (VName -> DimSize
Var VName
src) []

defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (Concat Int
i VName
x [VName]
ys DimSize
_) = do
  VName
offs_glb <- String -> PrimType -> ImpM lore r op VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim String
"tmp_offs" PrimType
int32
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
offs_glb Exp
0

  [VName] -> (VName -> ImpM lore r op ()) -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (VName
xVName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
:[VName]
ys) ((VName -> ImpM lore r op ()) -> ImpM lore r op ())
-> (VName -> ImpM lore r op ()) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \VName
y -> do
    [DimSize]
y_dims <- Type -> [DimSize]
forall u. TypeBase (ShapeBase DimSize) u -> [DimSize]
arrayDims (Type -> [DimSize])
-> ImpM lore r op Type -> ImpM lore r op [DimSize]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM lore r op Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
y
    let rows :: Exp
rows = case Int -> [DimSize] -> [DimSize]
forall a. Int -> [a] -> [a]
drop Int
i [DimSize]
y_dims of
                 []  -> String -> Exp
forall a. HasCallStack => String -> a
error (String -> Exp) -> String -> Exp
forall a b. (a -> b) -> a -> b
$ String
"defCompileBasicOp Concat: empty array shape for " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
y
                 DimSize
r:[DimSize]
_ -> PrimType -> DimSize -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 DimSize
r
        skip_dims :: [DimSize]
skip_dims = Int -> [DimSize] -> [DimSize]
forall a. Int -> [a] -> [a]
take Int
i [DimSize]
y_dims
        sliceAllDim :: d -> DimIndex d
sliceAllDim d
d = d -> d -> d -> DimIndex d
forall d. d -> d -> d -> DimIndex d
DimSlice d
0 d
d d
1
        skip_slices :: [DimIndex Exp]
skip_slices = (DimSize -> DimIndex Exp) -> [DimSize] -> [DimIndex Exp]
forall a b. (a -> b) -> [a] -> [b]
map (Exp -> DimIndex Exp
forall d. Num d => d -> DimIndex d
sliceAllDim (Exp -> DimIndex Exp)
-> (DimSize -> Exp) -> DimSize -> DimIndex Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> DimSize -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32) [DimSize]
skip_dims
        destslice :: [DimIndex Exp]
destslice = [DimIndex Exp]
skip_slices [DimIndex Exp] -> [DimIndex Exp] -> [DimIndex Exp]
forall a. [a] -> [a] -> [a]
++ [Exp -> Exp -> Exp -> DimIndex Exp
forall d. d -> d -> d -> DimIndex d
DimSlice (VName -> Exp
Imp.vi32 VName
offs_glb) Exp
rows Exp
1]
    VName
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (PatElemT (MemBound NoUniqueness) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT (MemBound NoUniqueness)
pe) [DimIndex Exp]
destslice (VName -> DimSize
Var VName
y) []
    Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
offs_glb (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
offs_glb PrimType
int32 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
rows

defCompileBasicOp (Pattern [] [PatElemT (LetDec lore)
pe]) (ArrayLit [DimSize]
es Type
_)
  | Just vs :: [PrimValue]
vs@(PrimValue
v:[PrimValue]
_) <- (DimSize -> Maybe PrimValue) -> [DimSize] -> Maybe [PrimValue]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimSize -> Maybe PrimValue
isLiteral [DimSize]
es = do
      MemLocation
dest_mem <- ArrayEntry -> MemLocation
entryArrayLocation (ArrayEntry -> MemLocation)
-> ImpM lore r op ArrayEntry -> ImpM lore r op MemLocation
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM lore r op ArrayEntry
forall lore r op. VName -> ImpM lore r op ArrayEntry
lookupArray (PatElemT (MemBound NoUniqueness) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT (MemBound NoUniqueness)
pe)
      Space
dest_space <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM lore r op MemEntry -> ImpM lore r op Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM lore r op MemEntry
forall lore r op. VName -> ImpM lore r op MemEntry
lookupMemory (MemLocation -> VName
memLocationName MemLocation
dest_mem)
      let t :: PrimType
t = PrimValue -> PrimType
primValueType PrimValue
v
      VName
static_array <- String -> ImpM lore r op VName
forall lore r op. String -> ImpM lore r op VName
newVNameForFun String
"static_array"
      Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Space -> PrimType -> ArrayContents -> Code op
forall a. VName -> Space -> PrimType -> ArrayContents -> Code a
Imp.DeclareArray VName
static_array Space
dest_space PrimType
t (ArrayContents -> Code op) -> ArrayContents -> Code op
forall a b. (a -> b) -> a -> b
$ [PrimValue] -> ArrayContents
Imp.ArrayValues [PrimValue]
vs
      let static_src :: MemLocation
static_src = VName -> [DimSize] -> IxFun Exp -> MemLocation
MemLocation VName
static_array [IntType -> Integer -> DimSize
intConst IntType
Int32 (Integer -> DimSize) -> Integer -> DimSize
forall a b. (a -> b) -> a -> b
$ Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Integer) -> Int -> Integer
forall a b. (a -> b) -> a -> b
$ [DimSize] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimSize]
es] (IxFun Exp -> MemLocation) -> IxFun Exp -> MemLocation
forall a b. (a -> b) -> a -> b
$
                       [Exp] -> IxFun Exp
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota [Int -> Exp
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Exp) -> Int -> Exp
forall a b. (a -> b) -> a -> b
$ [DimSize] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimSize]
es]
          entry :: VarEntry lore
entry = Maybe (Exp lore) -> MemEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> MemEntry -> VarEntry lore
MemVar Maybe (Exp lore)
forall a. Maybe a
Nothing (MemEntry -> VarEntry lore) -> MemEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
dest_space
      VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
static_array VarEntry lore
entry
      let slice :: [DimIndex Exp]
slice = [Exp -> Exp -> Exp -> DimIndex Exp
forall d. d -> d -> d -> DimIndex d
DimSlice Exp
0 ([DimSize] -> Exp
forall i a. Num i => [a] -> i
genericLength [DimSize]
es) Exp
1]
      CopyCompiler lore r op
forall lore r op. CopyCompiler lore r op
copy PrimType
t MemLocation
dest_mem [DimIndex Exp]
slice MemLocation
static_src [DimIndex Exp]
slice
  | Bool
otherwise =
    [(Integer, DimSize)]
-> ((Integer, DimSize) -> ImpM lore r op ()) -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Integer] -> [DimSize] -> [(Integer, DimSize)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Integer
0..] [DimSize]
es) (((Integer, DimSize) -> ImpM lore r op ()) -> ImpM lore r op ())
-> ((Integer, DimSize) -> ImpM lore r op ()) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(Integer
i,DimSize
e) ->
      VName
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (PatElemT (MemBound NoUniqueness) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT (MemBound NoUniqueness)
pe) [Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix (Exp -> DimIndex Exp) -> Exp -> DimIndex Exp
forall a b. (a -> b) -> a -> b
$ Integer -> Exp
forall a. Num a => Integer -> a
fromInteger Integer
i] DimSize
e []

  where isLiteral :: DimSize -> Maybe PrimValue
isLiteral (Constant PrimValue
v) = PrimValue -> Maybe PrimValue
forall a. a -> Maybe a
Just PrimValue
v
        isLiteral DimSize
_ = Maybe PrimValue
forall a. Maybe a
Nothing

defCompileBasicOp Pattern lore
_ Rearrange{} =
  () -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

defCompileBasicOp Pattern lore
_ Rotate{} =
  () -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

defCompileBasicOp Pattern lore
_ Reshape{} =
  () -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

defCompileBasicOp Pattern lore
pat BasicOp
e =
  String -> ImpM lore r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op ()) -> String -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ String
"ImpGen.defCompileBasicOp: Invalid pattern\n  " String -> ShowS
forall a. [a] -> [a] -> [a]
++
  PatternT (MemBound NoUniqueness) -> String
forall a. Pretty a => a -> String
pretty Pattern lore
PatternT (MemBound NoUniqueness)
pat String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"\nfor expression\n  " String -> ShowS
forall a. [a] -> [a] -> [a]
++ BasicOp -> String
forall a. Pretty a => a -> String
pretty BasicOp
e

-- | Note: a hack to be used only for functions.
addArrays :: [ArrayDecl] -> ImpM lore r op ()
addArrays :: [ArrayDecl] -> ImpM lore r op ()
addArrays = (ArrayDecl -> ImpM lore r op ())
-> [ArrayDecl] -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ArrayDecl -> ImpM lore r op ()
forall lore r op. ArrayDecl -> ImpM lore r op ()
addArray
  where addArray :: ArrayDecl -> ImpM lore r op ()
addArray (ArrayDecl VName
name PrimType
bt MemLocation
location) =
          VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
          Maybe (Exp lore) -> ArrayEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> ArrayEntry -> VarEntry lore
ArrayVar Maybe (Exp lore)
forall a. Maybe a
Nothing ArrayEntry :: MemLocation -> PrimType -> ArrayEntry
ArrayEntry
          { entryArrayLocation :: MemLocation
entryArrayLocation = MemLocation
location
          , entryArrayElemType :: PrimType
entryArrayElemType = PrimType
bt
          }

-- | Like 'dFParams', but does not create new declarations.
-- Note: a hack to be used only for functions.
addFParams :: Mem lore => [FParam lore] -> ImpM lore r op ()
addFParams :: [FParam lore] -> ImpM lore r op ()
addFParams = (Param (MemInfo DimSize Uniqueness MemBind) -> ImpM lore r op ())
-> [Param (MemInfo DimSize Uniqueness MemBind)]
-> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Param (MemInfo DimSize Uniqueness MemBind) -> ImpM lore r op ()
forall u lore r op.
Param (MemInfo DimSize u MemBind) -> ImpM lore r op ()
addFParam
  where addFParam :: Param (MemInfo DimSize u MemBind) -> ImpM lore r op ()
addFParam Param (MemInfo DimSize u MemBind)
fparam =
          VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar (Param (MemInfo DimSize u MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo DimSize u MemBind)
fparam) (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
          Maybe (Exp lore) -> MemBound NoUniqueness -> VarEntry lore
forall lore.
Maybe (Exp lore) -> MemBound NoUniqueness -> VarEntry lore
memBoundToVarEntry Maybe (Exp lore)
forall a. Maybe a
Nothing (MemBound NoUniqueness -> VarEntry lore)
-> MemBound NoUniqueness -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ MemInfo DimSize u MemBind -> MemBound NoUniqueness
forall d u r. MemInfo d u r -> MemInfo d NoUniqueness r
noUniquenessReturns (MemInfo DimSize u MemBind -> MemBound NoUniqueness)
-> MemInfo DimSize u MemBind -> MemBound NoUniqueness
forall a b. (a -> b) -> a -> b
$ Param (MemInfo DimSize u MemBind) -> MemInfo DimSize u MemBind
forall dec. Param dec -> dec
paramDec Param (MemInfo DimSize u MemBind)
fparam

-- | Another hack.
addLoopVar :: VName -> IntType -> ImpM lore r op ()
addLoopVar :: VName -> IntType -> ImpM lore r op ()
addLoopVar VName
i IntType
it = VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
i (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
ScalarVar Maybe (Exp lore)
forall a. Maybe a
Nothing (ScalarEntry -> VarEntry lore) -> ScalarEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarEntry
ScalarEntry (PrimType -> ScalarEntry) -> PrimType -> ScalarEntry
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
it

dVars :: Mem lore =>
            Maybe (Exp lore) -> [PatElem lore] -> ImpM lore r op ()
dVars :: Maybe (Exp lore) -> [PatElem lore] -> ImpM lore r op ()
dVars Maybe (Exp lore)
e = (PatElemT (MemBound NoUniqueness) -> ImpM lore r op ())
-> [PatElemT (MemBound NoUniqueness)] -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ PatElemT (MemBound NoUniqueness) -> ImpM lore r op ()
dVar
  where dVar :: PatElemT (MemBound NoUniqueness) -> ImpM lore r op ()
dVar = Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
forall lore r op.
Mem lore =>
Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
dScope Maybe (Exp lore)
e (Scope lore -> ImpM lore r op ())
-> (PatElemT (MemBound NoUniqueness) -> Scope lore)
-> PatElemT (MemBound NoUniqueness)
-> ImpM lore r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT (MemBound NoUniqueness) -> Scope lore
forall lore dec. (LetDec lore ~ dec) => PatElemT dec -> Scope lore
scopeOfPatElem

dFParams :: Mem lore => [FParam lore] -> ImpM lore r op ()
dFParams :: [FParam lore] -> ImpM lore r op ()
dFParams = Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
forall lore r op.
Mem lore =>
Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
dScope Maybe (Exp lore)
forall a. Maybe a
Nothing (Scope lore -> ImpM lore r op ())
-> ([Param (MemInfo DimSize Uniqueness MemBind)] -> Scope lore)
-> [Param (MemInfo DimSize Uniqueness MemBind)]
-> ImpM lore r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Param (MemInfo DimSize Uniqueness MemBind)] -> Scope lore
forall lore dec.
(FParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfFParams

dLParams :: Mem lore => [LParam lore] -> ImpM lore r op ()
dLParams :: [LParam lore] -> ImpM lore r op ()
dLParams = Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
forall lore r op.
Mem lore =>
Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
dScope Maybe (Exp lore)
forall a. Maybe a
Nothing (Scope lore -> ImpM lore r op ())
-> ([Param (MemBound NoUniqueness)] -> Scope lore)
-> [Param (MemBound NoUniqueness)]
-> ImpM lore r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Param (MemBound NoUniqueness)] -> Scope lore
forall lore dec.
(LParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfLParams

dPrimVol_ :: VName -> PrimType -> ImpM lore r op ()
dPrimVol_ :: VName -> PrimType -> ImpM lore r op ()
dPrimVol_ VName
name PrimType
t = do
 Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Volatility -> PrimType -> Code op
forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
name Volatility
Imp.Volatile PrimType
t
 VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
ScalarVar Maybe (Exp lore)
forall a. Maybe a
Nothing (ScalarEntry -> VarEntry lore) -> ScalarEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarEntry
ScalarEntry PrimType
t

dPrim_ :: VName -> PrimType -> ImpM lore r op ()
dPrim_ :: VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
name PrimType
t = do
 Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Volatility -> PrimType -> Code op
forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
name Volatility
Imp.Nonvolatile PrimType
t
 VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
ScalarVar Maybe (Exp lore)
forall a. Maybe a
Nothing (ScalarEntry -> VarEntry lore) -> ScalarEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarEntry
ScalarEntry PrimType
t

dPrim :: String -> PrimType -> ImpM lore r op VName
dPrim :: String -> PrimType -> ImpM lore r op VName
dPrim String
name PrimType
t = do VName
name' <- String -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
name
                  VName -> PrimType -> ImpM lore r op ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
name' PrimType
t
                  VName -> ImpM lore r op VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
name'

dPrimV_ :: VName -> Imp.Exp -> ImpM lore r op ()
dPrimV_ :: VName -> Exp -> ImpM lore r op ()
dPrimV_ VName
name Exp
e = do VName -> PrimType -> ImpM lore r op ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
name (PrimType -> ImpM lore r op ()) -> PrimType -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType Exp
e
                    VName
name VName -> Exp -> ImpM lore r op ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- Exp
e

dPrimV :: String -> Imp.Exp -> ImpM lore r op VName
dPrimV :: String -> Exp -> ImpM lore r op VName
dPrimV String
name Exp
e = do VName
name' <- String -> PrimType -> ImpM lore r op VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim String
name (PrimType -> ImpM lore r op VName)
-> PrimType -> ImpM lore r op VName
forall a b. (a -> b) -> a -> b
$ Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType Exp
e
                   VName
name' VName -> Exp -> ImpM lore r op ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- Exp
e
                   VName -> ImpM lore r op VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
name'

dPrimVE :: String -> Imp.Exp -> ImpM lore r op Imp.Exp
dPrimVE :: String -> Exp -> ImpM lore r op Exp
dPrimVE String
name Exp
e = do VName
name' <- String -> PrimType -> ImpM lore r op VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim String
name (PrimType -> ImpM lore r op VName)
-> PrimType -> ImpM lore r op VName
forall a b. (a -> b) -> a -> b
$ Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType Exp
e
                    VName
name' VName -> Exp -> ImpM lore r op ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- Exp
e
                    Exp -> ImpM lore r op Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> ImpM lore r op Exp) -> Exp -> ImpM lore r op Exp
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
name' (PrimType -> Exp) -> PrimType -> Exp
forall a b. (a -> b) -> a -> b
$ Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType Exp
e

memBoundToVarEntry :: Maybe (Exp lore) -> MemBound NoUniqueness
                   -> VarEntry lore
memBoundToVarEntry :: Maybe (Exp lore) -> MemBound NoUniqueness -> VarEntry lore
memBoundToVarEntry Maybe (Exp lore)
e (MemPrim PrimType
bt) =
  Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
ScalarVar Maybe (Exp lore)
e ScalarEntry :: PrimType -> ScalarEntry
ScalarEntry { entryScalarType :: PrimType
entryScalarType = PrimType
bt }
memBoundToVarEntry Maybe (Exp lore)
e (MemMem Space
space) =
  Maybe (Exp lore) -> MemEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> MemEntry -> VarEntry lore
MemVar Maybe (Exp lore)
e (MemEntry -> VarEntry lore) -> MemEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
space
memBoundToVarEntry Maybe (Exp lore)
e (MemArray PrimType
bt ShapeBase DimSize
shape NoUniqueness
_ (ArrayIn VName
mem IxFun
ixfun)) =
  let location :: MemLocation
location = VName -> [DimSize] -> IxFun Exp -> MemLocation
MemLocation VName
mem (ShapeBase DimSize -> [DimSize]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase DimSize
shape) (IxFun Exp -> MemLocation) -> IxFun Exp -> MemLocation
forall a b. (a -> b) -> a -> b
$ (PrimExp VName -> Exp) -> IxFun -> IxFun Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (PrimType -> PrimExp VName -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32) IxFun
ixfun
  in Maybe (Exp lore) -> ArrayEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> ArrayEntry -> VarEntry lore
ArrayVar Maybe (Exp lore)
e ArrayEntry :: MemLocation -> PrimType -> ArrayEntry
ArrayEntry { entryArrayLocation :: MemLocation
entryArrayLocation = MemLocation
location
                           , entryArrayElemType :: PrimType
entryArrayElemType = PrimType
bt
                           }

infoDec :: Mem lore =>
            NameInfo lore
         -> MemInfo SubExp NoUniqueness MemBind
infoDec :: NameInfo lore -> MemBound NoUniqueness
infoDec (LetName LetDec lore
dec) = LetDec lore
MemBound NoUniqueness
dec
infoDec (FParamName FParamInfo lore
dec) = MemInfo DimSize Uniqueness MemBind -> MemBound NoUniqueness
forall d u r. MemInfo d u r -> MemInfo d NoUniqueness r
noUniquenessReturns FParamInfo lore
MemInfo DimSize Uniqueness MemBind
dec
infoDec (LParamName LParamInfo lore
dec) = LParamInfo lore
MemBound NoUniqueness
dec
infoDec (IndexName IntType
it) = PrimType -> MemBound NoUniqueness
forall d u ret. PrimType -> MemInfo d u ret
MemPrim (PrimType -> MemBound NoUniqueness)
-> PrimType -> MemBound NoUniqueness
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
it

dInfo :: Mem lore =>
         Maybe (Exp lore) -> VName -> NameInfo lore
      -> ImpM lore r op ()
dInfo :: Maybe (Exp lore) -> VName -> NameInfo lore -> ImpM lore r op ()
dInfo Maybe (Exp lore)
e VName
name NameInfo lore
info = do
  let entry :: VarEntry lore
entry = Maybe (Exp lore) -> MemBound NoUniqueness -> VarEntry lore
forall lore.
Maybe (Exp lore) -> MemBound NoUniqueness -> VarEntry lore
memBoundToVarEntry Maybe (Exp lore)
e (MemBound NoUniqueness -> VarEntry lore)
-> MemBound NoUniqueness -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ NameInfo lore -> MemBound NoUniqueness
forall lore. Mem lore => NameInfo lore -> MemBound NoUniqueness
infoDec NameInfo lore
info
  case VarEntry lore
entry of
    MemVar Maybe (Exp lore)
_ MemEntry
entry' ->
      Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Code op
forall a. VName -> Space -> Code a
Imp.DeclareMem VName
name (Space -> Code op) -> Space -> Code op
forall a b. (a -> b) -> a -> b
$ MemEntry -> Space
entryMemSpace MemEntry
entry'
    ScalarVar Maybe (Exp lore)
_ ScalarEntry
entry' ->
      Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Volatility -> PrimType -> Code op
forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
name Volatility
Imp.Nonvolatile (PrimType -> Code op) -> PrimType -> Code op
forall a b. (a -> b) -> a -> b
$ ScalarEntry -> PrimType
entryScalarType ScalarEntry
entry'
    ArrayVar Maybe (Exp lore)
_ ArrayEntry
_ ->
      () -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name VarEntry lore
entry

dScope :: Mem lore =>
          Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
dScope :: Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
dScope Maybe (Exp lore)
e = ((VName, NameInfo lore) -> ImpM lore r op ())
-> [(VName, NameInfo lore)] -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((VName -> NameInfo lore -> ImpM lore r op ())
-> (VName, NameInfo lore) -> ImpM lore r op ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((VName -> NameInfo lore -> ImpM lore r op ())
 -> (VName, NameInfo lore) -> ImpM lore r op ())
-> (VName -> NameInfo lore -> ImpM lore r op ())
-> (VName, NameInfo lore)
-> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> VName -> NameInfo lore -> ImpM lore r op ()
forall lore r op.
Mem lore =>
Maybe (Exp lore) -> VName -> NameInfo lore -> ImpM lore r op ()
dInfo Maybe (Exp lore)
e) ([(VName, NameInfo lore)] -> ImpM lore r op ())
-> (Scope lore -> [(VName, NameInfo lore)])
-> Scope lore
-> ImpM lore r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope lore -> [(VName, NameInfo lore)]
forall k a. Map k a -> [(k, a)]
M.toList

dArray :: VName -> PrimType -> ShapeBase SubExp -> MemBind -> ImpM lore r op ()
dArray :: VName
-> PrimType -> ShapeBase DimSize -> MemBind -> ImpM lore r op ()
dArray VName
name PrimType
bt ShapeBase DimSize
shape MemBind
membind =
  VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
  Maybe (Exp lore) -> MemBound NoUniqueness -> VarEntry lore
forall lore.
Maybe (Exp lore) -> MemBound NoUniqueness -> VarEntry lore
memBoundToVarEntry Maybe (Exp lore)
forall a. Maybe a
Nothing (MemBound NoUniqueness -> VarEntry lore)
-> MemBound NoUniqueness -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ PrimType
-> ShapeBase DimSize
-> NoUniqueness
-> MemBind
-> MemBound NoUniqueness
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt ShapeBase DimSize
shape NoUniqueness
NoUniqueness MemBind
membind

everythingVolatile :: ImpM lore r op a -> ImpM lore r op a
everythingVolatile :: ImpM lore r op a -> ImpM lore r op a
everythingVolatile = (Env lore r op -> Env lore r op)
-> ImpM lore r op a -> ImpM lore r op a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((Env lore r op -> Env lore r op)
 -> ImpM lore r op a -> ImpM lore r op a)
-> (Env lore r op -> Env lore r op)
-> ImpM lore r op a
-> ImpM lore r op a
forall a b. (a -> b) -> a -> b
$ \Env lore r op
env -> Env lore r op
env { envVolatility :: Volatility
envVolatility = Volatility
Imp.Volatile }

-- | Remove the array targets.
funcallTargets :: Destination -> ImpM lore r op [VName]
funcallTargets :: Destination -> ImpM lore r op [VName]
funcallTargets (Destination Maybe Int
_ [ValueDestination]
dests) =
  [[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[VName]] -> [VName])
-> ImpM lore r op [[VName]] -> ImpM lore r op [VName]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ValueDestination -> ImpM lore r op [VName])
-> [ValueDestination] -> ImpM lore r op [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ValueDestination -> ImpM lore r op [VName]
forall (m :: * -> *). Monad m => ValueDestination -> m [VName]
funcallTarget [ValueDestination]
dests
  where funcallTarget :: ValueDestination -> m [VName]
funcallTarget (ScalarDestination VName
name) =
          [VName] -> m [VName]
forall (m :: * -> *) a. Monad m => a -> m a
return [VName
name]
        funcallTarget (ArrayDestination Maybe MemLocation
_) =
          [VName] -> m [VName]
forall (m :: * -> *) a. Monad m => a -> m a
return []
        funcallTarget (MemoryDestination VName
name) =
          [VName] -> m [VName]
forall (m :: * -> *) a. Monad m => a -> m a
return [VName
name]

-- | Compile things to 'Imp.Exp'.
class ToExp a where
  -- | Compile to an 'Imp.Exp', where the type (must must still be a
  -- primitive) is deduced monadically.
  toExp :: a -> ImpM lore r op Imp.Exp
  -- | Compile where we know the type in advance.
  toExp' :: PrimType -> a -> Imp.Exp

instance ToExp SubExp where
  toExp :: DimSize -> ImpM lore r op Exp
toExp (Constant PrimValue
v) =
    Exp -> ImpM lore r op Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> ImpM lore r op Exp) -> Exp -> ImpM lore r op Exp
forall a b. (a -> b) -> a -> b
$ PrimValue -> Exp
forall v. PrimValue -> PrimExp v
Imp.ValueExp PrimValue
v
  toExp (Var VName
v) =
    VName -> ImpM lore r op (VarEntry lore)
forall lore r op. VName -> ImpM lore r op (VarEntry lore)
lookupVar VName
v ImpM lore r op (VarEntry lore)
-> (VarEntry lore -> ImpM lore r op Exp) -> ImpM lore r op Exp
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    ScalarVar Maybe (Exp lore)
_ (ScalarEntry PrimType
pt) ->
      Exp -> ImpM lore r op Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> ImpM lore r op Exp) -> Exp -> ImpM lore r op Exp
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
v PrimType
pt
    VarEntry lore
_       -> String -> ImpM lore r op Exp
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op Exp) -> String -> ImpM lore r op Exp
forall a b. (a -> b) -> a -> b
$ String
"toExp SubExp: SubExp is not a primitive type: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
v

  toExp' :: PrimType -> DimSize -> Exp
toExp' PrimType
_ (Constant PrimValue
v) = PrimValue -> Exp
forall v. PrimValue -> PrimExp v
Imp.ValueExp PrimValue
v
  toExp' PrimType
t (Var VName
v) = VName -> PrimType -> Exp
Imp.var VName
v PrimType
t

instance ToExp (PrimExp VName) where
  toExp :: PrimExp VName -> ImpM lore r op Exp
toExp = Exp -> ImpM lore r op Exp
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp -> ImpM lore r op Exp)
-> (PrimExp VName -> Exp) -> PrimExp VName -> ImpM lore r op Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> ExpLeaf) -> PrimExp VName -> Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> ExpLeaf
Imp.ScalarVar
  toExp' :: PrimType -> PrimExp VName -> Exp
toExp' PrimType
_ = (VName -> ExpLeaf) -> PrimExp VName -> Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> ExpLeaf
Imp.ScalarVar

addVar :: VName -> VarEntry lore -> ImpM lore r op ()
addVar :: VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name VarEntry lore
entry =
  (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ())
-> (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState lore r op
s -> ImpState lore r op
s { stateVTable :: VTable lore
stateVTable = VName -> VarEntry lore -> VTable lore -> VTable lore
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
name VarEntry lore
entry (VTable lore -> VTable lore) -> VTable lore -> VTable lore
forall a b. (a -> b) -> a -> b
$ ImpState lore r op -> VTable lore
forall lore r op. ImpState lore r op -> VTable lore
stateVTable ImpState lore r op
s }

localDefaultSpace :: Imp.Space -> ImpM lore r op a -> ImpM lore r op a
localDefaultSpace :: Space -> ImpM lore r op a -> ImpM lore r op a
localDefaultSpace Space
space = (Env lore r op -> Env lore r op)
-> ImpM lore r op a -> ImpM lore r op a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (\Env lore r op
env -> Env lore r op
env { envDefaultSpace :: Space
envDefaultSpace = Space
space })

askFunction :: ImpM lore r op (Maybe Name)
askFunction :: ImpM lore r op (Maybe Name)
askFunction = (Env lore r op -> Maybe Name) -> ImpM lore r op (Maybe Name)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Maybe Name
forall lore r op. Env lore r op -> Maybe Name
envFunction

-- | Generate a 'VName', prefixed with 'askFunction' if it exists.
newVNameForFun :: String -> ImpM lore r op VName
newVNameForFun :: String -> ImpM lore r op VName
newVNameForFun String
s = do
  Maybe String
fname <- (Name -> String) -> Maybe Name -> Maybe String
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Name -> String
nameToString (Maybe Name -> Maybe String)
-> ImpM lore r op (Maybe Name) -> ImpM lore r op (Maybe String)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM lore r op (Maybe Name)
forall lore r op. ImpM lore r op (Maybe Name)
askFunction
  String -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> ImpM lore r op VName) -> String -> ImpM lore r op VName
forall a b. (a -> b) -> a -> b
$ String -> ShowS -> Maybe String -> String
forall b a. b -> (a -> b) -> Maybe a -> b
maybe String
"" (String -> ShowS
forall a. [a] -> [a] -> [a]
++String
".") Maybe String
fname String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
s

-- | Generate a 'Name', prefixed with 'askFunction' if it exists.
nameForFun :: String -> ImpM lore r op Name
nameForFun :: String -> ImpM lore r op Name
nameForFun String
s = do
  Maybe Name
fname <- ImpM lore r op (Maybe Name)
forall lore r op. ImpM lore r op (Maybe Name)
askFunction
  Name -> ImpM lore r op Name
forall (m :: * -> *) a. Monad m => a -> m a
return (Name -> ImpM lore r op Name) -> Name -> ImpM lore r op Name
forall a b. (a -> b) -> a -> b
$ Name -> (Name -> Name) -> Maybe Name -> Name
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Name
"" (Name -> Name -> Name
forall a. Semigroup a => a -> a -> a
<>Name
".") Maybe Name
fname Name -> Name -> Name
forall a. Semigroup a => a -> a -> a
<> String -> Name
nameFromString String
s

askEnv :: ImpM lore r op r
askEnv :: ImpM lore r op r
askEnv = (Env lore r op -> r) -> ImpM lore r op r
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> r
forall lore r op. Env lore r op -> r
envEnv

localEnv :: (r -> r) -> ImpM lore r op a -> ImpM lore r op a
localEnv :: (r -> r) -> ImpM lore r op a -> ImpM lore r op a
localEnv r -> r
f = (Env lore r op -> Env lore r op)
-> ImpM lore r op a -> ImpM lore r op a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((Env lore r op -> Env lore r op)
 -> ImpM lore r op a -> ImpM lore r op a)
-> (Env lore r op -> Env lore r op)
-> ImpM lore r op a
-> ImpM lore r op a
forall a b. (a -> b) -> a -> b
$ \Env lore r op
env -> Env lore r op
env { envEnv :: r
envEnv = r -> r
f (r -> r) -> r -> r
forall a b. (a -> b) -> a -> b
$ Env lore r op -> r
forall lore r op. Env lore r op -> r
envEnv Env lore r op
env }

-- | The active attributes, including those for the statement
-- currently being compiled.
askAttrs :: ImpM lore r op Attrs
askAttrs :: ImpM lore r op Attrs
askAttrs = (Env lore r op -> Attrs) -> ImpM lore r op Attrs
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Attrs
forall lore r op. Env lore r op -> Attrs
envAttrs

-- | Add more attributes to what is returning by 'askAttrs'.
localAttrs :: Attrs -> ImpM lore r op a -> ImpM lore r op a
localAttrs :: Attrs -> ImpM lore r op a -> ImpM lore r op a
localAttrs Attrs
attrs = (Env lore r op -> Env lore r op)
-> ImpM lore r op a -> ImpM lore r op a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((Env lore r op -> Env lore r op)
 -> ImpM lore r op a -> ImpM lore r op a)
-> (Env lore r op -> Env lore r op)
-> ImpM lore r op a
-> ImpM lore r op a
forall a b. (a -> b) -> a -> b
$ \Env lore r op
env -> Env lore r op
env { envAttrs :: Attrs
envAttrs = Attrs
attrs Attrs -> Attrs -> Attrs
forall a. Semigroup a => a -> a -> a
<> Env lore r op -> Attrs
forall lore r op. Env lore r op -> Attrs
envAttrs Env lore r op
env }

localOps :: Operations lore r op -> ImpM lore r op a -> ImpM lore r op a
localOps :: Operations lore r op -> ImpM lore r op a -> ImpM lore r op a
localOps Operations lore r op
ops = (Env lore r op -> Env lore r op)
-> ImpM lore r op a -> ImpM lore r op a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((Env lore r op -> Env lore r op)
 -> ImpM lore r op a -> ImpM lore r op a)
-> (Env lore r op -> Env lore r op)
-> ImpM lore r op a
-> ImpM lore r op a
forall a b. (a -> b) -> a -> b
$ \Env lore r op
env ->
                         Env lore r op
env { envExpCompiler :: ExpCompiler lore r op
envExpCompiler = Operations lore r op -> ExpCompiler lore r op
forall lore r op. Operations lore r op -> ExpCompiler lore r op
opsExpCompiler Operations lore r op
ops
                             , envStmsCompiler :: StmsCompiler lore r op
envStmsCompiler = Operations lore r op -> StmsCompiler lore r op
forall lore r op. Operations lore r op -> StmsCompiler lore r op
opsStmsCompiler Operations lore r op
ops
                             , envCopyCompiler :: CopyCompiler lore r op
envCopyCompiler = Operations lore r op -> CopyCompiler lore r op
forall lore r op. Operations lore r op -> CopyCompiler lore r op
opsCopyCompiler Operations lore r op
ops
                             , envOpCompiler :: OpCompiler lore r op
envOpCompiler = Operations lore r op -> OpCompiler lore r op
forall lore r op. Operations lore r op -> OpCompiler lore r op
opsOpCompiler Operations lore r op
ops
                             , envAllocCompilers :: Map Space (AllocCompiler lore r op)
envAllocCompilers = Operations lore r op -> Map Space (AllocCompiler lore r op)
forall lore r op.
Operations lore r op -> Map Space (AllocCompiler lore r op)
opsAllocCompilers Operations lore r op
ops
                             }

-- | Get the current symbol table.
getVTable :: ImpM lore r op (VTable lore)
getVTable :: ImpM lore r op (VTable lore)
getVTable = (ImpState lore r op -> VTable lore) -> ImpM lore r op (VTable lore)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ImpState lore r op -> VTable lore
forall lore r op. ImpState lore r op -> VTable lore
stateVTable

putVTable :: VTable lore -> ImpM lore r op ()
putVTable :: VTable lore -> ImpM lore r op ()
putVTable VTable lore
vtable = (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ())
-> (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState lore r op
s -> ImpState lore r op
s { stateVTable :: VTable lore
stateVTable = VTable lore
vtable }

-- | Run an action with a modified symbol table.  All changes to the
-- symbol table will be reverted once the action is done!
localVTable :: (VTable lore -> VTable lore) -> ImpM lore r op a -> ImpM lore r op a
localVTable :: (VTable lore -> VTable lore)
-> ImpM lore r op a -> ImpM lore r op a
localVTable VTable lore -> VTable lore
f ImpM lore r op a
m = do
  VTable lore
old_vtable <- ImpM lore r op (VTable lore)
forall lore r op. ImpM lore r op (VTable lore)
getVTable
  VTable lore -> ImpM lore r op ()
forall lore r op. VTable lore -> ImpM lore r op ()
putVTable (VTable lore -> ImpM lore r op ())
-> VTable lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VTable lore -> VTable lore
f VTable lore
old_vtable
  a
a <- ImpM lore r op a
m
  VTable lore -> ImpM lore r op ()
forall lore r op. VTable lore -> ImpM lore r op ()
putVTable VTable lore
old_vtable
  a -> ImpM lore r op a
forall (m :: * -> *) a. Monad m => a -> m a
return a
a

lookupVar :: VName -> ImpM lore r op (VarEntry lore)
lookupVar :: VName -> ImpM lore r op (VarEntry lore)
lookupVar VName
name = do
  Maybe (VarEntry lore)
res <- (ImpState lore r op -> Maybe (VarEntry lore))
-> ImpM lore r op (Maybe (VarEntry lore))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ImpState lore r op -> Maybe (VarEntry lore))
 -> ImpM lore r op (Maybe (VarEntry lore)))
-> (ImpState lore r op -> Maybe (VarEntry lore))
-> ImpM lore r op (Maybe (VarEntry lore))
forall a b. (a -> b) -> a -> b
$ VName -> Map VName (VarEntry lore) -> Maybe (VarEntry lore)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name (Map VName (VarEntry lore) -> Maybe (VarEntry lore))
-> (ImpState lore r op -> Map VName (VarEntry lore))
-> ImpState lore r op
-> Maybe (VarEntry lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ImpState lore r op -> Map VName (VarEntry lore)
forall lore r op. ImpState lore r op -> VTable lore
stateVTable
  case Maybe (VarEntry lore)
res of
    Just VarEntry lore
entry -> VarEntry lore -> ImpM lore r op (VarEntry lore)
forall (m :: * -> *) a. Monad m => a -> m a
return VarEntry lore
entry
    Maybe (VarEntry lore)
_ -> String -> ImpM lore r op (VarEntry lore)
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op (VarEntry lore))
-> String -> ImpM lore r op (VarEntry lore)
forall a b. (a -> b) -> a -> b
$ String
"Unknown variable: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
name

lookupArray :: VName -> ImpM lore r op ArrayEntry
lookupArray :: VName -> ImpM lore r op ArrayEntry
lookupArray VName
name = do
  VarEntry lore
res <- VName -> ImpM lore r op (VarEntry lore)
forall lore r op. VName -> ImpM lore r op (VarEntry lore)
lookupVar VName
name
  case VarEntry lore
res of
    ArrayVar Maybe (Exp lore)
_ ArrayEntry
entry -> ArrayEntry -> ImpM lore r op ArrayEntry
forall (m :: * -> *) a. Monad m => a -> m a
return ArrayEntry
entry
    VarEntry lore
_                -> String -> ImpM lore r op ArrayEntry
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op ArrayEntry)
-> String -> ImpM lore r op ArrayEntry
forall a b. (a -> b) -> a -> b
$ String
"ImpGen.lookupArray: not an array: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
name

lookupMemory :: VName -> ImpM lore r op MemEntry
lookupMemory :: VName -> ImpM lore r op MemEntry
lookupMemory VName
name = do
  VarEntry lore
res <- VName -> ImpM lore r op (VarEntry lore)
forall lore r op. VName -> ImpM lore r op (VarEntry lore)
lookupVar VName
name
  case VarEntry lore
res of
    MemVar Maybe (Exp lore)
_ MemEntry
entry -> MemEntry -> ImpM lore r op MemEntry
forall (m :: * -> *) a. Monad m => a -> m a
return MemEntry
entry
    VarEntry lore
_              -> String -> ImpM lore r op MemEntry
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op MemEntry)
-> String -> ImpM lore r op MemEntry
forall a b. (a -> b) -> a -> b
$ String
"Unknown memory block: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
name

destinationFromPattern :: Mem lore => Pattern lore -> ImpM lore r op Destination
destinationFromPattern :: Pattern lore -> ImpM lore r op Destination
destinationFromPattern Pattern lore
pat =
  ([ValueDestination] -> Destination)
-> ImpM lore r op [ValueDestination] -> ImpM lore r op Destination
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Maybe Int -> [ValueDestination] -> Destination
Destination (VName -> Int
baseTag (VName -> Int) -> Maybe VName -> Maybe Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [VName] -> Maybe VName
forall a. [a] -> Maybe a
maybeHead (PatternT (MemBound NoUniqueness) -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern lore
PatternT (MemBound NoUniqueness)
pat))) (ImpM lore r op [ValueDestination] -> ImpM lore r op Destination)
-> ([PatElemT (MemBound NoUniqueness)]
    -> ImpM lore r op [ValueDestination])
-> [PatElemT (MemBound NoUniqueness)]
-> ImpM lore r op Destination
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElemT (MemBound NoUniqueness)
 -> ImpM lore r op ValueDestination)
-> [PatElemT (MemBound NoUniqueness)]
-> ImpM lore r op [ValueDestination]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM PatElemT (MemBound NoUniqueness) -> ImpM lore r op ValueDestination
forall dec lore r op.
PatElemT dec -> ImpM lore r op ValueDestination
inspect ([PatElemT (MemBound NoUniqueness)] -> ImpM lore r op Destination)
-> [PatElemT (MemBound NoUniqueness)] -> ImpM lore r op Destination
forall a b. (a -> b) -> a -> b
$
  PatternT (MemBound NoUniqueness)
-> [PatElemT (MemBound NoUniqueness)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern lore
PatternT (MemBound NoUniqueness)
pat
  where inspect :: PatElemT dec -> ImpM lore r op ValueDestination
inspect PatElemT dec
patElem = do
          let name :: VName
name = PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT dec
patElem
          VarEntry lore
entry <- VName -> ImpM lore r op (VarEntry lore)
forall lore r op. VName -> ImpM lore r op (VarEntry lore)
lookupVar VName
name
          case VarEntry lore
entry of
            ArrayVar Maybe (Exp lore)
_ (ArrayEntry MemLocation{} PrimType
_) ->
              ValueDestination -> ImpM lore r op ValueDestination
forall (m :: * -> *) a. Monad m => a -> m a
return (ValueDestination -> ImpM lore r op ValueDestination)
-> ValueDestination -> ImpM lore r op ValueDestination
forall a b. (a -> b) -> a -> b
$ Maybe MemLocation -> ValueDestination
ArrayDestination Maybe MemLocation
forall a. Maybe a
Nothing
            MemVar{} ->
              ValueDestination -> ImpM lore r op ValueDestination
forall (m :: * -> *) a. Monad m => a -> m a
return (ValueDestination -> ImpM lore r op ValueDestination)
-> ValueDestination -> ImpM lore r op ValueDestination
forall a b. (a -> b) -> a -> b
$ VName -> ValueDestination
MemoryDestination VName
name

            ScalarVar{} ->
              ValueDestination -> ImpM lore r op ValueDestination
forall (m :: * -> *) a. Monad m => a -> m a
return (ValueDestination -> ImpM lore r op ValueDestination)
-> ValueDestination -> ImpM lore r op ValueDestination
forall a b. (a -> b) -> a -> b
$ VName -> ValueDestination
ScalarDestination VName
name

fullyIndexArray :: VName -> [Imp.Exp]
                -> ImpM lore r op (VName, Imp.Space, Count Elements Imp.Exp)
fullyIndexArray :: VName -> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
fullyIndexArray VName
name [Exp]
indices = do
  ArrayEntry
arr <- VName -> ImpM lore r op ArrayEntry
forall lore r op. VName -> ImpM lore r op ArrayEntry
lookupArray VName
name
  MemLocation
-> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
forall lore r op.
MemLocation
-> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
fullyIndexArray' (ArrayEntry -> MemLocation
entryArrayLocation ArrayEntry
arr) [Exp]
indices

fullyIndexArray' :: MemLocation -> [Imp.Exp]
                 -> ImpM lore r op (VName, Imp.Space, Count Elements Imp.Exp)
fullyIndexArray' :: MemLocation
-> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
fullyIndexArray' (MemLocation VName
mem [DimSize]
_ IxFun Exp
ixfun) [Exp]
indices = do
  Space
space <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM lore r op MemEntry -> ImpM lore r op Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM lore r op MemEntry
forall lore r op. VName -> ImpM lore r op MemEntry
lookupMemory VName
mem
  let indices' :: [Exp]
indices' = case Space
space of
                   ScalarSpace [DimSize]
ds PrimType
_ ->
                     let ([Exp]
zero_is, [Exp]
is) = Int -> [Exp] -> ([Exp], [Exp])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([DimSize] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimSize]
ds) [Exp]
indices
                     in (Exp -> Exp) -> [Exp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (Exp -> Exp -> Exp
forall a b. a -> b -> a
const Exp
0) [Exp]
zero_is [Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++ [Exp]
is
                   Space
_ -> [Exp]
indices
  (VName, Space, Count Elements Exp)
-> ImpM lore r op (VName, Space, Count Elements Exp)
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
mem, Space
space,
          Exp -> Count Elements Exp
elements (Exp -> Count Elements Exp) -> Exp -> Count Elements Exp
forall a b. (a -> b) -> a -> b
$ IxFun Exp -> [Exp] -> Exp
forall num.
(IntegralExp num, Eq num) =>
IxFun num -> Indices num -> num
IxFun.index IxFun Exp
ixfun [Exp]
indices')

-- More complicated read/write operations that use index functions.

copy :: CopyCompiler lore r op
copy :: CopyCompiler lore r op
copy PrimType
bt MemLocation
dest [DimIndex Exp]
destslice MemLocation
src [DimIndex Exp]
srcslice = do
  CopyCompiler lore r op
cc <- (Env lore r op -> CopyCompiler lore r op)
-> ImpM lore r op (CopyCompiler lore r op)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> CopyCompiler lore r op
forall lore r op. Env lore r op -> CopyCompiler lore r op
envCopyCompiler
  CopyCompiler lore r op
cc PrimType
bt MemLocation
dest [DimIndex Exp]
destslice MemLocation
src [DimIndex Exp]
srcslice

-- | Use an 'Imp.Copy' if possible, otherwise 'copyElementWise'.
defaultCopy :: CopyCompiler lore r op
defaultCopy :: CopyCompiler lore r op
defaultCopy PrimType
bt MemLocation
dest [DimIndex Exp]
destslice MemLocation
src [DimIndex Exp]
srcslice
  | Just Exp
destoffset <-
      IxFun Exp -> Exp -> Maybe Exp
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
IxFun.linearWithOffset (IxFun Exp -> [DimIndex Exp] -> IxFun Exp
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun Exp
destIxFun [DimIndex Exp]
destslice) Exp
bt_size,
    Just Exp
srcoffset  <-
      IxFun Exp -> Exp -> Maybe Exp
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
IxFun.linearWithOffset (IxFun Exp -> [DimIndex Exp] -> IxFun Exp
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun Exp
srcIxFun [DimIndex Exp]
srcslice) Exp
bt_size = do
        Space
srcspace <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM lore r op MemEntry -> ImpM lore r op Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM lore r op MemEntry
forall lore r op. VName -> ImpM lore r op MemEntry
lookupMemory VName
srcmem
        Space
destspace <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM lore r op MemEntry -> ImpM lore r op Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM lore r op MemEntry
forall lore r op. VName -> ImpM lore r op MemEntry
lookupMemory VName
destmem
        if Space -> Bool
isScalarSpace Space
srcspace Bool -> Bool -> Bool
|| Space -> Bool
isScalarSpace Space
destspace
          then CopyCompiler lore r op
forall lore r op. CopyCompiler lore r op
copyElementWise PrimType
bt MemLocation
dest [DimIndex Exp]
destslice MemLocation
src [DimIndex Exp]
srcslice
          else Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName
-> Count Bytes Exp
-> Space
-> VName
-> Count Bytes Exp
-> Space
-> Count Bytes Exp
-> Code op
forall a.
VName
-> Count Bytes Exp
-> Space
-> VName
-> Count Bytes Exp
-> Space
-> Count Bytes Exp
-> Code a
Imp.Copy
               VName
destmem (Exp -> Count Bytes Exp
bytes Exp
destoffset) Space
destspace
               VName
srcmem (Exp -> Count Bytes Exp
bytes Exp
srcoffset) Space
srcspace (Count Bytes Exp -> Code op) -> Count Bytes Exp -> Code op
forall a b. (a -> b) -> a -> b
$
               Count Elements Exp
num_elems Count Elements Exp -> PrimType -> Count Bytes Exp
`withElemType` PrimType
bt
  | Bool
otherwise =
      CopyCompiler lore r op
forall lore r op. CopyCompiler lore r op
copyElementWise PrimType
bt MemLocation
dest [DimIndex Exp]
destslice MemLocation
src [DimIndex Exp]
srcslice
  where bt_size :: Exp
bt_size = PrimType -> Exp
forall a. Num a => PrimType -> a
primByteSize PrimType
bt
        num_elems :: Count Elements Exp
num_elems = Exp -> Count Elements Exp
Imp.elements (Exp -> Count Elements Exp) -> Exp -> Count Elements Exp
forall a b. (a -> b) -> a -> b
$ [Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([Exp] -> Exp) -> [Exp] -> Exp
forall a b. (a -> b) -> a -> b
$ [DimIndex Exp] -> [Exp]
forall d. Slice d -> [d]
sliceDims [DimIndex Exp]
srcslice
        MemLocation VName
destmem [DimSize]
_ IxFun Exp
destIxFun = MemLocation
dest
        MemLocation VName
srcmem [DimSize]
_ IxFun Exp
srcIxFun = MemLocation
src
        isScalarSpace :: Space -> Bool
isScalarSpace ScalarSpace{} = Bool
True
        isScalarSpace Space
_ = Bool
False

copyElementWise :: CopyCompiler lore r op
copyElementWise :: CopyCompiler lore r op
copyElementWise PrimType
bt MemLocation
dest [DimIndex Exp]
destslice MemLocation
src [DimIndex Exp]
srcslice = do
    let bounds :: [Exp]
bounds = [DimIndex Exp] -> [Exp]
forall d. Slice d -> [d]
sliceDims [DimIndex Exp]
srcslice
    [VName]
is <- Int -> ImpM lore r op VName -> ImpM lore r op [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM ([Exp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Exp]
bounds) (String -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"i")
    let ivars :: [Exp]
ivars = (VName -> Exp) -> [VName] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Exp
Imp.vi32 [VName]
is
    (VName
destmem, Space
destspace, Count Elements Exp
destidx) <-
      MemLocation
-> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
forall lore r op.
MemLocation
-> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
fullyIndexArray' MemLocation
dest ([Exp] -> ImpM lore r op (VName, Space, Count Elements Exp))
-> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
forall a b. (a -> b) -> a -> b
$ [DimIndex Exp] -> [Exp] -> [Exp]
forall d. Num d => Slice d -> [d] -> [d]
fixSlice [DimIndex Exp]
destslice [Exp]
ivars
    (VName
srcmem, Space
srcspace, Count Elements Exp
srcidx) <-
      MemLocation
-> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
forall lore r op.
MemLocation
-> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
fullyIndexArray' MemLocation
src ([Exp] -> ImpM lore r op (VName, Space, Count Elements Exp))
-> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
forall a b. (a -> b) -> a -> b
$ [DimIndex Exp] -> [Exp] -> [Exp]
forall d. Num d => Slice d -> [d] -> [d]
fixSlice [DimIndex Exp]
srcslice [Exp]
ivars
    Volatility
vol <- (Env lore r op -> Volatility) -> ImpM lore r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Volatility
forall lore r op. Env lore r op -> Volatility
envVolatility
    Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ ((Code op -> Code op)
 -> (Code op -> Code op) -> Code op -> Code op)
-> (Code op -> Code op)
-> [Code op -> Code op]
-> Code op
-> Code op
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (Code op -> Code op) -> (Code op -> Code op) -> Code op -> Code op
forall b c a. (b -> c) -> (a -> b) -> a -> c
(.) Code op -> Code op
forall a. a -> a
id ((VName -> Exp -> Code op -> Code op)
-> [VName] -> [Exp] -> [Code op -> Code op]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (VName -> IntType -> Exp -> Code op -> Code op
forall a. VName -> IntType -> Exp -> Code a -> Code a
`Imp.For` IntType
Int32) [VName]
is [Exp]
bounds) (Code op -> Code op) -> Code op -> Code op
forall a b. (a -> b) -> a -> b
$
      VName
-> Count Elements Exp
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code op
forall a.
VName
-> Count Elements Exp
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
destmem Count Elements Exp
destidx PrimType
bt Space
destspace Volatility
vol (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$
      VName
-> Count Elements Exp -> PrimType -> Space -> Volatility -> Exp
Imp.index VName
srcmem Count Elements Exp
srcidx PrimType
bt Space
srcspace Volatility
vol

-- | Copy from here to there; both destination and source may be
-- indexeded.
copyArrayDWIM :: PrimType
              -> MemLocation -> [DimIndex Imp.Exp]
              -> MemLocation -> [DimIndex Imp.Exp]
              -> ImpM lore r op (Imp.Code op)
copyArrayDWIM :: PrimType
-> MemLocation
-> [DimIndex Exp]
-> MemLocation
-> [DimIndex Exp]
-> ImpM lore r op (Code op)
copyArrayDWIM PrimType
bt
  destlocation :: MemLocation
destlocation@(MemLocation VName
_ [DimSize]
destshape IxFun Exp
_) [DimIndex Exp]
destslice
  srclocation :: MemLocation
srclocation@(MemLocation VName
_ [DimSize]
srcshape IxFun Exp
_) [DimIndex Exp]
srcslice

  | Just [Exp]
destis <- (DimIndex Exp -> Maybe Exp) -> [DimIndex Exp] -> Maybe [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndex Exp -> Maybe Exp
forall d. DimIndex d -> Maybe d
dimFix [DimIndex Exp]
destslice,
    Just [Exp]
srcis <- (DimIndex Exp -> Maybe Exp) -> [DimIndex Exp] -> Maybe [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndex Exp -> Maybe Exp
forall d. DimIndex d -> Maybe d
dimFix [DimIndex Exp]
srcslice,
    [Exp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Exp]
srcis Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [DimSize] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimSize]
srcshape,
    [Exp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Exp]
destis Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [DimSize] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimSize]
destshape = do
  (VName
targetmem, Space
destspace, Count Elements Exp
targetoffset) <-
    MemLocation
-> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
forall lore r op.
MemLocation
-> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
fullyIndexArray' MemLocation
destlocation [Exp]
destis
  (VName
srcmem, Space
srcspace, Count Elements Exp
srcoffset) <-
    MemLocation
-> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
forall lore r op.
MemLocation
-> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
fullyIndexArray' MemLocation
srclocation [Exp]
srcis
  Volatility
vol <- (Env lore r op -> Volatility) -> ImpM lore r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Volatility
forall lore r op. Env lore r op -> Volatility
envVolatility
  Code op -> ImpM lore r op (Code op)
forall (m :: * -> *) a. Monad m => a -> m a
return (Code op -> ImpM lore r op (Code op))
-> Code op -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$ VName
-> Count Elements Exp
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code op
forall a.
VName
-> Count Elements Exp
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
targetmem Count Elements Exp
targetoffset PrimType
bt Space
destspace Volatility
vol (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$
    VName
-> Count Elements Exp -> PrimType -> Space -> Volatility -> Exp
Imp.index VName
srcmem Count Elements Exp
srcoffset PrimType
bt Space
srcspace Volatility
vol

  | Bool
otherwise = do
      let destslice' :: [DimIndex Exp]
destslice' =
            [Exp] -> [DimIndex Exp] -> [DimIndex Exp]
forall d. Num d => [d] -> [DimIndex d] -> [DimIndex d]
fullSliceNum ((DimSize -> Exp) -> [DimSize] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> DimSize -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32) [DimSize]
destshape) [DimIndex Exp]
destslice
          srcslice' :: [DimIndex Exp]
srcslice'  =
            [Exp] -> [DimIndex Exp] -> [DimIndex Exp]
forall d. Num d => [d] -> [DimIndex d] -> [DimIndex d]
fullSliceNum ((DimSize -> Exp) -> [DimSize] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> DimSize -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32) [DimSize]
srcshape) [DimIndex Exp]
srcslice
          destrank :: Int
destrank = [Exp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([Exp] -> Int) -> [Exp] -> Int
forall a b. (a -> b) -> a -> b
$ [DimIndex Exp] -> [Exp]
forall d. Slice d -> [d]
sliceDims [DimIndex Exp]
destslice'
          srcrank :: Int
srcrank = [Exp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([Exp] -> Int) -> [Exp] -> Int
forall a b. (a -> b) -> a -> b
$ [DimIndex Exp] -> [Exp]
forall d. Slice d -> [d]
sliceDims [DimIndex Exp]
srcslice'
      if Int
destrank Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
srcrank
        then String -> ImpM lore r op (Code op)
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op (Code op))
-> String -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$ String
"copyArrayDWIM: cannot copy to " String -> ShowS
forall a. [a] -> [a] -> [a]
++
             VName -> String
forall a. Pretty a => a -> String
pretty (MemLocation -> VName
memLocationName MemLocation
destlocation) String -> ShowS
forall a. [a] -> [a] -> [a]
++
             String
" from " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty (MemLocation -> VName
memLocationName MemLocation
srclocation) String -> ShowS
forall a. [a] -> [a] -> [a]
++
             String
" because ranks do not match (" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Pretty a => a -> String
pretty Int
destrank String -> ShowS
forall a. [a] -> [a] -> [a]
++
             String
" vs " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Pretty a => a -> String
pretty Int
srcrank String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"
      else if MemLocation
destlocation MemLocation -> MemLocation -> Bool
forall a. Eq a => a -> a -> Bool
== MemLocation
srclocation Bool -> Bool -> Bool
&& [DimIndex Exp]
destslice' [DimIndex Exp] -> [DimIndex Exp] -> Bool
forall a. Eq a => a -> a -> Bool
== [DimIndex Exp]
srcslice'
        then Code op -> ImpM lore r op (Code op)
forall (m :: * -> *) a. Monad m => a -> m a
return Code op
forall a. Monoid a => a
mempty -- Copy would be no-op.
        else ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (ImpM lore r op () -> ImpM lore r op (Code op))
-> ImpM lore r op () -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$ CopyCompiler lore r op
forall lore r op. CopyCompiler lore r op
copy PrimType
bt MemLocation
destlocation [DimIndex Exp]
destslice' MemLocation
srclocation [DimIndex Exp]
srcslice'

-- | Like 'copyDWIM', but the target is a 'ValueDestination'
-- instead of a variable name.
copyDWIMDest :: ValueDestination -> [DimIndex Imp.Exp] -> SubExp -> [DimIndex Imp.Exp]
             -> ImpM lore r op ()

copyDWIMDest :: ValueDestination
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIMDest ValueDestination
_ [DimIndex Exp]
_ (Constant PrimValue
v) (DimIndex Exp
_:[DimIndex Exp]
_) =
  String -> ImpM lore r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op ()) -> String -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
  [String] -> String
unwords [String
"copyDWIMDest: constant source", PrimValue -> String
forall a. Pretty a => a -> String
pretty PrimValue
v, String
"cannot be indexed."]
copyDWIMDest ValueDestination
pat [DimIndex Exp]
dest_slice (Constant PrimValue
v) [] =
  case (DimIndex Exp -> Maybe Exp) -> [DimIndex Exp] -> Maybe [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndex Exp -> Maybe Exp
forall d. DimIndex d -> Maybe d
dimFix [DimIndex Exp]
dest_slice of
    Maybe [Exp]
Nothing ->
      String -> ImpM lore r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op ()) -> String -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
      [String] -> String
unwords [String
"copyDWIMDest: constant source", PrimValue -> String
forall a. Pretty a => a -> String
pretty PrimValue
v, String
"with slice destination."]
    Just [Exp]
dest_is ->
      case ValueDestination
pat of
        ScalarDestination VName
name ->
          Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
name (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ PrimValue -> Exp
forall v. PrimValue -> PrimExp v
Imp.ValueExp PrimValue
v
        MemoryDestination{} ->
          String -> ImpM lore r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op ()) -> String -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
          [String] -> String
unwords [String
"copyDWIMDest: constant source", PrimValue -> String
forall a. Pretty a => a -> String
pretty PrimValue
v, String
"cannot be written to memory destination."]
        ArrayDestination (Just MemLocation
dest_loc) -> do
          (VName
dest_mem, Space
dest_space, Count Elements Exp
dest_i) <-
            MemLocation
-> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
forall lore r op.
MemLocation
-> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
fullyIndexArray' MemLocation
dest_loc [Exp]
dest_is
          Volatility
vol <- (Env lore r op -> Volatility) -> ImpM lore r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Volatility
forall lore r op. Env lore r op -> Volatility
envVolatility
          Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName
-> Count Elements Exp
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code op
forall a.
VName
-> Count Elements Exp
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
dest_mem Count Elements Exp
dest_i PrimType
bt Space
dest_space Volatility
vol (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ PrimValue -> Exp
forall v. PrimValue -> PrimExp v
Imp.ValueExp PrimValue
v
        ArrayDestination Maybe MemLocation
Nothing ->
          String -> ImpM lore r op ()
forall a. HasCallStack => String -> a
error String
"copyDWIMDest: ArrayDestination Nothing"
  where bt :: PrimType
bt = PrimValue -> PrimType
primValueType PrimValue
v

copyDWIMDest ValueDestination
dest [DimIndex Exp]
dest_slice (Var VName
src) [DimIndex Exp]
src_slice = do
  VarEntry lore
src_entry <- VName -> ImpM lore r op (VarEntry lore)
forall lore r op. VName -> ImpM lore r op (VarEntry lore)
lookupVar VName
src
  case (ValueDestination
dest, VarEntry lore
src_entry) of
    (MemoryDestination VName
mem, MemVar Maybe (Exp lore)
_ (MemEntry Space
space)) ->
      Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> Space -> Code op
forall a. VName -> VName -> Space -> Code a
Imp.SetMem VName
mem VName
src Space
space

    (MemoryDestination{}, VarEntry lore
_) ->
      String -> ImpM lore r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op ()) -> String -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
      [String] -> String
unwords [String
"copyDWIMDest: cannot write", VName -> String
forall a. Pretty a => a -> String
pretty VName
src, String
"to memory destination."]

    (ValueDestination
_, MemVar{}) ->
      String -> ImpM lore r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op ()) -> String -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
      [String] -> String
unwords [String
"copyDWIMDest: source", VName -> String
forall a. Pretty a => a -> String
pretty VName
src, String
"is a memory block."]

    (ValueDestination
_, ScalarVar Maybe (Exp lore)
_ (ScalarEntry PrimType
_)) | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [DimIndex Exp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [DimIndex Exp]
src_slice ->
      String -> ImpM lore r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op ()) -> String -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
      [String] -> String
unwords [String
"copyDWIMDest: prim-typed source", VName -> String
forall a. Pretty a => a -> String
pretty VName
src, String
"with slice", [DimIndex Exp] -> String
forall a. Pretty a => a -> String
pretty [DimIndex Exp]
src_slice]

    (ScalarDestination VName
name, VarEntry lore
_) | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [DimIndex Exp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [DimIndex Exp]
dest_slice ->
      String -> ImpM lore r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op ()) -> String -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
      [String] -> String
unwords [String
"copyDWIMDest: prim-typed target", VName -> String
forall a. Pretty a => a -> String
pretty VName
name, String
"with slice", [DimIndex Exp] -> String
forall a. Pretty a => a -> String
pretty [DimIndex Exp]
dest_slice]

    (ScalarDestination VName
name, ScalarVar Maybe (Exp lore)
_ (ScalarEntry PrimType
pt)) ->
      Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
name (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
src PrimType
pt

    (ScalarDestination VName
name, ArrayVar Maybe (Exp lore)
_ ArrayEntry
arr)
      | Just [Exp]
src_is <- (DimIndex Exp -> Maybe Exp) -> [DimIndex Exp] -> Maybe [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndex Exp -> Maybe Exp
forall d. DimIndex d -> Maybe d
dimFix [DimIndex Exp]
src_slice,
        [DimIndex Exp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex Exp]
src_slice Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [DimSize] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (ArrayEntry -> [DimSize]
entryArrayShape ArrayEntry
arr) -> do
          let bt :: PrimType
bt = ArrayEntry -> PrimType
entryArrayElemType ArrayEntry
arr
          (VName
mem, Space
space, Count Elements Exp
i) <-
            MemLocation
-> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
forall lore r op.
MemLocation
-> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
fullyIndexArray' (ArrayEntry -> MemLocation
entryArrayLocation ArrayEntry
arr) [Exp]
src_is
          Volatility
vol <- (Env lore r op -> Volatility) -> ImpM lore r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Volatility
forall lore r op. Env lore r op -> Volatility
envVolatility
          Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
name (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ VName
-> Count Elements Exp -> PrimType -> Space -> Volatility -> Exp
Imp.index VName
mem Count Elements Exp
i PrimType
bt Space
space Volatility
vol
      | Bool
otherwise ->
          String -> ImpM lore r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op ()) -> String -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
          [String] -> String
unwords [String
"copyDWIMDest: prim-typed target", VName -> String
forall a. Pretty a => a -> String
pretty VName
name,
                   String
"and array-typed source", VName -> String
forall a. Pretty a => a -> String
pretty VName
src,
                   String
"with slice", [DimIndex Exp] -> String
forall a. Pretty a => a -> String
pretty [DimIndex Exp]
src_slice]

    (ArrayDestination (Just MemLocation
dest_loc), ArrayVar Maybe (Exp lore)
_ ArrayEntry
src_arr) -> do
      let src_loc :: MemLocation
src_loc = ArrayEntry -> MemLocation
entryArrayLocation ArrayEntry
src_arr
          bt :: PrimType
bt = ArrayEntry -> PrimType
entryArrayElemType ArrayEntry
src_arr
      Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ())
-> ImpM lore r op (Code op) -> ImpM lore r op ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< PrimType
-> MemLocation
-> [DimIndex Exp]
-> MemLocation
-> [DimIndex Exp]
-> ImpM lore r op (Code op)
forall lore r op.
PrimType
-> MemLocation
-> [DimIndex Exp]
-> MemLocation
-> [DimIndex Exp]
-> ImpM lore r op (Code op)
copyArrayDWIM PrimType
bt MemLocation
dest_loc [DimIndex Exp]
dest_slice MemLocation
src_loc [DimIndex Exp]
src_slice

    (ArrayDestination (Just MemLocation
dest_loc), ScalarVar Maybe (Exp lore)
_ (ScalarEntry PrimType
bt))
      | Just [Exp]
dest_is <- (DimIndex Exp -> Maybe Exp) -> [DimIndex Exp] -> Maybe [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndex Exp -> Maybe Exp
forall d. DimIndex d -> Maybe d
dimFix [DimIndex Exp]
dest_slice -> do
          (VName
dest_mem, Space
dest_space, Count Elements Exp
dest_i) <- MemLocation
-> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
forall lore r op.
MemLocation
-> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
fullyIndexArray' MemLocation
dest_loc [Exp]
dest_is
          Volatility
vol <- (Env lore r op -> Volatility) -> ImpM lore r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Volatility
forall lore r op. Env lore r op -> Volatility
envVolatility
          Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName
-> Count Elements Exp
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code op
forall a.
VName
-> Count Elements Exp
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
dest_mem Count Elements Exp
dest_i PrimType
bt Space
dest_space Volatility
vol (VName -> PrimType -> Exp
Imp.var VName
src PrimType
bt)
      | Bool
otherwise ->
          String -> ImpM lore r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op ()) -> String -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
          [String] -> String
unwords [String
"copyDWIMDest: array-typed target and prim-typed source", VName -> String
forall a. Pretty a => a -> String
pretty VName
src,
                   String
"with slice", [DimIndex Exp] -> String
forall a. Pretty a => a -> String
pretty [DimIndex Exp]
dest_slice]

    (ArrayDestination Maybe MemLocation
Nothing, VarEntry lore
_) ->
      () -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return () -- Nothing to do; something else set some memory
                -- somewhere.

-- | Copy from here to there; both destination and source be
-- indexeded.  If so, they better be arrays of enough dimensions.
-- This function will generally just Do What I Mean, and Do The Right
-- Thing.  Both destination and source must be in scope.
copyDWIM :: VName -> [DimIndex Imp.Exp] -> SubExp -> [DimIndex Imp.Exp]
         -> ImpM lore r op ()
copyDWIM :: VName
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM VName
dest [DimIndex Exp]
dest_slice DimSize
src [DimIndex Exp]
src_slice = do
  VarEntry lore
dest_entry <- VName -> ImpM lore r op (VarEntry lore)
forall lore r op. VName -> ImpM lore r op (VarEntry lore)
lookupVar VName
dest
  let dest_target :: ValueDestination
dest_target =
        case VarEntry lore
dest_entry of
          ScalarVar Maybe (Exp lore)
_ ScalarEntry
_ ->
            VName -> ValueDestination
ScalarDestination VName
dest

          ArrayVar Maybe (Exp lore)
_ (ArrayEntry (MemLocation VName
mem [DimSize]
shape IxFun Exp
ixfun) PrimType
_) ->
            Maybe MemLocation -> ValueDestination
ArrayDestination (Maybe MemLocation -> ValueDestination)
-> Maybe MemLocation -> ValueDestination
forall a b. (a -> b) -> a -> b
$ MemLocation -> Maybe MemLocation
forall a. a -> Maybe a
Just (MemLocation -> Maybe MemLocation)
-> MemLocation -> Maybe MemLocation
forall a b. (a -> b) -> a -> b
$ VName -> [DimSize] -> IxFun Exp -> MemLocation
MemLocation VName
mem [DimSize]
shape IxFun Exp
ixfun

          MemVar Maybe (Exp lore)
_ MemEntry
_ ->
            VName -> ValueDestination
MemoryDestination VName
dest
  ValueDestination
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
forall lore r op.
ValueDestination
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIMDest ValueDestination
dest_target [DimIndex Exp]
dest_slice DimSize
src [DimIndex Exp]
src_slice

-- | As 'copyDWIM', but implicitly 'DimFix'es the indexes.
copyDWIMFix :: VName -> [Imp.Exp] -> SubExp -> [Imp.Exp] -> ImpM lore r op ()
copyDWIMFix :: VName -> [Exp] -> DimSize -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
dest [Exp]
dest_is DimSize
src [Exp]
src_is =
  VName
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM VName
dest ((Exp -> DimIndex Exp) -> [Exp] -> [DimIndex Exp]
forall a b. (a -> b) -> [a] -> [b]
map Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix [Exp]
dest_is) DimSize
src ((Exp -> DimIndex Exp) -> [Exp] -> [DimIndex Exp]
forall a b. (a -> b) -> [a] -> [b]
map Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix [Exp]
src_is)

-- | @compileAlloc pat size space@ allocates @n@ bytes of memory in @space@,
-- writing the result to @dest@, which must be a single
-- 'MemoryDestination',
compileAlloc :: Mem lore =>
                Pattern lore -> SubExp -> Space
             -> ImpM lore r op ()
compileAlloc :: Pattern lore -> DimSize -> Space -> ImpM lore r op ()
compileAlloc (Pattern [] [PatElemT (LetDec lore)
mem]) DimSize
e Space
space = do
  Count Bytes Exp
e' <- Exp -> Count Bytes Exp
Imp.bytes (Exp -> Count Bytes Exp)
-> ImpM lore r op Exp -> ImpM lore r op (Count Bytes Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> DimSize -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp DimSize
e
  Maybe (AllocCompiler lore r op)
allocator <- (Env lore r op -> Maybe (AllocCompiler lore r op))
-> ImpM lore r op (Maybe (AllocCompiler lore r op))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((Env lore r op -> Maybe (AllocCompiler lore r op))
 -> ImpM lore r op (Maybe (AllocCompiler lore r op)))
-> (Env lore r op -> Maybe (AllocCompiler lore r op))
-> ImpM lore r op (Maybe (AllocCompiler lore r op))
forall a b. (a -> b) -> a -> b
$ Space
-> Map Space (AllocCompiler lore r op)
-> Maybe (AllocCompiler lore r op)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Space
space (Map Space (AllocCompiler lore r op)
 -> Maybe (AllocCompiler lore r op))
-> (Env lore r op -> Map Space (AllocCompiler lore r op))
-> Env lore r op
-> Maybe (AllocCompiler lore r op)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env lore r op -> Map Space (AllocCompiler lore r op)
forall lore r op.
Env lore r op -> Map Space (AllocCompiler lore r op)
envAllocCompilers
  case Maybe (AllocCompiler lore r op)
allocator of
    Maybe (AllocCompiler lore r op)
Nothing -> Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Count Bytes Exp -> Space -> Code op
forall a. VName -> Count Bytes Exp -> Space -> Code a
Imp.Allocate (PatElemT (MemBound NoUniqueness) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT (MemBound NoUniqueness)
mem) Count Bytes Exp
e' Space
space
    Just AllocCompiler lore r op
allocator' -> AllocCompiler lore r op
allocator' (PatElemT (MemBound NoUniqueness) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT (MemBound NoUniqueness)
mem) Count Bytes Exp
e'
compileAlloc Pattern lore
pat DimSize
_ Space
_ =
  String -> ImpM lore r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op ()) -> String -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ String
"compileAlloc: Invalid pattern: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ PatternT (MemBound NoUniqueness) -> String
forall a. Pretty a => a -> String
pretty Pattern lore
PatternT (MemBound NoUniqueness)
pat

-- | The number of bytes needed to represent the array in a
-- straightforward contiguous format, as an 'Int64' expression.
typeSize :: Type -> Count Bytes Imp.Exp
typeSize :: Type -> Count Bytes Exp
typeSize Type
t =
  Exp -> Count Bytes Exp
Imp.bytes (Exp -> Count Bytes Exp) -> Exp -> Count Bytes Exp
forall a b. (a -> b) -> a -> b
$ IntType -> Exp -> Exp
forall v. IntType -> PrimExp v -> PrimExp v
sExt IntType
Int64 (ExpLeaf -> PrimType -> Exp
forall v. v -> PrimType -> PrimExp v
Imp.LeafExp (PrimType -> ExpLeaf
Imp.SizeOf (PrimType -> ExpLeaf) -> PrimType -> ExpLeaf
forall a b. (a -> b) -> a -> b
$ Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t) PrimType
int32) Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
*
  [Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ((DimSize -> Exp) -> [DimSize] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (IntType -> Exp -> Exp
forall v. IntType -> PrimExp v -> PrimExp v
sExt IntType
Int64 (Exp -> Exp) -> (DimSize -> Exp) -> DimSize -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> DimSize -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32) (Type -> [DimSize]
forall u. TypeBase (ShapeBase DimSize) u -> [DimSize]
arrayDims Type
t))

--- Building blocks for constructing code.

sFor' :: VName -> IntType -> Imp.Exp -> ImpM lore r op () -> ImpM lore r op ()
sFor' :: VName -> IntType -> Exp -> ImpM lore r op () -> ImpM lore r op ()
sFor' VName
i IntType
it Exp
bound ImpM lore r op ()
body = do
  VName -> IntType -> ImpM lore r op ()
forall lore r op. VName -> IntType -> ImpM lore r op ()
addLoopVar VName
i IntType
it
  Code op
body' <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect ImpM lore r op ()
body
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> IntType -> Exp -> Code op -> Code op
forall a. VName -> IntType -> Exp -> Code a -> Code a
Imp.For VName
i IntType
it Exp
bound Code op
body'

sFor :: String -> Imp.Exp -> (Imp.Exp -> ImpM lore r op ()) -> ImpM lore r op ()
sFor :: String -> Exp -> (Exp -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
i Exp
bound Exp -> ImpM lore r op ()
body = do
  VName
i' <- String -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
i
  IntType
it <- case Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType Exp
bound of
          IntType IntType
it -> IntType -> ImpM lore r op IntType
forall (m :: * -> *) a. Monad m => a -> m a
return IntType
it
          PrimType
t -> String -> ImpM lore r op IntType
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op IntType)
-> String -> ImpM lore r op IntType
forall a b. (a -> b) -> a -> b
$ String
"sFor: bound " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Exp -> String
forall a. Pretty a => a -> String
pretty Exp
bound String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" is of type " String -> ShowS
forall a. [a] -> [a] -> [a]
++ PrimType -> String
forall a. Pretty a => a -> String
pretty PrimType
t
  VName -> IntType -> ImpM lore r op ()
forall lore r op. VName -> IntType -> ImpM lore r op ()
addLoopVar VName
i' IntType
it
  Code op
body' <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (ImpM lore r op () -> ImpM lore r op (Code op))
-> ImpM lore r op () -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$ Exp -> ImpM lore r op ()
body (Exp -> ImpM lore r op ()) -> Exp -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
i' (PrimType -> Exp) -> PrimType -> Exp
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
it
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> IntType -> Exp -> Code op -> Code op
forall a. VName -> IntType -> Exp -> Code a -> Code a
Imp.For VName
i' IntType
it Exp
bound Code op
body'

sWhile :: Imp.Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhile :: Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhile Exp
cond ImpM lore r op ()
body = do
  Code op
body' <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect ImpM lore r op ()
body
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Exp -> Code op -> Code op
forall a. Exp -> Code a -> Code a
Imp.While Exp
cond Code op
body'

sComment :: String -> ImpM lore r op () -> ImpM lore r op ()
sComment :: String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
s ImpM lore r op ()
code = do
  Code op
code' <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect ImpM lore r op ()
code
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ String -> Code op -> Code op
forall a. String -> Code a -> Code a
Imp.Comment String
s Code op
code'

sIf :: Imp.Exp -> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf :: Exp -> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf Exp
cond ImpM lore r op ()
tbranch ImpM lore r op ()
fbranch = do
  Code op
tbranch' <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect ImpM lore r op ()
tbranch
  Code op
fbranch' <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect ImpM lore r op ()
fbranch
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Exp -> Code op -> Code op -> Code op
forall a. Exp -> Code a -> Code a -> Code a
Imp.If Exp
cond Code op
tbranch' Code op
fbranch'

sWhen :: Imp.Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen :: Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen Exp
cond ImpM lore r op ()
tbranch = Exp -> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
Exp -> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf Exp
cond ImpM lore r op ()
tbranch (() -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())

sUnless :: Imp.Exp -> ImpM lore r op () -> ImpM lore r op ()
sUnless :: Exp -> ImpM lore r op () -> ImpM lore r op ()
sUnless Exp
cond = Exp -> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
Exp -> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf Exp
cond (() -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())

sOp :: op -> ImpM lore r op ()
sOp :: op -> ImpM lore r op ()
sOp = Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ())
-> (op -> Code op) -> op -> ImpM lore r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. op -> Code op
forall a. a -> Code a
Imp.Op

sDeclareMem :: String -> Space -> ImpM lore r op VName
sDeclareMem :: String -> Space -> ImpM lore r op VName
sDeclareMem String
name Space
space = do
  VName
name' <- String -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
name
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Code op
forall a. VName -> Space -> Code a
Imp.DeclareMem VName
name' Space
space
  VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name' (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> MemEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> MemEntry -> VarEntry lore
MemVar Maybe (Exp lore)
forall a. Maybe a
Nothing (MemEntry -> VarEntry lore) -> MemEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
space
  VName -> ImpM lore r op VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
name'

sAlloc_ :: VName -> Count Bytes Imp.Exp -> Space -> ImpM lore r op ()
sAlloc_ :: VName -> Count Bytes Exp -> Space -> ImpM lore r op ()
sAlloc_ VName
name' Count Bytes Exp
size' Space
space = do
  Maybe (AllocCompiler lore r op)
allocator <- (Env lore r op -> Maybe (AllocCompiler lore r op))
-> ImpM lore r op (Maybe (AllocCompiler lore r op))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((Env lore r op -> Maybe (AllocCompiler lore r op))
 -> ImpM lore r op (Maybe (AllocCompiler lore r op)))
-> (Env lore r op -> Maybe (AllocCompiler lore r op))
-> ImpM lore r op (Maybe (AllocCompiler lore r op))
forall a b. (a -> b) -> a -> b
$ Space
-> Map Space (AllocCompiler lore r op)
-> Maybe (AllocCompiler lore r op)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Space
space (Map Space (AllocCompiler lore r op)
 -> Maybe (AllocCompiler lore r op))
-> (Env lore r op -> Map Space (AllocCompiler lore r op))
-> Env lore r op
-> Maybe (AllocCompiler lore r op)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env lore r op -> Map Space (AllocCompiler lore r op)
forall lore r op.
Env lore r op -> Map Space (AllocCompiler lore r op)
envAllocCompilers
  case Maybe (AllocCompiler lore r op)
allocator of
    Maybe (AllocCompiler lore r op)
Nothing -> Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Count Bytes Exp -> Space -> Code op
forall a. VName -> Count Bytes Exp -> Space -> Code a
Imp.Allocate VName
name' Count Bytes Exp
size' Space
space
    Just AllocCompiler lore r op
allocator' -> AllocCompiler lore r op
allocator' VName
name' Count Bytes Exp
size'

sAlloc :: String -> Count Bytes Imp.Exp -> Space -> ImpM lore r op VName
sAlloc :: String -> Count Bytes Exp -> Space -> ImpM lore r op VName
sAlloc String
name Count Bytes Exp
size Space
space = do
  VName
name' <- String -> Space -> ImpM lore r op VName
forall lore r op. String -> Space -> ImpM lore r op VName
sDeclareMem String
name Space
space
  VName -> Count Bytes Exp -> Space -> ImpM lore r op ()
forall lore r op.
VName -> Count Bytes Exp -> Space -> ImpM lore r op ()
sAlloc_ VName
name' Count Bytes Exp
size Space
space
  VName -> ImpM lore r op VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
name'

sArray :: String -> PrimType -> ShapeBase SubExp -> MemBind -> ImpM lore r op VName
sArray :: String
-> PrimType -> ShapeBase DimSize -> MemBind -> ImpM lore r op VName
sArray String
name PrimType
bt ShapeBase DimSize
shape MemBind
membind = do
  VName
name' <- String -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
name
  VName
-> PrimType -> ShapeBase DimSize -> MemBind -> ImpM lore r op ()
forall lore r op.
VName
-> PrimType -> ShapeBase DimSize -> MemBind -> ImpM lore r op ()
dArray VName
name' PrimType
bt ShapeBase DimSize
shape MemBind
membind
  VName -> ImpM lore r op VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
name'

-- | Declare an array in row-major order in the given memory block.
sArrayInMem :: String -> PrimType -> ShapeBase SubExp -> VName -> ImpM lore r op VName
sArrayInMem :: String
-> PrimType -> ShapeBase DimSize -> VName -> ImpM lore r op VName
sArrayInMem String
name PrimType
pt ShapeBase DimSize
shape VName
mem =
  String
-> PrimType -> ShapeBase DimSize -> MemBind -> ImpM lore r op VName
forall lore r op.
String
-> PrimType -> ShapeBase DimSize -> MemBind -> ImpM lore r op VName
sArray String
name PrimType
pt ShapeBase DimSize
shape (MemBind -> ImpM lore r op VName)
-> MemBind -> ImpM lore r op VName
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn VName
mem (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$
  Shape (PrimExp VName) -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota (Shape (PrimExp VName) -> IxFun) -> Shape (PrimExp VName) -> IxFun
forall a b. (a -> b) -> a -> b
$ (DimSize -> PrimExp VName) -> [DimSize] -> Shape (PrimExp VName)
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> DimSize -> PrimExp VName
primExpFromSubExp PrimType
int32) ([DimSize] -> Shape (PrimExp VName))
-> [DimSize] -> Shape (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ ShapeBase DimSize -> [DimSize]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase DimSize
shape

-- | Like 'sAllocArray', but permute the in-memory representation of the indices as specified.
sAllocArrayPerm :: String -> PrimType -> ShapeBase SubExp -> Space -> [Int] -> ImpM lore r op VName
sAllocArrayPerm :: String
-> PrimType
-> ShapeBase DimSize
-> Space
-> [Int]
-> ImpM lore r op VName
sAllocArrayPerm String
name PrimType
pt ShapeBase DimSize
shape Space
space [Int]
perm = do
  let permuted_dims :: [DimSize]
permuted_dims = [Int] -> [DimSize] -> [DimSize]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm ([DimSize] -> [DimSize]) -> [DimSize] -> [DimSize]
forall a b. (a -> b) -> a -> b
$ ShapeBase DimSize -> [DimSize]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase DimSize
shape
  VName
mem <- String -> Count Bytes Exp -> Space -> ImpM lore r op VName
forall lore r op.
String -> Count Bytes Exp -> Space -> ImpM lore r op VName
sAlloc (String
name String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_mem") (Type -> Count Bytes Exp
typeSize (PrimType -> ShapeBase DimSize -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
pt ShapeBase DimSize
shape NoUniqueness
NoUniqueness)) Space
space
  let iota_ixfun :: IxFun
iota_ixfun = Shape (PrimExp VName) -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota (Shape (PrimExp VName) -> IxFun) -> Shape (PrimExp VName) -> IxFun
forall a b. (a -> b) -> a -> b
$ (DimSize -> PrimExp VName) -> [DimSize] -> Shape (PrimExp VName)
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> DimSize -> PrimExp VName
primExpFromSubExp PrimType
int32) [DimSize]
permuted_dims
  String
-> PrimType -> ShapeBase DimSize -> MemBind -> ImpM lore r op VName
forall lore r op.
String
-> PrimType -> ShapeBase DimSize -> MemBind -> ImpM lore r op VName
sArray String
name PrimType
pt ShapeBase DimSize
shape (MemBind -> ImpM lore r op VName)
-> MemBind -> ImpM lore r op VName
forall a b. (a -> b) -> a -> b
$
    VName -> IxFun -> MemBind
ArrayIn VName
mem (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$ IxFun -> [Int] -> IxFun
forall num. IntegralExp num => IxFun num -> [Int] -> IxFun num
IxFun.permute IxFun
iota_ixfun ([Int] -> IxFun) -> [Int] -> IxFun
forall a b. (a -> b) -> a -> b
$ [Int] -> [Int]
rearrangeInverse [Int]
perm

-- | Uses linear/iota index function.
sAllocArray :: String -> PrimType -> ShapeBase SubExp -> Space -> ImpM lore r op VName
sAllocArray :: String
-> PrimType -> ShapeBase DimSize -> Space -> ImpM lore r op VName
sAllocArray String
name PrimType
pt ShapeBase DimSize
shape Space
space =
  String
-> PrimType
-> ShapeBase DimSize
-> Space
-> [Int]
-> ImpM lore r op VName
forall lore r op.
String
-> PrimType
-> ShapeBase DimSize
-> Space
-> [Int]
-> ImpM lore r op VName
sAllocArrayPerm String
name PrimType
pt ShapeBase DimSize
shape Space
space [Int
0..ShapeBase DimSize -> Int
forall a. ArrayShape a => a -> Int
shapeRank ShapeBase DimSize
shapeInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]

-- | Uses linear/iota index function.
sStaticArray :: String -> Space -> PrimType -> Imp.ArrayContents -> ImpM lore r op VName
sStaticArray :: String
-> Space -> PrimType -> ArrayContents -> ImpM lore r op VName
sStaticArray String
name Space
space PrimType
pt ArrayContents
vs = do
  let num_elems :: Int
num_elems = case ArrayContents
vs of Imp.ArrayValues [PrimValue]
vs' -> [PrimValue] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimValue]
vs'
                             Imp.ArrayZeros Int
n -> Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n
      shape :: ShapeBase DimSize
shape = [DimSize] -> ShapeBase DimSize
forall d. [d] -> ShapeBase d
Shape [IntType -> Integer -> DimSize
intConst IntType
Int32 (Integer -> DimSize) -> Integer -> DimSize
forall a b. (a -> b) -> a -> b
$ Int -> Integer
forall a. Integral a => a -> Integer
toInteger Int
num_elems]
  VName
mem <- String -> ImpM lore r op VName
forall lore r op. String -> ImpM lore r op VName
newVNameForFun (String -> ImpM lore r op VName) -> String -> ImpM lore r op VName
forall a b. (a -> b) -> a -> b
$ String
name String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_mem"
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Space -> PrimType -> ArrayContents -> Code op
forall a. VName -> Space -> PrimType -> ArrayContents -> Code a
Imp.DeclareArray VName
mem Space
space PrimType
pt ArrayContents
vs
  VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
mem (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> MemEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> MemEntry -> VarEntry lore
MemVar Maybe (Exp lore)
forall a. Maybe a
Nothing (MemEntry -> VarEntry lore) -> MemEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
space
  String
-> PrimType -> ShapeBase DimSize -> MemBind -> ImpM lore r op VName
forall lore r op.
String
-> PrimType -> ShapeBase DimSize -> MemBind -> ImpM lore r op VName
sArray String
name PrimType
pt ShapeBase DimSize
shape (MemBind -> ImpM lore r op VName)
-> MemBind -> ImpM lore r op VName
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn VName
mem (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$ Shape (PrimExp VName) -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota [Int -> PrimExp VName
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
num_elems]

sWrite :: VName -> [Imp.Exp] -> PrimExp Imp.ExpLeaf -> ImpM lore r op ()
sWrite :: VName -> [Exp] -> Exp -> ImpM lore r op ()
sWrite VName
arr [Exp]
is Exp
v = do
  (VName
mem, Space
space, Count Elements Exp
offset) <- VName -> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
forall lore r op.
VName -> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
fullyIndexArray VName
arr [Exp]
is
  Volatility
vol <- (Env lore r op -> Volatility) -> ImpM lore r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Volatility
forall lore r op. Env lore r op -> Volatility
envVolatility
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName
-> Count Elements Exp
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code op
forall a.
VName
-> Count Elements Exp
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
mem Count Elements Exp
offset (Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType Exp
v) Space
space Volatility
vol Exp
v

sUpdate :: VName -> Slice Imp.Exp -> SubExp -> ImpM lore r op ()
sUpdate :: VName -> [DimIndex Exp] -> DimSize -> ImpM lore r op ()
sUpdate VName
arr [DimIndex Exp]
slice DimSize
v = VName
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM VName
arr [DimIndex Exp]
slice DimSize
v []

sLoopNest :: Shape
          -> ([Imp.Exp] -> ImpM lore r op ())
          -> ImpM lore r op ()
sLoopNest :: ShapeBase DimSize
-> ([Exp] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest = [Exp]
-> [DimSize] -> ([Exp] -> ImpM lore r op ()) -> ImpM lore r op ()
forall a lore r op.
ToExp a =>
[Exp] -> [a] -> ([Exp] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest' [] ([DimSize] -> ([Exp] -> ImpM lore r op ()) -> ImpM lore r op ())
-> (ShapeBase DimSize -> [DimSize])
-> ShapeBase DimSize
-> ([Exp] -> ImpM lore r op ())
-> ImpM lore r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ShapeBase DimSize -> [DimSize]
forall d. ShapeBase d -> [d]
shapeDims
  where sLoopNest' :: [Exp] -> [a] -> ([Exp] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest' [Exp]
is [] [Exp] -> ImpM lore r op ()
f = [Exp] -> ImpM lore r op ()
f ([Exp] -> ImpM lore r op ()) -> [Exp] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ [Exp] -> [Exp]
forall a. [a] -> [a]
reverse [Exp]
is
        sLoopNest' [Exp]
is (a
d:[a]
ds) [Exp] -> ImpM lore r op ()
f = do
          Exp
d' <- a -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp a
d
          String -> Exp -> (Exp -> ImpM lore r op ()) -> ImpM lore r op ()
forall lore r op.
String -> Exp -> (Exp -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"nest_i" Exp
d' ((Exp -> ImpM lore r op ()) -> ImpM lore r op ())
-> (Exp -> ImpM lore r op ()) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \Exp
i -> [Exp] -> [a] -> ([Exp] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest' (Exp
iExp -> [Exp] -> [Exp]
forall a. a -> [a] -> [a]
:[Exp]
is) [a]
ds [Exp] -> ImpM lore r op ()
f

-- | ASsignment.
(<--) :: VName -> Imp.Exp -> ImpM lore r op ()
VName
x <-- :: VName -> Exp -> ImpM lore r op ()
<-- Exp
e = Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
x Exp
e
infixl 3 <--

-- | Constructing an ad-hoc function that does not
-- correspond to any of the IR functions in the input program.
function :: Name -> [Imp.Param] -> [Imp.Param] -> ImpM lore r op ()
         -> ImpM lore r op ()
function :: Name
-> [Param] -> [Param] -> ImpM lore r op () -> ImpM lore r op ()
function Name
fname [Param]
outputs [Param]
inputs ImpM lore r op ()
m = (Env lore r op -> Env lore r op)
-> ImpM lore r op () -> ImpM lore r op ()
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local Env lore r op -> Env lore r op
newFunction (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ do
  Code op
body <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (ImpM lore r op () -> ImpM lore r op (Code op))
-> ImpM lore r op () -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$ do
    (Param -> ImpM lore r op ()) -> [Param] -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Param -> ImpM lore r op ()
forall lore r op. Param -> ImpM lore r op ()
addParam ([Param] -> ImpM lore r op ()) -> [Param] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ [Param]
outputs [Param] -> [Param] -> [Param]
forall a. [a] -> [a] -> [a]
++ [Param]
inputs
    ImpM lore r op ()
m
  Name -> Function op -> ImpM lore r op ()
forall op lore r. Name -> Function op -> ImpM lore r op ()
emitFunction Name
fname (Function op -> ImpM lore r op ())
-> Function op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Bool
-> [Param]
-> [Param]
-> Code op
-> [ExternalValue]
-> [ExternalValue]
-> Function op
forall a.
Bool
-> [Param]
-> [Param]
-> Code a
-> [ExternalValue]
-> [ExternalValue]
-> FunctionT a
Imp.Function Bool
False [Param]
outputs [Param]
inputs Code op
body [] []
  where addParam :: Param -> ImpM lore r op ()
addParam (Imp.MemParam VName
name Space
space) =
          VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> MemEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> MemEntry -> VarEntry lore
MemVar Maybe (Exp lore)
forall a. Maybe a
Nothing (MemEntry -> VarEntry lore) -> MemEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
space
        addParam (Imp.ScalarParam VName
name PrimType
bt) =
          VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
ScalarVar Maybe (Exp lore)
forall a. Maybe a
Nothing (ScalarEntry -> VarEntry lore) -> ScalarEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarEntry
ScalarEntry PrimType
bt
        newFunction :: Env lore r op -> Env lore r op
newFunction Env lore r op
env = Env lore r op
env { envFunction :: Maybe Name
envFunction = Name -> Maybe Name
forall a. a -> Maybe a
Just Name
fname }