{-# LANGUAGE GeneralizedNewtypeDeriving, FlexibleContexts, LambdaCase, FlexibleInstances, MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE ConstraintKinds #-}
module Futhark.CodeGen.ImpGen
  ( -- * Entry Points
    compileProg

    -- * Pluggable Compiler
  , OpCompiler
  , ExpCompiler
  , CopyCompiler
  , StmsCompiler
  , AllocCompiler
  , Operations (..)
  , defaultOperations
  , MemLocation (..)
  , MemEntry (..)
  , ScalarEntry (..)

    -- * Monadic Compiler Interface
  , ImpM
  , localDefaultSpace, askFunction
  , askEnv, localEnv
  , VTable
  , getVTable
  , localVTable
  , subImpM
  , subImpM_
  , emit
  , emitFunction
  , hasFunction
  , collect
  , collect'
  , comment
  , VarEntry (..)
  , ArrayEntry (..)

    -- * Lookups
  , lookupVar
  , lookupArray
  , lookupMemory

    -- * Building Blocks
  , ToExp(..)
  , compileAlloc
  , everythingVolatile
  , compileBody
  , compileBody'
  , compileLoopBody
  , defCompileStms
  , compileStms
  , compileExp
  , defCompileExp
  , fullyIndexArray
  , fullyIndexArray'
  , copy
  , copyDWIM
  , copyDWIMFix
  , copyElementWise
  , typeSize

  -- * Constructing code.
  , dLParams
  , dFParams
  , dScope
  , dArray
  , dPrim, dPrimVol_, dPrim_, dPrimV_, dPrimV, dPrimVE

  , sFor, sWhile
  , sComment
  , sIf, sWhen, sUnless
  , sOp
  , sDeclareMem, sAlloc, sAlloc_
  , sArray, sAllocArray, sAllocArrayPerm, sStaticArray
  , sWrite, sUpdate
  , sLoopNest
  , (<--)

  , function
  )
  where

import Control.Monad.RWS    hiding (mapM, forM)
import Control.Monad.State  hiding (mapM, forM, State)
import Control.Monad.Writer hiding (mapM, forM)
import Data.Bifunctor (first)
import qualified Data.DList as DL
import Data.Either
import Data.Traversable
import qualified Data.Map.Strict as M
import qualified Data.Set as S
import Data.Maybe
import Data.List (find, sortOn)

import qualified Futhark.CodeGen.ImpCode as Imp
import Futhark.CodeGen.ImpCode
  (Bytes, Elements,
   bytes, elements, withElemType)
import Futhark.Representation.ExplicitMemory
import Futhark.Representation.SOACS (SOACS)
import qualified Futhark.Representation.ExplicitMemory.IndexFunction as IxFun
import Futhark.Construct (fullSliceNum)
import Futhark.MonadFreshNames
import Futhark.Util

-- | How to compile an 'Op'.
type OpCompiler lore r op = Pattern lore -> Op lore -> ImpM lore r op ()

-- | How to compile some 'Stms'.
type StmsCompiler lore r op = Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()

-- | How to compile an 'Exp'.
type ExpCompiler lore r op = Pattern lore -> Exp lore -> ImpM lore r op ()

type CopyCompiler lore r op = PrimType
                           -> MemLocation
                           -> MemLocation
                           -> ImpM lore r op ()

-- | An alternate way of compiling an allocation.
type AllocCompiler lore r op = VName -> Count Bytes Imp.Exp -> ImpM lore r op ()

data Operations lore r op = Operations { Operations lore r op -> ExpCompiler lore r op
opsExpCompiler :: ExpCompiler lore r op
                                     , Operations lore r op -> OpCompiler lore r op
opsOpCompiler :: OpCompiler lore r op
                                     , Operations lore r op -> StmsCompiler lore r op
opsStmsCompiler :: StmsCompiler lore r op
                                     , Operations lore r op -> CopyCompiler lore r op
opsCopyCompiler :: CopyCompiler lore r op
                                     , Operations lore r op -> Map Space (AllocCompiler lore r op)
opsAllocCompilers :: M.Map Space (AllocCompiler lore r op)
                                     }

-- | An operations set for which the expression compiler always
-- returns 'CompileExp'.
defaultOperations :: (ExplicitMemorish lore, FreeIn op) =>
                     OpCompiler lore r op -> Operations lore r op
defaultOperations :: OpCompiler lore r op -> Operations lore r op
defaultOperations OpCompiler lore r op
opc = Operations :: forall lore r op.
ExpCompiler lore r op
-> OpCompiler lore r op
-> StmsCompiler lore r op
-> CopyCompiler lore r op
-> Map Space (AllocCompiler lore r op)
-> Operations lore r op
Operations { opsExpCompiler :: ExpCompiler lore r op
opsExpCompiler = ExpCompiler lore r op
forall lore r op.
ExplicitMemorish lore =>
Pattern lore -> Exp lore -> ImpM lore r op ()
defCompileExp
                                   , opsOpCompiler :: OpCompiler lore r op
opsOpCompiler = OpCompiler lore r op
opc
                                   , opsStmsCompiler :: StmsCompiler lore r op
opsStmsCompiler = StmsCompiler lore r op
forall lore op r.
(ExplicitMemorish lore, FreeIn op) =>
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
defCompileStms
                                   , opsCopyCompiler :: CopyCompiler lore r op
opsCopyCompiler = CopyCompiler lore r op
forall lore r op. CopyCompiler lore r op
defaultCopy
                                   , opsAllocCompilers :: Map Space (AllocCompiler lore r op)
opsAllocCompilers = Map Space (AllocCompiler lore r op)
forall a. Monoid a => a
mempty
                                   }

-- | When an array is dared, this is where it is stored.
data MemLocation = MemLocation { MemLocation -> VName
memLocationName :: VName
                               , MemLocation -> [DimSize]
memLocationShape :: [Imp.DimSize]
                               , MemLocation -> IxFun Exp
memLocationIxFun :: IxFun.IxFun Imp.Exp
                               }
                   deriving (MemLocation -> MemLocation -> Bool
(MemLocation -> MemLocation -> Bool)
-> (MemLocation -> MemLocation -> Bool) -> Eq MemLocation
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: MemLocation -> MemLocation -> Bool
$c/= :: MemLocation -> MemLocation -> Bool
== :: MemLocation -> MemLocation -> Bool
$c== :: MemLocation -> MemLocation -> Bool
Eq, Int -> MemLocation -> ShowS
[MemLocation] -> ShowS
MemLocation -> String
(Int -> MemLocation -> ShowS)
-> (MemLocation -> String)
-> ([MemLocation] -> ShowS)
-> Show MemLocation
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MemLocation] -> ShowS
$cshowList :: [MemLocation] -> ShowS
show :: MemLocation -> String
$cshow :: MemLocation -> String
showsPrec :: Int -> MemLocation -> ShowS
$cshowsPrec :: Int -> MemLocation -> ShowS
Show)

data ArrayEntry = ArrayEntry {
    ArrayEntry -> MemLocation
entryArrayLocation :: MemLocation
  , ArrayEntry -> PrimType
entryArrayElemType :: PrimType
  }
  deriving (Int -> ArrayEntry -> ShowS
[ArrayEntry] -> ShowS
ArrayEntry -> String
(Int -> ArrayEntry -> ShowS)
-> (ArrayEntry -> String)
-> ([ArrayEntry] -> ShowS)
-> Show ArrayEntry
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ArrayEntry] -> ShowS
$cshowList :: [ArrayEntry] -> ShowS
show :: ArrayEntry -> String
$cshow :: ArrayEntry -> String
showsPrec :: Int -> ArrayEntry -> ShowS
$cshowsPrec :: Int -> ArrayEntry -> ShowS
Show)

entryArrayShape :: ArrayEntry -> [Imp.DimSize]
entryArrayShape :: ArrayEntry -> [DimSize]
entryArrayShape = MemLocation -> [DimSize]
memLocationShape (MemLocation -> [DimSize])
-> (ArrayEntry -> MemLocation) -> ArrayEntry -> [DimSize]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArrayEntry -> MemLocation
entryArrayLocation

newtype MemEntry = MemEntry { MemEntry -> Space
entryMemSpace :: Imp.Space }
  deriving (Int -> MemEntry -> ShowS
[MemEntry] -> ShowS
MemEntry -> String
(Int -> MemEntry -> ShowS)
-> (MemEntry -> String) -> ([MemEntry] -> ShowS) -> Show MemEntry
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MemEntry] -> ShowS
$cshowList :: [MemEntry] -> ShowS
show :: MemEntry -> String
$cshow :: MemEntry -> String
showsPrec :: Int -> MemEntry -> ShowS
$cshowsPrec :: Int -> MemEntry -> ShowS
Show)

newtype ScalarEntry = ScalarEntry {
    ScalarEntry -> PrimType
entryScalarType    :: PrimType
  }
  deriving (Int -> ScalarEntry -> ShowS
[ScalarEntry] -> ShowS
ScalarEntry -> String
(Int -> ScalarEntry -> ShowS)
-> (ScalarEntry -> String)
-> ([ScalarEntry] -> ShowS)
-> Show ScalarEntry
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ScalarEntry] -> ShowS
$cshowList :: [ScalarEntry] -> ShowS
show :: ScalarEntry -> String
$cshow :: ScalarEntry -> String
showsPrec :: Int -> ScalarEntry -> ShowS
$cshowsPrec :: Int -> ScalarEntry -> ShowS
Show)

-- | Every non-scalar variable must be associated with an entry.
data VarEntry lore = ArrayVar (Maybe (Exp lore)) ArrayEntry
                   | ScalarVar (Maybe (Exp lore)) ScalarEntry
                   | MemVar (Maybe (Exp lore)) MemEntry
                   deriving (Int -> VarEntry lore -> ShowS
[VarEntry lore] -> ShowS
VarEntry lore -> String
(Int -> VarEntry lore -> ShowS)
-> (VarEntry lore -> String)
-> ([VarEntry lore] -> ShowS)
-> Show (VarEntry lore)
forall lore. Annotations lore => Int -> VarEntry lore -> ShowS
forall lore. Annotations lore => [VarEntry lore] -> ShowS
forall lore. Annotations lore => VarEntry lore -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [VarEntry lore] -> ShowS
$cshowList :: forall lore. Annotations lore => [VarEntry lore] -> ShowS
show :: VarEntry lore -> String
$cshow :: forall lore. Annotations lore => VarEntry lore -> String
showsPrec :: Int -> VarEntry lore -> ShowS
$cshowsPrec :: forall lore. Annotations lore => Int -> VarEntry lore -> ShowS
Show)

-- | When compiling an expression, this is a description of where the
-- result should end up.  The integer is a reference to the construct
-- that gave rise to this destination (for patterns, this will be the
-- tag of the first name in the pattern).  This can be used to make
-- the generated code easier to relate to the original code.
data Destination = Destination { Destination -> Maybe Int
destinationTag :: Maybe Int
                               , Destination -> [ValueDestination]
valueDestinations :: [ValueDestination] }
                    deriving (Int -> Destination -> ShowS
[Destination] -> ShowS
Destination -> String
(Int -> Destination -> ShowS)
-> (Destination -> String)
-> ([Destination] -> ShowS)
-> Show Destination
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Destination] -> ShowS
$cshowList :: [Destination] -> ShowS
show :: Destination -> String
$cshow :: Destination -> String
showsPrec :: Int -> Destination -> ShowS
$cshowsPrec :: Int -> Destination -> ShowS
Show)

data ValueDestination = ScalarDestination VName
                      | MemoryDestination VName
                      | ArrayDestination (Maybe MemLocation)
                        -- ^ The 'MemLocation' is 'Just' if a copy if
                        -- required.  If it is 'Nothing', then a
                        -- copy/assignment of a memory block somewhere
                        -- takes care of this array.
                      deriving (Int -> ValueDestination -> ShowS
[ValueDestination] -> ShowS
ValueDestination -> String
(Int -> ValueDestination -> ShowS)
-> (ValueDestination -> String)
-> ([ValueDestination] -> ShowS)
-> Show ValueDestination
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ValueDestination] -> ShowS
$cshowList :: [ValueDestination] -> ShowS
show :: ValueDestination -> String
$cshow :: ValueDestination -> String
showsPrec :: Int -> ValueDestination -> ShowS
$cshowsPrec :: Int -> ValueDestination -> ShowS
Show)

data Env lore r op = Env {
    Env lore r op -> ExpCompiler lore r op
envExpCompiler :: ExpCompiler lore r op
  , Env lore r op -> StmsCompiler lore r op
envStmsCompiler :: StmsCompiler lore r op
  , Env lore r op -> OpCompiler lore r op
envOpCompiler :: OpCompiler lore r op
  , Env lore r op -> CopyCompiler lore r op
envCopyCompiler :: CopyCompiler lore r op
  , Env lore r op -> Map Space (AllocCompiler lore r op)
envAllocCompilers :: M.Map Space (AllocCompiler lore r op)
  , Env lore r op -> Space
envDefaultSpace :: Imp.Space
  , Env lore r op -> Volatility
envVolatility :: Imp.Volatility
  , Env lore r op -> r
envEnv :: r
    -- ^ User-extensible environment.
  , Env lore r op -> Maybe Name
envFunction :: Maybe Name
    -- ^ Name of the function we are compiling, if any.
  }

newEnv :: r -> Operations lore r op -> Imp.Space -> Env lore r op
newEnv :: r -> Operations lore r op -> Space -> Env lore r op
newEnv r
r Operations lore r op
ops Space
ds =
  Env :: forall lore r op.
ExpCompiler lore r op
-> StmsCompiler lore r op
-> OpCompiler lore r op
-> CopyCompiler lore r op
-> Map Space (AllocCompiler lore r op)
-> Space
-> Volatility
-> r
-> Maybe Name
-> Env lore r op
Env { envExpCompiler :: ExpCompiler lore r op
envExpCompiler = Operations lore r op -> ExpCompiler lore r op
forall lore r op. Operations lore r op -> ExpCompiler lore r op
opsExpCompiler Operations lore r op
ops
      , envStmsCompiler :: StmsCompiler lore r op
envStmsCompiler = Operations lore r op -> StmsCompiler lore r op
forall lore r op. Operations lore r op -> StmsCompiler lore r op
opsStmsCompiler Operations lore r op
ops
      , envOpCompiler :: OpCompiler lore r op
envOpCompiler = Operations lore r op -> OpCompiler lore r op
forall lore r op. Operations lore r op -> OpCompiler lore r op
opsOpCompiler Operations lore r op
ops
      , envCopyCompiler :: CopyCompiler lore r op
envCopyCompiler = Operations lore r op -> CopyCompiler lore r op
forall lore r op. Operations lore r op -> CopyCompiler lore r op
opsCopyCompiler Operations lore r op
ops
      , envAllocCompilers :: Map Space (AllocCompiler lore r op)
envAllocCompilers = Map Space (AllocCompiler lore r op)
forall a. Monoid a => a
mempty
      , envDefaultSpace :: Space
envDefaultSpace = Space
ds
      , envVolatility :: Volatility
envVolatility = Volatility
Imp.Nonvolatile
      , envEnv :: r
envEnv = r
r
      , envFunction :: Maybe Name
envFunction = Maybe Name
forall a. Maybe a
Nothing
      }

-- | The symbol table used during compilation.
type VTable lore = M.Map VName (VarEntry lore)

data State lore r op = State { State lore r op -> VTable lore
stateVTable :: VTable lore
                           , State lore r op -> Functions op
stateFunctions :: Imp.Functions op
                           , State lore r op -> VNameSource
stateNameSource :: VNameSource
                           }

newState :: VNameSource -> State lore r op
newState :: VNameSource -> State lore r op
newState = VTable lore -> Functions op -> VNameSource -> State lore r op
forall lore r op.
VTable lore -> Functions op -> VNameSource -> State lore r op
State VTable lore
forall a. Monoid a => a
mempty Functions op
forall a. Monoid a => a
mempty

newtype ImpM lore r op a = ImpM (RWS (Env lore r op) (Imp.Code op) (State lore r op) a)
  deriving (a -> ImpM lore r op b -> ImpM lore r op a
(a -> b) -> ImpM lore r op a -> ImpM lore r op b
(forall a b. (a -> b) -> ImpM lore r op a -> ImpM lore r op b)
-> (forall a b. a -> ImpM lore r op b -> ImpM lore r op a)
-> Functor (ImpM lore r op)
forall a b. a -> ImpM lore r op b -> ImpM lore r op a
forall a b. (a -> b) -> ImpM lore r op a -> ImpM lore r op b
forall lore r op a b. a -> ImpM lore r op b -> ImpM lore r op a
forall lore r op a b.
(a -> b) -> ImpM lore r op a -> ImpM lore r op b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> ImpM lore r op b -> ImpM lore r op a
$c<$ :: forall lore r op a b. a -> ImpM lore r op b -> ImpM lore r op a
fmap :: (a -> b) -> ImpM lore r op a -> ImpM lore r op b
$cfmap :: forall lore r op a b.
(a -> b) -> ImpM lore r op a -> ImpM lore r op b
Functor, Functor (ImpM lore r op)
a -> ImpM lore r op a
Functor (ImpM lore r op)
-> (forall a. a -> ImpM lore r op a)
-> (forall a b.
    ImpM lore r op (a -> b) -> ImpM lore r op a -> ImpM lore r op b)
-> (forall a b c.
    (a -> b -> c)
    -> ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op c)
-> (forall a b.
    ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b)
-> (forall a b.
    ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op a)
-> Applicative (ImpM lore r op)
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op a
ImpM lore r op (a -> b) -> ImpM lore r op a -> ImpM lore r op b
(a -> b -> c)
-> ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op c
forall a. a -> ImpM lore r op a
forall a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op a
forall a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
forall a b.
ImpM lore r op (a -> b) -> ImpM lore r op a -> ImpM lore r op b
forall lore r op. Functor (ImpM lore r op)
forall a b c.
(a -> b -> c)
-> ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op c
forall lore r op a. a -> ImpM lore r op a
forall lore r op a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op a
forall lore r op a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
forall lore r op a b.
ImpM lore r op (a -> b) -> ImpM lore r op a -> ImpM lore r op b
forall lore r op a b c.
(a -> b -> c)
-> ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op a
$c<* :: forall lore r op a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op a
*> :: ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
$c*> :: forall lore r op a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
liftA2 :: (a -> b -> c)
-> ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op c
$cliftA2 :: forall lore r op a b c.
(a -> b -> c)
-> ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op c
<*> :: ImpM lore r op (a -> b) -> ImpM lore r op a -> ImpM lore r op b
$c<*> :: forall lore r op a b.
ImpM lore r op (a -> b) -> ImpM lore r op a -> ImpM lore r op b
pure :: a -> ImpM lore r op a
$cpure :: forall lore r op a. a -> ImpM lore r op a
$cp1Applicative :: forall lore r op. Functor (ImpM lore r op)
Applicative, Applicative (ImpM lore r op)
a -> ImpM lore r op a
Applicative (ImpM lore r op)
-> (forall a b.
    ImpM lore r op a -> (a -> ImpM lore r op b) -> ImpM lore r op b)
-> (forall a b.
    ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b)
-> (forall a. a -> ImpM lore r op a)
-> Monad (ImpM lore r op)
ImpM lore r op a -> (a -> ImpM lore r op b) -> ImpM lore r op b
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
forall a. a -> ImpM lore r op a
forall a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
forall a b.
ImpM lore r op a -> (a -> ImpM lore r op b) -> ImpM lore r op b
forall lore r op. Applicative (ImpM lore r op)
forall lore r op a. a -> ImpM lore r op a
forall lore r op a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
forall lore r op a b.
ImpM lore r op a -> (a -> ImpM lore r op b) -> ImpM lore r op b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> ImpM lore r op a
$creturn :: forall lore r op a. a -> ImpM lore r op a
>> :: ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
$c>> :: forall lore r op a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
>>= :: ImpM lore r op a -> (a -> ImpM lore r op b) -> ImpM lore r op b
$c>>= :: forall lore r op a b.
ImpM lore r op a -> (a -> ImpM lore r op b) -> ImpM lore r op b
$cp1Monad :: forall lore r op. Applicative (ImpM lore r op)
Monad,
            MonadState (State lore r op),
            MonadReader (Env lore r op),
            MonadWriter (Imp.Code op))

instance MonadFreshNames (ImpM lore r op) where
  getNameSource :: ImpM lore r op VNameSource
getNameSource = (State lore r op -> VNameSource) -> ImpM lore r op VNameSource
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets State lore r op -> VNameSource
forall lore r op. State lore r op -> VNameSource
stateNameSource
  putNameSource :: VNameSource -> ImpM lore r op ()
putNameSource VNameSource
src = (State lore r op -> State lore r op) -> ImpM lore r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((State lore r op -> State lore r op) -> ImpM lore r op ())
-> (State lore r op -> State lore r op) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \State lore r op
s -> State lore r op
s { stateNameSource :: VNameSource
stateNameSource = VNameSource
src }

-- Cannot be an ExplicitMemory scope because the index functions have
-- the wrong leaves (VName instead of Imp.Exp).
instance HasScope SOACS (ImpM lore r op) where
  askScope :: ImpM lore r op (Scope SOACS)
askScope = (VarEntry lore -> NameInfo SOACS)
-> Map VName (VarEntry lore) -> Scope SOACS
forall a b k. (a -> b) -> Map k a -> Map k b
M.map (Type -> NameInfo SOACS
forall lore. LetAttr lore -> NameInfo lore
LetInfo (Type -> NameInfo SOACS)
-> (VarEntry lore -> Type) -> VarEntry lore -> NameInfo SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VarEntry lore -> Type
forall lore. VarEntry lore -> Type
entryType) (Map VName (VarEntry lore) -> Scope SOACS)
-> ImpM lore r op (Map VName (VarEntry lore))
-> ImpM lore r op (Scope SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (State lore r op -> Map VName (VarEntry lore))
-> ImpM lore r op (Map VName (VarEntry lore))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets State lore r op -> Map VName (VarEntry lore)
forall lore r op. State lore r op -> VTable lore
stateVTable
    where entryType :: VarEntry lore -> Type
entryType (MemVar Maybe (Exp lore)
_ MemEntry
memEntry) =
            Space -> Type
forall shape u. Space -> TypeBase shape u
Mem (MemEntry -> Space
entryMemSpace MemEntry
memEntry)
          entryType (ArrayVar Maybe (Exp lore)
_ ArrayEntry
arrayEntry) =
            PrimType -> ShapeBase DimSize -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array
            (ArrayEntry -> PrimType
entryArrayElemType ArrayEntry
arrayEntry)
            ([DimSize] -> ShapeBase DimSize
forall d. [d] -> ShapeBase d
Shape ([DimSize] -> ShapeBase DimSize) -> [DimSize] -> ShapeBase DimSize
forall a b. (a -> b) -> a -> b
$ ArrayEntry -> [DimSize]
entryArrayShape ArrayEntry
arrayEntry)
            NoUniqueness
NoUniqueness
          entryType (ScalarVar Maybe (Exp lore)
_ ScalarEntry
scalarEntry) =
            PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> Type) -> PrimType -> Type
forall a b. (a -> b) -> a -> b
$ ScalarEntry -> PrimType
entryScalarType ScalarEntry
scalarEntry

runImpM :: ImpM lore r op a
        -> r -> Operations lore r op -> Imp.Space -> State lore r op
        -> (a, State lore r op, Imp.Code op)
runImpM :: ImpM lore r op a
-> r
-> Operations lore r op
-> Space
-> State lore r op
-> (a, State lore r op, Code op)
runImpM (ImpM RWS (Env lore r op) (Code op) (State lore r op) a
m) r
r Operations lore r op
ops Space
space = RWS (Env lore r op) (Code op) (State lore r op) a
-> Env lore r op
-> State lore r op
-> (a, State lore r op, Code op)
forall r w s a. RWS r w s a -> r -> s -> (a, s, w)
runRWS RWS (Env lore r op) (Code op) (State lore r op) a
m (Env lore r op -> State lore r op -> (a, State lore r op, Code op))
-> Env lore r op
-> State lore r op
-> (a, State lore r op, Code op)
forall a b. (a -> b) -> a -> b
$ r -> Operations lore r op -> Space -> Env lore r op
forall r lore op.
r -> Operations lore r op -> Space -> Env lore r op
newEnv r
r Operations lore r op
ops Space
space

subImpM_ :: r' -> Operations lore r' op' -> ImpM lore r' op' a
         -> ImpM lore r op (Imp.Code op')
subImpM_ :: r'
-> Operations lore r' op'
-> ImpM lore r' op' a
-> ImpM lore r op (Code op')
subImpM_ r'
r Operations lore r' op'
ops ImpM lore r' op' a
m = (a, Code op') -> Code op'
forall a b. (a, b) -> b
snd ((a, Code op') -> Code op')
-> ImpM lore r op (a, Code op') -> ImpM lore r op (Code op')
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> r'
-> Operations lore r' op'
-> ImpM lore r' op' a
-> ImpM lore r op (a, Code op')
forall r' lore op' a r op.
r'
-> Operations lore r' op'
-> ImpM lore r' op' a
-> ImpM lore r op (a, Code op')
subImpM r'
r Operations lore r' op'
ops ImpM lore r' op' a
m

subImpM :: r' -> Operations lore r' op' -> ImpM lore r' op' a
        -> ImpM lore r op (a, Imp.Code op')
subImpM :: r'
-> Operations lore r' op'
-> ImpM lore r' op' a
-> ImpM lore r op (a, Code op')
subImpM r'
r Operations lore r' op'
ops (ImpM RWS (Env lore r' op') (Code op') (State lore r' op') a
m) = do
  Env lore r op
env <- ImpM lore r op (Env lore r op)
forall r (m :: * -> *). MonadReader r m => m r
ask
  State lore r op
s <- ImpM lore r op (State lore r op)
forall s (m :: * -> *). MonadState s m => m s
get
  let (a
x, State lore r' op'
s', Code op'
code) =
        RWS (Env lore r' op') (Code op') (State lore r' op') a
-> Env lore r' op'
-> State lore r' op'
-> (a, State lore r' op', Code op')
forall r w s a. RWS r w s a -> r -> s -> (a, s, w)
runRWS RWS (Env lore r' op') (Code op') (State lore r' op') a
m Env lore r op
env { envExpCompiler :: ExpCompiler lore r' op'
envExpCompiler = Operations lore r' op' -> ExpCompiler lore r' op'
forall lore r op. Operations lore r op -> ExpCompiler lore r op
opsExpCompiler Operations lore r' op'
ops
                     , envStmsCompiler :: StmsCompiler lore r' op'
envStmsCompiler = Operations lore r' op' -> StmsCompiler lore r' op'
forall lore r op. Operations lore r op -> StmsCompiler lore r op
opsStmsCompiler Operations lore r' op'
ops
                     , envCopyCompiler :: CopyCompiler lore r' op'
envCopyCompiler = Operations lore r' op' -> CopyCompiler lore r' op'
forall lore r op. Operations lore r op -> CopyCompiler lore r op
opsCopyCompiler Operations lore r' op'
ops
                     , envOpCompiler :: OpCompiler lore r' op'
envOpCompiler = Operations lore r' op' -> OpCompiler lore r' op'
forall lore r op. Operations lore r op -> OpCompiler lore r op
opsOpCompiler Operations lore r' op'
ops
                     , envAllocCompilers :: Map Space (AllocCompiler lore r' op')
envAllocCompilers = Operations lore r' op' -> Map Space (AllocCompiler lore r' op')
forall lore r op.
Operations lore r op -> Map Space (AllocCompiler lore r op)
opsAllocCompilers Operations lore r' op'
ops
                     , envEnv :: r'
envEnv = r'
r
                     }
                 State lore r op
s { stateVTable :: VTable lore
stateVTable = State lore r op -> VTable lore
forall lore r op. State lore r op -> VTable lore
stateVTable State lore r op
s
                   , stateFunctions :: Functions op'
stateFunctions = Functions op'
forall a. Monoid a => a
mempty }
  VNameSource -> ImpM lore r op ()
forall (m :: * -> *). MonadFreshNames m => VNameSource -> m ()
putNameSource (VNameSource -> ImpM lore r op ())
-> VNameSource -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ State lore r' op' -> VNameSource
forall lore r op. State lore r op -> VNameSource
stateNameSource State lore r' op'
s'
  (a, Code op') -> ImpM lore r op (a, Code op')
forall (m :: * -> *) a. Monad m => a -> m a
return (a
x, Code op'
code)

-- | Execute a code generation action, returning the code that was
-- emitted.
collect :: ImpM lore r op () -> ImpM lore r op (Imp.Code op)
collect :: ImpM lore r op () -> ImpM lore r op (Code op)
collect ImpM lore r op ()
m = ImpM lore r op (Code op, Code op -> Code op)
-> ImpM lore r op (Code op)
forall w (m :: * -> *) a. MonadWriter w m => m (a, w -> w) -> m a
pass (ImpM lore r op (Code op, Code op -> Code op)
 -> ImpM lore r op (Code op))
-> ImpM lore r op (Code op, Code op -> Code op)
-> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$ do
  ((), Code op
code) <- ImpM lore r op () -> ImpM lore r op ((), Code op)
forall w (m :: * -> *) a. MonadWriter w m => m a -> m (a, w)
listen ImpM lore r op ()
m
  (Code op, Code op -> Code op)
-> ImpM lore r op (Code op, Code op -> Code op)
forall (m :: * -> *) a. Monad m => a -> m a
return (Code op
code, Code op -> Code op -> Code op
forall a b. a -> b -> a
const Code op
forall a. Monoid a => a
mempty)

collect' :: ImpM lore r op a -> ImpM lore r op (a, Imp.Code op)
collect' :: ImpM lore r op a -> ImpM lore r op (a, Code op)
collect' ImpM lore r op a
m = ImpM lore r op ((a, Code op), Code op -> Code op)
-> ImpM lore r op (a, Code op)
forall w (m :: * -> *) a. MonadWriter w m => m (a, w -> w) -> m a
pass (ImpM lore r op ((a, Code op), Code op -> Code op)
 -> ImpM lore r op (a, Code op))
-> ImpM lore r op ((a, Code op), Code op -> Code op)
-> ImpM lore r op (a, Code op)
forall a b. (a -> b) -> a -> b
$ do
  (a
x, Code op
code) <- ImpM lore r op a -> ImpM lore r op (a, Code op)
forall w (m :: * -> *) a. MonadWriter w m => m a -> m (a, w)
listen ImpM lore r op a
m
  ((a, Code op), Code op -> Code op)
-> ImpM lore r op ((a, Code op), Code op -> Code op)
forall (m :: * -> *) a. Monad m => a -> m a
return ((a
x, Code op
code), Code op -> Code op -> Code op
forall a b. a -> b -> a
const Code op
forall a. Monoid a => a
mempty)

-- | Execute a code generation action, wrapping the generated code
-- within a 'Imp.Comment' with the given description.
comment :: String -> ImpM lore r op () -> ImpM lore r op ()
comment :: String -> ImpM lore r op () -> ImpM lore r op ()
comment String
desc ImpM lore r op ()
m = do Code op
code <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect ImpM lore r op ()
m
                    Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ String -> Code op -> Code op
forall a. String -> Code a -> Code a
Imp.Comment String
desc Code op
code

-- | Emit some generated imperative code.
emit :: Imp.Code op -> ImpM lore r op ()
emit :: Code op -> ImpM lore r op ()
emit = Code op -> ImpM lore r op ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell

-- | Emit a function in the generated code.
emitFunction :: Name -> Imp.Function op -> ImpM lore r op ()
emitFunction :: Name -> Function op -> ImpM lore r op ()
emitFunction Name
fname Function op
fun = do
  Imp.Functions [(Name, Function op)]
fs <- (State lore r op -> Functions op) -> ImpM lore r op (Functions op)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets State lore r op -> Functions op
forall lore r op. State lore r op -> Functions op
stateFunctions
  (State lore r op -> State lore r op) -> ImpM lore r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((State lore r op -> State lore r op) -> ImpM lore r op ())
-> (State lore r op -> State lore r op) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \State lore r op
s -> State lore r op
s { stateFunctions :: Functions op
stateFunctions = [(Name, Function op)] -> Functions op
forall a. [(Name, Function a)] -> Functions a
Imp.Functions ([(Name, Function op)] -> Functions op)
-> [(Name, Function op)] -> Functions op
forall a b. (a -> b) -> a -> b
$ (Name
fname,Function op
fun) (Name, Function op)
-> [(Name, Function op)] -> [(Name, Function op)]
forall a. a -> [a] -> [a]
: [(Name, Function op)]
fs }

-- | Check if a function of a given name exists.
hasFunction :: Name -> ImpM lore r op Bool
hasFunction :: Name -> ImpM lore r op Bool
hasFunction Name
fname = (State lore r op -> Bool) -> ImpM lore r op Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((State lore r op -> Bool) -> ImpM lore r op Bool)
-> (State lore r op -> Bool) -> ImpM lore r op Bool
forall a b. (a -> b) -> a -> b
$ \State lore r op
s -> let Imp.Functions [(Name, Function op)]
fs = State lore r op -> Functions op
forall lore r op. State lore r op -> Functions op
stateFunctions State lore r op
s
                                 in Maybe (Function op) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe (Function op) -> Bool) -> Maybe (Function op) -> Bool
forall a b. (a -> b) -> a -> b
$ Name -> [(Name, Function op)] -> Maybe (Function op)
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Name
fname [(Name, Function op)]
fs

constsVTable :: LetAttr lore ~ LetAttr ExplicitMemory =>
                Stms lore -> VTable lore
constsVTable :: Stms lore -> VTable lore
constsVTable = (Stm lore -> VTable lore) -> Stms lore -> VTable lore
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm lore -> VTable lore
forall lore.
(LetAttr lore ~ MemInfo DimSize NoUniqueness MemBind) =>
Stm lore -> Map VName (VarEntry lore)
stmVtable
  where stmVtable :: Stm lore -> Map VName (VarEntry lore)
stmVtable (Let Pattern lore
pat StmAux (ExpAttr lore)
_ Exp lore
e) =
          ((VName, NameInfo ExplicitMemory) -> Map VName (VarEntry lore))
-> [(VName, NameInfo ExplicitMemory)] -> Map VName (VarEntry lore)
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Exp lore
-> (VName, NameInfo ExplicitMemory) -> Map VName (VarEntry lore)
forall lore k.
Exp lore -> (k, NameInfo ExplicitMemory) -> Map k (VarEntry lore)
peVtable Exp lore
e) ([(VName, NameInfo ExplicitMemory)] -> Map VName (VarEntry lore))
-> [(VName, NameInfo ExplicitMemory)] -> Map VName (VarEntry lore)
forall a b. (a -> b) -> a -> b
$ Map VName (NameInfo ExplicitMemory)
-> [(VName, NameInfo ExplicitMemory)]
forall k a. Map k a -> [(k, a)]
M.toList (Map VName (NameInfo ExplicitMemory)
 -> [(VName, NameInfo ExplicitMemory)])
-> Map VName (NameInfo ExplicitMemory)
-> [(VName, NameInfo ExplicitMemory)]
forall a b. (a -> b) -> a -> b
$
          [Map VName (NameInfo ExplicitMemory)]
-> Map VName (NameInfo ExplicitMemory)
forall a. Monoid a => [a] -> a
mconcat ([Map VName (NameInfo ExplicitMemory)]
 -> Map VName (NameInfo ExplicitMemory))
-> [Map VName (NameInfo ExplicitMemory)]
-> Map VName (NameInfo ExplicitMemory)
forall a b. (a -> b) -> a -> b
$ (PatElemT (MemInfo DimSize NoUniqueness MemBind)
 -> Map VName (NameInfo ExplicitMemory))
-> [PatElemT (MemInfo DimSize NoUniqueness MemBind)]
-> [Map VName (NameInfo ExplicitMemory)]
forall a b. (a -> b) -> [a] -> [b]
map PatElemT (MemInfo DimSize NoUniqueness MemBind)
-> Map VName (NameInfo ExplicitMemory)
forall lore attr.
(LetAttr lore ~ attr) =>
PatElemT attr -> Scope lore
scopeOfPatElem ([PatElemT (MemInfo DimSize NoUniqueness MemBind)]
 -> [Map VName (NameInfo ExplicitMemory)])
-> [PatElemT (MemInfo DimSize NoUniqueness MemBind)]
-> [Map VName (NameInfo ExplicitMemory)]
forall a b. (a -> b) -> a -> b
$ PatternT (MemInfo DimSize NoUniqueness MemBind)
-> [PatElemT (MemInfo DimSize NoUniqueness MemBind)]
forall attr. PatternT attr -> [PatElemT attr]
patternElements Pattern lore
PatternT (MemInfo DimSize NoUniqueness MemBind)
pat
        peVtable :: Exp lore -> (k, NameInfo ExplicitMemory) -> Map k (VarEntry lore)
peVtable Exp lore
e (k
name, NameInfo ExplicitMemory
info) =
          k -> VarEntry lore -> Map k (VarEntry lore)
forall k a. k -> a -> Map k a
M.singleton k
name (VarEntry lore -> Map k (VarEntry lore))
-> VarEntry lore -> Map k (VarEntry lore)
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore)
-> MemInfo DimSize NoUniqueness MemBind -> VarEntry lore
forall lore.
Maybe (Exp lore)
-> MemInfo DimSize NoUniqueness MemBind -> VarEntry lore
memBoundToVarEntry (Exp lore -> Maybe (Exp lore)
forall a. a -> Maybe a
Just Exp lore
e) (MemInfo DimSize NoUniqueness MemBind -> VarEntry lore)
-> MemInfo DimSize NoUniqueness MemBind -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ NameInfo ExplicitMemory -> MemInfo DimSize NoUniqueness MemBind
infoAttr NameInfo ExplicitMemory
info

compileProg :: (ExplicitMemorish lore, FreeIn op, MonadFreshNames m) =>
               r -> Operations lore r op -> Imp.Space
            -> Prog lore -> m (Imp.Definitions op)
compileProg :: r
-> Operations lore r op -> Space -> Prog lore -> m (Definitions op)
compileProg r
r Operations lore r op
ops Space
space (Prog Stms lore
consts [FunDef lore]
funs) =
  (VNameSource -> (Definitions op, VNameSource))
-> m (Definitions op)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Definitions op, VNameSource))
 -> m (Definitions op))
-> (VNameSource -> (Definitions op, VNameSource))
-> m (Definitions op)
forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
  let (Constants op
consts', State lore r op
s', Code op
_) =
        ImpM lore r op (Constants op)
-> r
-> Operations lore r op
-> Space
-> State lore r op
-> (Constants op, State lore r op, Code op)
forall lore r op a.
ImpM lore r op a
-> r
-> Operations lore r op
-> Space
-> State lore r op
-> (a, State lore r op, Code op)
runImpM ImpM lore r op (Constants op)
compile r
r Operations lore r op
ops Space
space
        (VNameSource -> State Any Any op
forall lore r op. VNameSource -> State lore r op
newState VNameSource
src) { stateVTable :: VTable lore
stateVTable = Stms lore -> VTable lore
forall lore.
(LetAttr lore ~ LetAttr ExplicitMemory) =>
Stms lore -> VTable lore
constsVTable Stms lore
consts }
  in (Constants op -> Functions op -> Definitions op
forall a. Constants a -> Functions a -> Definitions a
Imp.Definitions Constants op
consts' (State lore r op -> Functions op
forall lore r op. State lore r op -> Functions op
stateFunctions State lore r op
s'),
      State lore r op -> VNameSource
forall lore r op. State lore r op -> VNameSource
stateNameSource State lore r op
s')
  where compile :: ImpM lore r op (Constants op)
compile = do
          (FunDef lore -> ImpM lore r op ())
-> [FunDef lore] -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ FunDef lore -> ImpM lore r op ()
forall lore r op.
(Checkable lore, OpReturns lore,
 LetAttr lore ~ MemInfo DimSize NoUniqueness MemBind,
 BranchType lore ~ BodyReturns,
 FParamAttr lore ~ MemInfo DimSize Uniqueness MemBind,
 RetType lore ~ FunReturns,
 LParamAttr lore ~ MemInfo DimSize NoUniqueness MemBind) =>
FunDef lore -> ImpM lore r op ()
compileFunDef' [FunDef lore]
funs
          Names
free_in_funs <- (State lore r op -> Names) -> ImpM lore r op Names
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (Functions op -> Names
forall a. FreeIn a => a -> Names
freeIn (Functions op -> Names)
-> (State lore r op -> Functions op) -> State lore r op -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. State lore r op -> Functions op
forall lore r op. State lore r op -> Functions op
stateFunctions)
          Names -> Stms lore -> ImpM lore r op (Constants op)
forall lore r op.
Names -> Stms lore -> ImpM lore r op (Constants op)
compileConsts Names
free_in_funs Stms lore
consts

        compileFunDef' :: FunDef lore -> ImpM lore r op ()
compileFunDef' FunDef lore
fdef =
          (Env lore r op -> Env lore r op)
-> ImpM lore r op () -> ImpM lore r op ()
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (\Env lore r op
env -> Env lore r op
env { envFunction :: Maybe Name
envFunction = Name -> Maybe Name
forall a. a -> Maybe a
Just (Name -> Maybe Name) -> Name -> Maybe Name
forall a b. (a -> b) -> a -> b
$ FunDef lore -> Name
forall lore. FunDef lore -> Name
funDefName FunDef lore
fdef }) (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
          FunDef lore -> ImpM lore r op ()
forall lore r op.
ExplicitMemorish lore =>
FunDef lore -> ImpM lore r op ()
compileFunDef FunDef lore
fdef

compileConsts :: Names -> Stms lore -> ImpM lore r op (Imp.Constants op)
compileConsts :: Names -> Stms lore -> ImpM lore r op (Constants op)
compileConsts Names
used_consts Stms lore
stms = do
  Code op
code <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (ImpM lore r op () -> ImpM lore r op (Code op))
-> ImpM lore r op () -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$ Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
used_consts Stms lore
stms (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ () -> ImpM lore r op ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  Constants op -> ImpM lore r op (Constants op)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Constants op -> ImpM lore r op (Constants op))
-> Constants op -> ImpM lore r op (Constants op)
forall a b. (a -> b) -> a -> b
$ ([Param] -> Code op -> Constants op)
-> ([Param], Code op) -> Constants op
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry [Param] -> Code op -> Constants op
forall a. [Param] -> Code a -> Constants a
Imp.Constants (([Param], Code op) -> Constants op)
-> ([Param], Code op) -> Constants op
forall a b. (a -> b) -> a -> b
$ (DList Param -> [Param])
-> (DList Param, Code op) -> ([Param], Code op)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first DList Param -> [Param]
forall a. DList a -> [a]
DL.toList ((DList Param, Code op) -> ([Param], Code op))
-> (DList Param, Code op) -> ([Param], Code op)
forall a b. (a -> b) -> a -> b
$ Code op -> (DList Param, Code op)
extract Code op
code
  where -- Fish out those top-level declarations in the constant
        -- initialisation code that are free in the functions.
        extract :: Code op -> (DList Param, Code op)
extract (Code op
x Imp.:>>: Code op
y) =
          Code op -> (DList Param, Code op)
extract Code op
x (DList Param, Code op)
-> (DList Param, Code op) -> (DList Param, Code op)
forall a. Semigroup a => a -> a -> a
<> Code op -> (DList Param, Code op)
extract Code op
y
        extract (Imp.DeclareMem VName
name Space
space)
          | VName
name VName -> Names -> Bool
`nameIn` Names
used_consts =
              (Param -> DList Param
forall a. a -> DList a
DL.singleton (Param -> DList Param) -> Param -> DList Param
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Param
Imp.MemParam VName
name Space
space,
               Code op
forall a. Monoid a => a
mempty)
        extract (Imp.DeclareScalar VName
name Volatility
_ PrimType
t)
          | VName
name VName -> Names -> Bool
`nameIn` Names
used_consts =
              (Param -> DList Param
forall a. a -> DList a
DL.singleton (Param -> DList Param) -> Param -> DList Param
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Param
Imp.ScalarParam VName
name PrimType
t,
               Code op
forall a. Monoid a => a
mempty)
        extract Code op
s =
          (DList Param
forall a. Monoid a => a
mempty, Code op
s)

compileInParam :: ExplicitMemorish lore =>
                  FParam lore -> ImpM lore r op (Either Imp.Param ArrayDecl)
compileInParam :: FParam lore -> ImpM lore r op (Either Param ArrayDecl)
compileInParam FParam lore
fparam = case Param (MemInfo DimSize Uniqueness MemBind)
-> MemInfo DimSize Uniqueness MemBind
forall attr. Param attr -> attr
paramAttr FParam lore
Param (MemInfo DimSize Uniqueness MemBind)
fparam of
  MemPrim PrimType
bt ->
    Either Param ArrayDecl -> ImpM lore r op (Either Param ArrayDecl)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either Param ArrayDecl -> ImpM lore r op (Either Param ArrayDecl))
-> Either Param ArrayDecl
-> ImpM lore r op (Either Param ArrayDecl)
forall a b. (a -> b) -> a -> b
$ Param -> Either Param ArrayDecl
forall a b. a -> Either a b
Left (Param -> Either Param ArrayDecl)
-> Param -> Either Param ArrayDecl
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Param
Imp.ScalarParam VName
name PrimType
bt
  MemMem Space
space ->
    Either Param ArrayDecl -> ImpM lore r op (Either Param ArrayDecl)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either Param ArrayDecl -> ImpM lore r op (Either Param ArrayDecl))
-> Either Param ArrayDecl
-> ImpM lore r op (Either Param ArrayDecl)
forall a b. (a -> b) -> a -> b
$ Param -> Either Param ArrayDecl
forall a b. a -> Either a b
Left (Param -> Either Param ArrayDecl)
-> Param -> Either Param ArrayDecl
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Param
Imp.MemParam VName
name Space
space
  MemArray PrimType
bt ShapeBase DimSize
shape Uniqueness
_ (ArrayIn VName
mem IxFun
ixfun) ->
    Either Param ArrayDecl -> ImpM lore r op (Either Param ArrayDecl)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either Param ArrayDecl -> ImpM lore r op (Either Param ArrayDecl))
-> Either Param ArrayDecl
-> ImpM lore r op (Either Param ArrayDecl)
forall a b. (a -> b) -> a -> b
$ ArrayDecl -> Either Param ArrayDecl
forall a b. b -> Either a b
Right (ArrayDecl -> Either Param ArrayDecl)
-> ArrayDecl -> Either Param ArrayDecl
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> MemLocation -> ArrayDecl
ArrayDecl VName
name PrimType
bt (MemLocation -> ArrayDecl) -> MemLocation -> ArrayDecl
forall a b. (a -> b) -> a -> b
$
    VName -> [DimSize] -> IxFun Exp -> MemLocation
MemLocation VName
mem (ShapeBase DimSize -> [DimSize]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase DimSize
shape) (IxFun Exp -> MemLocation) -> IxFun Exp -> MemLocation
forall a b. (a -> b) -> a -> b
$ (PrimExp VName -> Exp) -> IxFun -> IxFun Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (PrimType -> PrimExp VName -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32) IxFun
ixfun
  where name :: VName
name = Param (MemInfo DimSize Uniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName FParam lore
Param (MemInfo DimSize Uniqueness MemBind)
fparam

data ArrayDecl = ArrayDecl VName PrimType MemLocation

fparamSizes :: Typed attr => Param attr -> S.Set VName
fparamSizes :: Param attr -> Set VName
fparamSizes = [VName] -> Set VName
forall a. Ord a => [a] -> Set a
S.fromList ([VName] -> Set VName)
-> (Param attr -> [VName]) -> Param attr -> Set VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [DimSize] -> [VName]
subExpVars ([DimSize] -> [VName])
-> (Param attr -> [DimSize]) -> Param attr -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> [DimSize]
forall u. TypeBase (ShapeBase DimSize) u -> [DimSize]
arrayDims (Type -> [DimSize])
-> (Param attr -> Type) -> Param attr -> [DimSize]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param attr -> Type
forall attr. Typed attr => Param attr -> Type
paramType

compileInParams :: ExplicitMemorish lore =>
                   [FParam lore] -> [EntryPointType]
                -> ImpM lore r op ([Imp.Param], [ArrayDecl], [Imp.ExternalValue])
compileInParams :: [FParam lore]
-> [EntryPointType]
-> ImpM lore r op ([Param], [ArrayDecl], [ExternalValue])
compileInParams [FParam lore]
params [EntryPointType]
orig_epts = do
  let ([Param (MemInfo DimSize Uniqueness MemBind)]
ctx_params, [Param (MemInfo DimSize Uniqueness MemBind)]
val_params) =
        Int
-> [Param (MemInfo DimSize Uniqueness MemBind)]
-> ([Param (MemInfo DimSize Uniqueness MemBind)],
    [Param (MemInfo DimSize Uniqueness MemBind)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([Param (MemInfo DimSize Uniqueness MemBind)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [FParam lore]
[Param (MemInfo DimSize Uniqueness MemBind)]
params Int -> Int -> Int
forall a. Num a => a -> a -> a
- [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((EntryPointType -> Int) -> [EntryPointType] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map EntryPointType -> Int
entryPointSize [EntryPointType]
orig_epts)) [FParam lore]
[Param (MemInfo DimSize Uniqueness MemBind)]
params
  ([Param]
inparams, [ArrayDecl]
arrayds) <- [Either Param ArrayDecl] -> ([Param], [ArrayDecl])
forall a b. [Either a b] -> ([a], [b])
partitionEithers ([Either Param ArrayDecl] -> ([Param], [ArrayDecl]))
-> ImpM lore r op [Either Param ArrayDecl]
-> ImpM lore r op ([Param], [ArrayDecl])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Param (MemInfo DimSize Uniqueness MemBind)
 -> ImpM lore r op (Either Param ArrayDecl))
-> [Param (MemInfo DimSize Uniqueness MemBind)]
-> ImpM lore r op [Either Param ArrayDecl]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Param (MemInfo DimSize Uniqueness MemBind)
-> ImpM lore r op (Either Param ArrayDecl)
forall lore r op.
ExplicitMemorish lore =>
FParam lore -> ImpM lore r op (Either Param ArrayDecl)
compileInParam ([Param (MemInfo DimSize Uniqueness MemBind)]
ctx_params[Param (MemInfo DimSize Uniqueness MemBind)]
-> [Param (MemInfo DimSize Uniqueness MemBind)]
-> [Param (MemInfo DimSize Uniqueness MemBind)]
forall a. [a] -> [a] -> [a]
++[Param (MemInfo DimSize Uniqueness MemBind)]
val_params)
  let findArray :: VName -> Maybe ArrayDecl
findArray VName
x = (ArrayDecl -> Bool) -> [ArrayDecl] -> Maybe ArrayDecl
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (VName -> ArrayDecl -> Bool
isArrayDecl VName
x) [ArrayDecl]
arrayds
      sizes :: Set VName
sizes = [Set VName] -> Set VName
forall a. Monoid a => [a] -> a
mconcat ([Set VName] -> Set VName) -> [Set VName] -> Set VName
forall a b. (a -> b) -> a -> b
$ (Param (MemInfo DimSize Uniqueness MemBind) -> Set VName)
-> [Param (MemInfo DimSize Uniqueness MemBind)] -> [Set VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (MemInfo DimSize Uniqueness MemBind) -> Set VName
forall attr. Typed attr => Param attr -> Set VName
fparamSizes ([Param (MemInfo DimSize Uniqueness MemBind)] -> [Set VName])
-> [Param (MemInfo DimSize Uniqueness MemBind)] -> [Set VName]
forall a b. (a -> b) -> a -> b
$ [Param (MemInfo DimSize Uniqueness MemBind)]
ctx_params[Param (MemInfo DimSize Uniqueness MemBind)]
-> [Param (MemInfo DimSize Uniqueness MemBind)]
-> [Param (MemInfo DimSize Uniqueness MemBind)]
forall a. [a] -> [a] -> [a]
++[Param (MemInfo DimSize Uniqueness MemBind)]
val_params

      summaries :: Map VName Space
summaries = [(VName, Space)] -> Map VName Space
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Space)] -> Map VName Space)
-> [(VName, Space)] -> Map VName Space
forall a b. (a -> b) -> a -> b
$ (Param (MemInfo DimSize Uniqueness MemBind)
 -> Maybe (VName, Space))
-> [Param (MemInfo DimSize Uniqueness MemBind)] -> [(VName, Space)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Param (MemInfo DimSize Uniqueness MemBind) -> Maybe (VName, Space)
forall d u ret. Param (MemInfo d u ret) -> Maybe (VName, Space)
memSummary [FParam lore]
[Param (MemInfo DimSize Uniqueness MemBind)]
params
        where memSummary :: Param (MemInfo d u ret) -> Maybe (VName, Space)
memSummary Param (MemInfo d u ret)
param
                | MemMem Space
space <- Param (MemInfo d u ret) -> MemInfo d u ret
forall attr. Param attr -> attr
paramAttr Param (MemInfo d u ret)
param =
                    (VName, Space) -> Maybe (VName, Space)
forall a. a -> Maybe a
Just (Param (MemInfo d u ret) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo d u ret)
param, Space
space)
                | Bool
otherwise =
                    Maybe (VName, Space)
forall a. Maybe a
Nothing

      findMemInfo :: VName -> Maybe Space
      findMemInfo :: VName -> Maybe Space
findMemInfo = (VName -> Map VName Space -> Maybe Space)
-> Map VName Space -> VName -> Maybe Space
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> Map VName Space -> Maybe Space
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Map VName Space
summaries

      mkValueDesc :: Param (MemInfo DimSize Uniqueness MemBind)
-> Signedness -> Maybe ValueDesc
mkValueDesc Param (MemInfo DimSize Uniqueness MemBind)
fparam Signedness
signedness =
        case (VName -> Maybe ArrayDecl
findArray (VName -> Maybe ArrayDecl) -> VName -> Maybe ArrayDecl
forall a b. (a -> b) -> a -> b
$ Param (MemInfo DimSize Uniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo DimSize Uniqueness MemBind)
fparam, Param (MemInfo DimSize Uniqueness MemBind) -> Type
forall attr. Typed attr => Param attr -> Type
paramType Param (MemInfo DimSize Uniqueness MemBind)
fparam) of
          (Just (ArrayDecl VName
_ PrimType
bt (MemLocation VName
mem [DimSize]
shape IxFun Exp
_)), Type
_) -> do
            Space
memspace <- VName -> Maybe Space
findMemInfo VName
mem
            ValueDesc -> Maybe ValueDesc
forall a. a -> Maybe a
Just (ValueDesc -> Maybe ValueDesc) -> ValueDesc -> Maybe ValueDesc
forall a b. (a -> b) -> a -> b
$ VName -> Space -> PrimType -> Signedness -> [DimSize] -> ValueDesc
Imp.ArrayValue VName
mem Space
memspace PrimType
bt Signedness
signedness [DimSize]
shape
          (Maybe ArrayDecl
_, Prim PrimType
bt)
            | Param (MemInfo DimSize Uniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo DimSize Uniqueness MemBind)
fparam VName -> Set VName -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Set VName
sizes ->
              Maybe ValueDesc
forall a. Maybe a
Nothing
            | Bool
otherwise ->
              ValueDesc -> Maybe ValueDesc
forall a. a -> Maybe a
Just (ValueDesc -> Maybe ValueDesc) -> ValueDesc -> Maybe ValueDesc
forall a b. (a -> b) -> a -> b
$ PrimType -> Signedness -> VName -> ValueDesc
Imp.ScalarValue PrimType
bt Signedness
signedness (VName -> ValueDesc) -> VName -> ValueDesc
forall a b. (a -> b) -> a -> b
$ Param (MemInfo DimSize Uniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo DimSize Uniqueness MemBind)
fparam
          (Maybe ArrayDecl, Type)
_ ->
            Maybe ValueDesc
forall a. Maybe a
Nothing

      mkExts :: [EntryPointType]
-> [Param (MemInfo DimSize Uniqueness MemBind)] -> [ExternalValue]
mkExts (TypeOpaque String
desc Int
n:[EntryPointType]
epts) [Param (MemInfo DimSize Uniqueness MemBind)]
fparams =
        let ([Param (MemInfo DimSize Uniqueness MemBind)]
fparams',[Param (MemInfo DimSize Uniqueness MemBind)]
rest) = Int
-> [Param (MemInfo DimSize Uniqueness MemBind)]
-> ([Param (MemInfo DimSize Uniqueness MemBind)],
    [Param (MemInfo DimSize Uniqueness MemBind)])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
n [Param (MemInfo DimSize Uniqueness MemBind)]
fparams
        in String -> [ValueDesc] -> ExternalValue
Imp.OpaqueValue String
desc
           ((Param (MemInfo DimSize Uniqueness MemBind) -> Maybe ValueDesc)
-> [Param (MemInfo DimSize Uniqueness MemBind)] -> [ValueDesc]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (Param (MemInfo DimSize Uniqueness MemBind)
-> Signedness -> Maybe ValueDesc
`mkValueDesc` Signedness
Imp.TypeDirect) [Param (MemInfo DimSize Uniqueness MemBind)]
fparams') ExternalValue -> [ExternalValue] -> [ExternalValue]
forall a. a -> [a] -> [a]
:
           [EntryPointType]
-> [Param (MemInfo DimSize Uniqueness MemBind)] -> [ExternalValue]
mkExts [EntryPointType]
epts [Param (MemInfo DimSize Uniqueness MemBind)]
rest
      mkExts (EntryPointType
TypeUnsigned:[EntryPointType]
epts) (Param (MemInfo DimSize Uniqueness MemBind)
fparam:[Param (MemInfo DimSize Uniqueness MemBind)]
fparams) =
        Maybe ExternalValue -> [ExternalValue]
forall a. Maybe a -> [a]
maybeToList (ValueDesc -> ExternalValue
Imp.TransparentValue (ValueDesc -> ExternalValue)
-> Maybe ValueDesc -> Maybe ExternalValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Param (MemInfo DimSize Uniqueness MemBind)
-> Signedness -> Maybe ValueDesc
mkValueDesc Param (MemInfo DimSize Uniqueness MemBind)
fparam Signedness
Imp.TypeUnsigned) [ExternalValue] -> [ExternalValue] -> [ExternalValue]
forall a. [a] -> [a] -> [a]
++
        [EntryPointType]
-> [Param (MemInfo DimSize Uniqueness MemBind)] -> [ExternalValue]
mkExts [EntryPointType]
epts [Param (MemInfo DimSize Uniqueness MemBind)]
fparams
      mkExts (EntryPointType
TypeDirect:[EntryPointType]
epts) (Param (MemInfo DimSize Uniqueness MemBind)
fparam:[Param (MemInfo DimSize Uniqueness MemBind)]
fparams) =
        Maybe ExternalValue -> [ExternalValue]
forall a. Maybe a -> [a]
maybeToList (ValueDesc -> ExternalValue
Imp.TransparentValue (ValueDesc -> ExternalValue)
-> Maybe ValueDesc -> Maybe ExternalValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Param (MemInfo DimSize Uniqueness MemBind)
-> Signedness -> Maybe ValueDesc
mkValueDesc Param (MemInfo DimSize Uniqueness MemBind)
fparam Signedness
Imp.TypeDirect) [ExternalValue] -> [ExternalValue] -> [ExternalValue]
forall a. [a] -> [a] -> [a]
++
        [EntryPointType]
-> [Param (MemInfo DimSize Uniqueness MemBind)] -> [ExternalValue]
mkExts [EntryPointType]
epts [Param (MemInfo DimSize Uniqueness MemBind)]
fparams
      mkExts [EntryPointType]
_ [Param (MemInfo DimSize Uniqueness MemBind)]
_ = []

  ([Param], [ArrayDecl], [ExternalValue])
-> ImpM lore r op ([Param], [ArrayDecl], [ExternalValue])
forall (m :: * -> *) a. Monad m => a -> m a
return ([Param]
inparams, [ArrayDecl]
arrayds, [EntryPointType]
-> [Param (MemInfo DimSize Uniqueness MemBind)] -> [ExternalValue]
mkExts [EntryPointType]
orig_epts [Param (MemInfo DimSize Uniqueness MemBind)]
val_params)
  where isArrayDecl :: VName -> ArrayDecl -> Bool
isArrayDecl VName
x (ArrayDecl VName
y PrimType
_ MemLocation
_) = VName
x VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
y

compileOutParams :: ExplicitMemorish lore =>
                    [RetType lore] -> [EntryPointType]
                 -> ImpM lore r op ([Imp.ExternalValue], [Imp.Param], Destination)
compileOutParams :: [RetType lore]
-> [EntryPointType]
-> ImpM lore r op ([ExternalValue], [Param], Destination)
compileOutParams [RetType lore]
orig_rts [EntryPointType]
orig_epts = do
  (([ExternalValue]
extvs, [ValueDestination]
dests), ([Param]
outparams,Map Int ValueDestination
ctx_dests)) <-
    WriterT
  ([Param], Map Int ValueDestination)
  (ImpM lore r op)
  ([ExternalValue], [ValueDestination])
-> ImpM
     lore
     r
     op
     (([ExternalValue], [ValueDestination]),
      ([Param], Map Int ValueDestination))
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT
   ([Param], Map Int ValueDestination)
   (ImpM lore r op)
   ([ExternalValue], [ValueDestination])
 -> ImpM
      lore
      r
      op
      (([ExternalValue], [ValueDestination]),
       ([Param], Map Int ValueDestination)))
-> WriterT
     ([Param], Map Int ValueDestination)
     (ImpM lore r op)
     ([ExternalValue], [ValueDestination])
-> ImpM
     lore
     r
     op
     (([ExternalValue], [ValueDestination]),
      ([Param], Map Int ValueDestination))
forall a b. (a -> b) -> a -> b
$ StateT
  (Map Any Any, Map Int VName)
  (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
  ([ExternalValue], [ValueDestination])
-> (Map Any Any, Map Int VName)
-> WriterT
     ([Param], Map Int ValueDestination)
     (ImpM lore r op)
     ([ExternalValue], [ValueDestination])
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT ([EntryPointType]
-> [FunReturns]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ([ExternalValue], [ValueDestination])
mkExts [EntryPointType]
orig_epts [RetType lore]
[FunReturns]
orig_rts) (Map Any Any
forall k a. Map k a
M.empty, Map Int VName
forall k a. Map k a
M.empty)
  let ctx_dests' :: [ValueDestination]
ctx_dests' = ((Int, ValueDestination) -> ValueDestination)
-> [(Int, ValueDestination)] -> [ValueDestination]
forall a b. (a -> b) -> [a] -> [b]
map (Int, ValueDestination) -> ValueDestination
forall a b. (a, b) -> b
snd ([(Int, ValueDestination)] -> [ValueDestination])
-> [(Int, ValueDestination)] -> [ValueDestination]
forall a b. (a -> b) -> a -> b
$ ((Int, ValueDestination) -> Int)
-> [(Int, ValueDestination)] -> [(Int, ValueDestination)]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn (Int, ValueDestination) -> Int
forall a b. (a, b) -> a
fst ([(Int, ValueDestination)] -> [(Int, ValueDestination)])
-> [(Int, ValueDestination)] -> [(Int, ValueDestination)]
forall a b. (a -> b) -> a -> b
$ Map Int ValueDestination -> [(Int, ValueDestination)]
forall k a. Map k a -> [(k, a)]
M.toList Map Int ValueDestination
ctx_dests
  ([ExternalValue], [Param], Destination)
-> ImpM lore r op ([ExternalValue], [Param], Destination)
forall (m :: * -> *) a. Monad m => a -> m a
return ([ExternalValue]
extvs, [Param]
outparams, Maybe Int -> [ValueDestination] -> Destination
Destination Maybe Int
forall a. Maybe a
Nothing ([ValueDestination] -> Destination)
-> [ValueDestination] -> Destination
forall a b. (a -> b) -> a -> b
$ [ValueDestination]
ctx_dests' [ValueDestination] -> [ValueDestination] -> [ValueDestination]
forall a. Semigroup a => a -> a -> a
<> [ValueDestination]
dests)
  where imp :: ImpM lore r op a
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     a
imp = WriterT ([Param], Map Int ValueDestination) (ImpM lore r op) a
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op) a
 -> StateT
      (Map Any Any, Map Int VName)
      (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
      a)
-> (ImpM lore r op a
    -> WriterT ([Param], Map Int ValueDestination) (ImpM lore r op) a)
-> ImpM lore r op a
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ImpM lore r op a
-> WriterT ([Param], Map Int ValueDestination) (ImpM lore r op) a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift

        mkExts :: [EntryPointType]
-> [FunReturns]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ([ExternalValue], [ValueDestination])
mkExts (TypeOpaque String
desc Int
n:[EntryPointType]
epts) [FunReturns]
rts = do
          let ([FunReturns]
rts',[FunReturns]
rest) = Int -> [FunReturns] -> ([FunReturns], [FunReturns])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
n [FunReturns]
rts
          ([ValueDesc]
evs, [ValueDestination]
dests) <- [(ValueDesc, ValueDestination)]
-> ([ValueDesc], [ValueDestination])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(ValueDesc, ValueDestination)]
 -> ([ValueDesc], [ValueDestination]))
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     [(ValueDesc, ValueDestination)]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ([ValueDesc], [ValueDestination])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (FunReturns
 -> Signedness
 -> StateT
      (Map Any Any, Map Int VName)
      (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
      (ValueDesc, ValueDestination))
-> [FunReturns]
-> [Signedness]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     [(ValueDesc, ValueDestination)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM FunReturns
-> Signedness
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     (ValueDesc, ValueDestination)
mkParam [FunReturns]
rts' (Signedness -> [Signedness]
forall a. a -> [a]
repeat Signedness
Imp.TypeDirect)
          ([ExternalValue]
more_values, [ValueDestination]
more_dests) <- [EntryPointType]
-> [FunReturns]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ([ExternalValue], [ValueDestination])
mkExts [EntryPointType]
epts [FunReturns]
rest
          ([ExternalValue], [ValueDestination])
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ([ExternalValue], [ValueDestination])
forall (m :: * -> *) a. Monad m => a -> m a
return (String -> [ValueDesc] -> ExternalValue
Imp.OpaqueValue String
desc [ValueDesc]
evs ExternalValue -> [ExternalValue] -> [ExternalValue]
forall a. a -> [a] -> [a]
: [ExternalValue]
more_values,
                  [ValueDestination]
dests [ValueDestination] -> [ValueDestination] -> [ValueDestination]
forall a. [a] -> [a] -> [a]
++ [ValueDestination]
more_dests)
        mkExts (EntryPointType
TypeUnsigned:[EntryPointType]
epts) (FunReturns
rt:[FunReturns]
rts) = do
          (ValueDesc
ev,ValueDestination
dest) <- FunReturns
-> Signedness
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     (ValueDesc, ValueDestination)
mkParam FunReturns
rt Signedness
Imp.TypeUnsigned
          ([ExternalValue]
more_values, [ValueDestination]
more_dests) <- [EntryPointType]
-> [FunReturns]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ([ExternalValue], [ValueDestination])
mkExts [EntryPointType]
epts [FunReturns]
rts
          ([ExternalValue], [ValueDestination])
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ([ExternalValue], [ValueDestination])
forall (m :: * -> *) a. Monad m => a -> m a
return (ValueDesc -> ExternalValue
Imp.TransparentValue ValueDesc
ev ExternalValue -> [ExternalValue] -> [ExternalValue]
forall a. a -> [a] -> [a]
: [ExternalValue]
more_values,
                  ValueDestination
dest ValueDestination -> [ValueDestination] -> [ValueDestination]
forall a. a -> [a] -> [a]
: [ValueDestination]
more_dests)
        mkExts (EntryPointType
TypeDirect:[EntryPointType]
epts) (FunReturns
rt:[FunReturns]
rts) = do
          (ValueDesc
ev,ValueDestination
dest) <- FunReturns
-> Signedness
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     (ValueDesc, ValueDestination)
mkParam FunReturns
rt Signedness
Imp.TypeDirect
          ([ExternalValue]
more_values, [ValueDestination]
more_dests) <- [EntryPointType]
-> [FunReturns]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ([ExternalValue], [ValueDestination])
mkExts [EntryPointType]
epts [FunReturns]
rts
          ([ExternalValue], [ValueDestination])
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ([ExternalValue], [ValueDestination])
forall (m :: * -> *) a. Monad m => a -> m a
return (ValueDesc -> ExternalValue
Imp.TransparentValue ValueDesc
ev ExternalValue -> [ExternalValue] -> [ExternalValue]
forall a. a -> [a] -> [a]
: [ExternalValue]
more_values,
                  ValueDestination
dest ValueDestination -> [ValueDestination] -> [ValueDestination]
forall a. a -> [a] -> [a]
: [ValueDestination]
more_dests)
        mkExts [EntryPointType]
_ [FunReturns]
_ = ([ExternalValue], [ValueDestination])
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ([ExternalValue], [ValueDestination])
forall (m :: * -> *) a. Monad m => a -> m a
return ([], [])

        mkParam :: FunReturns
-> Signedness
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     (ValueDesc, ValueDestination)
mkParam MemMem{} Signedness
_ =
          String
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     (ValueDesc, ValueDestination)
forall a. HasCallStack => String -> a
error String
"Functions may not explicitly return memory blocks."
        mkParam (MemPrim PrimType
t) Signedness
ept = do
          VName
out <- ImpM lore r op VName
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     VName
forall a.
ImpM lore r op a
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     a
imp (ImpM lore r op VName
 -> StateT
      (Map Any Any, Map Int VName)
      (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
      VName)
-> ImpM lore r op VName
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     VName
forall a b. (a -> b) -> a -> b
$ String -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"scalar_out"
          ([Param], Map Int ValueDestination)
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([VName -> PrimType -> Param
Imp.ScalarParam VName
out PrimType
t], Map Int ValueDestination
forall a. Monoid a => a
mempty)
          (ValueDesc, ValueDestination)
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     (ValueDesc, ValueDestination)
forall (m :: * -> *) a. Monad m => a -> m a
return (PrimType -> Signedness -> VName -> ValueDesc
Imp.ScalarValue PrimType
t Signedness
ept VName
out, VName -> ValueDestination
ScalarDestination VName
out)
        mkParam (MemArray PrimType
t ShapeBase (Ext DimSize)
shape Uniqueness
_ MemReturn
attr) Signedness
ept = do
          Space
space <- (Env lore r op -> Space)
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     Space
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Space
forall lore r op. Env lore r op -> Space
envDefaultSpace
          VName
memout <- case MemReturn
attr of
            ReturnsNewBlock Space
_ Int
x ExtIxFun
_ixfun -> do
              VName
memout <- ImpM lore r op VName
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     VName
forall a.
ImpM lore r op a
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     a
imp (ImpM lore r op VName
 -> StateT
      (Map Any Any, Map Int VName)
      (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
      VName)
-> ImpM lore r op VName
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     VName
forall a b. (a -> b) -> a -> b
$ String -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"out_mem"
              ([Param], Map Int ValueDestination)
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([VName -> Space -> Param
Imp.MemParam VName
memout Space
space],
                    Int -> ValueDestination -> Map Int ValueDestination
forall k a. k -> a -> Map k a
M.singleton Int
x (ValueDestination -> Map Int ValueDestination)
-> ValueDestination -> Map Int ValueDestination
forall a b. (a -> b) -> a -> b
$ VName -> ValueDestination
MemoryDestination VName
memout)
              VName
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
memout
            ReturnsInBlock VName
memout ExtIxFun
_ ->
              VName
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
memout
          [DimSize]
resultshape <- (Ext DimSize
 -> StateT
      (Map Any Any, Map Int VName)
      (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
      DimSize)
-> [Ext DimSize]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     [DimSize]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Ext DimSize
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     DimSize
inspectExtSize ([Ext DimSize]
 -> StateT
      (Map Any Any, Map Int VName)
      (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
      [DimSize])
-> [Ext DimSize]
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     [DimSize]
forall a b. (a -> b) -> a -> b
$ ShapeBase (Ext DimSize) -> [Ext DimSize]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase (Ext DimSize)
shape
          (ValueDesc, ValueDestination)
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     (ValueDesc, ValueDestination)
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> Space -> PrimType -> Signedness -> [DimSize] -> ValueDesc
Imp.ArrayValue VName
memout Space
space PrimType
t Signedness
ept [DimSize]
resultshape,
                  Maybe MemLocation -> ValueDestination
ArrayDestination Maybe MemLocation
forall a. Maybe a
Nothing)

        inspectExtSize :: Ext DimSize
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     DimSize
inspectExtSize (Ext Int
x) = do
          (Map Any Any
memseen,Map Int VName
arrseen) <- StateT
  (Map Any Any, Map Int VName)
  (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
  (Map Any Any, Map Int VName)
forall s (m :: * -> *). MonadState s m => m s
get
          case Int -> Map Int VName -> Maybe VName
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Int
x Map Int VName
arrseen of
            Maybe VName
Nothing -> do
              VName
out <- ImpM lore r op VName
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     VName
forall a.
ImpM lore r op a
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     a
imp (ImpM lore r op VName
 -> StateT
      (Map Any Any, Map Int VName)
      (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
      VName)
-> ImpM lore r op VName
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     VName
forall a b. (a -> b) -> a -> b
$ String -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"out_arrsize"
              ([Param], Map Int ValueDestination)
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([VName -> PrimType -> Param
Imp.ScalarParam VName
out PrimType
int32],
                    Int -> ValueDestination -> Map Int ValueDestination
forall k a. k -> a -> Map k a
M.singleton Int
x (ValueDestination -> Map Int ValueDestination)
-> ValueDestination -> Map Int ValueDestination
forall a b. (a -> b) -> a -> b
$ VName -> ValueDestination
ScalarDestination VName
out)
              (Map Any Any, Map Int VName)
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (Map Any Any
memseen, Int -> VName -> Map Int VName -> Map Int VName
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Int
x VName
out Map Int VName
arrseen)
              DimSize
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     DimSize
forall (m :: * -> *) a. Monad m => a -> m a
return (DimSize
 -> StateT
      (Map Any Any, Map Int VName)
      (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
      DimSize)
-> DimSize
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     DimSize
forall a b. (a -> b) -> a -> b
$ VName -> DimSize
Var VName
out
            Just VName
out ->
              DimSize
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     DimSize
forall (m :: * -> *) a. Monad m => a -> m a
return (DimSize
 -> StateT
      (Map Any Any, Map Int VName)
      (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
      DimSize)
-> DimSize
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     DimSize
forall a b. (a -> b) -> a -> b
$ VName -> DimSize
Var VName
out
        inspectExtSize (Free DimSize
se) =
          DimSize
-> StateT
     (Map Any Any, Map Int VName)
     (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
     DimSize
forall (m :: * -> *) a. Monad m => a -> m a
return DimSize
se

compileFunDef :: ExplicitMemorish lore =>
                 FunDef lore
              -> ImpM lore r op ()
compileFunDef :: FunDef lore -> ImpM lore r op ()
compileFunDef (FunDef Maybe EntryPoint
entry Name
fname [RetType lore]
rettype [FParam lore]
params BodyT lore
body) = do
  (([Param]
outparams, [Param]
inparams, [ExternalValue]
results, [ExternalValue]
args), Code op
body') <- ImpM lore r op ([Param], [Param], [ExternalValue], [ExternalValue])
-> ImpM
     lore
     r
     op
     (([Param], [Param], [ExternalValue], [ExternalValue]), Code op)
forall lore r op a. ImpM lore r op a -> ImpM lore r op (a, Code op)
collect' ImpM lore r op ([Param], [Param], [ExternalValue], [ExternalValue])
compile
  Name -> Function op -> ImpM lore r op ()
forall op lore r. Name -> Function op -> ImpM lore r op ()
emitFunction Name
fname (Function op -> ImpM lore r op ())
-> Function op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Bool
-> [Param]
-> [Param]
-> Code op
-> [ExternalValue]
-> [ExternalValue]
-> Function op
forall a.
Bool
-> [Param]
-> [Param]
-> Code a
-> [ExternalValue]
-> [ExternalValue]
-> FunctionT a
Imp.Function (Maybe EntryPoint -> Bool
forall a. Maybe a -> Bool
isJust Maybe EntryPoint
entry) [Param]
outparams [Param]
inparams Code op
body' [ExternalValue]
results [ExternalValue]
args
  where params_entry :: [EntryPointType]
params_entry = [EntryPointType]
-> (EntryPoint -> [EntryPointType])
-> Maybe EntryPoint
-> [EntryPointType]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (Int -> EntryPointType -> [EntryPointType]
forall a. Int -> a -> [a]
replicate ([Param (MemInfo DimSize Uniqueness MemBind)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [FParam lore]
[Param (MemInfo DimSize Uniqueness MemBind)]
params) EntryPointType
TypeDirect) EntryPoint -> [EntryPointType]
forall a b. (a, b) -> a
fst Maybe EntryPoint
entry
        ret_entry :: [EntryPointType]
ret_entry = [EntryPointType]
-> (EntryPoint -> [EntryPointType])
-> Maybe EntryPoint
-> [EntryPointType]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (Int -> EntryPointType -> [EntryPointType]
forall a. Int -> a -> [a]
replicate ([FunReturns] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [RetType lore]
[FunReturns]
rettype) EntryPointType
TypeDirect) EntryPoint -> [EntryPointType]
forall a b. (a, b) -> b
snd Maybe EntryPoint
entry
        compile :: ImpM lore r op ([Param], [Param], [ExternalValue], [ExternalValue])
compile = do
          ([Param]
inparams, [ArrayDecl]
arrayds, [ExternalValue]
args) <- [FParam lore]
-> [EntryPointType]
-> ImpM lore r op ([Param], [ArrayDecl], [ExternalValue])
forall lore r op.
ExplicitMemorish lore =>
[FParam lore]
-> [EntryPointType]
-> ImpM lore r op ([Param], [ArrayDecl], [ExternalValue])
compileInParams [FParam lore]
params [EntryPointType]
params_entry
          ([ExternalValue]
results, [Param]
outparams, Destination Maybe Int
_ [ValueDestination]
dests) <- [RetType lore]
-> [EntryPointType]
-> ImpM lore r op ([ExternalValue], [Param], Destination)
forall lore r op.
ExplicitMemorish lore =>
[RetType lore]
-> [EntryPointType]
-> ImpM lore r op ([ExternalValue], [Param], Destination)
compileOutParams [RetType lore]
rettype [EntryPointType]
ret_entry
          [FParam lore] -> ImpM lore r op ()
forall lore r op.
ExplicitMemorish lore =>
[FParam lore] -> ImpM lore r op ()
addFParams [FParam lore]
params
          [ArrayDecl] -> ImpM lore r op ()
forall lore r op. [ArrayDecl] -> ImpM lore r op ()
addArrays [ArrayDecl]
arrayds

          let Body BodyAttr lore
_ Stms lore
stms [DimSize]
ses = BodyT lore
body
          Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms ([DimSize] -> Names
forall a. FreeIn a => a -> Names
freeIn [DimSize]
ses) Stms lore
stms (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
            [(ValueDestination, DimSize)]
-> ((ValueDestination, DimSize) -> ImpM lore r op ())
-> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([ValueDestination] -> [DimSize] -> [(ValueDestination, DimSize)]
forall a b. [a] -> [b] -> [(a, b)]
zip [ValueDestination]
dests [DimSize]
ses) (((ValueDestination, DimSize) -> ImpM lore r op ())
 -> ImpM lore r op ())
-> ((ValueDestination, DimSize) -> ImpM lore r op ())
-> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(ValueDestination
d, DimSize
se) -> ValueDestination
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
forall lore r op.
ValueDestination
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIMDest ValueDestination
d [] DimSize
se []

          ([Param], [Param], [ExternalValue], [ExternalValue])
-> ImpM
     lore r op ([Param], [Param], [ExternalValue], [ExternalValue])
forall (m :: * -> *) a. Monad m => a -> m a
return ([Param]
outparams, [Param]
inparams, [ExternalValue]
results, [ExternalValue]
args)

compileBody :: (ExplicitMemorish lore) => Pattern lore -> Body lore -> ImpM lore r op ()
compileBody :: Pattern lore -> Body lore -> ImpM lore r op ()
compileBody Pattern lore
pat (Body BodyAttr lore
_ Stms lore
bnds [DimSize]
ses) = do
  Destination Maybe Int
_ [ValueDestination]
dests <- Pattern lore -> ImpM lore r op Destination
forall lore r op.
ExplicitMemorish lore =>
Pattern lore -> ImpM lore r op Destination
destinationFromPattern Pattern lore
pat
  Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms ([DimSize] -> Names
forall a. FreeIn a => a -> Names
freeIn [DimSize]
ses) Stms lore
bnds (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
    [(ValueDestination, DimSize)]
-> ((ValueDestination, DimSize) -> ImpM lore r op ())
-> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([ValueDestination] -> [DimSize] -> [(ValueDestination, DimSize)]
forall a b. [a] -> [b] -> [(a, b)]
zip [ValueDestination]
dests [DimSize]
ses) (((ValueDestination, DimSize) -> ImpM lore r op ())
 -> ImpM lore r op ())
-> ((ValueDestination, DimSize) -> ImpM lore r op ())
-> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(ValueDestination
d, DimSize
se) -> ValueDestination
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
forall lore r op.
ValueDestination
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIMDest ValueDestination
d [] DimSize
se []

compileBody' :: [Param attr] -> Body lore -> ImpM lore r op ()
compileBody' :: [Param attr] -> Body lore -> ImpM lore r op ()
compileBody' [Param attr]
params (Body BodyAttr lore
_ Stms lore
bnds [DimSize]
ses) =
  Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms ([DimSize] -> Names
forall a. FreeIn a => a -> Names
freeIn [DimSize]
ses) Stms lore
bnds (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
    [(Param attr, DimSize)]
-> ((Param attr, DimSize) -> ImpM lore r op ())
-> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param attr] -> [DimSize] -> [(Param attr, DimSize)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param attr]
params [DimSize]
ses) (((Param attr, DimSize) -> ImpM lore r op ()) -> ImpM lore r op ())
-> ((Param attr, DimSize) -> ImpM lore r op ())
-> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(Param attr
param, DimSize
se) -> VName
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (Param attr -> VName
forall attr. Param attr -> VName
paramName Param attr
param) [] DimSize
se []

compileLoopBody :: Typed attr => [Param attr] -> Body lore -> ImpM lore r op ()
compileLoopBody :: [Param attr] -> Body lore -> ImpM lore r op ()
compileLoopBody [Param attr]
mergeparams (Body BodyAttr lore
_ Stms lore
bnds [DimSize]
ses) = do
  -- We cannot write the results to the merge parameters immediately,
  -- as some of the results may actually *be* merge parameters, and
  -- would thus be clobbered.  Therefore, we first copy to new
  -- variables mirroring the merge parameters, and then copy this
  -- buffer to the merge parameters.  This is efficient, because the
  -- operations are all scalar operations.
  [VName]
tmpnames <- (Param attr -> ImpM lore r op VName)
-> [Param attr] -> ImpM lore r op [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> ImpM lore r op VName)
-> (Param attr -> String) -> Param attr -> ImpM lore r op VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (String -> ShowS
forall a. [a] -> [a] -> [a]
++String
"_tmp") ShowS -> (Param attr -> String) -> Param attr -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> String
baseString (VName -> String) -> (Param attr -> VName) -> Param attr -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param attr -> VName
forall attr. Param attr -> VName
paramName) [Param attr]
mergeparams
  Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms ([DimSize] -> Names
forall a. FreeIn a => a -> Names
freeIn [DimSize]
ses) Stms lore
bnds (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ do
    [ImpM lore r op ()]
copy_to_merge_params <- [(Param attr, VName, DimSize)]
-> ((Param attr, VName, DimSize)
    -> ImpM lore r op (ImpM lore r op ()))
-> ImpM lore r op [ImpM lore r op ()]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Param attr]
-> [VName] -> [DimSize] -> [(Param attr, VName, DimSize)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Param attr]
mergeparams [VName]
tmpnames [DimSize]
ses) (((Param attr, VName, DimSize)
  -> ImpM lore r op (ImpM lore r op ()))
 -> ImpM lore r op [ImpM lore r op ()])
-> ((Param attr, VName, DimSize)
    -> ImpM lore r op (ImpM lore r op ()))
-> ImpM lore r op [ImpM lore r op ()]
forall a b. (a -> b) -> a -> b
$ \(Param attr
p,VName
tmp,DimSize
se) ->
      case Param attr -> Type
forall t. Typed t => t -> Type
typeOf Param attr
p of
        Prim PrimType
pt  -> do
          Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Volatility -> PrimType -> Code op
forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
tmp Volatility
Imp.Nonvolatile PrimType
pt
          Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
tmp (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ PrimType -> DimSize -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
pt DimSize
se
          ImpM lore r op () -> ImpM lore r op (ImpM lore r op ())
forall (m :: * -> *) a. Monad m => a -> m a
return (ImpM lore r op () -> ImpM lore r op (ImpM lore r op ()))
-> ImpM lore r op () -> ImpM lore r op (ImpM lore r op ())
forall a b. (a -> b) -> a -> b
$ Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar (Param attr -> VName
forall attr. Param attr -> VName
paramName Param attr
p) (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
tmp PrimType
pt
        Mem Space
space | Var VName
v <- DimSize
se -> do
          Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Code op
forall a. VName -> Space -> Code a
Imp.DeclareMem VName
tmp Space
space
          Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> Space -> Code op
forall a. VName -> VName -> Space -> Code a
Imp.SetMem VName
tmp VName
v Space
space
          ImpM lore r op () -> ImpM lore r op (ImpM lore r op ())
forall (m :: * -> *) a. Monad m => a -> m a
return (ImpM lore r op () -> ImpM lore r op (ImpM lore r op ()))
-> ImpM lore r op () -> ImpM lore r op (ImpM lore r op ())
forall a b. (a -> b) -> a -> b
$ Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> Space -> Code op
forall a. VName -> VName -> Space -> Code a
Imp.SetMem (Param attr -> VName
forall attr. Param attr -> VName
paramName Param attr
p) VName
tmp Space
space
        Type
_ -> ImpM lore r op () -> ImpM lore r op (ImpM lore r op ())
forall (m :: * -> *) a. Monad m => a -> m a
return (ImpM lore r op () -> ImpM lore r op (ImpM lore r op ()))
-> ImpM lore r op () -> ImpM lore r op (ImpM lore r op ())
forall a b. (a -> b) -> a -> b
$ () -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    [ImpM lore r op ()] -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ [ImpM lore r op ()]
copy_to_merge_params

compileStms :: Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms :: Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
alive_after_stms Stms lore
all_stms ImpM lore r op ()
m = do
  Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
cb <- (Env lore r op
 -> Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ())
-> ImpM
     lore
     r
     op
     (Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ())
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op
-> Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op. Env lore r op -> StmsCompiler lore r op
envStmsCompiler
  Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
cb Names
alive_after_stms Stms lore
all_stms ImpM lore r op ()
m

defCompileStms :: (ExplicitMemorish lore, FreeIn op) =>
                  Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
defCompileStms :: Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
defCompileStms Names
alive_after_stms Stms lore
all_stms ImpM lore r op ()
m =
  -- We keep track of any memory blocks produced by the statements,
  -- and after the last time that memory block is used, we insert a
  -- Free.  This is very conservative, but can cut down on lifetimes
  -- in some cases.
  ImpM lore r op Names -> ImpM lore r op ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ImpM lore r op Names -> ImpM lore r op ())
-> ImpM lore r op Names -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Set (VName, Space) -> [Stm lore] -> ImpM lore r op Names
compileStms' Set (VName, Space)
forall a. Monoid a => a
mempty ([Stm lore] -> ImpM lore r op Names)
-> [Stm lore] -> ImpM lore r op Names
forall a b. (a -> b) -> a -> b
$ Stms lore -> [Stm lore]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms lore
all_stms
  where compileStms' :: Set (VName, Space) -> [Stm lore] -> ImpM lore r op Names
compileStms' Set (VName, Space)
allocs (Let Pattern lore
pat StmAux (ExpAttr lore)
_ Exp lore
e:[Stm lore]
bs) = do
          Maybe (Exp lore) -> [PatElem lore] -> ImpM lore r op ()
forall lore r op.
ExplicitMemorish lore =>
Maybe (Exp lore) -> [PatElem lore] -> ImpM lore r op ()
dVars (Exp lore -> Maybe (Exp lore)
forall a. a -> Maybe a
Just Exp lore
e) (PatternT (MemInfo DimSize NoUniqueness MemBind)
-> [PatElemT (MemInfo DimSize NoUniqueness MemBind)]
forall attr. PatternT attr -> [PatElemT attr]
patternElements Pattern lore
PatternT (MemInfo DimSize NoUniqueness MemBind)
pat)

          Code op
e_code <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (ImpM lore r op () -> ImpM lore r op (Code op))
-> ImpM lore r op () -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$ Pattern lore -> Exp lore -> ImpM lore r op ()
forall lore r op. Pattern lore -> Exp lore -> ImpM lore r op ()
compileExp Pattern lore
pat Exp lore
e
          (Names
live_after, Code op
bs_code) <- ImpM lore r op Names -> ImpM lore r op (Names, Code op)
forall lore r op a. ImpM lore r op a -> ImpM lore r op (a, Code op)
collect' (ImpM lore r op Names -> ImpM lore r op (Names, Code op))
-> ImpM lore r op Names -> ImpM lore r op (Names, Code op)
forall a b. (a -> b) -> a -> b
$ Set (VName, Space) -> [Stm lore] -> ImpM lore r op Names
compileStms' (PatternT (MemInfo DimSize NoUniqueness MemBind)
-> Set (VName, Space)
patternAllocs Pattern lore
PatternT (MemInfo DimSize NoUniqueness MemBind)
pat Set (VName, Space) -> Set (VName, Space) -> Set (VName, Space)
forall a. Semigroup a => a -> a -> a
<> Set (VName, Space)
allocs) [Stm lore]
bs
          let dies_here :: VName -> Bool
dies_here VName
v = Bool -> Bool
not (VName
v VName -> Names -> Bool
`nameIn` Names
live_after) Bool -> Bool -> Bool
&&
                            VName
v VName -> Names -> Bool
`nameIn` Code op -> Names
forall a. FreeIn a => a -> Names
freeIn Code op
e_code
              to_free :: Set (VName, Space)
to_free = ((VName, Space) -> Bool)
-> Set (VName, Space) -> Set (VName, Space)
forall a. (a -> Bool) -> Set a -> Set a
S.filter (VName -> Bool
dies_here (VName -> Bool)
-> ((VName, Space) -> VName) -> (VName, Space) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, Space) -> VName
forall a b. (a, b) -> a
fst) Set (VName, Space)
allocs

          Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit Code op
e_code
          ((VName, Space) -> ImpM lore r op ())
-> Set (VName, Space) -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ())
-> ((VName, Space) -> Code op)
-> (VName, Space)
-> ImpM lore r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> Space -> Code op) -> (VName, Space) -> Code op
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry VName -> Space -> Code op
forall a. VName -> Space -> Code a
Imp.Free) Set (VName, Space)
to_free
          Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit Code op
bs_code

          Names -> ImpM lore r op Names
forall (m :: * -> *) a. Monad m => a -> m a
return (Names -> ImpM lore r op Names) -> Names -> ImpM lore r op Names
forall a b. (a -> b) -> a -> b
$ Code op -> Names
forall a. FreeIn a => a -> Names
freeIn Code op
e_code Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
live_after
        compileStms' Set (VName, Space)
_ [] = do
          Code op
code <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect ImpM lore r op ()
m
          Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit Code op
code
          Names -> ImpM lore r op Names
forall (m :: * -> *) a. Monad m => a -> m a
return (Names -> ImpM lore r op Names) -> Names -> ImpM lore r op Names
forall a b. (a -> b) -> a -> b
$ Code op -> Names
forall a. FreeIn a => a -> Names
freeIn Code op
code Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
alive_after_stms

        patternAllocs :: PatternT (MemInfo DimSize NoUniqueness MemBind)
-> Set (VName, Space)
patternAllocs = [(VName, Space)] -> Set (VName, Space)
forall a. Ord a => [a] -> Set a
S.fromList ([(VName, Space)] -> Set (VName, Space))
-> (PatternT (MemInfo DimSize NoUniqueness MemBind)
    -> [(VName, Space)])
-> PatternT (MemInfo DimSize NoUniqueness MemBind)
-> Set (VName, Space)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElemT (MemInfo DimSize NoUniqueness MemBind)
 -> Maybe (VName, Space))
-> [PatElemT (MemInfo DimSize NoUniqueness MemBind)]
-> [(VName, Space)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe PatElemT (MemInfo DimSize NoUniqueness MemBind)
-> Maybe (VName, Space)
forall attr. Typed attr => PatElemT attr -> Maybe (VName, Space)
isMemPatElem ([PatElemT (MemInfo DimSize NoUniqueness MemBind)]
 -> [(VName, Space)])
-> (PatternT (MemInfo DimSize NoUniqueness MemBind)
    -> [PatElemT (MemInfo DimSize NoUniqueness MemBind)])
-> PatternT (MemInfo DimSize NoUniqueness MemBind)
-> [(VName, Space)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatternT (MemInfo DimSize NoUniqueness MemBind)
-> [PatElemT (MemInfo DimSize NoUniqueness MemBind)]
forall attr. PatternT attr -> [PatElemT attr]
patternElements
        isMemPatElem :: PatElemT attr -> Maybe (VName, Space)
isMemPatElem PatElemT attr
pe = case PatElemT attr -> Type
forall attr. Typed attr => PatElemT attr -> Type
patElemType PatElemT attr
pe of
                            Mem Space
space -> (VName, Space) -> Maybe (VName, Space)
forall a. a -> Maybe a
Just (PatElemT attr -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT attr
pe, Space
space)
                            Type
_         -> Maybe (VName, Space)
forall a. Maybe a
Nothing

compileExp :: Pattern lore -> Exp lore -> ImpM lore r op ()
compileExp :: Pattern lore -> Exp lore -> ImpM lore r op ()
compileExp Pattern lore
pat Exp lore
e = do
  Pattern lore -> Exp lore -> ImpM lore r op ()
ec <- (Env lore r op -> Pattern lore -> Exp lore -> ImpM lore r op ())
-> ImpM lore r op (Pattern lore -> Exp lore -> ImpM lore r op ())
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Pattern lore -> Exp lore -> ImpM lore r op ()
forall lore r op. Env lore r op -> ExpCompiler lore r op
envExpCompiler
  Pattern lore -> Exp lore -> ImpM lore r op ()
ec Pattern lore
pat Exp lore
e

defCompileExp :: (ExplicitMemorish lore) =>
                 Pattern lore -> Exp lore -> ImpM lore r op ()

defCompileExp :: Pattern lore -> Exp lore -> ImpM lore r op ()
defCompileExp Pattern lore
pat (If DimSize
cond BodyT lore
tbranch BodyT lore
fbranch IfAttr (BranchType lore)
_) = do
  Code op
tcode <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (ImpM lore r op () -> ImpM lore r op (Code op))
-> ImpM lore r op () -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$ Pattern lore -> BodyT lore -> ImpM lore r op ()
forall lore r op.
ExplicitMemorish lore =>
Pattern lore -> Body lore -> ImpM lore r op ()
compileBody Pattern lore
pat BodyT lore
tbranch
  Code op
fcode <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (ImpM lore r op () -> ImpM lore r op (Code op))
-> ImpM lore r op () -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$ Pattern lore -> BodyT lore -> ImpM lore r op ()
forall lore r op.
ExplicitMemorish lore =>
Pattern lore -> Body lore -> ImpM lore r op ()
compileBody Pattern lore
pat BodyT lore
fbranch
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Exp -> Code op -> Code op -> Code op
forall a. Exp -> Code a -> Code a -> Code a
Imp.If (PrimType -> DimSize -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
Bool DimSize
cond) Code op
tcode Code op
fcode

defCompileExp Pattern lore
pat (Apply Name
fname [(DimSize, Diet)]
args [RetType lore]
_ (Safety, SrcLoc, [SrcLoc])
_) = do
  Destination
dest <- Pattern lore -> ImpM lore r op Destination
forall lore r op.
ExplicitMemorish lore =>
Pattern lore -> ImpM lore r op Destination
destinationFromPattern Pattern lore
pat
  [VName]
targets <- Destination -> ImpM lore r op [VName]
forall lore r op. Destination -> ImpM lore r op [VName]
funcallTargets Destination
dest
  [Arg]
args' <- [Maybe Arg] -> [Arg]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe Arg] -> [Arg])
-> ImpM lore r op [Maybe Arg] -> ImpM lore r op [Arg]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((DimSize, Diet) -> ImpM lore r op (Maybe Arg))
-> [(DimSize, Diet)] -> ImpM lore r op [Maybe Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (DimSize, Diet) -> ImpM lore r op (Maybe Arg)
forall (m :: * -> *) t b.
(Monad m, HasScope t m) =>
(DimSize, b) -> m (Maybe Arg)
compileArg [(DimSize, Diet)]
args
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Name -> [Arg] -> Code op
forall a. [VName] -> Name -> [Arg] -> Code a
Imp.Call [VName]
targets Name
fname [Arg]
args'
  where compileArg :: (DimSize, b) -> m (Maybe Arg)
compileArg (DimSize
se, b
_) = do
          Type
t <- DimSize -> m Type
forall t (m :: * -> *). HasScope t m => DimSize -> m Type
subExpType DimSize
se
          case (DimSize
se, Type
t) of
            (DimSize
_, Prim PrimType
pt)   -> Maybe Arg -> m (Maybe Arg)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Arg -> m (Maybe Arg)) -> Maybe Arg -> m (Maybe Arg)
forall a b. (a -> b) -> a -> b
$ Arg -> Maybe Arg
forall a. a -> Maybe a
Just (Arg -> Maybe Arg) -> Arg -> Maybe Arg
forall a b. (a -> b) -> a -> b
$ Exp -> Arg
Imp.ExpArg (Exp -> Arg) -> Exp -> Arg
forall a b. (a -> b) -> a -> b
$ PrimType -> DimSize -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
pt DimSize
se
            (Var VName
v, Mem{}) -> Maybe Arg -> m (Maybe Arg)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Arg -> m (Maybe Arg)) -> Maybe Arg -> m (Maybe Arg)
forall a b. (a -> b) -> a -> b
$ Arg -> Maybe Arg
forall a. a -> Maybe a
Just (Arg -> Maybe Arg) -> Arg -> Maybe Arg
forall a b. (a -> b) -> a -> b
$ VName -> Arg
Imp.MemArg VName
v
            (DimSize, Type)
_              -> Maybe Arg -> m (Maybe Arg)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Arg
forall a. Maybe a
Nothing

defCompileExp Pattern lore
pat (BasicOp BasicOp lore
op) = Pattern lore -> BasicOp lore -> ImpM lore r op ()
forall lore r op.
ExplicitMemorish lore =>
Pattern lore -> BasicOp lore -> ImpM lore r op ()
defCompileBasicOp Pattern lore
pat BasicOp lore
op

defCompileExp Pattern lore
pat (DoLoop [(FParam lore, DimSize)]
ctx [(FParam lore, DimSize)]
val LoopForm lore
form BodyT lore
body) = do
  [FParam lore] -> ImpM lore r op ()
forall lore r op.
ExplicitMemorish lore =>
[FParam lore] -> ImpM lore r op ()
dFParams [FParam lore]
[Param (MemInfo DimSize Uniqueness MemBind)]
mergepat
  [(Param (MemInfo DimSize Uniqueness MemBind), DimSize)]
-> ((Param (MemInfo DimSize Uniqueness MemBind), DimSize)
    -> ImpM lore r op ())
-> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(Param (MemInfo DimSize Uniqueness MemBind), DimSize)]
merge (((Param (MemInfo DimSize Uniqueness MemBind), DimSize)
  -> ImpM lore r op ())
 -> ImpM lore r op ())
-> ((Param (MemInfo DimSize Uniqueness MemBind), DimSize)
    -> ImpM lore r op ())
-> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo DimSize Uniqueness MemBind)
p, DimSize
se) ->
    Bool -> ImpM lore r op () -> ImpM lore r op ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ((Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
==Int
0) (Int -> Bool) -> Int -> Bool
forall a b. (a -> b) -> a -> b
$ Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (Type -> Int) -> Type -> Int
forall a b. (a -> b) -> a -> b
$ Param (MemInfo DimSize Uniqueness MemBind) -> Type
forall attr. Typed attr => Param attr -> Type
paramType Param (MemInfo DimSize Uniqueness MemBind)
p) (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
    VName
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (Param (MemInfo DimSize Uniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo DimSize Uniqueness MemBind)
p) [] DimSize
se []

  let doBody :: ImpM lore r op ()
doBody = [Param (MemInfo DimSize Uniqueness MemBind)]
-> BodyT lore -> ImpM lore r op ()
forall attr lore r op.
Typed attr =>
[Param attr] -> Body lore -> ImpM lore r op ()
compileLoopBody [Param (MemInfo DimSize Uniqueness MemBind)]
mergepat BodyT lore
body

  case LoopForm lore
form of
    ForLoop VName
i IntType
it DimSize
bound [(LParam lore, VName)]
loopvars -> do
      let setLoopParam :: (Param (MemInfo DimSize NoUniqueness MemBind), VName)
-> ImpM lore r op ()
setLoopParam (Param (MemInfo DimSize NoUniqueness MemBind)
p,VName
a)
            | Prim PrimType
_ <- Param (MemInfo DimSize NoUniqueness MemBind) -> Type
forall attr. Typed attr => Param attr -> Type
paramType Param (MemInfo DimSize NoUniqueness MemBind)
p =
                VName
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (Param (MemInfo DimSize NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo DimSize NoUniqueness MemBind)
p) [] (VName -> DimSize
Var VName
a) [Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix (Exp -> DimIndex Exp) -> Exp -> DimIndex Exp
forall a b. (a -> b) -> a -> b
$ VName -> Exp
Imp.vi32 VName
i]
            | Bool
otherwise =
                () -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

      [LParam lore] -> ImpM lore r op ()
forall lore r op.
ExplicitMemorish lore =>
[LParam lore] -> ImpM lore r op ()
dLParams ([LParam lore] -> ImpM lore r op ())
-> [LParam lore] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ ((Param (MemInfo DimSize NoUniqueness MemBind), VName)
 -> Param (MemInfo DimSize NoUniqueness MemBind))
-> [(Param (MemInfo DimSize NoUniqueness MemBind), VName)]
-> [Param (MemInfo DimSize NoUniqueness MemBind)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (MemInfo DimSize NoUniqueness MemBind), VName)
-> Param (MemInfo DimSize NoUniqueness MemBind)
forall a b. (a, b) -> a
fst [(LParam lore, VName)]
[(Param (MemInfo DimSize NoUniqueness MemBind), VName)]
loopvars
      VName -> IntType -> Exp -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
VName -> IntType -> Exp -> ImpM lore r op () -> ImpM lore r op ()
sFor' VName
i IntType
it (PrimType -> DimSize -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' (IntType -> PrimType
IntType IntType
it) DimSize
bound) (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
        ((Param (MemInfo DimSize NoUniqueness MemBind), VName)
 -> ImpM lore r op ())
-> [(Param (MemInfo DimSize NoUniqueness MemBind), VName)]
-> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Param (MemInfo DimSize NoUniqueness MemBind), VName)
-> ImpM lore r op ()
setLoopParam [(LParam lore, VName)]
[(Param (MemInfo DimSize NoUniqueness MemBind), VName)]
loopvars ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ImpM lore r op ()
doBody
    WhileLoop VName
cond ->
      Exp -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhile (VName -> PrimType -> Exp
Imp.var VName
cond PrimType
Bool) ImpM lore r op ()
doBody

  Destination Maybe Int
_ [ValueDestination]
pat_dests <- Pattern lore -> ImpM lore r op Destination
forall lore r op.
ExplicitMemorish lore =>
Pattern lore -> ImpM lore r op Destination
destinationFromPattern Pattern lore
pat
  [(ValueDestination, DimSize)]
-> ((ValueDestination, DimSize) -> ImpM lore r op ())
-> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([ValueDestination] -> [DimSize] -> [(ValueDestination, DimSize)]
forall a b. [a] -> [b] -> [(a, b)]
zip [ValueDestination]
pat_dests ([DimSize] -> [(ValueDestination, DimSize)])
-> [DimSize] -> [(ValueDestination, DimSize)]
forall a b. (a -> b) -> a -> b
$ ((Param (MemInfo DimSize Uniqueness MemBind), DimSize) -> DimSize)
-> [(Param (MemInfo DimSize Uniqueness MemBind), DimSize)]
-> [DimSize]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> DimSize
Var (VName -> DimSize)
-> ((Param (MemInfo DimSize Uniqueness MemBind), DimSize) -> VName)
-> (Param (MemInfo DimSize Uniqueness MemBind), DimSize)
-> DimSize
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param (MemInfo DimSize Uniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName (Param (MemInfo DimSize Uniqueness MemBind) -> VName)
-> ((Param (MemInfo DimSize Uniqueness MemBind), DimSize)
    -> Param (MemInfo DimSize Uniqueness MemBind))
-> (Param (MemInfo DimSize Uniqueness MemBind), DimSize)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param (MemInfo DimSize Uniqueness MemBind), DimSize)
-> Param (MemInfo DimSize Uniqueness MemBind)
forall a b. (a, b) -> a
fst) [(Param (MemInfo DimSize Uniqueness MemBind), DimSize)]
merge) (((ValueDestination, DimSize) -> ImpM lore r op ())
 -> ImpM lore r op ())
-> ((ValueDestination, DimSize) -> ImpM lore r op ())
-> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(ValueDestination
d, DimSize
r) ->
    ValueDestination
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
forall lore r op.
ValueDestination
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIMDest ValueDestination
d [] DimSize
r []

  where merge :: [(Param (MemInfo DimSize Uniqueness MemBind), DimSize)]
merge = [(FParam lore, DimSize)]
[(Param (MemInfo DimSize Uniqueness MemBind), DimSize)]
ctx [(Param (MemInfo DimSize Uniqueness MemBind), DimSize)]
-> [(Param (MemInfo DimSize Uniqueness MemBind), DimSize)]
-> [(Param (MemInfo DimSize Uniqueness MemBind), DimSize)]
forall a. [a] -> [a] -> [a]
++ [(FParam lore, DimSize)]
[(Param (MemInfo DimSize Uniqueness MemBind), DimSize)]
val
        mergepat :: [Param (MemInfo DimSize Uniqueness MemBind)]
mergepat = ((Param (MemInfo DimSize Uniqueness MemBind), DimSize)
 -> Param (MemInfo DimSize Uniqueness MemBind))
-> [(Param (MemInfo DimSize Uniqueness MemBind), DimSize)]
-> [Param (MemInfo DimSize Uniqueness MemBind)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (MemInfo DimSize Uniqueness MemBind), DimSize)
-> Param (MemInfo DimSize Uniqueness MemBind)
forall a b. (a, b) -> a
fst [(Param (MemInfo DimSize Uniqueness MemBind), DimSize)]
merge

defCompileExp Pattern lore
pat (Op Op lore
op) = do
  PatternT (MemInfo DimSize NoUniqueness MemBind)
-> Op lore -> ImpM lore r op ()
opc <- (Env lore r op
 -> PatternT (MemInfo DimSize NoUniqueness MemBind)
 -> Op lore
 -> ImpM lore r op ())
-> ImpM
     lore
     r
     op
     (PatternT (MemInfo DimSize NoUniqueness MemBind)
      -> Op lore -> ImpM lore r op ())
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op
-> PatternT (MemInfo DimSize NoUniqueness MemBind)
-> Op lore
-> ImpM lore r op ()
forall lore r op. Env lore r op -> OpCompiler lore r op
envOpCompiler
  PatternT (MemInfo DimSize NoUniqueness MemBind)
-> Op lore -> ImpM lore r op ()
opc Pattern lore
PatternT (MemInfo DimSize NoUniqueness MemBind)
pat Op lore
op

defCompileBasicOp :: ExplicitMemorish lore =>
                     Pattern lore -> BasicOp lore -> ImpM lore r op ()

defCompileBasicOp :: Pattern lore -> BasicOp lore -> ImpM lore r op ()
defCompileBasicOp (Pattern [PatElemT (LetAttr lore)]
_ [PatElemT (LetAttr lore)
pe]) (SubExp DimSize
se) =
  VName
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (PatElemT (MemInfo DimSize NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr lore)
PatElemT (MemInfo DimSize NoUniqueness MemBind)
pe) [] DimSize
se []

defCompileBasicOp (Pattern [PatElemT (LetAttr lore)]
_ [PatElemT (LetAttr lore)
pe]) (Opaque DimSize
se) =
  VName
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (PatElemT (MemInfo DimSize NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr lore)
PatElemT (MemInfo DimSize NoUniqueness MemBind)
pe) [] DimSize
se []

defCompileBasicOp (Pattern [PatElemT (LetAttr lore)]
_ [PatElemT (LetAttr lore)
pe]) (UnOp UnOp
op DimSize
e) = do
  Exp
e' <- DimSize -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp DimSize
e
  PatElemT (MemInfo DimSize NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr lore)
PatElemT (MemInfo DimSize NoUniqueness MemBind)
pe VName -> Exp -> ImpM lore r op ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- UnOp -> Exp -> Exp
forall v. UnOp -> PrimExp v -> PrimExp v
Imp.UnOpExp UnOp
op Exp
e'

defCompileBasicOp (Pattern [PatElemT (LetAttr lore)]
_ [PatElemT (LetAttr lore)
pe]) (ConvOp ConvOp
conv DimSize
e) = do
  Exp
e' <- DimSize -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp DimSize
e
  PatElemT (MemInfo DimSize NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr lore)
PatElemT (MemInfo DimSize NoUniqueness MemBind)
pe VName -> Exp -> ImpM lore r op ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- ConvOp -> Exp -> Exp
forall v. ConvOp -> PrimExp v -> PrimExp v
Imp.ConvOpExp ConvOp
conv Exp
e'

defCompileBasicOp (Pattern [PatElemT (LetAttr lore)]
_ [PatElemT (LetAttr lore)
pe]) (BinOp BinOp
bop DimSize
x DimSize
y) = do
  Exp
x' <- DimSize -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp DimSize
x
  Exp
y' <- DimSize -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp DimSize
y
  PatElemT (MemInfo DimSize NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr lore)
PatElemT (MemInfo DimSize NoUniqueness MemBind)
pe VName -> Exp -> ImpM lore r op ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp BinOp
bop Exp
x' Exp
y'

defCompileBasicOp (Pattern [PatElemT (LetAttr lore)]
_ [PatElemT (LetAttr lore)
pe]) (CmpOp CmpOp
bop DimSize
x DimSize
y) = do
  Exp
x' <- DimSize -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp DimSize
x
  Exp
y' <- DimSize -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp DimSize
y
  PatElemT (MemInfo DimSize NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr lore)
PatElemT (MemInfo DimSize NoUniqueness MemBind)
pe VName -> Exp -> ImpM lore r op ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- CmpOp -> Exp -> Exp -> Exp
forall v. CmpOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.CmpOpExp CmpOp
bop Exp
x' Exp
y'

defCompileBasicOp Pattern lore
_ (Assert DimSize
e ErrorMsg DimSize
msg (SrcLoc, [SrcLoc])
loc) = do
  Exp
e' <- DimSize -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp DimSize
e
  ErrorMsg Exp
msg' <- (DimSize -> ImpM lore r op Exp)
-> ErrorMsg DimSize -> ImpM lore r op (ErrorMsg Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse DimSize -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp ErrorMsg DimSize
msg
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Exp -> ErrorMsg Exp -> (SrcLoc, [SrcLoc]) -> Code op
forall a. Exp -> ErrorMsg Exp -> (SrcLoc, [SrcLoc]) -> Code a
Imp.Assert Exp
e' ErrorMsg Exp
msg' (SrcLoc, [SrcLoc])
loc

defCompileBasicOp (Pattern [PatElemT (LetAttr lore)]
_ [PatElemT (LetAttr lore)
pe]) (Index VName
src Slice DimSize
slice)
  | Just [DimSize]
idxs <- Slice DimSize -> Maybe [DimSize]
forall d. Slice d -> Maybe [d]
sliceIndices Slice DimSize
slice =
      VName
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (PatElemT (MemInfo DimSize NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr lore)
PatElemT (MemInfo DimSize NoUniqueness MemBind)
pe) [] (VName -> DimSize
Var VName
src) ([DimIndex Exp] -> ImpM lore r op ())
-> [DimIndex Exp] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ (DimSize -> DimIndex Exp) -> [DimSize] -> [DimIndex Exp]
forall a b. (a -> b) -> [a] -> [b]
map (Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix (Exp -> DimIndex Exp)
-> (DimSize -> Exp) -> DimSize -> DimIndex Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> DimSize -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32) [DimSize]
idxs

defCompileBasicOp Pattern lore
_ Index{} =
  () -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

defCompileBasicOp (Pattern [PatElemT (LetAttr lore)]
_ [PatElemT (LetAttr lore)
pe]) (Update VName
_ Slice DimSize
slice DimSize
se) =
  VName -> [DimIndex Exp] -> DimSize -> ImpM lore r op ()
forall lore r op.
VName -> [DimIndex Exp] -> DimSize -> ImpM lore r op ()
sUpdate (PatElemT (MemInfo DimSize NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr lore)
PatElemT (MemInfo DimSize NoUniqueness MemBind)
pe) ((DimIndex DimSize -> DimIndex Exp)
-> Slice DimSize -> [DimIndex Exp]
forall a b. (a -> b) -> [a] -> [b]
map ((DimSize -> Exp) -> DimIndex DimSize -> DimIndex Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (PrimType -> DimSize -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32)) Slice DimSize
slice) DimSize
se

defCompileBasicOp (Pattern [PatElemT (LetAttr lore)]
_ [PatElemT (LetAttr lore)
pe]) (Replicate (Shape [DimSize]
ds) DimSize
se) = do
  [Exp]
ds' <- (DimSize -> ImpM lore r op Exp)
-> [DimSize] -> ImpM lore r op [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimSize -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp [DimSize]
ds
  [VName]
is <- Int -> ImpM lore r op VName -> ImpM lore r op [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM ([DimSize] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimSize]
ds) (String -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"i")
  Code op
copy_elem <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (ImpM lore r op () -> ImpM lore r op (Code op))
-> ImpM lore r op () -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$ VName
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (PatElemT (MemInfo DimSize NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr lore)
PatElemT (MemInfo DimSize NoUniqueness MemBind)
pe) ((VName -> DimIndex Exp) -> [VName] -> [DimIndex Exp]
forall a b. (a -> b) -> [a] -> [b]
map (Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix (Exp -> DimIndex Exp) -> (VName -> Exp) -> VName -> DimIndex Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Exp
Imp.vi32) [VName]
is) DimSize
se []
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ ((Code op -> Code op)
 -> (Code op -> Code op) -> Code op -> Code op)
-> (Code op -> Code op)
-> [Code op -> Code op]
-> Code op
-> Code op
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (Code op -> Code op) -> (Code op -> Code op) -> Code op -> Code op
forall b c a. (b -> c) -> (a -> b) -> a -> c
(.) Code op -> Code op
forall a. a -> a
id ((VName -> Exp -> Code op -> Code op)
-> [VName] -> [Exp] -> [Code op -> Code op]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (VName -> IntType -> Exp -> Code op -> Code op
forall a. VName -> IntType -> Exp -> Code a -> Code a
`Imp.For` IntType
Int32) [VName]
is [Exp]
ds') Code op
copy_elem

defCompileBasicOp Pattern lore
_ Scratch{} =
  () -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

defCompileBasicOp (Pattern [] [PatElemT (LetAttr lore)
pe]) (Iota DimSize
n DimSize
e DimSize
s IntType
it) = do
  Exp
n' <- DimSize -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp DimSize
n
  Exp
e' <- DimSize -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp DimSize
e
  Exp
s' <- DimSize -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp DimSize
s
  String -> Exp -> (Exp -> ImpM lore r op ()) -> ImpM lore r op ()
forall lore r op.
String -> Exp -> (Exp -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"i" Exp
n' ((Exp -> ImpM lore r op ()) -> ImpM lore r op ())
-> (Exp -> ImpM lore r op ()) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \Exp
i -> do
    let i' :: Exp
i' = ConvOp -> Exp -> Exp
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (IntType -> IntType -> ConvOp
SExt IntType
Int32 IntType
it) Exp
i
    VName
x <- String -> Exp -> ImpM lore r op VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"x" (Exp -> ImpM lore r op VName) -> Exp -> ImpM lore r op VName
forall a b. (a -> b) -> a -> b
$ Exp
e' Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
i' Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
s'
    VName
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (PatElemT (MemInfo DimSize NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr lore)
PatElemT (MemInfo DimSize NoUniqueness MemBind)
pe) [Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix Exp
i] (VName -> DimSize
Var VName
x) []

defCompileBasicOp (Pattern [PatElemT (LetAttr lore)]
_ [PatElemT (LetAttr lore)
pe]) (Copy VName
src) =
  VName
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (PatElemT (MemInfo DimSize NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr lore)
PatElemT (MemInfo DimSize NoUniqueness MemBind)
pe) [] (VName -> DimSize
Var VName
src) []

defCompileBasicOp (Pattern [PatElemT (LetAttr lore)]
_ [PatElemT (LetAttr lore)
pe]) (Manifest [Int]
_ VName
src) =
  VName
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (PatElemT (MemInfo DimSize NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr lore)
PatElemT (MemInfo DimSize NoUniqueness MemBind)
pe) [] (VName -> DimSize
Var VName
src) []

defCompileBasicOp (Pattern [PatElemT (LetAttr lore)]
_ [PatElemT (LetAttr lore)
pe]) (Concat Int
i VName
x [VName]
ys DimSize
_) = do
    MemLocation VName
destmem [DimSize]
destshape IxFun Exp
destixfun <-
      ArrayEntry -> MemLocation
entryArrayLocation (ArrayEntry -> MemLocation)
-> ImpM lore r op ArrayEntry -> ImpM lore r op MemLocation
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM lore r op ArrayEntry
forall lore r op. VName -> ImpM lore r op ArrayEntry
lookupArray (PatElemT (MemInfo DimSize NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr lore)
PatElemT (MemInfo DimSize NoUniqueness MemBind)
pe)
    VName
offs_glb <- String -> PrimType -> ImpM lore r op VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim String
"tmp_offs" PrimType
int32
    Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
offs_glb Exp
0
    let perm :: [Int]
perm = [Int
i] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
0..Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1..[DimSize] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimSize]
destshapeInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
        invperm :: [Int]
invperm = [Int] -> [Int]
rearrangeInverse [Int]
perm
        destloc :: MemLocation
destloc = VName -> [DimSize] -> IxFun Exp -> MemLocation
MemLocation VName
destmem [DimSize]
destshape
                  (IxFun Exp -> [Int] -> IxFun Exp
forall num. IntegralExp num => IxFun num -> [Int] -> IxFun num
IxFun.permute (IxFun Exp -> Exp -> IxFun Exp
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> IxFun num
IxFun.offsetIndex (IxFun Exp -> [Int] -> IxFun Exp
forall num. IntegralExp num => IxFun num -> [Int] -> IxFun num
IxFun.permute IxFun Exp
destixfun [Int]
perm) (Exp -> IxFun Exp) -> Exp -> IxFun Exp
forall a b. (a -> b) -> a -> b
$
                                  VName -> Exp
Imp.vi32 VName
offs_glb)
                   [Int]
invperm)

    [VName] -> (VName -> ImpM lore r op ()) -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (VName
xVName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
:[VName]
ys) ((VName -> ImpM lore r op ()) -> ImpM lore r op ())
-> (VName -> ImpM lore r op ()) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \VName
y -> do
      ArrayEntry
yentry <- VName -> ImpM lore r op ArrayEntry
forall lore r op. VName -> ImpM lore r op ArrayEntry
lookupArray VName
y
      let srcloc :: MemLocation
srcloc = ArrayEntry -> MemLocation
entryArrayLocation ArrayEntry
yentry
          rows :: Exp
rows = case Int -> [DimSize] -> [DimSize]
forall a. Int -> [a] -> [a]
drop Int
i ([DimSize] -> [DimSize]) -> [DimSize] -> [DimSize]
forall a b. (a -> b) -> a -> b
$ ArrayEntry -> [DimSize]
entryArrayShape ArrayEntry
yentry of
                  []  -> String -> Exp
forall a. HasCallStack => String -> a
error (String -> Exp) -> String -> Exp
forall a b. (a -> b) -> a -> b
$ String
"defCompileBasicOp Concat: empty array shape for " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
y
                  DimSize
r:[DimSize]
_ -> PrimType -> DimSize -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 DimSize
r
      CopyCompiler lore r op
forall lore r op. CopyCompiler lore r op
copy (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType (Type -> PrimType) -> Type -> PrimType
forall a b. (a -> b) -> a -> b
$ PatElemT (MemInfo DimSize NoUniqueness MemBind) -> Type
forall attr. Typed attr => PatElemT attr -> Type
patElemType PatElemT (LetAttr lore)
PatElemT (MemInfo DimSize NoUniqueness MemBind)
pe) MemLocation
destloc MemLocation
srcloc
      Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
offs_glb (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
offs_glb PrimType
int32 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
rows

defCompileBasicOp (Pattern [] [PatElemT (LetAttr lore)
pe]) (ArrayLit [DimSize]
es Type
_)
  | Just vs :: [PrimValue]
vs@(PrimValue
v:[PrimValue]
_) <- (DimSize -> Maybe PrimValue) -> [DimSize] -> Maybe [PrimValue]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimSize -> Maybe PrimValue
isLiteral [DimSize]
es = do
      MemLocation
dest_mem <- ArrayEntry -> MemLocation
entryArrayLocation (ArrayEntry -> MemLocation)
-> ImpM lore r op ArrayEntry -> ImpM lore r op MemLocation
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM lore r op ArrayEntry
forall lore r op. VName -> ImpM lore r op ArrayEntry
lookupArray (PatElemT (MemInfo DimSize NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr lore)
PatElemT (MemInfo DimSize NoUniqueness MemBind)
pe)
      Space
dest_space <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM lore r op MemEntry -> ImpM lore r op Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM lore r op MemEntry
forall lore r op. VName -> ImpM lore r op MemEntry
lookupMemory (MemLocation -> VName
memLocationName MemLocation
dest_mem)
      let t :: PrimType
t = PrimValue -> PrimType
primValueType PrimValue
v
      VName
static_array <- String -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"static_array"
      Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Space -> PrimType -> ArrayContents -> Code op
forall a. VName -> Space -> PrimType -> ArrayContents -> Code a
Imp.DeclareArray VName
static_array Space
dest_space PrimType
t (ArrayContents -> Code op) -> ArrayContents -> Code op
forall a b. (a -> b) -> a -> b
$ [PrimValue] -> ArrayContents
Imp.ArrayValues [PrimValue]
vs
      let static_src :: MemLocation
static_src = VName -> [DimSize] -> IxFun Exp -> MemLocation
MemLocation VName
static_array [IntType -> Integer -> DimSize
intConst IntType
Int32 (Integer -> DimSize) -> Integer -> DimSize
forall a b. (a -> b) -> a -> b
$ Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Integer) -> Int -> Integer
forall a b. (a -> b) -> a -> b
$ [DimSize] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimSize]
es] (IxFun Exp -> MemLocation) -> IxFun Exp -> MemLocation
forall a b. (a -> b) -> a -> b
$
                       [Exp] -> IxFun Exp
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota [Int -> Exp
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Exp) -> Int -> Exp
forall a b. (a -> b) -> a -> b
$ [DimSize] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimSize]
es]
          entry :: VarEntry lore
entry = Maybe (Exp lore) -> MemEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> MemEntry -> VarEntry lore
MemVar Maybe (Exp lore)
forall a. Maybe a
Nothing (MemEntry -> VarEntry lore) -> MemEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
dest_space
      VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
static_array VarEntry lore
entry
      CopyCompiler lore r op
forall lore r op. CopyCompiler lore r op
copy PrimType
t MemLocation
dest_mem MemLocation
static_src
  | Bool
otherwise =
    [(Integer, DimSize)]
-> ((Integer, DimSize) -> ImpM lore r op ()) -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Integer] -> [DimSize] -> [(Integer, DimSize)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Integer
0..] [DimSize]
es) (((Integer, DimSize) -> ImpM lore r op ()) -> ImpM lore r op ())
-> ((Integer, DimSize) -> ImpM lore r op ()) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(Integer
i,DimSize
e) ->
      VName
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (PatElemT (MemInfo DimSize NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr lore)
PatElemT (MemInfo DimSize NoUniqueness MemBind)
pe) [Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix (Exp -> DimIndex Exp) -> Exp -> DimIndex Exp
forall a b. (a -> b) -> a -> b
$ Integer -> Exp
forall a. Num a => Integer -> a
fromInteger Integer
i] DimSize
e []

  where isLiteral :: DimSize -> Maybe PrimValue
isLiteral (Constant PrimValue
v) = PrimValue -> Maybe PrimValue
forall a. a -> Maybe a
Just PrimValue
v
        isLiteral DimSize
_ = Maybe PrimValue
forall a. Maybe a
Nothing

defCompileBasicOp Pattern lore
_ Rearrange{} =
  () -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

defCompileBasicOp Pattern lore
_ Rotate{} =
  () -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

defCompileBasicOp Pattern lore
_ Reshape{} =
  () -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

defCompileBasicOp Pattern lore
_ Repeat{} =
  () -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

defCompileBasicOp Pattern lore
pat BasicOp lore
e =
  String -> ImpM lore r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op ()) -> String -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ String
"ImpGen.defCompileBasicOp: Invalid pattern\n  " String -> ShowS
forall a. [a] -> [a] -> [a]
++
  PatternT (MemInfo DimSize NoUniqueness MemBind) -> String
forall a. Pretty a => a -> String
pretty Pattern lore
PatternT (MemInfo DimSize NoUniqueness MemBind)
pat String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"\nfor expression\n  " String -> ShowS
forall a. [a] -> [a] -> [a]
++ BasicOp lore -> String
forall a. Pretty a => a -> String
pretty BasicOp lore
e

-- | Note: a hack to be used only for functions.
addArrays :: [ArrayDecl] -> ImpM lore r op ()
addArrays :: [ArrayDecl] -> ImpM lore r op ()
addArrays = (ArrayDecl -> ImpM lore r op ())
-> [ArrayDecl] -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ArrayDecl -> ImpM lore r op ()
forall lore r op. ArrayDecl -> ImpM lore r op ()
addArray
  where addArray :: ArrayDecl -> ImpM lore r op ()
addArray (ArrayDecl VName
name PrimType
bt MemLocation
location) =
          VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
          Maybe (Exp lore) -> ArrayEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> ArrayEntry -> VarEntry lore
ArrayVar Maybe (Exp lore)
forall a. Maybe a
Nothing ArrayEntry :: MemLocation -> PrimType -> ArrayEntry
ArrayEntry
          { entryArrayLocation :: MemLocation
entryArrayLocation = MemLocation
location
          , entryArrayElemType :: PrimType
entryArrayElemType = PrimType
bt
          }

-- | Like 'dFParams', but does not create new declarations.
-- Note: a hack to be used only for functions.
addFParams :: ExplicitMemorish lore => [FParam lore] -> ImpM lore r op ()
addFParams :: [FParam lore] -> ImpM lore r op ()
addFParams = (Param (MemInfo DimSize Uniqueness MemBind) -> ImpM lore r op ())
-> [Param (MemInfo DimSize Uniqueness MemBind)]
-> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Param (MemInfo DimSize Uniqueness MemBind) -> ImpM lore r op ()
forall u lore r op.
Param (MemInfo DimSize u MemBind) -> ImpM lore r op ()
addFParam
  where addFParam :: Param (MemInfo DimSize u MemBind) -> ImpM lore r op ()
addFParam Param (MemInfo DimSize u MemBind)
fparam =
          VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar (Param (MemInfo DimSize u MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo DimSize u MemBind)
fparam) (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
          Maybe (Exp lore)
-> MemInfo DimSize NoUniqueness MemBind -> VarEntry lore
forall lore.
Maybe (Exp lore)
-> MemInfo DimSize NoUniqueness MemBind -> VarEntry lore
memBoundToVarEntry Maybe (Exp lore)
forall a. Maybe a
Nothing (MemInfo DimSize NoUniqueness MemBind -> VarEntry lore)
-> MemInfo DimSize NoUniqueness MemBind -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ MemInfo DimSize u MemBind -> MemInfo DimSize NoUniqueness MemBind
forall d u r. MemInfo d u r -> MemInfo d NoUniqueness r
noUniquenessReturns (MemInfo DimSize u MemBind -> MemInfo DimSize NoUniqueness MemBind)
-> MemInfo DimSize u MemBind
-> MemInfo DimSize NoUniqueness MemBind
forall a b. (a -> b) -> a -> b
$ Param (MemInfo DimSize u MemBind) -> MemInfo DimSize u MemBind
forall attr. Param attr -> attr
paramAttr Param (MemInfo DimSize u MemBind)
fparam

-- | Another hack.
addLoopVar :: VName -> IntType -> ImpM lore r op ()
addLoopVar :: VName -> IntType -> ImpM lore r op ()
addLoopVar VName
i IntType
it = VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
i (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
ScalarVar Maybe (Exp lore)
forall a. Maybe a
Nothing (ScalarEntry -> VarEntry lore) -> ScalarEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarEntry
ScalarEntry (PrimType -> ScalarEntry) -> PrimType -> ScalarEntry
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
it

dVars :: ExplicitMemorish lore =>
            Maybe (Exp lore) -> [PatElem lore] -> ImpM lore r op ()
dVars :: Maybe (Exp lore) -> [PatElem lore] -> ImpM lore r op ()
dVars Maybe (Exp lore)
e = (PatElemT (MemInfo DimSize NoUniqueness MemBind)
 -> ImpM lore r op ())
-> [PatElemT (MemInfo DimSize NoUniqueness MemBind)]
-> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ PatElemT (MemInfo DimSize NoUniqueness MemBind)
-> ImpM lore r op ()
dVar
  where dVar :: PatElemT (MemInfo DimSize NoUniqueness MemBind)
-> ImpM lore r op ()
dVar = Maybe (Exp lore)
-> Map VName (NameInfo ExplicitMemory) -> ImpM lore r op ()
forall lore r op.
Maybe (Exp lore)
-> Map VName (NameInfo ExplicitMemory) -> ImpM lore r op ()
dScope Maybe (Exp lore)
e (Map VName (NameInfo ExplicitMemory) -> ImpM lore r op ())
-> (PatElemT (MemInfo DimSize NoUniqueness MemBind)
    -> Map VName (NameInfo ExplicitMemory))
-> PatElemT (MemInfo DimSize NoUniqueness MemBind)
-> ImpM lore r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT (MemInfo DimSize NoUniqueness MemBind)
-> Map VName (NameInfo ExplicitMemory)
forall lore attr.
(LetAttr lore ~ attr) =>
PatElemT attr -> Scope lore
scopeOfPatElem

dFParams :: ExplicitMemorish lore => [FParam lore] -> ImpM lore r op ()
dFParams :: [FParam lore] -> ImpM lore r op ()
dFParams = Maybe (Exp lore)
-> Map VName (NameInfo ExplicitMemory) -> ImpM lore r op ()
forall lore r op.
Maybe (Exp lore)
-> Map VName (NameInfo ExplicitMemory) -> ImpM lore r op ()
dScope Maybe (Exp lore)
forall a. Maybe a
Nothing (Map VName (NameInfo ExplicitMemory) -> ImpM lore r op ())
-> ([Param (MemInfo DimSize Uniqueness MemBind)]
    -> Map VName (NameInfo ExplicitMemory))
-> [Param (MemInfo DimSize Uniqueness MemBind)]
-> ImpM lore r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Param (MemInfo DimSize Uniqueness MemBind)]
-> Map VName (NameInfo ExplicitMemory)
forall lore attr.
(FParamAttr lore ~ attr) =>
[Param attr] -> Scope lore
scopeOfFParams

dLParams :: ExplicitMemorish lore => [LParam lore] -> ImpM lore r op ()
dLParams :: [LParam lore] -> ImpM lore r op ()
dLParams = Maybe (Exp lore)
-> Map VName (NameInfo ExplicitMemory) -> ImpM lore r op ()
forall lore r op.
Maybe (Exp lore)
-> Map VName (NameInfo ExplicitMemory) -> ImpM lore r op ()
dScope Maybe (Exp lore)
forall a. Maybe a
Nothing (Map VName (NameInfo ExplicitMemory) -> ImpM lore r op ())
-> ([Param (MemInfo DimSize NoUniqueness MemBind)]
    -> Map VName (NameInfo ExplicitMemory))
-> [Param (MemInfo DimSize NoUniqueness MemBind)]
-> ImpM lore r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Param (MemInfo DimSize NoUniqueness MemBind)]
-> Map VName (NameInfo ExplicitMemory)
forall lore attr.
(LParamAttr lore ~ attr) =>
[Param attr] -> Scope lore
scopeOfLParams

dPrimVol_ :: VName -> PrimType -> ImpM lore r op ()
dPrimVol_ :: VName -> PrimType -> ImpM lore r op ()
dPrimVol_ VName
name PrimType
t = do
 Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Volatility -> PrimType -> Code op
forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
name Volatility
Imp.Volatile PrimType
t
 VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
ScalarVar Maybe (Exp lore)
forall a. Maybe a
Nothing (ScalarEntry -> VarEntry lore) -> ScalarEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarEntry
ScalarEntry PrimType
t

dPrim_ :: VName -> PrimType -> ImpM lore r op ()
dPrim_ :: VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
name PrimType
t = do
 Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Volatility -> PrimType -> Code op
forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
name Volatility
Imp.Nonvolatile PrimType
t
 VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
ScalarVar Maybe (Exp lore)
forall a. Maybe a
Nothing (ScalarEntry -> VarEntry lore) -> ScalarEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarEntry
ScalarEntry PrimType
t

dPrim :: String -> PrimType -> ImpM lore r op VName
dPrim :: String -> PrimType -> ImpM lore r op VName
dPrim String
name PrimType
t = do VName
name' <- String -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
name
                  VName -> PrimType -> ImpM lore r op ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
name' PrimType
t
                  VName -> ImpM lore r op VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
name'

dPrimV_ :: VName -> Imp.Exp -> ImpM lore r op ()
dPrimV_ :: VName -> Exp -> ImpM lore r op ()
dPrimV_ VName
name Exp
e = do VName -> PrimType -> ImpM lore r op ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
name (PrimType -> ImpM lore r op ()) -> PrimType -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType Exp
e
                    VName
name VName -> Exp -> ImpM lore r op ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- Exp
e

dPrimV :: String -> Imp.Exp -> ImpM lore r op VName
dPrimV :: String -> Exp -> ImpM lore r op VName
dPrimV String
name Exp
e = do VName
name' <- String -> PrimType -> ImpM lore r op VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim String
name (PrimType -> ImpM lore r op VName)
-> PrimType -> ImpM lore r op VName
forall a b. (a -> b) -> a -> b
$ Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType Exp
e
                   VName
name' VName -> Exp -> ImpM lore r op ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- Exp
e
                   VName -> ImpM lore r op VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
name'

dPrimVE :: String -> Imp.Exp -> ImpM lore r op Imp.Exp
dPrimVE :: String -> Exp -> ImpM lore r op Exp
dPrimVE String
name Exp
e = do VName
name' <- String -> PrimType -> ImpM lore r op VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim String
name (PrimType -> ImpM lore r op VName)
-> PrimType -> ImpM lore r op VName
forall a b. (a -> b) -> a -> b
$ Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType Exp
e
                    VName
name' VName -> Exp -> ImpM lore r op ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- Exp
e
                    Exp -> ImpM lore r op Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> ImpM lore r op Exp) -> Exp -> ImpM lore r op Exp
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
name' (PrimType -> Exp) -> PrimType -> Exp
forall a b. (a -> b) -> a -> b
$ Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType Exp
e

memBoundToVarEntry :: Maybe (Exp lore) -> MemBound NoUniqueness
                   -> VarEntry lore
memBoundToVarEntry :: Maybe (Exp lore)
-> MemInfo DimSize NoUniqueness MemBind -> VarEntry lore
memBoundToVarEntry Maybe (Exp lore)
e (MemPrim PrimType
bt) =
  Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
ScalarVar Maybe (Exp lore)
e ScalarEntry :: PrimType -> ScalarEntry
ScalarEntry { entryScalarType :: PrimType
entryScalarType = PrimType
bt }
memBoundToVarEntry Maybe (Exp lore)
e (MemMem Space
space) =
  Maybe (Exp lore) -> MemEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> MemEntry -> VarEntry lore
MemVar Maybe (Exp lore)
e (MemEntry -> VarEntry lore) -> MemEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
space
memBoundToVarEntry Maybe (Exp lore)
e (MemArray PrimType
bt ShapeBase DimSize
shape NoUniqueness
_ (ArrayIn VName
mem IxFun
ixfun)) =
  let location :: MemLocation
location = VName -> [DimSize] -> IxFun Exp -> MemLocation
MemLocation VName
mem (ShapeBase DimSize -> [DimSize]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase DimSize
shape) (IxFun Exp -> MemLocation) -> IxFun Exp -> MemLocation
forall a b. (a -> b) -> a -> b
$ (PrimExp VName -> Exp) -> IxFun -> IxFun Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (PrimType -> PrimExp VName -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32) IxFun
ixfun
  in Maybe (Exp lore) -> ArrayEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> ArrayEntry -> VarEntry lore
ArrayVar Maybe (Exp lore)
e ArrayEntry :: MemLocation -> PrimType -> ArrayEntry
ArrayEntry { entryArrayLocation :: MemLocation
entryArrayLocation = MemLocation
location
                           , entryArrayElemType :: PrimType
entryArrayElemType = PrimType
bt
                           }

infoAttr :: NameInfo ExplicitMemory
         -> MemInfo SubExp NoUniqueness MemBind
infoAttr :: NameInfo ExplicitMemory -> MemInfo DimSize NoUniqueness MemBind
infoAttr (LetInfo LetAttr ExplicitMemory
attr) = LetAttr ExplicitMemory
MemInfo DimSize NoUniqueness MemBind
attr
infoAttr (FParamInfo FParamAttr ExplicitMemory
attr) = MemInfo DimSize Uniqueness MemBind
-> MemInfo DimSize NoUniqueness MemBind
forall d u r. MemInfo d u r -> MemInfo d NoUniqueness r
noUniquenessReturns FParamAttr ExplicitMemory
MemInfo DimSize Uniqueness MemBind
attr
infoAttr (LParamInfo LParamAttr ExplicitMemory
attr) = LParamAttr ExplicitMemory
MemInfo DimSize NoUniqueness MemBind
attr
infoAttr (IndexInfo IntType
it) = PrimType -> MemInfo DimSize NoUniqueness MemBind
forall d u ret. PrimType -> MemInfo d u ret
MemPrim (PrimType -> MemInfo DimSize NoUniqueness MemBind)
-> PrimType -> MemInfo DimSize NoUniqueness MemBind
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
it

dInfo :: Maybe (Exp lore) -> VName -> NameInfo ExplicitMemory
         -> ImpM lore r op ()
dInfo :: Maybe (Exp lore)
-> VName -> NameInfo ExplicitMemory -> ImpM lore r op ()
dInfo Maybe (Exp lore)
e VName
name NameInfo ExplicitMemory
info = do
  let entry :: VarEntry lore
entry = Maybe (Exp lore)
-> MemInfo DimSize NoUniqueness MemBind -> VarEntry lore
forall lore.
Maybe (Exp lore)
-> MemInfo DimSize NoUniqueness MemBind -> VarEntry lore
memBoundToVarEntry Maybe (Exp lore)
e (MemInfo DimSize NoUniqueness MemBind -> VarEntry lore)
-> MemInfo DimSize NoUniqueness MemBind -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ NameInfo ExplicitMemory -> MemInfo DimSize NoUniqueness MemBind
infoAttr NameInfo ExplicitMemory
info
  case VarEntry lore
entry of
    MemVar Maybe (Exp lore)
_ MemEntry
entry' ->
      Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Code op
forall a. VName -> Space -> Code a
Imp.DeclareMem VName
name (Space -> Code op) -> Space -> Code op
forall a b. (a -> b) -> a -> b
$ MemEntry -> Space
entryMemSpace MemEntry
entry'
    ScalarVar Maybe (Exp lore)
_ ScalarEntry
entry' ->
      Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Volatility -> PrimType -> Code op
forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
name Volatility
Imp.Nonvolatile (PrimType -> Code op) -> PrimType -> Code op
forall a b. (a -> b) -> a -> b
$ ScalarEntry -> PrimType
entryScalarType ScalarEntry
entry'
    ArrayVar Maybe (Exp lore)
_ ArrayEntry
_ ->
      () -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name VarEntry lore
entry

dScope :: Maybe (Exp lore) -> Scope ExplicitMemory -> ImpM lore r op ()
dScope :: Maybe (Exp lore)
-> Map VName (NameInfo ExplicitMemory) -> ImpM lore r op ()
dScope Maybe (Exp lore)
e = ((VName, NameInfo ExplicitMemory) -> ImpM lore r op ())
-> [(VName, NameInfo ExplicitMemory)] -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((VName -> NameInfo ExplicitMemory -> ImpM lore r op ())
-> (VName, NameInfo ExplicitMemory) -> ImpM lore r op ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((VName -> NameInfo ExplicitMemory -> ImpM lore r op ())
 -> (VName, NameInfo ExplicitMemory) -> ImpM lore r op ())
-> (VName -> NameInfo ExplicitMemory -> ImpM lore r op ())
-> (VName, NameInfo ExplicitMemory)
-> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore)
-> VName -> NameInfo ExplicitMemory -> ImpM lore r op ()
forall lore r op.
Maybe (Exp lore)
-> VName -> NameInfo ExplicitMemory -> ImpM lore r op ()
dInfo Maybe (Exp lore)
e) ([(VName, NameInfo ExplicitMemory)] -> ImpM lore r op ())
-> (Map VName (NameInfo ExplicitMemory)
    -> [(VName, NameInfo ExplicitMemory)])
-> Map VName (NameInfo ExplicitMemory)
-> ImpM lore r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map VName (NameInfo ExplicitMemory)
-> [(VName, NameInfo ExplicitMemory)]
forall k a. Map k a -> [(k, a)]
M.toList

dArray :: VName -> PrimType -> ShapeBase SubExp -> MemBind -> ImpM lore r op ()
dArray :: VName
-> PrimType -> ShapeBase DimSize -> MemBind -> ImpM lore r op ()
dArray VName
name PrimType
bt ShapeBase DimSize
shape MemBind
membind =
  VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
  Maybe (Exp lore)
-> MemInfo DimSize NoUniqueness MemBind -> VarEntry lore
forall lore.
Maybe (Exp lore)
-> MemInfo DimSize NoUniqueness MemBind -> VarEntry lore
memBoundToVarEntry Maybe (Exp lore)
forall a. Maybe a
Nothing (MemInfo DimSize NoUniqueness MemBind -> VarEntry lore)
-> MemInfo DimSize NoUniqueness MemBind -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ PrimType
-> ShapeBase DimSize
-> NoUniqueness
-> MemBind
-> MemInfo DimSize NoUniqueness MemBind
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt ShapeBase DimSize
shape NoUniqueness
NoUniqueness MemBind
membind

everythingVolatile :: ImpM lore r op a -> ImpM lore r op a
everythingVolatile :: ImpM lore r op a -> ImpM lore r op a
everythingVolatile = (Env lore r op -> Env lore r op)
-> ImpM lore r op a -> ImpM lore r op a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((Env lore r op -> Env lore r op)
 -> ImpM lore r op a -> ImpM lore r op a)
-> (Env lore r op -> Env lore r op)
-> ImpM lore r op a
-> ImpM lore r op a
forall a b. (a -> b) -> a -> b
$ \Env lore r op
env -> Env lore r op
env { envVolatility :: Volatility
envVolatility = Volatility
Imp.Volatile }

-- | Remove the array targets.
funcallTargets :: Destination -> ImpM lore r op [VName]
funcallTargets :: Destination -> ImpM lore r op [VName]
funcallTargets (Destination Maybe Int
_ [ValueDestination]
dests) =
  [[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[VName]] -> [VName])
-> ImpM lore r op [[VName]] -> ImpM lore r op [VName]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ValueDestination -> ImpM lore r op [VName])
-> [ValueDestination] -> ImpM lore r op [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ValueDestination -> ImpM lore r op [VName]
forall (m :: * -> *). Monad m => ValueDestination -> m [VName]
funcallTarget [ValueDestination]
dests
  where funcallTarget :: ValueDestination -> m [VName]
funcallTarget (ScalarDestination VName
name) =
          [VName] -> m [VName]
forall (m :: * -> *) a. Monad m => a -> m a
return [VName
name]
        funcallTarget (ArrayDestination Maybe MemLocation
_) =
          [VName] -> m [VName]
forall (m :: * -> *) a. Monad m => a -> m a
return []
        funcallTarget (MemoryDestination VName
name) =
          [VName] -> m [VName]
forall (m :: * -> *) a. Monad m => a -> m a
return [VName
name]

-- | Compile things to 'Imp.Exp'.
class ToExp a where
  -- | Compile to an 'Imp.Exp', where the type (must must still be a
  -- primitive) is deduced monadically.
  toExp :: a -> ImpM lore r op Imp.Exp
  -- | Compile where we know the type in advance.
  toExp' :: PrimType -> a -> Imp.Exp

instance ToExp SubExp where
  toExp :: DimSize -> ImpM lore r op Exp
toExp (Constant PrimValue
v) =
    Exp -> ImpM lore r op Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> ImpM lore r op Exp) -> Exp -> ImpM lore r op Exp
forall a b. (a -> b) -> a -> b
$ PrimValue -> Exp
forall v. PrimValue -> PrimExp v
Imp.ValueExp PrimValue
v
  toExp (Var VName
v) =
    VName -> ImpM lore r op (VarEntry lore)
forall lore r op. VName -> ImpM lore r op (VarEntry lore)
lookupVar VName
v ImpM lore r op (VarEntry lore)
-> (VarEntry lore -> ImpM lore r op Exp) -> ImpM lore r op Exp
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    ScalarVar Maybe (Exp lore)
_ (ScalarEntry PrimType
pt) ->
      Exp -> ImpM lore r op Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> ImpM lore r op Exp) -> Exp -> ImpM lore r op Exp
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
v PrimType
pt
    VarEntry lore
_       -> String -> ImpM lore r op Exp
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op Exp) -> String -> ImpM lore r op Exp
forall a b. (a -> b) -> a -> b
$ String
"toExp SubExp: SubExp is not a primitive type: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
v

  toExp' :: PrimType -> DimSize -> Exp
toExp' PrimType
_ (Constant PrimValue
v) = PrimValue -> Exp
forall v. PrimValue -> PrimExp v
Imp.ValueExp PrimValue
v
  toExp' PrimType
t (Var VName
v) = VName -> PrimType -> Exp
Imp.var VName
v PrimType
t

instance ToExp (PrimExp VName) where
  toExp :: PrimExp VName -> ImpM lore r op Exp
toExp = Exp -> ImpM lore r op Exp
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp -> ImpM lore r op Exp)
-> (PrimExp VName -> Exp) -> PrimExp VName -> ImpM lore r op Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> ExpLeaf) -> PrimExp VName -> Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> ExpLeaf
Imp.ScalarVar
  toExp' :: PrimType -> PrimExp VName -> Exp
toExp' PrimType
_ = (VName -> ExpLeaf) -> PrimExp VName -> Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> ExpLeaf
Imp.ScalarVar

addVar :: VName -> VarEntry lore -> ImpM lore r op ()
addVar :: VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name VarEntry lore
entry =
  (State lore r op -> State lore r op) -> ImpM lore r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((State lore r op -> State lore r op) -> ImpM lore r op ())
-> (State lore r op -> State lore r op) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \State lore r op
s -> State lore r op
s { stateVTable :: VTable lore
stateVTable = VName -> VarEntry lore -> VTable lore -> VTable lore
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
name VarEntry lore
entry (VTable lore -> VTable lore) -> VTable lore -> VTable lore
forall a b. (a -> b) -> a -> b
$ State lore r op -> VTable lore
forall lore r op. State lore r op -> VTable lore
stateVTable State lore r op
s }

localDefaultSpace :: Imp.Space -> ImpM lore r op a -> ImpM lore r op a
localDefaultSpace :: Space -> ImpM lore r op a -> ImpM lore r op a
localDefaultSpace Space
space = (Env lore r op -> Env lore r op)
-> ImpM lore r op a -> ImpM lore r op a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (\Env lore r op
env -> Env lore r op
env { envDefaultSpace :: Space
envDefaultSpace = Space
space })

askFunction :: ImpM lore r op (Maybe Name)
askFunction :: ImpM lore r op (Maybe Name)
askFunction = (Env lore r op -> Maybe Name) -> ImpM lore r op (Maybe Name)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Maybe Name
forall lore r op. Env lore r op -> Maybe Name
envFunction

askEnv :: ImpM lore r op r
askEnv :: ImpM lore r op r
askEnv = (Env lore r op -> r) -> ImpM lore r op r
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> r
forall lore r op. Env lore r op -> r
envEnv

localEnv :: (r -> r) -> ImpM lore r op a -> ImpM lore r op a
localEnv :: (r -> r) -> ImpM lore r op a -> ImpM lore r op a
localEnv r -> r
f = (Env lore r op -> Env lore r op)
-> ImpM lore r op a -> ImpM lore r op a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((Env lore r op -> Env lore r op)
 -> ImpM lore r op a -> ImpM lore r op a)
-> (Env lore r op -> Env lore r op)
-> ImpM lore r op a
-> ImpM lore r op a
forall a b. (a -> b) -> a -> b
$ \Env lore r op
env -> Env lore r op
env { envEnv :: r
envEnv = r -> r
f (r -> r) -> r -> r
forall a b. (a -> b) -> a -> b
$ Env lore r op -> r
forall lore r op. Env lore r op -> r
envEnv Env lore r op
env }

-- | Get the current symbol table.
getVTable :: ImpM lore r op (VTable lore)
getVTable :: ImpM lore r op (VTable lore)
getVTable = (State lore r op -> VTable lore) -> ImpM lore r op (VTable lore)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets State lore r op -> VTable lore
forall lore r op. State lore r op -> VTable lore
stateVTable

putVTable :: VTable lore -> ImpM lore r op ()
putVTable :: VTable lore -> ImpM lore r op ()
putVTable VTable lore
vtable = (State lore r op -> State lore r op) -> ImpM lore r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((State lore r op -> State lore r op) -> ImpM lore r op ())
-> (State lore r op -> State lore r op) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \State lore r op
s -> State lore r op
s { stateVTable :: VTable lore
stateVTable = VTable lore
vtable }

-- | Run an action with a modified symbol table.  All changes to the
-- symbol table will be reverted once the action is done!
localVTable :: (VTable lore -> VTable lore) -> ImpM lore r op a -> ImpM lore r op a
localVTable :: (VTable lore -> VTable lore)
-> ImpM lore r op a -> ImpM lore r op a
localVTable VTable lore -> VTable lore
f ImpM lore r op a
m = do
  VTable lore
old_vtable <- ImpM lore r op (VTable lore)
forall lore r op. ImpM lore r op (VTable lore)
getVTable
  VTable lore -> ImpM lore r op ()
forall lore r op. VTable lore -> ImpM lore r op ()
putVTable (VTable lore -> ImpM lore r op ())
-> VTable lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VTable lore -> VTable lore
f VTable lore
old_vtable
  a
a <- ImpM lore r op a
m
  VTable lore -> ImpM lore r op ()
forall lore r op. VTable lore -> ImpM lore r op ()
putVTable VTable lore
old_vtable
  a -> ImpM lore r op a
forall (m :: * -> *) a. Monad m => a -> m a
return a
a

lookupVar :: VName -> ImpM lore r op (VarEntry lore)
lookupVar :: VName -> ImpM lore r op (VarEntry lore)
lookupVar VName
name = do
  Maybe (VarEntry lore)
res <- (State lore r op -> Maybe (VarEntry lore))
-> ImpM lore r op (Maybe (VarEntry lore))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((State lore r op -> Maybe (VarEntry lore))
 -> ImpM lore r op (Maybe (VarEntry lore)))
-> (State lore r op -> Maybe (VarEntry lore))
-> ImpM lore r op (Maybe (VarEntry lore))
forall a b. (a -> b) -> a -> b
$ VName -> Map VName (VarEntry lore) -> Maybe (VarEntry lore)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name (Map VName (VarEntry lore) -> Maybe (VarEntry lore))
-> (State lore r op -> Map VName (VarEntry lore))
-> State lore r op
-> Maybe (VarEntry lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. State lore r op -> Map VName (VarEntry lore)
forall lore r op. State lore r op -> VTable lore
stateVTable
  case Maybe (VarEntry lore)
res of
    Just VarEntry lore
entry -> VarEntry lore -> ImpM lore r op (VarEntry lore)
forall (m :: * -> *) a. Monad m => a -> m a
return VarEntry lore
entry
    Maybe (VarEntry lore)
_ -> String -> ImpM lore r op (VarEntry lore)
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op (VarEntry lore))
-> String -> ImpM lore r op (VarEntry lore)
forall a b. (a -> b) -> a -> b
$ String
"Unknown variable: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
name

lookupArray :: VName -> ImpM lore r op ArrayEntry
lookupArray :: VName -> ImpM lore r op ArrayEntry
lookupArray VName
name = do
  VarEntry lore
res <- VName -> ImpM lore r op (VarEntry lore)
forall lore r op. VName -> ImpM lore r op (VarEntry lore)
lookupVar VName
name
  case VarEntry lore
res of
    ArrayVar Maybe (Exp lore)
_ ArrayEntry
entry -> ArrayEntry -> ImpM lore r op ArrayEntry
forall (m :: * -> *) a. Monad m => a -> m a
return ArrayEntry
entry
    VarEntry lore
_                -> String -> ImpM lore r op ArrayEntry
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op ArrayEntry)
-> String -> ImpM lore r op ArrayEntry
forall a b. (a -> b) -> a -> b
$ String
"ImpGen.lookupArray: not an array: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
name

lookupMemory :: VName -> ImpM lore r op MemEntry
lookupMemory :: VName -> ImpM lore r op MemEntry
lookupMemory VName
name = do
  VarEntry lore
res <- VName -> ImpM lore r op (VarEntry lore)
forall lore r op. VName -> ImpM lore r op (VarEntry lore)
lookupVar VName
name
  case VarEntry lore
res of
    MemVar Maybe (Exp lore)
_ MemEntry
entry -> MemEntry -> ImpM lore r op MemEntry
forall (m :: * -> *) a. Monad m => a -> m a
return MemEntry
entry
    VarEntry lore
_              -> String -> ImpM lore r op MemEntry
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op MemEntry)
-> String -> ImpM lore r op MemEntry
forall a b. (a -> b) -> a -> b
$ String
"Unknown memory block: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
name

destinationFromPattern :: ExplicitMemorish lore => Pattern lore -> ImpM lore r op Destination
destinationFromPattern :: Pattern lore -> ImpM lore r op Destination
destinationFromPattern Pattern lore
pat =
  ([ValueDestination] -> Destination)
-> ImpM lore r op [ValueDestination] -> ImpM lore r op Destination
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Maybe Int -> [ValueDestination] -> Destination
Destination (VName -> Int
baseTag (VName -> Int) -> Maybe VName -> Maybe Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [VName] -> Maybe VName
forall a. [a] -> Maybe a
maybeHead (PatternT (MemInfo DimSize NoUniqueness MemBind) -> [VName]
forall attr. PatternT attr -> [VName]
patternNames Pattern lore
PatternT (MemInfo DimSize NoUniqueness MemBind)
pat))) (ImpM lore r op [ValueDestination] -> ImpM lore r op Destination)
-> ([PatElemT (MemInfo DimSize NoUniqueness MemBind)]
    -> ImpM lore r op [ValueDestination])
-> [PatElemT (MemInfo DimSize NoUniqueness MemBind)]
-> ImpM lore r op Destination
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElemT (MemInfo DimSize NoUniqueness MemBind)
 -> ImpM lore r op ValueDestination)
-> [PatElemT (MemInfo DimSize NoUniqueness MemBind)]
-> ImpM lore r op [ValueDestination]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM PatElemT (MemInfo DimSize NoUniqueness MemBind)
-> ImpM lore r op ValueDestination
forall attr lore r op.
PatElemT attr -> ImpM lore r op ValueDestination
inspect ([PatElemT (MemInfo DimSize NoUniqueness MemBind)]
 -> ImpM lore r op Destination)
-> [PatElemT (MemInfo DimSize NoUniqueness MemBind)]
-> ImpM lore r op Destination
forall a b. (a -> b) -> a -> b
$
  PatternT (MemInfo DimSize NoUniqueness MemBind)
-> [PatElemT (MemInfo DimSize NoUniqueness MemBind)]
forall attr. PatternT attr -> [PatElemT attr]
patternElements Pattern lore
PatternT (MemInfo DimSize NoUniqueness MemBind)
pat
  where inspect :: PatElemT attr -> ImpM lore r op ValueDestination
inspect PatElemT attr
patElem = do
          let name :: VName
name = PatElemT attr -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT attr
patElem
          VarEntry lore
entry <- VName -> ImpM lore r op (VarEntry lore)
forall lore r op. VName -> ImpM lore r op (VarEntry lore)
lookupVar VName
name
          case VarEntry lore
entry of
            ArrayVar Maybe (Exp lore)
_ (ArrayEntry MemLocation{} PrimType
_) ->
              ValueDestination -> ImpM lore r op ValueDestination
forall (m :: * -> *) a. Monad m => a -> m a
return (ValueDestination -> ImpM lore r op ValueDestination)
-> ValueDestination -> ImpM lore r op ValueDestination
forall a b. (a -> b) -> a -> b
$ Maybe MemLocation -> ValueDestination
ArrayDestination Maybe MemLocation
forall a. Maybe a
Nothing
            MemVar{} ->
              ValueDestination -> ImpM lore r op ValueDestination
forall (m :: * -> *) a. Monad m => a -> m a
return (ValueDestination -> ImpM lore r op ValueDestination)
-> ValueDestination -> ImpM lore r op ValueDestination
forall a b. (a -> b) -> a -> b
$ VName -> ValueDestination
MemoryDestination VName
name

            ScalarVar{} ->
              ValueDestination -> ImpM lore r op ValueDestination
forall (m :: * -> *) a. Monad m => a -> m a
return (ValueDestination -> ImpM lore r op ValueDestination)
-> ValueDestination -> ImpM lore r op ValueDestination
forall a b. (a -> b) -> a -> b
$ VName -> ValueDestination
ScalarDestination VName
name

fullyIndexArray :: VName -> [Imp.Exp]
                -> ImpM lore r op (VName, Imp.Space, Count Elements Imp.Exp)
fullyIndexArray :: VName -> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
fullyIndexArray VName
name [Exp]
indices = do
  ArrayEntry
arr <- VName -> ImpM lore r op ArrayEntry
forall lore r op. VName -> ImpM lore r op ArrayEntry
lookupArray VName
name
  MemLocation
-> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
forall lore r op.
MemLocation
-> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
fullyIndexArray' (ArrayEntry -> MemLocation
entryArrayLocation ArrayEntry
arr) [Exp]
indices

fullyIndexArray' :: MemLocation -> [Imp.Exp]
                 -> ImpM lore r op (VName, Imp.Space, Count Elements Imp.Exp)
fullyIndexArray' :: MemLocation
-> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
fullyIndexArray' (MemLocation VName
mem [DimSize]
_ IxFun Exp
ixfun) [Exp]
indices = do
  Space
space <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM lore r op MemEntry -> ImpM lore r op Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM lore r op MemEntry
forall lore r op. VName -> ImpM lore r op MemEntry
lookupMemory VName
mem
  let indices' :: [Exp]
indices' = case Space
space of
                   ScalarSpace [DimSize]
ds PrimType
_ ->
                     let ([Exp]
zero_is, [Exp]
is) = Int -> [Exp] -> ([Exp], [Exp])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([DimSize] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimSize]
ds) [Exp]
indices
                     in (Exp -> Exp) -> [Exp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (Exp -> Exp -> Exp
forall a b. a -> b -> a
const Exp
0) [Exp]
zero_is [Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++ [Exp]
is
                   Space
_ -> [Exp]
indices
  (VName, Space, Count Elements Exp)
-> ImpM lore r op (VName, Space, Count Elements Exp)
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
mem, Space
space,
          Exp -> Count Elements Exp
elements (Exp -> Count Elements Exp) -> Exp -> Count Elements Exp
forall a b. (a -> b) -> a -> b
$ IxFun Exp -> [Exp] -> Exp
forall num.
(IntegralExp num, Eq num) =>
IxFun num -> Indices num -> num
IxFun.index IxFun Exp
ixfun [Exp]
indices')

sliceArray :: MemLocation
           -> Slice Imp.Exp
           -> MemLocation
sliceArray :: MemLocation -> [DimIndex Exp] -> MemLocation
sliceArray (MemLocation VName
mem [DimSize]
shape IxFun Exp
ixfun) [DimIndex Exp]
slice =
  VName -> [DimSize] -> IxFun Exp -> MemLocation
MemLocation VName
mem ([DimSize] -> [DimIndex Exp] -> [DimSize]
forall a d. [a] -> [DimIndex d] -> [a]
update [DimSize]
shape [DimIndex Exp]
slice) (IxFun Exp -> MemLocation) -> IxFun Exp -> MemLocation
forall a b. (a -> b) -> a -> b
$ IxFun Exp -> [DimIndex Exp] -> IxFun Exp
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun Exp
ixfun [DimIndex Exp]
slice
  where update :: [a] -> [DimIndex d] -> [a]
update (a
d:[a]
ds) (DimSlice{}:[DimIndex d]
is) = a
d a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a] -> [DimIndex d] -> [a]
update [a]
ds [DimIndex d]
is
        update (a
_:[a]
ds) (DimFix{}:[DimIndex d]
is) = [a] -> [DimIndex d] -> [a]
update [a]
ds [DimIndex d]
is
        update [a]
_      [DimIndex d]
_               = []

-- More complicated read/write operations that use index functions.

copy :: CopyCompiler lore r op
copy :: CopyCompiler lore r op
copy PrimType
bt MemLocation
pat MemLocation
src = do
  CopyCompiler lore r op
cc <- (Env lore r op -> CopyCompiler lore r op)
-> ImpM lore r op (CopyCompiler lore r op)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> CopyCompiler lore r op
forall lore r op. Env lore r op -> CopyCompiler lore r op
envCopyCompiler
  CopyCompiler lore r op
cc PrimType
bt MemLocation
pat MemLocation
src

-- | Use an 'Imp.Copy' if possible, otherwise 'copyElementWise'.
defaultCopy :: CopyCompiler lore r op
defaultCopy :: CopyCompiler lore r op
defaultCopy PrimType
bt MemLocation
dest MemLocation
src
  | Just Exp
destoffset <-
      IxFun Exp -> Exp -> Maybe Exp
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
IxFun.linearWithOffset IxFun Exp
destIxFun Exp
bt_size,
    Just Exp
srcoffset  <-
      IxFun Exp -> Exp -> Maybe Exp
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
IxFun.linearWithOffset IxFun Exp
srcIxFun Exp
bt_size = do
        Space
srcspace <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM lore r op MemEntry -> ImpM lore r op Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM lore r op MemEntry
forall lore r op. VName -> ImpM lore r op MemEntry
lookupMemory VName
srcmem
        Space
destspace <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM lore r op MemEntry -> ImpM lore r op Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM lore r op MemEntry
forall lore r op. VName -> ImpM lore r op MemEntry
lookupMemory VName
destmem
        if Space -> Bool
isScalarSpace Space
srcspace Bool -> Bool -> Bool
|| Space -> Bool
isScalarSpace Space
destspace
          then CopyCompiler lore r op
forall lore r op. CopyCompiler lore r op
copyElementWise PrimType
bt MemLocation
dest MemLocation
src
          else Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName
-> Count Bytes Exp
-> Space
-> VName
-> Count Bytes Exp
-> Space
-> Count Bytes Exp
-> Code op
forall a.
VName
-> Count Bytes Exp
-> Space
-> VName
-> Count Bytes Exp
-> Space
-> Count Bytes Exp
-> Code a
Imp.Copy
               VName
destmem (Exp -> Count Bytes Exp
bytes Exp
destoffset) Space
destspace
               VName
srcmem (Exp -> Count Bytes Exp
bytes Exp
srcoffset) Space
srcspace (Count Bytes Exp -> Code op) -> Count Bytes Exp -> Code op
forall a b. (a -> b) -> a -> b
$
               Count Elements Exp
num_elems Count Elements Exp -> PrimType -> Count Bytes Exp
`withElemType` PrimType
bt
  | Bool
otherwise =
      CopyCompiler lore r op
forall lore r op. CopyCompiler lore r op
copyElementWise PrimType
bt MemLocation
dest MemLocation
src
  where bt_size :: Exp
bt_size = PrimType -> Exp
forall a. Num a => PrimType -> a
primByteSize PrimType
bt
        num_elems :: Count Elements Exp
num_elems = Exp -> Count Elements Exp
Imp.elements (Exp -> Count Elements Exp) -> Exp -> Count Elements Exp
forall a b. (a -> b) -> a -> b
$ [Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([Exp] -> Exp) -> [Exp] -> Exp
forall a b. (a -> b) -> a -> b
$ (DimSize -> Exp) -> [DimSize] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> DimSize -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32) [DimSize]
srcshape
        MemLocation VName
destmem [DimSize]
_ IxFun Exp
destIxFun = MemLocation
dest
        MemLocation VName
srcmem [DimSize]
srcshape IxFun Exp
srcIxFun = MemLocation
src
        isScalarSpace :: Space -> Bool
isScalarSpace ScalarSpace{} = Bool
True
        isScalarSpace Space
_ = Bool
False

copyElementWise :: CopyCompiler lore r op
copyElementWise :: CopyCompiler lore r op
copyElementWise PrimType
bt MemLocation
dest MemLocation
src = do
    let bounds :: [Exp]
bounds = (DimSize -> Exp) -> [DimSize] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> DimSize -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32) ([DimSize] -> [Exp]) -> [DimSize] -> [Exp]
forall a b. (a -> b) -> a -> b
$ MemLocation -> [DimSize]
memLocationShape MemLocation
src
    [VName]
is <- Int -> ImpM lore r op VName -> ImpM lore r op [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM ([Exp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Exp]
bounds) (String -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"i")
    let ivars :: [Exp]
ivars = (VName -> Exp) -> [VName] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Exp
Imp.vi32 [VName]
is
    (VName
destmem, Space
destspace, Count Elements Exp
destidx) <- MemLocation
-> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
forall lore r op.
MemLocation
-> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
fullyIndexArray' MemLocation
dest [Exp]
ivars
    (VName
srcmem, Space
srcspace, Count Elements Exp
srcidx) <- MemLocation
-> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
forall lore r op.
MemLocation
-> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
fullyIndexArray' MemLocation
src [Exp]
ivars
    Volatility
vol <- (Env lore r op -> Volatility) -> ImpM lore r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Volatility
forall lore r op. Env lore r op -> Volatility
envVolatility
    Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ ((Code op -> Code op)
 -> (Code op -> Code op) -> Code op -> Code op)
-> (Code op -> Code op)
-> [Code op -> Code op]
-> Code op
-> Code op
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (Code op -> Code op) -> (Code op -> Code op) -> Code op -> Code op
forall b c a. (b -> c) -> (a -> b) -> a -> c
(.) Code op -> Code op
forall a. a -> a
id ((VName -> Exp -> Code op -> Code op)
-> [VName] -> [Exp] -> [Code op -> Code op]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (VName -> IntType -> Exp -> Code op -> Code op
forall a. VName -> IntType -> Exp -> Code a -> Code a
`Imp.For` IntType
Int32) [VName]
is [Exp]
bounds) (Code op -> Code op) -> Code op -> Code op
forall a b. (a -> b) -> a -> b
$
      VName
-> Count Elements Exp
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code op
forall a.
VName
-> Count Elements Exp
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
destmem Count Elements Exp
destidx PrimType
bt Space
destspace Volatility
vol (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$
      VName
-> Count Elements Exp -> PrimType -> Space -> Volatility -> Exp
Imp.index VName
srcmem Count Elements Exp
srcidx PrimType
bt Space
srcspace Volatility
vol

-- | Copy from here to there; both destination and source may be
-- indexeded.
copyArrayDWIM :: PrimType
              -> MemLocation -> [DimIndex Imp.Exp]
              -> MemLocation -> [DimIndex Imp.Exp]
              -> ImpM lore r op (Imp.Code op)
copyArrayDWIM :: PrimType
-> MemLocation
-> [DimIndex Exp]
-> MemLocation
-> [DimIndex Exp]
-> ImpM lore r op (Code op)
copyArrayDWIM PrimType
bt
  destlocation :: MemLocation
destlocation@(MemLocation VName
_ [DimSize]
destshape IxFun Exp
_) [DimIndex Exp]
destslice
  srclocation :: MemLocation
srclocation@(MemLocation VName
_ [DimSize]
srcshape IxFun Exp
_) [DimIndex Exp]
srcslice

  | Just [Exp]
destis <- (DimIndex Exp -> Maybe Exp) -> [DimIndex Exp] -> Maybe [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndex Exp -> Maybe Exp
forall d. DimIndex d -> Maybe d
dimFix [DimIndex Exp]
destslice,
    Just [Exp]
srcis <- (DimIndex Exp -> Maybe Exp) -> [DimIndex Exp] -> Maybe [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndex Exp -> Maybe Exp
forall d. DimIndex d -> Maybe d
dimFix [DimIndex Exp]
srcslice,
    [Exp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Exp]
srcis Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [DimSize] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimSize]
srcshape,
    [Exp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Exp]
destis Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [DimSize] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimSize]
destshape = do
  (VName
targetmem, Space
destspace, Count Elements Exp
targetoffset) <-
    MemLocation
-> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
forall lore r op.
MemLocation
-> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
fullyIndexArray' MemLocation
destlocation [Exp]
destis
  (VName
srcmem, Space
srcspace, Count Elements Exp
srcoffset) <-
    MemLocation
-> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
forall lore r op.
MemLocation
-> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
fullyIndexArray' MemLocation
srclocation [Exp]
srcis
  Volatility
vol <- (Env lore r op -> Volatility) -> ImpM lore r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Volatility
forall lore r op. Env lore r op -> Volatility
envVolatility
  Code op -> ImpM lore r op (Code op)
forall (m :: * -> *) a. Monad m => a -> m a
return (Code op -> ImpM lore r op (Code op))
-> Code op -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$ VName
-> Count Elements Exp
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code op
forall a.
VName
-> Count Elements Exp
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
targetmem Count Elements Exp
targetoffset PrimType
bt Space
destspace Volatility
vol (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$
    VName
-> Count Elements Exp -> PrimType -> Space -> Volatility -> Exp
Imp.index VName
srcmem Count Elements Exp
srcoffset PrimType
bt Space
srcspace Volatility
vol

  | Bool
otherwise = do
      let destlocation' :: MemLocation
destlocation' =
            MemLocation -> [DimIndex Exp] -> MemLocation
sliceArray MemLocation
destlocation ([DimIndex Exp] -> MemLocation) -> [DimIndex Exp] -> MemLocation
forall a b. (a -> b) -> a -> b
$
            [Exp] -> [DimIndex Exp] -> [DimIndex Exp]
forall d. Num d => [d] -> [DimIndex d] -> [DimIndex d]
fullSliceNum ((DimSize -> Exp) -> [DimSize] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> DimSize -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32) [DimSize]
destshape) [DimIndex Exp]
destslice
          srclocation' :: MemLocation
srclocation'  =
            MemLocation -> [DimIndex Exp] -> MemLocation
sliceArray MemLocation
srclocation ([DimIndex Exp] -> MemLocation) -> [DimIndex Exp] -> MemLocation
forall a b. (a -> b) -> a -> b
$
            [Exp] -> [DimIndex Exp] -> [DimIndex Exp]
forall d. Num d => [d] -> [DimIndex d] -> [DimIndex d]
fullSliceNum ((DimSize -> Exp) -> [DimSize] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> DimSize -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32) [DimSize]
srcshape) [DimIndex Exp]
srcslice
          destrank :: Int
destrank = [DimSize] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (MemLocation -> [DimSize]
memLocationShape MemLocation
destlocation')
          srcrank :: Int
srcrank = [DimSize] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (MemLocation -> [DimSize]
memLocationShape MemLocation
srclocation')
      if Int
destrank Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
srcrank
        then String -> ImpM lore r op (Code op)
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op (Code op))
-> String -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$ String
"copyArrayDWIM: cannot copy to " String -> ShowS
forall a. [a] -> [a] -> [a]
++
             VName -> String
forall a. Pretty a => a -> String
pretty (MemLocation -> VName
memLocationName MemLocation
destlocation') String -> ShowS
forall a. [a] -> [a] -> [a]
++
             String
" from " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty (MemLocation -> VName
memLocationName MemLocation
srclocation') String -> ShowS
forall a. [a] -> [a] -> [a]
++
             String
" because ranks do not match (" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Pretty a => a -> String
pretty Int
destrank String -> ShowS
forall a. [a] -> [a] -> [a]
++
             String
" vs " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Pretty a => a -> String
pretty Int
srcrank String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"
      else if MemLocation
destlocation' MemLocation -> MemLocation -> Bool
forall a. Eq a => a -> a -> Bool
== MemLocation
srclocation'
        then Code op -> ImpM lore r op (Code op)
forall (m :: * -> *) a. Monad m => a -> m a
return Code op
forall a. Monoid a => a
mempty -- Copy would be no-op.
        else ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (ImpM lore r op () -> ImpM lore r op (Code op))
-> ImpM lore r op () -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$ CopyCompiler lore r op
forall lore r op. CopyCompiler lore r op
copy PrimType
bt MemLocation
destlocation' MemLocation
srclocation'

-- | Like 'copyDWIM', but the target is a 'ValueDestination'
-- instead of a variable name.
copyDWIMDest :: ValueDestination -> [DimIndex Imp.Exp] -> SubExp -> [DimIndex Imp.Exp]
             -> ImpM lore r op ()

copyDWIMDest :: ValueDestination
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIMDest ValueDestination
_ [DimIndex Exp]
_ (Constant PrimValue
v) (DimIndex Exp
_:[DimIndex Exp]
_) =
  String -> ImpM lore r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op ()) -> String -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
  [String] -> String
unwords [String
"copyDWIMDest: constant source", PrimValue -> String
forall a. Pretty a => a -> String
pretty PrimValue
v, String
"cannot be indexed."]
copyDWIMDest ValueDestination
pat [DimIndex Exp]
dest_slice (Constant PrimValue
v) [] =
  case (DimIndex Exp -> Maybe Exp) -> [DimIndex Exp] -> Maybe [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndex Exp -> Maybe Exp
forall d. DimIndex d -> Maybe d
dimFix [DimIndex Exp]
dest_slice of
    Maybe [Exp]
Nothing ->
      String -> ImpM lore r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op ()) -> String -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
      [String] -> String
unwords [String
"copyDWIMDest: constant source", PrimValue -> String
forall a. Pretty a => a -> String
pretty PrimValue
v, String
"with slice destination."]
    Just [Exp]
dest_is ->
      case ValueDestination
pat of
        ScalarDestination VName
name ->
          Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
name (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ PrimValue -> Exp
forall v. PrimValue -> PrimExp v
Imp.ValueExp PrimValue
v
        MemoryDestination{} ->
          String -> ImpM lore r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op ()) -> String -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
          [String] -> String
unwords [String
"copyDWIMDest: constant source", PrimValue -> String
forall a. Pretty a => a -> String
pretty PrimValue
v, String
"cannot be written to memory destination."]
        ArrayDestination (Just MemLocation
dest_loc) -> do
          (VName
dest_mem, Space
dest_space, Count Elements Exp
dest_i) <-
            MemLocation
-> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
forall lore r op.
MemLocation
-> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
fullyIndexArray' MemLocation
dest_loc [Exp]
dest_is
          Volatility
vol <- (Env lore r op -> Volatility) -> ImpM lore r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Volatility
forall lore r op. Env lore r op -> Volatility
envVolatility
          Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName
-> Count Elements Exp
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code op
forall a.
VName
-> Count Elements Exp
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
dest_mem Count Elements Exp
dest_i PrimType
bt Space
dest_space Volatility
vol (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ PrimValue -> Exp
forall v. PrimValue -> PrimExp v
Imp.ValueExp PrimValue
v
        ArrayDestination Maybe MemLocation
Nothing ->
          String -> ImpM lore r op ()
forall a. HasCallStack => String -> a
error String
"copyDWIMDest: ArrayDestination Nothing"
  where bt :: PrimType
bt = PrimValue -> PrimType
primValueType PrimValue
v

copyDWIMDest ValueDestination
dest [DimIndex Exp]
dest_slice (Var VName
src) [DimIndex Exp]
src_slice = do
  VarEntry lore
src_entry <- VName -> ImpM lore r op (VarEntry lore)
forall lore r op. VName -> ImpM lore r op (VarEntry lore)
lookupVar VName
src
  case (ValueDestination
dest, VarEntry lore
src_entry) of
    (MemoryDestination VName
mem, MemVar Maybe (Exp lore)
_ (MemEntry Space
space)) ->
      Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> Space -> Code op
forall a. VName -> VName -> Space -> Code a
Imp.SetMem VName
mem VName
src Space
space

    (MemoryDestination{}, VarEntry lore
_) ->
      String -> ImpM lore r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op ()) -> String -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
      [String] -> String
unwords [String
"copyDWIMDest: cannot write", VName -> String
forall a. Pretty a => a -> String
pretty VName
src, String
"to memory destination."]

    (ValueDestination
_, MemVar{}) ->
      String -> ImpM lore r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op ()) -> String -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
      [String] -> String
unwords [String
"copyDWIMDest: source", VName -> String
forall a. Pretty a => a -> String
pretty VName
src, String
"is a memory block."]

    (ValueDestination
_, ScalarVar Maybe (Exp lore)
_ (ScalarEntry PrimType
_)) | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [DimIndex Exp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [DimIndex Exp]
src_slice ->
      String -> ImpM lore r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op ()) -> String -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
      [String] -> String
unwords [String
"copyDWIMDest: prim-typed source", VName -> String
forall a. Pretty a => a -> String
pretty VName
src, String
"with slice", [DimIndex Exp] -> String
forall a. Pretty a => a -> String
pretty [DimIndex Exp]
src_slice]

    (ScalarDestination VName
name, VarEntry lore
_) | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [DimIndex Exp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [DimIndex Exp]
dest_slice ->
      String -> ImpM lore r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op ()) -> String -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
      [String] -> String
unwords [String
"copyDWIMDest: prim-typed target", VName -> String
forall a. Pretty a => a -> String
pretty VName
name, String
"with slice", [DimIndex Exp] -> String
forall a. Pretty a => a -> String
pretty [DimIndex Exp]
dest_slice]

    (ScalarDestination VName
name, ScalarVar Maybe (Exp lore)
_ (ScalarEntry PrimType
pt)) ->
      Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
name (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
src PrimType
pt

    (ScalarDestination VName
name, ArrayVar Maybe (Exp lore)
_ ArrayEntry
arr)
      | Just [Exp]
src_is <- (DimIndex Exp -> Maybe Exp) -> [DimIndex Exp] -> Maybe [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndex Exp -> Maybe Exp
forall d. DimIndex d -> Maybe d
dimFix [DimIndex Exp]
src_slice -> do
          let bt :: PrimType
bt = ArrayEntry -> PrimType
entryArrayElemType ArrayEntry
arr
          (VName
mem, Space
space, Count Elements Exp
i) <-
            MemLocation
-> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
forall lore r op.
MemLocation
-> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
fullyIndexArray' (ArrayEntry -> MemLocation
entryArrayLocation ArrayEntry
arr) [Exp]
src_is
          Volatility
vol <- (Env lore r op -> Volatility) -> ImpM lore r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Volatility
forall lore r op. Env lore r op -> Volatility
envVolatility
          Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
name (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ VName
-> Count Elements Exp -> PrimType -> Space -> Volatility -> Exp
Imp.index VName
mem Count Elements Exp
i PrimType
bt Space
space Volatility
vol
      | Bool
otherwise ->
          String -> ImpM lore r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op ()) -> String -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
          [String] -> String
unwords [String
"copyDWIMDest: prim-typed target and array-typed source", VName -> String
forall a. Pretty a => a -> String
pretty VName
src,
                   String
"with slice", [DimIndex Exp] -> String
forall a. Pretty a => a -> String
pretty [DimIndex Exp]
src_slice]

    (ArrayDestination (Just MemLocation
dest_loc), ArrayVar Maybe (Exp lore)
_ ArrayEntry
src_arr) -> do
      let src_loc :: MemLocation
src_loc = ArrayEntry -> MemLocation
entryArrayLocation ArrayEntry
src_arr
          bt :: PrimType
bt = ArrayEntry -> PrimType
entryArrayElemType ArrayEntry
src_arr
      Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ())
-> ImpM lore r op (Code op) -> ImpM lore r op ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< PrimType
-> MemLocation
-> [DimIndex Exp]
-> MemLocation
-> [DimIndex Exp]
-> ImpM lore r op (Code op)
forall lore r op.
PrimType
-> MemLocation
-> [DimIndex Exp]
-> MemLocation
-> [DimIndex Exp]
-> ImpM lore r op (Code op)
copyArrayDWIM PrimType
bt MemLocation
dest_loc [DimIndex Exp]
dest_slice MemLocation
src_loc [DimIndex Exp]
src_slice

    (ArrayDestination (Just MemLocation
dest_loc), ScalarVar Maybe (Exp lore)
_ (ScalarEntry PrimType
bt))
      | Just [Exp]
dest_is <- (DimIndex Exp -> Maybe Exp) -> [DimIndex Exp] -> Maybe [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndex Exp -> Maybe Exp
forall d. DimIndex d -> Maybe d
dimFix [DimIndex Exp]
dest_slice -> do
          (VName
dest_mem, Space
dest_space, Count Elements Exp
dest_i) <- MemLocation
-> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
forall lore r op.
MemLocation
-> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
fullyIndexArray' MemLocation
dest_loc [Exp]
dest_is
          Volatility
vol <- (Env lore r op -> Volatility) -> ImpM lore r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Volatility
forall lore r op. Env lore r op -> Volatility
envVolatility
          Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName
-> Count Elements Exp
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code op
forall a.
VName
-> Count Elements Exp
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
dest_mem Count Elements Exp
dest_i PrimType
bt Space
dest_space Volatility
vol (VName -> PrimType -> Exp
Imp.var VName
src PrimType
bt)
      | Bool
otherwise ->
          String -> ImpM lore r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op ()) -> String -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
          [String] -> String
unwords [String
"copyDWIMDest: array-typed target and prim-typed source", VName -> String
forall a. Pretty a => a -> String
pretty VName
src,
                   String
"with slice", [DimIndex Exp] -> String
forall a. Pretty a => a -> String
pretty [DimIndex Exp]
dest_slice]

    (ArrayDestination Maybe MemLocation
Nothing, VarEntry lore
_) ->
      () -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return () -- Nothing to do; something else set some memory
                -- somewhere.

-- | Copy from here to there; both destination and source be
-- indexeded.  If so, they better be arrays of enough dimensions.
-- This function will generally just Do What I Mean, and Do The Right
-- Thing.  Both destination and source must be in scope.
copyDWIM :: VName -> [DimIndex Imp.Exp] -> SubExp -> [DimIndex Imp.Exp]
         -> ImpM lore r op ()
copyDWIM :: VName
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM VName
dest [DimIndex Exp]
dest_slice DimSize
src [DimIndex Exp]
src_slice = do
  VarEntry lore
dest_entry <- VName -> ImpM lore r op (VarEntry lore)
forall lore r op. VName -> ImpM lore r op (VarEntry lore)
lookupVar VName
dest
  let dest_target :: ValueDestination
dest_target =
        case VarEntry lore
dest_entry of
          ScalarVar Maybe (Exp lore)
_ ScalarEntry
_ ->
            VName -> ValueDestination
ScalarDestination VName
dest

          ArrayVar Maybe (Exp lore)
_ (ArrayEntry (MemLocation VName
mem [DimSize]
shape IxFun Exp
ixfun) PrimType
_) ->
            Maybe MemLocation -> ValueDestination
ArrayDestination (Maybe MemLocation -> ValueDestination)
-> Maybe MemLocation -> ValueDestination
forall a b. (a -> b) -> a -> b
$ MemLocation -> Maybe MemLocation
forall a. a -> Maybe a
Just (MemLocation -> Maybe MemLocation)
-> MemLocation -> Maybe MemLocation
forall a b. (a -> b) -> a -> b
$ VName -> [DimSize] -> IxFun Exp -> MemLocation
MemLocation VName
mem [DimSize]
shape IxFun Exp
ixfun

          MemVar Maybe (Exp lore)
_ MemEntry
_ ->
            VName -> ValueDestination
MemoryDestination VName
dest
  ValueDestination
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
forall lore r op.
ValueDestination
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIMDest ValueDestination
dest_target [DimIndex Exp]
dest_slice DimSize
src [DimIndex Exp]
src_slice

-- | As 'copyDWIM', but implicitly 'DimFix'es the indexes.
copyDWIMFix :: VName -> [Imp.Exp] -> SubExp -> [Imp.Exp] -> ImpM lore r op ()
copyDWIMFix :: VName -> [Exp] -> DimSize -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
dest [Exp]
dest_is DimSize
src [Exp]
src_is =
  VName
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM VName
dest ((Exp -> DimIndex Exp) -> [Exp] -> [DimIndex Exp]
forall a b. (a -> b) -> [a] -> [b]
map Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix [Exp]
dest_is) DimSize
src ((Exp -> DimIndex Exp) -> [Exp] -> [DimIndex Exp]
forall a b. (a -> b) -> [a] -> [b]
map Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix [Exp]
src_is)

-- | @compileAlloc pat size space@ allocates @n@ bytes of memory in @space@,
-- writing the result to @dest@, which must be a single
-- 'MemoryDestination',
compileAlloc :: ExplicitMemorish lore =>
                Pattern lore -> SubExp -> Space
             -> ImpM lore r op ()
compileAlloc :: Pattern lore -> DimSize -> Space -> ImpM lore r op ()
compileAlloc (Pattern [] [PatElemT (LetAttr lore)
mem]) DimSize
e Space
space = do
  Count Bytes Exp
e' <- Exp -> Count Bytes Exp
Imp.bytes (Exp -> Count Bytes Exp)
-> ImpM lore r op Exp -> ImpM lore r op (Count Bytes Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> DimSize -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp DimSize
e
  Maybe (AllocCompiler lore r op)
allocator <- (Env lore r op -> Maybe (AllocCompiler lore r op))
-> ImpM lore r op (Maybe (AllocCompiler lore r op))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((Env lore r op -> Maybe (AllocCompiler lore r op))
 -> ImpM lore r op (Maybe (AllocCompiler lore r op)))
-> (Env lore r op -> Maybe (AllocCompiler lore r op))
-> ImpM lore r op (Maybe (AllocCompiler lore r op))
forall a b. (a -> b) -> a -> b
$ Space
-> Map Space (AllocCompiler lore r op)
-> Maybe (AllocCompiler lore r op)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Space
space (Map Space (AllocCompiler lore r op)
 -> Maybe (AllocCompiler lore r op))
-> (Env lore r op -> Map Space (AllocCompiler lore r op))
-> Env lore r op
-> Maybe (AllocCompiler lore r op)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env lore r op -> Map Space (AllocCompiler lore r op)
forall lore r op.
Env lore r op -> Map Space (AllocCompiler lore r op)
envAllocCompilers
  case Maybe (AllocCompiler lore r op)
allocator of
    Maybe (AllocCompiler lore r op)
Nothing -> Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Count Bytes Exp -> Space -> Code op
forall a. VName -> Count Bytes Exp -> Space -> Code a
Imp.Allocate (PatElemT (MemInfo DimSize NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr lore)
PatElemT (MemInfo DimSize NoUniqueness MemBind)
mem) Count Bytes Exp
e' Space
space
    Just AllocCompiler lore r op
allocator' -> AllocCompiler lore r op
allocator' (PatElemT (MemInfo DimSize NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr lore)
PatElemT (MemInfo DimSize NoUniqueness MemBind)
mem) Count Bytes Exp
e'
compileAlloc Pattern lore
pat DimSize
_ Space
_ =
  String -> ImpM lore r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op ()) -> String -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ String
"compileAlloc: Invalid pattern: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ PatternT (MemInfo DimSize NoUniqueness MemBind) -> String
forall a. Pretty a => a -> String
pretty Pattern lore
PatternT (MemInfo DimSize NoUniqueness MemBind)
pat

-- | The number of bytes needed to represent the array in a
-- straightforward contiguous format.
typeSize :: Type -> Count Bytes Imp.Exp
typeSize :: Type -> Count Bytes Exp
typeSize Type
t = Exp -> Count Bytes Exp
Imp.bytes (Exp -> Count Bytes Exp) -> Exp -> Count Bytes Exp
forall a b. (a -> b) -> a -> b
$ ExpLeaf -> PrimType -> Exp
forall v. v -> PrimType -> PrimExp v
Imp.LeafExp (PrimType -> ExpLeaf
Imp.SizeOf (PrimType -> ExpLeaf) -> PrimType -> ExpLeaf
forall a b. (a -> b) -> a -> b
$ Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t) PrimType
int32 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
*
             [Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ((DimSize -> Exp) -> [DimSize] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> DimSize -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32) (Type -> [DimSize]
forall u. TypeBase (ShapeBase DimSize) u -> [DimSize]
arrayDims Type
t))

--- Building blocks for constructing code.

sFor' :: VName -> IntType -> Imp.Exp -> ImpM lore r op () -> ImpM lore r op ()
sFor' :: VName -> IntType -> Exp -> ImpM lore r op () -> ImpM lore r op ()
sFor' VName
i IntType
it Exp
bound ImpM lore r op ()
body = do
  VName -> IntType -> ImpM lore r op ()
forall lore r op. VName -> IntType -> ImpM lore r op ()
addLoopVar VName
i IntType
it
  Code op
body' <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect ImpM lore r op ()
body
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> IntType -> Exp -> Code op -> Code op
forall a. VName -> IntType -> Exp -> Code a -> Code a
Imp.For VName
i IntType
it Exp
bound Code op
body'

sFor :: String -> Imp.Exp -> (Imp.Exp -> ImpM lore r op ()) -> ImpM lore r op ()
sFor :: String -> Exp -> (Exp -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
i Exp
bound Exp -> ImpM lore r op ()
body = do
  VName
i' <- String -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
i
  IntType
it <- case Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType Exp
bound of
          IntType IntType
it -> IntType -> ImpM lore r op IntType
forall (m :: * -> *) a. Monad m => a -> m a
return IntType
it
          PrimType
t -> String -> ImpM lore r op IntType
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op IntType)
-> String -> ImpM lore r op IntType
forall a b. (a -> b) -> a -> b
$ String
"sFor: bound " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Exp -> String
forall a. Pretty a => a -> String
pretty Exp
bound String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" is of type " String -> ShowS
forall a. [a] -> [a] -> [a]
++ PrimType -> String
forall a. Pretty a => a -> String
pretty PrimType
t
  VName -> IntType -> ImpM lore r op ()
forall lore r op. VName -> IntType -> ImpM lore r op ()
addLoopVar VName
i' IntType
it
  Code op
body' <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (ImpM lore r op () -> ImpM lore r op (Code op))
-> ImpM lore r op () -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$ Exp -> ImpM lore r op ()
body (Exp -> ImpM lore r op ()) -> Exp -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
i' (PrimType -> Exp) -> PrimType -> Exp
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
it
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> IntType -> Exp -> Code op -> Code op
forall a. VName -> IntType -> Exp -> Code a -> Code a
Imp.For VName
i' IntType
it Exp
bound Code op
body'

sWhile :: Imp.Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhile :: Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhile Exp
cond ImpM lore r op ()
body = do
  Code op
body' <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect ImpM lore r op ()
body
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Exp -> Code op -> Code op
forall a. Exp -> Code a -> Code a
Imp.While Exp
cond Code op
body'

sComment :: String -> ImpM lore r op () -> ImpM lore r op ()
sComment :: String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
s ImpM lore r op ()
code = do
  Code op
code' <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect ImpM lore r op ()
code
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ String -> Code op -> Code op
forall a. String -> Code a -> Code a
Imp.Comment String
s Code op
code'

sIf :: Imp.Exp -> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf :: Exp -> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf Exp
cond ImpM lore r op ()
tbranch ImpM lore r op ()
fbranch = do
  Code op
tbranch' <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect ImpM lore r op ()
tbranch
  Code op
fbranch' <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect ImpM lore r op ()
fbranch
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Exp -> Code op -> Code op -> Code op
forall a. Exp -> Code a -> Code a -> Code a
Imp.If Exp
cond Code op
tbranch' Code op
fbranch'

sWhen :: Imp.Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen :: Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen Exp
cond ImpM lore r op ()
tbranch = Exp -> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
Exp -> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf Exp
cond ImpM lore r op ()
tbranch (() -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())

sUnless :: Imp.Exp -> ImpM lore r op () -> ImpM lore r op ()
sUnless :: Exp -> ImpM lore r op () -> ImpM lore r op ()
sUnless Exp
cond = Exp -> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
Exp -> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf Exp
cond (() -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())

sOp :: op -> ImpM lore r op ()
sOp :: op -> ImpM lore r op ()
sOp = Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ())
-> (op -> Code op) -> op -> ImpM lore r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. op -> Code op
forall a. a -> Code a
Imp.Op

sDeclareMem :: String -> Space -> ImpM lore r op VName
sDeclareMem :: String -> Space -> ImpM lore r op VName
sDeclareMem String
name Space
space = do
  VName
name' <- String -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
name
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Code op
forall a. VName -> Space -> Code a
Imp.DeclareMem VName
name' Space
space
  VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name' (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> MemEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> MemEntry -> VarEntry lore
MemVar Maybe (Exp lore)
forall a. Maybe a
Nothing (MemEntry -> VarEntry lore) -> MemEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
space
  VName -> ImpM lore r op VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
name'

sAlloc_ :: VName -> Count Bytes Imp.Exp -> Space -> ImpM lore r op ()
sAlloc_ :: VName -> Count Bytes Exp -> Space -> ImpM lore r op ()
sAlloc_ VName
name' Count Bytes Exp
size' Space
space = do
  Maybe (AllocCompiler lore r op)
allocator <- (Env lore r op -> Maybe (AllocCompiler lore r op))
-> ImpM lore r op (Maybe (AllocCompiler lore r op))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((Env lore r op -> Maybe (AllocCompiler lore r op))
 -> ImpM lore r op (Maybe (AllocCompiler lore r op)))
-> (Env lore r op -> Maybe (AllocCompiler lore r op))
-> ImpM lore r op (Maybe (AllocCompiler lore r op))
forall a b. (a -> b) -> a -> b
$ Space
-> Map Space (AllocCompiler lore r op)
-> Maybe (AllocCompiler lore r op)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Space
space (Map Space (AllocCompiler lore r op)
 -> Maybe (AllocCompiler lore r op))
-> (Env lore r op -> Map Space (AllocCompiler lore r op))
-> Env lore r op
-> Maybe (AllocCompiler lore r op)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env lore r op -> Map Space (AllocCompiler lore r op)
forall lore r op.
Env lore r op -> Map Space (AllocCompiler lore r op)
envAllocCompilers
  case Maybe (AllocCompiler lore r op)
allocator of
    Maybe (AllocCompiler lore r op)
Nothing -> Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Count Bytes Exp -> Space -> Code op
forall a. VName -> Count Bytes Exp -> Space -> Code a
Imp.Allocate VName
name' Count Bytes Exp
size' Space
space
    Just AllocCompiler lore r op
allocator' -> AllocCompiler lore r op
allocator' VName
name' Count Bytes Exp
size'

sAlloc :: String -> Count Bytes Imp.Exp -> Space -> ImpM lore r op VName
sAlloc :: String -> Count Bytes Exp -> Space -> ImpM lore r op VName
sAlloc String
name Count Bytes Exp
size Space
space = do
  VName
name' <- String -> Space -> ImpM lore r op VName
forall lore r op. String -> Space -> ImpM lore r op VName
sDeclareMem String
name Space
space
  VName -> Count Bytes Exp -> Space -> ImpM lore r op ()
forall lore r op.
VName -> Count Bytes Exp -> Space -> ImpM lore r op ()
sAlloc_ VName
name' Count Bytes Exp
size Space
space
  VName -> ImpM lore r op VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
name'

sArray :: String -> PrimType -> ShapeBase SubExp -> MemBind -> ImpM lore r op VName
sArray :: String
-> PrimType -> ShapeBase DimSize -> MemBind -> ImpM lore r op VName
sArray String
name PrimType
bt ShapeBase DimSize
shape MemBind
membind = do
  VName
name' <- String -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
name
  VName
-> PrimType -> ShapeBase DimSize -> MemBind -> ImpM lore r op ()
forall lore r op.
VName
-> PrimType -> ShapeBase DimSize -> MemBind -> ImpM lore r op ()
dArray VName
name' PrimType
bt ShapeBase DimSize
shape MemBind
membind
  VName -> ImpM lore r op VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
name'

-- | Like 'sAllocArray', but permute the in-memory representation of the indices as specified.
sAllocArrayPerm :: String -> PrimType -> ShapeBase SubExp -> Space -> [Int] -> ImpM lore r op VName
sAllocArrayPerm :: String
-> PrimType
-> ShapeBase DimSize
-> Space
-> [Int]
-> ImpM lore r op VName
sAllocArrayPerm String
name PrimType
pt ShapeBase DimSize
shape Space
space [Int]
perm = do
  let permuted_dims :: [DimSize]
permuted_dims = [Int] -> [DimSize] -> [DimSize]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm ([DimSize] -> [DimSize]) -> [DimSize] -> [DimSize]
forall a b. (a -> b) -> a -> b
$ ShapeBase DimSize -> [DimSize]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase DimSize
shape
  VName
mem <- String -> Count Bytes Exp -> Space -> ImpM lore r op VName
forall lore r op.
String -> Count Bytes Exp -> Space -> ImpM lore r op VName
sAlloc (String
name String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_mem") (Type -> Count Bytes Exp
typeSize (PrimType -> ShapeBase DimSize -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
pt ShapeBase DimSize
shape NoUniqueness
NoUniqueness)) Space
space
  let iota_ixfun :: IxFun
iota_ixfun = Shape (PrimExp VName) -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota (Shape (PrimExp VName) -> IxFun) -> Shape (PrimExp VName) -> IxFun
forall a b. (a -> b) -> a -> b
$ (DimSize -> PrimExp VName) -> [DimSize] -> Shape (PrimExp VName)
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> DimSize -> PrimExp VName
primExpFromSubExp PrimType
int32) [DimSize]
permuted_dims
  String
-> PrimType -> ShapeBase DimSize -> MemBind -> ImpM lore r op VName
forall lore r op.
String
-> PrimType -> ShapeBase DimSize -> MemBind -> ImpM lore r op VName
sArray String
name PrimType
pt ShapeBase DimSize
shape (MemBind -> ImpM lore r op VName)
-> MemBind -> ImpM lore r op VName
forall a b. (a -> b) -> a -> b
$
    VName -> IxFun -> MemBind
ArrayIn VName
mem (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$ IxFun -> [Int] -> IxFun
forall num. IntegralExp num => IxFun num -> [Int] -> IxFun num
IxFun.permute IxFun
iota_ixfun ([Int] -> IxFun) -> [Int] -> IxFun
forall a b. (a -> b) -> a -> b
$ [Int] -> [Int]
rearrangeInverse [Int]
perm

-- | Uses linear/iota index function.
sAllocArray :: String -> PrimType -> ShapeBase SubExp -> Space -> ImpM lore r op VName
sAllocArray :: String
-> PrimType -> ShapeBase DimSize -> Space -> ImpM lore r op VName
sAllocArray String
name PrimType
pt ShapeBase DimSize
shape Space
space =
  String
-> PrimType
-> ShapeBase DimSize
-> Space
-> [Int]
-> ImpM lore r op VName
forall lore r op.
String
-> PrimType
-> ShapeBase DimSize
-> Space
-> [Int]
-> ImpM lore r op VName
sAllocArrayPerm String
name PrimType
pt ShapeBase DimSize
shape Space
space [Int
0..ShapeBase DimSize -> Int
forall a. ArrayShape a => a -> Int
shapeRank ShapeBase DimSize
shapeInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]

-- | Uses linear/iota index function.
sStaticArray :: String -> Space -> PrimType -> Imp.ArrayContents -> ImpM lore r op VName
sStaticArray :: String
-> Space -> PrimType -> ArrayContents -> ImpM lore r op VName
sStaticArray String
name Space
space PrimType
pt ArrayContents
vs = do
  let num_elems :: Int
num_elems = case ArrayContents
vs of Imp.ArrayValues [PrimValue]
vs' -> [PrimValue] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimValue]
vs'
                             Imp.ArrayZeros Int
n -> Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n
      shape :: ShapeBase DimSize
shape = [DimSize] -> ShapeBase DimSize
forall d. [d] -> ShapeBase d
Shape [IntType -> Integer -> DimSize
intConst IntType
Int32 (Integer -> DimSize) -> Integer -> DimSize
forall a b. (a -> b) -> a -> b
$ Int -> Integer
forall a. Integral a => a -> Integer
toInteger Int
num_elems]
  VName
mem <- String -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> ImpM lore r op VName) -> String -> ImpM lore r op VName
forall a b. (a -> b) -> a -> b
$ String
name String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_mem"
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Space -> PrimType -> ArrayContents -> Code op
forall a. VName -> Space -> PrimType -> ArrayContents -> Code a
Imp.DeclareArray VName
mem Space
space PrimType
pt ArrayContents
vs
  VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
mem (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> MemEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> MemEntry -> VarEntry lore
MemVar Maybe (Exp lore)
forall a. Maybe a
Nothing (MemEntry -> VarEntry lore) -> MemEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
space
  String
-> PrimType -> ShapeBase DimSize -> MemBind -> ImpM lore r op VName
forall lore r op.
String
-> PrimType -> ShapeBase DimSize -> MemBind -> ImpM lore r op VName
sArray String
name PrimType
pt ShapeBase DimSize
shape (MemBind -> ImpM lore r op VName)
-> MemBind -> ImpM lore r op VName
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn VName
mem (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$ Shape (PrimExp VName) -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota [Int -> PrimExp VName
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
num_elems]

sWrite :: VName -> [Imp.Exp] -> PrimExp Imp.ExpLeaf -> ImpM lore r op ()
sWrite :: VName -> [Exp] -> Exp -> ImpM lore r op ()
sWrite VName
arr [Exp]
is Exp
v = do
  (VName
mem, Space
space, Count Elements Exp
offset) <- VName -> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
forall lore r op.
VName -> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
fullyIndexArray VName
arr [Exp]
is
  Volatility
vol <- (Env lore r op -> Volatility) -> ImpM lore r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Volatility
forall lore r op. Env lore r op -> Volatility
envVolatility
  Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName
-> Count Elements Exp
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code op
forall a.
VName
-> Count Elements Exp
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
mem Count Elements Exp
offset (Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType Exp
v) Space
space Volatility
vol Exp
v

sUpdate :: VName -> Slice Imp.Exp -> SubExp -> ImpM lore r op ()
sUpdate :: VName -> [DimIndex Exp] -> DimSize -> ImpM lore r op ()
sUpdate VName
arr [DimIndex Exp]
slice DimSize
v = do
  MemLocation VName
mem [DimSize]
shape IxFun Exp
ixfun <- ArrayEntry -> MemLocation
entryArrayLocation (ArrayEntry -> MemLocation)
-> ImpM lore r op ArrayEntry -> ImpM lore r op MemLocation
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM lore r op ArrayEntry
forall lore r op. VName -> ImpM lore r op ArrayEntry
lookupArray VName
arr
  let memdest :: MemLocation
memdest = MemLocation -> [DimIndex Exp] -> MemLocation
sliceArray (VName -> [DimSize] -> IxFun Exp -> MemLocation
MemLocation VName
mem [DimSize]
shape IxFun Exp
ixfun) [DimIndex Exp]
slice
  ValueDestination
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
forall lore r op.
ValueDestination
-> [DimIndex Exp] -> DimSize -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIMDest (Maybe MemLocation -> ValueDestination
ArrayDestination (Maybe MemLocation -> ValueDestination)
-> Maybe MemLocation -> ValueDestination
forall a b. (a -> b) -> a -> b
$ MemLocation -> Maybe MemLocation
forall a. a -> Maybe a
Just MemLocation
memdest) [] DimSize
v []

sLoopNest :: Shape
          -> ([Imp.Exp] -> ImpM lore r op ())
          -> ImpM lore r op ()
sLoopNest :: ShapeBase DimSize
-> ([Exp] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest = [Exp]
-> [DimSize] -> ([Exp] -> ImpM lore r op ()) -> ImpM lore r op ()
forall a lore r op.
ToExp a =>
[Exp] -> [a] -> ([Exp] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest' [] ([DimSize] -> ([Exp] -> ImpM lore r op ()) -> ImpM lore r op ())
-> (ShapeBase DimSize -> [DimSize])
-> ShapeBase DimSize
-> ([Exp] -> ImpM lore r op ())
-> ImpM lore r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ShapeBase DimSize -> [DimSize]
forall d. ShapeBase d -> [d]
shapeDims
  where sLoopNest' :: [Exp] -> [a] -> ([Exp] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest' [Exp]
is [] [Exp] -> ImpM lore r op ()
f = [Exp] -> ImpM lore r op ()
f ([Exp] -> ImpM lore r op ()) -> [Exp] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ [Exp] -> [Exp]
forall a. [a] -> [a]
reverse [Exp]
is
        sLoopNest' [Exp]
is (a
d:[a]
ds) [Exp] -> ImpM lore r op ()
f = do
          Exp
d' <- a -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp a
d
          String -> Exp -> (Exp -> ImpM lore r op ()) -> ImpM lore r op ()
forall lore r op.
String -> Exp -> (Exp -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"nest_i" Exp
d' ((Exp -> ImpM lore r op ()) -> ImpM lore r op ())
-> (Exp -> ImpM lore r op ()) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \Exp
i -> [Exp] -> [a] -> ([Exp] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest' (Exp
iExp -> [Exp] -> [Exp]
forall a. a -> [a] -> [a]
:[Exp]
is) [a]
ds [Exp] -> ImpM lore r op ()
f

-- | ASsignment.
(<--) :: VName -> Imp.Exp -> ImpM lore r op ()
VName
x <-- :: VName -> Exp -> ImpM lore r op ()
<-- Exp
e = Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
x Exp
e
infixl 3 <--

-- | Constructing a non-entry point function.
function :: [Imp.Param] -> [Imp.Param] -> ImpM lore r op ()
         -> ImpM lore r op (Imp.Function op)
function :: [Param]
-> [Param] -> ImpM lore r op () -> ImpM lore r op (Function op)
function [Param]
outputs [Param]
inputs ImpM lore r op ()
m = do
  Code op
body <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (ImpM lore r op () -> ImpM lore r op (Code op))
-> ImpM lore r op () -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$ do
    (Param -> ImpM lore r op ()) -> [Param] -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Param -> ImpM lore r op ()
forall lore r op. Param -> ImpM lore r op ()
addParam ([Param] -> ImpM lore r op ()) -> [Param] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ [Param]
outputs [Param] -> [Param] -> [Param]
forall a. [a] -> [a] -> [a]
++ [Param]
inputs
    ImpM lore r op ()
m
  Function op -> ImpM lore r op (Function op)
forall (m :: * -> *) a. Monad m => a -> m a
return (Function op -> ImpM lore r op (Function op))
-> Function op -> ImpM lore r op (Function op)
forall a b. (a -> b) -> a -> b
$ Bool
-> [Param]
-> [Param]
-> Code op
-> [ExternalValue]
-> [ExternalValue]
-> Function op
forall a.
Bool
-> [Param]
-> [Param]
-> Code a
-> [ExternalValue]
-> [ExternalValue]
-> FunctionT a
Imp.Function Bool
False [Param]
outputs [Param]
inputs Code op
body [] []
  where addParam :: Param -> ImpM lore r op ()
addParam (Imp.MemParam VName
name Space
space) =
          VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> MemEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> MemEntry -> VarEntry lore
MemVar Maybe (Exp lore)
forall a. Maybe a
Nothing (MemEntry -> VarEntry lore) -> MemEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
space
        addParam (Imp.ScalarParam VName
name PrimType
bt) =
          VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
ScalarVar Maybe (Exp lore)
forall a. Maybe a
Nothing (ScalarEntry -> VarEntry lore) -> ScalarEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarEntry
ScalarEntry PrimType
bt