{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE Strict #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.CodeGen.ImpGen
(
compileProg,
OpCompiler,
ExpCompiler,
CopyCompiler,
StmsCompiler,
AllocCompiler,
Operations (..),
defaultOperations,
MemLoc (..),
sliceMemLoc,
MemEntry (..),
ScalarEntry (..),
ImpM,
localDefaultSpace,
askFunction,
newVNameForFun,
nameForFun,
askEnv,
localEnv,
localOps,
VTable,
getVTable,
localVTable,
subImpM,
subImpM_,
emit,
emitFunction,
hasFunction,
collect,
collect',
comment,
VarEntry (..),
ArrayEntry (..),
lookupVar,
lookupArray,
lookupMemory,
lookupAcc,
askAttrs,
TV,
mkTV,
tvSize,
tvExp,
tvVar,
ToExp (..),
compileAlloc,
everythingVolatile,
compileBody,
compileBody',
compileLoopBody,
defCompileStms,
compileStms,
compileExp,
defCompileExp,
fullyIndexArray,
fullyIndexArray',
copy,
copyDWIM,
copyDWIMFix,
copyElementWise,
typeSize,
inBounds,
isMapTransposeCopy,
caseMatch,
dLParams,
dFParams,
addLoopVar,
dScope,
dArray,
dPrim,
dPrimVol,
dPrim_,
dPrimV_,
dPrimV,
dPrimVE,
dIndexSpace,
dIndexSpace',
rotateIndex,
sFor,
sWhile,
sComment,
sIf,
sWhen,
sUnless,
sOp,
sDeclareMem,
sAlloc,
sAlloc_,
sArray,
sArrayInMem,
sAllocArray,
sAllocArrayPerm,
sStaticArray,
sWrite,
sUpdate,
sLoopNest,
sCopy,
sLoopSpace,
(<--),
(<~~),
function,
warn,
module Language.Futhark.Warnings,
)
where
import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Writer
import Control.Parallel.Strategies
import Data.Bifunctor (first)
import Data.DList qualified as DL
import Data.Either
import Data.List (find)
import Data.List.NonEmpty (NonEmpty (..))
import Data.Map.Strict qualified as M
import Data.Maybe
import Data.Set qualified as S
import Data.String
import Data.Text qualified as T
import Futhark.CodeGen.ImpCode
( Bytes,
Count,
Elements,
bytes,
elements,
withElemType,
)
import Futhark.CodeGen.ImpCode qualified as Imp
import Futhark.CodeGen.ImpGen.Transpose
import Futhark.Construct hiding (ToExp (..))
import Futhark.IR.Mem
import Futhark.IR.Mem.IxFun qualified as IxFun
import Futhark.IR.SOACS (SOACS)
import Futhark.Util
import Futhark.Util.IntegralExp
import Futhark.Util.Loc (noLoc)
import Futhark.Util.Pretty hiding (nest, space)
import Language.Futhark.Warnings
import Prelude hiding (mod, quot)
type OpCompiler rep r op = Pat (LetDec rep) -> Op rep -> ImpM rep r op ()
type StmsCompiler rep r op = Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
type ExpCompiler rep r op = Pat (LetDec rep) -> Exp rep -> ImpM rep r op ()
type CopyCompiler rep r op =
PrimType ->
MemLoc ->
MemLoc ->
ImpM rep r op ()
type AllocCompiler rep r op = VName -> Count Bytes (Imp.TExp Int64) -> ImpM rep r op ()
data Operations rep r op = Operations
{ forall {k} (rep :: k) r op.
Operations rep r op -> ExpCompiler rep r op
opsExpCompiler :: ExpCompiler rep r op,
forall {k} (rep :: k) r op.
Operations rep r op -> OpCompiler rep r op
opsOpCompiler :: OpCompiler rep r op,
forall {k} (rep :: k) r op.
Operations rep r op -> StmsCompiler rep r op
opsStmsCompiler :: StmsCompiler rep r op,
forall {k} (rep :: k) r op.
Operations rep r op -> CopyCompiler rep r op
opsCopyCompiler :: CopyCompiler rep r op,
forall {k} (rep :: k) r op.
Operations rep r op -> Map Space (AllocCompiler rep r op)
opsAllocCompilers :: M.Map Space (AllocCompiler rep r op)
}
defaultOperations ::
(Mem rep inner, FreeIn op) =>
OpCompiler rep r op ->
Operations rep r op
defaultOperations :: forall {k} (rep :: k) inner op r.
(Mem rep inner, FreeIn op) =>
OpCompiler rep r op -> Operations rep r op
defaultOperations OpCompiler rep r op
opc =
Operations
{ opsExpCompiler :: ExpCompiler rep r op
opsExpCompiler = forall {k} (rep :: k) inner r op.
Mem rep inner =>
Pat (LetDec rep) -> Exp rep -> ImpM rep r op ()
defCompileExp,
opsOpCompiler :: OpCompiler rep r op
opsOpCompiler = OpCompiler rep r op
opc,
opsStmsCompiler :: StmsCompiler rep r op
opsStmsCompiler = forall {k} (rep :: k) inner op r.
(Mem rep inner, FreeIn op) =>
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
defCompileStms,
opsCopyCompiler :: CopyCompiler rep r op
opsCopyCompiler = forall {k} (rep :: k) r op. CopyCompiler rep r op
defaultCopy,
opsAllocCompilers :: Map Space (AllocCompiler rep r op)
opsAllocCompilers = forall a. Monoid a => a
mempty
}
data MemLoc = MemLoc
{ MemLoc -> VName
memLocName :: VName,
MemLoc -> [SubExp]
memLocShape :: [Imp.DimSize],
MemLoc -> IxFun (TExp Int64)
memLocIxFun :: IxFun.IxFun (Imp.TExp Int64)
}
deriving (MemLoc -> MemLoc -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: MemLoc -> MemLoc -> Bool
$c/= :: MemLoc -> MemLoc -> Bool
== :: MemLoc -> MemLoc -> Bool
$c== :: MemLoc -> MemLoc -> Bool
Eq, Int -> MemLoc -> ShowS
[MemLoc] -> ShowS
MemLoc -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [MemLoc] -> ShowS
$cshowList :: [MemLoc] -> ShowS
show :: MemLoc -> [Char]
$cshow :: MemLoc -> [Char]
showsPrec :: Int -> MemLoc -> ShowS
$cshowsPrec :: Int -> MemLoc -> ShowS
Show)
sliceMemLoc :: MemLoc -> Slice (Imp.TExp Int64) -> MemLoc
sliceMemLoc :: MemLoc -> Slice (TExp Int64) -> MemLoc
sliceMemLoc (MemLoc VName
mem [SubExp]
shape IxFun (TExp Int64)
ixfun) Slice (TExp Int64)
slice =
VName -> [SubExp] -> IxFun (TExp Int64) -> MemLoc
MemLoc VName
mem [SubExp]
shape forall a b. (a -> b) -> a -> b
$ forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TExp Int64)
ixfun Slice (TExp Int64)
slice
flatSliceMemLoc :: MemLoc -> FlatSlice (Imp.TExp Int64) -> MemLoc
flatSliceMemLoc :: MemLoc -> FlatSlice (TExp Int64) -> MemLoc
flatSliceMemLoc (MemLoc VName
mem [SubExp]
shape IxFun (TExp Int64)
ixfun) FlatSlice (TExp Int64)
slice =
VName -> [SubExp] -> IxFun (TExp Int64) -> MemLoc
MemLoc VName
mem [SubExp]
shape forall a b. (a -> b) -> a -> b
$ forall num.
(Eq num, IntegralExp num) =>
IxFun num -> FlatSlice num -> IxFun num
IxFun.flatSlice IxFun (TExp Int64)
ixfun FlatSlice (TExp Int64)
slice
data ArrayEntry = ArrayEntry
{ ArrayEntry -> MemLoc
entryArrayLoc :: MemLoc,
ArrayEntry -> PrimType
entryArrayElemType :: PrimType
}
deriving (Int -> ArrayEntry -> ShowS
[ArrayEntry] -> ShowS
ArrayEntry -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [ArrayEntry] -> ShowS
$cshowList :: [ArrayEntry] -> ShowS
show :: ArrayEntry -> [Char]
$cshow :: ArrayEntry -> [Char]
showsPrec :: Int -> ArrayEntry -> ShowS
$cshowsPrec :: Int -> ArrayEntry -> ShowS
Show)
entryArrayShape :: ArrayEntry -> [Imp.DimSize]
entryArrayShape :: ArrayEntry -> [SubExp]
entryArrayShape = MemLoc -> [SubExp]
memLocShape forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArrayEntry -> MemLoc
entryArrayLoc
newtype MemEntry = MemEntry {MemEntry -> Space
entryMemSpace :: Imp.Space}
deriving (Int -> MemEntry -> ShowS
[MemEntry] -> ShowS
MemEntry -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [MemEntry] -> ShowS
$cshowList :: [MemEntry] -> ShowS
show :: MemEntry -> [Char]
$cshow :: MemEntry -> [Char]
showsPrec :: Int -> MemEntry -> ShowS
$cshowsPrec :: Int -> MemEntry -> ShowS
Show)
newtype ScalarEntry = ScalarEntry
{ ScalarEntry -> PrimType
entryScalarType :: PrimType
}
deriving (Int -> ScalarEntry -> ShowS
[ScalarEntry] -> ShowS
ScalarEntry -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [ScalarEntry] -> ShowS
$cshowList :: [ScalarEntry] -> ShowS
show :: ScalarEntry -> [Char]
$cshow :: ScalarEntry -> [Char]
showsPrec :: Int -> ScalarEntry -> ShowS
$cshowsPrec :: Int -> ScalarEntry -> ShowS
Show)
data VarEntry rep
= ArrayVar (Maybe (Exp rep)) ArrayEntry
| ScalarVar (Maybe (Exp rep)) ScalarEntry
| MemVar (Maybe (Exp rep)) MemEntry
| AccVar (Maybe (Exp rep)) (VName, Shape, [Type])
deriving (Int -> VarEntry rep -> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
forall k (rep :: k). RepTypes rep => Int -> VarEntry rep -> ShowS
forall k (rep :: k). RepTypes rep => [VarEntry rep] -> ShowS
forall k (rep :: k). RepTypes rep => VarEntry rep -> [Char]
showList :: [VarEntry rep] -> ShowS
$cshowList :: forall k (rep :: k). RepTypes rep => [VarEntry rep] -> ShowS
show :: VarEntry rep -> [Char]
$cshow :: forall k (rep :: k). RepTypes rep => VarEntry rep -> [Char]
showsPrec :: Int -> VarEntry rep -> ShowS
$cshowsPrec :: forall k (rep :: k). RepTypes rep => Int -> VarEntry rep -> ShowS
Show)
data ValueDestination
= ScalarDestination VName
| MemoryDestination VName
|
ArrayDestination (Maybe MemLoc)
deriving (Int -> ValueDestination -> ShowS
[ValueDestination] -> ShowS
ValueDestination -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [ValueDestination] -> ShowS
$cshowList :: [ValueDestination] -> ShowS
show :: ValueDestination -> [Char]
$cshow :: ValueDestination -> [Char]
showsPrec :: Int -> ValueDestination -> ShowS
$cshowsPrec :: Int -> ValueDestination -> ShowS
Show)
data Env rep r op = Env
{ forall {k} (rep :: k) r op. Env rep r op -> ExpCompiler rep r op
envExpCompiler :: ExpCompiler rep r op,
forall {k} (rep :: k) r op. Env rep r op -> StmsCompiler rep r op
envStmsCompiler :: StmsCompiler rep r op,
forall {k} (rep :: k) r op. Env rep r op -> OpCompiler rep r op
envOpCompiler :: OpCompiler rep r op,
forall {k} (rep :: k) r op. Env rep r op -> CopyCompiler rep r op
envCopyCompiler :: CopyCompiler rep r op,
forall {k} (rep :: k) r op.
Env rep r op -> Map Space (AllocCompiler rep r op)
envAllocCompilers :: M.Map Space (AllocCompiler rep r op),
forall {k} (rep :: k) r op. Env rep r op -> Space
envDefaultSpace :: Imp.Space,
forall {k} (rep :: k) r op. Env rep r op -> Volatility
envVolatility :: Imp.Volatility,
forall {k} (rep :: k) r op. Env rep r op -> r
envEnv :: r,
forall {k} (rep :: k) r op. Env rep r op -> Maybe Name
envFunction :: Maybe Name,
forall {k} (rep :: k) r op. Env rep r op -> Attrs
envAttrs :: Attrs
}
newEnv :: r -> Operations rep r op -> Imp.Space -> Env rep r op
newEnv :: forall {k} r (rep :: k) op.
r -> Operations rep r op -> Space -> Env rep r op
newEnv r
r Operations rep r op
ops Space
ds =
Env
{ envExpCompiler :: ExpCompiler rep r op
envExpCompiler = forall {k} (rep :: k) r op.
Operations rep r op -> ExpCompiler rep r op
opsExpCompiler Operations rep r op
ops,
envStmsCompiler :: StmsCompiler rep r op
envStmsCompiler = forall {k} (rep :: k) r op.
Operations rep r op -> StmsCompiler rep r op
opsStmsCompiler Operations rep r op
ops,
envOpCompiler :: OpCompiler rep r op
envOpCompiler = forall {k} (rep :: k) r op.
Operations rep r op -> OpCompiler rep r op
opsOpCompiler Operations rep r op
ops,
envCopyCompiler :: CopyCompiler rep r op
envCopyCompiler = forall {k} (rep :: k) r op.
Operations rep r op -> CopyCompiler rep r op
opsCopyCompiler Operations rep r op
ops,
envAllocCompilers :: Map Space (AllocCompiler rep r op)
envAllocCompilers = forall a. Monoid a => a
mempty,
envDefaultSpace :: Space
envDefaultSpace = Space
ds,
envVolatility :: Volatility
envVolatility = Volatility
Imp.Nonvolatile,
envEnv :: r
envEnv = r
r,
envFunction :: Maybe Name
envFunction = forall a. Maybe a
Nothing,
envAttrs :: Attrs
envAttrs = forall a. Monoid a => a
mempty
}
type VTable rep = M.Map VName (VarEntry rep)
data ImpState rep r op = ImpState
{ forall {k} {k} (rep :: k) (r :: k) op.
ImpState rep r op -> VTable rep
stateVTable :: VTable rep,
forall {k} {k} (rep :: k) (r :: k) op.
ImpState rep r op -> Functions op
stateFunctions :: Imp.Functions op,
forall {k} {k} (rep :: k) (r :: k) op. ImpState rep r op -> Code op
stateCode :: Imp.Code op,
forall {k} {k} (rep :: k) (r :: k) op.
ImpState rep r op -> Warnings
stateWarnings :: Warnings,
forall {k} {k} (rep :: k) (r :: k) op.
ImpState rep r op
-> Map VName ([VName], Maybe (Lambda rep, [SubExp]))
stateAccs :: M.Map VName ([VName], Maybe (Lambda rep, [SubExp])),
forall {k} {k} (rep :: k) (r :: k) op.
ImpState rep r op -> VNameSource
stateNameSource :: VNameSource
}
newState :: VNameSource -> ImpState rep r op
newState :: forall {k} {k} (rep :: k) (r :: k) op.
VNameSource -> ImpState rep r op
newState = forall {k} {k} (rep :: k) (r :: k) op.
VTable rep
-> Functions op
-> Code op
-> Warnings
-> Map VName ([VName], Maybe (Lambda rep, [SubExp]))
-> VNameSource
-> ImpState rep r op
ImpState forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty
newtype ImpM rep r op a
= ImpM (ReaderT (Env rep r op) (State (ImpState rep r op)) a)
deriving
( forall k (rep :: k) r op a b.
a -> ImpM rep r op b -> ImpM rep r op a
forall k (rep :: k) r op a b.
(a -> b) -> ImpM rep r op a -> ImpM rep r op b
forall a b. a -> ImpM rep r op b -> ImpM rep r op a
forall a b. (a -> b) -> ImpM rep r op a -> ImpM rep r op b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> ImpM rep r op b -> ImpM rep r op a
$c<$ :: forall k (rep :: k) r op a b.
a -> ImpM rep r op b -> ImpM rep r op a
fmap :: forall a b. (a -> b) -> ImpM rep r op a -> ImpM rep r op b
$cfmap :: forall k (rep :: k) r op a b.
(a -> b) -> ImpM rep r op a -> ImpM rep r op b
Functor,
forall a. a -> ImpM rep r op a
forall k (rep :: k) r op. Functor (ImpM rep r op)
forall k (rep :: k) r op a. a -> ImpM rep r op a
forall k (rep :: k) r op a b.
ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op a
forall k (rep :: k) r op a b.
ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op b
forall k (rep :: k) r op a b.
ImpM rep r op (a -> b) -> ImpM rep r op a -> ImpM rep r op b
forall k (rep :: k) r op a b c.
(a -> b -> c)
-> ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op c
forall a b. ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op a
forall a b. ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op b
forall a b.
ImpM rep r op (a -> b) -> ImpM rep r op a -> ImpM rep r op b
forall a b c.
(a -> b -> c)
-> ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b. ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op a
$c<* :: forall k (rep :: k) r op a b.
ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op a
*> :: forall a b. ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op b
$c*> :: forall k (rep :: k) r op a b.
ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op b
liftA2 :: forall a b c.
(a -> b -> c)
-> ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op c
$cliftA2 :: forall k (rep :: k) r op a b c.
(a -> b -> c)
-> ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op c
<*> :: forall a b.
ImpM rep r op (a -> b) -> ImpM rep r op a -> ImpM rep r op b
$c<*> :: forall k (rep :: k) r op a b.
ImpM rep r op (a -> b) -> ImpM rep r op a -> ImpM rep r op b
pure :: forall a. a -> ImpM rep r op a
$cpure :: forall k (rep :: k) r op a. a -> ImpM rep r op a
Applicative,
forall a. a -> ImpM rep r op a
forall k (rep :: k) r op. Applicative (ImpM rep r op)
forall k (rep :: k) r op a. a -> ImpM rep r op a
forall k (rep :: k) r op a b.
ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op b
forall k (rep :: k) r op a b.
ImpM rep r op a -> (a -> ImpM rep r op b) -> ImpM rep r op b
forall a b. ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op b
forall a b.
ImpM rep r op a -> (a -> ImpM rep r op b) -> ImpM rep r op b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> ImpM rep r op a
$creturn :: forall k (rep :: k) r op a. a -> ImpM rep r op a
>> :: forall a b. ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op b
$c>> :: forall k (rep :: k) r op a b.
ImpM rep r op a -> ImpM rep r op b -> ImpM rep r op b
>>= :: forall a b.
ImpM rep r op a -> (a -> ImpM rep r op b) -> ImpM rep r op b
$c>>= :: forall k (rep :: k) r op a b.
ImpM rep r op a -> (a -> ImpM rep r op b) -> ImpM rep r op b
Monad,
MonadState (ImpState rep r op),
MonadReader (Env rep r op)
)
instance MonadFreshNames (ImpM rep r op) where
getNameSource :: ImpM rep r op VNameSource
getNameSource = forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall {k} {k} (rep :: k) (r :: k) op.
ImpState rep r op -> VNameSource
stateNameSource
putNameSource :: VNameSource -> ImpM rep r op ()
putNameSource VNameSource
src = forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s -> ImpState rep r op
s {stateNameSource :: VNameSource
stateNameSource = VNameSource
src}
instance HasScope SOACS (ImpM rep r op) where
askScope :: ImpM rep r op (Scope SOACS)
askScope = forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall a b. (a -> b) -> a -> b
$ forall a b k. (a -> b) -> Map k a -> Map k b
M.map (forall {k} (rep :: k). LetDec rep -> NameInfo rep
LetName forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} {rep :: k}. VarEntry rep -> Type
entryType) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} {k} (rep :: k) (r :: k) op.
ImpState rep r op -> VTable rep
stateVTable
where
entryType :: VarEntry rep -> Type
entryType (MemVar Maybe (Exp rep)
_ MemEntry
memEntry) =
forall shape u. Space -> TypeBase shape u
Mem (MemEntry -> Space
entryMemSpace MemEntry
memEntry)
entryType (ArrayVar Maybe (Exp rep)
_ ArrayEntry
arrayEntry) =
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array
(ArrayEntry -> PrimType
entryArrayElemType ArrayEntry
arrayEntry)
(forall d. [d] -> ShapeBase d
Shape forall a b. (a -> b) -> a -> b
$ ArrayEntry -> [SubExp]
entryArrayShape ArrayEntry
arrayEntry)
NoUniqueness
NoUniqueness
entryType (ScalarVar Maybe (Exp rep)
_ ScalarEntry
scalarEntry) =
forall shape u. PrimType -> TypeBase shape u
Prim forall a b. (a -> b) -> a -> b
$ ScalarEntry -> PrimType
entryScalarType ScalarEntry
scalarEntry
entryType (AccVar Maybe (Exp rep)
_ (VName
acc, Shape
ispace, [Type]
ts)) =
forall shape u. VName -> Shape -> [Type] -> u -> TypeBase shape u
Acc VName
acc Shape
ispace [Type]
ts NoUniqueness
NoUniqueness
runImpM ::
ImpM rep r op a ->
r ->
Operations rep r op ->
Imp.Space ->
ImpState rep r op ->
(a, ImpState rep r op)
runImpM :: forall {k} (rep :: k) r op a.
ImpM rep r op a
-> r
-> Operations rep r op
-> Space
-> ImpState rep r op
-> (a, ImpState rep r op)
runImpM (ImpM ReaderT (Env rep r op) (State (ImpState rep r op)) a
m) r
r Operations rep r op
ops Space
space = forall s a. State s a -> s -> (a, s)
runState (forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Env rep r op) (State (ImpState rep r op)) a
m forall a b. (a -> b) -> a -> b
$ forall {k} r (rep :: k) op.
r -> Operations rep r op -> Space -> Env rep r op
newEnv r
r Operations rep r op
ops Space
space)
subImpM_ ::
r' ->
Operations rep r' op' ->
ImpM rep r' op' a ->
ImpM rep r op (Imp.Code op')
subImpM_ :: forall {k} r' (rep :: k) op' a r op.
r'
-> Operations rep r' op'
-> ImpM rep r' op' a
-> ImpM rep r op (Code op')
subImpM_ r'
r Operations rep r' op'
ops ImpM rep r' op' a
m = forall a b. (a, b) -> b
snd forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} r' (rep :: k) op' a r op.
r'
-> Operations rep r' op'
-> ImpM rep r' op' a
-> ImpM rep r op (a, Code op')
subImpM r'
r Operations rep r' op'
ops ImpM rep r' op' a
m
subImpM ::
r' ->
Operations rep r' op' ->
ImpM rep r' op' a ->
ImpM rep r op (a, Imp.Code op')
subImpM :: forall {k} r' (rep :: k) op' a r op.
r'
-> Operations rep r' op'
-> ImpM rep r' op' a
-> ImpM rep r op (a, Code op')
subImpM r'
r Operations rep r' op'
ops (ImpM ReaderT (Env rep r' op') (State (ImpState rep r' op')) a
m) = do
Env rep r op
env <- forall r (m :: * -> *). MonadReader r m => m r
ask
ImpState rep r op
s <- forall s (m :: * -> *). MonadState s m => m s
get
let env' :: Env rep r' op'
env' =
Env rep r op
env
{ envExpCompiler :: ExpCompiler rep r' op'
envExpCompiler = forall {k} (rep :: k) r op.
Operations rep r op -> ExpCompiler rep r op
opsExpCompiler Operations rep r' op'
ops,
envStmsCompiler :: StmsCompiler rep r' op'
envStmsCompiler = forall {k} (rep :: k) r op.
Operations rep r op -> StmsCompiler rep r op
opsStmsCompiler Operations rep r' op'
ops,
envCopyCompiler :: CopyCompiler rep r' op'
envCopyCompiler = forall {k} (rep :: k) r op.
Operations rep r op -> CopyCompiler rep r op
opsCopyCompiler Operations rep r' op'
ops,
envOpCompiler :: OpCompiler rep r' op'
envOpCompiler = forall {k} (rep :: k) r op.
Operations rep r op -> OpCompiler rep r op
opsOpCompiler Operations rep r' op'
ops,
envAllocCompilers :: Map Space (AllocCompiler rep r' op')
envAllocCompilers = forall {k} (rep :: k) r op.
Operations rep r op -> Map Space (AllocCompiler rep r op)
opsAllocCompilers Operations rep r' op'
ops,
envEnv :: r'
envEnv = r'
r
}
s' :: ImpState rep r' op'
s' =
ImpState
{ stateVTable :: VTable rep
stateVTable = forall {k} {k} (rep :: k) (r :: k) op.
ImpState rep r op -> VTable rep
stateVTable ImpState rep r op
s,
stateFunctions :: Functions op'
stateFunctions = forall a. Monoid a => a
mempty,
stateCode :: Code op'
stateCode = forall a. Monoid a => a
mempty,
stateNameSource :: VNameSource
stateNameSource = forall {k} {k} (rep :: k) (r :: k) op.
ImpState rep r op -> VNameSource
stateNameSource ImpState rep r op
s,
stateWarnings :: Warnings
stateWarnings = forall a. Monoid a => a
mempty,
stateAccs :: Map VName ([VName], Maybe (Lambda rep, [SubExp]))
stateAccs = forall {k} {k} (rep :: k) (r :: k) op.
ImpState rep r op
-> Map VName ([VName], Maybe (Lambda rep, [SubExp]))
stateAccs ImpState rep r op
s
}
(a
x, ImpState rep r' op'
s'') = forall s a. State s a -> s -> (a, s)
runState (forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Env rep r' op') (State (ImpState rep r' op')) a
m Env rep r' op'
env') ImpState rep r' op'
s'
forall (m :: * -> *). MonadFreshNames m => VNameSource -> m ()
putNameSource forall a b. (a -> b) -> a -> b
$ forall {k} {k} (rep :: k) (r :: k) op.
ImpState rep r op -> VNameSource
stateNameSource ImpState rep r' op'
s''
forall {k} (rep :: k) r op. Warnings -> ImpM rep r op ()
warnings forall a b. (a -> b) -> a -> b
$ forall {k} {k} (rep :: k) (r :: k) op.
ImpState rep r op -> Warnings
stateWarnings ImpState rep r' op'
s''
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
x, forall {k} {k} (rep :: k) (r :: k) op. ImpState rep r op -> Code op
stateCode ImpState rep r' op'
s'')
collect :: ImpM rep r op () -> ImpM rep r op (Imp.Code op)
collect :: forall {k} (rep :: k) r op.
ImpM rep r op () -> ImpM rep r op (Code op)
collect = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a, b) -> b
snd forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k) r op a.
ImpM rep r op a -> ImpM rep r op (a, Code op)
collect'
collect' :: ImpM rep r op a -> ImpM rep r op (a, Imp.Code op)
collect' :: forall {k} (rep :: k) r op a.
ImpM rep r op a -> ImpM rep r op (a, Code op)
collect' ImpM rep r op a
m = do
Code op
prev_code <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall {k} {k} (rep :: k) (r :: k) op. ImpState rep r op -> Code op
stateCode
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s -> ImpState rep r op
s {stateCode :: Code op
stateCode = forall a. Monoid a => a
mempty}
a
x <- ImpM rep r op a
m
Code op
new_code <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall {k} {k} (rep :: k) (r :: k) op. ImpState rep r op -> Code op
stateCode
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s -> ImpState rep r op
s {stateCode :: Code op
stateCode = Code op
prev_code}
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
x, Code op
new_code)
comment :: T.Text -> ImpM rep r op () -> ImpM rep r op ()
Text
desc ImpM rep r op ()
m = do
Code op
code <- forall {k} (rep :: k) r op.
ImpM rep r op () -> ImpM rep r op (Code op)
collect ImpM rep r op ()
m
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. Text -> Code a -> Code a
Imp.Comment Text
desc Code op
code
emit :: Imp.Code op -> ImpM rep r op ()
emit :: forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit Code op
code = forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s -> ImpState rep r op
s {stateCode :: Code op
stateCode = forall {k} {k} (rep :: k) (r :: k) op. ImpState rep r op -> Code op
stateCode ImpState rep r op
s forall a. Semigroup a => a -> a -> a
<> Code op
code}
warnings :: Warnings -> ImpM rep r op ()
warnings :: forall {k} (rep :: k) r op. Warnings -> ImpM rep r op ()
warnings Warnings
ws = forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s -> ImpState rep r op
s {stateWarnings :: Warnings
stateWarnings = Warnings
ws forall a. Semigroup a => a -> a -> a
<> forall {k} {k} (rep :: k) (r :: k) op.
ImpState rep r op -> Warnings
stateWarnings ImpState rep r op
s}
warn :: Located loc => loc -> [loc] -> T.Text -> ImpM rep r op ()
warn :: forall {k} loc (rep :: k) r op.
Located loc =>
loc -> [loc] -> Text -> ImpM rep r op ()
warn loc
loc [loc]
locs Text
problem =
forall {k} (rep :: k) r op. Warnings -> ImpM rep r op ()
warnings forall a b. (a -> b) -> a -> b
$ SrcLoc -> [SrcLoc] -> Doc () -> Warnings
singleWarning' (forall a. Located a => a -> SrcLoc
srclocOf loc
loc) (forall a b. (a -> b) -> [a] -> [b]
map forall a. Located a => a -> SrcLoc
srclocOf [loc]
locs) (forall a ann. Pretty a => a -> Doc ann
pretty Text
problem)
emitFunction :: Name -> Imp.Function op -> ImpM rep r op ()
emitFunction :: forall {k} op (rep :: k) r. Name -> Function op -> ImpM rep r op ()
emitFunction Name
fname Function op
fun = do
Imp.Functions [(Name, Function op)]
fs <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall {k} {k} (rep :: k) (r :: k) op.
ImpState rep r op -> Functions op
stateFunctions
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s -> ImpState rep r op
s {stateFunctions :: Functions op
stateFunctions = forall a. [(Name, Function a)] -> Functions a
Imp.Functions forall a b. (a -> b) -> a -> b
$ (Name
fname, Function op
fun) forall a. a -> [a] -> [a]
: [(Name, Function op)]
fs}
hasFunction :: Name -> ImpM rep r op Bool
hasFunction :: forall {k} (rep :: k) r op. Name -> ImpM rep r op Bool
hasFunction Name
fname = forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s ->
let Imp.Functions [(Name, Function op)]
fs = forall {k} {k} (rep :: k) (r :: k) op.
ImpState rep r op -> Functions op
stateFunctions ImpState rep r op
s
in forall a. Maybe a -> Bool
isJust forall a b. (a -> b) -> a -> b
$ forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Name
fname [(Name, Function op)]
fs
constsVTable :: Mem rep inner => Stms rep -> VTable rep
constsVTable :: forall {k} (rep :: k) inner.
Mem rep inner =>
Stms rep -> VTable rep
constsVTable = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap forall {k} {rep :: k}.
HasLetDecMem (LetDec rep) =>
Stm rep -> Map VName (VarEntry rep)
stmVtable
where
stmVtable :: Stm rep -> Map VName (VarEntry rep)
stmVtable (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ Exp rep
e) =
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (forall {k} {t} {rep :: k}.
HasLetDecMem t =>
Exp rep -> PatElem t -> Map VName (VarEntry rep)
peVtable Exp rep
e) forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat
peVtable :: Exp rep -> PatElem t -> Map VName (VarEntry rep)
peVtable Exp rep
e (PatElem VName
name t
dec) =
forall k a. k -> a -> Map k a
M.singleton VName
name forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Maybe (Exp rep) -> LParamMem -> VarEntry rep
memBoundToVarEntry (forall a. a -> Maybe a
Just Exp rep
e) forall a b. (a -> b) -> a -> b
$ forall t. HasLetDecMem t => t -> LParamMem
letDecMem t
dec
compileProg ::
(Mem rep inner, FreeIn op, MonadFreshNames m) =>
r ->
Operations rep r op ->
Imp.Space ->
Prog rep ->
m (Warnings, Imp.Definitions op)
compileProg :: forall {k} (rep :: k) inner op (m :: * -> *) r.
(Mem rep inner, FreeIn op, MonadFreshNames m) =>
r
-> Operations rep r op
-> Space
-> Prog rep
-> m (Warnings, Definitions op)
compileProg r
r Operations rep r op
ops Space
space (Prog OpaqueTypes
types Stms rep
consts [FunDef rep]
funs) =
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
let ([()]
_, [ImpState rep r op]
ss) =
forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ forall b a. Strategy b -> (a -> b) -> [a] -> [b]
parMap forall a. Strategy a
rpar (VNameSource -> FunDef rep -> ((), ImpState rep r op)
compileFunDef' VNameSource
src) [FunDef rep]
funs
free_in_funs :: Names
free_in_funs =
forall a. FreeIn a => a -> Names
freeIn forall a b. (a -> b) -> a -> b
$ forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall {k} {k} (rep :: k) (r :: k) op.
ImpState rep r op -> Functions op
stateFunctions [ImpState rep r op]
ss
(Constants op
consts', ImpState rep r op
s') =
forall {k} (rep :: k) r op a.
ImpM rep r op a
-> r
-> Operations rep r op
-> Space
-> ImpState rep r op
-> (a, ImpState rep r op)
runImpM (forall {k} (rep :: k) r op.
Names -> Stms rep -> ImpM rep r op (Constants op)
compileConsts Names
free_in_funs Stms rep
consts) r
r Operations rep r op
ops Space
space forall a b. (a -> b) -> a -> b
$
forall {k} {k} {k} {k} {rep :: k} {r :: k} {op} {rep :: k}
{r :: k}.
[ImpState rep r op] -> ImpState rep r op
combineStates [ImpState rep r op]
ss
in ( ( forall {k} {k} (rep :: k) (r :: k) op.
ImpState rep r op -> Warnings
stateWarnings ImpState rep r op
s',
forall a.
OpaqueTypes -> Constants a -> Functions a -> Definitions a
Imp.Definitions OpaqueTypes
types Constants op
consts' (forall {k} {k} (rep :: k) (r :: k) op.
ImpState rep r op -> Functions op
stateFunctions ImpState rep r op
s')
),
forall {k} {k} (rep :: k) (r :: k) op.
ImpState rep r op -> VNameSource
stateNameSource ImpState rep r op
s'
)
where
compileFunDef' :: VNameSource -> FunDef rep -> ((), ImpState rep r op)
compileFunDef' VNameSource
src FunDef rep
fdef =
forall {k} (rep :: k) r op a.
ImpM rep r op a
-> r
-> Operations rep r op
-> Space
-> ImpState rep r op
-> (a, ImpState rep r op)
runImpM
(forall {k} (rep :: k) inner r op.
Mem rep inner =>
OpaqueTypes -> FunDef rep -> ImpM rep r op ()
compileFunDef OpaqueTypes
types FunDef rep
fdef)
r
r
Operations rep r op
ops
Space
space
(forall {k} {k} (rep :: k) (r :: k) op.
VNameSource -> ImpState rep r op
newState VNameSource
src) {stateVTable :: VTable rep
stateVTable = forall {k} (rep :: k) inner.
Mem rep inner =>
Stms rep -> VTable rep
constsVTable Stms rep
consts}
combineStates :: [ImpState rep r op] -> ImpState rep r op
combineStates [ImpState rep r op]
ss =
let Imp.Functions [(Name, Function op)]
funs' = forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall {k} {k} (rep :: k) (r :: k) op.
ImpState rep r op -> Functions op
stateFunctions [ImpState rep r op]
ss
src :: VNameSource
src = forall a. Monoid a => [a] -> a
mconcat (forall a b. (a -> b) -> [a] -> [b]
map forall {k} {k} (rep :: k) (r :: k) op.
ImpState rep r op -> VNameSource
stateNameSource [ImpState rep r op]
ss)
in (forall {k} {k} (rep :: k) (r :: k) op.
VNameSource -> ImpState rep r op
newState VNameSource
src)
{ stateFunctions :: Functions op
stateFunctions =
forall a. [(Name, Function a)] -> Functions a
Imp.Functions forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [(k, a)]
M.toList forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(Name, Function op)]
funs',
stateWarnings :: Warnings
stateWarnings =
forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall {k} {k} (rep :: k) (r :: k) op.
ImpState rep r op -> Warnings
stateWarnings [ImpState rep r op]
ss
}
compileConsts :: Names -> Stms rep -> ImpM rep r op (Imp.Constants op)
compileConsts :: forall {k} (rep :: k) r op.
Names -> Stms rep -> ImpM rep r op (Constants op)
compileConsts Names
used_consts Stms rep
stms = do
Code op
code <- forall {k} (rep :: k) r op.
ImpM rep r op () -> ImpM rep r op (Code op)
collect forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
used_consts Stms rep
stms forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry forall a. [Param] -> Code a -> Constants a
Imp.Constants forall a b. (a -> b) -> a -> b
$ forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first forall a. DList a -> [a]
DL.toList forall a b. (a -> b) -> a -> b
$ Code op -> (DList Param, Code op)
extract Code op
code
where
extract :: Code op -> (DList Param, Code op)
extract (Code op
x Imp.:>>: Code op
y) =
Code op -> (DList Param, Code op)
extract Code op
x forall a. Semigroup a => a -> a -> a
<> Code op -> (DList Param, Code op)
extract Code op
y
extract (Imp.DeclareMem VName
name Space
space)
| VName
name VName -> Names -> Bool
`nameIn` Names
used_consts =
( forall a. a -> DList a
DL.singleton forall a b. (a -> b) -> a -> b
$ VName -> Space -> Param
Imp.MemParam VName
name Space
space,
forall a. Monoid a => a
mempty
)
extract (Imp.DeclareScalar VName
name Volatility
_ PrimType
t)
| VName
name VName -> Names -> Bool
`nameIn` Names
used_consts =
( forall a. a -> DList a
DL.singleton forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Param
Imp.ScalarParam VName
name PrimType
t,
forall a. Monoid a => a
mempty
)
extract Code op
s =
(forall a. Monoid a => a
mempty, Code op
s)
lookupOpaqueType :: String -> OpaqueTypes -> OpaqueType
lookupOpaqueType :: [Char] -> OpaqueTypes -> OpaqueType
lookupOpaqueType [Char]
v (OpaqueTypes [([Char], OpaqueType)]
types) =
case forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup [Char]
v [([Char], OpaqueType)]
types of
Just OpaqueType
t -> OpaqueType
t
Maybe OpaqueType
Nothing -> forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"Unknown opaque type: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show [Char]
v
valueTypeSign :: ValueType -> Signedness
valueTypeSign :: ValueType -> Signedness
valueTypeSign (ValueType Signedness
sign Rank
_ PrimType
_) = Signedness
sign
entryPointSignedness :: OpaqueTypes -> EntryPointType -> [Signedness]
entryPointSignedness :: OpaqueTypes -> EntryPointType -> [Signedness]
entryPointSignedness OpaqueTypes
_ (TypeTransparent ValueType
vt) = [ValueType -> Signedness
valueTypeSign ValueType
vt]
entryPointSignedness OpaqueTypes
types (TypeOpaque [Char]
desc) =
case [Char] -> OpaqueTypes -> OpaqueType
lookupOpaqueType [Char]
desc OpaqueTypes
types of
OpaqueType [ValueType]
vts -> forall a b. (a -> b) -> [a] -> [b]
map ValueType -> Signedness
valueTypeSign [ValueType]
vts
OpaqueRecord [(Name, EntryPointType)]
fs -> forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (OpaqueTypes -> EntryPointType -> [Signedness]
entryPointSignedness OpaqueTypes
types forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) [(Name, EntryPointType)]
fs
entryPointSize :: OpaqueTypes -> EntryPointType -> Int
entryPointSize :: OpaqueTypes -> EntryPointType -> Int
entryPointSize OpaqueTypes
_ (TypeTransparent ValueType
_) = Int
1
entryPointSize OpaqueTypes
types (TypeOpaque [Char]
desc) =
case [Char] -> OpaqueTypes -> OpaqueType
lookupOpaqueType [Char]
desc OpaqueTypes
types of
OpaqueType [ValueType]
vts -> forall (t :: * -> *) a. Foldable t => t a -> Int
length [ValueType]
vts
OpaqueRecord [(Name, EntryPointType)]
fs -> forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (OpaqueTypes -> EntryPointType -> Int
entryPointSize OpaqueTypes
types forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) [(Name, EntryPointType)]
fs
compileInParam ::
Mem rep inner =>
FParam rep ->
ImpM rep r op (Either Imp.Param ArrayDecl)
compileInParam :: forall {k} (rep :: k) inner r op.
Mem rep inner =>
FParam rep -> ImpM rep r op (Either Param ArrayDecl)
compileInParam FParam rep
fparam = case forall dec. Param dec -> dec
paramDec FParam rep
fparam of
MemPrim PrimType
bt ->
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a b. a -> Either a b
Left forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Param
Imp.ScalarParam VName
name PrimType
bt
MemMem Space
space ->
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a b. a -> Either a b
Left forall a b. (a -> b) -> a -> b
$ VName -> Space -> Param
Imp.MemParam VName
name Space
space
MemArray PrimType
bt Shape
shape Uniqueness
_ (ArrayIn VName
mem IxFun (TExp Int64)
ixfun) ->
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a b. b -> Either a b
Right forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> MemLoc -> ArrayDecl
ArrayDecl VName
name PrimType
bt forall a b. (a -> b) -> a -> b
$ VName -> [SubExp] -> IxFun (TExp Int64) -> MemLoc
MemLoc VName
mem (forall d. ShapeBase d -> [d]
shapeDims Shape
shape) IxFun (TExp Int64)
ixfun
MemAcc {} ->
forall a. HasCallStack => [Char] -> a
error [Char]
"Functions may not have accumulator parameters."
where
name :: VName
name = forall dec. Param dec -> VName
paramName FParam rep
fparam
data ArrayDecl = ArrayDecl VName PrimType MemLoc
compileInParams ::
Mem rep inner =>
OpaqueTypes ->
[FParam rep] ->
Maybe [EntryParam] ->
ImpM rep r op ([Imp.Param], [ArrayDecl], Maybe [((Name, Uniqueness), Imp.ExternalValue)])
compileInParams :: forall {k} (rep :: k) inner r op.
Mem rep inner =>
OpaqueTypes
-> [FParam rep]
-> Maybe [EntryParam]
-> ImpM
rep
r
op
([Param], [ArrayDecl], Maybe [((Name, Uniqueness), ExternalValue)])
compileInParams OpaqueTypes
types [FParam rep]
params Maybe [EntryParam]
eparams = do
([Param]
inparams, [ArrayDecl]
arrayds) <- forall a b. [Either a b] -> ([a], [b])
partitionEithers forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {k} (rep :: k) inner r op.
Mem rep inner =>
FParam rep -> ImpM rep r op (Either Param ArrayDecl)
compileInParam [FParam rep]
params
let findArray :: VName -> Maybe ArrayDecl
findArray VName
x = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (VName -> ArrayDecl -> Bool
isArrayDecl VName
x) [ArrayDecl]
arrayds
summaries :: Map VName Space
summaries = forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$ forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe forall {d} {u} {ret}.
Param (MemInfo d u ret) -> Maybe (VName, Space)
memSummary [FParam rep]
params
where
memSummary :: Param (MemInfo d u ret) -> Maybe (VName, Space)
memSummary Param (MemInfo d u ret)
param
| MemMem Space
space <- forall dec. Param dec -> dec
paramDec Param (MemInfo d u ret)
param =
forall a. a -> Maybe a
Just (forall dec. Param dec -> VName
paramName Param (MemInfo d u ret)
param, Space
space)
| Bool
otherwise =
forall a. Maybe a
Nothing
findMemInfo :: VName -> Maybe Space
findMemInfo :: VName -> Maybe Space
findMemInfo = forall a b c. (a -> b -> c) -> b -> a -> c
flip forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Map VName Space
summaries
mkValueDesc :: Param FParamMem -> Signedness -> Maybe ValueDesc
mkValueDesc Param FParamMem
fparam Signedness
signedness =
case (VName -> Maybe ArrayDecl
findArray forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param FParamMem
fparam, forall dec. Typed dec => Param dec -> Type
paramType Param FParamMem
fparam) of
(Just (ArrayDecl VName
_ PrimType
bt (MemLoc VName
mem [SubExp]
shape IxFun (TExp Int64)
_)), Type
_) -> do
Space
memspace <- VName -> Maybe Space
findMemInfo VName
mem
forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ VName -> Space -> PrimType -> Signedness -> [SubExp] -> ValueDesc
Imp.ArrayValue VName
mem Space
memspace PrimType
bt Signedness
signedness [SubExp]
shape
(Maybe ArrayDecl
_, Prim PrimType
bt) ->
forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ PrimType -> Signedness -> VName -> ValueDesc
Imp.ScalarValue PrimType
bt Signedness
signedness forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param FParamMem
fparam
(Maybe ArrayDecl, Type)
_ ->
forall a. Maybe a
Nothing
mkExts :: [EntryParam]
-> [Param FParamMem] -> [((Name, Uniqueness), ExternalValue)]
mkExts (EntryParam Name
v Uniqueness
u et :: EntryPointType
et@(TypeOpaque [Char]
desc) : [EntryParam]
epts) [Param FParamMem]
fparams =
let signs :: [Signedness]
signs = OpaqueTypes -> EntryPointType -> [Signedness]
entryPointSignedness OpaqueTypes
types EntryPointType
et
n :: Int
n = OpaqueTypes -> EntryPointType -> Int
entryPointSize OpaqueTypes
types EntryPointType
et
([Param FParamMem]
fparams', [Param FParamMem]
rest) = forall a. Int -> [a] -> ([a], [a])
splitAt Int
n [Param FParamMem]
fparams
in ( (Name
v, Uniqueness
u),
[Char] -> [ValueDesc] -> ExternalValue
Imp.OpaqueValue
[Char]
desc
(forall a. [Maybe a] -> [a]
catMaybes forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Param FParamMem -> Signedness -> Maybe ValueDesc
mkValueDesc [Param FParamMem]
fparams' [Signedness]
signs)
)
forall a. a -> [a] -> [a]
: [EntryParam]
-> [Param FParamMem] -> [((Name, Uniqueness), ExternalValue)]
mkExts [EntryParam]
epts [Param FParamMem]
rest
mkExts (EntryParam Name
v Uniqueness
u (TypeTransparent (ValueType Signedness
s Rank
_ PrimType
_)) : [EntryParam]
epts) (Param FParamMem
fparam : [Param FParamMem]
fparams) =
forall a. Maybe a -> [a]
maybeToList (((Name
v, Uniqueness
u),) forall b c a. (b -> c) -> (a -> b) -> a -> c
. ValueDesc -> ExternalValue
Imp.TransparentValue forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Param FParamMem -> Signedness -> Maybe ValueDesc
mkValueDesc Param FParamMem
fparam Signedness
s)
forall a. [a] -> [a] -> [a]
++ [EntryParam]
-> [Param FParamMem] -> [((Name, Uniqueness), ExternalValue)]
mkExts [EntryParam]
epts [Param FParamMem]
fparams
mkExts [EntryParam]
_ [Param FParamMem]
_ = []
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( [Param]
inparams,
[ArrayDecl]
arrayds,
case Maybe [EntryParam]
eparams of
Just [EntryParam]
eparams' ->
let num_val_params :: Int
num_val_params = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum (forall a b. (a -> b) -> [a] -> [b]
map (OpaqueTypes -> EntryPointType -> Int
entryPointSize OpaqueTypes
types forall b c a. (b -> c) -> (a -> b) -> a -> c
. EntryParam -> EntryPointType
entryParamType) [EntryParam]
eparams')
([Param FParamMem]
_ctx_params, [Param FParamMem]
val_params) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [FParam rep]
params forall a. Num a => a -> a -> a
- Int
num_val_params) [FParam rep]
params
in forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ [EntryParam]
-> [Param FParamMem] -> [((Name, Uniqueness), ExternalValue)]
mkExts [EntryParam]
eparams' [Param FParamMem]
val_params
Maybe [EntryParam]
Nothing -> forall a. Maybe a
Nothing
)
where
isArrayDecl :: VName -> ArrayDecl -> Bool
isArrayDecl VName
x (ArrayDecl VName
y PrimType
_ MemLoc
_) = VName
x forall a. Eq a => a -> a -> Bool
== VName
y
compileOutParam ::
FunReturns -> ImpM rep r op (Maybe Imp.Param, ValueDestination)
compileOutParam :: forall {k} (rep :: k) r op.
RetTypeMem -> ImpM rep r op (Maybe Param, ValueDestination)
compileOutParam (MemPrim PrimType
t) = do
VName
name <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"prim_out"
forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Param
Imp.ScalarParam VName
name PrimType
t, VName -> ValueDestination
ScalarDestination VName
name)
compileOutParam (MemMem Space
space) = do
VName
name <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"mem_out"
forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ VName -> Space -> Param
Imp.MemParam VName
name Space
space, VName -> ValueDestination
MemoryDestination VName
name)
compileOutParam MemArray {} =
forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. Maybe a
Nothing, Maybe MemLoc -> ValueDestination
ArrayDestination forall a. Maybe a
Nothing)
compileOutParam MemAcc {} =
forall a. HasCallStack => [Char] -> a
error [Char]
"Functions may not return accumulators."
compileExternalValues ::
Mem rep inner =>
OpaqueTypes ->
[RetType rep] ->
[EntryResult] ->
[Maybe Imp.Param] ->
ImpM rep r op [(Uniqueness, Imp.ExternalValue)]
compileExternalValues :: forall {k} (rep :: k) inner r op.
Mem rep inner =>
OpaqueTypes
-> [RetType rep]
-> [EntryResult]
-> [Maybe Param]
-> ImpM rep r op [(Uniqueness, ExternalValue)]
compileExternalValues OpaqueTypes
types [RetType rep]
orig_rts [EntryResult]
orig_epts [Maybe Param]
maybe_params = do
let ([RetTypeMem]
ctx_rts, [RetTypeMem]
val_rts) =
forall a. Int -> [a] -> ([a], [a])
splitAt
(forall (t :: * -> *) a. Foldable t => t a -> Int
length [RetType rep]
orig_rts forall a. Num a => a -> a -> a
- forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum (forall a b. (a -> b) -> [a] -> [b]
map (OpaqueTypes -> EntryPointType -> Int
entryPointSize OpaqueTypes
types forall b c a. (b -> c) -> (a -> b) -> a -> c
. EntryResult -> EntryPointType
entryResultType) [EntryResult]
orig_epts))
[RetType rep]
orig_rts
let nthOut :: Int -> VName
nthOut Int
i = case forall int a. Integral int => int -> [a] -> Maybe a
maybeNth Int
i [Maybe Param]
maybe_params of
Just (Just Param
p) -> Param -> VName
Imp.paramName Param
p
Just Maybe Param
Nothing -> forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"Output " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show Int
i forall a. [a] -> [a] -> [a]
++ [Char]
" not a param."
Maybe (Maybe Param)
Nothing -> forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"Param " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show Int
i forall a. [a] -> [a] -> [a]
++ [Char]
" does not exist."
mkValueDesc :: Int -> Signedness -> RetTypeMem -> ImpM rep r op ValueDesc
mkValueDesc Int
_ Signedness
signedness (MemArray PrimType
t ShapeBase (Ext SubExp)
shape Uniqueness
_ MemReturn
ret) = do
(VName
mem, Space
space) <-
case MemReturn
ret of
ReturnsNewBlock Space
space Int
j ExtIxFun
_ixfun ->
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> VName
nthOut Int
j, Space
space)
ReturnsInBlock VName
mem ExtIxFun
_ixfun -> do
Space
space <- MemEntry -> Space
entryMemSpace forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. VName -> ImpM rep r op MemEntry
lookupMemory VName
mem
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
mem, Space
space)
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> Space -> PrimType -> Signedness -> [SubExp] -> ValueDesc
Imp.ArrayValue VName
mem Space
space PrimType
t Signedness
signedness forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map Ext SubExp -> SubExp
f forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims ShapeBase (Ext SubExp)
shape
where
f :: Ext SubExp -> SubExp
f (Free SubExp
v) = SubExp
v
f (Ext Int
i) = VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ Int -> VName
nthOut Int
i
mkValueDesc Int
i Signedness
signedness (MemPrim PrimType
bt) =
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ PrimType -> Signedness -> VName -> ValueDesc
Imp.ScalarValue PrimType
bt Signedness
signedness forall a b. (a -> b) -> a -> b
$ Int -> VName
nthOut Int
i
mkValueDesc Int
_ Signedness
_ MemAcc {} =
forall a. HasCallStack => [Char] -> a
error [Char]
"mkValueDesc: unexpected MemAcc output."
mkValueDesc Int
_ Signedness
_ MemMem {} =
forall a. HasCallStack => [Char] -> a
error [Char]
"mkValueDesc: unexpected MemMem output."
mkExts :: Int
-> [EntryResult]
-> [RetTypeMem]
-> ImpM rep r op [(Uniqueness, ExternalValue)]
mkExts Int
i (EntryResult Uniqueness
u et :: EntryPointType
et@(TypeOpaque [Char]
desc) : [EntryResult]
epts) [RetTypeMem]
rets = do
let signs :: [Signedness]
signs = OpaqueTypes -> EntryPointType -> [Signedness]
entryPointSignedness OpaqueTypes
types EntryPointType
et
n :: Int
n = OpaqueTypes -> EntryPointType -> Int
entryPointSize OpaqueTypes
types EntryPointType
et
([RetTypeMem]
rets', [RetTypeMem]
rest) = forall a. Int -> [a] -> ([a], [a])
splitAt Int
n [RetTypeMem]
rets
[ValueDesc]
vds <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Int
i ..] [Signedness]
signs [RetTypeMem]
rets') forall a b. (a -> b) -> a -> b
$ \(Int
j, Signedness
s, RetTypeMem
r) -> Int -> Signedness -> RetTypeMem -> ImpM rep r op ValueDesc
mkValueDesc Int
j Signedness
s RetTypeMem
r
((Uniqueness
u, [Char] -> [ValueDesc] -> ExternalValue
Imp.OpaqueValue [Char]
desc [ValueDesc]
vds) :) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int
-> [EntryResult]
-> [RetTypeMem]
-> ImpM rep r op [(Uniqueness, ExternalValue)]
mkExts (Int
i forall a. Num a => a -> a -> a
+ Int
n) [EntryResult]
epts [RetTypeMem]
rest
mkExts Int
i (EntryResult Uniqueness
u (TypeTransparent (ValueType Signedness
s Rank
_ PrimType
_)) : [EntryResult]
epts) (RetTypeMem
ret : [RetTypeMem]
rets) = do
ValueDesc
vd <- Int -> Signedness -> RetTypeMem -> ImpM rep r op ValueDesc
mkValueDesc Int
i Signedness
s RetTypeMem
ret
((Uniqueness
u, ValueDesc -> ExternalValue
Imp.TransparentValue ValueDesc
vd) :) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int
-> [EntryResult]
-> [RetTypeMem]
-> ImpM rep r op [(Uniqueness, ExternalValue)]
mkExts (Int
i forall a. Num a => a -> a -> a
+ Int
1) [EntryResult]
epts [RetTypeMem]
rets
mkExts Int
_ [EntryResult]
_ [RetTypeMem]
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure []
Int
-> [EntryResult]
-> [RetTypeMem]
-> ImpM rep r op [(Uniqueness, ExternalValue)]
mkExts (forall (t :: * -> *) a. Foldable t => t a -> Int
length [RetTypeMem]
ctx_rts) [EntryResult]
orig_epts [RetTypeMem]
val_rts
compileOutParams ::
Mem rep inner =>
OpaqueTypes ->
[RetType rep] ->
Maybe [EntryResult] ->
ImpM rep r op (Maybe [(Uniqueness, Imp.ExternalValue)], [Imp.Param], [ValueDestination])
compileOutParams :: forall {k} (rep :: k) inner r op.
Mem rep inner =>
OpaqueTypes
-> [RetType rep]
-> Maybe [EntryResult]
-> ImpM
rep
r
op
(Maybe [(Uniqueness, ExternalValue)], [Param], [ValueDestination])
compileOutParams OpaqueTypes
types [RetType rep]
orig_rts Maybe [EntryResult]
maybe_orig_epts = do
([Maybe Param]
maybe_params, [ValueDestination]
dests) <- forall a b. [(a, b)] -> ([a], [b])
unzip forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {k} (rep :: k) r op.
RetTypeMem -> ImpM rep r op (Maybe Param, ValueDestination)
compileOutParam [RetType rep]
orig_rts
Maybe [(Uniqueness, ExternalValue)]
evs <- case Maybe [EntryResult]
maybe_orig_epts of
Just [EntryResult]
orig_epts ->
forall a. a -> Maybe a
Just forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) inner r op.
Mem rep inner =>
OpaqueTypes
-> [RetType rep]
-> [EntryResult]
-> [Maybe Param]
-> ImpM rep r op [(Uniqueness, ExternalValue)]
compileExternalValues OpaqueTypes
types [RetType rep]
orig_rts [EntryResult]
orig_epts [Maybe Param]
maybe_params
Maybe [EntryResult]
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe [(Uniqueness, ExternalValue)]
evs, forall a. [Maybe a] -> [a]
catMaybes [Maybe Param]
maybe_params, [ValueDestination]
dests)
compileFunDef ::
Mem rep inner =>
OpaqueTypes ->
FunDef rep ->
ImpM rep r op ()
compileFunDef :: forall {k} (rep :: k) inner r op.
Mem rep inner =>
OpaqueTypes -> FunDef rep -> ImpM rep r op ()
compileFunDef OpaqueTypes
types (FunDef Maybe EntryPoint
entry Attrs
_ Name
fname [RetType rep]
rettype [FParam rep]
params Body rep
body) =
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (\Env rep r op
env -> Env rep r op
env {envFunction :: Maybe Name
envFunction = Maybe Name
name_entry forall (m :: * -> *) a. MonadPlus m => m a -> m a -> m a
`mplus` forall a. a -> Maybe a
Just Name
fname}) forall a b. (a -> b) -> a -> b
$ do
(([Param]
outparams, [Param]
inparams, Maybe [(Uniqueness, ExternalValue)]
results, Maybe [((Name, Uniqueness), ExternalValue)]
args), Code op
body') <- forall {k} (rep :: k) r op a.
ImpM rep r op a -> ImpM rep r op (a, Code op)
collect' ImpM
rep
r
op
([Param], [Param], Maybe [(Uniqueness, ExternalValue)],
Maybe [((Name, Uniqueness), ExternalValue)])
compile
let entry' :: Maybe EntryPoint
entry' = case (Maybe Name
name_entry, Maybe [(Uniqueness, ExternalValue)]
results, Maybe [((Name, Uniqueness), ExternalValue)]
args) of
(Just Name
name_entry', Just [(Uniqueness, ExternalValue)]
results', Just [((Name, Uniqueness), ExternalValue)]
args') ->
forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Name
-> [(Uniqueness, ExternalValue)]
-> [((Name, Uniqueness), ExternalValue)]
-> EntryPoint
Imp.EntryPoint Name
name_entry' [(Uniqueness, ExternalValue)]
results' [((Name, Uniqueness), ExternalValue)]
args'
(Maybe Name, Maybe [(Uniqueness, ExternalValue)],
Maybe [((Name, Uniqueness), ExternalValue)])
_ ->
forall a. Maybe a
Nothing
forall {k} op (rep :: k) r. Name -> Function op -> ImpM rep r op ()
emitFunction Name
fname forall a b. (a -> b) -> a -> b
$ forall a.
Maybe EntryPoint -> [Param] -> [Param] -> Code a -> FunctionT a
Imp.Function Maybe EntryPoint
entry' [Param]
outparams [Param]
inparams Code op
body'
where
(Maybe Name
name_entry, Maybe [EntryParam]
params_entry, Maybe [EntryResult]
ret_entry) = case Maybe EntryPoint
entry of
Maybe EntryPoint
Nothing -> (forall a. Maybe a
Nothing, forall a. Maybe a
Nothing, forall a. Maybe a
Nothing)
Just (Name
x, [EntryParam]
y, [EntryResult]
z) -> (forall a. a -> Maybe a
Just Name
x, forall a. a -> Maybe a
Just [EntryParam]
y, forall a. a -> Maybe a
Just [EntryResult]
z)
compile :: ImpM
rep
r
op
([Param], [Param], Maybe [(Uniqueness, ExternalValue)],
Maybe [((Name, Uniqueness), ExternalValue)])
compile = do
([Param]
inparams, [ArrayDecl]
arrayds, Maybe [((Name, Uniqueness), ExternalValue)]
args) <- forall {k} (rep :: k) inner r op.
Mem rep inner =>
OpaqueTypes
-> [FParam rep]
-> Maybe [EntryParam]
-> ImpM
rep
r
op
([Param], [ArrayDecl], Maybe [((Name, Uniqueness), ExternalValue)])
compileInParams OpaqueTypes
types [FParam rep]
params Maybe [EntryParam]
params_entry
(Maybe [(Uniqueness, ExternalValue)]
results, [Param]
outparams, [ValueDestination]
dests) <- forall {k} (rep :: k) inner r op.
Mem rep inner =>
OpaqueTypes
-> [RetType rep]
-> Maybe [EntryResult]
-> ImpM
rep
r
op
(Maybe [(Uniqueness, ExternalValue)], [Param], [ValueDestination])
compileOutParams OpaqueTypes
types [RetType rep]
rettype Maybe [EntryResult]
ret_entry
forall {k} (rep :: k) inner r op.
Mem rep inner =>
[FParam rep] -> ImpM rep r op ()
addFParams [FParam rep]
params
forall {k} (rep :: k) r op. [ArrayDecl] -> ImpM rep r op ()
addArrays [ArrayDecl]
arrayds
let Body BodyDec rep
_ Stms rep
stms Result
ses = Body rep
body
forall {k} (rep :: k) r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms (forall a. FreeIn a => a -> Names
freeIn Result
ses) Stms rep
stms forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [ValueDestination]
dests Result
ses) forall a b. (a -> b) -> a -> b
$
\(ValueDestination
d, SubExpRes Certs
_ SubExp
se) -> forall {k} (rep :: k) r op.
ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIMDest ValueDestination
d [] SubExp
se []
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Param]
outparams, [Param]
inparams, Maybe [(Uniqueness, ExternalValue)]
results, Maybe [((Name, Uniqueness), ExternalValue)]
args)
compileBody :: Pat (LetDec rep) -> Body rep -> ImpM rep r op ()
compileBody :: forall {k} (rep :: k) r op.
Pat (LetDec rep) -> Body rep -> ImpM rep r op ()
compileBody Pat (LetDec rep)
pat (Body BodyDec rep
_ Stms rep
stms Result
ses) = do
[ValueDestination]
dests <- forall {k} (rep :: k) r op.
Pat (LetDec rep) -> ImpM rep r op [ValueDestination]
destinationFromPat Pat (LetDec rep)
pat
forall {k} (rep :: k) r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms (forall a. FreeIn a => a -> Names
freeIn Result
ses) Stms rep
stms forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [ValueDestination]
dests Result
ses) forall a b. (a -> b) -> a -> b
$
\(ValueDestination
d, SubExpRes Certs
_ SubExp
se) -> forall {k} (rep :: k) r op.
ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIMDest ValueDestination
d [] SubExp
se []
compileBody' :: [Param dec] -> Body rep -> ImpM rep r op ()
compileBody' :: forall {k} dec (rep :: k) r op.
[Param dec] -> Body rep -> ImpM rep r op ()
compileBody' [Param dec]
params (Body BodyDec rep
_ Stms rep
stms Result
ses) =
forall {k} (rep :: k) r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms (forall a. FreeIn a => a -> Names
freeIn Result
ses) Stms rep
stms forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Param dec]
params Result
ses) forall a b. (a -> b) -> a -> b
$
\(Param dec
param, SubExpRes Certs
_ SubExp
se) -> forall {k} (rep :: k) r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (forall dec. Param dec -> VName
paramName Param dec
param) [] SubExp
se []
compileLoopBody :: Typed dec => [Param dec] -> Body rep -> ImpM rep r op ()
compileLoopBody :: forall {k} dec (rep :: k) r op.
Typed dec =>
[Param dec] -> Body rep -> ImpM rep r op ()
compileLoopBody [Param dec]
mergeparams (Body BodyDec rep
_ Stms rep
stms Result
ses) = do
[VName]
tmpnames <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a. [a] -> [a] -> [a]
++ [Char]
"_tmp") forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> [Char]
baseString forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> VName
paramName) [Param dec]
mergeparams
forall {k} (rep :: k) r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms (forall a. FreeIn a => a -> Names
freeIn Result
ses) Stms rep
stms forall a b. (a -> b) -> a -> b
$ do
[ImpM rep r op ()]
copy_to_merge_params <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Param dec]
mergeparams [VName]
tmpnames Result
ses) forall a b. (a -> b) -> a -> b
$ \(Param dec
p, VName
tmp, SubExpRes Certs
_ SubExp
se) ->
case forall t. Typed t => t -> Type
typeOf Param dec
p of
Prim PrimType
pt -> do
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
tmp Volatility
Imp.Nonvolatile PrimType
pt
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. VName -> Exp -> Code a
Imp.SetScalar VName
tmp forall a b. (a -> b) -> a -> b
$ forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
pt SubExp
se
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. VName -> Exp -> Code a
Imp.SetScalar (forall dec. Param dec -> VName
paramName Param dec
p) forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
tmp PrimType
pt
Mem Space
space | Var VName
v <- SubExp
se -> do
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. VName -> Space -> Code a
Imp.DeclareMem VName
tmp Space
space
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. VName -> VName -> Space -> Code a
Imp.SetMem VName
tmp VName
v Space
space
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. VName -> VName -> Space -> Code a
Imp.SetMem (forall dec. Param dec -> VName
paramName Param dec
p) VName
tmp Space
space
Type
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ [ImpM rep r op ()]
copy_to_merge_params
compileStms :: Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms :: forall {k} (rep :: k) r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
alive_after_stms Stms rep
all_stms ImpM rep r op ()
m = do
StmsCompiler rep r op
cb <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall {k} (rep :: k) r op. Env rep r op -> StmsCompiler rep r op
envStmsCompiler
StmsCompiler rep r op
cb Names
alive_after_stms Stms rep
all_stms ImpM rep r op ()
m
defCompileStms ::
(Mem rep inner, FreeIn op) =>
Names ->
Stms rep ->
ImpM rep r op () ->
ImpM rep r op ()
defCompileStms :: forall {k} (rep :: k) inner op r.
(Mem rep inner, FreeIn op) =>
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
defCompileStms Names
alive_after_stms Stms rep
all_stms ImpM rep r op ()
m =
forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ Set (VName, Space) -> [Stm rep] -> ImpM rep r op Names
compileStms' forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList Stms rep
all_stms
where
compileStms' :: Set (VName, Space) -> [Stm rep] -> ImpM rep r op Names
compileStms' Set (VName, Space)
allocs (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux Exp rep
e : [Stm rep]
bs) = do
forall {k} (rep :: k) inner r op.
Mem rep inner =>
Maybe (Exp rep) -> [PatElem (LetDec rep)] -> ImpM rep r op ()
dVars (forall a. a -> Maybe a
Just Exp rep
e) (forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec rep)
pat)
Code op
e_code <-
forall {k} (rep :: k) r op a.
Attrs -> ImpM rep r op a -> ImpM rep r op a
localAttrs (forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux (ExpDec rep)
aux) forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k) r op.
ImpM rep r op () -> ImpM rep r op (Code op)
collect forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k) r op.
Pat (LetDec rep) -> Exp rep -> ImpM rep r op ()
compileExp Pat (LetDec rep)
pat Exp rep
e
(Names
live_after, Code op
bs_code) <- forall {k} (rep :: k) r op a.
ImpM rep r op a -> ImpM rep r op (a, Code op)
collect' forall a b. (a -> b) -> a -> b
$ Set (VName, Space) -> [Stm rep] -> ImpM rep r op Names
compileStms' (Pat (LetDec rep) -> Set (VName, Space)
patternAllocs Pat (LetDec rep)
pat forall a. Semigroup a => a -> a -> a
<> Set (VName, Space)
allocs) [Stm rep]
bs
let dies_here :: VName -> Bool
dies_here VName
v =
(VName
v VName -> Names -> Bool
`notNameIn` Names
live_after) Bool -> Bool -> Bool
&& (VName
v VName -> Names -> Bool
`nameIn` forall a. FreeIn a => a -> Names
freeIn Code op
e_code)
to_free :: Set (VName, Space)
to_free = forall a. (a -> Bool) -> Set a -> Set a
S.filter (VName -> Bool
dies_here forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) Set (VName, Space)
allocs
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit Code op
e_code
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry forall a. VName -> Space -> Code a
Imp.Free) Set (VName, Space)
to_free
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit Code op
bs_code
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Names
freeIn Code op
e_code forall a. Semigroup a => a -> a -> a
<> Names
live_after
compileStms' Set (VName, Space)
_ [] = do
Code op
code <- forall {k} (rep :: k) r op.
ImpM rep r op () -> ImpM rep r op (Code op)
collect ImpM rep r op ()
m
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit Code op
code
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Names
freeIn Code op
code forall a. Semigroup a => a -> a -> a
<> Names
alive_after_stms
patternAllocs :: Pat (LetDec rep) -> Set (VName, Space)
patternAllocs = forall a. Ord a => [a] -> Set a
S.fromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe forall {dec}. Typed dec => PatElem dec -> Maybe (VName, Space)
isMemPatElem forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Pat dec -> [PatElem dec]
patElems
isMemPatElem :: PatElem dec -> Maybe (VName, Space)
isMemPatElem PatElem dec
pe = case forall dec. Typed dec => PatElem dec -> Type
patElemType PatElem dec
pe of
Mem Space
space -> forall a. a -> Maybe a
Just (forall dec. PatElem dec -> VName
patElemName PatElem dec
pe, Space
space)
Type
_ -> forall a. Maybe a
Nothing
compileExp :: Pat (LetDec rep) -> Exp rep -> ImpM rep r op ()
compileExp :: forall {k} (rep :: k) r op.
Pat (LetDec rep) -> Exp rep -> ImpM rep r op ()
compileExp Pat (LetDec rep)
pat Exp rep
e = do
Pat (LetDec rep) -> Exp rep -> ImpM rep r op ()
ec <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall {k} (rep :: k) r op. Env rep r op -> ExpCompiler rep r op
envExpCompiler
Pat (LetDec rep) -> Exp rep -> ImpM rep r op ()
ec Pat (LetDec rep)
pat Exp rep
e
caseMatch :: [SubExp] -> [Maybe PrimValue] -> Imp.TExp Bool
caseMatch :: [SubExp] -> [Maybe PrimValue] -> TExp Bool
caseMatch [SubExp]
ses [Maybe PrimValue]
vs = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
(.&&.) forall v. TPrimExp Bool v
true (forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall {a}. ToExp a => a -> Maybe PrimValue -> TExp Bool
cmp [SubExp]
ses [Maybe PrimValue]
vs)
where
cmp :: a -> Maybe PrimValue -> TExp Bool
cmp a
se (Just PrimValue
v) = forall v. PrimExp v -> TPrimExp Bool v
isBool forall a b. (a -> b) -> a -> b
$ forall a. ToExp a => PrimType -> a -> Exp
toExp' (PrimValue -> PrimType
primValueType PrimValue
v) a
se forall v. PrimExp v -> PrimExp v -> PrimExp v
~==~ forall v. PrimValue -> PrimExp v
ValueExp PrimValue
v
cmp a
_ Maybe PrimValue
Nothing = forall v. TPrimExp Bool v
true
defCompileExp ::
(Mem rep inner) =>
Pat (LetDec rep) ->
Exp rep ->
ImpM rep r op ()
defCompileExp :: forall {k} (rep :: k) inner r op.
Mem rep inner =>
Pat (LetDec rep) -> Exp rep -> ImpM rep r op ()
defCompileExp Pat (LetDec rep)
pat (Match [SubExp]
ses [Case (Body rep)]
cases Body rep
defbody MatchDec (BranchType rep)
_) =
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Case (Body rep) -> ImpM rep r op () -> ImpM rep r op ()
f (forall {k} (rep :: k) r op.
Pat (LetDec rep) -> Body rep -> ImpM rep r op ()
compileBody Pat (LetDec rep)
pat Body rep
defbody) [Case (Body rep)]
cases
where
f :: Case (Body rep) -> ImpM rep r op () -> ImpM rep r op ()
f (Case [Maybe PrimValue]
vs Body rep
body) = forall {k} (rep :: k) r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf ([SubExp] -> [Maybe PrimValue] -> TExp Bool
caseMatch [SubExp]
ses [Maybe PrimValue]
vs) (forall {k} (rep :: k) r op.
Pat (LetDec rep) -> Body rep -> ImpM rep r op ()
compileBody Pat (LetDec rep)
pat Body rep
body)
defCompileExp Pat (LetDec rep)
pat (Apply Name
fname [(SubExp, Diet)]
args [RetType rep]
_ (Safety, SrcLoc, [SrcLoc])
_) = do
[ValueDestination]
dest <- forall {k} (rep :: k) r op.
Pat (LetDec rep) -> ImpM rep r op [ValueDestination]
destinationFromPat Pat (LetDec rep)
pat
[VName]
targets <- forall {k} (rep :: k) r op.
[ValueDestination] -> ImpM rep r op [VName]
funcallTargets [ValueDestination]
dest
[Arg]
args' <- forall a. [Maybe a] -> [a]
catMaybes forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {k} {m :: * -> *} {t :: k} {b}.
(Monad m, HasScope t m) =>
(SubExp, b) -> m (Maybe Arg)
compileArg [(SubExp, Diet)]
args
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. [VName] -> Name -> [Arg] -> Code a
Imp.Call [VName]
targets Name
fname [Arg]
args'
where
compileArg :: (SubExp, b) -> m (Maybe Arg)
compileArg (SubExp
se, b
_) = do
Type
t <- forall {k} (t :: k) (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
se
case (SubExp
se, Type
t) of
(SubExp
_, Prim PrimType
pt) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Exp -> Arg
Imp.ExpArg forall a b. (a -> b) -> a -> b
$ forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
pt SubExp
se
(Var VName
v, Mem {}) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ VName -> Arg
Imp.MemArg VName
v
(SubExp, Type)
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
defCompileExp Pat (LetDec rep)
pat (BasicOp BasicOp
op) = forall {k} (rep :: k) inner r op.
Mem rep inner =>
Pat (LetDec rep) -> BasicOp -> ImpM rep r op ()
defCompileBasicOp Pat (LetDec rep)
pat BasicOp
op
defCompileExp Pat (LetDec rep)
pat (DoLoop [(FParam rep, SubExp)]
merge LoopForm rep
form Body rep
body) = do
Attrs
attrs <- forall {k} (rep :: k) r op. ImpM rep r op Attrs
askAttrs
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Attr
"unroll" Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs) forall a b. (a -> b) -> a -> b
$
forall {k} loc (rep :: k) r op.
Located loc =>
loc -> [loc] -> Text -> ImpM rep r op ()
warn (forall a. IsLocation a => a
noLoc :: SrcLoc) [] Text
"#[unroll] on loop with unknown number of iterations."
forall {k} (rep :: k) inner r op.
Mem rep inner =>
[FParam rep] -> ImpM rep r op ()
dFParams [Param FParamMem]
params
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(FParam rep, SubExp)]
merge forall a b. (a -> b) -> a -> b
$ \(Param FParamMem
p, SubExp
se) ->
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ((forall a. Eq a => a -> a -> Bool
== Int
0) forall a b. (a -> b) -> a -> b
$ forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank forall a b. (a -> b) -> a -> b
$ forall dec. Typed dec => Param dec -> Type
paramType Param FParamMem
p) forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k) r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (forall dec. Param dec -> VName
paramName Param FParamMem
p) [] SubExp
se []
let doBody :: ImpM rep r op ()
doBody = forall {k} dec (rep :: k) r op.
Typed dec =>
[Param dec] -> Body rep -> ImpM rep r op ()
compileLoopBody [Param FParamMem]
params Body rep
body
case LoopForm rep
form of
ForLoop VName
i IntType
_ SubExp
bound [(LParam rep, VName)]
loopvars -> do
let setLoopParam :: (Param LParamMem, VName) -> ImpM rep r op ()
setLoopParam (Param LParamMem
p, VName
a)
| Prim PrimType
_ <- forall dec. Typed dec => Param dec -> Type
paramType Param LParamMem
p =
forall {k} (rep :: k) r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var VName
a) [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ forall a. a -> TPrimExp Int64 a
Imp.le64 VName
i]
| Bool
otherwise =
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
Exp
bound' <- forall a {k} (rep :: k) r op. ToExp a => a -> ImpM rep r op Exp
toExp SubExp
bound
forall {k} (rep :: k) inner r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(LParam rep, VName)]
loopvars
forall {k} (rep :: k) r op.
VName -> Exp -> ImpM rep r op () -> ImpM rep r op ()
sFor' VName
i Exp
bound' forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Param LParamMem, VName) -> ImpM rep r op ()
setLoopParam [(LParam rep, VName)]
loopvars forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ImpM rep r op ()
doBody
WhileLoop VName
cond ->
forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhile (forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
cond PrimType
Bool) ImpM rep r op ()
doBody
[ValueDestination]
pat_dests <- forall {k} (rep :: k) r op.
Pat (LetDec rep) -> ImpM rep r op [ValueDestination]
destinationFromPat Pat (LetDec rep)
pat
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [ValueDestination]
pat_dests forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Param dec -> VName
paramName forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(FParam rep, SubExp)]
merge) forall a b. (a -> b) -> a -> b
$ \(ValueDestination
d, SubExp
r) ->
forall {k} (rep :: k) r op.
ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIMDest ValueDestination
d [] SubExp
r []
where
params :: [Param FParamMem]
params = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(FParam rep, SubExp)]
merge
defCompileExp Pat (LetDec rep)
pat (WithAcc [WithAccInput rep]
inputs Lambda rep
lam) = do
forall {k} (rep :: k) inner r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [WithAccInput rep]
inputs forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam) forall a b. (a -> b) -> a -> b
$ \((Shape
_, [VName]
arrs, Maybe (Lambda rep, [SubExp])
op), Param LParamMem
p) ->
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s ->
ImpState rep r op
s {stateAccs :: Map VName ([VName], Maybe (Lambda rep, [SubExp]))
stateAccs = forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (forall dec. Param dec -> VName
paramName Param LParamMem
p) ([VName]
arrs, Maybe (Lambda rep, [SubExp])
op) forall a b. (a -> b) -> a -> b
$ forall {k} {k} (rep :: k) (r :: k) op.
ImpState rep r op
-> Map VName ([VName], Maybe (Lambda rep, [SubExp]))
stateAccs ImpState rep r op
s}
forall {k} (rep :: k) r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
lam) forall a b. (a -> b) -> a -> b
$ do
let nonacc_res :: Result
nonacc_res = forall a. Int -> [a] -> [a]
drop Int
num_accs (forall {k} (rep :: k). Body rep -> Result
bodyResult (forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
lam))
nonacc_pat_names :: [VName]
nonacc_pat_names = forall a. Int -> [a] -> [a]
takeLast (forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
nonacc_res) (forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat)
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
nonacc_pat_names Result
nonacc_res) forall a b. (a -> b) -> a -> b
$ \(VName
v, SubExpRes Certs
_ SubExp
se) ->
forall {k} (rep :: k) r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM VName
v [] SubExp
se []
where
num_accs :: Int
num_accs = forall (t :: * -> *) a. Foldable t => t a -> Int
length [WithAccInput rep]
inputs
defCompileExp Pat (LetDec rep)
pat (Op Op rep
op) = do
Pat (LetDec rep) -> MemOp inner -> ImpM rep r op ()
opc <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall {k} (rep :: k) r op. Env rep r op -> OpCompiler rep r op
envOpCompiler
Pat (LetDec rep) -> MemOp inner -> ImpM rep r op ()
opc Pat (LetDec rep)
pat Op rep
op
tracePrim :: T.Text -> PrimType -> SubExp -> ImpM rep r op ()
tracePrim :: forall {k} (rep :: k) r op.
Text -> PrimType -> SubExp -> ImpM rep r op ()
tracePrim Text
s PrimType
t SubExp
se =
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. ErrorMsg Exp -> Code a
Imp.TracePrint forall a b. (a -> b) -> a -> b
$
forall a. [ErrorMsgPart a] -> ErrorMsg a
ErrorMsg [forall a. Text -> ErrorMsgPart a
ErrorString (Text
s forall a. Semigroup a => a -> a -> a
<> Text
": "), forall a. PrimType -> a -> ErrorMsgPart a
ErrorVal PrimType
t (forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
t SubExp
se), forall a. Text -> ErrorMsgPart a
ErrorString Text
"\n"]
traceArray :: T.Text -> PrimType -> Shape -> SubExp -> ImpM rep r op ()
traceArray :: forall {k} (rep :: k) r op.
Text -> PrimType -> Shape -> SubExp -> ImpM rep r op ()
traceArray Text
s PrimType
t Shape
shape SubExp
se = do
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. ErrorMsg Exp -> Code a
Imp.TracePrint forall a b. (a -> b) -> a -> b
$ forall a. [ErrorMsgPart a] -> ErrorMsg a
ErrorMsg [forall a. Text -> ErrorMsgPart a
ErrorString (Text
s forall a. Semigroup a => a -> a -> a
<> Text
": ")]
forall {k} (rep :: k) r op.
Shape -> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest Shape
shape forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
is -> do
TV Any
arr_elem <- forall {k} {k} (rep :: k) r op (t :: k).
[Char] -> PrimType -> ImpM rep r op (TV t)
dPrim [Char]
"arr_elem" PrimType
t
forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall {k} (t :: k). TV t -> VName
tvVar TV Any
arr_elem) [] SubExp
se [TExp Int64]
is
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. ErrorMsg Exp -> Code a
Imp.TracePrint forall a b. (a -> b) -> a -> b
$ forall a. [ErrorMsgPart a] -> ErrorMsg a
ErrorMsg [forall a. PrimType -> a -> ErrorMsgPart a
ErrorVal PrimType
t (forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (forall {k} (t :: k). TV t -> TExp t
tvExp TV Any
arr_elem)), ErrorMsgPart Exp
" "]
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. ErrorMsg Exp -> Code a
Imp.TracePrint forall a b. (a -> b) -> a -> b
$ forall a. [ErrorMsgPart a] -> ErrorMsg a
ErrorMsg [ErrorMsgPart Exp
"\n"]
defCompileBasicOp ::
Mem rep inner =>
Pat (LetDec rep) ->
BasicOp ->
ImpM rep r op ()
defCompileBasicOp :: forall {k} (rep :: k) inner r op.
Mem rep inner =>
Pat (LetDec rep) -> BasicOp -> ImpM rep r op ()
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (SubExp SubExp
se) =
forall {k} (rep :: k) r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe) [] SubExp
se []
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (Opaque OpaqueOp
op SubExp
se) = do
forall {k} (rep :: k) r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe) [] SubExp
se []
case OpaqueOp
op of
OpaqueOp
OpaqueNil -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
OpaqueTrace Text
s -> forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
comment (Text
"Trace: " forall a. Semigroup a => a -> a -> a
<> Text
s) forall a b. (a -> b) -> a -> b
$ do
Type
se_t <- forall {k} (t :: k) (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
se
case Type
se_t of
Prim PrimType
t -> forall {k} (rep :: k) r op.
Text -> PrimType -> SubExp -> ImpM rep r op ()
tracePrim Text
s PrimType
t SubExp
se
Array PrimType
t Shape
shape NoUniqueness
_ -> forall {k} (rep :: k) r op.
Text -> PrimType -> Shape -> SubExp -> ImpM rep r op ()
traceArray Text
s PrimType
t Shape
shape SubExp
se
Type
_ ->
forall {k} loc (rep :: k) r op.
Located loc =>
loc -> [loc] -> Text -> ImpM rep r op ()
warn [forall a. Monoid a => a
mempty :: SrcLoc] forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$
Text
s forall a. Semigroup a => a -> a -> a
<> Text
": cannot trace value of this (core) type: " forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText Type
se_t
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (UnOp UnOp
op SubExp
e) = do
Exp
e' <- forall a {k} (rep :: k) r op. ToExp a => a -> ImpM rep r op Exp
toExp SubExp
e
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe forall {k} (rep :: k) r op. VName -> Exp -> ImpM rep r op ()
<~~ forall v. UnOp -> PrimExp v -> PrimExp v
Imp.UnOpExp UnOp
op Exp
e'
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (ConvOp ConvOp
conv SubExp
e) = do
Exp
e' <- forall a {k} (rep :: k) r op. ToExp a => a -> ImpM rep r op Exp
toExp SubExp
e
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe forall {k} (rep :: k) r op. VName -> Exp -> ImpM rep r op ()
<~~ forall v. ConvOp -> PrimExp v -> PrimExp v
Imp.ConvOpExp ConvOp
conv Exp
e'
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (BinOp BinOp
bop SubExp
x SubExp
y) = do
Exp
x' <- forall a {k} (rep :: k) r op. ToExp a => a -> ImpM rep r op Exp
toExp SubExp
x
Exp
y' <- forall a {k} (rep :: k) r op. ToExp a => a -> ImpM rep r op Exp
toExp SubExp
y
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe forall {k} (rep :: k) r op. VName -> Exp -> ImpM rep r op ()
<~~ forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp BinOp
bop Exp
x' Exp
y'
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (CmpOp CmpOp
bop SubExp
x SubExp
y) = do
Exp
x' <- forall a {k} (rep :: k) r op. ToExp a => a -> ImpM rep r op Exp
toExp SubExp
x
Exp
y' <- forall a {k} (rep :: k) r op. ToExp a => a -> ImpM rep r op Exp
toExp SubExp
y
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe forall {k} (rep :: k) r op. VName -> Exp -> ImpM rep r op ()
<~~ forall v. CmpOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.CmpOpExp CmpOp
bop Exp
x' Exp
y'
defCompileBasicOp Pat (LetDec rep)
_ (Assert SubExp
e ErrorMsg SubExp
msg (SrcLoc, [SrcLoc])
loc) = do
Exp
e' <- forall a {k} (rep :: k) r op. ToExp a => a -> ImpM rep r op Exp
toExp SubExp
e
ErrorMsg Exp
msg' <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall a {k} (rep :: k) r op. ToExp a => a -> ImpM rep r op Exp
toExp ErrorMsg SubExp
msg
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. Exp -> ErrorMsg Exp -> (SrcLoc, [SrcLoc]) -> Code a
Imp.Assert Exp
e' ErrorMsg Exp
msg' (SrcLoc, [SrcLoc])
loc
Attrs
attrs <- forall {k} (rep :: k) r op. ImpM rep r op Attrs
askAttrs
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Name -> [Attr] -> Attr
AttrComp Name
"warn" [Attr
"safety_checks"] Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs) forall a b. (a -> b) -> a -> b
$
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry forall {k} loc (rep :: k) r op.
Located loc =>
loc -> [loc] -> Text -> ImpM rep r op ()
warn (SrcLoc, [SrcLoc])
loc Text
"Safety check required at run-time."
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (Index VName
src Slice SubExp
slice)
| Just [SubExp]
idxs <- forall d. Slice d -> Maybe [d]
sliceIndices Slice SubExp
slice =
forall {k} (rep :: k) r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe) [] (VName -> SubExp
Var VName
src) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall d. d -> DimIndex d
DimFix forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> TExp Int64
pe64) [SubExp]
idxs
defCompileBasicOp Pat (LetDec rep)
_ Index {} =
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (Update Safety
safety VName
_ Slice SubExp
slice SubExp
se) =
case Safety
safety of
Safety
Unsafe -> ImpM rep r op ()
write
Safety
Safe -> forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (Slice (TExp Int64) -> [TExp Int64] -> TExp Bool
inBounds Slice (TExp Int64)
slice' [TExp Int64]
dims) ImpM rep r op ()
write
where
slice' :: Slice (TExp Int64)
slice' = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
pe64 Slice SubExp
slice
dims :: [TExp Int64]
dims = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 forall a b. (a -> b) -> a -> b
$ forall u. TypeBase Shape u -> [SubExp]
arrayDims forall a b. (a -> b) -> a -> b
$ forall dec. Typed dec => PatElem dec -> Type
patElemType PatElem (LetDec rep)
pe
write :: ImpM rep r op ()
write = forall {k} (rep :: k) r op.
VName -> Slice (TExp Int64) -> SubExp -> ImpM rep r op ()
sUpdate (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe) Slice (TExp Int64)
slice' SubExp
se
defCompileBasicOp Pat (LetDec rep)
_ FlatIndex {} =
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (FlatUpdate VName
_ FlatSlice SubExp
slice VName
v) = do
MemLoc
pe_loc <- ArrayEntry -> MemLoc
entryArrayLoc forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. VName -> ImpM rep r op ArrayEntry
lookupArray (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe)
MemLoc
v_loc <- ArrayEntry -> MemLoc
entryArrayLoc forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. VName -> ImpM rep r op ArrayEntry
lookupArray VName
v
forall {k} (rep :: k) r op. CopyCompiler rep r op
copy (forall shape u. TypeBase shape u -> PrimType
elemType (forall dec. Typed dec => PatElem dec -> Type
patElemType PatElem (LetDec rep)
pe)) (MemLoc -> FlatSlice (TExp Int64) -> MemLoc
flatSliceMemLoc MemLoc
pe_loc FlatSlice (TExp Int64)
slice') MemLoc
v_loc
where
slice' :: FlatSlice (TExp Int64)
slice' = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
pe64 FlatSlice SubExp
slice
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (Replicate (Shape [SubExp]
ds) SubExp
se)
| Acc {} <- forall dec. Typed dec => PatElem dec -> Type
patElemType PatElem (LetDec rep)
pe = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
| Bool
otherwise = do
[Exp]
ds' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall a {k} (rep :: k) r op. ToExp a => a -> ImpM rep r op Exp
toExp [SubExp]
ds
[VName]
is <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
ds) (forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"i")
Code op
copy_elem <- forall {k} (rep :: k) r op.
ImpM rep r op () -> ImpM rep r op (Code op)
collect forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe) (forall a b. (a -> b) -> [a] -> [b]
map (forall d. d -> DimIndex d
DimFix forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> TPrimExp Int64 a
Imp.le64) [VName]
is) SubExp
se []
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl forall b c a. (b -> c) -> (a -> b) -> a -> c
(.) forall a. a -> a
id (forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. VName -> Exp -> Code a -> Code a
Imp.For [VName]
is [Exp]
ds') Code op
copy_elem
defCompileBasicOp Pat (LetDec rep)
_ Scratch {} =
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (Iota SubExp
n SubExp
e SubExp
s IntType
it) = do
Exp
e' <- forall a {k} (rep :: k) r op. ToExp a => a -> ImpM rep r op Exp
toExp SubExp
e
Exp
s' <- forall a {k} (rep :: k) r op. ToExp a => a -> ImpM rep r op Exp
toExp SubExp
s
forall {k} {k} (t :: k) (rep :: k) r op.
[Char]
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor [Char]
"i" (SubExp -> TExp Int64
pe64 SubExp
n) forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
let i' :: Exp
i' = forall v. IntType -> PrimExp v -> PrimExp v
sExt IntType
it forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Int64
i
TV Any
x <-
forall {k} {k} (t :: k) (rep :: k) r op.
[Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
"x" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp forall a b. (a -> b) -> a -> b
$
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> Overflow -> BinOp
Add IntType
it Overflow
OverflowUndef) Exp
e' forall a b. (a -> b) -> a -> b
$
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> Overflow -> BinOp
Mul IntType
it Overflow
OverflowUndef) Exp
i' Exp
s'
forall {k} (rep :: k) r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe) [forall d. d -> DimIndex d
DimFix TExp Int64
i] (VName -> SubExp
Var (forall {k} (t :: k). TV t -> VName
tvVar TV Any
x)) []
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (Copy VName
src) =
forall {k} (rep :: k) r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe) [] (VName -> SubExp
Var VName
src) []
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (Manifest [Int]
_ VName
src) =
forall {k} (rep :: k) r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe) [] (VName -> SubExp
Var VName
src) []
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (Concat Int
i (VName
x :| [VName]
ys) SubExp
_) = do
TV Int64
offs_glb <- forall {k} {k} (t :: k) (rep :: k) r op.
[Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
"tmp_offs" TExp Int64
0
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (VName
x forall a. a -> [a] -> [a]
: [VName]
ys) forall a b. (a -> b) -> a -> b
$ \VName
y -> do
[SubExp]
y_dims <- forall u. TypeBase Shape u -> [SubExp]
arrayDims forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
y
let rows :: TExp Int64
rows = case forall a. Int -> [a] -> [a]
drop Int
i [SubExp]
y_dims of
[] -> forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"defCompileBasicOp Concat: empty array shape for " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString VName
y
SubExp
r : [SubExp]
_ -> SubExp -> TExp Int64
pe64 SubExp
r
skip_dims :: [SubExp]
skip_dims = forall a. Int -> [a] -> [a]
take Int
i [SubExp]
y_dims
sliceAllDim :: d -> DimIndex d
sliceAllDim d
d = forall d. d -> d -> d -> DimIndex d
DimSlice d
0 d
d d
1
skip_slices :: [DimIndex (TExp Int64)]
skip_slices = forall a b. (a -> b) -> [a] -> [b]
map (forall {d}. Num d => d -> DimIndex d
sliceAllDim forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> TExp Int64
pe64) [SubExp]
skip_dims
destslice :: [DimIndex (TExp Int64)]
destslice = [DimIndex (TExp Int64)]
skip_slices forall a. [a] -> [a] -> [a]
++ [forall d. d -> d -> d -> DimIndex d
DimSlice (forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
offs_glb) TExp Int64
rows TExp Int64
1]
forall {k} (rep :: k) r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe) [DimIndex (TExp Int64)]
destslice (VName -> SubExp
Var VName
y) []
TV Int64
offs_glb forall {k} {k} (t :: k) (rep :: k) r op.
TV t -> TExp t -> ImpM rep r op ()
<-- forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
offs_glb forall a. Num a => a -> a -> a
+ TExp Int64
rows
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (ArrayLit [SubExp]
es Type
_)
| Just vs :: [PrimValue]
vs@(PrimValue
v : [PrimValue]
_) <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> Maybe PrimValue
isLiteral [SubExp]
es = do
MemLoc
dest_mem <- ArrayEntry -> MemLoc
entryArrayLoc forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. VName -> ImpM rep r op ArrayEntry
lookupArray (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe)
Space
dest_space <- MemEntry -> Space
entryMemSpace forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. VName -> ImpM rep r op MemEntry
lookupMemory (MemLoc -> VName
memLocName MemLoc
dest_mem)
let t :: PrimType
t = PrimValue -> PrimType
primValueType PrimValue
v
VName
static_array <- forall {k} (rep :: k) r op. [Char] -> ImpM rep r op VName
newVNameForFun [Char]
"static_array"
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. VName -> Space -> PrimType -> ArrayContents -> Code a
Imp.DeclareArray VName
static_array Space
dest_space PrimType
t forall a b. (a -> b) -> a -> b
$ [PrimValue] -> ArrayContents
Imp.ArrayValues [PrimValue]
vs
let static_src :: MemLoc
static_src =
VName -> [SubExp] -> IxFun (TExp Int64) -> MemLoc
MemLoc VName
static_array [IntType -> Integer -> SubExp
intConst IntType
Int64 forall a b. (a -> b) -> a -> b
$ forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
es] forall a b. (a -> b) -> a -> b
$
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota [forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
es]
entry :: VarEntry rep
entry = forall {k} (rep :: k). Maybe (Exp rep) -> MemEntry -> VarEntry rep
MemVar forall a. Maybe a
Nothing forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
dest_space
forall {k} (rep :: k) r op.
VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
static_array VarEntry rep
entry
forall {k} (rep :: k) r op. CopyCompiler rep r op
copy PrimType
t MemLoc
dest_mem MemLoc
static_src
| Bool
otherwise =
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Integer
0 ..] [SubExp]
es) forall a b. (a -> b) -> a -> b
$ \(Integer
i, SubExp
e) ->
forall {k} (rep :: k) r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe) [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ forall a. Num a => Integer -> a
fromInteger Integer
i] SubExp
e []
where
isLiteral :: SubExp -> Maybe PrimValue
isLiteral (Constant PrimValue
v) = forall a. a -> Maybe a
Just PrimValue
v
isLiteral SubExp
_ = forall a. Maybe a
Nothing
defCompileBasicOp Pat (LetDec rep)
_ Rearrange {} =
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
defCompileBasicOp (Pat [PatElem (LetDec rep)
pe]) (Rotate [SubExp]
rs VName
arr) = do
Shape
shape <- forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
arr
forall {k} (rep :: k) r op.
Shape -> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest Shape
shape forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
is -> do
[TExp Int64]
is' <- forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence forall a b. (a -> b) -> a -> b
$ forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 forall {k} {rep :: k} {r} {op}.
SubExp -> SubExp -> TExp Int64 -> ImpM rep r op (TExp Int64)
rotate (forall d. ShapeBase d -> [d]
shapeDims Shape
shape) [SubExp]
rs [TExp Int64]
is
forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
pe) [TExp Int64]
is (VName -> SubExp
Var VName
arr) [TExp Int64]
is'
where
rotate :: SubExp -> SubExp -> TExp Int64 -> ImpM rep r op (TExp Int64)
rotate SubExp
d SubExp
r TExp Int64
i = forall {k} {k} (t :: k) (rep :: k) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"rot_i" forall a b. (a -> b) -> a -> b
$ TExp Int64 -> TExp Int64 -> TExp Int64 -> TExp Int64
rotateIndex (SubExp -> TExp Int64
pe64 SubExp
d) (SubExp -> TExp Int64
pe64 SubExp
r) TExp Int64
i
defCompileBasicOp Pat (LetDec rep)
_ Reshape {} =
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
defCompileBasicOp Pat (LetDec rep)
_ (UpdateAcc VName
acc [SubExp]
is [SubExp]
vs) = forall {k} (rep :: k) r op.
Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"UpdateAcc" forall a b. (a -> b) -> a -> b
$ do
let is' :: [TExp Int64]
is' = forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 [SubExp]
is
(VName
_, Space
_, [VName]
arrs, [TExp Int64]
dims, Maybe (Lambda rep)
op) <- forall {k} (rep :: k) r op.
VName
-> [TExp Int64]
-> ImpM
rep r op (VName, Space, [VName], [TExp Int64], Maybe (Lambda rep))
lookupAcc VName
acc [TExp Int64]
is'
forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (Slice (TExp Int64) -> [TExp Int64] -> TExp Bool
inBounds (forall d. [DimIndex d] -> Slice d
Slice (forall a b. (a -> b) -> [a] -> [b]
map forall d. d -> DimIndex d
DimFix [TExp Int64]
is')) [TExp Int64]
dims) forall a b. (a -> b) -> a -> b
$
case Maybe (Lambda rep)
op of
Maybe (Lambda rep)
Nothing ->
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
arrs [SubExp]
vs) forall a b. (a -> b) -> a -> b
$ \(VName
arr, SubExp
v) -> forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
arr [TExp Int64]
is' SubExp
v []
Just Lambda rep
lam -> do
forall {k} (rep :: k) inner r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam
let ([VName]
x_params, [VName]
y_params) =
forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
vs) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
x_params [VName]
arrs) forall a b. (a -> b) -> a -> b
$ \(VName
xp, VName
arr) ->
forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
xp [] (VName -> SubExp
Var VName
arr) [TExp Int64]
is'
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
y_params [SubExp]
vs) forall a b. (a -> b) -> a -> b
$ \(VName
yp, SubExp
v) ->
forall {k} (rep :: k) r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM VName
yp [] SubExp
v []
forall {k} (rep :: k) r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms forall a. Monoid a => a
mempty (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
lam) forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
arrs (forall {k} (rep :: k). Body rep -> Result
bodyResult (forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
lam))) forall a b. (a -> b) -> a -> b
$ \(VName
arr, SubExpRes Certs
_ SubExp
se) ->
forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
arr [TExp Int64]
is' SubExp
se []
defCompileBasicOp Pat (LetDec rep)
pat BasicOp
e =
forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$
[Char]
"ImpGen.defCompileBasicOp: Invalid pattern\n "
forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString Pat (LetDec rep)
pat
forall a. [a] -> [a] -> [a]
++ [Char]
"\nfor expression\n "
forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString BasicOp
e
addArrays :: [ArrayDecl] -> ImpM rep r op ()
addArrays :: forall {k} (rep :: k) r op. [ArrayDecl] -> ImpM rep r op ()
addArrays = forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall {k} {rep :: k} {r} {op}. ArrayDecl -> ImpM rep r op ()
addArray
where
addArray :: ArrayDecl -> ImpM rep r op ()
addArray (ArrayDecl VName
name PrimType
bt MemLoc
location) =
forall {k} (rep :: k) r op.
VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
name forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k).
Maybe (Exp rep) -> ArrayEntry -> VarEntry rep
ArrayVar
forall a. Maybe a
Nothing
ArrayEntry
{ entryArrayLoc :: MemLoc
entryArrayLoc = MemLoc
location,
entryArrayElemType :: PrimType
entryArrayElemType = PrimType
bt
}
addFParams :: Mem rep inner => [FParam rep] -> ImpM rep r op ()
addFParams :: forall {k} (rep :: k) inner r op.
Mem rep inner =>
[FParam rep] -> ImpM rep r op ()
addFParams = forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall {k} {u} {rep :: k} {r} {op}.
Param (MemInfo SubExp u MemBind) -> ImpM rep r op ()
addFParam
where
addFParam :: Param (MemInfo SubExp u MemBind) -> ImpM rep r op ()
addFParam Param (MemInfo SubExp u MemBind)
fparam =
forall {k} (rep :: k) r op.
VName -> VarEntry rep -> ImpM rep r op ()
addVar (forall dec. Param dec -> VName
paramName Param (MemInfo SubExp u MemBind)
fparam) forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). Maybe (Exp rep) -> LParamMem -> VarEntry rep
memBoundToVarEntry forall a. Maybe a
Nothing forall a b. (a -> b) -> a -> b
$
forall d u r. MemInfo d u r -> MemInfo d NoUniqueness r
noUniquenessReturns forall a b. (a -> b) -> a -> b
$
forall dec. Param dec -> dec
paramDec Param (MemInfo SubExp u MemBind)
fparam
addLoopVar :: VName -> IntType -> ImpM rep r op ()
addLoopVar :: forall {k} (rep :: k) r op. VName -> IntType -> ImpM rep r op ()
addLoopVar VName
i IntType
it = forall {k} (rep :: k) r op.
VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
i forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Maybe (Exp rep) -> ScalarEntry -> VarEntry rep
ScalarVar forall a. Maybe a
Nothing forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarEntry
ScalarEntry forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
it
dVars ::
Mem rep inner =>
Maybe (Exp rep) ->
[PatElem (LetDec rep)] ->
ImpM rep r op ()
dVars :: forall {k} (rep :: k) inner r op.
Mem rep inner =>
Maybe (Exp rep) -> [PatElem (LetDec rep)] -> ImpM rep r op ()
dVars Maybe (Exp rep)
e = forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ PatElem (LetDec rep) -> ImpM rep r op ()
dVar
where
dVar :: PatElem (LetDec rep) -> ImpM rep r op ()
dVar = forall {k} (rep :: k) inner r op.
Mem rep inner =>
Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
dScope Maybe (Exp rep)
e forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k) dec.
(LetDec rep ~ dec) =>
PatElem dec -> Scope rep
scopeOfPatElem
dFParams :: Mem rep inner => [FParam rep] -> ImpM rep r op ()
dFParams :: forall {k} (rep :: k) inner r op.
Mem rep inner =>
[FParam rep] -> ImpM rep r op ()
dFParams = forall {k} (rep :: k) inner r op.
Mem rep inner =>
Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
dScope forall a. Maybe a
Nothing forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k) dec.
(FParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfFParams
dLParams :: Mem rep inner => [LParam rep] -> ImpM rep r op ()
dLParams :: forall {k} (rep :: k) inner r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams = forall {k} (rep :: k) inner r op.
Mem rep inner =>
Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
dScope forall a. Maybe a
Nothing forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k) dec.
(LParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfLParams
dPrimVol :: String -> PrimType -> Imp.TExp t -> ImpM rep r op (TV t)
dPrimVol :: forall {k} {k} (t :: k) (rep :: k) r op.
[Char] -> PrimType -> TExp t -> ImpM rep r op (TV t)
dPrimVol [Char]
name PrimType
t TExp t
e = do
VName
name' <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
name
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
name' Volatility
Imp.Volatile PrimType
t
forall {k} (rep :: k) r op.
VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
name' forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Maybe (Exp rep) -> ScalarEntry -> VarEntry rep
ScalarVar forall a. Maybe a
Nothing forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarEntry
ScalarEntry PrimType
t
VName
name' forall {k} (rep :: k) r op. VName -> Exp -> ImpM rep r op ()
<~~ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp t
e
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). VName -> PrimType -> TV t
TV VName
name' PrimType
t
dPrim_ :: VName -> PrimType -> ImpM rep r op ()
dPrim_ :: forall {k} (rep :: k) r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
name PrimType
t = do
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
name Volatility
Imp.Nonvolatile PrimType
t
forall {k} (rep :: k) r op.
VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
name forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Maybe (Exp rep) -> ScalarEntry -> VarEntry rep
ScalarVar forall a. Maybe a
Nothing forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarEntry
ScalarEntry PrimType
t
dPrim :: String -> PrimType -> ImpM rep r op (TV t)
dPrim :: forall {k} {k} (rep :: k) r op (t :: k).
[Char] -> PrimType -> ImpM rep r op (TV t)
dPrim [Char]
name PrimType
t = do
VName
name' <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
name
forall {k} (rep :: k) r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
name' PrimType
t
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). VName -> PrimType -> TV t
TV VName
name' PrimType
t
dPrimV_ :: VName -> Imp.TExp t -> ImpM rep r op ()
dPrimV_ :: forall {k} {k} (t :: k) (rep :: k) r op.
VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
name TExp t
e = do
forall {k} (rep :: k) r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
name PrimType
t
forall {k} (t :: k). VName -> PrimType -> TV t
TV VName
name PrimType
t forall {k} {k} (t :: k) (rep :: k) r op.
TV t -> TExp t -> ImpM rep r op ()
<-- TExp t
e
where
t :: PrimType
t = forall v. PrimExp v -> PrimType
primExpType forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp t
e
dPrimV :: String -> Imp.TExp t -> ImpM rep r op (TV t)
dPrimV :: forall {k} {k} (t :: k) (rep :: k) r op.
[Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
name TExp t
e = do
TV t
name' <- forall {k} {k} (rep :: k) r op (t :: k).
[Char] -> PrimType -> ImpM rep r op (TV t)
dPrim [Char]
name forall a b. (a -> b) -> a -> b
$ forall v. PrimExp v -> PrimType
primExpType forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp t
e
TV t
name' forall {k} {k} (t :: k) (rep :: k) r op.
TV t -> TExp t -> ImpM rep r op ()
<-- TExp t
e
forall (f :: * -> *) a. Applicative f => a -> f a
pure TV t
name'
dPrimVE :: String -> Imp.TExp t -> ImpM rep r op (Imp.TExp t)
dPrimVE :: forall {k} {k} (t :: k) (rep :: k) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
name TExp t
e = do
TV t
name' <- forall {k} {k} (rep :: k) r op (t :: k).
[Char] -> PrimType -> ImpM rep r op (TV t)
dPrim [Char]
name forall a b. (a -> b) -> a -> b
$ forall v. PrimExp v -> PrimType
primExpType forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp t
e
TV t
name' forall {k} {k} (t :: k) (rep :: k) r op.
TV t -> TExp t -> ImpM rep r op ()
<-- TExp t
e
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k). TV t -> TExp t
tvExp TV t
name'
memBoundToVarEntry ::
Maybe (Exp rep) ->
MemBound NoUniqueness ->
VarEntry rep
memBoundToVarEntry :: forall {k} (rep :: k). Maybe (Exp rep) -> LParamMem -> VarEntry rep
memBoundToVarEntry Maybe (Exp rep)
e (MemPrim PrimType
bt) =
forall {k} (rep :: k).
Maybe (Exp rep) -> ScalarEntry -> VarEntry rep
ScalarVar Maybe (Exp rep)
e ScalarEntry {entryScalarType :: PrimType
entryScalarType = PrimType
bt}
memBoundToVarEntry Maybe (Exp rep)
e (MemMem Space
space) =
forall {k} (rep :: k). Maybe (Exp rep) -> MemEntry -> VarEntry rep
MemVar Maybe (Exp rep)
e forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
space
memBoundToVarEntry Maybe (Exp rep)
e (MemAcc VName
acc Shape
ispace [Type]
ts NoUniqueness
_) =
forall {k} (rep :: k).
Maybe (Exp rep) -> (VName, Shape, [Type]) -> VarEntry rep
AccVar Maybe (Exp rep)
e (VName
acc, Shape
ispace, [Type]
ts)
memBoundToVarEntry Maybe (Exp rep)
e (MemArray PrimType
bt Shape
shape NoUniqueness
_ (ArrayIn VName
mem IxFun (TExp Int64)
ixfun)) =
let location :: MemLoc
location = VName -> [SubExp] -> IxFun (TExp Int64) -> MemLoc
MemLoc VName
mem (forall d. ShapeBase d -> [d]
shapeDims Shape
shape) IxFun (TExp Int64)
ixfun
in forall {k} (rep :: k).
Maybe (Exp rep) -> ArrayEntry -> VarEntry rep
ArrayVar
Maybe (Exp rep)
e
ArrayEntry
{ entryArrayLoc :: MemLoc
entryArrayLoc = MemLoc
location,
entryArrayElemType :: PrimType
entryArrayElemType = PrimType
bt
}
infoDec ::
Mem rep inner =>
NameInfo rep ->
MemInfo SubExp NoUniqueness MemBind
infoDec :: forall {k} (rep :: k) inner.
Mem rep inner =>
NameInfo rep -> LParamMem
infoDec (LetName LetDec rep
dec) = forall t. HasLetDecMem t => t -> LParamMem
letDecMem LetDec rep
dec
infoDec (FParamName FParamInfo rep
dec) = forall d u r. MemInfo d u r -> MemInfo d NoUniqueness r
noUniquenessReturns FParamInfo rep
dec
infoDec (LParamName LParamInfo rep
dec) = LParamInfo rep
dec
infoDec (IndexName IntType
it) = forall d u ret. PrimType -> MemInfo d u ret
MemPrim forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
it
dInfo ::
Mem rep inner =>
Maybe (Exp rep) ->
VName ->
NameInfo rep ->
ImpM rep r op ()
dInfo :: forall {k} (rep :: k) inner r op.
Mem rep inner =>
Maybe (Exp rep) -> VName -> NameInfo rep -> ImpM rep r op ()
dInfo Maybe (Exp rep)
e VName
name NameInfo rep
info = do
let entry :: VarEntry rep
entry = forall {k} (rep :: k). Maybe (Exp rep) -> LParamMem -> VarEntry rep
memBoundToVarEntry Maybe (Exp rep)
e forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) inner.
Mem rep inner =>
NameInfo rep -> LParamMem
infoDec NameInfo rep
info
case VarEntry rep
entry of
MemVar Maybe (Exp rep)
_ MemEntry
entry' ->
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. VName -> Space -> Code a
Imp.DeclareMem VName
name forall a b. (a -> b) -> a -> b
$ MemEntry -> Space
entryMemSpace MemEntry
entry'
ScalarVar Maybe (Exp rep)
_ ScalarEntry
entry' ->
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
name Volatility
Imp.Nonvolatile forall a b. (a -> b) -> a -> b
$ ScalarEntry -> PrimType
entryScalarType ScalarEntry
entry'
ArrayVar Maybe (Exp rep)
_ ArrayEntry
_ ->
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
AccVar {} ->
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
forall {k} (rep :: k) r op.
VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
name VarEntry rep
entry
dScope ::
Mem rep inner =>
Maybe (Exp rep) ->
Scope rep ->
ImpM rep r op ()
dScope :: forall {k} (rep :: k) inner r op.
Mem rep inner =>
Maybe (Exp rep) -> Scope rep -> ImpM rep r op ()
dScope Maybe (Exp rep)
e = forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) inner r op.
Mem rep inner =>
Maybe (Exp rep) -> VName -> NameInfo rep -> ImpM rep r op ()
dInfo Maybe (Exp rep)
e) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k a. Map k a -> [(k, a)]
M.toList
dArray :: VName -> PrimType -> ShapeBase SubExp -> VName -> IxFun -> ImpM rep r op ()
dArray :: forall {k} (rep :: k) r op.
VName
-> PrimType
-> Shape
-> VName
-> IxFun (TExp Int64)
-> ImpM rep r op ()
dArray VName
name PrimType
pt Shape
shape VName
mem IxFun (TExp Int64)
ixfun =
forall {k} (rep :: k) r op.
VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
name forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Maybe (Exp rep) -> ArrayEntry -> VarEntry rep
ArrayVar forall a. Maybe a
Nothing forall a b. (a -> b) -> a -> b
$ MemLoc -> PrimType -> ArrayEntry
ArrayEntry MemLoc
location PrimType
pt
where
location :: MemLoc
location =
VName -> [SubExp] -> IxFun (TExp Int64) -> MemLoc
MemLoc VName
mem (forall d. ShapeBase d -> [d]
shapeDims Shape
shape) IxFun (TExp Int64)
ixfun
everythingVolatile :: ImpM rep r op a -> ImpM rep r op a
everythingVolatile :: forall {k} (rep :: k) r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile = forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local forall a b. (a -> b) -> a -> b
$ \Env rep r op
env -> Env rep r op
env {envVolatility :: Volatility
envVolatility = Volatility
Imp.Volatile}
funcallTargets :: [ValueDestination] -> ImpM rep r op [VName]
funcallTargets :: forall {k} (rep :: k) r op.
[ValueDestination] -> ImpM rep r op [VName]
funcallTargets [ValueDestination]
dests =
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {f :: * -> *}.
Applicative f =>
ValueDestination -> f [VName]
funcallTarget [ValueDestination]
dests
where
funcallTarget :: ValueDestination -> f [VName]
funcallTarget (ScalarDestination VName
name) =
forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName
name]
funcallTarget (ArrayDestination Maybe MemLoc
_) =
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
funcallTarget (MemoryDestination VName
name) =
forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName
name]
data TV t = TV VName PrimType
mkTV :: VName -> PrimType -> TV t
mkTV :: forall {k} (t :: k). VName -> PrimType -> TV t
mkTV = forall {k} (t :: k). VName -> PrimType -> TV t
TV
tvSize :: TV t -> Imp.DimSize
tvSize :: forall {k} (t :: k). TV t -> SubExp
tvSize = VName -> SubExp
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (t :: k). TV t -> VName
tvVar
tvExp :: TV t -> Imp.TExp t
tvExp :: forall {k} (t :: k). TV t -> TExp t
tvExp (TV VName
v PrimType
t) = forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
Imp.TPrimExp forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
v PrimType
t
tvVar :: TV t -> VName
tvVar :: forall {k} (t :: k). TV t -> VName
tvVar (TV VName
v PrimType
_) = VName
v
class ToExp a where
toExp :: a -> ImpM rep r op Imp.Exp
toExp' :: PrimType -> a -> Imp.Exp
instance ToExp SubExp where
toExp :: forall {k} (rep :: k) r op. SubExp -> ImpM rep r op Exp
toExp (Constant PrimValue
v) =
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall v. PrimValue -> PrimExp v
Imp.ValueExp PrimValue
v
toExp (Var VName
v) =
forall {k} (rep :: k) r op. VName -> ImpM rep r op (VarEntry rep)
lookupVar VName
v forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
ScalarVar Maybe (Exp rep)
_ (ScalarEntry PrimType
pt) ->
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
v PrimType
pt
VarEntry rep
_ -> forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"toExp SubExp: SubExp is not a primitive type: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString VName
v
toExp' :: PrimType -> SubExp -> Exp
toExp' PrimType
_ (Constant PrimValue
v) = forall v. PrimValue -> PrimExp v
Imp.ValueExp PrimValue
v
toExp' PrimType
t (Var VName
v) = VName -> PrimType -> Exp
Imp.var VName
v PrimType
t
instance ToExp (PrimExp VName) where
toExp :: forall {k} (rep :: k) r op. Exp -> ImpM rep r op Exp
toExp = forall (f :: * -> *) a. Applicative f => a -> f a
pure
toExp' :: PrimType -> Exp -> Exp
toExp' PrimType
_ = forall a. a -> a
id
addVar :: VName -> VarEntry rep -> ImpM rep r op ()
addVar :: forall {k} (rep :: k) r op.
VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
name VarEntry rep
entry =
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s -> ImpState rep r op
s {stateVTable :: VTable rep
stateVTable = forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
name VarEntry rep
entry forall a b. (a -> b) -> a -> b
$ forall {k} {k} (rep :: k) (r :: k) op.
ImpState rep r op -> VTable rep
stateVTable ImpState rep r op
s}
localDefaultSpace :: Imp.Space -> ImpM rep r op a -> ImpM rep r op a
localDefaultSpace :: forall {k} (rep :: k) r op a.
Space -> ImpM rep r op a -> ImpM rep r op a
localDefaultSpace Space
space = forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (\Env rep r op
env -> Env rep r op
env {envDefaultSpace :: Space
envDefaultSpace = Space
space})
askFunction :: ImpM rep r op (Maybe Name)
askFunction :: forall {k} (rep :: k) r op. ImpM rep r op (Maybe Name)
askFunction = forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall {k} (rep :: k) r op. Env rep r op -> Maybe Name
envFunction
newVNameForFun :: String -> ImpM rep r op VName
newVNameForFun :: forall {k} (rep :: k) r op. [Char] -> ImpM rep r op VName
newVNameForFun [Char]
s = do
Maybe [Char]
fname <- forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Name -> [Char]
nameToString forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. ImpM rep r op (Maybe Name)
askFunction
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName forall a b. (a -> b) -> a -> b
$ forall b a. b -> (a -> b) -> Maybe a -> b
maybe [Char]
"" (forall a. [a] -> [a] -> [a]
++ [Char]
".") Maybe [Char]
fname forall a. [a] -> [a] -> [a]
++ [Char]
s
nameForFun :: String -> ImpM rep r op Name
nameForFun :: forall {k} (rep :: k) r op. [Char] -> ImpM rep r op Name
nameForFun [Char]
s = do
Maybe Name
fname <- forall {k} (rep :: k) r op. ImpM rep r op (Maybe Name)
askFunction
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall b a. b -> (a -> b) -> Maybe a -> b
maybe Name
"" (forall a. Semigroup a => a -> a -> a
<> Name
".") Maybe Name
fname forall a. Semigroup a => a -> a -> a
<> [Char] -> Name
nameFromString [Char]
s
askEnv :: ImpM rep r op r
askEnv :: forall {k} (rep :: k) r op. ImpM rep r op r
askEnv = forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall {k} (rep :: k) r op. Env rep r op -> r
envEnv
localEnv :: (r -> r) -> ImpM rep r op a -> ImpM rep r op a
localEnv :: forall {k} r (rep :: k) op a.
(r -> r) -> ImpM rep r op a -> ImpM rep r op a
localEnv r -> r
f = forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local forall a b. (a -> b) -> a -> b
$ \Env rep r op
env -> Env rep r op
env {envEnv :: r
envEnv = r -> r
f forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) r op. Env rep r op -> r
envEnv Env rep r op
env}
askAttrs :: ImpM rep r op Attrs
askAttrs :: forall {k} (rep :: k) r op. ImpM rep r op Attrs
askAttrs = forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall {k} (rep :: k) r op. Env rep r op -> Attrs
envAttrs
localAttrs :: Attrs -> ImpM rep r op a -> ImpM rep r op a
localAttrs :: forall {k} (rep :: k) r op a.
Attrs -> ImpM rep r op a -> ImpM rep r op a
localAttrs Attrs
attrs = forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local forall a b. (a -> b) -> a -> b
$ \Env rep r op
env -> Env rep r op
env {envAttrs :: Attrs
envAttrs = Attrs
attrs forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k) r op. Env rep r op -> Attrs
envAttrs Env rep r op
env}
localOps :: Operations rep r op -> ImpM rep r op a -> ImpM rep r op a
localOps :: forall {k} (rep :: k) r op a.
Operations rep r op -> ImpM rep r op a -> ImpM rep r op a
localOps Operations rep r op
ops = forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local forall a b. (a -> b) -> a -> b
$ \Env rep r op
env ->
Env rep r op
env
{ envExpCompiler :: ExpCompiler rep r op
envExpCompiler = forall {k} (rep :: k) r op.
Operations rep r op -> ExpCompiler rep r op
opsExpCompiler Operations rep r op
ops,
envStmsCompiler :: StmsCompiler rep r op
envStmsCompiler = forall {k} (rep :: k) r op.
Operations rep r op -> StmsCompiler rep r op
opsStmsCompiler Operations rep r op
ops,
envCopyCompiler :: CopyCompiler rep r op
envCopyCompiler = forall {k} (rep :: k) r op.
Operations rep r op -> CopyCompiler rep r op
opsCopyCompiler Operations rep r op
ops,
envOpCompiler :: OpCompiler rep r op
envOpCompiler = forall {k} (rep :: k) r op.
Operations rep r op -> OpCompiler rep r op
opsOpCompiler Operations rep r op
ops,
envAllocCompilers :: Map Space (AllocCompiler rep r op)
envAllocCompilers = forall {k} (rep :: k) r op.
Operations rep r op -> Map Space (AllocCompiler rep r op)
opsAllocCompilers Operations rep r op
ops
}
getVTable :: ImpM rep r op (VTable rep)
getVTable :: forall {k} (rep :: k) r op. ImpM rep r op (VTable rep)
getVTable = forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall {k} {k} (rep :: k) (r :: k) op.
ImpState rep r op -> VTable rep
stateVTable
putVTable :: VTable rep -> ImpM rep r op ()
putVTable :: forall {k} (rep :: k) r op. VTable rep -> ImpM rep r op ()
putVTable VTable rep
vtable = forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \ImpState rep r op
s -> ImpState rep r op
s {stateVTable :: VTable rep
stateVTable = VTable rep
vtable}
localVTable :: (VTable rep -> VTable rep) -> ImpM rep r op a -> ImpM rep r op a
localVTable :: forall {k} (rep :: k) r op a.
(VTable rep -> VTable rep) -> ImpM rep r op a -> ImpM rep r op a
localVTable VTable rep -> VTable rep
f ImpM rep r op a
m = do
VTable rep
old_vtable <- forall {k} (rep :: k) r op. ImpM rep r op (VTable rep)
getVTable
forall {k} (rep :: k) r op. VTable rep -> ImpM rep r op ()
putVTable forall a b. (a -> b) -> a -> b
$ VTable rep -> VTable rep
f VTable rep
old_vtable
a
a <- ImpM rep r op a
m
forall {k} (rep :: k) r op. VTable rep -> ImpM rep r op ()
putVTable VTable rep
old_vtable
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
a
lookupVar :: VName -> ImpM rep r op (VarEntry rep)
lookupVar :: forall {k} (rep :: k) r op. VName -> ImpM rep r op (VarEntry rep)
lookupVar VName
name = do
Maybe (VarEntry rep)
res <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} {k} (rep :: k) (r :: k) op.
ImpState rep r op -> VTable rep
stateVTable
case Maybe (VarEntry rep)
res of
Just VarEntry rep
entry -> forall (f :: * -> *) a. Applicative f => a -> f a
pure VarEntry rep
entry
Maybe (VarEntry rep)
_ -> forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"Unknown variable: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString VName
name
lookupArray :: VName -> ImpM rep r op ArrayEntry
lookupArray :: forall {k} (rep :: k) r op. VName -> ImpM rep r op ArrayEntry
lookupArray VName
name = do
VarEntry rep
res <- forall {k} (rep :: k) r op. VName -> ImpM rep r op (VarEntry rep)
lookupVar VName
name
case VarEntry rep
res of
ArrayVar Maybe (Exp rep)
_ ArrayEntry
entry -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ArrayEntry
entry
VarEntry rep
_ -> forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"ImpGen.lookupArray: not an array: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString VName
name
lookupMemory :: VName -> ImpM rep r op MemEntry
lookupMemory :: forall {k} (rep :: k) r op. VName -> ImpM rep r op MemEntry
lookupMemory VName
name = do
VarEntry rep
res <- forall {k} (rep :: k) r op. VName -> ImpM rep r op (VarEntry rep)
lookupVar VName
name
case VarEntry rep
res of
MemVar Maybe (Exp rep)
_ MemEntry
entry -> forall (f :: * -> *) a. Applicative f => a -> f a
pure MemEntry
entry
VarEntry rep
_ -> forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"Unknown memory block: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString VName
name
lookupArraySpace :: VName -> ImpM rep r op Space
lookupArraySpace :: forall {k} (rep :: k) r op. VName -> ImpM rep r op Space
lookupArraySpace =
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap MemEntry -> Space
entryMemSpace forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k) r op. VName -> ImpM rep r op MemEntry
lookupMemory
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (MemLoc -> VName
memLocName forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArrayEntry -> MemLoc
entryArrayLoc) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k) r op. VName -> ImpM rep r op ArrayEntry
lookupArray
lookupAcc ::
VName ->
[Imp.TExp Int64] ->
ImpM rep r op (VName, Space, [VName], [Imp.TExp Int64], Maybe (Lambda rep))
lookupAcc :: forall {k} (rep :: k) r op.
VName
-> [TExp Int64]
-> ImpM
rep r op (VName, Space, [VName], [TExp Int64], Maybe (Lambda rep))
lookupAcc VName
name [TExp Int64]
is = do
VarEntry rep
res <- forall {k} (rep :: k) r op. VName -> ImpM rep r op (VarEntry rep)
lookupVar VName
name
case VarEntry rep
res of
AccVar Maybe (Exp rep)
_ (VName
acc, Shape
ispace, [Type]
_) -> do
Maybe ([VName], Maybe (Lambda rep, [SubExp]))
acc' <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
acc forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} {k} (rep :: k) (r :: k) op.
ImpState rep r op
-> Map VName ([VName], Maybe (Lambda rep, [SubExp]))
stateAccs
case Maybe ([VName], Maybe (Lambda rep, [SubExp]))
acc' of
Just ([], Maybe (Lambda rep, [SubExp])
_) ->
forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"Accumulator with no arrays: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString VName
name
Just (arrs :: [VName]
arrs@(VName
arr : [VName]
_), Just (Lambda rep
op, [SubExp]
_)) -> do
Space
space <- forall {k} (rep :: k) r op. VName -> ImpM rep r op Space
lookupArraySpace VName
arr
let ([Param (LParamInfo rep)]
i_params, [Param (LParamInfo rep)]
ps) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [TExp Int64]
is) forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda rep
op
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ forall {k} {k} (t :: k) (rep :: k) r op.
VName -> TExp t -> ImpM rep r op ()
dPrimV_ (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param (LParamInfo rep)]
i_params) [TExp Int64]
is
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( VName
acc,
Space
space,
[VName]
arrs,
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 (forall d. ShapeBase d -> [d]
shapeDims Shape
ispace),
forall a. a -> Maybe a
Just Lambda rep
op {lambdaParams :: [Param (LParamInfo rep)]
lambdaParams = [Param (LParamInfo rep)]
ps}
)
Just (arrs :: [VName]
arrs@(VName
arr : [VName]
_), Maybe (Lambda rep, [SubExp])
Nothing) -> do
Space
space <- forall {k} (rep :: k) r op. VName -> ImpM rep r op Space
lookupArraySpace VName
arr
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
acc, Space
space, [VName]
arrs, forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 (forall d. ShapeBase d -> [d]
shapeDims Shape
ispace), forall a. Maybe a
Nothing)
Maybe ([VName], Maybe (Lambda rep, [SubExp]))
Nothing ->
forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"ImpGen.lookupAcc: unlisted accumulator: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString VName
name
VarEntry rep
_ -> forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"ImpGen.lookupAcc: not an accumulator: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString VName
name
destinationFromPat :: Pat (LetDec rep) -> ImpM rep r op [ValueDestination]
destinationFromPat :: forall {k} (rep :: k) r op.
Pat (LetDec rep) -> ImpM rep r op [ValueDestination]
destinationFromPat = forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {k} {dec} {rep :: k} {r} {op}.
PatElem dec -> ImpM rep r op ValueDestination
inspect forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Pat dec -> [PatElem dec]
patElems
where
inspect :: PatElem dec -> ImpM rep r op ValueDestination
inspect PatElem dec
pe = do
let name :: VName
name = forall dec. PatElem dec -> VName
patElemName PatElem dec
pe
VarEntry rep
entry <- forall {k} (rep :: k) r op. VName -> ImpM rep r op (VarEntry rep)
lookupVar VName
name
case VarEntry rep
entry of
ArrayVar Maybe (Exp rep)
_ (ArrayEntry MemLoc {} PrimType
_) ->
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Maybe MemLoc -> ValueDestination
ArrayDestination forall a. Maybe a
Nothing
MemVar {} ->
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> ValueDestination
MemoryDestination VName
name
ScalarVar {} ->
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> ValueDestination
ScalarDestination VName
name
AccVar {} ->
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Maybe MemLoc -> ValueDestination
ArrayDestination forall a. Maybe a
Nothing
fullyIndexArray ::
VName ->
[Imp.TExp Int64] ->
ImpM rep r op (VName, Imp.Space, Count Elements (Imp.TExp Int64))
fullyIndexArray :: forall {k} (rep :: k) r op.
VName
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray VName
name [TExp Int64]
indices = do
ArrayEntry
arr <- forall {k} (rep :: k) r op. VName -> ImpM rep r op ArrayEntry
lookupArray VName
name
forall {k} (rep :: k) r op.
MemLoc
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' (ArrayEntry -> MemLoc
entryArrayLoc ArrayEntry
arr) [TExp Int64]
indices
fullyIndexArray' ::
MemLoc ->
[Imp.TExp Int64] ->
ImpM rep r op (VName, Imp.Space, Count Elements (Imp.TExp Int64))
fullyIndexArray' :: forall {k} (rep :: k) r op.
MemLoc
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' (MemLoc VName
mem [SubExp]
_ IxFun (TExp Int64)
ixfun) [TExp Int64]
indices = do
Space
space <- MemEntry -> Space
entryMemSpace forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. VName -> ImpM rep r op MemEntry
lookupMemory VName
mem
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( VName
mem,
Space
space,
forall a. a -> Count Elements a
elements forall a b. (a -> b) -> a -> b
$ forall num.
(IntegralExp num, Eq num) =>
IxFun num -> Indices num -> num
IxFun.index IxFun (TExp Int64)
ixfun [TExp Int64]
indices
)
copy :: CopyCompiler rep r op
copy :: forall {k} (rep :: k) r op. CopyCompiler rep r op
copy
PrimType
bt
dst :: MemLoc
dst@(MemLoc VName
dst_name [SubExp]
_ dst_ixfn :: IxFun (TExp Int64)
dst_ixfn@(IxFun.IxFun dst_lmads :: NonEmpty (LMAD (TExp Int64))
dst_lmads@(LMAD (TExp Int64)
dst_lmad :| [LMAD (TExp Int64)]
_) [TExp Int64]
_ Bool
_))
src :: MemLoc
src@(MemLoc VName
src_name [SubExp]
_ src_ixfn :: IxFun (TExp Int64)
src_ixfn@(IxFun.IxFun src_lmads :: NonEmpty (LMAD (TExp Int64))
src_lmads@(LMAD (TExp Int64)
src_lmad :| [LMAD (TExp Int64)]
_) [TExp Int64]
_ Bool
_)) = do
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (VName
dst_name forall a. Eq a => a -> a -> Bool
== VName
src_name Bool -> Bool -> Bool
&& IxFun (TExp Int64)
dst_ixfn forall num. Eq num => IxFun num -> IxFun num -> Bool
`IxFun.equivalent` IxFun (TExp Int64)
src_ixfn)
forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sUnless
( forall v. Bool -> TPrimExp Bool v
fromBool (VName
dst_name forall a. Eq a => a -> a -> Bool
== VName
src_name Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => t a -> Int
length NonEmpty (LMAD (TExp Int64))
dst_lmads forall a. Eq a => a -> a -> Bool
== Int
1 Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => t a -> Int
length NonEmpty (LMAD (TExp Int64))
src_lmads forall a. Eq a => a -> a -> Bool
== Int
1)
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. forall {k} num (t :: k).
Eq num =>
LMAD (TPrimExp t num) -> LMAD (TPrimExp t num) -> TPrimExp Bool num
IxFun.dynamicEqualsLMAD LMAD (TExp Int64)
dst_lmad LMAD (TExp Int64)
src_lmad
)
forall a b. (a -> b) -> a -> b
$ do
CopyCompiler rep r op
cc <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall {k} (rep :: k) r op. Env rep r op -> CopyCompiler rep r op
envCopyCompiler
CopyCompiler rep r op
cc PrimType
bt MemLoc
dst MemLoc
src
isMapTransposeCopy ::
PrimType ->
MemLoc ->
MemLoc ->
Maybe
( Imp.TExp Int64,
Imp.TExp Int64,
Imp.TExp Int64,
Imp.TExp Int64,
Imp.TExp Int64
)
isMapTransposeCopy :: PrimType
-> MemLoc
-> MemLoc
-> Maybe
(TExp Int64, TExp Int64, TExp Int64, TExp Int64, TExp Int64)
isMapTransposeCopy PrimType
bt (MemLoc VName
_ [SubExp]
_ IxFun (TExp Int64)
destIxFun) (MemLoc VName
_ [SubExp]
_ IxFun (TExp Int64)
srcIxFun)
| Just (TExp Int64
dest_offset, [(Int, TExp Int64)]
perm_and_destshape) <- forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe (num, [(Int, num)])
IxFun.rearrangeWithOffset IxFun (TExp Int64)
destIxFun TExp Int64
bt_size,
([Int]
perm, [TExp Int64]
destshape) <- forall a b. [(a, b)] -> ([a], [b])
unzip [(Int, TExp Int64)]
perm_and_destshape,
Just TExp Int64
src_offset <- forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
IxFun.linearWithOffset IxFun (TExp Int64)
srcIxFun TExp Int64
bt_size,
Just (Int
r1, Int
r2, Int
_) <- [Int] -> Maybe (Int, Int, Int)
isMapTranspose [Int]
perm =
forall {t :: * -> *} {t :: * -> *} {c} {d} {e} {f :: * -> *} {a}
{b}.
(Foldable t, Foldable t, Num c, Num d, Num e, Applicative f) =>
[c]
-> (([c], [c]) -> (t d, t e))
-> Int
-> Int
-> a
-> b
-> f (a, b, c, d, e)
isOk [TExp Int64]
destshape forall {b} {a}. (b, a) -> (a, b)
swap Int
r1 Int
r2 TExp Int64
dest_offset TExp Int64
src_offset
| Just TExp Int64
dest_offset <- forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
IxFun.linearWithOffset IxFun (TExp Int64)
destIxFun TExp Int64
bt_size,
Just (TExp Int64
src_offset, [(Int, TExp Int64)]
perm_and_srcshape) <- forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe (num, [(Int, num)])
IxFun.rearrangeWithOffset IxFun (TExp Int64)
srcIxFun TExp Int64
bt_size,
([Int]
perm, [TExp Int64]
srcshape) <- forall a b. [(a, b)] -> ([a], [b])
unzip [(Int, TExp Int64)]
perm_and_srcshape,
Just (Int
r1, Int
r2, Int
_) <- [Int] -> Maybe (Int, Int, Int)
isMapTranspose [Int]
perm =
forall {t :: * -> *} {t :: * -> *} {c} {d} {e} {f :: * -> *} {a}
{b}.
(Foldable t, Foldable t, Num c, Num d, Num e, Applicative f) =>
[c]
-> (([c], [c]) -> (t d, t e))
-> Int
-> Int
-> a
-> b
-> f (a, b, c, d, e)
isOk [TExp Int64]
srcshape forall a. a -> a
id Int
r1 Int
r2 TExp Int64
dest_offset TExp Int64
src_offset
| Bool
otherwise =
forall a. Maybe a
Nothing
where
bt_size :: TExp Int64
bt_size = forall a. Num a => PrimType -> a
primByteSize PrimType
bt
swap :: (b, a) -> (a, b)
swap (b
x, a
y) = (a
y, b
x)
isOk :: [c]
-> (([c], [c]) -> (t d, t e))
-> Int
-> Int
-> a
-> b
-> f (a, b, c, d, e)
isOk [c]
shape ([c], [c]) -> (t d, t e)
f Int
r1 Int
r2 a
dest_offset b
src_offset = do
let (c
num_arrays, d
size_x, e
size_y) = forall {t :: * -> *} {t :: * -> *} {a} {b} {c}.
(Foldable t, Foldable t, Num a, Num b, Num c) =>
[a] -> (([a], [a]) -> (t b, t c)) -> Int -> Int -> (a, b, c)
getSizes [c]
shape ([c], [c]) -> (t d, t e)
f Int
r1 Int
r2
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( a
dest_offset,
b
src_offset,
c
num_arrays,
d
size_x,
e
size_y
)
getSizes :: [a] -> (([a], [a]) -> (t b, t c)) -> Int -> Int -> (a, b, c)
getSizes [a]
shape ([a], [a]) -> (t b, t c)
f Int
r1 Int
r2 =
let ([a]
mapped, [a]
notmapped) = forall a. Int -> [a] -> ([a], [a])
splitAt Int
r1 [a]
shape
(t b
pretrans, t c
posttrans) = ([a], [a]) -> (t b, t c)
f forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> ([a], [a])
splitAt Int
r2 [a]
notmapped
in (forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [a]
mapped, forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product t b
pretrans, forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product t c
posttrans)
mapTransposeName :: PrimType -> String
mapTransposeName :: PrimType -> [Char]
mapTransposeName PrimType
bt = [Char]
"map_transpose_" forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString PrimType
bt
mapTransposeForType :: PrimType -> ImpM rep r op Name
mapTransposeForType :: forall {k} (rep :: k) r op. PrimType -> ImpM rep r op Name
mapTransposeForType PrimType
bt = do
let fname :: Name
fname = [Char] -> Name
nameFromString forall a b. (a -> b) -> a -> b
$ [Char]
"builtin#" forall a. Semigroup a => a -> a -> a
<> PrimType -> [Char]
mapTransposeName PrimType
bt
Bool
exists <- forall {k} (rep :: k) r op. Name -> ImpM rep r op Bool
hasFunction Name
fname
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
exists forall a b. (a -> b) -> a -> b
$ forall {k} op (rep :: k) r. Name -> Function op -> ImpM rep r op ()
emitFunction Name
fname forall a b. (a -> b) -> a -> b
$ forall op. Name -> PrimType -> Function op
mapTransposeFunction Name
fname PrimType
bt
forall (f :: * -> *) a. Applicative f => a -> f a
pure Name
fname
defaultCopy :: CopyCompiler rep r op
defaultCopy :: forall {k} (rep :: k) r op. CopyCompiler rep r op
defaultCopy PrimType
pt MemLoc
dest MemLoc
src
| Just (TExp Int64
destoffset, TExp Int64
srcoffset, TExp Int64
num_arrays, TExp Int64
size_x, TExp Int64
size_y) <-
PrimType
-> MemLoc
-> MemLoc
-> Maybe
(TExp Int64, TExp Int64, TExp Int64, TExp Int64, TExp Int64)
isMapTransposeCopy PrimType
pt MemLoc
dest MemLoc
src = do
Name
fname <- forall {k} (rep :: k) r op. PrimType -> ImpM rep r op Name
mapTransposeForType PrimType
pt
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit
forall a b. (a -> b) -> a -> b
$ forall a. [VName] -> Name -> [Arg] -> Code a
Imp.Call
[]
Name
fname
forall a b. (a -> b) -> a -> b
$ PrimType
-> VName
-> Count Bytes (TExp Int64)
-> VName
-> Count Bytes (TExp Int64)
-> TExp Int64
-> TExp Int64
-> TExp Int64
-> [Arg]
transposeArgs
PrimType
pt
VName
destmem
(forall a. a -> Count Bytes a
bytes TExp Int64
destoffset)
VName
srcmem
(forall a. a -> Count Bytes a
bytes TExp Int64
srcoffset)
TExp Int64
num_arrays
TExp Int64
size_x
TExp Int64
size_y
| Just TExp Int64
destoffset <-
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
IxFun.linearWithOffset IxFun (TExp Int64)
dest_ixfun TExp Int64
pt_size,
Just TExp Int64
srcoffset <-
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
IxFun.linearWithOffset IxFun (TExp Int64)
src_ixfun TExp Int64
pt_size = do
Space
srcspace <- MemEntry -> Space
entryMemSpace forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. VName -> ImpM rep r op MemEntry
lookupMemory VName
srcmem
Space
destspace <- MemEntry -> Space
entryMemSpace forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) r op. VName -> ImpM rep r op MemEntry
lookupMemory VName
destmem
if Space -> Bool
isScalarSpace Space
srcspace Bool -> Bool -> Bool
|| Space -> Bool
isScalarSpace Space
destspace
then forall {k} (rep :: k) r op. CopyCompiler rep r op
copyElementWise PrimType
pt MemLoc
dest MemLoc
src
else forall {k} (rep :: k) r op.
VName
-> TExp Int64
-> Space
-> VName
-> TExp Int64
-> Space
-> Count Elements (TExp Int64)
-> PrimType
-> ImpM rep r op ()
sCopy VName
destmem TExp Int64
destoffset Space
destspace VName
srcmem TExp Int64
srcoffset Space
srcspace Count Elements (TExp Int64)
num_elems PrimType
pt
| Bool
otherwise =
forall {k} (rep :: k) r op. CopyCompiler rep r op
copyElementWise PrimType
pt MemLoc
dest MemLoc
src
where
pt_size :: TExp Int64
pt_size = forall a. Num a => PrimType -> a
primByteSize PrimType
pt
num_elems :: Count Elements (TExp Int64)
num_elems = forall a. a -> Count Elements a
Imp.elements forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product forall a b. (a -> b) -> a -> b
$ forall num. (Eq num, IntegralExp num) => IxFun num -> Shape num
IxFun.shape forall a b. (a -> b) -> a -> b
$ MemLoc -> IxFun (TExp Int64)
memLocIxFun MemLoc
src
MemLoc VName
destmem [SubExp]
_ IxFun (TExp Int64)
dest_ixfun = MemLoc
dest
MemLoc VName
srcmem [SubExp]
_ IxFun (TExp Int64)
src_ixfun = MemLoc
src
isScalarSpace :: Space -> Bool
isScalarSpace ScalarSpace {} = Bool
True
isScalarSpace Space
_ = Bool
False
copyElementWise :: CopyCompiler rep r op
copyElementWise :: forall {k} (rep :: k) r op. CopyCompiler rep r op
copyElementWise PrimType
bt MemLoc
dest MemLoc
src = do
let bounds :: [TExp Int64]
bounds = forall num. (Eq num, IntegralExp num) => IxFun num -> Shape num
IxFun.shape forall a b. (a -> b) -> a -> b
$ MemLoc -> IxFun (TExp Int64)
memLocIxFun MemLoc
src
[VName]
is <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (forall (t :: * -> *) a. Foldable t => t a -> Int
length [TExp Int64]
bounds) (forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"i")
let ivars :: [TExp Int64]
ivars = forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
is
(VName
destmem, Space
destspace, Count Elements (TExp Int64)
destidx) <- forall {k} (rep :: k) r op.
MemLoc
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLoc
dest [TExp Int64]
ivars
(VName
srcmem, Space
srcspace, Count Elements (TExp Int64)
srcidx) <- forall {k} (rep :: k) r op.
MemLoc
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLoc
src [TExp Int64]
ivars
Volatility
vol <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall {k} (rep :: k) r op. Env rep r op -> Volatility
envVolatility
VName
tmp <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"tmp"
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl forall b c a. (b -> c) -> (a -> b) -> a -> c
(.) forall a. a -> a
id (forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. VName -> Exp -> Code a -> Code a
Imp.For [VName]
is forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped [TExp Int64]
bounds) forall a b. (a -> b) -> a -> b
$
forall a. Monoid a => [a] -> a
mconcat
[ forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
tmp Volatility
vol PrimType
bt,
forall a.
VName
-> VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Code a
Imp.Read VName
tmp VName
srcmem Count Elements (TExp Int64)
srcidx PrimType
bt Space
srcspace Volatility
vol,
forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
destmem Count Elements (TExp Int64)
destidx PrimType
bt Space
destspace Volatility
vol forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
tmp PrimType
bt
]
copyArrayDWIM ::
PrimType ->
MemLoc ->
[DimIndex (Imp.TExp Int64)] ->
MemLoc ->
[DimIndex (Imp.TExp Int64)] ->
ImpM rep r op (Imp.Code op)
copyArrayDWIM :: forall {k} (rep :: k) r op.
PrimType
-> MemLoc
-> [DimIndex (TExp Int64)]
-> MemLoc
-> [DimIndex (TExp Int64)]
-> ImpM rep r op (Code op)
copyArrayDWIM
PrimType
bt
destlocation :: MemLoc
destlocation@(MemLoc VName
_ [SubExp]
destshape IxFun (TExp Int64)
_)
[DimIndex (TExp Int64)]
destslice
srclocation :: MemLoc
srclocation@(MemLoc VName
_ [SubExp]
srcshape IxFun (TExp Int64)
_)
[DimIndex (TExp Int64)]
srcslice
| Just [TExp Int64]
destis <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall d. DimIndex d -> Maybe d
dimFix [DimIndex (TExp Int64)]
destslice,
Just [TExp Int64]
srcis <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall d. DimIndex d -> Maybe d
dimFix [DimIndex (TExp Int64)]
srcslice,
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TExp Int64]
srcis forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
srcshape,
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TExp Int64]
destis forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
destshape = do
(VName
targetmem, Space
destspace, Count Elements (TExp Int64)
targetoffset) <-
forall {k} (rep :: k) r op.
MemLoc
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLoc
destlocation [TExp Int64]
destis
(VName
srcmem, Space
srcspace, Count Elements (TExp Int64)
srcoffset) <-
forall {k} (rep :: k) r op.
MemLoc
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLoc
srclocation [TExp Int64]
srcis
Volatility
vol <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall {k} (rep :: k) r op. Env rep r op -> Volatility
envVolatility
forall {k} (rep :: k) r op.
ImpM rep r op () -> ImpM rep r op (Code op)
collect forall a b. (a -> b) -> a -> b
$ do
VName
tmp <- forall {k} (t :: k). TV t -> VName
tvVar forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} {k} (rep :: k) r op (t :: k).
[Char] -> PrimType -> ImpM rep r op (TV t)
dPrim [Char]
"tmp" PrimType
bt
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a.
VName
-> VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Code a
Imp.Read VName
tmp VName
srcmem Count Elements (TExp Int64)
srcoffset PrimType
bt Space
srcspace Volatility
vol
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
targetmem Count Elements (TExp Int64)
targetoffset PrimType
bt Space
destspace Volatility
vol forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
tmp PrimType
bt
| Bool
otherwise = do
let destslice' :: Slice (TExp Int64)
destslice' = forall d. Num d => [d] -> [DimIndex d] -> Slice d
fullSliceNum (forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 [SubExp]
destshape) [DimIndex (TExp Int64)]
destslice
srcslice' :: Slice (TExp Int64)
srcslice' = forall d. Num d => [d] -> [DimIndex d] -> Slice d
fullSliceNum (forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 [SubExp]
srcshape) [DimIndex (TExp Int64)]
srcslice
destrank :: Int
destrank = forall (t :: * -> *) a. Foldable t => t a -> Int
length forall a b. (a -> b) -> a -> b
$ forall d. Slice d -> [d]
sliceDims Slice (TExp Int64)
destslice'
srcrank :: Int
srcrank = forall (t :: * -> *) a. Foldable t => t a -> Int
length forall a b. (a -> b) -> a -> b
$ forall d. Slice d -> [d]
sliceDims Slice (TExp Int64)
srcslice'
destlocation' :: MemLoc
destlocation' = MemLoc -> Slice (TExp Int64) -> MemLoc
sliceMemLoc MemLoc
destlocation Slice (TExp Int64)
destslice'
srclocation' :: MemLoc
srclocation' = MemLoc -> Slice (TExp Int64) -> MemLoc
sliceMemLoc MemLoc
srclocation Slice (TExp Int64)
srcslice'
if Int
destrank forall a. Eq a => a -> a -> Bool
/= Int
srcrank
then
forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$
[Char]
"copyArrayDWIM: cannot copy to "
forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString (MemLoc -> VName
memLocName MemLoc
destlocation)
forall a. [a] -> [a] -> [a]
++ [Char]
" from "
forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString (MemLoc -> VName
memLocName MemLoc
srclocation)
forall a. [a] -> [a] -> [a]
++ [Char]
" because ranks do not match ("
forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString Int
destrank
forall a. [a] -> [a] -> [a]
++ [Char]
" vs "
forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString Int
srcrank
forall a. [a] -> [a] -> [a]
++ [Char]
")"
else
if MemLoc
destlocation' forall a. Eq a => a -> a -> Bool
== MemLoc
srclocation'
then forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Monoid a => a
mempty
else forall {k} (rep :: k) r op.
ImpM rep r op () -> ImpM rep r op (Code op)
collect forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) r op. CopyCompiler rep r op
copy PrimType
bt MemLoc
destlocation' MemLoc
srclocation'
copyDWIMDest ::
ValueDestination ->
[DimIndex (Imp.TExp Int64)] ->
SubExp ->
[DimIndex (Imp.TExp Int64)] ->
ImpM rep r op ()
copyDWIMDest :: forall {k} (rep :: k) r op.
ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIMDest ValueDestination
_ [DimIndex (TExp Int64)]
_ (Constant PrimValue
v) (DimIndex (TExp Int64)
_ : [DimIndex (TExp Int64)]
_) =
forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$
[[Char]] -> [Char]
unwords [[Char]
"copyDWIMDest: constant source", forall a. Pretty a => a -> [Char]
prettyString PrimValue
v, [Char]
"cannot be indexed."]
copyDWIMDest ValueDestination
pat [DimIndex (TExp Int64)]
dest_slice (Constant PrimValue
v) [] =
case forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall d. DimIndex d -> Maybe d
dimFix [DimIndex (TExp Int64)]
dest_slice of
Maybe [TExp Int64]
Nothing ->
forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$
[[Char]] -> [Char]
unwords [[Char]
"copyDWIMDest: constant source", forall a. Pretty a => a -> [Char]
prettyString PrimValue
v, [Char]
"with slice destination."]
Just [TExp Int64]
dest_is ->
case ValueDestination
pat of
ScalarDestination VName
name ->
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. VName -> Exp -> Code a
Imp.SetScalar VName
name forall a b. (a -> b) -> a -> b
$ forall v. PrimValue -> PrimExp v
Imp.ValueExp PrimValue
v
MemoryDestination {} ->
forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$
[[Char]] -> [Char]
unwords [[Char]
"copyDWIMDest: constant source", forall a. Pretty a => a -> [Char]
prettyString PrimValue
v, [Char]
"cannot be written to memory destination."]
ArrayDestination (Just MemLoc
dest_loc) -> do
(VName
dest_mem, Space
dest_space, Count Elements (TExp Int64)
dest_i) <-
forall {k} (rep :: k) r op.
MemLoc
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLoc
dest_loc [TExp Int64]
dest_is
Volatility
vol <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall {k} (rep :: k) r op. Env rep r op -> Volatility
envVolatility
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
dest_mem Count Elements (TExp Int64)
dest_i PrimType
bt Space
dest_space Volatility
vol forall a b. (a -> b) -> a -> b
$ forall v. PrimValue -> PrimExp v
Imp.ValueExp PrimValue
v
ArrayDestination Maybe MemLoc
Nothing ->
forall a. HasCallStack => [Char] -> a
error [Char]
"copyDWIMDest: ArrayDestination Nothing"
where
bt :: PrimType
bt = PrimValue -> PrimType
primValueType PrimValue
v
copyDWIMDest ValueDestination
dest [DimIndex (TExp Int64)]
dest_slice (Var VName
src) [DimIndex (TExp Int64)]
src_slice = do
VarEntry rep
src_entry <- forall {k} (rep :: k) r op. VName -> ImpM rep r op (VarEntry rep)
lookupVar VName
src
case (ValueDestination
dest, VarEntry rep
src_entry) of
(MemoryDestination VName
mem, MemVar Maybe (Exp rep)
_ (MemEntry Space
space)) ->
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. VName -> VName -> Space -> Code a
Imp.SetMem VName
mem VName
src Space
space
(MemoryDestination {}, VarEntry rep
_) ->
forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$
[[Char]] -> [Char]
unwords [[Char]
"copyDWIMDest: cannot write", forall a. Pretty a => a -> [Char]
prettyString VName
src, [Char]
"to memory destination."]
(ValueDestination
_, MemVar {}) ->
forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$
[[Char]] -> [Char]
unwords [[Char]
"copyDWIMDest: source", forall a. Pretty a => a -> [Char]
prettyString VName
src, [Char]
"is a memory block."]
(ValueDestination
_, ScalarVar Maybe (Exp rep)
_ (ScalarEntry PrimType
_))
| Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null [DimIndex (TExp Int64)]
src_slice ->
forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$
[[Char]] -> [Char]
unwords [[Char]
"copyDWIMDest: prim-typed source", forall a. Pretty a => a -> [Char]
prettyString VName
src, [Char]
"with slice", forall a. Pretty a => a -> [Char]
prettyString [DimIndex (TExp Int64)]
src_slice]
(ScalarDestination VName
name, VarEntry rep
_)
| Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null [DimIndex (TExp Int64)]
dest_slice ->
forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$
[[Char]] -> [Char]
unwords [[Char]
"copyDWIMDest: prim-typed target", forall a. Pretty a => a -> [Char]
prettyString VName
name, [Char]
"with slice", forall a. Pretty a => a -> [Char]
prettyString [DimIndex (TExp Int64)]
dest_slice]
(ScalarDestination VName
name, ScalarVar Maybe (Exp rep)
_ (ScalarEntry PrimType
pt)) ->
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. VName -> Exp -> Code a
Imp.SetScalar VName
name forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
src PrimType
pt
(ScalarDestination VName
name, ArrayVar Maybe (Exp rep)
_ ArrayEntry
arr)
| Just [TExp Int64]
src_is <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall d. DimIndex d -> Maybe d
dimFix [DimIndex (TExp Int64)]
src_slice,
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex (TExp Int64)]
src_slice forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length (ArrayEntry -> [SubExp]
entryArrayShape ArrayEntry
arr) -> do
let bt :: PrimType
bt = ArrayEntry -> PrimType
entryArrayElemType ArrayEntry
arr
(VName
mem, Space
space, Count Elements (TExp Int64)
i) <-
forall {k} (rep :: k) r op.
MemLoc
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' (ArrayEntry -> MemLoc
entryArrayLoc ArrayEntry
arr) [TExp Int64]
src_is
Volatility
vol <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall {k} (rep :: k) r op. Env rep r op -> Volatility
envVolatility
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a.
VName
-> VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Code a
Imp.Read VName
name VName
mem Count Elements (TExp Int64)
i PrimType
bt Space
space Volatility
vol
| Bool
otherwise ->
forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$
[[Char]] -> [Char]
unwords
[ [Char]
"copyDWIMDest: prim-typed target",
forall a. Pretty a => a -> [Char]
prettyString VName
name,
[Char]
"and array-typed source",
forall a. Pretty a => a -> [Char]
prettyString VName
src,
[Char]
"of shape",
forall a. Pretty a => a -> [Char]
prettyString (ArrayEntry -> [SubExp]
entryArrayShape ArrayEntry
arr),
[Char]
"sliced with",
forall a. Pretty a => a -> [Char]
prettyString [DimIndex (TExp Int64)]
src_slice
]
(ArrayDestination (Just MemLoc
dest_loc), ArrayVar Maybe (Exp rep)
_ ArrayEntry
src_arr) -> do
let src_loc :: MemLoc
src_loc = ArrayEntry -> MemLoc
entryArrayLoc ArrayEntry
src_arr
bt :: PrimType
bt = ArrayEntry -> PrimType
entryArrayElemType ArrayEntry
src_arr
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall {k} (rep :: k) r op.
PrimType
-> MemLoc
-> [DimIndex (TExp Int64)]
-> MemLoc
-> [DimIndex (TExp Int64)]
-> ImpM rep r op (Code op)
copyArrayDWIM PrimType
bt MemLoc
dest_loc [DimIndex (TExp Int64)]
dest_slice MemLoc
src_loc [DimIndex (TExp Int64)]
src_slice
(ArrayDestination (Just MemLoc
dest_loc), ScalarVar Maybe (Exp rep)
_ (ScalarEntry PrimType
bt))
| Just [TExp Int64]
dest_is <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall d. DimIndex d -> Maybe d
dimFix [DimIndex (TExp Int64)]
dest_slice,
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TExp Int64]
dest_is forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length (MemLoc -> [SubExp]
memLocShape MemLoc
dest_loc) -> do
(VName
dest_mem, Space
dest_space, Count Elements (TExp Int64)
dest_i) <- forall {k} (rep :: k) r op.
MemLoc
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLoc
dest_loc [TExp Int64]
dest_is
Volatility
vol <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall {k} (rep :: k) r op. Env rep r op -> Volatility
envVolatility
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
dest_mem Count Elements (TExp Int64)
dest_i PrimType
bt Space
dest_space Volatility
vol (VName -> PrimType -> Exp
Imp.var VName
src PrimType
bt)
| Bool
otherwise ->
forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$
[[Char]] -> [Char]
unwords
[ [Char]
"copyDWIMDest: array-typed target and prim-typed source",
forall a. Pretty a => a -> [Char]
prettyString VName
src,
[Char]
"with slice",
forall a. Pretty a => a -> [Char]
prettyString [DimIndex (TExp Int64)]
dest_slice
]
(ArrayDestination Maybe MemLoc
Nothing, VarEntry rep
_) ->
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
(ValueDestination
_, AccVar {}) ->
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
copyDWIM ::
VName ->
[DimIndex (Imp.TExp Int64)] ->
SubExp ->
[DimIndex (Imp.TExp Int64)] ->
ImpM rep r op ()
copyDWIM :: forall {k} (rep :: k) r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM VName
dest [DimIndex (TExp Int64)]
dest_slice SubExp
src [DimIndex (TExp Int64)]
src_slice = do
VarEntry rep
dest_entry <- forall {k} (rep :: k) r op. VName -> ImpM rep r op (VarEntry rep)
lookupVar VName
dest
let dest_target :: ValueDestination
dest_target =
case VarEntry rep
dest_entry of
ScalarVar Maybe (Exp rep)
_ ScalarEntry
_ ->
VName -> ValueDestination
ScalarDestination VName
dest
ArrayVar Maybe (Exp rep)
_ (ArrayEntry (MemLoc VName
mem [SubExp]
shape IxFun (TExp Int64)
ixfun) PrimType
_) ->
Maybe MemLoc -> ValueDestination
ArrayDestination forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ VName -> [SubExp] -> IxFun (TExp Int64) -> MemLoc
MemLoc VName
mem [SubExp]
shape IxFun (TExp Int64)
ixfun
MemVar Maybe (Exp rep)
_ MemEntry
_ ->
VName -> ValueDestination
MemoryDestination VName
dest
AccVar {} ->
Maybe MemLoc -> ValueDestination
ArrayDestination forall a. Maybe a
Nothing
forall {k} (rep :: k) r op.
ValueDestination
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIMDest ValueDestination
dest_target [DimIndex (TExp Int64)]
dest_slice SubExp
src [DimIndex (TExp Int64)]
src_slice
copyDWIMFix ::
VName ->
[Imp.TExp Int64] ->
SubExp ->
[Imp.TExp Int64] ->
ImpM rep r op ()
copyDWIMFix :: forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM rep r op ()
copyDWIMFix VName
dest [TExp Int64]
dest_is SubExp
src [TExp Int64]
src_is =
forall {k} (rep :: k) r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM VName
dest (forall a b. (a -> b) -> [a] -> [b]
map forall d. d -> DimIndex d
DimFix [TExp Int64]
dest_is) SubExp
src (forall a b. (a -> b) -> [a] -> [b]
map forall d. d -> DimIndex d
DimFix [TExp Int64]
src_is)
compileAlloc ::
Mem rep inner => Pat (LetDec rep) -> SubExp -> Space -> ImpM rep r op ()
compileAlloc :: forall {k} (rep :: k) inner r op.
Mem rep inner =>
Pat (LetDec rep) -> SubExp -> Space -> ImpM rep r op ()
compileAlloc (Pat [PatElem (LetDec rep)
mem]) SubExp
e Space
space = do
let e' :: Count Bytes (TExp Int64)
e' = forall a. a -> Count Bytes a
Imp.bytes forall a b. (a -> b) -> a -> b
$ SubExp -> TExp Int64
pe64 SubExp
e
Maybe (AllocCompiler rep r op)
allocator <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Space
space forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k) r op.
Env rep r op -> Map Space (AllocCompiler rep r op)
envAllocCompilers
case Maybe (AllocCompiler rep r op)
allocator of
Maybe (AllocCompiler rep r op)
Nothing -> forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. VName -> Count Bytes (TExp Int64) -> Space -> Code a
Imp.Allocate (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
mem) Count Bytes (TExp Int64)
e' Space
space
Just AllocCompiler rep r op
allocator' -> AllocCompiler rep r op
allocator' (forall dec. PatElem dec -> VName
patElemName PatElem (LetDec rep)
mem) Count Bytes (TExp Int64)
e'
compileAlloc Pat (LetDec rep)
pat SubExp
_ Space
_ =
forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"compileAlloc: Invalid pattern: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString Pat (LetDec rep)
pat
typeSize :: Type -> Count Bytes (Imp.TExp Int64)
typeSize :: Type -> Count Bytes (TExp Int64)
typeSize Type
t =
forall a. a -> Count Bytes a
Imp.bytes forall a b. (a -> b) -> a -> b
$ forall a. Num a => PrimType -> a
primByteSize (forall shape u. TypeBase shape u -> PrimType
elemType Type
t) forall a. Num a => a -> a -> a
* forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product (forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 (forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t))
inBounds :: Slice (Imp.TExp Int64) -> [Imp.TExp Int64] -> Imp.TExp Bool
inBounds :: Slice (TExp Int64) -> [TExp Int64] -> TExp Bool
inBounds (Slice [DimIndex (TExp Int64)]
slice) [TExp Int64]
dims =
let condInBounds :: DimIndex (TPrimExp t v) -> TPrimExp t v -> TPrimExp Bool v
condInBounds (DimFix TPrimExp t v
i) TPrimExp t v
d =
TPrimExp t v
0 forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TPrimExp t v
i forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TPrimExp t v
i forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp t v
d
condInBounds (DimSlice TPrimExp t v
i TPrimExp t v
n TPrimExp t v
s) TPrimExp t v
d =
TPrimExp t v
0 forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TPrimExp t v
i forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TPrimExp t v
i forall a. Num a => a -> a -> a
+ (TPrimExp t v
n forall a. Num a => a -> a -> a
- TPrimExp t v
1) forall a. Num a => a -> a -> a
* TPrimExp t v
s forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp t v
d
in forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
(.&&.) forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall {k} {t :: k} {v}.
(NumExp t, Pretty v) =>
DimIndex (TPrimExp t v) -> TPrimExp t v -> TPrimExp Bool v
condInBounds [DimIndex (TExp Int64)]
slice [TExp Int64]
dims
rotateIndex ::
Imp.TExp Int64 ->
Imp.TExp Int64 ->
Imp.TExp Int64 ->
Imp.TExp Int64
rotateIndex :: TExp Int64 -> TExp Int64 -> TExp Int64 -> TExp Int64
rotateIndex TExp Int64
d TExp Int64
r TExp Int64
i = (TExp Int64
i forall a. Num a => a -> a -> a
+ TExp Int64
r) forall e. IntegralExp e => e -> e -> e
`mod` TExp Int64
d
sFor' :: VName -> Imp.Exp -> ImpM rep r op () -> ImpM rep r op ()
sFor' :: forall {k} (rep :: k) r op.
VName -> Exp -> ImpM rep r op () -> ImpM rep r op ()
sFor' VName
i Exp
bound ImpM rep r op ()
body = do
let it :: IntType
it = case forall v. PrimExp v -> PrimType
primExpType Exp
bound of
IntType IntType
bound_t -> IntType
bound_t
PrimType
t -> forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"sFor': bound " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString Exp
bound forall a. [a] -> [a] -> [a]
++ [Char]
" is of type " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString PrimType
t
forall {k} (rep :: k) r op. VName -> IntType -> ImpM rep r op ()
addLoopVar VName
i IntType
it
Code op
body' <- forall {k} (rep :: k) r op.
ImpM rep r op () -> ImpM rep r op (Code op)
collect ImpM rep r op ()
body
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. VName -> Exp -> Code a -> Code a
Imp.For VName
i Exp
bound Code op
body'
sFor :: String -> Imp.TExp t -> (Imp.TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor :: forall {k} {k} (t :: k) (rep :: k) r op.
[Char]
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor [Char]
i TExp t
bound TExp t -> ImpM rep r op ()
body = do
VName
i' <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
i
forall {k} (rep :: k) r op.
VName -> Exp -> ImpM rep r op () -> ImpM rep r op ()
sFor' VName
i' (forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp t
bound) forall a b. (a -> b) -> a -> b
$
TExp t -> ImpM rep r op ()
body forall a b. (a -> b) -> a -> b
$
forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp forall a b. (a -> b) -> a -> b
$
VName -> PrimType -> Exp
Imp.var VName
i' forall a b. (a -> b) -> a -> b
$
forall v. PrimExp v -> PrimType
primExpType forall a b. (a -> b) -> a -> b
$
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp t
bound
sWhile :: Imp.TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhile :: forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhile TExp Bool
cond ImpM rep r op ()
body = do
Code op
body' <- forall {k} (rep :: k) r op.
ImpM rep r op () -> ImpM rep r op (Code op)
collect ImpM rep r op ()
body
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. TExp Bool -> Code a -> Code a
Imp.While TExp Bool
cond Code op
body'
sComment :: T.Text -> ImpM rep r op () -> ImpM rep r op ()
Text
s ImpM rep r op ()
code = do
Code op
code' <- forall {k} (rep :: k) r op.
ImpM rep r op () -> ImpM rep r op (Code op)
collect ImpM rep r op ()
code
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. Text -> Code a -> Code a
Imp.Comment Text
s Code op
code'
sIf :: Imp.TExp Bool -> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf :: forall {k} (rep :: k) r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf TExp Bool
cond ImpM rep r op ()
tbranch ImpM rep r op ()
fbranch = do
Code op
tbranch' <- forall {k} (rep :: k) r op.
ImpM rep r op () -> ImpM rep r op (Code op)
collect ImpM rep r op ()
tbranch
Code op
fbranch' <- forall {k} (rep :: k) r op.
ImpM rep r op () -> ImpM rep r op (Code op)
collect ImpM rep r op ()
fbranch
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$
if TExp Bool
cond forall a. Eq a => a -> a -> Bool
== forall v. TPrimExp Bool v
true
then Code op
tbranch'
else
if TExp Bool
cond forall a. Eq a => a -> a -> Bool
== forall v. TPrimExp Bool v
false
then Code op
fbranch'
else forall a. TExp Bool -> Code a -> Code a -> Code a
Imp.If TExp Bool
cond Code op
tbranch' Code op
fbranch'
sWhen :: Imp.TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen :: forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
cond ImpM rep r op ()
tbranch = forall {k} (rep :: k) r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf TExp Bool
cond ImpM rep r op ()
tbranch (forall (f :: * -> *) a. Applicative f => a -> f a
pure ())
sUnless :: Imp.TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sUnless :: forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sUnless TExp Bool
cond = forall {k} (rep :: k) r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf TExp Bool
cond (forall (f :: * -> *) a. Applicative f => a -> f a
pure ())
sOp :: op -> ImpM rep r op ()
sOp :: forall {k} op (rep :: k) r. op -> ImpM rep r op ()
sOp = forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> Code a
Imp.Op
sDeclareMem :: String -> Space -> ImpM rep r op VName
sDeclareMem :: forall {k} (rep :: k) r op. [Char] -> Space -> ImpM rep r op VName
sDeclareMem [Char]
name Space
space = do
VName
name' <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
name
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. VName -> Space -> Code a
Imp.DeclareMem VName
name' Space
space
forall {k} (rep :: k) r op.
VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
name' forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Maybe (Exp rep) -> MemEntry -> VarEntry rep
MemVar forall a. Maybe a
Nothing forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
space
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
name'
sAlloc_ :: VName -> Count Bytes (Imp.TExp Int64) -> Space -> ImpM rep r op ()
sAlloc_ :: forall {k} (rep :: k) r op.
VName -> Count Bytes (TExp Int64) -> Space -> ImpM rep r op ()
sAlloc_ VName
name' Count Bytes (TExp Int64)
size' Space
space = do
Maybe (AllocCompiler rep r op)
allocator <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Space
space forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k) r op.
Env rep r op -> Map Space (AllocCompiler rep r op)
envAllocCompilers
case Maybe (AllocCompiler rep r op)
allocator of
Maybe (AllocCompiler rep r op)
Nothing -> forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. VName -> Count Bytes (TExp Int64) -> Space -> Code a
Imp.Allocate VName
name' Count Bytes (TExp Int64)
size' Space
space
Just AllocCompiler rep r op
allocator' -> AllocCompiler rep r op
allocator' VName
name' Count Bytes (TExp Int64)
size'
sAlloc :: String -> Count Bytes (Imp.TExp Int64) -> Space -> ImpM rep r op VName
sAlloc :: forall {k} (rep :: k) r op.
[Char] -> Count Bytes (TExp Int64) -> Space -> ImpM rep r op VName
sAlloc [Char]
name Count Bytes (TExp Int64)
size Space
space = do
VName
name' <- forall {k} (rep :: k) r op. [Char] -> Space -> ImpM rep r op VName
sDeclareMem [Char]
name Space
space
forall {k} (rep :: k) r op.
VName -> Count Bytes (TExp Int64) -> Space -> ImpM rep r op ()
sAlloc_ VName
name' Count Bytes (TExp Int64)
size Space
space
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
name'
sArray :: String -> PrimType -> ShapeBase SubExp -> VName -> IxFun -> ImpM rep r op VName
sArray :: forall {k} (rep :: k) r op.
[Char]
-> PrimType
-> Shape
-> VName
-> IxFun (TExp Int64)
-> ImpM rep r op VName
sArray [Char]
name PrimType
bt Shape
shape VName
mem IxFun (TExp Int64)
ixfun = do
VName
name' <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
name
forall {k} (rep :: k) r op.
VName
-> PrimType
-> Shape
-> VName
-> IxFun (TExp Int64)
-> ImpM rep r op ()
dArray VName
name' PrimType
bt Shape
shape VName
mem IxFun (TExp Int64)
ixfun
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
name'
sArrayInMem :: String -> PrimType -> ShapeBase SubExp -> VName -> ImpM rep r op VName
sArrayInMem :: forall {k} (rep :: k) r op.
[Char] -> PrimType -> Shape -> VName -> ImpM rep r op VName
sArrayInMem [Char]
name PrimType
pt Shape
shape VName
mem =
forall {k} (rep :: k) r op.
[Char]
-> PrimType
-> Shape
-> VName
-> IxFun (TExp Int64)
-> ImpM rep r op VName
sArray [Char]
name PrimType
pt Shape
shape VName
mem forall a b. (a -> b) -> a -> b
$
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota forall a b. (a -> b) -> a -> b
$
forall a b. (a -> b) -> [a] -> [b]
map (forall v. PrimExp v -> TPrimExp Int64 v
isInt64 forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> SubExp -> Exp
primExpFromSubExp PrimType
int64) forall a b. (a -> b) -> a -> b
$
forall d. ShapeBase d -> [d]
shapeDims Shape
shape
sAllocArrayPerm :: String -> PrimType -> ShapeBase SubExp -> Space -> [Int] -> ImpM rep r op VName
sAllocArrayPerm :: forall {k} (rep :: k) r op.
[Char]
-> PrimType -> Shape -> Space -> [Int] -> ImpM rep r op VName
sAllocArrayPerm [Char]
name PrimType
pt Shape
shape Space
space [Int]
perm = do
let permuted_dims :: [SubExp]
permuted_dims = forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims Shape
shape
VName
mem <- forall {k} (rep :: k) r op.
[Char] -> Count Bytes (TExp Int64) -> Space -> ImpM rep r op VName
sAlloc ([Char]
name forall a. [a] -> [a] -> [a]
++ [Char]
"_mem") (Type -> Count Bytes (TExp Int64)
typeSize (forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
pt Shape
shape NoUniqueness
NoUniqueness)) Space
space
let iota_ixfun :: IxFun (TExp Int64)
iota_ixfun = forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall v. PrimExp v -> TPrimExp Int64 v
isInt64 forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> SubExp -> Exp
primExpFromSubExp PrimType
int64) [SubExp]
permuted_dims
forall {k} (rep :: k) r op.
[Char]
-> PrimType
-> Shape
-> VName
-> IxFun (TExp Int64)
-> ImpM rep r op VName
sArray [Char]
name PrimType
pt Shape
shape VName
mem forall a b. (a -> b) -> a -> b
$
forall num. IntegralExp num => IxFun num -> [Int] -> IxFun num
IxFun.permute IxFun (TExp Int64)
iota_ixfun forall a b. (a -> b) -> a -> b
$
[Int] -> [Int]
rearrangeInverse [Int]
perm
sAllocArray :: String -> PrimType -> ShapeBase SubExp -> Space -> ImpM rep r op VName
sAllocArray :: forall {k} (rep :: k) r op.
[Char] -> PrimType -> Shape -> Space -> ImpM rep r op VName
sAllocArray [Char]
name PrimType
pt Shape
shape Space
space =
forall {k} (rep :: k) r op.
[Char]
-> PrimType -> Shape -> Space -> [Int] -> ImpM rep r op VName
sAllocArrayPerm [Char]
name PrimType
pt Shape
shape Space
space [Int
0 .. forall a. ArrayShape a => a -> Int
shapeRank Shape
shape forall a. Num a => a -> a -> a
- Int
1]
sStaticArray :: String -> Space -> PrimType -> Imp.ArrayContents -> ImpM rep r op VName
sStaticArray :: forall {k} (rep :: k) r op.
[Char] -> Space -> PrimType -> ArrayContents -> ImpM rep r op VName
sStaticArray [Char]
name Space
space PrimType
pt ArrayContents
vs = do
let num_elems :: Int
num_elems = case ArrayContents
vs of
Imp.ArrayValues [PrimValue]
vs' -> forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimValue]
vs'
Imp.ArrayZeros Int
n -> forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n
shape :: Shape
shape = forall d. [d] -> ShapeBase d
Shape [IntType -> Integer -> SubExp
intConst IntType
Int64 forall a b. (a -> b) -> a -> b
$ forall a. Integral a => a -> Integer
toInteger Int
num_elems]
VName
mem <- forall {k} (rep :: k) r op. [Char] -> ImpM rep r op VName
newVNameForFun forall a b. (a -> b) -> a -> b
$ [Char]
name forall a. [a] -> [a] -> [a]
++ [Char]
"_mem"
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. VName -> Space -> PrimType -> ArrayContents -> Code a
Imp.DeclareArray VName
mem Space
space PrimType
pt ArrayContents
vs
forall {k} (rep :: k) r op.
VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
mem forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Maybe (Exp rep) -> MemEntry -> VarEntry rep
MemVar forall a. Maybe a
Nothing forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
space
forall {k} (rep :: k) r op.
[Char]
-> PrimType
-> Shape
-> VName
-> IxFun (TExp Int64)
-> ImpM rep r op VName
sArray [Char]
name PrimType
pt Shape
shape VName
mem forall a b. (a -> b) -> a -> b
$ forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota [forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
num_elems]
sWrite :: VName -> [Imp.TExp Int64] -> Imp.Exp -> ImpM rep r op ()
sWrite :: forall {k} (rep :: k) r op.
VName -> [TExp Int64] -> Exp -> ImpM rep r op ()
sWrite VName
arr [TExp Int64]
is Exp
v = do
(VName
mem, Space
space, Count Elements (TExp Int64)
offset) <- forall {k} (rep :: k) r op.
VName
-> [TExp Int64]
-> ImpM rep r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray VName
arr [TExp Int64]
is
Volatility
vol <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall {k} (rep :: k) r op. Env rep r op -> Volatility
envVolatility
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
mem Count Elements (TExp Int64)
offset (forall v. PrimExp v -> PrimType
primExpType Exp
v) Space
space Volatility
vol Exp
v
sUpdate :: VName -> Slice (Imp.TExp Int64) -> SubExp -> ImpM rep r op ()
sUpdate :: forall {k} (rep :: k) r op.
VName -> Slice (TExp Int64) -> SubExp -> ImpM rep r op ()
sUpdate VName
arr Slice (TExp Int64)
slice SubExp
v = forall {k} (rep :: k) r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM rep r op ()
copyDWIM VName
arr (forall d. Slice d -> [DimIndex d]
unSlice Slice (TExp Int64)
slice) SubExp
v []
sLoopSpace ::
[Imp.TExp t] ->
([Imp.TExp t] -> ImpM rep r op ()) ->
ImpM rep r op ()
sLoopSpace :: forall {k} {k} (t :: k) (rep :: k) r op.
[TExp t] -> ([TExp t] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopSpace = forall {k} {k} {t :: k} {rep :: k} {r} {op}.
[TExp t]
-> [TExp t] -> ([TExp t] -> ImpM rep r op ()) -> ImpM rep r op ()
nest []
where
nest :: [TExp t]
-> [TExp t] -> ([TExp t] -> ImpM rep r op ()) -> ImpM rep r op ()
nest [TExp t]
is [] [TExp t] -> ImpM rep r op ()
f = [TExp t] -> ImpM rep r op ()
f forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
reverse [TExp t]
is
nest [TExp t]
is (TExp t
d : [TExp t]
ds) [TExp t] -> ImpM rep r op ()
f = forall {k} {k} (t :: k) (rep :: k) r op.
[Char]
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor [Char]
"nest_i" TExp t
d forall a b. (a -> b) -> a -> b
$ \TExp t
i -> [TExp t]
-> [TExp t] -> ([TExp t] -> ImpM rep r op ()) -> ImpM rep r op ()
nest (TExp t
i forall a. a -> [a] -> [a]
: [TExp t]
is) [TExp t]
ds [TExp t] -> ImpM rep r op ()
f
sLoopNest ::
Shape ->
([Imp.TExp Int64] -> ImpM rep r op ()) ->
ImpM rep r op ()
sLoopNest :: forall {k} (rep :: k) r op.
Shape -> ([TExp Int64] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest = forall {k} {k} (t :: k) (rep :: k) r op.
[TExp t] -> ([TExp t] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopSpace forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
pe64 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d. ShapeBase d -> [d]
shapeDims
sCopy ::
VName ->
Imp.TExp Int64 ->
Space ->
VName ->
Imp.TExp Int64 ->
Space ->
Count Elements (Imp.TExp Int64) ->
PrimType ->
ImpM rep r op ()
sCopy :: forall {k} (rep :: k) r op.
VName
-> TExp Int64
-> Space
-> VName
-> TExp Int64
-> Space
-> Count Elements (TExp Int64)
-> PrimType
-> ImpM rep r op ()
sCopy VName
destmem TExp Int64
destoffset Space
destspace VName
srcmem TExp Int64
srcoffset Space
srcspace Count Elements (TExp Int64)
num_elems PrimType
pt =
if VName
destmem forall a. Eq a => a -> a -> Bool
== VName
srcmem
then forall {k} (rep :: k) r op.
TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sUnless (TExp Int64
destoffset forall {k} (t :: k) v.
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int64
srcoffset) ImpM rep r op ()
the_copy
else ImpM rep r op ()
the_copy
where
the_copy :: ImpM rep r op ()
the_copy =
forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit
forall a b. (a -> b) -> a -> b
$ forall a.
PrimType
-> VName
-> Count Bytes (TExp Int64)
-> Space
-> VName
-> Count Bytes (TExp Int64)
-> Space
-> Count Bytes (TExp Int64)
-> Code a
Imp.Copy
PrimType
pt
VName
destmem
(forall a. a -> Count Bytes a
bytes TExp Int64
destoffset)
Space
destspace
VName
srcmem
(forall a. a -> Count Bytes a
bytes TExp Int64
srcoffset)
Space
srcspace
forall a b. (a -> b) -> a -> b
$ Count Elements (TExp Int64)
num_elems Count Elements (TExp Int64) -> PrimType -> Count Bytes (TExp Int64)
`withElemType` PrimType
pt
(<~~) :: VName -> Imp.Exp -> ImpM rep r op ()
VName
x <~~ :: forall {k} (rep :: k) r op. VName -> Exp -> ImpM rep r op ()
<~~ Exp
e = forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. VName -> Exp -> Code a
Imp.SetScalar VName
x Exp
e
infixl 3 <~~
(<--) :: TV t -> Imp.TExp t -> ImpM rep r op ()
TV VName
x PrimType
_ <-- :: forall {k} {k} (t :: k) (rep :: k) r op.
TV t -> TExp t -> ImpM rep r op ()
<-- TExp t
e = forall {k} op (rep :: k) r. Code op -> ImpM rep r op ()
emit forall a b. (a -> b) -> a -> b
$ forall a. VName -> Exp -> Code a
Imp.SetScalar VName
x forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp t
e
infixl 3 <--
function ::
Name ->
[Imp.Param] ->
[Imp.Param] ->
ImpM rep r op () ->
ImpM rep r op ()
function :: forall {k} (rep :: k) r op.
Name -> [Param] -> [Param] -> ImpM rep r op () -> ImpM rep r op ()
function Name
fname [Param]
outputs [Param]
inputs ImpM rep r op ()
m = forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local Env rep r op -> Env rep r op
newFunction forall a b. (a -> b) -> a -> b
$ do
Code op
body <- forall {k} (rep :: k) r op.
ImpM rep r op () -> ImpM rep r op (Code op)
collect forall a b. (a -> b) -> a -> b
$ do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall {k} {rep :: k} {r} {op}. Param -> ImpM rep r op ()
addParam forall a b. (a -> b) -> a -> b
$ [Param]
outputs forall a. [a] -> [a] -> [a]
++ [Param]
inputs
ImpM rep r op ()
m
forall {k} op (rep :: k) r. Name -> Function op -> ImpM rep r op ()
emitFunction Name
fname forall a b. (a -> b) -> a -> b
$ forall a.
Maybe EntryPoint -> [Param] -> [Param] -> Code a -> FunctionT a
Imp.Function forall a. Maybe a
Nothing [Param]
outputs [Param]
inputs Code op
body
where
addParam :: Param -> ImpM rep r op ()
addParam (Imp.MemParam VName
name Space
space) =
forall {k} (rep :: k) r op.
VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
name forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Maybe (Exp rep) -> MemEntry -> VarEntry rep
MemVar forall a. Maybe a
Nothing forall a b. (a -> b) -> a -> b
$ Space -> MemEntry
MemEntry Space
space
addParam (Imp.ScalarParam VName
name PrimType
bt) =
forall {k} (rep :: k) r op.
VName -> VarEntry rep -> ImpM rep r op ()
addVar VName
name forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Maybe (Exp rep) -> ScalarEntry -> VarEntry rep
ScalarVar forall a. Maybe a
Nothing forall a b. (a -> b) -> a -> b
$ PrimType -> ScalarEntry
ScalarEntry PrimType
bt
newFunction :: Env rep r op -> Env rep r op
newFunction Env rep r op
env = Env rep r op
env {envFunction :: Maybe Name
envFunction = forall a. a -> Maybe a
Just Name
fname}
dSlices :: [Imp.TExp Int64] -> ImpM rep r op [Imp.TExp Int64]
dSlices :: forall {k} (rep :: k) r op.
[TExp Int64] -> ImpM rep r op [TExp Int64]
dSlices = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a. Int -> [a] -> [a]
drop Int
1 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} {k} {t :: k} {rep :: k} {r} {op}.
NumExp t =>
[TExp t] -> ImpM rep r op (TExp t, [TExp t])
dSlices'
where
dSlices' :: [TExp t] -> ImpM rep r op (TExp t, [TExp t])
dSlices' [] = forall (f :: * -> *) a. Applicative f => a -> f a
pure (TExp t
1, [TExp t
1])
dSlices' (TExp t
n : [TExp t]
ns) = do
(TExp t
prod, [TExp t]
ns') <- [TExp t] -> ImpM rep r op (TExp t, [TExp t])
dSlices' [TExp t]
ns
TExp t
n' <- forall {k} {k} (t :: k) (rep :: k) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"slice" forall a b. (a -> b) -> a -> b
$ TExp t
n forall a. Num a => a -> a -> a
* TExp t
prod
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TExp t
n', TExp t
n' forall a. a -> [a] -> [a]
: [TExp t]
ns')
dIndexSpace ::
[(VName, Imp.TExp Int64)] ->
Imp.TExp Int64 ->
ImpM rep r op ()
dIndexSpace :: forall {k} (rep :: k) r op.
[(VName, TExp Int64)] -> TExp Int64 -> ImpM rep r op ()
dIndexSpace [(VName, TExp Int64)]
vs_ds TExp Int64
j = do
[TExp Int64]
slices <- forall {k} (rep :: k) r op.
[TExp Int64] -> ImpM rep r op [TExp Int64]
dSlices (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [(VName, TExp Int64)]
vs_ds)
forall {k} (rep :: k) r op.
[(VName, TExp Int64)] -> TExp Int64 -> ImpM rep r op ()
loop (forall a b. [a] -> [b] -> [(a, b)]
zip (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(VName, TExp Int64)]
vs_ds) [TExp Int64]
slices) TExp Int64
j
where
loop :: [(VName, TExp Int64)] -> TExp Int64 -> ImpM rep r op ()
loop ((VName
v, TExp Int64
size) : [(VName, TExp Int64)]
rest) TExp Int64
i = do
forall {k} {k} (t :: k) (rep :: k) r op.
VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
v (TExp Int64
i forall e. IntegralExp e => e -> e -> e
`quot` TExp Int64
size)
TExp Int64
i' <- forall {k} {k} (t :: k) (rep :: k) r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"remnant" forall a b. (a -> b) -> a -> b
$ TExp Int64
i forall a. Num a => a -> a -> a
- forall a. a -> TPrimExp Int64 a
Imp.le64 VName
v forall a. Num a => a -> a -> a
* TExp Int64
size
[(VName, TExp Int64)] -> TExp Int64 -> ImpM rep r op ()
loop [(VName, TExp Int64)]
rest TExp Int64
i'
loop [(VName, TExp Int64)]
_ TExp Int64
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
dIndexSpace' ::
String ->
[Imp.TExp Int64] ->
Imp.TExp Int64 ->
ImpM rep r op [Imp.TExp Int64]
dIndexSpace' :: forall {k} (rep :: k) r op.
[Char] -> [TExp Int64] -> TExp Int64 -> ImpM rep r op [TExp Int64]
dIndexSpace' [Char]
desc [TExp Int64]
ds TExp Int64
j = do
[VName]
ivs <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (forall (t :: * -> *) a. Foldable t => t a -> Int
length [TExp Int64]
ds) (forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
desc)
forall {k} (rep :: k) r op.
[(VName, TExp Int64)] -> TExp Int64 -> ImpM rep r op ()
dIndexSpace (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
ivs [TExp Int64]
ds) TExp Int64
j
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
ivs