{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE Strict #-}
{-# LANGUAGE Trustworthy #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.CodeGen.ImpGen
(
compileProg,
OpCompiler,
ExpCompiler,
CopyCompiler,
StmsCompiler,
AllocCompiler,
Operations (..),
defaultOperations,
MemLocation (..),
MemEntry (..),
ScalarEntry (..),
ImpM,
localDefaultSpace,
askFunction,
newVNameForFun,
nameForFun,
askEnv,
localEnv,
localOps,
VTable,
getVTable,
localVTable,
subImpM,
subImpM_,
emit,
emitFunction,
hasFunction,
collect,
collect',
comment,
VarEntry (..),
ArrayEntry (..),
lookupVar,
lookupArray,
lookupMemory,
TV,
mkTV,
tvSize,
tvExp,
tvVar,
ToExp (..),
compileAlloc,
everythingVolatile,
compileBody,
compileBody',
compileLoopBody,
defCompileStms,
compileStms,
compileExp,
defCompileExp,
fullyIndexArray,
fullyIndexArray',
copy,
copyDWIM,
copyDWIMFix,
copyElementWise,
typeSize,
isMapTransposeCopy,
dLParams,
dFParams,
dScope,
dArray,
dPrim,
dPrimVol,
dPrim_,
dPrimV_,
dPrimV,
dPrimVE,
sFor,
sWhile,
sComment,
sIf,
sWhen,
sUnless,
sOp,
sDeclareMem,
sAlloc,
sAlloc_,
sArray,
sArrayInMem,
sAllocArray,
sAllocArrayPerm,
sStaticArray,
sWrite,
sUpdate,
sLoopNest,
(<--),
(<~~),
function,
warn,
module Language.Futhark.Warnings,
)
where
import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Writer
import Control.Parallel.Strategies
import Data.Bifunctor (first)
import qualified Data.DList as DL
import Data.Either
import Data.List (find, genericLength, sortOn)
import qualified Data.Map.Strict as M
import Data.Maybe
import qualified Data.Set as S
import Futhark.CodeGen.ImpCode
( Bytes,
Count,
Elements,
bytes,
elements,
withElemType,
)
import qualified Futhark.CodeGen.ImpCode as Imp
import Futhark.CodeGen.ImpGen.Transpose
import Futhark.Construct hiding (ToExp (..))
import Futhark.IR.Mem
import qualified Futhark.IR.Mem.IxFun as IxFun
import Futhark.IR.SOACS (SOACS)
import Futhark.Util
import Futhark.Util.Loc (noLoc)
import Language.Futhark.Warnings
type OpCompiler lore r op = Pattern lore -> Op lore -> ImpM lore r op ()
type StmsCompiler lore r op = Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
type ExpCompiler lore r op = Pattern lore -> Exp lore -> ImpM lore r op ()
type CopyCompiler lore r op =
PrimType ->
MemLocation ->
Slice (Imp.TExp Int64) ->
MemLocation ->
Slice (Imp.TExp Int64) ->
ImpM lore r op ()
type AllocCompiler lore r op = VName -> Count Bytes (Imp.TExp Int64) -> ImpM lore r op ()
data Operations lore r op = Operations
{ Operations lore r op -> ExpCompiler lore r op
opsExpCompiler :: ExpCompiler lore r op,
Operations lore r op -> OpCompiler lore r op
opsOpCompiler :: OpCompiler lore r op,
Operations lore r op -> StmsCompiler lore r op
opsStmsCompiler :: StmsCompiler lore r op,
Operations lore r op -> CopyCompiler lore r op
opsCopyCompiler :: CopyCompiler lore r op,
Operations lore r op -> Map Space (AllocCompiler lore r op)
opsAllocCompilers :: M.Map Space (AllocCompiler lore r op)
}
defaultOperations ::
(Mem lore, FreeIn op) =>
OpCompiler lore r op ->
Operations lore r op
defaultOperations :: OpCompiler lore r op -> Operations lore r op
defaultOperations OpCompiler lore r op
opc =
Operations :: forall lore r op.
ExpCompiler lore r op
-> OpCompiler lore r op
-> StmsCompiler lore r op
-> CopyCompiler lore r op
-> Map Space (AllocCompiler lore r op)
-> Operations lore r op
Operations
{ opsExpCompiler :: ExpCompiler lore r op
opsExpCompiler = ExpCompiler lore r op
forall lore r op.
Mem lore =>
Pattern lore -> Exp lore -> ImpM lore r op ()
defCompileExp,
opsOpCompiler :: OpCompiler lore r op
opsOpCompiler = OpCompiler lore r op
opc,
opsStmsCompiler :: StmsCompiler lore r op
opsStmsCompiler = StmsCompiler lore r op
forall lore op r.
(Mem lore, FreeIn op) =>
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
defCompileStms,
opsCopyCompiler :: CopyCompiler lore r op
opsCopyCompiler = CopyCompiler lore r op
forall lore r op. CopyCompiler lore r op
defaultCopy,
opsAllocCompilers :: Map Space (AllocCompiler lore r op)
opsAllocCompilers = Map Space (AllocCompiler lore r op)
forall a. Monoid a => a
mempty
}
data MemLocation = MemLocation
{ MemLocation -> VName
memLocationName :: VName,
MemLocation -> [DimSize]
memLocationShape :: [Imp.DimSize],
MemLocation -> IxFun (TExp Int64)
memLocationIxFun :: IxFun.IxFun (Imp.TExp Int64)
}
deriving (MemLocation -> MemLocation -> Bool
(MemLocation -> MemLocation -> Bool)
-> (MemLocation -> MemLocation -> Bool) -> Eq MemLocation
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: MemLocation -> MemLocation -> Bool
$c/= :: MemLocation -> MemLocation -> Bool
== :: MemLocation -> MemLocation -> Bool
$c== :: MemLocation -> MemLocation -> Bool
Eq, Int -> MemLocation -> ShowS
[MemLocation] -> ShowS
MemLocation -> String
(Int -> MemLocation -> ShowS)
-> (MemLocation -> String)
-> ([MemLocation] -> ShowS)
-> Show MemLocation
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MemLocation] -> ShowS
$cshowList :: [MemLocation] -> ShowS
show :: MemLocation -> String
$cshow :: MemLocation -> String
showsPrec :: Int -> MemLocation -> ShowS
$cshowsPrec :: Int -> MemLocation -> ShowS
Show)
data ArrayEntry = ArrayEntry
{ ArrayEntry -> MemLocation
entryArrayLocation :: MemLocation,
ArrayEntry -> PrimType
entryArrayElemType :: PrimType
}
deriving (Int -> ArrayEntry -> ShowS
[ArrayEntry] -> ShowS
ArrayEntry -> String
(Int -> ArrayEntry -> ShowS)
-> (ArrayEntry -> String)
-> ([ArrayEntry] -> ShowS)
-> Show ArrayEntry
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ArrayEntry] -> ShowS
$cshowList :: [ArrayEntry] -> ShowS
show :: ArrayEntry -> String
$cshow :: ArrayEntry -> String
showsPrec :: Int -> ArrayEntry -> ShowS
$cshowsPrec :: Int -> ArrayEntry -> ShowS
Show)
entryArrayShape :: ArrayEntry -> [Imp.DimSize]
entryArrayShape :: ArrayEntry -> [DimSize]
entryArrayShape = MemLocation -> [DimSize]
memLocationShape (MemLocation -> [DimSize])
-> (ArrayEntry -> MemLocation) -> ArrayEntry -> [DimSize]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArrayEntry -> MemLocation
entryArrayLocation
newtype MemEntry = MemEntry {MemEntry -> Space
entryMemSpace :: Imp.Space}
deriving (Int -> MemEntry -> ShowS
[MemEntry] -> ShowS
MemEntry -> String
(Int -> MemEntry -> ShowS)
-> (MemEntry -> String) -> ([MemEntry] -> ShowS) -> Show MemEntry
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MemEntry] -> ShowS
$cshowList :: [MemEntry] -> ShowS
show :: MemEntry -> String
$cshow :: MemEntry -> String
showsPrec :: Int -> MemEntry -> ShowS
$cshowsPrec :: Int -> MemEntry -> ShowS
Show)
newtype ScalarEntry = ScalarEntry
{ ScalarEntry -> PrimType
entryScalarType :: PrimType
}
deriving (Int -> ScalarEntry -> ShowS
[ScalarEntry] -> ShowS
ScalarEntry -> String
(Int -> ScalarEntry -> ShowS)
-> (ScalarEntry -> String)
-> ([ScalarEntry] -> ShowS)
-> Show ScalarEntry
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ScalarEntry] -> ShowS
$cshowList :: [ScalarEntry] -> ShowS
show :: ScalarEntry -> String
$cshow :: ScalarEntry -> String
showsPrec :: Int -> ScalarEntry -> ShowS
$cshowsPrec :: Int -> ScalarEntry -> ShowS
Show)
data VarEntry lore
= ArrayVar (Maybe (Exp lore)) ArrayEntry
| ScalarVar (Maybe (Exp lore)) ScalarEntry
| MemVar (Maybe (Exp lore)) MemEntry
deriving (Int -> VarEntry lore -> ShowS
[VarEntry lore] -> ShowS
VarEntry lore -> String
(Int -> VarEntry lore -> ShowS)
-> (VarEntry lore -> String)
-> ([VarEntry lore] -> ShowS)
-> Show (VarEntry lore)
forall lore. Decorations lore => Int -> VarEntry lore -> ShowS
forall lore. Decorations lore => [VarEntry lore] -> ShowS
forall lore. Decorations lore => VarEntry lore -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [VarEntry lore] -> ShowS
$cshowList :: forall lore. Decorations lore => [VarEntry lore] -> ShowS
show :: VarEntry lore -> String
$cshow :: forall lore. Decorations lore => VarEntry lore -> String
showsPrec :: Int -> VarEntry lore -> ShowS
$cshowsPrec :: forall lore. Decorations lore => Int -> VarEntry lore -> ShowS
Show)
data Destination = Destination
{ Destination -> Maybe Int
destinationTag :: Maybe Int,
Destination -> [ValueDestination]
valueDestinations :: [ValueDestination]
}
deriving (Int -> Destination -> ShowS
[Destination] -> ShowS
Destination -> String
(Int -> Destination -> ShowS)
-> (Destination -> String)
-> ([Destination] -> ShowS)
-> Show Destination
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Destination] -> ShowS
$cshowList :: [Destination] -> ShowS
show :: Destination -> String
$cshow :: Destination -> String
showsPrec :: Int -> Destination -> ShowS
$cshowsPrec :: Int -> Destination -> ShowS
Show)
data ValueDestination
= ScalarDestination VName
| MemoryDestination VName
|
ArrayDestination (Maybe MemLocation)
deriving (Int -> ValueDestination -> ShowS
[ValueDestination] -> ShowS
ValueDestination -> String
(Int -> ValueDestination -> ShowS)
-> (ValueDestination -> String)
-> ([ValueDestination] -> ShowS)
-> Show ValueDestination
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ValueDestination] -> ShowS
$cshowList :: [ValueDestination] -> ShowS
show :: ValueDestination -> String
$cshow :: ValueDestination -> String
showsPrec :: Int -> ValueDestination -> ShowS
$cshowsPrec :: Int -> ValueDestination -> ShowS
Show)
data Env lore r op = Env
{ Env lore r op -> ExpCompiler lore r op
envExpCompiler :: ExpCompiler lore r op,
Env lore r op -> StmsCompiler lore r op
envStmsCompiler :: StmsCompiler lore r op,
Env lore r op -> OpCompiler lore r op
envOpCompiler :: OpCompiler lore r op,
Env lore r op -> CopyCompiler lore r op
envCopyCompiler :: CopyCompiler lore r op,
Env lore r op -> Map Space (AllocCompiler lore r op)
envAllocCompilers :: M.Map Space (AllocCompiler lore r op),
Env lore r op -> Space
envDefaultSpace :: Imp.Space,
Env lore r op -> Volatility
envVolatility :: Imp.Volatility,
Env lore r op -> r
envEnv :: r,
Env lore r op -> Maybe Name
envFunction :: Maybe Name,
Env lore r op -> Attrs
envAttrs :: Attrs
}
newEnv :: r -> Operations lore r op -> Imp.Space -> Env lore r op
newEnv :: r -> Operations lore r op -> Space -> Env lore r op
newEnv r
r Operations lore r op
ops Space
ds =
Env :: forall lore r op.
ExpCompiler lore r op
-> StmsCompiler lore r op
-> OpCompiler lore r op
-> CopyCompiler lore r op
-> Map Space (AllocCompiler lore r op)
-> Space
-> Volatility
-> r
-> Maybe Name
-> Attrs
-> Env lore r op
Env
{ envExpCompiler :: ExpCompiler lore r op
envExpCompiler = Operations lore r op -> ExpCompiler lore r op
forall lore r op. Operations lore r op -> ExpCompiler lore r op
opsExpCompiler Operations lore r op
ops,
envStmsCompiler :: StmsCompiler lore r op
envStmsCompiler = Operations lore r op -> StmsCompiler lore r op
forall lore r op. Operations lore r op -> StmsCompiler lore r op
opsStmsCompiler Operations lore r op
ops,
envOpCompiler :: OpCompiler lore r op
envOpCompiler = Operations lore r op -> OpCompiler lore r op
forall lore r op. Operations lore r op -> OpCompiler lore r op
opsOpCompiler Operations lore r op
ops,
envCopyCompiler :: CopyCompiler lore r op
envCopyCompiler = Operations lore r op -> CopyCompiler lore r op
forall lore r op. Operations lore r op -> CopyCompiler lore r op
opsCopyCompiler Operations lore r op
ops,
envAllocCompilers :: Map Space (AllocCompiler lore r op)
envAllocCompilers = Map Space (AllocCompiler lore r op)
forall a. Monoid a => a
mempty,
envDefaultSpace :: Space
envDefaultSpace = Space
ds,
envVolatility :: Volatility
envVolatility = Volatility
Imp.Nonvolatile,
envEnv :: r
envEnv = r
r,
envFunction :: Maybe Name
envFunction = Maybe Name
forall a. Maybe a
Nothing,
envAttrs :: Attrs
envAttrs = Attrs
forall a. Monoid a => a
mempty
}
type VTable lore = M.Map VName (VarEntry lore)
data ImpState lore r op = ImpState
{ ImpState lore r op -> VTable lore
stateVTable :: VTable lore,
ImpState lore r op -> Functions op
stateFunctions :: Imp.Functions op,
ImpState lore r op -> Code op
stateCode :: Imp.Code op,
ImpState lore r op -> Warnings
stateWarnings :: Warnings,
ImpState lore r op -> VNameSource
stateNameSource :: VNameSource
}
newState :: VNameSource -> ImpState lore r op
newState :: VNameSource -> ImpState lore r op
newState = VTable lore
-> Functions op
-> Code op
-> Warnings
-> VNameSource
-> ImpState lore r op
forall lore r op.
VTable lore
-> Functions op
-> Code op
-> Warnings
-> VNameSource
-> ImpState lore r op
ImpState VTable lore
forall a. Monoid a => a
mempty Functions op
forall a. Monoid a => a
mempty Code op
forall a. Monoid a => a
mempty Warnings
forall a. Monoid a => a
mempty
newtype ImpM lore r op a
= ImpM (ReaderT (Env lore r op) (State (ImpState lore r op)) a)
deriving
( a -> ImpM lore r op b -> ImpM lore r op a
(a -> b) -> ImpM lore r op a -> ImpM lore r op b
(forall a b. (a -> b) -> ImpM lore r op a -> ImpM lore r op b)
-> (forall a b. a -> ImpM lore r op b -> ImpM lore r op a)
-> Functor (ImpM lore r op)
forall a b. a -> ImpM lore r op b -> ImpM lore r op a
forall a b. (a -> b) -> ImpM lore r op a -> ImpM lore r op b
forall lore r op a b. a -> ImpM lore r op b -> ImpM lore r op a
forall lore r op a b.
(a -> b) -> ImpM lore r op a -> ImpM lore r op b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> ImpM lore r op b -> ImpM lore r op a
$c<$ :: forall lore r op a b. a -> ImpM lore r op b -> ImpM lore r op a
fmap :: (a -> b) -> ImpM lore r op a -> ImpM lore r op b
$cfmap :: forall lore r op a b.
(a -> b) -> ImpM lore r op a -> ImpM lore r op b
Functor,
Functor (ImpM lore r op)
a -> ImpM lore r op a
Functor (ImpM lore r op)
-> (forall a. a -> ImpM lore r op a)
-> (forall a b.
ImpM lore r op (a -> b) -> ImpM lore r op a -> ImpM lore r op b)
-> (forall a b c.
(a -> b -> c)
-> ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op c)
-> (forall a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b)
-> (forall a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op a)
-> Applicative (ImpM lore r op)
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op a
ImpM lore r op (a -> b) -> ImpM lore r op a -> ImpM lore r op b
(a -> b -> c)
-> ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op c
forall a. a -> ImpM lore r op a
forall a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op a
forall a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
forall a b.
ImpM lore r op (a -> b) -> ImpM lore r op a -> ImpM lore r op b
forall lore r op. Functor (ImpM lore r op)
forall a b c.
(a -> b -> c)
-> ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op c
forall lore r op a. a -> ImpM lore r op a
forall lore r op a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op a
forall lore r op a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
forall lore r op a b.
ImpM lore r op (a -> b) -> ImpM lore r op a -> ImpM lore r op b
forall lore r op a b c.
(a -> b -> c)
-> ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op a
$c<* :: forall lore r op a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op a
*> :: ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
$c*> :: forall lore r op a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
liftA2 :: (a -> b -> c)
-> ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op c
$cliftA2 :: forall lore r op a b c.
(a -> b -> c)
-> ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op c
<*> :: ImpM lore r op (a -> b) -> ImpM lore r op a -> ImpM lore r op b
$c<*> :: forall lore r op a b.
ImpM lore r op (a -> b) -> ImpM lore r op a -> ImpM lore r op b
pure :: a -> ImpM lore r op a
$cpure :: forall lore r op a. a -> ImpM lore r op a
$cp1Applicative :: forall lore r op. Functor (ImpM lore r op)
Applicative,
Applicative (ImpM lore r op)
a -> ImpM lore r op a
Applicative (ImpM lore r op)
-> (forall a b.
ImpM lore r op a -> (a -> ImpM lore r op b) -> ImpM lore r op b)
-> (forall a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b)
-> (forall a. a -> ImpM lore r op a)
-> Monad (ImpM lore r op)
ImpM lore r op a -> (a -> ImpM lore r op b) -> ImpM lore r op b
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
forall a. a -> ImpM lore r op a
forall a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
forall a b.
ImpM lore r op a -> (a -> ImpM lore r op b) -> ImpM lore r op b
forall lore r op. Applicative (ImpM lore r op)
forall lore r op a. a -> ImpM lore r op a
forall lore r op a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
forall lore r op a b.
ImpM lore r op a -> (a -> ImpM lore r op b) -> ImpM lore r op b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> ImpM lore r op a
$creturn :: forall lore r op a. a -> ImpM lore r op a
>> :: ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
$c>> :: forall lore r op a b.
ImpM lore r op a -> ImpM lore r op b -> ImpM lore r op b
>>= :: ImpM lore r op a -> (a -> ImpM lore r op b) -> ImpM lore r op b
$c>>= :: forall lore r op a b.
ImpM lore r op a -> (a -> ImpM lore r op b) -> ImpM lore r op b
$cp1Monad :: forall lore r op. Applicative (ImpM lore r op)
Monad,
MonadState (ImpState lore r op),
MonadReader (Env lore r op)
)
instance MonadFreshNames (ImpM lore r op) where
getNameSource :: ImpM lore r op VNameSource
getNameSource = (ImpState lore r op -> VNameSource) -> ImpM lore r op VNameSource
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ImpState lore r op -> VNameSource
forall lore r op. ImpState lore r op -> VNameSource
stateNameSource
putNameSource :: VNameSource -> ImpM lore r op ()
putNameSource VNameSource
src = (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ())
-> (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState lore r op
s -> ImpState lore r op
s {stateNameSource :: VNameSource
stateNameSource = VNameSource
src}
instance HasScope SOACS (ImpM lore r op) where
askScope :: ImpM lore r op (Scope SOACS)
askScope = (ImpState lore r op -> Scope SOACS) -> ImpM lore r op (Scope SOACS)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ImpState lore r op -> Scope SOACS)
-> ImpM lore r op (Scope SOACS))
-> (ImpState lore r op -> Scope SOACS)
-> ImpM lore r op (Scope SOACS)
forall a b. (a -> b) -> a -> b
$ (VarEntry lore -> NameInfo SOACS)
-> Map VName (VarEntry lore) -> Scope SOACS
forall a b k. (a -> b) -> Map k a -> Map k b
M.map (Type -> NameInfo SOACS
forall lore. LetDec lore -> NameInfo lore
LetName (Type -> NameInfo SOACS)
-> (VarEntry lore -> Type) -> VarEntry lore -> NameInfo SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VarEntry lore -> Type
forall lore. VarEntry lore -> Type
entryType) (Map VName (VarEntry lore) -> Scope SOACS)
-> (ImpState lore r op -> Map VName (VarEntry lore))
-> ImpState lore r op
-> Scope SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ImpState lore r op -> Map VName (VarEntry lore)
forall lore r op. ImpState lore r op -> VTable lore
stateVTable
where
entryType :: VarEntry lore -> Type
entryType (MemVar Maybe (Exp lore)
_ MemEntry
memEntry) =
Space -> Type
forall shape u. Space -> TypeBase shape u
Mem (MemEntry -> Space
entryMemSpace MemEntry
memEntry)
entryType (ArrayVar Maybe (Exp lore)
_ ArrayEntry
arrayEntry) =
PrimType -> ShapeBase DimSize -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array
(ArrayEntry -> PrimType
entryArrayElemType ArrayEntry
arrayEntry)
([DimSize] -> ShapeBase DimSize
forall d. [d] -> ShapeBase d
Shape ([DimSize] -> ShapeBase DimSize) -> [DimSize] -> ShapeBase DimSize
forall a b. (a -> b) -> a -> b
$ ArrayEntry -> [DimSize]
entryArrayShape ArrayEntry
arrayEntry)
NoUniqueness
NoUniqueness
entryType (ScalarVar Maybe (Exp lore)
_ ScalarEntry
scalarEntry) =
PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> Type) -> PrimType -> Type
forall a b. (a -> b) -> a -> b
$ ScalarEntry -> PrimType
entryScalarType ScalarEntry
scalarEntry
runImpM ::
ImpM lore r op a ->
r ->
Operations lore r op ->
Imp.Space ->
ImpState lore r op ->
(a, ImpState lore r op)
runImpM :: ImpM lore r op a
-> r
-> Operations lore r op
-> Space
-> ImpState lore r op
-> (a, ImpState lore r op)
runImpM (ImpM ReaderT (Env lore r op) (State (ImpState lore r op)) a
m) r
r Operations lore r op
ops Space
space = State (ImpState lore r op) a
-> ImpState lore r op -> (a, ImpState lore r op)
forall s a. State s a -> s -> (a, s)
runState (ReaderT (Env lore r op) (State (ImpState lore r op)) a
-> Env lore r op -> State (ImpState lore r op) a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Env lore r op) (State (ImpState lore r op)) a
m (Env lore r op -> State (ImpState lore r op) a)
-> Env lore r op -> State (ImpState lore r op) a
forall a b. (a -> b) -> a -> b
$ r -> Operations lore r op -> Space -> Env lore r op
forall r lore op.
r -> Operations lore r op -> Space -> Env lore r op
newEnv r
r Operations lore r op
ops Space
space)
subImpM_ ::
r' ->
Operations lore r' op' ->
ImpM lore r' op' a ->
ImpM lore r op (Imp.Code op')
subImpM_ :: r'
-> Operations lore r' op'
-> ImpM lore r' op' a
-> ImpM lore r op (Code op')
subImpM_ r'
r Operations lore r' op'
ops ImpM lore r' op' a
m = (a, Code op') -> Code op'
forall a b. (a, b) -> b
snd ((a, Code op') -> Code op')
-> ImpM lore r op (a, Code op') -> ImpM lore r op (Code op')
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> r'
-> Operations lore r' op'
-> ImpM lore r' op' a
-> ImpM lore r op (a, Code op')
forall r' lore op' a r op.
r'
-> Operations lore r' op'
-> ImpM lore r' op' a
-> ImpM lore r op (a, Code op')
subImpM r'
r Operations lore r' op'
ops ImpM lore r' op' a
m
subImpM ::
r' ->
Operations lore r' op' ->
ImpM lore r' op' a ->
ImpM lore r op (a, Imp.Code op')
subImpM :: r'
-> Operations lore r' op'
-> ImpM lore r' op' a
-> ImpM lore r op (a, Code op')
subImpM r'
r Operations lore r' op'
ops (ImpM ReaderT (Env lore r' op') (State (ImpState lore r' op')) a
m) = do
Env lore r op
env <- ImpM lore r op (Env lore r op)
forall r (m :: * -> *). MonadReader r m => m r
ask
ImpState lore r op
s <- ImpM lore r op (ImpState lore r op)
forall s (m :: * -> *). MonadState s m => m s
get
let env' :: Env lore r' op'
env' =
Env lore r op
env
{ envExpCompiler :: ExpCompiler lore r' op'
envExpCompiler = Operations lore r' op' -> ExpCompiler lore r' op'
forall lore r op. Operations lore r op -> ExpCompiler lore r op
opsExpCompiler Operations lore r' op'
ops,
envStmsCompiler :: StmsCompiler lore r' op'
envStmsCompiler = Operations lore r' op' -> StmsCompiler lore r' op'
forall lore r op. Operations lore r op -> StmsCompiler lore r op
opsStmsCompiler Operations lore r' op'
ops,
envCopyCompiler :: CopyCompiler lore r' op'
envCopyCompiler = Operations lore r' op' -> CopyCompiler lore r' op'
forall lore r op. Operations lore r op -> CopyCompiler lore r op
opsCopyCompiler Operations lore r' op'
ops,
envOpCompiler :: OpCompiler lore r' op'
envOpCompiler = Operations lore r' op' -> OpCompiler lore r' op'
forall lore r op. Operations lore r op -> OpCompiler lore r op
opsOpCompiler Operations lore r' op'
ops,
envAllocCompilers :: Map Space (AllocCompiler lore r' op')
envAllocCompilers = Operations lore r' op' -> Map Space (AllocCompiler lore r' op')
forall lore r op.
Operations lore r op -> Map Space (AllocCompiler lore r op)
opsAllocCompilers Operations lore r' op'
ops,
envEnv :: r'
envEnv = r'
r
}
s' :: ImpState lore r' op'
s' =
ImpState :: forall lore r op.
VTable lore
-> Functions op
-> Code op
-> Warnings
-> VNameSource
-> ImpState lore r op
ImpState
{ stateVTable :: VTable lore
stateVTable = ImpState lore r op -> VTable lore
forall lore r op. ImpState lore r op -> VTable lore
stateVTable ImpState lore r op
s,
stateFunctions :: Functions op'
stateFunctions = Functions op'
forall a. Monoid a => a
mempty,
stateCode :: Code op'
stateCode = Code op'
forall a. Monoid a => a
mempty,
stateNameSource :: VNameSource
stateNameSource = ImpState lore r op -> VNameSource
forall lore r op. ImpState lore r op -> VNameSource
stateNameSource ImpState lore r op
s,
stateWarnings :: Warnings
stateWarnings = Warnings
forall a. Monoid a => a
mempty
}
(a
x, ImpState lore r' op'
s'') = State (ImpState lore r' op') a
-> ImpState lore r' op' -> (a, ImpState lore r' op')
forall s a. State s a -> s -> (a, s)
runState (ReaderT (Env lore r' op') (State (ImpState lore r' op')) a
-> Env lore r' op' -> State (ImpState lore r' op') a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Env lore r' op') (State (ImpState lore r' op')) a
m Env lore r' op'
env') ImpState lore r' op'
s'
VNameSource -> ImpM lore r op ()
forall (m :: * -> *). MonadFreshNames m => VNameSource -> m ()
putNameSource (VNameSource -> ImpM lore r op ())
-> VNameSource -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ ImpState lore r' op' -> VNameSource
forall lore r op. ImpState lore r op -> VNameSource
stateNameSource ImpState lore r' op'
s''
Warnings -> ImpM lore r op ()
forall lore r op. Warnings -> ImpM lore r op ()
warnings (Warnings -> ImpM lore r op ()) -> Warnings -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ ImpState lore r' op' -> Warnings
forall lore r op. ImpState lore r op -> Warnings
stateWarnings ImpState lore r' op'
s''
(a, Code op') -> ImpM lore r op (a, Code op')
forall (m :: * -> *) a. Monad m => a -> m a
return (a
x, ImpState lore r' op' -> Code op'
forall lore r op. ImpState lore r op -> Code op
stateCode ImpState lore r' op'
s'')
collect :: ImpM lore r op () -> ImpM lore r op (Imp.Code op)
collect :: ImpM lore r op () -> ImpM lore r op (Code op)
collect = (((), Code op) -> Code op)
-> ImpM lore r op ((), Code op) -> ImpM lore r op (Code op)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((), Code op) -> Code op
forall a b. (a, b) -> b
snd (ImpM lore r op ((), Code op) -> ImpM lore r op (Code op))
-> (ImpM lore r op () -> ImpM lore r op ((), Code op))
-> ImpM lore r op ()
-> ImpM lore r op (Code op)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ImpM lore r op () -> ImpM lore r op ((), Code op)
forall lore r op a. ImpM lore r op a -> ImpM lore r op (a, Code op)
collect'
collect' :: ImpM lore r op a -> ImpM lore r op (a, Imp.Code op)
collect' :: ImpM lore r op a -> ImpM lore r op (a, Code op)
collect' ImpM lore r op a
m = do
Code op
prev_code <- (ImpState lore r op -> Code op) -> ImpM lore r op (Code op)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ImpState lore r op -> Code op
forall lore r op. ImpState lore r op -> Code op
stateCode
(ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ())
-> (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState lore r op
s -> ImpState lore r op
s {stateCode :: Code op
stateCode = Code op
forall a. Monoid a => a
mempty}
a
x <- ImpM lore r op a
m
Code op
new_code <- (ImpState lore r op -> Code op) -> ImpM lore r op (Code op)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ImpState lore r op -> Code op
forall lore r op. ImpState lore r op -> Code op
stateCode
(ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ())
-> (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState lore r op
s -> ImpState lore r op
s {stateCode :: Code op
stateCode = Code op
prev_code}
(a, Code op) -> ImpM lore r op (a, Code op)
forall (m :: * -> *) a. Monad m => a -> m a
return (a
x, Code op
new_code)
comment :: String -> ImpM lore r op () -> ImpM lore r op ()
String
desc ImpM lore r op ()
m = do
Code op
code <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect ImpM lore r op ()
m
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ String -> Code op -> Code op
forall a. String -> Code a -> Code a
Imp.Comment String
desc Code op
code
emit :: Imp.Code op -> ImpM lore r op ()
emit :: Code op -> ImpM lore r op ()
emit Code op
code = (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ())
-> (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState lore r op
s -> ImpState lore r op
s {stateCode :: Code op
stateCode = ImpState lore r op -> Code op
forall lore r op. ImpState lore r op -> Code op
stateCode ImpState lore r op
s Code op -> Code op -> Code op
forall a. Semigroup a => a -> a -> a
<> Code op
code}
warnings :: Warnings -> ImpM lore r op ()
warnings :: Warnings -> ImpM lore r op ()
warnings Warnings
ws = (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ())
-> (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState lore r op
s -> ImpState lore r op
s {stateWarnings :: Warnings
stateWarnings = Warnings
ws Warnings -> Warnings -> Warnings
forall a. Semigroup a => a -> a -> a
<> ImpState lore r op -> Warnings
forall lore r op. ImpState lore r op -> Warnings
stateWarnings ImpState lore r op
s}
warn :: Located loc => loc -> [loc] -> String -> ImpM lore r op ()
warn :: loc -> [loc] -> String -> ImpM lore r op ()
warn loc
loc [loc]
locs String
problem =
Warnings -> ImpM lore r op ()
forall lore r op. Warnings -> ImpM lore r op ()
warnings (Warnings -> ImpM lore r op ()) -> Warnings -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ SrcLoc -> [SrcLoc] -> String -> Warnings
singleWarning' (loc -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf loc
loc) ((loc -> SrcLoc) -> [loc] -> [SrcLoc]
forall a b. (a -> b) -> [a] -> [b]
map loc -> SrcLoc
forall a. Located a => a -> SrcLoc
srclocOf [loc]
locs) String
problem
emitFunction :: Name -> Imp.Function op -> ImpM lore r op ()
emitFunction :: Name -> Function op -> ImpM lore r op ()
emitFunction Name
fname Function op
fun = do
Imp.Functions [(Name, Function op)]
fs <- (ImpState lore r op -> Functions op)
-> ImpM lore r op (Functions op)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ImpState lore r op -> Functions op
forall lore r op. ImpState lore r op -> Functions op
stateFunctions
(ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ())
-> (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState lore r op
s -> ImpState lore r op
s {stateFunctions :: Functions op
stateFunctions = [(Name, Function op)] -> Functions op
forall a. [(Name, Function a)] -> Functions a
Imp.Functions ([(Name, Function op)] -> Functions op)
-> [(Name, Function op)] -> Functions op
forall a b. (a -> b) -> a -> b
$ (Name
fname, Function op
fun) (Name, Function op)
-> [(Name, Function op)] -> [(Name, Function op)]
forall a. a -> [a] -> [a]
: [(Name, Function op)]
fs}
hasFunction :: Name -> ImpM lore r op Bool
hasFunction :: Name -> ImpM lore r op Bool
hasFunction Name
fname = (ImpState lore r op -> Bool) -> ImpM lore r op Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ImpState lore r op -> Bool) -> ImpM lore r op Bool)
-> (ImpState lore r op -> Bool) -> ImpM lore r op Bool
forall a b. (a -> b) -> a -> b
$ \ImpState lore r op
s ->
let Imp.Functions [(Name, Function op)]
fs = ImpState lore r op -> Functions op
forall lore r op. ImpState lore r op -> Functions op
stateFunctions ImpState lore r op
s
in Maybe (Function op) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe (Function op) -> Bool) -> Maybe (Function op) -> Bool
forall a b. (a -> b) -> a -> b
$ Name -> [(Name, Function op)] -> Maybe (Function op)
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Name
fname [(Name, Function op)]
fs
constsVTable :: Mem lore => Stms lore -> VTable lore
constsVTable :: Stms lore -> VTable lore
constsVTable = (Stm lore -> VTable lore) -> Stms lore -> VTable lore
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm lore -> VTable lore
forall lore.
(LetDec lore ~ MemBound NoUniqueness) =>
Stm lore -> Map VName (VarEntry lore)
stmVtable
where
stmVtable :: Stm lore -> Map VName (VarEntry lore)
stmVtable (Let Pattern lore
pat StmAux (ExpDec lore)
_ Exp lore
e) =
(PatElemT (MemBound NoUniqueness) -> Map VName (VarEntry lore))
-> [PatElemT (MemBound NoUniqueness)] -> Map VName (VarEntry lore)
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Exp lore
-> PatElemT (MemBound NoUniqueness) -> Map VName (VarEntry lore)
forall lore.
Exp lore
-> PatElemT (MemBound NoUniqueness) -> Map VName (VarEntry lore)
peVtable Exp lore
e) ([PatElemT (MemBound NoUniqueness)] -> Map VName (VarEntry lore))
-> [PatElemT (MemBound NoUniqueness)] -> Map VName (VarEntry lore)
forall a b. (a -> b) -> a -> b
$ PatternT (MemBound NoUniqueness)
-> [PatElemT (MemBound NoUniqueness)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern lore
PatternT (MemBound NoUniqueness)
pat
peVtable :: Exp lore
-> PatElemT (MemBound NoUniqueness) -> Map VName (VarEntry lore)
peVtable Exp lore
e (PatElem VName
name MemBound NoUniqueness
dec) =
VName -> VarEntry lore -> Map VName (VarEntry lore)
forall k a. k -> a -> Map k a
M.singleton VName
name (VarEntry lore -> Map VName (VarEntry lore))
-> VarEntry lore -> Map VName (VarEntry lore)
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> MemBound NoUniqueness -> VarEntry lore
forall lore.
Maybe (Exp lore) -> MemBound NoUniqueness -> VarEntry lore
memBoundToVarEntry (Exp lore -> Maybe (Exp lore)
forall a. a -> Maybe a
Just Exp lore
e) MemBound NoUniqueness
dec
compileProg ::
(Mem lore, FreeIn op, MonadFreshNames m) =>
r ->
Operations lore r op ->
Imp.Space ->
Prog lore ->
m (Warnings, Imp.Definitions op)
compileProg :: r
-> Operations lore r op
-> Space
-> Prog lore
-> m (Warnings, Definitions op)
compileProg r
r Operations lore r op
ops Space
space (Prog Stms lore
consts [FunDef lore]
funs) =
(VNameSource -> ((Warnings, Definitions op), VNameSource))
-> m (Warnings, Definitions op)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> ((Warnings, Definitions op), VNameSource))
-> m (Warnings, Definitions op))
-> (VNameSource -> ((Warnings, Definitions op), VNameSource))
-> m (Warnings, Definitions op)
forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
let ([()]
_, [ImpState lore r op]
ss) =
[((), ImpState lore r op)] -> ([()], [ImpState lore r op])
forall a b. [(a, b)] -> ([a], [b])
unzip ([((), ImpState lore r op)] -> ([()], [ImpState lore r op]))
-> [((), ImpState lore r op)] -> ([()], [ImpState lore r op])
forall a b. (a -> b) -> a -> b
$ Strategy ((), ImpState lore r op)
-> (FunDef lore -> ((), ImpState lore r op))
-> [FunDef lore]
-> [((), ImpState lore r op)]
forall b a. Strategy b -> (a -> b) -> [a] -> [b]
parMap Strategy ((), ImpState lore r op)
forall a. Strategy a
rpar (VNameSource -> FunDef lore -> ((), ImpState lore r op)
compileFunDef' VNameSource
src) [FunDef lore]
funs
free_in_funs :: Names
free_in_funs =
Functions op -> Names
forall a. FreeIn a => a -> Names
freeIn (Functions op -> Names) -> Functions op -> Names
forall a b. (a -> b) -> a -> b
$ [Functions op] -> Functions op
forall a. Monoid a => [a] -> a
mconcat ([Functions op] -> Functions op) -> [Functions op] -> Functions op
forall a b. (a -> b) -> a -> b
$ (ImpState lore r op -> Functions op)
-> [ImpState lore r op] -> [Functions op]
forall a b. (a -> b) -> [a] -> [b]
map ImpState lore r op -> Functions op
forall lore r op. ImpState lore r op -> Functions op
stateFunctions [ImpState lore r op]
ss
(Constants op
consts', ImpState lore r op
s') =
ImpM lore r op (Constants op)
-> r
-> Operations lore r op
-> Space
-> ImpState lore r op
-> (Constants op, ImpState lore r op)
forall lore r op a.
ImpM lore r op a
-> r
-> Operations lore r op
-> Space
-> ImpState lore r op
-> (a, ImpState lore r op)
runImpM (Names -> Stms lore -> ImpM lore r op (Constants op)
forall lore r op.
Names -> Stms lore -> ImpM lore r op (Constants op)
compileConsts Names
free_in_funs Stms lore
consts) r
r Operations lore r op
ops Space
space (ImpState lore r op -> (Constants op, ImpState lore r op))
-> ImpState lore r op -> (Constants op, ImpState lore r op)
forall a b. (a -> b) -> a -> b
$
[ImpState lore r op] -> ImpState lore r op
forall lore r op lore r. [ImpState lore r op] -> ImpState lore r op
combineStates [ImpState lore r op]
ss
in ( ( ImpState lore r op -> Warnings
forall lore r op. ImpState lore r op -> Warnings
stateWarnings ImpState lore r op
s',
Constants op -> Functions op -> Definitions op
forall a. Constants a -> Functions a -> Definitions a
Imp.Definitions Constants op
consts' (ImpState lore r op -> Functions op
forall lore r op. ImpState lore r op -> Functions op
stateFunctions ImpState lore r op
s')
),
ImpState lore r op -> VNameSource
forall lore r op. ImpState lore r op -> VNameSource
stateNameSource ImpState lore r op
s'
)
where
compileFunDef' :: VNameSource -> FunDef lore -> ((), ImpState lore r op)
compileFunDef' VNameSource
src FunDef lore
fdef =
ImpM lore r op ()
-> r
-> Operations lore r op
-> Space
-> ImpState lore r op
-> ((), ImpState lore r op)
forall lore r op a.
ImpM lore r op a
-> r
-> Operations lore r op
-> Space
-> ImpState lore r op
-> (a, ImpState lore r op)
runImpM
(FunDef lore -> ImpM lore r op ()
forall lore r op. Mem lore => FunDef lore -> ImpM lore r op ()
compileFunDef FunDef lore
fdef)
r
r
Operations lore r op
ops
Space
space
(VNameSource -> ImpState Any Any op
forall lore r op. VNameSource -> ImpState lore r op
newState VNameSource
src) {stateVTable :: VTable lore
stateVTable = Stms lore -> VTable lore
forall lore. Mem lore => Stms lore -> VTable lore
constsVTable Stms lore
consts}
combineStates :: [ImpState lore r op] -> ImpState lore r op
combineStates [ImpState lore r op]
ss =
let Imp.Functions [(Name, Function op)]
funs' = [Functions op] -> Functions op
forall a. Monoid a => [a] -> a
mconcat ([Functions op] -> Functions op) -> [Functions op] -> Functions op
forall a b. (a -> b) -> a -> b
$ (ImpState lore r op -> Functions op)
-> [ImpState lore r op] -> [Functions op]
forall a b. (a -> b) -> [a] -> [b]
map ImpState lore r op -> Functions op
forall lore r op. ImpState lore r op -> Functions op
stateFunctions [ImpState lore r op]
ss
src :: VNameSource
src = [VNameSource] -> VNameSource
forall a. Monoid a => [a] -> a
mconcat ((ImpState lore r op -> VNameSource)
-> [ImpState lore r op] -> [VNameSource]
forall a b. (a -> b) -> [a] -> [b]
map ImpState lore r op -> VNameSource
forall lore r op. ImpState lore r op -> VNameSource
stateNameSource [ImpState lore r op]
ss)
in (VNameSource -> ImpState lore Any op
forall lore r op. VNameSource -> ImpState lore r op
newState VNameSource
src)
{ stateFunctions :: Functions op
stateFunctions =
[(Name, Function op)] -> Functions op
forall a. [(Name, Function a)] -> Functions a
Imp.Functions ([(Name, Function op)] -> Functions op)
-> [(Name, Function op)] -> Functions op
forall a b. (a -> b) -> a -> b
$ Map Name (Function op) -> [(Name, Function op)]
forall k a. Map k a -> [(k, a)]
M.toList (Map Name (Function op) -> [(Name, Function op)])
-> Map Name (Function op) -> [(Name, Function op)]
forall a b. (a -> b) -> a -> b
$ [(Name, Function op)] -> Map Name (Function op)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(Name, Function op)]
funs',
stateWarnings :: Warnings
stateWarnings =
[Warnings] -> Warnings
forall a. Monoid a => [a] -> a
mconcat ([Warnings] -> Warnings) -> [Warnings] -> Warnings
forall a b. (a -> b) -> a -> b
$ (ImpState lore r op -> Warnings)
-> [ImpState lore r op] -> [Warnings]
forall a b. (a -> b) -> [a] -> [b]
map ImpState lore r op -> Warnings
forall lore r op. ImpState lore r op -> Warnings
stateWarnings [ImpState lore r op]
ss
}
compileConsts :: Names -> Stms lore -> ImpM lore r op (Imp.Constants op)
compileConsts :: Names -> Stms lore -> ImpM lore r op (Constants op)
compileConsts Names
used_consts Stms lore
stms = do
Code op
code <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (ImpM lore r op () -> ImpM lore r op (Code op))
-> ImpM lore r op () -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$ Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
used_consts Stms lore
stms (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ () -> ImpM lore r op ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
Constants op -> ImpM lore r op (Constants op)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Constants op -> ImpM lore r op (Constants op))
-> Constants op -> ImpM lore r op (Constants op)
forall a b. (a -> b) -> a -> b
$ ([Param] -> Code op -> Constants op)
-> ([Param], Code op) -> Constants op
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry [Param] -> Code op -> Constants op
forall a. [Param] -> Code a -> Constants a
Imp.Constants (([Param], Code op) -> Constants op)
-> ([Param], Code op) -> Constants op
forall a b. (a -> b) -> a -> b
$ (DList Param -> [Param])
-> (DList Param, Code op) -> ([Param], Code op)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first DList Param -> [Param]
forall a. DList a -> [a]
DL.toList ((DList Param, Code op) -> ([Param], Code op))
-> (DList Param, Code op) -> ([Param], Code op)
forall a b. (a -> b) -> a -> b
$ Code op -> (DList Param, Code op)
extract Code op
code
where
extract :: Code op -> (DList Param, Code op)
extract (Code op
x Imp.:>>: Code op
y) =
Code op -> (DList Param, Code op)
extract Code op
x (DList Param, Code op)
-> (DList Param, Code op) -> (DList Param, Code op)
forall a. Semigroup a => a -> a -> a
<> Code op -> (DList Param, Code op)
extract Code op
y
extract (Imp.DeclareMem VName
name Space
space)
| VName
name VName -> Names -> Bool
`nameIn` Names
used_consts =
( Param -> DList Param
forall a. a -> DList a
DL.singleton (Param -> DList Param) -> Param -> DList Param
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Param
Imp.MemParam VName
name Space
space,
Code op
forall a. Monoid a => a
mempty
)
extract (Imp.DeclareScalar VName
name Volatility
_ PrimType
t)
| VName
name VName -> Names -> Bool
`nameIn` Names
used_consts =
( Param -> DList Param
forall a. a -> DList a
DL.singleton (Param -> DList Param) -> Param -> DList Param
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Param
Imp.ScalarParam VName
name PrimType
t,
Code op
forall a. Monoid a => a
mempty
)
extract Code op
s =
(DList Param
forall a. Monoid a => a
mempty, Code op
s)
compileInParam ::
Mem lore =>
FParam lore ->
ImpM lore r op (Either Imp.Param ArrayDecl)
compileInParam :: FParam lore -> ImpM lore r op (Either Param ArrayDecl)
compileInParam FParam lore
fparam = case Param (MemInfo DimSize Uniqueness MemBind)
-> MemInfo DimSize Uniqueness MemBind
forall dec. Param dec -> dec
paramDec FParam lore
Param (MemInfo DimSize Uniqueness MemBind)
fparam of
MemPrim PrimType
bt ->
Either Param ArrayDecl -> ImpM lore r op (Either Param ArrayDecl)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either Param ArrayDecl -> ImpM lore r op (Either Param ArrayDecl))
-> Either Param ArrayDecl
-> ImpM lore r op (Either Param ArrayDecl)
forall a b. (a -> b) -> a -> b
$ Param -> Either Param ArrayDecl
forall a b. a -> Either a b
Left (Param -> Either Param ArrayDecl)
-> Param -> Either Param ArrayDecl
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Param
Imp.ScalarParam VName
name PrimType
bt
MemMem Space
space ->
Either Param ArrayDecl -> ImpM lore r op (Either Param ArrayDecl)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either Param ArrayDecl -> ImpM lore r op (Either Param ArrayDecl))
-> Either Param ArrayDecl
-> ImpM lore r op (Either Param ArrayDecl)
forall a b. (a -> b) -> a -> b
$ Param -> Either Param ArrayDecl
forall a b. a -> Either a b
Left (Param -> Either Param ArrayDecl)
-> Param -> Either Param ArrayDecl
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Param
Imp.MemParam VName
name Space
space
MemArray PrimType
bt ShapeBase DimSize
shape Uniqueness
_ (ArrayIn VName
mem IxFun
ixfun) ->
Either Param ArrayDecl -> ImpM lore r op (Either Param ArrayDecl)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either Param ArrayDecl -> ImpM lore r op (Either Param ArrayDecl))
-> Either Param ArrayDecl
-> ImpM lore r op (Either Param ArrayDecl)
forall a b. (a -> b) -> a -> b
$
ArrayDecl -> Either Param ArrayDecl
forall a b. b -> Either a b
Right (ArrayDecl -> Either Param ArrayDecl)
-> ArrayDecl -> Either Param ArrayDecl
forall a b. (a -> b) -> a -> b
$
VName -> PrimType -> MemLocation -> ArrayDecl
ArrayDecl VName
name PrimType
bt (MemLocation -> ArrayDecl) -> MemLocation -> ArrayDecl
forall a b. (a -> b) -> a -> b
$
VName -> [DimSize] -> IxFun (TExp Int64) -> MemLocation
MemLocation VName
mem (ShapeBase DimSize -> [DimSize]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase DimSize
shape) (IxFun (TExp Int64) -> MemLocation)
-> IxFun (TExp Int64) -> MemLocation
forall a b. (a -> b) -> a -> b
$ (TPrimExp Int64 VName -> TExp Int64) -> IxFun -> IxFun (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> ExpLeaf) -> TPrimExp Int64 VName -> TExp Int64
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> ExpLeaf
Imp.ScalarVar) IxFun
ixfun
where
name :: VName
name = Param (MemInfo DimSize Uniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName FParam lore
Param (MemInfo DimSize Uniqueness MemBind)
fparam
data ArrayDecl = ArrayDecl VName PrimType MemLocation
fparamSizes :: Typed dec => Param dec -> S.Set VName
fparamSizes :: Param dec -> Set VName
fparamSizes = [VName] -> Set VName
forall a. Ord a => [a] -> Set a
S.fromList ([VName] -> Set VName)
-> (Param dec -> [VName]) -> Param dec -> Set VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [DimSize] -> [VName]
subExpVars ([DimSize] -> [VName])
-> (Param dec -> [DimSize]) -> Param dec -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> [DimSize]
forall u. TypeBase (ShapeBase DimSize) u -> [DimSize]
arrayDims (Type -> [DimSize])
-> (Param dec -> Type) -> Param dec -> [DimSize]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param dec -> Type
forall dec. Typed dec => Param dec -> Type
paramType
compileInParams ::
Mem lore =>
[FParam lore] ->
[EntryPointType] ->
ImpM lore r op ([Imp.Param], [ArrayDecl], [Imp.ExternalValue])
compileInParams :: [FParam lore]
-> [EntryPointType]
-> ImpM lore r op ([Param], [ArrayDecl], [ExternalValue])
compileInParams [FParam lore]
params [EntryPointType]
orig_epts = do
let ([Param (MemInfo DimSize Uniqueness MemBind)]
ctx_params, [Param (MemInfo DimSize Uniqueness MemBind)]
val_params) =
Int
-> [Param (MemInfo DimSize Uniqueness MemBind)]
-> ([Param (MemInfo DimSize Uniqueness MemBind)],
[Param (MemInfo DimSize Uniqueness MemBind)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([Param (MemInfo DimSize Uniqueness MemBind)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [FParam lore]
[Param (MemInfo DimSize Uniqueness MemBind)]
params Int -> Int -> Int
forall a. Num a => a -> a -> a
- [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((EntryPointType -> Int) -> [EntryPointType] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map EntryPointType -> Int
entryPointSize [EntryPointType]
orig_epts)) [FParam lore]
[Param (MemInfo DimSize Uniqueness MemBind)]
params
([Param]
inparams, [ArrayDecl]
arrayds) <- [Either Param ArrayDecl] -> ([Param], [ArrayDecl])
forall a b. [Either a b] -> ([a], [b])
partitionEithers ([Either Param ArrayDecl] -> ([Param], [ArrayDecl]))
-> ImpM lore r op [Either Param ArrayDecl]
-> ImpM lore r op ([Param], [ArrayDecl])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Param (MemInfo DimSize Uniqueness MemBind)
-> ImpM lore r op (Either Param ArrayDecl))
-> [Param (MemInfo DimSize Uniqueness MemBind)]
-> ImpM lore r op [Either Param ArrayDecl]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Param (MemInfo DimSize Uniqueness MemBind)
-> ImpM lore r op (Either Param ArrayDecl)
forall lore r op.
Mem lore =>
FParam lore -> ImpM lore r op (Either Param ArrayDecl)
compileInParam ([Param (MemInfo DimSize Uniqueness MemBind)]
ctx_params [Param (MemInfo DimSize Uniqueness MemBind)]
-> [Param (MemInfo DimSize Uniqueness MemBind)]
-> [Param (MemInfo DimSize Uniqueness MemBind)]
forall a. [a] -> [a] -> [a]
++ [Param (MemInfo DimSize Uniqueness MemBind)]
val_params)
let findArray :: VName -> Maybe ArrayDecl
findArray VName
x = (ArrayDecl -> Bool) -> [ArrayDecl] -> Maybe ArrayDecl
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (VName -> ArrayDecl -> Bool
isArrayDecl VName
x) [ArrayDecl]
arrayds
sizes :: Set VName
sizes = [Set VName] -> Set VName
forall a. Monoid a => [a] -> a
mconcat ([Set VName] -> Set VName) -> [Set VName] -> Set VName
forall a b. (a -> b) -> a -> b
$ (Param (MemInfo DimSize Uniqueness MemBind) -> Set VName)
-> [Param (MemInfo DimSize Uniqueness MemBind)] -> [Set VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (MemInfo DimSize Uniqueness MemBind) -> Set VName
forall dec. Typed dec => Param dec -> Set VName
fparamSizes ([Param (MemInfo DimSize Uniqueness MemBind)] -> [Set VName])
-> [Param (MemInfo DimSize Uniqueness MemBind)] -> [Set VName]
forall a b. (a -> b) -> a -> b
$ [Param (MemInfo DimSize Uniqueness MemBind)]
ctx_params [Param (MemInfo DimSize Uniqueness MemBind)]
-> [Param (MemInfo DimSize Uniqueness MemBind)]
-> [Param (MemInfo DimSize Uniqueness MemBind)]
forall a. [a] -> [a] -> [a]
++ [Param (MemInfo DimSize Uniqueness MemBind)]
val_params
summaries :: Map VName Space
summaries = [(VName, Space)] -> Map VName Space
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Space)] -> Map VName Space)
-> [(VName, Space)] -> Map VName Space
forall a b. (a -> b) -> a -> b
$ (Param (MemInfo DimSize Uniqueness MemBind)
-> Maybe (VName, Space))
-> [Param (MemInfo DimSize Uniqueness MemBind)] -> [(VName, Space)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Param (MemInfo DimSize Uniqueness MemBind) -> Maybe (VName, Space)
forall d u ret. Param (MemInfo d u ret) -> Maybe (VName, Space)
memSummary [FParam lore]
[Param (MemInfo DimSize Uniqueness MemBind)]
params
where
memSummary :: Param (MemInfo d u ret) -> Maybe (VName, Space)
memSummary Param (MemInfo d u ret)
param
| MemMem Space
space <- Param (MemInfo d u ret) -> MemInfo d u ret
forall dec. Param dec -> dec
paramDec Param (MemInfo d u ret)
param =
(VName, Space) -> Maybe (VName, Space)
forall a. a -> Maybe a
Just (Param (MemInfo d u ret) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo d u ret)
param, Space
space)
| Bool
otherwise =
Maybe (VName, Space)
forall a. Maybe a
Nothing
findMemInfo :: VName -> Maybe Space
findMemInfo :: VName -> Maybe Space
findMemInfo = (VName -> Map VName Space -> Maybe Space)
-> Map VName Space -> VName -> Maybe Space
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> Map VName Space -> Maybe Space
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Map VName Space
summaries
mkValueDesc :: Param (MemInfo DimSize Uniqueness MemBind)
-> Signedness -> Maybe ValueDesc
mkValueDesc Param (MemInfo DimSize Uniqueness MemBind)
fparam Signedness
signedness =
case (VName -> Maybe ArrayDecl
findArray (VName -> Maybe ArrayDecl) -> VName -> Maybe ArrayDecl
forall a b. (a -> b) -> a -> b
$ Param (MemInfo DimSize Uniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo DimSize Uniqueness MemBind)
fparam, Param (MemInfo DimSize Uniqueness MemBind) -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param (MemInfo DimSize Uniqueness MemBind)
fparam) of
(Just (ArrayDecl VName
_ PrimType
bt (MemLocation VName
mem [DimSize]
shape IxFun (TExp Int64)
_)), Type
_) -> do
Space
memspace <- VName -> Maybe Space
findMemInfo VName
mem
ValueDesc -> Maybe ValueDesc
forall a. a -> Maybe a
Just (ValueDesc -> Maybe ValueDesc) -> ValueDesc -> Maybe ValueDesc
forall a b. (a -> b) -> a -> b
$ VName -> Space -> PrimType -> Signedness -> [DimSize] -> ValueDesc
Imp.ArrayValue VName
mem Space
memspace PrimType
bt Signedness
signedness [DimSize]
shape
(Maybe ArrayDecl
_, Prim PrimType
bt)
| Param (MemInfo DimSize Uniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo DimSize Uniqueness MemBind)
fparam VName -> Set VName -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Set VName
sizes ->
Maybe ValueDesc
forall a. Maybe a
Nothing
| Bool
otherwise ->
ValueDesc -> Maybe ValueDesc
forall a. a -> Maybe a
Just (ValueDesc -> Maybe ValueDesc) -> ValueDesc -> Maybe ValueDesc
forall a b. (a -> b) -> a -> b
$ PrimType -> Signedness -> VName -> ValueDesc
Imp.ScalarValue PrimType
bt Signedness
signedness (VName -> ValueDesc) -> VName -> ValueDesc
forall a b. (a -> b) -> a -> b
$ Param (MemInfo DimSize Uniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo DimSize Uniqueness MemBind)
fparam
(Maybe ArrayDecl, Type)
_ ->
Maybe ValueDesc
forall a. Maybe a
Nothing
mkExts :: [EntryPointType]
-> [Param (MemInfo DimSize Uniqueness MemBind)] -> [ExternalValue]
mkExts (TypeOpaque String
desc Int
n : [EntryPointType]
epts) [Param (MemInfo DimSize Uniqueness MemBind)]
fparams =
let ([Param (MemInfo DimSize Uniqueness MemBind)]
fparams', [Param (MemInfo DimSize Uniqueness MemBind)]
rest) = Int
-> [Param (MemInfo DimSize Uniqueness MemBind)]
-> ([Param (MemInfo DimSize Uniqueness MemBind)],
[Param (MemInfo DimSize Uniqueness MemBind)])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
n [Param (MemInfo DimSize Uniqueness MemBind)]
fparams
in String -> [ValueDesc] -> ExternalValue
Imp.OpaqueValue
String
desc
((Param (MemInfo DimSize Uniqueness MemBind) -> Maybe ValueDesc)
-> [Param (MemInfo DimSize Uniqueness MemBind)] -> [ValueDesc]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (Param (MemInfo DimSize Uniqueness MemBind)
-> Signedness -> Maybe ValueDesc
`mkValueDesc` Signedness
Imp.TypeDirect) [Param (MemInfo DimSize Uniqueness MemBind)]
fparams') ExternalValue -> [ExternalValue] -> [ExternalValue]
forall a. a -> [a] -> [a]
:
[EntryPointType]
-> [Param (MemInfo DimSize Uniqueness MemBind)] -> [ExternalValue]
mkExts [EntryPointType]
epts [Param (MemInfo DimSize Uniqueness MemBind)]
rest
mkExts (EntryPointType
TypeUnsigned : [EntryPointType]
epts) (Param (MemInfo DimSize Uniqueness MemBind)
fparam : [Param (MemInfo DimSize Uniqueness MemBind)]
fparams) =
Maybe ExternalValue -> [ExternalValue]
forall a. Maybe a -> [a]
maybeToList (ValueDesc -> ExternalValue
Imp.TransparentValue (ValueDesc -> ExternalValue)
-> Maybe ValueDesc -> Maybe ExternalValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Param (MemInfo DimSize Uniqueness MemBind)
-> Signedness -> Maybe ValueDesc
mkValueDesc Param (MemInfo DimSize Uniqueness MemBind)
fparam Signedness
Imp.TypeUnsigned)
[ExternalValue] -> [ExternalValue] -> [ExternalValue]
forall a. [a] -> [a] -> [a]
++ [EntryPointType]
-> [Param (MemInfo DimSize Uniqueness MemBind)] -> [ExternalValue]
mkExts [EntryPointType]
epts [Param (MemInfo DimSize Uniqueness MemBind)]
fparams
mkExts (EntryPointType
TypeDirect : [EntryPointType]
epts) (Param (MemInfo DimSize Uniqueness MemBind)
fparam : [Param (MemInfo DimSize Uniqueness MemBind)]
fparams) =
Maybe ExternalValue -> [ExternalValue]
forall a. Maybe a -> [a]
maybeToList (ValueDesc -> ExternalValue
Imp.TransparentValue (ValueDesc -> ExternalValue)
-> Maybe ValueDesc -> Maybe ExternalValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Param (MemInfo DimSize Uniqueness MemBind)
-> Signedness -> Maybe ValueDesc
mkValueDesc Param (MemInfo DimSize Uniqueness MemBind)
fparam Signedness
Imp.TypeDirect)
[ExternalValue] -> [ExternalValue] -> [ExternalValue]
forall a. [a] -> [a] -> [a]
++ [EntryPointType]
-> [Param (MemInfo DimSize Uniqueness MemBind)] -> [ExternalValue]
mkExts [EntryPointType]
epts [Param (MemInfo DimSize Uniqueness MemBind)]
fparams
mkExts [EntryPointType]
_ [Param (MemInfo DimSize Uniqueness MemBind)]
_ = []
([Param], [ArrayDecl], [ExternalValue])
-> ImpM lore r op ([Param], [ArrayDecl], [ExternalValue])
forall (m :: * -> *) a. Monad m => a -> m a
return ([Param]
inparams, [ArrayDecl]
arrayds, [EntryPointType]
-> [Param (MemInfo DimSize Uniqueness MemBind)] -> [ExternalValue]
mkExts [EntryPointType]
orig_epts [Param (MemInfo DimSize Uniqueness MemBind)]
val_params)
where
isArrayDecl :: VName -> ArrayDecl -> Bool
isArrayDecl VName
x (ArrayDecl VName
y PrimType
_ MemLocation
_) = VName
x VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
y
compileOutParams ::
Mem lore =>
[RetType lore] ->
[EntryPointType] ->
ImpM lore r op ([Imp.ExternalValue], [Imp.Param], Destination)
compileOutParams :: [RetType lore]
-> [EntryPointType]
-> ImpM lore r op ([ExternalValue], [Param], Destination)
compileOutParams [RetType lore]
orig_rts [EntryPointType]
orig_epts = do
(([ExternalValue]
extvs, [ValueDestination]
dests), ([Param]
outparams, Map Int ValueDestination
ctx_dests)) <-
WriterT
([Param], Map Int ValueDestination)
(ImpM lore r op)
([ExternalValue], [ValueDestination])
-> ImpM
lore
r
op
(([ExternalValue], [ValueDestination]),
([Param], Map Int ValueDestination))
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT
([Param], Map Int ValueDestination)
(ImpM lore r op)
([ExternalValue], [ValueDestination])
-> ImpM
lore
r
op
(([ExternalValue], [ValueDestination]),
([Param], Map Int ValueDestination)))
-> WriterT
([Param], Map Int ValueDestination)
(ImpM lore r op)
([ExternalValue], [ValueDestination])
-> ImpM
lore
r
op
(([ExternalValue], [ValueDestination]),
([Param], Map Int ValueDestination))
forall a b. (a -> b) -> a -> b
$ StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
([ExternalValue], [ValueDestination])
-> (Map Any Any, Map Int VName)
-> WriterT
([Param], Map Int ValueDestination)
(ImpM lore r op)
([ExternalValue], [ValueDestination])
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT ([EntryPointType]
-> [MemInfo (Ext DimSize) Uniqueness MemReturn]
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
([ExternalValue], [ValueDestination])
mkExts [EntryPointType]
orig_epts [RetType lore]
[MemInfo (Ext DimSize) Uniqueness MemReturn]
orig_rts) (Map Any Any
forall k a. Map k a
M.empty, Map Int VName
forall k a. Map k a
M.empty)
let ctx_dests' :: [ValueDestination]
ctx_dests' = ((Int, ValueDestination) -> ValueDestination)
-> [(Int, ValueDestination)] -> [ValueDestination]
forall a b. (a -> b) -> [a] -> [b]
map (Int, ValueDestination) -> ValueDestination
forall a b. (a, b) -> b
snd ([(Int, ValueDestination)] -> [ValueDestination])
-> [(Int, ValueDestination)] -> [ValueDestination]
forall a b. (a -> b) -> a -> b
$ ((Int, ValueDestination) -> Int)
-> [(Int, ValueDestination)] -> [(Int, ValueDestination)]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn (Int, ValueDestination) -> Int
forall a b. (a, b) -> a
fst ([(Int, ValueDestination)] -> [(Int, ValueDestination)])
-> [(Int, ValueDestination)] -> [(Int, ValueDestination)]
forall a b. (a -> b) -> a -> b
$ Map Int ValueDestination -> [(Int, ValueDestination)]
forall k a. Map k a -> [(k, a)]
M.toList Map Int ValueDestination
ctx_dests
([ExternalValue], [Param], Destination)
-> ImpM lore r op ([ExternalValue], [Param], Destination)
forall (m :: * -> *) a. Monad m => a -> m a
return ([ExternalValue]
extvs, [Param]
outparams, Maybe Int -> [ValueDestination] -> Destination
Destination Maybe Int
forall a. Maybe a
Nothing ([ValueDestination] -> Destination)
-> [ValueDestination] -> Destination
forall a b. (a -> b) -> a -> b
$ [ValueDestination]
ctx_dests' [ValueDestination] -> [ValueDestination] -> [ValueDestination]
forall a. Semigroup a => a -> a -> a
<> [ValueDestination]
dests)
where
imp :: ImpM lore r op a
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
a
imp = WriterT ([Param], Map Int ValueDestination) (ImpM lore r op) a
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (WriterT ([Param], Map Int ValueDestination) (ImpM lore r op) a
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
a)
-> (ImpM lore r op a
-> WriterT ([Param], Map Int ValueDestination) (ImpM lore r op) a)
-> ImpM lore r op a
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ImpM lore r op a
-> WriterT ([Param], Map Int ValueDestination) (ImpM lore r op) a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift
mkExts :: [EntryPointType]
-> [MemInfo (Ext DimSize) Uniqueness MemReturn]
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
([ExternalValue], [ValueDestination])
mkExts (TypeOpaque String
desc Int
n : [EntryPointType]
epts) [MemInfo (Ext DimSize) Uniqueness MemReturn]
rts = do
let ([MemInfo (Ext DimSize) Uniqueness MemReturn]
rts', [MemInfo (Ext DimSize) Uniqueness MemReturn]
rest) = Int
-> [MemInfo (Ext DimSize) Uniqueness MemReturn]
-> ([MemInfo (Ext DimSize) Uniqueness MemReturn],
[MemInfo (Ext DimSize) Uniqueness MemReturn])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
n [MemInfo (Ext DimSize) Uniqueness MemReturn]
rts
([ValueDesc]
evs, [ValueDestination]
dests) <- [(ValueDesc, ValueDestination)]
-> ([ValueDesc], [ValueDestination])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(ValueDesc, ValueDestination)]
-> ([ValueDesc], [ValueDestination]))
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
[(ValueDesc, ValueDestination)]
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
([ValueDesc], [ValueDestination])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (MemInfo (Ext DimSize) Uniqueness MemReturn
-> Signedness
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
(ValueDesc, ValueDestination))
-> [MemInfo (Ext DimSize) Uniqueness MemReturn]
-> [Signedness]
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
[(ValueDesc, ValueDestination)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM MemInfo (Ext DimSize) Uniqueness MemReturn
-> Signedness
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
(ValueDesc, ValueDestination)
mkParam [MemInfo (Ext DimSize) Uniqueness MemReturn]
rts' (Signedness -> [Signedness]
forall a. a -> [a]
repeat Signedness
Imp.TypeDirect)
([ExternalValue]
more_values, [ValueDestination]
more_dests) <- [EntryPointType]
-> [MemInfo (Ext DimSize) Uniqueness MemReturn]
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
([ExternalValue], [ValueDestination])
mkExts [EntryPointType]
epts [MemInfo (Ext DimSize) Uniqueness MemReturn]
rest
([ExternalValue], [ValueDestination])
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
([ExternalValue], [ValueDestination])
forall (m :: * -> *) a. Monad m => a -> m a
return
( String -> [ValueDesc] -> ExternalValue
Imp.OpaqueValue String
desc [ValueDesc]
evs ExternalValue -> [ExternalValue] -> [ExternalValue]
forall a. a -> [a] -> [a]
: [ExternalValue]
more_values,
[ValueDestination]
dests [ValueDestination] -> [ValueDestination] -> [ValueDestination]
forall a. [a] -> [a] -> [a]
++ [ValueDestination]
more_dests
)
mkExts (EntryPointType
TypeUnsigned : [EntryPointType]
epts) (MemInfo (Ext DimSize) Uniqueness MemReturn
rt : [MemInfo (Ext DimSize) Uniqueness MemReturn]
rts) = do
(ValueDesc
ev, ValueDestination
dest) <- MemInfo (Ext DimSize) Uniqueness MemReturn
-> Signedness
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
(ValueDesc, ValueDestination)
mkParam MemInfo (Ext DimSize) Uniqueness MemReturn
rt Signedness
Imp.TypeUnsigned
([ExternalValue]
more_values, [ValueDestination]
more_dests) <- [EntryPointType]
-> [MemInfo (Ext DimSize) Uniqueness MemReturn]
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
([ExternalValue], [ValueDestination])
mkExts [EntryPointType]
epts [MemInfo (Ext DimSize) Uniqueness MemReturn]
rts
([ExternalValue], [ValueDestination])
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
([ExternalValue], [ValueDestination])
forall (m :: * -> *) a. Monad m => a -> m a
return
( ValueDesc -> ExternalValue
Imp.TransparentValue ValueDesc
ev ExternalValue -> [ExternalValue] -> [ExternalValue]
forall a. a -> [a] -> [a]
: [ExternalValue]
more_values,
ValueDestination
dest ValueDestination -> [ValueDestination] -> [ValueDestination]
forall a. a -> [a] -> [a]
: [ValueDestination]
more_dests
)
mkExts (EntryPointType
TypeDirect : [EntryPointType]
epts) (MemInfo (Ext DimSize) Uniqueness MemReturn
rt : [MemInfo (Ext DimSize) Uniqueness MemReturn]
rts) = do
(ValueDesc
ev, ValueDestination
dest) <- MemInfo (Ext DimSize) Uniqueness MemReturn
-> Signedness
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
(ValueDesc, ValueDestination)
mkParam MemInfo (Ext DimSize) Uniqueness MemReturn
rt Signedness
Imp.TypeDirect
([ExternalValue]
more_values, [ValueDestination]
more_dests) <- [EntryPointType]
-> [MemInfo (Ext DimSize) Uniqueness MemReturn]
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
([ExternalValue], [ValueDestination])
mkExts [EntryPointType]
epts [MemInfo (Ext DimSize) Uniqueness MemReturn]
rts
([ExternalValue], [ValueDestination])
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
([ExternalValue], [ValueDestination])
forall (m :: * -> *) a. Monad m => a -> m a
return
( ValueDesc -> ExternalValue
Imp.TransparentValue ValueDesc
ev ExternalValue -> [ExternalValue] -> [ExternalValue]
forall a. a -> [a] -> [a]
: [ExternalValue]
more_values,
ValueDestination
dest ValueDestination -> [ValueDestination] -> [ValueDestination]
forall a. a -> [a] -> [a]
: [ValueDestination]
more_dests
)
mkExts [EntryPointType]
_ [MemInfo (Ext DimSize) Uniqueness MemReturn]
_ = ([ExternalValue], [ValueDestination])
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
([ExternalValue], [ValueDestination])
forall (m :: * -> *) a. Monad m => a -> m a
return ([], [])
mkParam :: MemInfo (Ext DimSize) Uniqueness MemReturn
-> Signedness
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
(ValueDesc, ValueDestination)
mkParam MemMem {} Signedness
_ =
String
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
(ValueDesc, ValueDestination)
forall a. HasCallStack => String -> a
error String
"Functions may not explicitly return memory blocks."
mkParam (MemPrim PrimType
t) Signedness
ept = do
VName
out <- ImpM lore r op VName
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
VName
forall a.
ImpM lore r op a
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
a
imp (ImpM lore r op VName
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
VName)
-> ImpM lore r op VName
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
VName
forall a b. (a -> b) -> a -> b
$ String -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"scalar_out"
([Param], Map Int ValueDestination)
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([VName -> PrimType -> Param
Imp.ScalarParam VName
out PrimType
t], Map Int ValueDestination
forall a. Monoid a => a
mempty)
(ValueDesc, ValueDestination)
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
(ValueDesc, ValueDestination)
forall (m :: * -> *) a. Monad m => a -> m a
return (PrimType -> Signedness -> VName -> ValueDesc
Imp.ScalarValue PrimType
t Signedness
ept VName
out, VName -> ValueDestination
ScalarDestination VName
out)
mkParam (MemArray PrimType
t ShapeBase (Ext DimSize)
shape Uniqueness
_ MemReturn
dec) Signedness
ept = do
Space
space <- (Env lore r op -> Space)
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
Space
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Space
forall lore r op. Env lore r op -> Space
envDefaultSpace
VName
memout <- case MemReturn
dec of
ReturnsNewBlock Space
_ Int
x ExtIxFun
_ixfun -> do
VName
memout <- ImpM lore r op VName
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
VName
forall a.
ImpM lore r op a
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
a
imp (ImpM lore r op VName
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
VName)
-> ImpM lore r op VName
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
VName
forall a b. (a -> b) -> a -> b
$ String -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"out_mem"
([Param], Map Int ValueDestination)
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell
( [VName -> Space -> Param
Imp.MemParam VName
memout Space
space],
Int -> ValueDestination -> Map Int ValueDestination
forall k a. k -> a -> Map k a
M.singleton Int
x (ValueDestination -> Map Int ValueDestination)
-> ValueDestination -> Map Int ValueDestination
forall a b. (a -> b) -> a -> b
$ VName -> ValueDestination
MemoryDestination VName
memout
)
VName
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
memout
ReturnsInBlock VName
memout ExtIxFun
_ ->
VName
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
memout
[DimSize]
resultshape <- (Ext DimSize
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
DimSize)
-> [Ext DimSize]
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
[DimSize]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Ext DimSize
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
DimSize
inspectExtSize ([Ext DimSize]
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
[DimSize])
-> [Ext DimSize]
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
[DimSize]
forall a b. (a -> b) -> a -> b
$ ShapeBase (Ext DimSize) -> [Ext DimSize]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase (Ext DimSize)
shape
(ValueDesc, ValueDestination)
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
(ValueDesc, ValueDestination)
forall (m :: * -> *) a. Monad m => a -> m a
return
( VName -> Space -> PrimType -> Signedness -> [DimSize] -> ValueDesc
Imp.ArrayValue VName
memout Space
space PrimType
t Signedness
ept [DimSize]
resultshape,
Maybe MemLocation -> ValueDestination
ArrayDestination Maybe MemLocation
forall a. Maybe a
Nothing
)
inspectExtSize :: Ext DimSize
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
DimSize
inspectExtSize (Ext Int
x) = do
(Map Any Any
memseen, Map Int VName
arrseen) <- StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
(Map Any Any, Map Int VName)
forall s (m :: * -> *). MonadState s m => m s
get
case Int -> Map Int VName -> Maybe VName
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Int
x Map Int VName
arrseen of
Maybe VName
Nothing -> do
VName
out <- ImpM lore r op VName
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
VName
forall a.
ImpM lore r op a
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
a
imp (ImpM lore r op VName
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
VName)
-> ImpM lore r op VName
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
VName
forall a b. (a -> b) -> a -> b
$ String -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"out_arrsize"
([Param], Map Int ValueDestination)
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell
( [VName -> PrimType -> Param
Imp.ScalarParam VName
out PrimType
int64],
Int -> ValueDestination -> Map Int ValueDestination
forall k a. k -> a -> Map k a
M.singleton Int
x (ValueDestination -> Map Int ValueDestination)
-> ValueDestination -> Map Int ValueDestination
forall a b. (a -> b) -> a -> b
$ VName -> ValueDestination
ScalarDestination VName
out
)
(Map Any Any, Map Int VName)
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (Map Any Any
memseen, Int -> VName -> Map Int VName -> Map Int VName
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Int
x VName
out Map Int VName
arrseen)
DimSize
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
DimSize
forall (m :: * -> *) a. Monad m => a -> m a
return (DimSize
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
DimSize)
-> DimSize
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
DimSize
forall a b. (a -> b) -> a -> b
$ VName -> DimSize
Var VName
out
Just VName
out ->
DimSize
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
DimSize
forall (m :: * -> *) a. Monad m => a -> m a
return (DimSize
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
DimSize)
-> DimSize
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
DimSize
forall a b. (a -> b) -> a -> b
$ VName -> DimSize
Var VName
out
inspectExtSize (Free DimSize
se) =
DimSize
-> StateT
(Map Any Any, Map Int VName)
(WriterT ([Param], Map Int ValueDestination) (ImpM lore r op))
DimSize
forall (m :: * -> *) a. Monad m => a -> m a
return DimSize
se
compileFunDef ::
Mem lore =>
FunDef lore ->
ImpM lore r op ()
compileFunDef :: FunDef lore -> ImpM lore r op ()
compileFunDef (FunDef Maybe EntryPoint
entry Attrs
_ Name
fname [RetType lore]
rettype [FParam lore]
params BodyT lore
body) =
(Env lore r op -> Env lore r op)
-> ImpM lore r op () -> ImpM lore r op ()
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (\Env lore r op
env -> Env lore r op
env {envFunction :: Maybe Name
envFunction = Name -> Maybe Name
forall a. a -> Maybe a
Just Name
fname}) (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ do
(([Param]
outparams, [Param]
inparams, [ExternalValue]
results, [ExternalValue]
args), Code op
body') <- ImpM lore r op ([Param], [Param], [ExternalValue], [ExternalValue])
-> ImpM
lore
r
op
(([Param], [Param], [ExternalValue], [ExternalValue]), Code op)
forall lore r op a. ImpM lore r op a -> ImpM lore r op (a, Code op)
collect' ImpM lore r op ([Param], [Param], [ExternalValue], [ExternalValue])
compile
Name -> Function op -> ImpM lore r op ()
forall op lore r. Name -> Function op -> ImpM lore r op ()
emitFunction Name
fname (Function op -> ImpM lore r op ())
-> Function op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Bool
-> [Param]
-> [Param]
-> Code op
-> [ExternalValue]
-> [ExternalValue]
-> Function op
forall a.
Bool
-> [Param]
-> [Param]
-> Code a
-> [ExternalValue]
-> [ExternalValue]
-> FunctionT a
Imp.Function (Maybe EntryPoint -> Bool
forall a. Maybe a -> Bool
isJust Maybe EntryPoint
entry) [Param]
outparams [Param]
inparams Code op
body' [ExternalValue]
results [ExternalValue]
args
where
params_entry :: [EntryPointType]
params_entry = [EntryPointType]
-> (EntryPoint -> [EntryPointType])
-> Maybe EntryPoint
-> [EntryPointType]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (Int -> EntryPointType -> [EntryPointType]
forall a. Int -> a -> [a]
replicate ([Param (MemInfo DimSize Uniqueness MemBind)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [FParam lore]
[Param (MemInfo DimSize Uniqueness MemBind)]
params) EntryPointType
TypeDirect) EntryPoint -> [EntryPointType]
forall a b. (a, b) -> a
fst Maybe EntryPoint
entry
ret_entry :: [EntryPointType]
ret_entry = [EntryPointType]
-> (EntryPoint -> [EntryPointType])
-> Maybe EntryPoint
-> [EntryPointType]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (Int -> EntryPointType -> [EntryPointType]
forall a. Int -> a -> [a]
replicate ([MemInfo (Ext DimSize) Uniqueness MemReturn] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [RetType lore]
[MemInfo (Ext DimSize) Uniqueness MemReturn]
rettype) EntryPointType
TypeDirect) EntryPoint -> [EntryPointType]
forall a b. (a, b) -> b
snd Maybe EntryPoint
entry
compile :: ImpM lore r op ([Param], [Param], [ExternalValue], [ExternalValue])
compile = do
([Param]
inparams, [ArrayDecl]
arrayds, [ExternalValue]
args) <- [FParam lore]
-> [EntryPointType]
-> ImpM lore r op ([Param], [ArrayDecl], [ExternalValue])
forall lore r op.
Mem lore =>
[FParam lore]
-> [EntryPointType]
-> ImpM lore r op ([Param], [ArrayDecl], [ExternalValue])
compileInParams [FParam lore]
params [EntryPointType]
params_entry
([ExternalValue]
results, [Param]
outparams, Destination Maybe Int
_ [ValueDestination]
dests) <- [RetType lore]
-> [EntryPointType]
-> ImpM lore r op ([ExternalValue], [Param], Destination)
forall lore r op.
Mem lore =>
[RetType lore]
-> [EntryPointType]
-> ImpM lore r op ([ExternalValue], [Param], Destination)
compileOutParams [RetType lore]
rettype [EntryPointType]
ret_entry
[FParam lore] -> ImpM lore r op ()
forall lore r op. Mem lore => [FParam lore] -> ImpM lore r op ()
addFParams [FParam lore]
params
[ArrayDecl] -> ImpM lore r op ()
forall lore r op. [ArrayDecl] -> ImpM lore r op ()
addArrays [ArrayDecl]
arrayds
let Body BodyDec lore
_ Stms lore
stms [DimSize]
ses = BodyT lore
body
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms ([DimSize] -> Names
forall a. FreeIn a => a -> Names
freeIn [DimSize]
ses) Stms lore
stms (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
[(ValueDestination, DimSize)]
-> ((ValueDestination, DimSize) -> ImpM lore r op ())
-> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([ValueDestination] -> [DimSize] -> [(ValueDestination, DimSize)]
forall a b. [a] -> [b] -> [(a, b)]
zip [ValueDestination]
dests [DimSize]
ses) (((ValueDestination, DimSize) -> ImpM lore r op ())
-> ImpM lore r op ())
-> ((ValueDestination, DimSize) -> ImpM lore r op ())
-> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(ValueDestination
d, DimSize
se) -> ValueDestination
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
ValueDestination
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIMDest ValueDestination
d [] DimSize
se []
([Param], [Param], [ExternalValue], [ExternalValue])
-> ImpM
lore r op ([Param], [Param], [ExternalValue], [ExternalValue])
forall (m :: * -> *) a. Monad m => a -> m a
return ([Param]
outparams, [Param]
inparams, [ExternalValue]
results, [ExternalValue]
args)
compileBody :: (Mem lore) => Pattern lore -> Body lore -> ImpM lore r op ()
compileBody :: Pattern lore -> Body lore -> ImpM lore r op ()
compileBody Pattern lore
pat (Body BodyDec lore
_ Stms lore
bnds [DimSize]
ses) = do
Destination Maybe Int
_ [ValueDestination]
dests <- Pattern lore -> ImpM lore r op Destination
forall lore r op.
Mem lore =>
Pattern lore -> ImpM lore r op Destination
destinationFromPattern Pattern lore
pat
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms ([DimSize] -> Names
forall a. FreeIn a => a -> Names
freeIn [DimSize]
ses) Stms lore
bnds (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
[(ValueDestination, DimSize)]
-> ((ValueDestination, DimSize) -> ImpM lore r op ())
-> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([ValueDestination] -> [DimSize] -> [(ValueDestination, DimSize)]
forall a b. [a] -> [b] -> [(a, b)]
zip [ValueDestination]
dests [DimSize]
ses) (((ValueDestination, DimSize) -> ImpM lore r op ())
-> ImpM lore r op ())
-> ((ValueDestination, DimSize) -> ImpM lore r op ())
-> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(ValueDestination
d, DimSize
se) -> ValueDestination
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
ValueDestination
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIMDest ValueDestination
d [] DimSize
se []
compileBody' :: [Param dec] -> Body lore -> ImpM lore r op ()
compileBody' :: [Param dec] -> Body lore -> ImpM lore r op ()
compileBody' [Param dec]
params (Body BodyDec lore
_ Stms lore
bnds [DimSize]
ses) =
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms ([DimSize] -> Names
forall a. FreeIn a => a -> Names
freeIn [DimSize]
ses) Stms lore
bnds (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
[(Param dec, DimSize)]
-> ((Param dec, DimSize) -> ImpM lore r op ()) -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param dec] -> [DimSize] -> [(Param dec, DimSize)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param dec]
params [DimSize]
ses) (((Param dec, DimSize) -> ImpM lore r op ()) -> ImpM lore r op ())
-> ((Param dec, DimSize) -> ImpM lore r op ()) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(Param dec
param, DimSize
se) -> VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
param) [] DimSize
se []
compileLoopBody :: Typed dec => [Param dec] -> Body lore -> ImpM lore r op ()
compileLoopBody :: [Param dec] -> Body lore -> ImpM lore r op ()
compileLoopBody [Param dec]
mergeparams (Body BodyDec lore
_ Stms lore
bnds [DimSize]
ses) = do
[VName]
tmpnames <- (Param dec -> ImpM lore r op VName)
-> [Param dec] -> ImpM lore r op [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> ImpM lore r op VName)
-> (Param dec -> String) -> Param dec -> ImpM lore r op VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_tmp") ShowS -> (Param dec -> String) -> Param dec -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> String
baseString (VName -> String) -> (Param dec -> VName) -> Param dec -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param dec -> VName
forall dec. Param dec -> VName
paramName) [Param dec]
mergeparams
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms ([DimSize] -> Names
forall a. FreeIn a => a -> Names
freeIn [DimSize]
ses) Stms lore
bnds (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ do
[ImpM lore r op ()]
copy_to_merge_params <- [(Param dec, VName, DimSize)]
-> ((Param dec, VName, DimSize)
-> ImpM lore r op (ImpM lore r op ()))
-> ImpM lore r op [ImpM lore r op ()]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Param dec]
-> [VName] -> [DimSize] -> [(Param dec, VName, DimSize)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Param dec]
mergeparams [VName]
tmpnames [DimSize]
ses) (((Param dec, VName, DimSize)
-> ImpM lore r op (ImpM lore r op ()))
-> ImpM lore r op [ImpM lore r op ()])
-> ((Param dec, VName, DimSize)
-> ImpM lore r op (ImpM lore r op ()))
-> ImpM lore r op [ImpM lore r op ()]
forall a b. (a -> b) -> a -> b
$ \(Param dec
p, VName
tmp, DimSize
se) ->
case Param dec -> Type
forall t. Typed t => t -> Type
typeOf Param dec
p of
Prim PrimType
pt -> do
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Volatility -> PrimType -> Code op
forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
tmp Volatility
Imp.Nonvolatile PrimType
pt
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
tmp (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ PrimType -> DimSize -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
pt DimSize
se
ImpM lore r op () -> ImpM lore r op (ImpM lore r op ())
forall (m :: * -> *) a. Monad m => a -> m a
return (ImpM lore r op () -> ImpM lore r op (ImpM lore r op ()))
-> ImpM lore r op () -> ImpM lore r op (ImpM lore r op ())
forall a b. (a -> b) -> a -> b
$ Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar (Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p) (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
tmp PrimType
pt
Mem Space
space | Var VName
v <- DimSize
se -> do
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Code op
forall a. VName -> Space -> Code a
Imp.DeclareMem VName
tmp Space
space
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> Space -> Code op
forall a. VName -> VName -> Space -> Code a
Imp.SetMem VName
tmp VName
v Space
space
ImpM lore r op () -> ImpM lore r op (ImpM lore r op ())
forall (m :: * -> *) a. Monad m => a -> m a
return (ImpM lore r op () -> ImpM lore r op (ImpM lore r op ()))
-> ImpM lore r op () -> ImpM lore r op (ImpM lore r op ())
forall a b. (a -> b) -> a -> b
$ Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> Space -> Code op
forall a. VName -> VName -> Space -> Code a
Imp.SetMem (Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p) VName
tmp Space
space
Type
_ -> ImpM lore r op () -> ImpM lore r op (ImpM lore r op ())
forall (m :: * -> *) a. Monad m => a -> m a
return (ImpM lore r op () -> ImpM lore r op (ImpM lore r op ()))
-> ImpM lore r op () -> ImpM lore r op (ImpM lore r op ())
forall a b. (a -> b) -> a -> b
$ () -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
[ImpM lore r op ()] -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ [ImpM lore r op ()]
copy_to_merge_params
compileStms :: Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms :: Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
alive_after_stms Stms lore
all_stms ImpM lore r op ()
m = do
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
cb <- (Env lore r op
-> Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ())
-> ImpM
lore
r
op
(Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ())
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op
-> Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op. Env lore r op -> StmsCompiler lore r op
envStmsCompiler
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
cb Names
alive_after_stms Stms lore
all_stms ImpM lore r op ()
m
defCompileStms ::
(Mem lore, FreeIn op) =>
Names ->
Stms lore ->
ImpM lore r op () ->
ImpM lore r op ()
defCompileStms :: Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
defCompileStms Names
alive_after_stms Stms lore
all_stms ImpM lore r op ()
m =
ImpM lore r op Names -> ImpM lore r op ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ImpM lore r op Names -> ImpM lore r op ())
-> ImpM lore r op Names -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Set (VName, Space) -> [Stm lore] -> ImpM lore r op Names
compileStms' Set (VName, Space)
forall a. Monoid a => a
mempty ([Stm lore] -> ImpM lore r op Names)
-> [Stm lore] -> ImpM lore r op Names
forall a b. (a -> b) -> a -> b
$ Stms lore -> [Stm lore]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms lore
all_stms
where
compileStms' :: Set (VName, Space) -> [Stm lore] -> ImpM lore r op Names
compileStms' Set (VName, Space)
allocs (Let Pattern lore
pat StmAux (ExpDec lore)
aux Exp lore
e : [Stm lore]
bs) = do
Maybe (Exp lore) -> [PatElem lore] -> ImpM lore r op ()
forall lore r op.
Mem lore =>
Maybe (Exp lore) -> [PatElem lore] -> ImpM lore r op ()
dVars (Exp lore -> Maybe (Exp lore)
forall a. a -> Maybe a
Just Exp lore
e) (PatternT (MemBound NoUniqueness)
-> [PatElemT (MemBound NoUniqueness)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern lore
PatternT (MemBound NoUniqueness)
pat)
Code op
e_code <-
Attrs -> ImpM lore r op (Code op) -> ImpM lore r op (Code op)
forall lore r op a. Attrs -> ImpM lore r op a -> ImpM lore r op a
localAttrs (StmAux (ExpDec lore) -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux (ExpDec lore)
aux) (ImpM lore r op (Code op) -> ImpM lore r op (Code op))
-> ImpM lore r op (Code op) -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$
ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (ImpM lore r op () -> ImpM lore r op (Code op))
-> ImpM lore r op () -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$ Pattern lore -> Exp lore -> ImpM lore r op ()
forall lore r op. Pattern lore -> Exp lore -> ImpM lore r op ()
compileExp Pattern lore
pat Exp lore
e
(Names
live_after, Code op
bs_code) <- ImpM lore r op Names -> ImpM lore r op (Names, Code op)
forall lore r op a. ImpM lore r op a -> ImpM lore r op (a, Code op)
collect' (ImpM lore r op Names -> ImpM lore r op (Names, Code op))
-> ImpM lore r op Names -> ImpM lore r op (Names, Code op)
forall a b. (a -> b) -> a -> b
$ Set (VName, Space) -> [Stm lore] -> ImpM lore r op Names
compileStms' (PatternT (MemBound NoUniqueness) -> Set (VName, Space)
patternAllocs Pattern lore
PatternT (MemBound NoUniqueness)
pat Set (VName, Space) -> Set (VName, Space) -> Set (VName, Space)
forall a. Semigroup a => a -> a -> a
<> Set (VName, Space)
allocs) [Stm lore]
bs
let dies_here :: VName -> Bool
dies_here VName
v =
Bool -> Bool
not (VName
v VName -> Names -> Bool
`nameIn` Names
live_after)
Bool -> Bool -> Bool
&& VName
v VName -> Names -> Bool
`nameIn` Code op -> Names
forall a. FreeIn a => a -> Names
freeIn Code op
e_code
to_free :: Set (VName, Space)
to_free = ((VName, Space) -> Bool)
-> Set (VName, Space) -> Set (VName, Space)
forall a. (a -> Bool) -> Set a -> Set a
S.filter (VName -> Bool
dies_here (VName -> Bool)
-> ((VName, Space) -> VName) -> (VName, Space) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, Space) -> VName
forall a b. (a, b) -> a
fst) Set (VName, Space)
allocs
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit Code op
e_code
((VName, Space) -> ImpM lore r op ())
-> Set (VName, Space) -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ())
-> ((VName, Space) -> Code op)
-> (VName, Space)
-> ImpM lore r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> Space -> Code op) -> (VName, Space) -> Code op
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry VName -> Space -> Code op
forall a. VName -> Space -> Code a
Imp.Free) Set (VName, Space)
to_free
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit Code op
bs_code
Names -> ImpM lore r op Names
forall (m :: * -> *) a. Monad m => a -> m a
return (Names -> ImpM lore r op Names) -> Names -> ImpM lore r op Names
forall a b. (a -> b) -> a -> b
$ Code op -> Names
forall a. FreeIn a => a -> Names
freeIn Code op
e_code Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
live_after
compileStms' Set (VName, Space)
_ [] = do
Code op
code <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect ImpM lore r op ()
m
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit Code op
code
Names -> ImpM lore r op Names
forall (m :: * -> *) a. Monad m => a -> m a
return (Names -> ImpM lore r op Names) -> Names -> ImpM lore r op Names
forall a b. (a -> b) -> a -> b
$ Code op -> Names
forall a. FreeIn a => a -> Names
freeIn Code op
code Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
alive_after_stms
patternAllocs :: PatternT (MemBound NoUniqueness) -> Set (VName, Space)
patternAllocs = [(VName, Space)] -> Set (VName, Space)
forall a. Ord a => [a] -> Set a
S.fromList ([(VName, Space)] -> Set (VName, Space))
-> (PatternT (MemBound NoUniqueness) -> [(VName, Space)])
-> PatternT (MemBound NoUniqueness)
-> Set (VName, Space)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElemT (MemBound NoUniqueness) -> Maybe (VName, Space))
-> [PatElemT (MemBound NoUniqueness)] -> [(VName, Space)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe PatElemT (MemBound NoUniqueness) -> Maybe (VName, Space)
forall dec. Typed dec => PatElemT dec -> Maybe (VName, Space)
isMemPatElem ([PatElemT (MemBound NoUniqueness)] -> [(VName, Space)])
-> (PatternT (MemBound NoUniqueness)
-> [PatElemT (MemBound NoUniqueness)])
-> PatternT (MemBound NoUniqueness)
-> [(VName, Space)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatternT (MemBound NoUniqueness)
-> [PatElemT (MemBound NoUniqueness)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements
isMemPatElem :: PatElemT dec -> Maybe (VName, Space)
isMemPatElem PatElemT dec
pe = case PatElemT dec -> Type
forall dec. Typed dec => PatElemT dec -> Type
patElemType PatElemT dec
pe of
Mem Space
space -> (VName, Space) -> Maybe (VName, Space)
forall a. a -> Maybe a
Just (PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT dec
pe, Space
space)
Type
_ -> Maybe (VName, Space)
forall a. Maybe a
Nothing
compileExp :: Pattern lore -> Exp lore -> ImpM lore r op ()
compileExp :: Pattern lore -> Exp lore -> ImpM lore r op ()
compileExp Pattern lore
pat Exp lore
e = do
Pattern lore -> Exp lore -> ImpM lore r op ()
ec <- (Env lore r op -> Pattern lore -> Exp lore -> ImpM lore r op ())
-> ImpM lore r op (Pattern lore -> Exp lore -> ImpM lore r op ())
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Pattern lore -> Exp lore -> ImpM lore r op ()
forall lore r op. Env lore r op -> ExpCompiler lore r op
envExpCompiler
Pattern lore -> Exp lore -> ImpM lore r op ()
ec Pattern lore
pat Exp lore
e
defCompileExp ::
(Mem lore) =>
Pattern lore ->
Exp lore ->
ImpM lore r op ()
defCompileExp :: Pattern lore -> Exp lore -> ImpM lore r op ()
defCompileExp Pattern lore
pat (If DimSize
cond BodyT lore
tbranch BodyT lore
fbranch IfDec (BranchType lore)
_) =
TExp Bool
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
TExp Bool
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf (DimSize -> TExp Bool
forall a. ToExp a => a -> TExp Bool
toBoolExp DimSize
cond) (Pattern lore -> BodyT lore -> ImpM lore r op ()
forall lore r op.
Mem lore =>
Pattern lore -> Body lore -> ImpM lore r op ()
compileBody Pattern lore
pat BodyT lore
tbranch) (Pattern lore -> BodyT lore -> ImpM lore r op ()
forall lore r op.
Mem lore =>
Pattern lore -> Body lore -> ImpM lore r op ()
compileBody Pattern lore
pat BodyT lore
fbranch)
defCompileExp Pattern lore
pat (Apply Name
fname [(DimSize, Diet)]
args [RetType lore]
_ (Safety, SrcLoc, [SrcLoc])
_) = do
Destination
dest <- Pattern lore -> ImpM lore r op Destination
forall lore r op.
Mem lore =>
Pattern lore -> ImpM lore r op Destination
destinationFromPattern Pattern lore
pat
[VName]
targets <- Destination -> ImpM lore r op [VName]
forall lore r op. Destination -> ImpM lore r op [VName]
funcallTargets Destination
dest
[Arg]
args' <- [Maybe Arg] -> [Arg]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe Arg] -> [Arg])
-> ImpM lore r op [Maybe Arg] -> ImpM lore r op [Arg]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((DimSize, Diet) -> ImpM lore r op (Maybe Arg))
-> [(DimSize, Diet)] -> ImpM lore r op [Maybe Arg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (DimSize, Diet) -> ImpM lore r op (Maybe Arg)
forall (m :: * -> *) t b.
(Monad m, HasScope t m) =>
(DimSize, b) -> m (Maybe Arg)
compileArg [(DimSize, Diet)]
args
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Name -> [Arg] -> Code op
forall a. [VName] -> Name -> [Arg] -> Code a
Imp.Call [VName]
targets Name
fname [Arg]
args'
where
compileArg :: (DimSize, b) -> m (Maybe Arg)
compileArg (DimSize
se, b
_) = do
Type
t <- DimSize -> m Type
forall t (m :: * -> *). HasScope t m => DimSize -> m Type
subExpType DimSize
se
case (DimSize
se, Type
t) of
(DimSize
_, Prim PrimType
pt) -> Maybe Arg -> m (Maybe Arg)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Arg -> m (Maybe Arg)) -> Maybe Arg -> m (Maybe Arg)
forall a b. (a -> b) -> a -> b
$ Arg -> Maybe Arg
forall a. a -> Maybe a
Just (Arg -> Maybe Arg) -> Arg -> Maybe Arg
forall a b. (a -> b) -> a -> b
$ Exp -> Arg
Imp.ExpArg (Exp -> Arg) -> Exp -> Arg
forall a b. (a -> b) -> a -> b
$ PrimType -> DimSize -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
pt DimSize
se
(Var VName
v, Mem {}) -> Maybe Arg -> m (Maybe Arg)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Arg -> m (Maybe Arg)) -> Maybe Arg -> m (Maybe Arg)
forall a b. (a -> b) -> a -> b
$ Arg -> Maybe Arg
forall a. a -> Maybe a
Just (Arg -> Maybe Arg) -> Arg -> Maybe Arg
forall a b. (a -> b) -> a -> b
$ VName -> Arg
Imp.MemArg VName
v
(DimSize, Type)
_ -> Maybe Arg -> m (Maybe Arg)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Arg
forall a. Maybe a
Nothing
defCompileExp Pattern lore
pat (BasicOp BasicOp
op) = Pattern lore -> BasicOp -> ImpM lore r op ()
forall lore r op.
Mem lore =>
Pattern lore -> BasicOp -> ImpM lore r op ()
defCompileBasicOp Pattern lore
pat BasicOp
op
defCompileExp Pattern lore
pat (DoLoop [(FParam lore, DimSize)]
ctx [(FParam lore, DimSize)]
val LoopForm lore
form BodyT lore
body) = do
Attrs
attrs <- ImpM lore r op Attrs
forall lore r op. ImpM lore r op Attrs
askAttrs
Bool -> ImpM lore r op () -> ImpM lore r op ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Attr
"unroll" Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs) (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
SrcLoc -> [SrcLoc] -> String -> ImpM lore r op ()
forall loc lore r op.
Located loc =>
loc -> [loc] -> String -> ImpM lore r op ()
warn (SrcLoc
forall a. IsLocation a => a
noLoc :: SrcLoc) [] String
"#[unroll] on loop with unknown number of iterations."
[FParam lore] -> ImpM lore r op ()
forall lore r op. Mem lore => [FParam lore] -> ImpM lore r op ()
dFParams [FParam lore]
[Param (MemInfo DimSize Uniqueness MemBind)]
mergepat
[(Param (MemInfo DimSize Uniqueness MemBind), DimSize)]
-> ((Param (MemInfo DimSize Uniqueness MemBind), DimSize)
-> ImpM lore r op ())
-> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(Param (MemInfo DimSize Uniqueness MemBind), DimSize)]
merge (((Param (MemInfo DimSize Uniqueness MemBind), DimSize)
-> ImpM lore r op ())
-> ImpM lore r op ())
-> ((Param (MemInfo DimSize Uniqueness MemBind), DimSize)
-> ImpM lore r op ())
-> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo DimSize Uniqueness MemBind)
p, DimSize
se) ->
Bool -> ImpM lore r op () -> ImpM lore r op ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ((Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0) (Int -> Bool) -> Int -> Bool
forall a b. (a -> b) -> a -> b
$ Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (Type -> Int) -> Type -> Int
forall a b. (a -> b) -> a -> b
$ Param (MemInfo DimSize Uniqueness MemBind) -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param (MemInfo DimSize Uniqueness MemBind)
p) (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (Param (MemInfo DimSize Uniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo DimSize Uniqueness MemBind)
p) [] DimSize
se []
let doBody :: ImpM lore r op ()
doBody = [Param (MemInfo DimSize Uniqueness MemBind)]
-> BodyT lore -> ImpM lore r op ()
forall dec lore r op.
Typed dec =>
[Param dec] -> Body lore -> ImpM lore r op ()
compileLoopBody [Param (MemInfo DimSize Uniqueness MemBind)]
mergepat BodyT lore
body
case LoopForm lore
form of
ForLoop VName
i IntType
_ DimSize
bound [(LParam lore, VName)]
loopvars -> do
let setLoopParam :: (Param (MemBound NoUniqueness), VName) -> ImpM lore r op ()
setLoopParam (Param (MemBound NoUniqueness)
p, VName
a)
| Prim PrimType
_ <- Param (MemBound NoUniqueness) -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param (MemBound NoUniqueness)
p =
VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (Param (MemBound NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (MemBound NoUniqueness)
p) [] (VName -> DimSize
Var VName
a) [TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix (TExp Int64 -> DimIndex (TExp Int64))
-> TExp Int64 -> DimIndex (TExp Int64)
forall a b. (a -> b) -> a -> b
$ VName -> TExp Int64
Imp.vi64 VName
i]
| Bool
otherwise =
() -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
Exp
bound' <- DimSize -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp DimSize
bound
[LParam lore] -> ImpM lore r op ()
forall lore r op. Mem lore => [LParam lore] -> ImpM lore r op ()
dLParams ([LParam lore] -> ImpM lore r op ())
-> [LParam lore] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ ((Param (MemBound NoUniqueness), VName)
-> Param (MemBound NoUniqueness))
-> [(Param (MemBound NoUniqueness), VName)]
-> [Param (MemBound NoUniqueness)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (MemBound NoUniqueness), VName)
-> Param (MemBound NoUniqueness)
forall a b. (a, b) -> a
fst [(LParam lore, VName)]
[(Param (MemBound NoUniqueness), VName)]
loopvars
VName -> Exp -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
VName -> Exp -> ImpM lore r op () -> ImpM lore r op ()
sFor' VName
i Exp
bound' (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
((Param (MemBound NoUniqueness), VName) -> ImpM lore r op ())
-> [(Param (MemBound NoUniqueness), VName)] -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Param (MemBound NoUniqueness), VName) -> ImpM lore r op ()
setLoopParam [(LParam lore, VName)]
[(Param (MemBound NoUniqueness), VName)]
loopvars ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ImpM lore r op ()
doBody
WhileLoop VName
cond ->
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhile (Exp -> TExp Bool
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (Exp -> TExp Bool) -> Exp -> TExp Bool
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
cond PrimType
Bool) ImpM lore r op ()
doBody
Destination Maybe Int
_ [ValueDestination]
pat_dests <- Pattern lore -> ImpM lore r op Destination
forall lore r op.
Mem lore =>
Pattern lore -> ImpM lore r op Destination
destinationFromPattern Pattern lore
pat
[(ValueDestination, DimSize)]
-> ((ValueDestination, DimSize) -> ImpM lore r op ())
-> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([ValueDestination] -> [DimSize] -> [(ValueDestination, DimSize)]
forall a b. [a] -> [b] -> [(a, b)]
zip [ValueDestination]
pat_dests ([DimSize] -> [(ValueDestination, DimSize)])
-> [DimSize] -> [(ValueDestination, DimSize)]
forall a b. (a -> b) -> a -> b
$ ((Param (MemInfo DimSize Uniqueness MemBind), DimSize) -> DimSize)
-> [(Param (MemInfo DimSize Uniqueness MemBind), DimSize)]
-> [DimSize]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> DimSize
Var (VName -> DimSize)
-> ((Param (MemInfo DimSize Uniqueness MemBind), DimSize) -> VName)
-> (Param (MemInfo DimSize Uniqueness MemBind), DimSize)
-> DimSize
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param (MemInfo DimSize Uniqueness MemBind) -> VName
forall dec. Param dec -> VName
paramName (Param (MemInfo DimSize Uniqueness MemBind) -> VName)
-> ((Param (MemInfo DimSize Uniqueness MemBind), DimSize)
-> Param (MemInfo DimSize Uniqueness MemBind))
-> (Param (MemInfo DimSize Uniqueness MemBind), DimSize)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param (MemInfo DimSize Uniqueness MemBind), DimSize)
-> Param (MemInfo DimSize Uniqueness MemBind)
forall a b. (a, b) -> a
fst) [(Param (MemInfo DimSize Uniqueness MemBind), DimSize)]
merge) (((ValueDestination, DimSize) -> ImpM lore r op ())
-> ImpM lore r op ())
-> ((ValueDestination, DimSize) -> ImpM lore r op ())
-> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(ValueDestination
d, DimSize
r) ->
ValueDestination
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
ValueDestination
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIMDest ValueDestination
d [] DimSize
r []
where
merge :: [(Param (MemInfo DimSize Uniqueness MemBind), DimSize)]
merge = [(FParam lore, DimSize)]
[(Param (MemInfo DimSize Uniqueness MemBind), DimSize)]
ctx [(Param (MemInfo DimSize Uniqueness MemBind), DimSize)]
-> [(Param (MemInfo DimSize Uniqueness MemBind), DimSize)]
-> [(Param (MemInfo DimSize Uniqueness MemBind), DimSize)]
forall a. [a] -> [a] -> [a]
++ [(FParam lore, DimSize)]
[(Param (MemInfo DimSize Uniqueness MemBind), DimSize)]
val
mergepat :: [Param (MemInfo DimSize Uniqueness MemBind)]
mergepat = ((Param (MemInfo DimSize Uniqueness MemBind), DimSize)
-> Param (MemInfo DimSize Uniqueness MemBind))
-> [(Param (MemInfo DimSize Uniqueness MemBind), DimSize)]
-> [Param (MemInfo DimSize Uniqueness MemBind)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (MemInfo DimSize Uniqueness MemBind), DimSize)
-> Param (MemInfo DimSize Uniqueness MemBind)
forall a b. (a, b) -> a
fst [(Param (MemInfo DimSize Uniqueness MemBind), DimSize)]
merge
defCompileExp Pattern lore
pat (Op Op lore
op) = do
PatternT (MemBound NoUniqueness) -> Op lore -> ImpM lore r op ()
opc <- (Env lore r op
-> PatternT (MemBound NoUniqueness)
-> Op lore
-> ImpM lore r op ())
-> ImpM
lore
r
op
(PatternT (MemBound NoUniqueness) -> Op lore -> ImpM lore r op ())
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op
-> PatternT (MemBound NoUniqueness) -> Op lore -> ImpM lore r op ()
forall lore r op. Env lore r op -> OpCompiler lore r op
envOpCompiler
PatternT (MemBound NoUniqueness) -> Op lore -> ImpM lore r op ()
opc Pattern lore
PatternT (MemBound NoUniqueness)
pat Op lore
op
defCompileBasicOp ::
Mem lore =>
Pattern lore ->
BasicOp ->
ImpM lore r op ()
defCompileBasicOp :: Pattern lore -> BasicOp -> ImpM lore r op ()
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (SubExp DimSize
se) =
VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (PatElemT (MemBound NoUniqueness) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT (MemBound NoUniqueness)
pe) [] DimSize
se []
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (Opaque DimSize
se) =
VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (PatElemT (MemBound NoUniqueness) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT (MemBound NoUniqueness)
pe) [] DimSize
se []
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (UnOp UnOp
op DimSize
e) = do
Exp
e' <- DimSize -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp DimSize
e
PatElemT (MemBound NoUniqueness) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT (MemBound NoUniqueness)
pe VName -> Exp -> ImpM lore r op ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<~~ UnOp -> Exp -> Exp
forall v. UnOp -> PrimExp v -> PrimExp v
Imp.UnOpExp UnOp
op Exp
e'
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (ConvOp ConvOp
conv DimSize
e) = do
Exp
e' <- DimSize -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp DimSize
e
PatElemT (MemBound NoUniqueness) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT (MemBound NoUniqueness)
pe VName -> Exp -> ImpM lore r op ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<~~ ConvOp -> Exp -> Exp
forall v. ConvOp -> PrimExp v -> PrimExp v
Imp.ConvOpExp ConvOp
conv Exp
e'
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (BinOp BinOp
bop DimSize
x DimSize
y) = do
Exp
x' <- DimSize -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp DimSize
x
Exp
y' <- DimSize -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp DimSize
y
PatElemT (MemBound NoUniqueness) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT (MemBound NoUniqueness)
pe VName -> Exp -> ImpM lore r op ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<~~ BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp BinOp
bop Exp
x' Exp
y'
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (CmpOp CmpOp
bop DimSize
x DimSize
y) = do
Exp
x' <- DimSize -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp DimSize
x
Exp
y' <- DimSize -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp DimSize
y
PatElemT (MemBound NoUniqueness) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT (MemBound NoUniqueness)
pe VName -> Exp -> ImpM lore r op ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<~~ CmpOp -> Exp -> Exp -> Exp
forall v. CmpOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.CmpOpExp CmpOp
bop Exp
x' Exp
y'
defCompileBasicOp Pattern lore
_ (Assert DimSize
e ErrorMsg DimSize
msg (SrcLoc, [SrcLoc])
loc) = do
Exp
e' <- DimSize -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp DimSize
e
ErrorMsg Exp
msg' <- (DimSize -> ImpM lore r op Exp)
-> ErrorMsg DimSize -> ImpM lore r op (ErrorMsg Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse DimSize -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp ErrorMsg DimSize
msg
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Exp -> ErrorMsg Exp -> (SrcLoc, [SrcLoc]) -> Code op
forall a. Exp -> ErrorMsg Exp -> (SrcLoc, [SrcLoc]) -> Code a
Imp.Assert Exp
e' ErrorMsg Exp
msg' (SrcLoc, [SrcLoc])
loc
Attrs
attrs <- ImpM lore r op Attrs
forall lore r op. ImpM lore r op Attrs
askAttrs
Bool -> ImpM lore r op () -> ImpM lore r op ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Name -> [Attr] -> Attr
AttrComp Name
"warn" [Attr
"safety_checks"] Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs) (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
(SrcLoc -> [SrcLoc] -> String -> ImpM lore r op ())
-> (SrcLoc, [SrcLoc]) -> String -> ImpM lore r op ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry SrcLoc -> [SrcLoc] -> String -> ImpM lore r op ()
forall loc lore r op.
Located loc =>
loc -> [loc] -> String -> ImpM lore r op ()
warn (SrcLoc, [SrcLoc])
loc String
"Safety check required at run-time."
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (Index VName
src Slice DimSize
slice)
| Just [DimSize]
idxs <- Slice DimSize -> Maybe [DimSize]
forall d. Slice d -> Maybe [d]
sliceIndices Slice DimSize
slice =
VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (PatElemT (MemBound NoUniqueness) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT (MemBound NoUniqueness)
pe) [] (VName -> DimSize
Var VName
src) ([DimIndex (TExp Int64)] -> ImpM lore r op ())
-> [DimIndex (TExp Int64)] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ (DimSize -> DimIndex (TExp Int64))
-> [DimSize] -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map (TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix (TExp Int64 -> DimIndex (TExp Int64))
-> (DimSize -> TExp Int64) -> DimSize -> DimIndex (TExp Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DimSize -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp) [DimSize]
idxs
defCompileBasicOp Pattern lore
_ Index {} =
() -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (Update VName
_ Slice DimSize
slice DimSize
se) =
VName -> [DimIndex (TExp Int64)] -> DimSize -> ImpM lore r op ()
forall lore r op.
VName -> [DimIndex (TExp Int64)] -> DimSize -> ImpM lore r op ()
sUpdate (PatElemT (MemBound NoUniqueness) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT (MemBound NoUniqueness)
pe) ((DimIndex DimSize -> DimIndex (TExp Int64))
-> Slice DimSize -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map ((DimSize -> TExp Int64)
-> DimIndex DimSize -> DimIndex (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap DimSize -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp) Slice DimSize
slice) DimSize
se
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (Replicate (Shape [DimSize]
ds) DimSize
se) = do
[Exp]
ds' <- (DimSize -> ImpM lore r op Exp)
-> [DimSize] -> ImpM lore r op [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimSize -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp [DimSize]
ds
[VName]
is <- Int -> ImpM lore r op VName -> ImpM lore r op [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM ([DimSize] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimSize]
ds) (String -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"i")
Code op
copy_elem <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (ImpM lore r op () -> ImpM lore r op (Code op))
-> ImpM lore r op () -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$ VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (PatElemT (MemBound NoUniqueness) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT (MemBound NoUniqueness)
pe) ((VName -> DimIndex (TExp Int64))
-> [VName] -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map (TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix (TExp Int64 -> DimIndex (TExp Int64))
-> (VName -> TExp Int64) -> VName -> DimIndex (TExp Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> TExp Int64
Imp.vi64) [VName]
is) DimSize
se []
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ ((Code op -> Code op)
-> (Code op -> Code op) -> Code op -> Code op)
-> (Code op -> Code op)
-> [Code op -> Code op]
-> Code op
-> Code op
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (Code op -> Code op) -> (Code op -> Code op) -> Code op -> Code op
forall b c a. (b -> c) -> (a -> b) -> a -> c
(.) Code op -> Code op
forall a. a -> a
id ((VName -> Exp -> Code op -> Code op)
-> [VName] -> [Exp] -> [Code op -> Code op]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> Exp -> Code op -> Code op
forall a. VName -> Exp -> Code a -> Code a
Imp.For [VName]
is [Exp]
ds') Code op
copy_elem
defCompileBasicOp Pattern lore
_ Scratch {} =
() -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
defCompileBasicOp (Pattern [] [PatElemT (LetDec lore)
pe]) (Iota DimSize
n DimSize
e DimSize
s IntType
it) = do
Exp
e' <- DimSize -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp DimSize
e
Exp
s' <- DimSize -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp DimSize
s
String
-> TExp Int64
-> (TExp Int64 -> ImpM lore r op ())
-> ImpM lore r op ()
forall t lore r op.
String
-> TExp t -> (TExp t -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"i" (DimSize -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp DimSize
n) ((TExp Int64 -> ImpM lore r op ()) -> ImpM lore r op ())
-> (TExp Int64 -> ImpM lore r op ()) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
let i' :: Exp
i' = IntType -> Exp -> Exp
forall v. IntType -> PrimExp v -> PrimExp v
sExt IntType
it (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
i
TV Any
x <-
String -> TExp Any -> ImpM lore r op (TV Any)
forall t lore r op. String -> TExp t -> ImpM lore r op (TV t)
dPrimV String
"x" (TExp Any -> ImpM lore r op (TV Any))
-> TExp Any -> ImpM lore r op (TV Any)
forall a b. (a -> b) -> a -> b
$
Exp -> TExp Any
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (Exp -> TExp Any) -> Exp -> TExp Any
forall a b. (a -> b) -> a -> b
$
BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> Overflow -> BinOp
Add IntType
it Overflow
OverflowUndef) Exp
e' (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$
BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> Overflow -> BinOp
Mul IntType
it Overflow
OverflowUndef) Exp
i' Exp
s'
VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (PatElemT (MemBound NoUniqueness) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT (MemBound NoUniqueness)
pe) [TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix TExp Int64
i] (VName -> DimSize
Var (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
x)) []
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (Copy VName
src) =
VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (PatElemT (MemBound NoUniqueness) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT (MemBound NoUniqueness)
pe) [] (VName -> DimSize
Var VName
src) []
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (Manifest [Int]
_ VName
src) =
VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (PatElemT (MemBound NoUniqueness) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT (MemBound NoUniqueness)
pe) [] (VName -> DimSize
Var VName
src) []
defCompileBasicOp (Pattern [PatElemT (LetDec lore)]
_ [PatElemT (LetDec lore)
pe]) (Concat Int
i VName
x [VName]
ys DimSize
_) = do
TV Int64
offs_glb <- String -> TExp Int64 -> ImpM lore r op (TV Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TV t)
dPrimV String
"tmp_offs" TExp Int64
0
[VName] -> (VName -> ImpM lore r op ()) -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (VName
x VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
ys) ((VName -> ImpM lore r op ()) -> ImpM lore r op ())
-> (VName -> ImpM lore r op ()) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \VName
y -> do
[DimSize]
y_dims <- Type -> [DimSize]
forall u. TypeBase (ShapeBase DimSize) u -> [DimSize]
arrayDims (Type -> [DimSize])
-> ImpM lore r op Type -> ImpM lore r op [DimSize]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM lore r op Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
y
let rows :: TExp Int64
rows = case Int -> [DimSize] -> [DimSize]
forall a. Int -> [a] -> [a]
drop Int
i [DimSize]
y_dims of
[] -> String -> TExp Int64
forall a. HasCallStack => String -> a
error (String -> TExp Int64) -> String -> TExp Int64
forall a b. (a -> b) -> a -> b
$ String
"defCompileBasicOp Concat: empty array shape for " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
y
DimSize
r : [DimSize]
_ -> DimSize -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp DimSize
r
skip_dims :: [DimSize]
skip_dims = Int -> [DimSize] -> [DimSize]
forall a. Int -> [a] -> [a]
take Int
i [DimSize]
y_dims
sliceAllDim :: d -> DimIndex d
sliceAllDim d
d = d -> d -> d -> DimIndex d
forall d. d -> d -> d -> DimIndex d
DimSlice d
0 d
d d
1
skip_slices :: [DimIndex (TExp Int64)]
skip_slices = (DimSize -> DimIndex (TExp Int64))
-> [DimSize] -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map (TExp Int64 -> DimIndex (TExp Int64)
forall d. Num d => d -> DimIndex d
sliceAllDim (TExp Int64 -> DimIndex (TExp Int64))
-> (DimSize -> TExp Int64) -> DimSize -> DimIndex (TExp Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DimSize -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp) [DimSize]
skip_dims
destslice :: [DimIndex (TExp Int64)]
destslice = [DimIndex (TExp Int64)]
skip_slices [DimIndex (TExp Int64)]
-> [DimIndex (TExp Int64)] -> [DimIndex (TExp Int64)]
forall a. [a] -> [a] -> [a]
++ [TExp Int64 -> TExp Int64 -> TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> d -> d -> DimIndex d
DimSlice (TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
offs_glb) TExp Int64
rows TExp Int64
1]
VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (PatElemT (MemBound NoUniqueness) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT (MemBound NoUniqueness)
pe) [DimIndex (TExp Int64)]
destslice (VName -> DimSize
Var VName
y) []
TV Int64
offs_glb TV Int64 -> TExp Int64 -> ImpM lore r op ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
offs_glb TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
rows
defCompileBasicOp (Pattern [] [PatElemT (LetDec lore)
pe]) (ArrayLit [DimSize]
es Type
_)
| Just vs :: [PrimValue]
vs@(PrimValue
v : [PrimValue]
_) <- (DimSize -> Maybe PrimValue) -> [DimSize] -> Maybe [PrimValue]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimSize -> Maybe PrimValue
isLiteral [DimSize]
es = do
MemLocation
dest_mem <- ArrayEntry -> MemLocation
entryArrayLocation (ArrayEntry -> MemLocation)
-> ImpM lore r op ArrayEntry -> ImpM lore r op MemLocation
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM lore r op ArrayEntry
forall lore r op. VName -> ImpM lore r op ArrayEntry
lookupArray (PatElemT (MemBound NoUniqueness) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT (MemBound NoUniqueness)
pe)
Space
dest_space <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM lore r op MemEntry -> ImpM lore r op Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM lore r op MemEntry
forall lore r op. VName -> ImpM lore r op MemEntry
lookupMemory (MemLocation -> VName
memLocationName MemLocation
dest_mem)
let t :: PrimType
t = PrimValue -> PrimType
primValueType PrimValue
v
VName
static_array <- String -> ImpM lore r op VName
forall lore r op. String -> ImpM lore r op VName
newVNameForFun String
"static_array"
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Space -> PrimType -> ArrayContents -> Code op
forall a. VName -> Space -> PrimType -> ArrayContents -> Code a
Imp.DeclareArray VName
static_array Space
dest_space PrimType
t (ArrayContents -> Code op) -> ArrayContents -> Code op
forall a b. (a -> b) -> a -> b
$ [PrimValue] -> ArrayContents
Imp.ArrayValues [PrimValue]
vs
let static_src :: MemLocation
static_src =
VName -> [DimSize] -> IxFun (TExp Int64) -> MemLocation
MemLocation VName
static_array [IntType -> Integer -> DimSize
intConst IntType
Int64 (Integer -> DimSize) -> Integer -> DimSize
forall a b. (a -> b) -> a -> b
$ Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Integer) -> Int -> Integer
forall a b. (a -> b) -> a -> b
$ [DimSize] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimSize]
es] (IxFun (TExp Int64) -> MemLocation)
-> IxFun (TExp Int64) -> MemLocation
forall a b. (a -> b) -> a -> b
$
Shape (TExp Int64) -> IxFun (TExp Int64)
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota [Int -> TExp Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> TExp Int64) -> Int -> TExp Int64
forall a b. (a -> b) -> a -> b
$ [DimSize] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimSize]
es]
entry :: VarEntry lore
entry = Maybe (Exp lore) -> MemEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> MemEntry -> VarEntry lore
MemVar Maybe (Exp lore)
forall a. Maybe a
Nothing (MemEntry -> VarEntry lore) -> MemEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
dest_space
VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
static_array VarEntry lore
entry
let slice :: [DimIndex (TExp Int64)]
slice = [TExp Int64 -> TExp Int64 -> TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> d -> d -> DimIndex d
DimSlice TExp Int64
0 ([DimSize] -> TExp Int64
forall i a. Num i => [a] -> i
genericLength [DimSize]
es) TExp Int64
1]
CopyCompiler lore r op
forall lore r op. CopyCompiler lore r op
copy PrimType
t MemLocation
dest_mem [DimIndex (TExp Int64)]
slice MemLocation
static_src [DimIndex (TExp Int64)]
slice
| Bool
otherwise =
[(Integer, DimSize)]
-> ((Integer, DimSize) -> ImpM lore r op ()) -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Integer] -> [DimSize] -> [(Integer, DimSize)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Integer
0 ..] [DimSize]
es) (((Integer, DimSize) -> ImpM lore r op ()) -> ImpM lore r op ())
-> ((Integer, DimSize) -> ImpM lore r op ()) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \(Integer
i, DimSize
e) ->
VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (PatElemT (MemBound NoUniqueness) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT (MemBound NoUniqueness)
pe) [TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix (TExp Int64 -> DimIndex (TExp Int64))
-> TExp Int64 -> DimIndex (TExp Int64)
forall a b. (a -> b) -> a -> b
$ Integer -> TExp Int64
forall a. Num a => Integer -> a
fromInteger Integer
i] DimSize
e []
where
isLiteral :: DimSize -> Maybe PrimValue
isLiteral (Constant PrimValue
v) = PrimValue -> Maybe PrimValue
forall a. a -> Maybe a
Just PrimValue
v
isLiteral DimSize
_ = Maybe PrimValue
forall a. Maybe a
Nothing
defCompileBasicOp Pattern lore
_ Rearrange {} =
() -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
defCompileBasicOp Pattern lore
_ Rotate {} =
() -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
defCompileBasicOp Pattern lore
_ Reshape {} =
() -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
defCompileBasicOp Pattern lore
pat BasicOp
e =
String -> ImpM lore r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op ()) -> String -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
String
"ImpGen.defCompileBasicOp: Invalid pattern\n "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ PatternT (MemBound NoUniqueness) -> String
forall a. Pretty a => a -> String
pretty Pattern lore
PatternT (MemBound NoUniqueness)
pat
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"\nfor expression\n "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ BasicOp -> String
forall a. Pretty a => a -> String
pretty BasicOp
e
addArrays :: [ArrayDecl] -> ImpM lore r op ()
addArrays :: [ArrayDecl] -> ImpM lore r op ()
addArrays = (ArrayDecl -> ImpM lore r op ())
-> [ArrayDecl] -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ArrayDecl -> ImpM lore r op ()
forall lore r op. ArrayDecl -> ImpM lore r op ()
addArray
where
addArray :: ArrayDecl -> ImpM lore r op ()
addArray (ArrayDecl VName
name PrimType
bt MemLocation
location) =
VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
Maybe (Exp lore) -> ArrayEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> ArrayEntry -> VarEntry lore
ArrayVar
Maybe (Exp lore)
forall a. Maybe a
Nothing
ArrayEntry :: MemLocation -> PrimType -> ArrayEntry
ArrayEntry
{ entryArrayLocation :: MemLocation
entryArrayLocation = MemLocation
location,
entryArrayElemType :: PrimType
entryArrayElemType = PrimType
bt
}
addFParams :: Mem lore => [FParam lore] -> ImpM lore r op ()
addFParams :: [FParam lore] -> ImpM lore r op ()
addFParams = (Param (MemInfo DimSize Uniqueness MemBind) -> ImpM lore r op ())
-> [Param (MemInfo DimSize Uniqueness MemBind)]
-> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Param (MemInfo DimSize Uniqueness MemBind) -> ImpM lore r op ()
forall u lore r op.
Param (MemInfo DimSize u MemBind) -> ImpM lore r op ()
addFParam
where
addFParam :: Param (MemInfo DimSize u MemBind) -> ImpM lore r op ()
addFParam Param (MemInfo DimSize u MemBind)
fparam =
VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar (Param (MemInfo DimSize u MemBind) -> VName
forall dec. Param dec -> VName
paramName Param (MemInfo DimSize u MemBind)
fparam) (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
Maybe (Exp lore) -> MemBound NoUniqueness -> VarEntry lore
forall lore.
Maybe (Exp lore) -> MemBound NoUniqueness -> VarEntry lore
memBoundToVarEntry Maybe (Exp lore)
forall a. Maybe a
Nothing (MemBound NoUniqueness -> VarEntry lore)
-> MemBound NoUniqueness -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ MemInfo DimSize u MemBind -> MemBound NoUniqueness
forall d u r. MemInfo d u r -> MemInfo d NoUniqueness r
noUniquenessReturns (MemInfo DimSize u MemBind -> MemBound NoUniqueness)
-> MemInfo DimSize u MemBind -> MemBound NoUniqueness
forall a b. (a -> b) -> a -> b
$ Param (MemInfo DimSize u MemBind) -> MemInfo DimSize u MemBind
forall dec. Param dec -> dec
paramDec Param (MemInfo DimSize u MemBind)
fparam
addLoopVar :: VName -> IntType -> ImpM lore r op ()
addLoopVar :: VName -> IntType -> ImpM lore r op ()
addLoopVar VName
i IntType
it = VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
i (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
ScalarVar Maybe (Exp lore)
forall a. Maybe a
Nothing (ScalarEntry -> VarEntry lore) -> ScalarEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarEntry
ScalarEntry (PrimType -> ScalarEntry) -> PrimType -> ScalarEntry
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
it
dVars ::
Mem lore =>
Maybe (Exp lore) ->
[PatElem lore] ->
ImpM lore r op ()
dVars :: Maybe (Exp lore) -> [PatElem lore] -> ImpM lore r op ()
dVars Maybe (Exp lore)
e = (PatElemT (MemBound NoUniqueness) -> ImpM lore r op ())
-> [PatElemT (MemBound NoUniqueness)] -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ PatElemT (MemBound NoUniqueness) -> ImpM lore r op ()
dVar
where
dVar :: PatElemT (MemBound NoUniqueness) -> ImpM lore r op ()
dVar = Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
forall lore r op.
Mem lore =>
Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
dScope Maybe (Exp lore)
e (Scope lore -> ImpM lore r op ())
-> (PatElemT (MemBound NoUniqueness) -> Scope lore)
-> PatElemT (MemBound NoUniqueness)
-> ImpM lore r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT (MemBound NoUniqueness) -> Scope lore
forall lore dec. (LetDec lore ~ dec) => PatElemT dec -> Scope lore
scopeOfPatElem
dFParams :: Mem lore => [FParam lore] -> ImpM lore r op ()
dFParams :: [FParam lore] -> ImpM lore r op ()
dFParams = Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
forall lore r op.
Mem lore =>
Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
dScope Maybe (Exp lore)
forall a. Maybe a
Nothing (Scope lore -> ImpM lore r op ())
-> ([Param (MemInfo DimSize Uniqueness MemBind)] -> Scope lore)
-> [Param (MemInfo DimSize Uniqueness MemBind)]
-> ImpM lore r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Param (MemInfo DimSize Uniqueness MemBind)] -> Scope lore
forall lore dec.
(FParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfFParams
dLParams :: Mem lore => [LParam lore] -> ImpM lore r op ()
dLParams :: [LParam lore] -> ImpM lore r op ()
dLParams = Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
forall lore r op.
Mem lore =>
Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
dScope Maybe (Exp lore)
forall a. Maybe a
Nothing (Scope lore -> ImpM lore r op ())
-> ([Param (MemBound NoUniqueness)] -> Scope lore)
-> [Param (MemBound NoUniqueness)]
-> ImpM lore r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Param (MemBound NoUniqueness)] -> Scope lore
forall lore dec.
(LParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfLParams
dPrimVol :: String -> PrimType -> Imp.TExp t -> ImpM lore r op (TV t)
dPrimVol :: String -> PrimType -> TExp t -> ImpM lore r op (TV t)
dPrimVol String
name PrimType
t TExp t
e = do
VName
name' <- String -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
name
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Volatility -> PrimType -> Code op
forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
name' Volatility
Imp.Volatile PrimType
t
VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name' (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
ScalarVar Maybe (Exp lore)
forall a. Maybe a
Nothing (ScalarEntry -> VarEntry lore) -> ScalarEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarEntry
ScalarEntry PrimType
t
VName
name' VName -> Exp -> ImpM lore r op ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<~~ TExp t -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp t
e
TV t -> ImpM lore r op (TV t)
forall (m :: * -> *) a. Monad m => a -> m a
return (TV t -> ImpM lore r op (TV t)) -> TV t -> ImpM lore r op (TV t)
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> TV t
forall t. VName -> PrimType -> TV t
TV VName
name' PrimType
t
dPrim_ :: VName -> PrimType -> ImpM lore r op ()
dPrim_ :: VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
name PrimType
t = do
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Volatility -> PrimType -> Code op
forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
name Volatility
Imp.Nonvolatile PrimType
t
VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
ScalarVar Maybe (Exp lore)
forall a. Maybe a
Nothing (ScalarEntry -> VarEntry lore) -> ScalarEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarEntry
ScalarEntry PrimType
t
dPrim :: String -> PrimType -> ImpM lore r op (TV t)
dPrim :: String -> PrimType -> ImpM lore r op (TV t)
dPrim String
name PrimType
t = do
VName
name' <- String -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
name
VName -> PrimType -> ImpM lore r op ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
name' PrimType
t
TV t -> ImpM lore r op (TV t)
forall (m :: * -> *) a. Monad m => a -> m a
return (TV t -> ImpM lore r op (TV t)) -> TV t -> ImpM lore r op (TV t)
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> TV t
forall t. VName -> PrimType -> TV t
TV VName
name' PrimType
t
dPrimV_ :: VName -> Imp.TExp t -> ImpM lore r op ()
dPrimV_ :: VName -> TExp t -> ImpM lore r op ()
dPrimV_ VName
name TExp t
e = do
VName -> PrimType -> ImpM lore r op ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
name PrimType
t
VName -> PrimType -> TV t
forall t. VName -> PrimType -> TV t
TV VName
name PrimType
t TV t -> TExp t -> ImpM lore r op ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- TExp t
e
where
t :: PrimType
t = Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType (Exp -> PrimType) -> Exp -> PrimType
forall a b. (a -> b) -> a -> b
$ TExp t -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp t
e
dPrimV :: String -> Imp.TExp t -> ImpM lore r op (TV t)
dPrimV :: String -> TExp t -> ImpM lore r op (TV t)
dPrimV String
name TExp t
e = do
TV t
name' <- String -> PrimType -> ImpM lore r op (TV t)
forall lore r op t. String -> PrimType -> ImpM lore r op (TV t)
dPrim String
name (PrimType -> ImpM lore r op (TV t))
-> PrimType -> ImpM lore r op (TV t)
forall a b. (a -> b) -> a -> b
$ Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType (Exp -> PrimType) -> Exp -> PrimType
forall a b. (a -> b) -> a -> b
$ TExp t -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp t
e
TV t
name' TV t -> TExp t -> ImpM lore r op ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- TExp t
e
TV t -> ImpM lore r op (TV t)
forall (m :: * -> *) a. Monad m => a -> m a
return TV t
name'
dPrimVE :: String -> Imp.TExp t -> ImpM lore r op (Imp.TExp t)
dPrimVE :: String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
name TExp t
e = do
TV t
name' <- String -> PrimType -> ImpM lore r op (TV t)
forall lore r op t. String -> PrimType -> ImpM lore r op (TV t)
dPrim String
name (PrimType -> ImpM lore r op (TV t))
-> PrimType -> ImpM lore r op (TV t)
forall a b. (a -> b) -> a -> b
$ Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType (Exp -> PrimType) -> Exp -> PrimType
forall a b. (a -> b) -> a -> b
$ TExp t -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp t
e
TV t
name' TV t -> TExp t -> ImpM lore r op ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- TExp t
e
TExp t -> ImpM lore r op (TExp t)
forall (m :: * -> *) a. Monad m => a -> m a
return (TExp t -> ImpM lore r op (TExp t))
-> TExp t -> ImpM lore r op (TExp t)
forall a b. (a -> b) -> a -> b
$ TV t -> TExp t
forall t. TV t -> TExp t
tvExp TV t
name'
memBoundToVarEntry ::
Maybe (Exp lore) ->
MemBound NoUniqueness ->
VarEntry lore
memBoundToVarEntry :: Maybe (Exp lore) -> MemBound NoUniqueness -> VarEntry lore
memBoundToVarEntry Maybe (Exp lore)
e (MemPrim PrimType
bt) =
Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
ScalarVar Maybe (Exp lore)
e ScalarEntry :: PrimType -> ScalarEntry
ScalarEntry {entryScalarType :: PrimType
entryScalarType = PrimType
bt}
memBoundToVarEntry Maybe (Exp lore)
e (MemMem Space
space) =
Maybe (Exp lore) -> MemEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> MemEntry -> VarEntry lore
MemVar Maybe (Exp lore)
e (MemEntry -> VarEntry lore) -> MemEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
space
memBoundToVarEntry Maybe (Exp lore)
e (MemArray PrimType
bt ShapeBase DimSize
shape NoUniqueness
_ (ArrayIn VName
mem IxFun
ixfun)) =
let location :: MemLocation
location = VName -> [DimSize] -> IxFun (TExp Int64) -> MemLocation
MemLocation VName
mem (ShapeBase DimSize -> [DimSize]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase DimSize
shape) (IxFun (TExp Int64) -> MemLocation)
-> IxFun (TExp Int64) -> MemLocation
forall a b. (a -> b) -> a -> b
$ (TPrimExp Int64 VName -> TExp Int64) -> IxFun -> IxFun (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((VName -> ExpLeaf) -> TPrimExp Int64 VName -> TExp Int64
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> ExpLeaf
Imp.ScalarVar) IxFun
ixfun
in Maybe (Exp lore) -> ArrayEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> ArrayEntry -> VarEntry lore
ArrayVar
Maybe (Exp lore)
e
ArrayEntry :: MemLocation -> PrimType -> ArrayEntry
ArrayEntry
{ entryArrayLocation :: MemLocation
entryArrayLocation = MemLocation
location,
entryArrayElemType :: PrimType
entryArrayElemType = PrimType
bt
}
infoDec ::
Mem lore =>
NameInfo lore ->
MemInfo SubExp NoUniqueness MemBind
infoDec :: NameInfo lore -> MemBound NoUniqueness
infoDec (LetName LetDec lore
dec) = LetDec lore
MemBound NoUniqueness
dec
infoDec (FParamName FParamInfo lore
dec) = MemInfo DimSize Uniqueness MemBind -> MemBound NoUniqueness
forall d u r. MemInfo d u r -> MemInfo d NoUniqueness r
noUniquenessReturns FParamInfo lore
MemInfo DimSize Uniqueness MemBind
dec
infoDec (LParamName LParamInfo lore
dec) = LParamInfo lore
MemBound NoUniqueness
dec
infoDec (IndexName IntType
it) = PrimType -> MemBound NoUniqueness
forall d u ret. PrimType -> MemInfo d u ret
MemPrim (PrimType -> MemBound NoUniqueness)
-> PrimType -> MemBound NoUniqueness
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
it
dInfo ::
Mem lore =>
Maybe (Exp lore) ->
VName ->
NameInfo lore ->
ImpM lore r op ()
dInfo :: Maybe (Exp lore) -> VName -> NameInfo lore -> ImpM lore r op ()
dInfo Maybe (Exp lore)
e VName
name NameInfo lore
info = do
let entry :: VarEntry lore
entry = Maybe (Exp lore) -> MemBound NoUniqueness -> VarEntry lore
forall lore.
Maybe (Exp lore) -> MemBound NoUniqueness -> VarEntry lore
memBoundToVarEntry Maybe (Exp lore)
e (MemBound NoUniqueness -> VarEntry lore)
-> MemBound NoUniqueness -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ NameInfo lore -> MemBound NoUniqueness
forall lore. Mem lore => NameInfo lore -> MemBound NoUniqueness
infoDec NameInfo lore
info
case VarEntry lore
entry of
MemVar Maybe (Exp lore)
_ MemEntry
entry' ->
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Code op
forall a. VName -> Space -> Code a
Imp.DeclareMem VName
name (Space -> Code op) -> Space -> Code op
forall a b. (a -> b) -> a -> b
$ MemEntry -> Space
entryMemSpace MemEntry
entry'
ScalarVar Maybe (Exp lore)
_ ScalarEntry
entry' ->
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Volatility -> PrimType -> Code op
forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
name Volatility
Imp.Nonvolatile (PrimType -> Code op) -> PrimType -> Code op
forall a b. (a -> b) -> a -> b
$ ScalarEntry -> PrimType
entryScalarType ScalarEntry
entry'
ArrayVar Maybe (Exp lore)
_ ArrayEntry
_ ->
() -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name VarEntry lore
entry
dScope ::
Mem lore =>
Maybe (Exp lore) ->
Scope lore ->
ImpM lore r op ()
dScope :: Maybe (Exp lore) -> Scope lore -> ImpM lore r op ()
dScope Maybe (Exp lore)
e = ((VName, NameInfo lore) -> ImpM lore r op ())
-> [(VName, NameInfo lore)] -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((VName -> NameInfo lore -> ImpM lore r op ())
-> (VName, NameInfo lore) -> ImpM lore r op ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((VName -> NameInfo lore -> ImpM lore r op ())
-> (VName, NameInfo lore) -> ImpM lore r op ())
-> (VName -> NameInfo lore -> ImpM lore r op ())
-> (VName, NameInfo lore)
-> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> VName -> NameInfo lore -> ImpM lore r op ()
forall lore r op.
Mem lore =>
Maybe (Exp lore) -> VName -> NameInfo lore -> ImpM lore r op ()
dInfo Maybe (Exp lore)
e) ([(VName, NameInfo lore)] -> ImpM lore r op ())
-> (Scope lore -> [(VName, NameInfo lore)])
-> Scope lore
-> ImpM lore r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope lore -> [(VName, NameInfo lore)]
forall k a. Map k a -> [(k, a)]
M.toList
dArray :: VName -> PrimType -> ShapeBase SubExp -> MemBind -> ImpM lore r op ()
dArray :: VName
-> PrimType -> ShapeBase DimSize -> MemBind -> ImpM lore r op ()
dArray VName
name PrimType
bt ShapeBase DimSize
shape MemBind
membind =
VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
Maybe (Exp lore) -> MemBound NoUniqueness -> VarEntry lore
forall lore.
Maybe (Exp lore) -> MemBound NoUniqueness -> VarEntry lore
memBoundToVarEntry Maybe (Exp lore)
forall a. Maybe a
Nothing (MemBound NoUniqueness -> VarEntry lore)
-> MemBound NoUniqueness -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ PrimType
-> ShapeBase DimSize
-> NoUniqueness
-> MemBind
-> MemBound NoUniqueness
forall d u ret.
PrimType -> ShapeBase d -> u -> ret -> MemInfo d u ret
MemArray PrimType
bt ShapeBase DimSize
shape NoUniqueness
NoUniqueness MemBind
membind
everythingVolatile :: ImpM lore r op a -> ImpM lore r op a
everythingVolatile :: ImpM lore r op a -> ImpM lore r op a
everythingVolatile = (Env lore r op -> Env lore r op)
-> ImpM lore r op a -> ImpM lore r op a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((Env lore r op -> Env lore r op)
-> ImpM lore r op a -> ImpM lore r op a)
-> (Env lore r op -> Env lore r op)
-> ImpM lore r op a
-> ImpM lore r op a
forall a b. (a -> b) -> a -> b
$ \Env lore r op
env -> Env lore r op
env {envVolatility :: Volatility
envVolatility = Volatility
Imp.Volatile}
funcallTargets :: Destination -> ImpM lore r op [VName]
funcallTargets :: Destination -> ImpM lore r op [VName]
funcallTargets (Destination Maybe Int
_ [ValueDestination]
dests) =
[[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[VName]] -> [VName])
-> ImpM lore r op [[VName]] -> ImpM lore r op [VName]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ValueDestination -> ImpM lore r op [VName])
-> [ValueDestination] -> ImpM lore r op [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ValueDestination -> ImpM lore r op [VName]
forall (m :: * -> *). Monad m => ValueDestination -> m [VName]
funcallTarget [ValueDestination]
dests
where
funcallTarget :: ValueDestination -> m [VName]
funcallTarget (ScalarDestination VName
name) =
[VName] -> m [VName]
forall (m :: * -> *) a. Monad m => a -> m a
return [VName
name]
funcallTarget (ArrayDestination Maybe MemLocation
_) =
[VName] -> m [VName]
forall (m :: * -> *) a. Monad m => a -> m a
return []
funcallTarget (MemoryDestination VName
name) =
[VName] -> m [VName]
forall (m :: * -> *) a. Monad m => a -> m a
return [VName
name]
data TV t = TV VName PrimType
mkTV :: VName -> PrimType -> TV t
mkTV :: VName -> PrimType -> TV t
mkTV = VName -> PrimType -> TV t
forall t. VName -> PrimType -> TV t
TV
tvSize :: TV t -> Imp.DimSize
tvSize :: TV t -> DimSize
tvSize = VName -> DimSize
Var (VName -> DimSize) -> (TV t -> VName) -> TV t -> DimSize
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TV t -> VName
forall t. TV t -> VName
tvVar
tvExp :: TV t -> Imp.TExp t
tvExp :: TV t -> TExp t
tvExp (TV VName
v PrimType
t) = Exp -> TExp t
forall t v. PrimExp v -> TPrimExp t v
Imp.TPrimExp (Exp -> TExp t) -> Exp -> TExp t
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
v PrimType
t
tvVar :: TV t -> VName
tvVar :: TV t -> VName
tvVar (TV VName
v PrimType
_) = VName
v
class ToExp a where
toExp :: a -> ImpM lore r op Imp.Exp
toExp' :: PrimType -> a -> Imp.Exp
toInt32Exp :: a -> Imp.TExp Int32
toInt32Exp = Exp -> TExp Int32
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (Exp -> TExp Int32) -> (a -> Exp) -> a -> TExp Int32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> a -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32
toInt64Exp :: a -> Imp.TExp Int64
toInt64Exp = Exp -> TExp Int64
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (Exp -> TExp Int64) -> (a -> Exp) -> a -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> a -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int64
toBoolExp :: a -> Imp.TExp Bool
toBoolExp = Exp -> TExp Bool
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (Exp -> TExp Bool) -> (a -> Exp) -> a -> TExp Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> a -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
Bool
instance ToExp SubExp where
toExp :: DimSize -> ImpM lore r op Exp
toExp (Constant PrimValue
v) =
Exp -> ImpM lore r op Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> ImpM lore r op Exp) -> Exp -> ImpM lore r op Exp
forall a b. (a -> b) -> a -> b
$ PrimValue -> Exp
forall v. PrimValue -> PrimExp v
Imp.ValueExp PrimValue
v
toExp (Var VName
v) =
VName -> ImpM lore r op (VarEntry lore)
forall lore r op. VName -> ImpM lore r op (VarEntry lore)
lookupVar VName
v ImpM lore r op (VarEntry lore)
-> (VarEntry lore -> ImpM lore r op Exp) -> ImpM lore r op Exp
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
ScalarVar Maybe (Exp lore)
_ (ScalarEntry PrimType
pt) ->
Exp -> ImpM lore r op Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> ImpM lore r op Exp) -> Exp -> ImpM lore r op Exp
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
v PrimType
pt
VarEntry lore
_ -> String -> ImpM lore r op Exp
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op Exp) -> String -> ImpM lore r op Exp
forall a b. (a -> b) -> a -> b
$ String
"toExp SubExp: SubExp is not a primitive type: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
v
toExp' :: PrimType -> DimSize -> Exp
toExp' PrimType
_ (Constant PrimValue
v) = PrimValue -> Exp
forall v. PrimValue -> PrimExp v
Imp.ValueExp PrimValue
v
toExp' PrimType
t (Var VName
v) = VName -> PrimType -> Exp
Imp.var VName
v PrimType
t
instance ToExp (PrimExp VName) where
toExp :: PrimExp VName -> ImpM lore r op Exp
toExp = Exp -> ImpM lore r op Exp
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp -> ImpM lore r op Exp)
-> (PrimExp VName -> Exp) -> PrimExp VName -> ImpM lore r op Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> ExpLeaf) -> PrimExp VName -> Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> ExpLeaf
Imp.ScalarVar
toExp' :: PrimType -> PrimExp VName -> Exp
toExp' PrimType
_ = (VName -> ExpLeaf) -> PrimExp VName -> Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> ExpLeaf
Imp.ScalarVar
addVar :: VName -> VarEntry lore -> ImpM lore r op ()
addVar :: VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name VarEntry lore
entry =
(ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ())
-> (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState lore r op
s -> ImpState lore r op
s {stateVTable :: VTable lore
stateVTable = VName -> VarEntry lore -> VTable lore -> VTable lore
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
name VarEntry lore
entry (VTable lore -> VTable lore) -> VTable lore -> VTable lore
forall a b. (a -> b) -> a -> b
$ ImpState lore r op -> VTable lore
forall lore r op. ImpState lore r op -> VTable lore
stateVTable ImpState lore r op
s}
localDefaultSpace :: Imp.Space -> ImpM lore r op a -> ImpM lore r op a
localDefaultSpace :: Space -> ImpM lore r op a -> ImpM lore r op a
localDefaultSpace Space
space = (Env lore r op -> Env lore r op)
-> ImpM lore r op a -> ImpM lore r op a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (\Env lore r op
env -> Env lore r op
env {envDefaultSpace :: Space
envDefaultSpace = Space
space})
askFunction :: ImpM lore r op (Maybe Name)
askFunction :: ImpM lore r op (Maybe Name)
askFunction = (Env lore r op -> Maybe Name) -> ImpM lore r op (Maybe Name)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Maybe Name
forall lore r op. Env lore r op -> Maybe Name
envFunction
newVNameForFun :: String -> ImpM lore r op VName
newVNameForFun :: String -> ImpM lore r op VName
newVNameForFun String
s = do
Maybe String
fname <- (Name -> String) -> Maybe Name -> Maybe String
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Name -> String
nameToString (Maybe Name -> Maybe String)
-> ImpM lore r op (Maybe Name) -> ImpM lore r op (Maybe String)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM lore r op (Maybe Name)
forall lore r op. ImpM lore r op (Maybe Name)
askFunction
String -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> ImpM lore r op VName) -> String -> ImpM lore r op VName
forall a b. (a -> b) -> a -> b
$ String -> ShowS -> Maybe String -> String
forall b a. b -> (a -> b) -> Maybe a -> b
maybe String
"" (String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
".") Maybe String
fname String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
s
nameForFun :: String -> ImpM lore r op Name
nameForFun :: String -> ImpM lore r op Name
nameForFun String
s = do
Maybe Name
fname <- ImpM lore r op (Maybe Name)
forall lore r op. ImpM lore r op (Maybe Name)
askFunction
Name -> ImpM lore r op Name
forall (m :: * -> *) a. Monad m => a -> m a
return (Name -> ImpM lore r op Name) -> Name -> ImpM lore r op Name
forall a b. (a -> b) -> a -> b
$ Name -> (Name -> Name) -> Maybe Name -> Name
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Name
"" (Name -> Name -> Name
forall a. Semigroup a => a -> a -> a
<> Name
".") Maybe Name
fname Name -> Name -> Name
forall a. Semigroup a => a -> a -> a
<> String -> Name
nameFromString String
s
askEnv :: ImpM lore r op r
askEnv :: ImpM lore r op r
askEnv = (Env lore r op -> r) -> ImpM lore r op r
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> r
forall lore r op. Env lore r op -> r
envEnv
localEnv :: (r -> r) -> ImpM lore r op a -> ImpM lore r op a
localEnv :: (r -> r) -> ImpM lore r op a -> ImpM lore r op a
localEnv r -> r
f = (Env lore r op -> Env lore r op)
-> ImpM lore r op a -> ImpM lore r op a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((Env lore r op -> Env lore r op)
-> ImpM lore r op a -> ImpM lore r op a)
-> (Env lore r op -> Env lore r op)
-> ImpM lore r op a
-> ImpM lore r op a
forall a b. (a -> b) -> a -> b
$ \Env lore r op
env -> Env lore r op
env {envEnv :: r
envEnv = r -> r
f (r -> r) -> r -> r
forall a b. (a -> b) -> a -> b
$ Env lore r op -> r
forall lore r op. Env lore r op -> r
envEnv Env lore r op
env}
askAttrs :: ImpM lore r op Attrs
askAttrs :: ImpM lore r op Attrs
askAttrs = (Env lore r op -> Attrs) -> ImpM lore r op Attrs
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Attrs
forall lore r op. Env lore r op -> Attrs
envAttrs
localAttrs :: Attrs -> ImpM lore r op a -> ImpM lore r op a
localAttrs :: Attrs -> ImpM lore r op a -> ImpM lore r op a
localAttrs Attrs
attrs = (Env lore r op -> Env lore r op)
-> ImpM lore r op a -> ImpM lore r op a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((Env lore r op -> Env lore r op)
-> ImpM lore r op a -> ImpM lore r op a)
-> (Env lore r op -> Env lore r op)
-> ImpM lore r op a
-> ImpM lore r op a
forall a b. (a -> b) -> a -> b
$ \Env lore r op
env -> Env lore r op
env {envAttrs :: Attrs
envAttrs = Attrs
attrs Attrs -> Attrs -> Attrs
forall a. Semigroup a => a -> a -> a
<> Env lore r op -> Attrs
forall lore r op. Env lore r op -> Attrs
envAttrs Env lore r op
env}
localOps :: Operations lore r op -> ImpM lore r op a -> ImpM lore r op a
localOps :: Operations lore r op -> ImpM lore r op a -> ImpM lore r op a
localOps Operations lore r op
ops = (Env lore r op -> Env lore r op)
-> ImpM lore r op a -> ImpM lore r op a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((Env lore r op -> Env lore r op)
-> ImpM lore r op a -> ImpM lore r op a)
-> (Env lore r op -> Env lore r op)
-> ImpM lore r op a
-> ImpM lore r op a
forall a b. (a -> b) -> a -> b
$ \Env lore r op
env ->
Env lore r op
env
{ envExpCompiler :: ExpCompiler lore r op
envExpCompiler = Operations lore r op -> ExpCompiler lore r op
forall lore r op. Operations lore r op -> ExpCompiler lore r op
opsExpCompiler Operations lore r op
ops,
envStmsCompiler :: StmsCompiler lore r op
envStmsCompiler = Operations lore r op -> StmsCompiler lore r op
forall lore r op. Operations lore r op -> StmsCompiler lore r op
opsStmsCompiler Operations lore r op
ops,
envCopyCompiler :: CopyCompiler lore r op
envCopyCompiler = Operations lore r op -> CopyCompiler lore r op
forall lore r op. Operations lore r op -> CopyCompiler lore r op
opsCopyCompiler Operations lore r op
ops,
envOpCompiler :: OpCompiler lore r op
envOpCompiler = Operations lore r op -> OpCompiler lore r op
forall lore r op. Operations lore r op -> OpCompiler lore r op
opsOpCompiler Operations lore r op
ops,
envAllocCompilers :: Map Space (AllocCompiler lore r op)
envAllocCompilers = Operations lore r op -> Map Space (AllocCompiler lore r op)
forall lore r op.
Operations lore r op -> Map Space (AllocCompiler lore r op)
opsAllocCompilers Operations lore r op
ops
}
getVTable :: ImpM lore r op (VTable lore)
getVTable :: ImpM lore r op (VTable lore)
getVTable = (ImpState lore r op -> VTable lore) -> ImpM lore r op (VTable lore)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ImpState lore r op -> VTable lore
forall lore r op. ImpState lore r op -> VTable lore
stateVTable
putVTable :: VTable lore -> ImpM lore r op ()
putVTable :: VTable lore -> ImpM lore r op ()
putVTable VTable lore
vtable = (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ())
-> (ImpState lore r op -> ImpState lore r op) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \ImpState lore r op
s -> ImpState lore r op
s {stateVTable :: VTable lore
stateVTable = VTable lore
vtable}
localVTable :: (VTable lore -> VTable lore) -> ImpM lore r op a -> ImpM lore r op a
localVTable :: (VTable lore -> VTable lore)
-> ImpM lore r op a -> ImpM lore r op a
localVTable VTable lore -> VTable lore
f ImpM lore r op a
m = do
VTable lore
old_vtable <- ImpM lore r op (VTable lore)
forall lore r op. ImpM lore r op (VTable lore)
getVTable
VTable lore -> ImpM lore r op ()
forall lore r op. VTable lore -> ImpM lore r op ()
putVTable (VTable lore -> ImpM lore r op ())
-> VTable lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VTable lore -> VTable lore
f VTable lore
old_vtable
a
a <- ImpM lore r op a
m
VTable lore -> ImpM lore r op ()
forall lore r op. VTable lore -> ImpM lore r op ()
putVTable VTable lore
old_vtable
a -> ImpM lore r op a
forall (m :: * -> *) a. Monad m => a -> m a
return a
a
lookupVar :: VName -> ImpM lore r op (VarEntry lore)
lookupVar :: VName -> ImpM lore r op (VarEntry lore)
lookupVar VName
name = do
Maybe (VarEntry lore)
res <- (ImpState lore r op -> Maybe (VarEntry lore))
-> ImpM lore r op (Maybe (VarEntry lore))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ImpState lore r op -> Maybe (VarEntry lore))
-> ImpM lore r op (Maybe (VarEntry lore)))
-> (ImpState lore r op -> Maybe (VarEntry lore))
-> ImpM lore r op (Maybe (VarEntry lore))
forall a b. (a -> b) -> a -> b
$ VName -> Map VName (VarEntry lore) -> Maybe (VarEntry lore)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name (Map VName (VarEntry lore) -> Maybe (VarEntry lore))
-> (ImpState lore r op -> Map VName (VarEntry lore))
-> ImpState lore r op
-> Maybe (VarEntry lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ImpState lore r op -> Map VName (VarEntry lore)
forall lore r op. ImpState lore r op -> VTable lore
stateVTable
case Maybe (VarEntry lore)
res of
Just VarEntry lore
entry -> VarEntry lore -> ImpM lore r op (VarEntry lore)
forall (m :: * -> *) a. Monad m => a -> m a
return VarEntry lore
entry
Maybe (VarEntry lore)
_ -> String -> ImpM lore r op (VarEntry lore)
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op (VarEntry lore))
-> String -> ImpM lore r op (VarEntry lore)
forall a b. (a -> b) -> a -> b
$ String
"Unknown variable: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
name
lookupArray :: VName -> ImpM lore r op ArrayEntry
lookupArray :: VName -> ImpM lore r op ArrayEntry
lookupArray VName
name = do
VarEntry lore
res <- VName -> ImpM lore r op (VarEntry lore)
forall lore r op. VName -> ImpM lore r op (VarEntry lore)
lookupVar VName
name
case VarEntry lore
res of
ArrayVar Maybe (Exp lore)
_ ArrayEntry
entry -> ArrayEntry -> ImpM lore r op ArrayEntry
forall (m :: * -> *) a. Monad m => a -> m a
return ArrayEntry
entry
VarEntry lore
_ -> String -> ImpM lore r op ArrayEntry
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op ArrayEntry)
-> String -> ImpM lore r op ArrayEntry
forall a b. (a -> b) -> a -> b
$ String
"ImpGen.lookupArray: not an array: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
name
lookupMemory :: VName -> ImpM lore r op MemEntry
lookupMemory :: VName -> ImpM lore r op MemEntry
lookupMemory VName
name = do
VarEntry lore
res <- VName -> ImpM lore r op (VarEntry lore)
forall lore r op. VName -> ImpM lore r op (VarEntry lore)
lookupVar VName
name
case VarEntry lore
res of
MemVar Maybe (Exp lore)
_ MemEntry
entry -> MemEntry -> ImpM lore r op MemEntry
forall (m :: * -> *) a. Monad m => a -> m a
return MemEntry
entry
VarEntry lore
_ -> String -> ImpM lore r op MemEntry
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op MemEntry)
-> String -> ImpM lore r op MemEntry
forall a b. (a -> b) -> a -> b
$ String
"Unknown memory block: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
name
destinationFromPattern :: Mem lore => Pattern lore -> ImpM lore r op Destination
destinationFromPattern :: Pattern lore -> ImpM lore r op Destination
destinationFromPattern Pattern lore
pat =
([ValueDestination] -> Destination)
-> ImpM lore r op [ValueDestination] -> ImpM lore r op Destination
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Maybe Int -> [ValueDestination] -> Destination
Destination (VName -> Int
baseTag (VName -> Int) -> Maybe VName -> Maybe Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [VName] -> Maybe VName
forall a. [a] -> Maybe a
maybeHead (PatternT (MemBound NoUniqueness) -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern lore
PatternT (MemBound NoUniqueness)
pat))) (ImpM lore r op [ValueDestination] -> ImpM lore r op Destination)
-> ([PatElemT (MemBound NoUniqueness)]
-> ImpM lore r op [ValueDestination])
-> [PatElemT (MemBound NoUniqueness)]
-> ImpM lore r op Destination
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElemT (MemBound NoUniqueness)
-> ImpM lore r op ValueDestination)
-> [PatElemT (MemBound NoUniqueness)]
-> ImpM lore r op [ValueDestination]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM PatElemT (MemBound NoUniqueness) -> ImpM lore r op ValueDestination
forall dec lore r op.
PatElemT dec -> ImpM lore r op ValueDestination
inspect ([PatElemT (MemBound NoUniqueness)] -> ImpM lore r op Destination)
-> [PatElemT (MemBound NoUniqueness)] -> ImpM lore r op Destination
forall a b. (a -> b) -> a -> b
$
PatternT (MemBound NoUniqueness)
-> [PatElemT (MemBound NoUniqueness)]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern lore
PatternT (MemBound NoUniqueness)
pat
where
inspect :: PatElemT dec -> ImpM lore r op ValueDestination
inspect PatElemT dec
patElem = do
let name :: VName
name = PatElemT dec -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT dec
patElem
VarEntry lore
entry <- VName -> ImpM lore r op (VarEntry lore)
forall lore r op. VName -> ImpM lore r op (VarEntry lore)
lookupVar VName
name
case VarEntry lore
entry of
ArrayVar Maybe (Exp lore)
_ (ArrayEntry MemLocation {} PrimType
_) ->
ValueDestination -> ImpM lore r op ValueDestination
forall (m :: * -> *) a. Monad m => a -> m a
return (ValueDestination -> ImpM lore r op ValueDestination)
-> ValueDestination -> ImpM lore r op ValueDestination
forall a b. (a -> b) -> a -> b
$ Maybe MemLocation -> ValueDestination
ArrayDestination Maybe MemLocation
forall a. Maybe a
Nothing
MemVar {} ->
ValueDestination -> ImpM lore r op ValueDestination
forall (m :: * -> *) a. Monad m => a -> m a
return (ValueDestination -> ImpM lore r op ValueDestination)
-> ValueDestination -> ImpM lore r op ValueDestination
forall a b. (a -> b) -> a -> b
$ VName -> ValueDestination
MemoryDestination VName
name
ScalarVar {} ->
ValueDestination -> ImpM lore r op ValueDestination
forall (m :: * -> *) a. Monad m => a -> m a
return (ValueDestination -> ImpM lore r op ValueDestination)
-> ValueDestination -> ImpM lore r op ValueDestination
forall a b. (a -> b) -> a -> b
$ VName -> ValueDestination
ScalarDestination VName
name
fullyIndexArray ::
VName ->
[Imp.TExp Int64] ->
ImpM lore r op (VName, Imp.Space, Count Elements (Imp.TExp Int64))
fullyIndexArray :: VName
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray VName
name Shape (TExp Int64)
indices = do
ArrayEntry
arr <- VName -> ImpM lore r op ArrayEntry
forall lore r op. VName -> ImpM lore r op ArrayEntry
lookupArray VName
name
MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall lore r op.
MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' (ArrayEntry -> MemLocation
entryArrayLocation ArrayEntry
arr) Shape (TExp Int64)
indices
fullyIndexArray' ::
MemLocation ->
[Imp.TExp Int64] ->
ImpM lore r op (VName, Imp.Space, Count Elements (Imp.TExp Int64))
fullyIndexArray' :: MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' (MemLocation VName
mem [DimSize]
_ IxFun (TExp Int64)
ixfun) Shape (TExp Int64)
indices = do
Space
space <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM lore r op MemEntry -> ImpM lore r op Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM lore r op MemEntry
forall lore r op. VName -> ImpM lore r op MemEntry
lookupMemory VName
mem
let indices' :: Shape (TExp Int64)
indices' = case Space
space of
ScalarSpace [DimSize]
ds PrimType
_ ->
let (Shape (TExp Int64)
zero_is, Shape (TExp Int64)
is) = Int
-> Shape (TExp Int64) -> (Shape (TExp Int64), Shape (TExp Int64))
forall a. Int -> [a] -> ([a], [a])
splitFromEnd ([DimSize] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimSize]
ds) Shape (TExp Int64)
indices
in (TExp Int64 -> TExp Int64)
-> Shape (TExp Int64) -> Shape (TExp Int64)
forall a b. (a -> b) -> [a] -> [b]
map (TExp Int64 -> TExp Int64 -> TExp Int64
forall a b. a -> b -> a
const TExp Int64
0) Shape (TExp Int64)
zero_is Shape (TExp Int64) -> Shape (TExp Int64) -> Shape (TExp Int64)
forall a. [a] -> [a] -> [a]
++ Shape (TExp Int64)
is
Space
_ -> Shape (TExp Int64)
indices
(VName, Space, Count Elements (TExp Int64))
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall (m :: * -> *) a. Monad m => a -> m a
return
( VName
mem,
Space
space,
TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
elements (TExp Int64 -> Count Elements (TExp Int64))
-> TExp Int64 -> Count Elements (TExp Int64)
forall a b. (a -> b) -> a -> b
$ IxFun (TExp Int64) -> Shape (TExp Int64) -> TExp Int64
forall num.
(IntegralExp num, Eq num) =>
IxFun num -> Indices num -> num
IxFun.index IxFun (TExp Int64)
ixfun Shape (TExp Int64)
indices'
)
copy :: CopyCompiler lore r op
copy :: CopyCompiler lore r op
copy PrimType
bt MemLocation
dest [DimIndex (TExp Int64)]
destslice MemLocation
src [DimIndex (TExp Int64)]
srcslice = do
CopyCompiler lore r op
cc <- (Env lore r op -> CopyCompiler lore r op)
-> ImpM lore r op (CopyCompiler lore r op)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> CopyCompiler lore r op
forall lore r op. Env lore r op -> CopyCompiler lore r op
envCopyCompiler
CopyCompiler lore r op
cc PrimType
bt MemLocation
dest [DimIndex (TExp Int64)]
destslice MemLocation
src [DimIndex (TExp Int64)]
srcslice
isMapTransposeCopy ::
PrimType ->
MemLocation ->
Slice (Imp.TExp Int64) ->
MemLocation ->
Slice (Imp.TExp Int64) ->
Maybe
( Imp.TExp Int64,
Imp.TExp Int64,
Imp.TExp Int64,
Imp.TExp Int64,
Imp.TExp Int64
)
isMapTransposeCopy :: PrimType
-> MemLocation
-> [DimIndex (TExp Int64)]
-> MemLocation
-> [DimIndex (TExp Int64)]
-> Maybe
(TExp Int64, TExp Int64, TExp Int64, TExp Int64, TExp Int64)
isMapTransposeCopy
PrimType
bt
(MemLocation VName
_ [DimSize]
_ IxFun (TExp Int64)
destIxFun)
[DimIndex (TExp Int64)]
destslice
(MemLocation VName
_ [DimSize]
_ IxFun (TExp Int64)
srcIxFun)
[DimIndex (TExp Int64)]
srcslice
| Just (TExp Int64
dest_offset, [(Int, TExp Int64)]
perm_and_destshape) <- IxFun (TExp Int64)
-> TExp Int64 -> Maybe (TExp Int64, [(Int, TExp Int64)])
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe (num, [(Int, num)])
IxFun.rearrangeWithOffset IxFun (TExp Int64)
destIxFun' TExp Int64
bt_size,
([Int]
perm, Shape (TExp Int64)
destshape) <- [(Int, TExp Int64)] -> ([Int], Shape (TExp Int64))
forall a b. [(a, b)] -> ([a], [b])
unzip [(Int, TExp Int64)]
perm_and_destshape,
Just TExp Int64
src_offset <- IxFun (TExp Int64) -> TExp Int64 -> Maybe (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
IxFun.linearWithOffset IxFun (TExp Int64)
srcIxFun' TExp Int64
bt_size,
Just (Int
r1, Int
r2, Int
_) <- [Int] -> Maybe (Int, Int, Int)
isMapTranspose [Int]
perm =
Shape (TExp Int64)
-> ((Shape (TExp Int64), Shape (TExp Int64))
-> (Shape (TExp Int64), Shape (TExp Int64)))
-> Int
-> Int
-> TExp Int64
-> TExp Int64
-> Maybe
(TExp Int64, TExp Int64, TExp Int64, TExp Int64, TExp Int64)
forall (t :: * -> *) (t :: * -> *) c d e (m :: * -> *) a b.
(Foldable t, Foldable t, Num c, Num d, Num e, Monad m) =>
[c]
-> (([c], [c]) -> (t d, t e))
-> Int
-> Int
-> a
-> b
-> m (a, b, c, d, e)
isOk Shape (TExp Int64)
destshape (Shape (TExp Int64), Shape (TExp Int64))
-> (Shape (TExp Int64), Shape (TExp Int64))
forall b a. (b, a) -> (a, b)
swap Int
r1 Int
r2 TExp Int64
dest_offset TExp Int64
src_offset
| Just TExp Int64
dest_offset <- IxFun (TExp Int64) -> TExp Int64 -> Maybe (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
IxFun.linearWithOffset IxFun (TExp Int64)
destIxFun' TExp Int64
bt_size,
Just (TExp Int64
src_offset, [(Int, TExp Int64)]
perm_and_srcshape) <- IxFun (TExp Int64)
-> TExp Int64 -> Maybe (TExp Int64, [(Int, TExp Int64)])
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe (num, [(Int, num)])
IxFun.rearrangeWithOffset IxFun (TExp Int64)
srcIxFun' TExp Int64
bt_size,
([Int]
perm, Shape (TExp Int64)
srcshape) <- [(Int, TExp Int64)] -> ([Int], Shape (TExp Int64))
forall a b. [(a, b)] -> ([a], [b])
unzip [(Int, TExp Int64)]
perm_and_srcshape,
Just (Int
r1, Int
r2, Int
_) <- [Int] -> Maybe (Int, Int, Int)
isMapTranspose [Int]
perm =
Shape (TExp Int64)
-> ((Shape (TExp Int64), Shape (TExp Int64))
-> (Shape (TExp Int64), Shape (TExp Int64)))
-> Int
-> Int
-> TExp Int64
-> TExp Int64
-> Maybe
(TExp Int64, TExp Int64, TExp Int64, TExp Int64, TExp Int64)
forall (t :: * -> *) (t :: * -> *) c d e (m :: * -> *) a b.
(Foldable t, Foldable t, Num c, Num d, Num e, Monad m) =>
[c]
-> (([c], [c]) -> (t d, t e))
-> Int
-> Int
-> a
-> b
-> m (a, b, c, d, e)
isOk Shape (TExp Int64)
srcshape (Shape (TExp Int64), Shape (TExp Int64))
-> (Shape (TExp Int64), Shape (TExp Int64))
forall a. a -> a
id Int
r1 Int
r2 TExp Int64
dest_offset TExp Int64
src_offset
| Bool
otherwise =
Maybe (TExp Int64, TExp Int64, TExp Int64, TExp Int64, TExp Int64)
forall a. Maybe a
Nothing
where
bt_size :: TExp Int64
bt_size = PrimType -> TExp Int64
forall a. Num a => PrimType -> a
primByteSize PrimType
bt
swap :: (b, a) -> (a, b)
swap (b
x, a
y) = (a
y, b
x)
destIxFun' :: IxFun (TExp Int64)
destIxFun' = IxFun (TExp Int64) -> [DimIndex (TExp Int64)] -> IxFun (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TExp Int64)
destIxFun [DimIndex (TExp Int64)]
destslice
srcIxFun' :: IxFun (TExp Int64)
srcIxFun' = IxFun (TExp Int64) -> [DimIndex (TExp Int64)] -> IxFun (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TExp Int64)
srcIxFun [DimIndex (TExp Int64)]
srcslice
isOk :: [c]
-> (([c], [c]) -> (t d, t e))
-> Int
-> Int
-> a
-> b
-> m (a, b, c, d, e)
isOk [c]
shape ([c], [c]) -> (t d, t e)
f Int
r1 Int
r2 a
dest_offset b
src_offset = do
let (c
num_arrays, d
size_x, e
size_y) = [c] -> (([c], [c]) -> (t d, t e)) -> Int -> Int -> (c, d, e)
forall (t :: * -> *) (t :: * -> *) a b c.
(Foldable t, Foldable t, Num a, Num b, Num c) =>
[a] -> (([a], [a]) -> (t b, t c)) -> Int -> Int -> (a, b, c)
getSizes [c]
shape ([c], [c]) -> (t d, t e)
f Int
r1 Int
r2
(a, b, c, d, e) -> m (a, b, c, d, e)
forall (m :: * -> *) a. Monad m => a -> m a
return
( a
dest_offset,
b
src_offset,
c
num_arrays,
d
size_x,
e
size_y
)
getSizes :: [a] -> (([a], [a]) -> (t b, t c)) -> Int -> Int -> (a, b, c)
getSizes [a]
shape ([a], [a]) -> (t b, t c)
f Int
r1 Int
r2 =
let ([a]
mapped, [a]
notmapped) = Int -> [a] -> ([a], [a])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
r1 [a]
shape
(t b
pretrans, t c
posttrans) = ([a], [a]) -> (t b, t c)
f (([a], [a]) -> (t b, t c)) -> ([a], [a]) -> (t b, t c)
forall a b. (a -> b) -> a -> b
$ Int -> [a] -> ([a], [a])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
r2 [a]
notmapped
in ([a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [a]
mapped, t b -> b
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product t b
pretrans, t c -> c
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product t c
posttrans)
mapTransposeName :: PrimType -> String
mapTransposeName :: PrimType -> String
mapTransposeName PrimType
bt = String
"map_transpose_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ PrimType -> String
forall a. Pretty a => a -> String
pretty PrimType
bt
mapTransposeForType :: PrimType -> ImpM lore r op Name
mapTransposeForType :: PrimType -> ImpM lore r op Name
mapTransposeForType PrimType
bt = do
let fname :: Name
fname = String -> Name
nameFromString (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ String
"builtin#" String -> ShowS
forall a. Semigroup a => a -> a -> a
<> PrimType -> String
mapTransposeName PrimType
bt
Bool
exists <- Name -> ImpM lore r op Bool
forall lore r op. Name -> ImpM lore r op Bool
hasFunction Name
fname
Bool -> ImpM lore r op () -> ImpM lore r op ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
exists (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Name -> Function op -> ImpM lore r op ()
forall op lore r. Name -> Function op -> ImpM lore r op ()
emitFunction Name
fname (Function op -> ImpM lore r op ())
-> Function op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Name -> PrimType -> Function op
forall op. Name -> PrimType -> Function op
mapTransposeFunction Name
fname PrimType
bt
Name -> ImpM lore r op Name
forall (m :: * -> *) a. Monad m => a -> m a
return Name
fname
defaultCopy :: CopyCompiler lore r op
defaultCopy :: CopyCompiler lore r op
defaultCopy PrimType
pt MemLocation
dest [DimIndex (TExp Int64)]
destslice MemLocation
src [DimIndex (TExp Int64)]
srcslice
| Just
( TExp Int64
destoffset,
TExp Int64
srcoffset,
TExp Int64
num_arrays,
TExp Int64
size_x,
TExp Int64
size_y
) <-
PrimType
-> MemLocation
-> [DimIndex (TExp Int64)]
-> MemLocation
-> [DimIndex (TExp Int64)]
-> Maybe
(TExp Int64, TExp Int64, TExp Int64, TExp Int64, TExp Int64)
isMapTransposeCopy PrimType
pt MemLocation
dest [DimIndex (TExp Int64)]
destslice MemLocation
src [DimIndex (TExp Int64)]
srcslice = do
Name
fname <- PrimType -> ImpM lore r op Name
forall lore r op. PrimType -> ImpM lore r op Name
mapTransposeForType PrimType
pt
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
[VName] -> Name -> [Arg] -> Code op
forall a. [VName] -> Name -> [Arg] -> Code a
Imp.Call
[]
Name
fname
([Arg] -> Code op) -> [Arg] -> Code op
forall a b. (a -> b) -> a -> b
$ PrimType
-> VName
-> Count Bytes (TExp Int64)
-> VName
-> Count Bytes (TExp Int64)
-> TExp Int64
-> TExp Int64
-> TExp Int64
-> [Arg]
transposeArgs
PrimType
pt
VName
destmem
(TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
bytes TExp Int64
destoffset)
VName
srcmem
(TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
bytes TExp Int64
srcoffset)
TExp Int64
num_arrays
TExp Int64
size_x
TExp Int64
size_y
| Just TExp Int64
destoffset <-
IxFun (TExp Int64) -> TExp Int64 -> Maybe (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
IxFun.linearWithOffset (IxFun (TExp Int64) -> [DimIndex (TExp Int64)] -> IxFun (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TExp Int64)
dest_ixfun [DimIndex (TExp Int64)]
destslice) TExp Int64
pt_size,
Just TExp Int64
srcoffset <-
IxFun (TExp Int64) -> TExp Int64 -> Maybe (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
IxFun.linearWithOffset (IxFun (TExp Int64) -> [DimIndex (TExp Int64)] -> IxFun (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TExp Int64)
src_ixfun [DimIndex (TExp Int64)]
srcslice) TExp Int64
pt_size = do
Space
srcspace <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM lore r op MemEntry -> ImpM lore r op Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM lore r op MemEntry
forall lore r op. VName -> ImpM lore r op MemEntry
lookupMemory VName
srcmem
Space
destspace <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM lore r op MemEntry -> ImpM lore r op Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM lore r op MemEntry
forall lore r op. VName -> ImpM lore r op MemEntry
lookupMemory VName
destmem
if Space -> Bool
isScalarSpace Space
srcspace Bool -> Bool -> Bool
|| Space -> Bool
isScalarSpace Space
destspace
then CopyCompiler lore r op
forall lore r op. CopyCompiler lore r op
copyElementWise PrimType
pt MemLocation
dest [DimIndex (TExp Int64)]
destslice MemLocation
src [DimIndex (TExp Int64)]
srcslice
else
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
VName
-> Count Bytes (TExp Int64)
-> Space
-> VName
-> Count Bytes (TExp Int64)
-> Space
-> Count Bytes (TExp Int64)
-> Code op
forall a.
VName
-> Count Bytes (TExp Int64)
-> Space
-> VName
-> Count Bytes (TExp Int64)
-> Space
-> Count Bytes (TExp Int64)
-> Code a
Imp.Copy
VName
destmem
(TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
bytes TExp Int64
destoffset)
Space
destspace
VName
srcmem
(TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
bytes TExp Int64
srcoffset)
Space
srcspace
(Count Bytes (TExp Int64) -> Code op)
-> Count Bytes (TExp Int64) -> Code op
forall a b. (a -> b) -> a -> b
$ Count Elements (TExp Int64)
num_elems Count Elements (TExp Int64) -> PrimType -> Count Bytes (TExp Int64)
`withElemType` PrimType
pt
| Bool
otherwise =
CopyCompiler lore r op
forall lore r op. CopyCompiler lore r op
copyElementWise PrimType
pt MemLocation
dest [DimIndex (TExp Int64)]
destslice MemLocation
src [DimIndex (TExp Int64)]
srcslice
where
pt_size :: TExp Int64
pt_size = PrimType -> TExp Int64
forall a. Num a => PrimType -> a
primByteSize PrimType
pt
num_elems :: Count Elements (TExp Int64)
num_elems = TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
Imp.elements (TExp Int64 -> Count Elements (TExp Int64))
-> TExp Int64 -> Count Elements (TExp Int64)
forall a b. (a -> b) -> a -> b
$ Shape (TExp Int64) -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product (Shape (TExp Int64) -> TExp Int64)
-> Shape (TExp Int64) -> TExp Int64
forall a b. (a -> b) -> a -> b
$ [DimIndex (TExp Int64)] -> Shape (TExp Int64)
forall d. Slice d -> [d]
sliceDims [DimIndex (TExp Int64)]
srcslice
MemLocation VName
destmem [DimSize]
_ IxFun (TExp Int64)
dest_ixfun = MemLocation
dest
MemLocation VName
srcmem [DimSize]
_ IxFun (TExp Int64)
src_ixfun = MemLocation
src
isScalarSpace :: Space -> Bool
isScalarSpace ScalarSpace {} = Bool
True
isScalarSpace Space
_ = Bool
False
copyElementWise :: CopyCompiler lore r op
copyElementWise :: CopyCompiler lore r op
copyElementWise PrimType
bt MemLocation
dest [DimIndex (TExp Int64)]
destslice MemLocation
src [DimIndex (TExp Int64)]
srcslice = do
let bounds :: Shape (TExp Int64)
bounds = [DimIndex (TExp Int64)] -> Shape (TExp Int64)
forall d. Slice d -> [d]
sliceDims [DimIndex (TExp Int64)]
srcslice
[VName]
is <- Int -> ImpM lore r op VName -> ImpM lore r op [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (Shape (TExp Int64) -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape (TExp Int64)
bounds) (String -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"i")
let ivars :: Shape (TExp Int64)
ivars = (VName -> TExp Int64) -> [VName] -> Shape (TExp Int64)
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 [VName]
is
(VName
destmem, Space
destspace, Count Elements (TExp Int64)
destidx) <-
MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall lore r op.
MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLocation
dest (Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64)))
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall a b. (a -> b) -> a -> b
$ [DimIndex (TExp Int64)] -> Shape (TExp Int64) -> Shape (TExp Int64)
forall d. Num d => Slice d -> [d] -> [d]
fixSlice [DimIndex (TExp Int64)]
destslice Shape (TExp Int64)
ivars
(VName
srcmem, Space
srcspace, Count Elements (TExp Int64)
srcidx) <-
MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall lore r op.
MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLocation
src (Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64)))
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall a b. (a -> b) -> a -> b
$ [DimIndex (TExp Int64)] -> Shape (TExp Int64) -> Shape (TExp Int64)
forall d. Num d => Slice d -> [d] -> [d]
fixSlice [DimIndex (TExp Int64)]
srcslice Shape (TExp Int64)
ivars
Volatility
vol <- (Env lore r op -> Volatility) -> ImpM lore r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Volatility
forall lore r op. Env lore r op -> Volatility
envVolatility
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
((Code op -> Code op)
-> (Code op -> Code op) -> Code op -> Code op)
-> (Code op -> Code op)
-> [Code op -> Code op]
-> Code op
-> Code op
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (Code op -> Code op) -> (Code op -> Code op) -> Code op -> Code op
forall b c a. (b -> c) -> (a -> b) -> a -> c
(.) Code op -> Code op
forall a. a -> a
id ((VName -> Exp -> Code op -> Code op)
-> [VName] -> [Exp] -> [Code op -> Code op]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> Exp -> Code op -> Code op
forall a. VName -> Exp -> Code a -> Code a
Imp.For [VName]
is ([Exp] -> [Code op -> Code op]) -> [Exp] -> [Code op -> Code op]
forall a b. (a -> b) -> a -> b
$ (TExp Int64 -> Exp) -> Shape (TExp Int64) -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped Shape (TExp Int64)
bounds) (Code op -> Code op) -> Code op -> Code op
forall a b. (a -> b) -> a -> b
$
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code op
forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
destmem Count Elements (TExp Int64)
destidx PrimType
bt Space
destspace Volatility
vol (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
Imp.index VName
srcmem Count Elements (TExp Int64)
srcidx PrimType
bt Space
srcspace Volatility
vol
copyArrayDWIM ::
PrimType ->
MemLocation ->
[DimIndex (Imp.TExp Int64)] ->
MemLocation ->
[DimIndex (Imp.TExp Int64)] ->
ImpM lore r op (Imp.Code op)
copyArrayDWIM :: PrimType
-> MemLocation
-> [DimIndex (TExp Int64)]
-> MemLocation
-> [DimIndex (TExp Int64)]
-> ImpM lore r op (Code op)
copyArrayDWIM
PrimType
bt
destlocation :: MemLocation
destlocation@(MemLocation VName
_ [DimSize]
destshape IxFun (TExp Int64)
_)
[DimIndex (TExp Int64)]
destslice
srclocation :: MemLocation
srclocation@(MemLocation VName
_ [DimSize]
srcshape IxFun (TExp Int64)
_)
[DimIndex (TExp Int64)]
srcslice
| Just Shape (TExp Int64)
destis <- (DimIndex (TExp Int64) -> Maybe (TExp Int64))
-> [DimIndex (TExp Int64)] -> Maybe (Shape (TExp Int64))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndex (TExp Int64) -> Maybe (TExp Int64)
forall d. DimIndex d -> Maybe d
dimFix [DimIndex (TExp Int64)]
destslice,
Just Shape (TExp Int64)
srcis <- (DimIndex (TExp Int64) -> Maybe (TExp Int64))
-> [DimIndex (TExp Int64)] -> Maybe (Shape (TExp Int64))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndex (TExp Int64) -> Maybe (TExp Int64)
forall d. DimIndex d -> Maybe d
dimFix [DimIndex (TExp Int64)]
srcslice,
Shape (TExp Int64) -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape (TExp Int64)
srcis Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [DimSize] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimSize]
srcshape,
Shape (TExp Int64) -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape (TExp Int64)
destis Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [DimSize] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimSize]
destshape = do
(VName
targetmem, Space
destspace, Count Elements (TExp Int64)
targetoffset) <-
MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall lore r op.
MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLocation
destlocation Shape (TExp Int64)
destis
(VName
srcmem, Space
srcspace, Count Elements (TExp Int64)
srcoffset) <-
MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall lore r op.
MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLocation
srclocation Shape (TExp Int64)
srcis
Volatility
vol <- (Env lore r op -> Volatility) -> ImpM lore r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Volatility
forall lore r op. Env lore r op -> Volatility
envVolatility
Code op -> ImpM lore r op (Code op)
forall (m :: * -> *) a. Monad m => a -> m a
return (Code op -> ImpM lore r op (Code op))
-> Code op -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code op
forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
targetmem Count Elements (TExp Int64)
targetoffset PrimType
bt Space
destspace Volatility
vol (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
Imp.index VName
srcmem Count Elements (TExp Int64)
srcoffset PrimType
bt Space
srcspace Volatility
vol
| Bool
otherwise = do
let destslice' :: [DimIndex (TExp Int64)]
destslice' =
Shape (TExp Int64)
-> [DimIndex (TExp Int64)] -> [DimIndex (TExp Int64)]
forall d. Num d => [d] -> [DimIndex d] -> [DimIndex d]
fullSliceNum ((DimSize -> TExp Int64) -> [DimSize] -> Shape (TExp Int64)
forall a b. (a -> b) -> [a] -> [b]
map DimSize -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [DimSize]
destshape) [DimIndex (TExp Int64)]
destslice
srcslice' :: [DimIndex (TExp Int64)]
srcslice' =
Shape (TExp Int64)
-> [DimIndex (TExp Int64)] -> [DimIndex (TExp Int64)]
forall d. Num d => [d] -> [DimIndex d] -> [DimIndex d]
fullSliceNum ((DimSize -> TExp Int64) -> [DimSize] -> Shape (TExp Int64)
forall a b. (a -> b) -> [a] -> [b]
map DimSize -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [DimSize]
srcshape) [DimIndex (TExp Int64)]
srcslice
destrank :: Int
destrank = Shape (TExp Int64) -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Shape (TExp Int64) -> Int) -> Shape (TExp Int64) -> Int
forall a b. (a -> b) -> a -> b
$ [DimIndex (TExp Int64)] -> Shape (TExp Int64)
forall d. Slice d -> [d]
sliceDims [DimIndex (TExp Int64)]
destslice'
srcrank :: Int
srcrank = Shape (TExp Int64) -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Shape (TExp Int64) -> Int) -> Shape (TExp Int64) -> Int
forall a b. (a -> b) -> a -> b
$ [DimIndex (TExp Int64)] -> Shape (TExp Int64)
forall d. Slice d -> [d]
sliceDims [DimIndex (TExp Int64)]
srcslice'
if Int
destrank Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
srcrank
then
String -> ImpM lore r op (Code op)
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op (Code op))
-> String -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$
String
"copyArrayDWIM: cannot copy to "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty (MemLocation -> VName
memLocationName MemLocation
destlocation)
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" from "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty (MemLocation -> VName
memLocationName MemLocation
srclocation)
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" because ranks do not match ("
String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Pretty a => a -> String
pretty Int
destrank
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" vs "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Pretty a => a -> String
pretty Int
srcrank
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"
else
if MemLocation
destlocation MemLocation -> MemLocation -> Bool
forall a. Eq a => a -> a -> Bool
== MemLocation
srclocation Bool -> Bool -> Bool
&& [DimIndex (TExp Int64)]
destslice' [DimIndex (TExp Int64)] -> [DimIndex (TExp Int64)] -> Bool
forall a. Eq a => a -> a -> Bool
== [DimIndex (TExp Int64)]
srcslice'
then Code op -> ImpM lore r op (Code op)
forall (m :: * -> *) a. Monad m => a -> m a
return Code op
forall a. Monoid a => a
mempty
else ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (ImpM lore r op () -> ImpM lore r op (Code op))
-> ImpM lore r op () -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$ CopyCompiler lore r op
forall lore r op. CopyCompiler lore r op
copy PrimType
bt MemLocation
destlocation [DimIndex (TExp Int64)]
destslice' MemLocation
srclocation [DimIndex (TExp Int64)]
srcslice'
copyDWIMDest ::
ValueDestination ->
[DimIndex (Imp.TExp Int64)] ->
SubExp ->
[DimIndex (Imp.TExp Int64)] ->
ImpM lore r op ()
copyDWIMDest :: ValueDestination
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIMDest ValueDestination
_ [DimIndex (TExp Int64)]
_ (Constant PrimValue
v) (DimIndex (TExp Int64)
_ : [DimIndex (TExp Int64)]
_) =
String -> ImpM lore r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op ()) -> String -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
[String] -> String
unwords [String
"copyDWIMDest: constant source", PrimValue -> String
forall a. Pretty a => a -> String
pretty PrimValue
v, String
"cannot be indexed."]
copyDWIMDest ValueDestination
pat [DimIndex (TExp Int64)]
dest_slice (Constant PrimValue
v) [] =
case (DimIndex (TExp Int64) -> Maybe (TExp Int64))
-> [DimIndex (TExp Int64)] -> Maybe (Shape (TExp Int64))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndex (TExp Int64) -> Maybe (TExp Int64)
forall d. DimIndex d -> Maybe d
dimFix [DimIndex (TExp Int64)]
dest_slice of
Maybe (Shape (TExp Int64))
Nothing ->
String -> ImpM lore r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op ()) -> String -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
[String] -> String
unwords [String
"copyDWIMDest: constant source", PrimValue -> String
forall a. Pretty a => a -> String
pretty PrimValue
v, String
"with slice destination."]
Just Shape (TExp Int64)
dest_is ->
case ValueDestination
pat of
ScalarDestination VName
name ->
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
name (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ PrimValue -> Exp
forall v. PrimValue -> PrimExp v
Imp.ValueExp PrimValue
v
MemoryDestination {} ->
String -> ImpM lore r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op ()) -> String -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
[String] -> String
unwords [String
"copyDWIMDest: constant source", PrimValue -> String
forall a. Pretty a => a -> String
pretty PrimValue
v, String
"cannot be written to memory destination."]
ArrayDestination (Just MemLocation
dest_loc) -> do
(VName
dest_mem, Space
dest_space, Count Elements (TExp Int64)
dest_i) <-
MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall lore r op.
MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLocation
dest_loc Shape (TExp Int64)
dest_is
Volatility
vol <- (Env lore r op -> Volatility) -> ImpM lore r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Volatility
forall lore r op. Env lore r op -> Volatility
envVolatility
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code op
forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
dest_mem Count Elements (TExp Int64)
dest_i PrimType
bt Space
dest_space Volatility
vol (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ PrimValue -> Exp
forall v. PrimValue -> PrimExp v
Imp.ValueExp PrimValue
v
ArrayDestination Maybe MemLocation
Nothing ->
String -> ImpM lore r op ()
forall a. HasCallStack => String -> a
error String
"copyDWIMDest: ArrayDestination Nothing"
where
bt :: PrimType
bt = PrimValue -> PrimType
primValueType PrimValue
v
copyDWIMDest ValueDestination
dest [DimIndex (TExp Int64)]
dest_slice (Var VName
src) [DimIndex (TExp Int64)]
src_slice = do
VarEntry lore
src_entry <- VName -> ImpM lore r op (VarEntry lore)
forall lore r op. VName -> ImpM lore r op (VarEntry lore)
lookupVar VName
src
case (ValueDestination
dest, VarEntry lore
src_entry) of
(MemoryDestination VName
mem, MemVar Maybe (Exp lore)
_ (MemEntry Space
space)) ->
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> Space -> Code op
forall a. VName -> VName -> Space -> Code a
Imp.SetMem VName
mem VName
src Space
space
(MemoryDestination {}, VarEntry lore
_) ->
String -> ImpM lore r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op ()) -> String -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
[String] -> String
unwords [String
"copyDWIMDest: cannot write", VName -> String
forall a. Pretty a => a -> String
pretty VName
src, String
"to memory destination."]
(ValueDestination
_, MemVar {}) ->
String -> ImpM lore r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op ()) -> String -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
[String] -> String
unwords [String
"copyDWIMDest: source", VName -> String
forall a. Pretty a => a -> String
pretty VName
src, String
"is a memory block."]
(ValueDestination
_, ScalarVar Maybe (Exp lore)
_ (ScalarEntry PrimType
_))
| Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [DimIndex (TExp Int64)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [DimIndex (TExp Int64)]
src_slice ->
String -> ImpM lore r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op ()) -> String -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
[String] -> String
unwords [String
"copyDWIMDest: prim-typed source", VName -> String
forall a. Pretty a => a -> String
pretty VName
src, String
"with slice", [DimIndex (TExp Int64)] -> String
forall a. Pretty a => a -> String
pretty [DimIndex (TExp Int64)]
src_slice]
(ScalarDestination VName
name, VarEntry lore
_)
| Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [DimIndex (TExp Int64)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [DimIndex (TExp Int64)]
dest_slice ->
String -> ImpM lore r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op ()) -> String -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
[String] -> String
unwords [String
"copyDWIMDest: prim-typed target", VName -> String
forall a. Pretty a => a -> String
pretty VName
name, String
"with slice", [DimIndex (TExp Int64)] -> String
forall a. Pretty a => a -> String
pretty [DimIndex (TExp Int64)]
dest_slice]
(ScalarDestination VName
name, ScalarVar Maybe (Exp lore)
_ (ScalarEntry PrimType
pt)) ->
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
name (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
src PrimType
pt
(ScalarDestination VName
name, ArrayVar Maybe (Exp lore)
_ ArrayEntry
arr)
| Just Shape (TExp Int64)
src_is <- (DimIndex (TExp Int64) -> Maybe (TExp Int64))
-> [DimIndex (TExp Int64)] -> Maybe (Shape (TExp Int64))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndex (TExp Int64) -> Maybe (TExp Int64)
forall d. DimIndex d -> Maybe d
dimFix [DimIndex (TExp Int64)]
src_slice,
[DimIndex (TExp Int64)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex (TExp Int64)]
src_slice Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [DimSize] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (ArrayEntry -> [DimSize]
entryArrayShape ArrayEntry
arr) -> do
let bt :: PrimType
bt = ArrayEntry -> PrimType
entryArrayElemType ArrayEntry
arr
(VName
mem, Space
space, Count Elements (TExp Int64)
i) <-
MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall lore r op.
MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' (ArrayEntry -> MemLocation
entryArrayLocation ArrayEntry
arr) Shape (TExp Int64)
src_is
Volatility
vol <- (Env lore r op -> Volatility) -> ImpM lore r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Volatility
forall lore r op. Env lore r op -> Volatility
envVolatility
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
name (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
Imp.index VName
mem Count Elements (TExp Int64)
i PrimType
bt Space
space Volatility
vol
| Bool
otherwise ->
String -> ImpM lore r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op ()) -> String -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
[String] -> String
unwords
[ String
"copyDWIMDest: prim-typed target",
VName -> String
forall a. Pretty a => a -> String
pretty VName
name,
String
"and array-typed source",
VName -> String
forall a. Pretty a => a -> String
pretty VName
src,
String
"with slice",
[DimIndex (TExp Int64)] -> String
forall a. Pretty a => a -> String
pretty [DimIndex (TExp Int64)]
src_slice
]
(ArrayDestination (Just MemLocation
dest_loc), ArrayVar Maybe (Exp lore)
_ ArrayEntry
src_arr) -> do
let src_loc :: MemLocation
src_loc = ArrayEntry -> MemLocation
entryArrayLocation ArrayEntry
src_arr
bt :: PrimType
bt = ArrayEntry -> PrimType
entryArrayElemType ArrayEntry
src_arr
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ())
-> ImpM lore r op (Code op) -> ImpM lore r op ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< PrimType
-> MemLocation
-> [DimIndex (TExp Int64)]
-> MemLocation
-> [DimIndex (TExp Int64)]
-> ImpM lore r op (Code op)
forall lore r op.
PrimType
-> MemLocation
-> [DimIndex (TExp Int64)]
-> MemLocation
-> [DimIndex (TExp Int64)]
-> ImpM lore r op (Code op)
copyArrayDWIM PrimType
bt MemLocation
dest_loc [DimIndex (TExp Int64)]
dest_slice MemLocation
src_loc [DimIndex (TExp Int64)]
src_slice
(ArrayDestination (Just MemLocation
dest_loc), ScalarVar Maybe (Exp lore)
_ (ScalarEntry PrimType
bt))
| Just Shape (TExp Int64)
dest_is <- (DimIndex (TExp Int64) -> Maybe (TExp Int64))
-> [DimIndex (TExp Int64)] -> Maybe (Shape (TExp Int64))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndex (TExp Int64) -> Maybe (TExp Int64)
forall d. DimIndex d -> Maybe d
dimFix [DimIndex (TExp Int64)]
dest_slice -> do
(VName
dest_mem, Space
dest_space, Count Elements (TExp Int64)
dest_i) <- MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall lore r op.
MemLocation
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLocation
dest_loc Shape (TExp Int64)
dest_is
Volatility
vol <- (Env lore r op -> Volatility) -> ImpM lore r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Volatility
forall lore r op. Env lore r op -> Volatility
envVolatility
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code op
forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
dest_mem Count Elements (TExp Int64)
dest_i PrimType
bt Space
dest_space Volatility
vol (VName -> PrimType -> Exp
Imp.var VName
src PrimType
bt)
| Bool
otherwise ->
String -> ImpM lore r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op ()) -> String -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
[String] -> String
unwords
[ String
"copyDWIMDest: array-typed target and prim-typed source",
VName -> String
forall a. Pretty a => a -> String
pretty VName
src,
String
"with slice",
[DimIndex (TExp Int64)] -> String
forall a. Pretty a => a -> String
pretty [DimIndex (TExp Int64)]
dest_slice
]
(ArrayDestination Maybe MemLocation
Nothing, VarEntry lore
_) ->
() -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
copyDWIM ::
VName ->
[DimIndex (Imp.TExp Int64)] ->
SubExp ->
[DimIndex (Imp.TExp Int64)] ->
ImpM lore r op ()
copyDWIM :: VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM VName
dest [DimIndex (TExp Int64)]
dest_slice DimSize
src [DimIndex (TExp Int64)]
src_slice = do
VarEntry lore
dest_entry <- VName -> ImpM lore r op (VarEntry lore)
forall lore r op. VName -> ImpM lore r op (VarEntry lore)
lookupVar VName
dest
let dest_target :: ValueDestination
dest_target =
case VarEntry lore
dest_entry of
ScalarVar Maybe (Exp lore)
_ ScalarEntry
_ ->
VName -> ValueDestination
ScalarDestination VName
dest
ArrayVar Maybe (Exp lore)
_ (ArrayEntry (MemLocation VName
mem [DimSize]
shape IxFun (TExp Int64)
ixfun) PrimType
_) ->
Maybe MemLocation -> ValueDestination
ArrayDestination (Maybe MemLocation -> ValueDestination)
-> Maybe MemLocation -> ValueDestination
forall a b. (a -> b) -> a -> b
$ MemLocation -> Maybe MemLocation
forall a. a -> Maybe a
Just (MemLocation -> Maybe MemLocation)
-> MemLocation -> Maybe MemLocation
forall a b. (a -> b) -> a -> b
$ VName -> [DimSize] -> IxFun (TExp Int64) -> MemLocation
MemLocation VName
mem [DimSize]
shape IxFun (TExp Int64)
ixfun
MemVar Maybe (Exp lore)
_ MemEntry
_ ->
VName -> ValueDestination
MemoryDestination VName
dest
ValueDestination
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
ValueDestination
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIMDest ValueDestination
dest_target [DimIndex (TExp Int64)]
dest_slice DimSize
src [DimIndex (TExp Int64)]
src_slice
copyDWIMFix ::
VName ->
[Imp.TExp Int64] ->
SubExp ->
[Imp.TExp Int64] ->
ImpM lore r op ()
copyDWIMFix :: VName
-> Shape (TExp Int64)
-> DimSize
-> Shape (TExp Int64)
-> ImpM lore r op ()
copyDWIMFix VName
dest Shape (TExp Int64)
dest_is DimSize
src Shape (TExp Int64)
src_is =
VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM VName
dest ((TExp Int64 -> DimIndex (TExp Int64))
-> Shape (TExp Int64) -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix Shape (TExp Int64)
dest_is) DimSize
src ((TExp Int64 -> DimIndex (TExp Int64))
-> Shape (TExp Int64) -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix Shape (TExp Int64)
src_is)
compileAlloc ::
Mem lore =>
Pattern lore ->
SubExp ->
Space ->
ImpM lore r op ()
compileAlloc :: Pattern lore -> DimSize -> Space -> ImpM lore r op ()
compileAlloc (Pattern [] [PatElemT (LetDec lore)
mem]) DimSize
e Space
space = do
let e' :: Count Bytes (TExp Int64)
e' = TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
Imp.bytes (TExp Int64 -> Count Bytes (TExp Int64))
-> TExp Int64 -> Count Bytes (TExp Int64)
forall a b. (a -> b) -> a -> b
$ DimSize -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp DimSize
e
Maybe (AllocCompiler lore r op)
allocator <- (Env lore r op -> Maybe (AllocCompiler lore r op))
-> ImpM lore r op (Maybe (AllocCompiler lore r op))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((Env lore r op -> Maybe (AllocCompiler lore r op))
-> ImpM lore r op (Maybe (AllocCompiler lore r op)))
-> (Env lore r op -> Maybe (AllocCompiler lore r op))
-> ImpM lore r op (Maybe (AllocCompiler lore r op))
forall a b. (a -> b) -> a -> b
$ Space
-> Map Space (AllocCompiler lore r op)
-> Maybe (AllocCompiler lore r op)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Space
space (Map Space (AllocCompiler lore r op)
-> Maybe (AllocCompiler lore r op))
-> (Env lore r op -> Map Space (AllocCompiler lore r op))
-> Env lore r op
-> Maybe (AllocCompiler lore r op)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env lore r op -> Map Space (AllocCompiler lore r op)
forall lore r op.
Env lore r op -> Map Space (AllocCompiler lore r op)
envAllocCompilers
case Maybe (AllocCompiler lore r op)
allocator of
Maybe (AllocCompiler lore r op)
Nothing -> Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Count Bytes (TExp Int64) -> Space -> Code op
forall a. VName -> Count Bytes (TExp Int64) -> Space -> Code a
Imp.Allocate (PatElemT (MemBound NoUniqueness) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT (MemBound NoUniqueness)
mem) Count Bytes (TExp Int64)
e' Space
space
Just AllocCompiler lore r op
allocator' -> AllocCompiler lore r op
allocator' (PatElemT (MemBound NoUniqueness) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
PatElemT (MemBound NoUniqueness)
mem) Count Bytes (TExp Int64)
e'
compileAlloc Pattern lore
pat DimSize
_ Space
_ =
String -> ImpM lore r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op ()) -> String -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ String
"compileAlloc: Invalid pattern: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ PatternT (MemBound NoUniqueness) -> String
forall a. Pretty a => a -> String
pretty Pattern lore
PatternT (MemBound NoUniqueness)
pat
typeSize :: Type -> Count Bytes (Imp.TExp Int64)
typeSize :: Type -> Count Bytes (TExp Int64)
typeSize Type
t =
TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
Imp.bytes (TExp Int64 -> Count Bytes (TExp Int64))
-> TExp Int64 -> Count Bytes (TExp Int64)
forall a b. (a -> b) -> a -> b
$
Exp -> TExp Int64
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (ExpLeaf -> PrimType -> Exp
forall v. v -> PrimType -> PrimExp v
Imp.LeafExp (PrimType -> ExpLeaf
Imp.SizeOf (PrimType -> ExpLeaf) -> PrimType -> ExpLeaf
forall a b. (a -> b) -> a -> b
$ Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t) PrimType
int64)
TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* Shape (TExp Int64) -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ((DimSize -> TExp Int64) -> [DimSize] -> Shape (TExp Int64)
forall a b. (a -> b) -> [a] -> [b]
map DimSize -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (Type -> [DimSize]
forall u. TypeBase (ShapeBase DimSize) u -> [DimSize]
arrayDims Type
t))
sFor' :: VName -> Imp.Exp -> ImpM lore r op () -> ImpM lore r op ()
sFor' :: VName -> Exp -> ImpM lore r op () -> ImpM lore r op ()
sFor' VName
i Exp
bound ImpM lore r op ()
body = do
let it :: IntType
it = case Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType Exp
bound of
IntType IntType
bound_t -> IntType
bound_t
PrimType
t -> String -> IntType
forall a. HasCallStack => String -> a
error (String -> IntType) -> String -> IntType
forall a b. (a -> b) -> a -> b
$ String
"sFor': bound " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Exp -> String
forall a. Pretty a => a -> String
pretty Exp
bound String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" is of type " String -> ShowS
forall a. [a] -> [a] -> [a]
++ PrimType -> String
forall a. Pretty a => a -> String
pretty PrimType
t
VName -> IntType -> ImpM lore r op ()
forall lore r op. VName -> IntType -> ImpM lore r op ()
addLoopVar VName
i IntType
it
Code op
body' <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect ImpM lore r op ()
body
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op -> Code op
forall a. VName -> Exp -> Code a -> Code a
Imp.For VName
i Exp
bound Code op
body'
sFor :: String -> Imp.TExp t -> (Imp.TExp t -> ImpM lore r op ()) -> ImpM lore r op ()
sFor :: String
-> TExp t -> (TExp t -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
i TExp t
bound TExp t -> ImpM lore r op ()
body = do
VName
i' <- String -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
i
VName -> Exp -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
VName -> Exp -> ImpM lore r op () -> ImpM lore r op ()
sFor' VName
i' (TExp t -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp t
bound) (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$
TExp t -> ImpM lore r op ()
body (TExp t -> ImpM lore r op ()) -> TExp t -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Exp -> TExp t
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (Exp -> TExp t) -> Exp -> TExp t
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
i' (PrimType -> Exp) -> PrimType -> Exp
forall a b. (a -> b) -> a -> b
$ Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType (Exp -> PrimType) -> Exp -> PrimType
forall a b. (a -> b) -> a -> b
$ TExp t -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp t
bound
sWhile :: Imp.TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhile :: TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhile TExp Bool
cond ImpM lore r op ()
body = do
Code op
body' <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect ImpM lore r op ()
body
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ TExp Bool -> Code op -> Code op
forall a. TExp Bool -> Code a -> Code a
Imp.While TExp Bool
cond Code op
body'
sComment :: String -> ImpM lore r op () -> ImpM lore r op ()
String
s ImpM lore r op ()
code = do
Code op
code' <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect ImpM lore r op ()
code
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ String -> Code op -> Code op
forall a. String -> Code a -> Code a
Imp.Comment String
s Code op
code'
sIf :: Imp.TExp Bool -> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf :: TExp Bool
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf TExp Bool
cond ImpM lore r op ()
tbranch ImpM lore r op ()
fbranch = do
Code op
tbranch' <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect ImpM lore r op ()
tbranch
Code op
fbranch' <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect ImpM lore r op ()
fbranch
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ TExp Bool -> Code op -> Code op -> Code op
forall a. TExp Bool -> Code a -> Code a -> Code a
Imp.If TExp Bool
cond Code op
tbranch' Code op
fbranch'
sWhen :: Imp.TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen :: TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen TExp Bool
cond ImpM lore r op ()
tbranch = TExp Bool
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
TExp Bool
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf TExp Bool
cond ImpM lore r op ()
tbranch (() -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())
sUnless :: Imp.TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sUnless :: TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sUnless TExp Bool
cond = TExp Bool
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
TExp Bool
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf TExp Bool
cond (() -> ImpM lore r op ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())
sOp :: op -> ImpM lore r op ()
sOp :: op -> ImpM lore r op ()
sOp = Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ())
-> (op -> Code op) -> op -> ImpM lore r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. op -> Code op
forall a. a -> Code a
Imp.Op
sDeclareMem :: String -> Space -> ImpM lore r op VName
sDeclareMem :: String -> Space -> ImpM lore r op VName
sDeclareMem String
name Space
space = do
VName
name' <- String -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
name
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Space -> Code op
forall a. VName -> Space -> Code a
Imp.DeclareMem VName
name' Space
space
VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name' (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> MemEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> MemEntry -> VarEntry lore
MemVar Maybe (Exp lore)
forall a. Maybe a
Nothing (MemEntry -> VarEntry lore) -> MemEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
space
VName -> ImpM lore r op VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
name'
sAlloc_ :: VName -> Count Bytes (Imp.TExp Int64) -> Space -> ImpM lore r op ()
sAlloc_ :: VName -> Count Bytes (TExp Int64) -> Space -> ImpM lore r op ()
sAlloc_ VName
name' Count Bytes (TExp Int64)
size' Space
space = do
Maybe (AllocCompiler lore r op)
allocator <- (Env lore r op -> Maybe (AllocCompiler lore r op))
-> ImpM lore r op (Maybe (AllocCompiler lore r op))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((Env lore r op -> Maybe (AllocCompiler lore r op))
-> ImpM lore r op (Maybe (AllocCompiler lore r op)))
-> (Env lore r op -> Maybe (AllocCompiler lore r op))
-> ImpM lore r op (Maybe (AllocCompiler lore r op))
forall a b. (a -> b) -> a -> b
$ Space
-> Map Space (AllocCompiler lore r op)
-> Maybe (AllocCompiler lore r op)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Space
space (Map Space (AllocCompiler lore r op)
-> Maybe (AllocCompiler lore r op))
-> (Env lore r op -> Map Space (AllocCompiler lore r op))
-> Env lore r op
-> Maybe (AllocCompiler lore r op)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env lore r op -> Map Space (AllocCompiler lore r op)
forall lore r op.
Env lore r op -> Map Space (AllocCompiler lore r op)
envAllocCompilers
case Maybe (AllocCompiler lore r op)
allocator of
Maybe (AllocCompiler lore r op)
Nothing -> Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Count Bytes (TExp Int64) -> Space -> Code op
forall a. VName -> Count Bytes (TExp Int64) -> Space -> Code a
Imp.Allocate VName
name' Count Bytes (TExp Int64)
size' Space
space
Just AllocCompiler lore r op
allocator' -> AllocCompiler lore r op
allocator' VName
name' Count Bytes (TExp Int64)
size'
sAlloc :: String -> Count Bytes (Imp.TExp Int64) -> Space -> ImpM lore r op VName
sAlloc :: String -> Count Bytes (TExp Int64) -> Space -> ImpM lore r op VName
sAlloc String
name Count Bytes (TExp Int64)
size Space
space = do
VName
name' <- String -> Space -> ImpM lore r op VName
forall lore r op. String -> Space -> ImpM lore r op VName
sDeclareMem String
name Space
space
VName -> Count Bytes (TExp Int64) -> Space -> ImpM lore r op ()
forall lore r op.
VName -> Count Bytes (TExp Int64) -> Space -> ImpM lore r op ()
sAlloc_ VName
name' Count Bytes (TExp Int64)
size Space
space
VName -> ImpM lore r op VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
name'
sArray :: String -> PrimType -> ShapeBase SubExp -> MemBind -> ImpM lore r op VName
sArray :: String
-> PrimType -> ShapeBase DimSize -> MemBind -> ImpM lore r op VName
sArray String
name PrimType
bt ShapeBase DimSize
shape MemBind
membind = do
VName
name' <- String -> ImpM lore r op VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
name
VName
-> PrimType -> ShapeBase DimSize -> MemBind -> ImpM lore r op ()
forall lore r op.
VName
-> PrimType -> ShapeBase DimSize -> MemBind -> ImpM lore r op ()
dArray VName
name' PrimType
bt ShapeBase DimSize
shape MemBind
membind
VName -> ImpM lore r op VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
name'
sArrayInMem :: String -> PrimType -> ShapeBase SubExp -> VName -> ImpM lore r op VName
sArrayInMem :: String
-> PrimType -> ShapeBase DimSize -> VName -> ImpM lore r op VName
sArrayInMem String
name PrimType
pt ShapeBase DimSize
shape VName
mem =
String
-> PrimType -> ShapeBase DimSize -> MemBind -> ImpM lore r op VName
forall lore r op.
String
-> PrimType -> ShapeBase DimSize -> MemBind -> ImpM lore r op VName
sArray String
name PrimType
pt ShapeBase DimSize
shape (MemBind -> ImpM lore r op VName)
-> MemBind -> ImpM lore r op VName
forall a b. (a -> b) -> a -> b
$
VName -> IxFun -> MemBind
ArrayIn VName
mem (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$
Shape (TPrimExp Int64 VName) -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota (Shape (TPrimExp Int64 VName) -> IxFun)
-> Shape (TPrimExp Int64 VName) -> IxFun
forall a b. (a -> b) -> a -> b
$ (DimSize -> TPrimExp Int64 VName)
-> [DimSize] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map (PrimExp VName -> TPrimExp Int64 VName
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (PrimExp VName -> TPrimExp Int64 VName)
-> (DimSize -> PrimExp VName) -> DimSize -> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> DimSize -> PrimExp VName
primExpFromSubExp PrimType
int64) ([DimSize] -> Shape (TPrimExp Int64 VName))
-> [DimSize] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ ShapeBase DimSize -> [DimSize]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase DimSize
shape
sAllocArrayPerm :: String -> PrimType -> ShapeBase SubExp -> Space -> [Int] -> ImpM lore r op VName
sAllocArrayPerm :: String
-> PrimType
-> ShapeBase DimSize
-> Space
-> [Int]
-> ImpM lore r op VName
sAllocArrayPerm String
name PrimType
pt ShapeBase DimSize
shape Space
space [Int]
perm = do
let permuted_dims :: [DimSize]
permuted_dims = [Int] -> [DimSize] -> [DimSize]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm ([DimSize] -> [DimSize]) -> [DimSize] -> [DimSize]
forall a b. (a -> b) -> a -> b
$ ShapeBase DimSize -> [DimSize]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase DimSize
shape
VName
mem <- String -> Count Bytes (TExp Int64) -> Space -> ImpM lore r op VName
forall lore r op.
String -> Count Bytes (TExp Int64) -> Space -> ImpM lore r op VName
sAlloc (String
name String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_mem") (Type -> Count Bytes (TExp Int64)
typeSize (PrimType -> ShapeBase DimSize -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
pt ShapeBase DimSize
shape NoUniqueness
NoUniqueness)) Space
space
let iota_ixfun :: IxFun
iota_ixfun = Shape (TPrimExp Int64 VName) -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota (Shape (TPrimExp Int64 VName) -> IxFun)
-> Shape (TPrimExp Int64 VName) -> IxFun
forall a b. (a -> b) -> a -> b
$ (DimSize -> TPrimExp Int64 VName)
-> [DimSize] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map (PrimExp VName -> TPrimExp Int64 VName
forall v. PrimExp v -> TPrimExp Int64 v
isInt64 (PrimExp VName -> TPrimExp Int64 VName)
-> (DimSize -> PrimExp VName) -> DimSize -> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> DimSize -> PrimExp VName
primExpFromSubExp PrimType
int64) [DimSize]
permuted_dims
String
-> PrimType -> ShapeBase DimSize -> MemBind -> ImpM lore r op VName
forall lore r op.
String
-> PrimType -> ShapeBase DimSize -> MemBind -> ImpM lore r op VName
sArray String
name PrimType
pt ShapeBase DimSize
shape (MemBind -> ImpM lore r op VName)
-> MemBind -> ImpM lore r op VName
forall a b. (a -> b) -> a -> b
$
VName -> IxFun -> MemBind
ArrayIn VName
mem (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$ IxFun -> [Int] -> IxFun
forall num. IntegralExp num => IxFun num -> [Int] -> IxFun num
IxFun.permute IxFun
iota_ixfun ([Int] -> IxFun) -> [Int] -> IxFun
forall a b. (a -> b) -> a -> b
$ [Int] -> [Int]
rearrangeInverse [Int]
perm
sAllocArray :: String -> PrimType -> ShapeBase SubExp -> Space -> ImpM lore r op VName
sAllocArray :: String
-> PrimType -> ShapeBase DimSize -> Space -> ImpM lore r op VName
sAllocArray String
name PrimType
pt ShapeBase DimSize
shape Space
space =
String
-> PrimType
-> ShapeBase DimSize
-> Space
-> [Int]
-> ImpM lore r op VName
forall lore r op.
String
-> PrimType
-> ShapeBase DimSize
-> Space
-> [Int]
-> ImpM lore r op VName
sAllocArrayPerm String
name PrimType
pt ShapeBase DimSize
shape Space
space [Int
0 .. ShapeBase DimSize -> Int
forall a. ArrayShape a => a -> Int
shapeRank ShapeBase DimSize
shape Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
sStaticArray :: String -> Space -> PrimType -> Imp.ArrayContents -> ImpM lore r op VName
sStaticArray :: String
-> Space -> PrimType -> ArrayContents -> ImpM lore r op VName
sStaticArray String
name Space
space PrimType
pt ArrayContents
vs = do
let num_elems :: Int
num_elems = case ArrayContents
vs of
Imp.ArrayValues [PrimValue]
vs' -> [PrimValue] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimValue]
vs'
Imp.ArrayZeros Int
n -> Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n
shape :: ShapeBase DimSize
shape = [DimSize] -> ShapeBase DimSize
forall d. [d] -> ShapeBase d
Shape [IntType -> Integer -> DimSize
intConst IntType
Int64 (Integer -> DimSize) -> Integer -> DimSize
forall a b. (a -> b) -> a -> b
$ Int -> Integer
forall a. Integral a => a -> Integer
toInteger Int
num_elems]
VName
mem <- String -> ImpM lore r op VName
forall lore r op. String -> ImpM lore r op VName
newVNameForFun (String -> ImpM lore r op VName) -> String -> ImpM lore r op VName
forall a b. (a -> b) -> a -> b
$ String
name String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_mem"
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Space -> PrimType -> ArrayContents -> Code op
forall a. VName -> Space -> PrimType -> ArrayContents -> Code a
Imp.DeclareArray VName
mem Space
space PrimType
pt ArrayContents
vs
VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
mem (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> MemEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> MemEntry -> VarEntry lore
MemVar Maybe (Exp lore)
forall a. Maybe a
Nothing (MemEntry -> VarEntry lore) -> MemEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
space
String
-> PrimType -> ShapeBase DimSize -> MemBind -> ImpM lore r op VName
forall lore r op.
String
-> PrimType -> ShapeBase DimSize -> MemBind -> ImpM lore r op VName
sArray String
name PrimType
pt ShapeBase DimSize
shape (MemBind -> ImpM lore r op VName)
-> MemBind -> ImpM lore r op VName
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn VName
mem (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$ Shape (TPrimExp Int64 VName) -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota [Int -> TPrimExp Int64 VName
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
num_elems]
sWrite :: VName -> [Imp.TExp Int64] -> Imp.Exp -> ImpM lore r op ()
sWrite :: VName -> Shape (TExp Int64) -> Exp -> ImpM lore r op ()
sWrite VName
arr Shape (TExp Int64)
is Exp
v = do
(VName
mem, Space
space, Count Elements (TExp Int64)
offset) <- VName
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
forall lore r op.
VName
-> Shape (TExp Int64)
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray VName
arr Shape (TExp Int64)
is
Volatility
vol <- (Env lore r op -> Volatility) -> ImpM lore r op Volatility
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env lore r op -> Volatility
forall lore r op. Env lore r op -> Volatility
envVolatility
Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code op
forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
mem Count Elements (TExp Int64)
offset (Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType Exp
v) Space
space Volatility
vol Exp
v
sUpdate :: VName -> Slice (Imp.TExp Int64) -> SubExp -> ImpM lore r op ()
sUpdate :: VName -> [DimIndex (TExp Int64)] -> DimSize -> ImpM lore r op ()
sUpdate VName
arr [DimIndex (TExp Int64)]
slice DimSize
v = VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> DimSize
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM VName
arr [DimIndex (TExp Int64)]
slice DimSize
v []
sLoopNest ::
Shape ->
([Imp.TExp Int64] -> ImpM lore r op ()) ->
ImpM lore r op ()
sLoopNest :: ShapeBase DimSize
-> (Shape (TExp Int64) -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest = Shape (TExp Int64)
-> [DimSize]
-> (Shape (TExp Int64) -> ImpM lore r op ())
-> ImpM lore r op ()
forall a lore r op.
ToExp a =>
Shape (TExp Int64)
-> [a]
-> (Shape (TExp Int64) -> ImpM lore r op ())
-> ImpM lore r op ()
sLoopNest' [] ([DimSize]
-> (Shape (TExp Int64) -> ImpM lore r op ()) -> ImpM lore r op ())
-> (ShapeBase DimSize -> [DimSize])
-> ShapeBase DimSize
-> (Shape (TExp Int64) -> ImpM lore r op ())
-> ImpM lore r op ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ShapeBase DimSize -> [DimSize]
forall d. ShapeBase d -> [d]
shapeDims
where
sLoopNest' :: Shape (TExp Int64)
-> [a]
-> (Shape (TExp Int64) -> ImpM lore r op ())
-> ImpM lore r op ()
sLoopNest' Shape (TExp Int64)
is [] Shape (TExp Int64) -> ImpM lore r op ()
f = Shape (TExp Int64) -> ImpM lore r op ()
f (Shape (TExp Int64) -> ImpM lore r op ())
-> Shape (TExp Int64) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Shape (TExp Int64) -> Shape (TExp Int64)
forall a. [a] -> [a]
reverse Shape (TExp Int64)
is
sLoopNest' Shape (TExp Int64)
is (a
d : [a]
ds) Shape (TExp Int64) -> ImpM lore r op ()
f =
String
-> TExp Int64
-> (TExp Int64 -> ImpM lore r op ())
-> ImpM lore r op ()
forall t lore r op.
String
-> TExp t -> (TExp t -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"nest_i" (a -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp a
d) ((TExp Int64 -> ImpM lore r op ()) -> ImpM lore r op ())
-> (TExp Int64 -> ImpM lore r op ()) -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> Shape (TExp Int64)
-> [a]
-> (Shape (TExp Int64) -> ImpM lore r op ())
-> ImpM lore r op ()
sLoopNest' (TExp Int64
i TExp Int64 -> Shape (TExp Int64) -> Shape (TExp Int64)
forall a. a -> [a] -> [a]
: Shape (TExp Int64)
is) [a]
ds Shape (TExp Int64) -> ImpM lore r op ()
f
(<~~) :: VName -> Imp.Exp -> ImpM lore r op ()
VName
x <~~ :: VName -> Exp -> ImpM lore r op ()
<~~ Exp
e = Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
x Exp
e
infixl 3 <~~
(<--) :: TV t -> Imp.TExp t -> ImpM lore r op ()
TV VName
x PrimType
_ <-- :: TV t -> TExp t -> ImpM lore r op ()
<-- TExp t
e = Code op -> ImpM lore r op ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code op -> ImpM lore r op ()) -> Code op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ VName -> Exp -> Code op
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
x (Exp -> Code op) -> Exp -> Code op
forall a b. (a -> b) -> a -> b
$ TExp t -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp t
e
infixl 3 <--
function ::
Name ->
[Imp.Param] ->
[Imp.Param] ->
ImpM lore r op () ->
ImpM lore r op ()
function :: Name
-> [Param] -> [Param] -> ImpM lore r op () -> ImpM lore r op ()
function Name
fname [Param]
outputs [Param]
inputs ImpM lore r op ()
m = (Env lore r op -> Env lore r op)
-> ImpM lore r op () -> ImpM lore r op ()
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local Env lore r op -> Env lore r op
newFunction (ImpM lore r op () -> ImpM lore r op ())
-> ImpM lore r op () -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ do
Code op
body <- ImpM lore r op () -> ImpM lore r op (Code op)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (ImpM lore r op () -> ImpM lore r op (Code op))
-> ImpM lore r op () -> ImpM lore r op (Code op)
forall a b. (a -> b) -> a -> b
$ do
(Param -> ImpM lore r op ()) -> [Param] -> ImpM lore r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Param -> ImpM lore r op ()
forall lore r op. Param -> ImpM lore r op ()
addParam ([Param] -> ImpM lore r op ()) -> [Param] -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ [Param]
outputs [Param] -> [Param] -> [Param]
forall a. [a] -> [a] -> [a]
++ [Param]
inputs
ImpM lore r op ()
m
Name -> Function op -> ImpM lore r op ()
forall op lore r. Name -> Function op -> ImpM lore r op ()
emitFunction Name
fname (Function op -> ImpM lore r op ())
-> Function op -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Bool
-> [Param]
-> [Param]
-> Code op
-> [ExternalValue]
-> [ExternalValue]
-> Function op
forall a.
Bool
-> [Param]
-> [Param]
-> Code a
-> [ExternalValue]
-> [ExternalValue]
-> FunctionT a
Imp.Function Bool
False [Param]
outputs [Param]
inputs Code op
body [] []
where
addParam :: Param -> ImpM lore r op ()
addParam (Imp.MemParam VName
name Space
space) =
VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> MemEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> MemEntry -> VarEntry lore
MemVar Maybe (Exp lore)
forall a. Maybe a
Nothing (MemEntry -> VarEntry lore) -> MemEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
space
addParam (Imp.ScalarParam VName
name PrimType
bt) =
VName -> VarEntry lore -> ImpM lore r op ()
forall lore r op. VName -> VarEntry lore -> ImpM lore r op ()
addVar VName
name (VarEntry lore -> ImpM lore r op ())
-> VarEntry lore -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> ScalarEntry -> VarEntry lore
ScalarVar Maybe (Exp lore)
forall a. Maybe a
Nothing (ScalarEntry -> VarEntry lore) -> ScalarEntry -> VarEntry lore
forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarEntry
ScalarEntry PrimType
bt
newFunction :: Env lore r op -> Env lore r op
newFunction Env lore r op
env = Env lore r op
env {envFunction :: Maybe Name
envFunction = Name -> Maybe Name
forall a. a -> Maybe a
Just Name
fname}