{-# LANGUAGE QuasiQuotes #-}

-- | C code generator framework.
module Futhark.CodeGen.Backends.GenericC.Monad
  ( -- * Pluggable compiler
    Operations (..),
    Publicness (..),
    OpCompiler,
    ErrorCompiler,
    CallCompiler,
    PointerQuals,
    MemoryType,
    WriteScalar,
    writeScalarPointerWithQuals,
    ReadScalar,
    readScalarPointerWithQuals,
    Allocate,
    Deallocate,
    CopyBarrier (..),
    Copy,

    -- * Monadic compiler interface
    CompilerM,
    CompilerState (..),
    CompilerEnv (..),
    getUserState,
    modifyUserState,
    generateProgramStruct,
    runCompilerM,
    inNewFunction,
    cachingMemory,
    volQuals,
    rawMem,
    item,
    items,
    stm,
    stms,
    decl,
    headerDecl,
    publicDef,
    publicDef_,
    profileReport,
    onClear,
    HeaderSection (..),
    libDecl,
    earlyDecl,
    publicName,
    contextField,
    contextFieldDyn,
    memToCType,
    cacheMem,
    fatMemory,
    rawMemCType,
    freeRawMem,
    allocRawMem,
    fatMemType,
    declAllocatedMem,
    freeAllocatedMem,
    collect,
    collect',
    contextType,
    configType,

    -- * Building Blocks
    copyMemoryDefaultSpace,
    derefPointer,
    setMem,
    allocMem,
    unRefMem,
    declMem,
    resetMem,
    fatMemAlloc,
    fatMemSet,
    fatMemUnRef,
    criticalSection,
    module Futhark.CodeGen.Backends.SimpleRep,
  )
where

import Control.Monad.Identity
import Control.Monad.Reader
import Control.Monad.State
import Data.Bifunctor (first)
import Data.DList qualified as DL
import Data.List (unzip4)
import Data.Loc
import Data.Map.Strict qualified as M
import Data.Maybe
import Data.Text qualified as T
import Futhark.CodeGen.Backends.GenericC.Pretty
import Futhark.CodeGen.Backends.SimpleRep
import Futhark.CodeGen.ImpCode
import Futhark.MonadFreshNames
import Language.C.Quote.OpenCL qualified as C
import Language.C.Syntax qualified as C

-- How public an array type definition sould be.  Public types show up
-- in the generated API, while private types are used only to
-- implement the members of opaques.
data Publicness = Private | Public
  deriving (Publicness -> Publicness -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Publicness -> Publicness -> Bool
$c/= :: Publicness -> Publicness -> Bool
== :: Publicness -> Publicness -> Bool
$c== :: Publicness -> Publicness -> Bool
Eq, Eq Publicness
Publicness -> Publicness -> Bool
Publicness -> Publicness -> Ordering
Publicness -> Publicness -> Publicness
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Publicness -> Publicness -> Publicness
$cmin :: Publicness -> Publicness -> Publicness
max :: Publicness -> Publicness -> Publicness
$cmax :: Publicness -> Publicness -> Publicness
>= :: Publicness -> Publicness -> Bool
$c>= :: Publicness -> Publicness -> Bool
> :: Publicness -> Publicness -> Bool
$c> :: Publicness -> Publicness -> Bool
<= :: Publicness -> Publicness -> Bool
$c<= :: Publicness -> Publicness -> Bool
< :: Publicness -> Publicness -> Bool
$c< :: Publicness -> Publicness -> Bool
compare :: Publicness -> Publicness -> Ordering
$ccompare :: Publicness -> Publicness -> Ordering
Ord, Int -> Publicness -> ShowS
[Publicness] -> ShowS
Publicness -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [Publicness] -> ShowS
$cshowList :: [Publicness] -> ShowS
show :: Publicness -> [Char]
$cshow :: Publicness -> [Char]
showsPrec :: Int -> Publicness -> ShowS
$cshowsPrec :: Int -> Publicness -> ShowS
Show)

type ArrayType = (Signedness, PrimType, Int)

data CompilerState s = CompilerState
  { forall s. CompilerState s -> Map ArrayType Publicness
compArrayTypes :: M.Map ArrayType Publicness,
    forall s. CompilerState s -> DList Definition
compEarlyDecls :: DL.DList C.Definition,
    forall s. CompilerState s -> VNameSource
compNameSrc :: VNameSource,
    forall s. CompilerState s -> s
compUserState :: s,
    forall s. CompilerState s -> Map HeaderSection (DList Definition)
compHeaderDecls :: M.Map HeaderSection (DL.DList C.Definition),
    forall s. CompilerState s -> DList Definition
compLibDecls :: DL.DList C.Definition,
    forall s.
CompilerState s -> DList (Id, Type, Maybe Exp, Maybe (Stm, Stm))
compCtxFields :: DL.DList (C.Id, C.Type, Maybe C.Exp, Maybe (C.Stm, C.Stm)),
    forall s. CompilerState s -> DList BlockItem
compProfileItems :: DL.DList C.BlockItem,
    forall s. CompilerState s -> DList BlockItem
compClearItems :: DL.DList C.BlockItem,
    forall s. CompilerState s -> [(VName, Space)]
compDeclaredMem :: [(VName, Space)],
    forall s. CompilerState s -> DList BlockItem
compItems :: DL.DList C.BlockItem
  }

newCompilerState :: VNameSource -> s -> CompilerState s
newCompilerState :: forall s. VNameSource -> s -> CompilerState s
newCompilerState VNameSource
src s
s =
  CompilerState
    { compArrayTypes :: Map ArrayType Publicness
compArrayTypes = forall a. Monoid a => a
mempty,
      compEarlyDecls :: DList Definition
compEarlyDecls = forall a. Monoid a => a
mempty,
      compNameSrc :: VNameSource
compNameSrc = VNameSource
src,
      compUserState :: s
compUserState = s
s,
      compHeaderDecls :: Map HeaderSection (DList Definition)
compHeaderDecls = forall a. Monoid a => a
mempty,
      compLibDecls :: DList Definition
compLibDecls = forall a. Monoid a => a
mempty,
      compCtxFields :: DList (Id, Type, Maybe Exp, Maybe (Stm, Stm))
compCtxFields = forall a. Monoid a => a
mempty,
      compProfileItems :: DList BlockItem
compProfileItems = forall a. Monoid a => a
mempty,
      compClearItems :: DList BlockItem
compClearItems = forall a. Monoid a => a
mempty,
      compDeclaredMem :: [(VName, Space)]
compDeclaredMem = forall a. Monoid a => a
mempty,
      compItems :: DList BlockItem
compItems = forall a. Monoid a => a
mempty
    }

-- | In which part of the header file we put the declaration.  This is
-- to ensure that the header file remains structured and readable.
data HeaderSection
  = ArrayDecl Name
  | OpaqueTypeDecl Name
  | OpaqueDecl Name
  | EntryDecl
  | MiscDecl
  | InitDecl
  deriving (HeaderSection -> HeaderSection -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: HeaderSection -> HeaderSection -> Bool
$c/= :: HeaderSection -> HeaderSection -> Bool
== :: HeaderSection -> HeaderSection -> Bool
$c== :: HeaderSection -> HeaderSection -> Bool
Eq, Eq HeaderSection
HeaderSection -> HeaderSection -> Bool
HeaderSection -> HeaderSection -> Ordering
HeaderSection -> HeaderSection -> HeaderSection
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: HeaderSection -> HeaderSection -> HeaderSection
$cmin :: HeaderSection -> HeaderSection -> HeaderSection
max :: HeaderSection -> HeaderSection -> HeaderSection
$cmax :: HeaderSection -> HeaderSection -> HeaderSection
>= :: HeaderSection -> HeaderSection -> Bool
$c>= :: HeaderSection -> HeaderSection -> Bool
> :: HeaderSection -> HeaderSection -> Bool
$c> :: HeaderSection -> HeaderSection -> Bool
<= :: HeaderSection -> HeaderSection -> Bool
$c<= :: HeaderSection -> HeaderSection -> Bool
< :: HeaderSection -> HeaderSection -> Bool
$c< :: HeaderSection -> HeaderSection -> Bool
compare :: HeaderSection -> HeaderSection -> Ordering
$ccompare :: HeaderSection -> HeaderSection -> Ordering
Ord)

-- | A substitute expression compiler, tried before the main
-- compilation function.
type OpCompiler op s = op -> CompilerM op s ()

type ErrorCompiler op s = ErrorMsg Exp -> String -> CompilerM op s ()

-- | The address space qualifiers for a pointer of the given type with
-- the given annotation.
type PointerQuals = String -> [C.TypeQual]

-- | The type of a memory block in the given memory space.
type MemoryType op s = SpaceId -> CompilerM op s C.Type

-- | Write a scalar to the given memory block with the given element
-- index and in the given memory space.
type WriteScalar op s =
  C.Exp -> C.Exp -> C.Type -> SpaceId -> Volatility -> C.Exp -> CompilerM op s ()

-- | Read a scalar from the given memory block with the given element
-- index and in the given memory space.
type ReadScalar op s =
  C.Exp -> C.Exp -> C.Type -> SpaceId -> Volatility -> CompilerM op s C.Exp

-- | Allocate a memory block of the given size and with the given tag
-- in the given memory space, saving a reference in the given variable
-- name.
type Allocate op s =
  C.Exp ->
  C.Exp ->
  C.Exp ->
  SpaceId ->
  CompilerM op s ()

-- | De-allocate the given memory block, with the given tag, with the
-- given size,, which is in the given memory space.
type Deallocate op s = C.Exp -> C.Exp -> C.Exp -> SpaceId -> CompilerM op s ()

-- | Whether a copying operation should implicitly function as a
-- barrier regarding further operations on the source.  This is a
-- rather subtle detail and is mostly useful for letting some
-- device/GPU copies be asynchronous (#1664).
data CopyBarrier
  = CopyBarrier
  | -- | Explicit context synchronisation should be done
    -- before the source or target is used.
    CopyNoBarrier
  deriving (CopyBarrier -> CopyBarrier -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: CopyBarrier -> CopyBarrier -> Bool
$c/= :: CopyBarrier -> CopyBarrier -> Bool
== :: CopyBarrier -> CopyBarrier -> Bool
$c== :: CopyBarrier -> CopyBarrier -> Bool
Eq, Int -> CopyBarrier -> ShowS
[CopyBarrier] -> ShowS
CopyBarrier -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [CopyBarrier] -> ShowS
$cshowList :: [CopyBarrier] -> ShowS
show :: CopyBarrier -> [Char]
$cshow :: CopyBarrier -> [Char]
showsPrec :: Int -> CopyBarrier -> ShowS
$cshowsPrec :: Int -> CopyBarrier -> ShowS
Show)

-- | Copy from one memory block to another.
type Copy op s =
  CopyBarrier ->
  C.Exp ->
  C.Exp ->
  Space ->
  C.Exp ->
  C.Exp ->
  Space ->
  C.Exp ->
  CompilerM op s ()

-- | Call a function.
type CallCompiler op s = [VName] -> Name -> [C.Exp] -> CompilerM op s ()

data Operations op s = Operations
  { forall op s. Operations op s -> WriteScalar op s
opsWriteScalar :: WriteScalar op s,
    forall op s. Operations op s -> ReadScalar op s
opsReadScalar :: ReadScalar op s,
    forall op s. Operations op s -> Allocate op s
opsAllocate :: Allocate op s,
    forall op s. Operations op s -> Allocate op s
opsDeallocate :: Deallocate op s,
    forall op s. Operations op s -> Copy op s
opsCopy :: Copy op s,
    forall op s. Operations op s -> MemoryType op s
opsMemoryType :: MemoryType op s,
    forall op s. Operations op s -> OpCompiler op s
opsCompiler :: OpCompiler op s,
    forall op s. Operations op s -> ErrorCompiler op s
opsError :: ErrorCompiler op s,
    forall op s. Operations op s -> CallCompiler op s
opsCall :: CallCompiler op s,
    -- | If true, use reference counting.  Otherwise, bare
    -- pointers.
    forall op s. Operations op s -> Bool
opsFatMemory :: Bool,
    -- | Code to bracket critical sections.
    forall op s. Operations op s -> ([BlockItem], [BlockItem])
opsCritical :: ([C.BlockItem], [C.BlockItem])
  }

freeAllocatedMem :: CompilerM op s [C.BlockItem]
freeAllocatedMem :: forall op s. CompilerM op s [BlockItem]
freeAllocatedMem = forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
collect forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry forall a op s. ToExp a => a -> Space -> CompilerM op s ()
unRefMem) forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall s. CompilerState s -> [(VName, Space)]
compDeclaredMem

declAllocatedMem :: CompilerM op s [C.BlockItem]
declAllocatedMem :: forall op s. CompilerM op s [BlockItem]
declAllocatedMem = forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
collect forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall {op} {s}. (VName, Space) -> CompilerM op s ()
f forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall s. CompilerState s -> [(VName, Space)]
compDeclaredMem
  where
    f :: (VName, Space) -> CompilerM op s ()
f (VName
name, Space
space) = do
      Type
ty <- forall op s. VName -> Space -> CompilerM op s Type
memToCType VName
name Space
space
      forall op s. InitGroup -> CompilerM op s ()
decl [C.cdecl|$ty:ty $id:name;|]
      forall a op s. ToExp a => a -> Space -> CompilerM op s ()
resetMem VName
name Space
space

data CompilerEnv op s = CompilerEnv
  { forall op s. CompilerEnv op s -> Operations op s
envOperations :: Operations op s,
    -- | Mapping memory blocks to sizes.  These memory blocks are CPU
    -- memory that we know are used in particularly simple ways (no
    -- reference counting necessary).  To cut down on allocator
    -- pressure, we keep these allocations around for a long time, and
    -- record their sizes so we can reuse them if possible (and
    -- realloc() when needed).
    forall op s. CompilerEnv op s -> Map Exp VName
envCachedMem :: M.Map C.Exp VName
  }

contextContents :: CompilerM op s ([C.FieldGroup], [C.Stm], [C.Stm])
contextContents :: forall op s. CompilerM op s ([FieldGroup], [Stm], [Stm])
contextContents = do
  ([Id]
field_names, [Type]
field_types, [Maybe Exp]
field_values, [Maybe (Stm, Stm)]
field_frees) <-
    forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall a b. (a -> b) -> a -> b
$ forall a b c d. [(a, b, c, d)] -> ([a], [b], [c], [d])
unzip4 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. DList a -> [a]
DL.toList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s.
CompilerState s -> DList (Id, Type, Maybe Exp, Maybe (Stm, Stm))
compCtxFields
  let fields :: [FieldGroup]
fields =
        [ [C.csdecl|$ty:ty $id:name;|]
          | (Id
name, Type
ty) <- forall a b. [a] -> [b] -> [(a, b)]
zip [Id]
field_names [Type]
field_types
        ]
      init_fields :: [Stm]
init_fields =
        [ [C.cstm|ctx->program->$id:name = $exp:e;|]
          | (Id
name, Just Exp
e) <- forall a b. [a] -> [b] -> [(a, b)]
zip [Id]
field_names [Maybe Exp]
field_values
        ]
      ([Stm]
setup, [Stm]
free) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ forall a. [Maybe a] -> [a]
catMaybes [Maybe (Stm, Stm)]
field_frees
  forall (f :: * -> *) a. Applicative f => a -> f a
pure ([FieldGroup]
fields, [Stm]
init_fields forall a. Semigroup a => a -> a -> a
<> [Stm]
setup, [Stm]
free)

generateProgramStruct :: CompilerM op s ()
generateProgramStruct :: forall op s. CompilerM op s ()
generateProgramStruct = do
  ([FieldGroup]
fields, [Stm]
init_fields, [Stm]
free_fields) <- forall op s. CompilerM op s ([FieldGroup], [Stm], [Stm])
contextContents
  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_
    forall op s. Definition -> CompilerM op s ()
earlyDecl
    [C.cunit|struct program {
               $sdecls:fields
             };
             static void setup_program(struct futhark_context* ctx) {
               (void)ctx;
               int error = 0;
               (void)error;
               ctx->program = malloc(sizeof(struct program));
               $stms:init_fields
             }
             static void teardown_program(struct futhark_context *ctx) {
               (void)ctx;
               int error = 0;
               (void)error;
               $stms:free_fields
               free(ctx->program);
             }|]

newtype CompilerM op s a
  = CompilerM (ReaderT (CompilerEnv op s) (State (CompilerState s)) a)
  deriving
    ( forall a b. a -> CompilerM op s b -> CompilerM op s a
forall a b. (a -> b) -> CompilerM op s a -> CompilerM op s b
forall op s a b. a -> CompilerM op s b -> CompilerM op s a
forall op s a b. (a -> b) -> CompilerM op s a -> CompilerM op s b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> CompilerM op s b -> CompilerM op s a
$c<$ :: forall op s a b. a -> CompilerM op s b -> CompilerM op s a
fmap :: forall a b. (a -> b) -> CompilerM op s a -> CompilerM op s b
$cfmap :: forall op s a b. (a -> b) -> CompilerM op s a -> CompilerM op s b
Functor,
      forall a. a -> CompilerM op s a
forall op s. Functor (CompilerM op s)
forall a b.
CompilerM op s a -> CompilerM op s b -> CompilerM op s a
forall a b.
CompilerM op s a -> CompilerM op s b -> CompilerM op s b
forall a b.
CompilerM op s (a -> b) -> CompilerM op s a -> CompilerM op s b
forall op s a. a -> CompilerM op s a
forall a b c.
(a -> b -> c)
-> CompilerM op s a -> CompilerM op s b -> CompilerM op s c
forall op s a b.
CompilerM op s a -> CompilerM op s b -> CompilerM op s a
forall op s a b.
CompilerM op s a -> CompilerM op s b -> CompilerM op s b
forall op s a b.
CompilerM op s (a -> b) -> CompilerM op s a -> CompilerM op s b
forall op s a b c.
(a -> b -> c)
-> CompilerM op s a -> CompilerM op s b -> CompilerM op s c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b.
CompilerM op s a -> CompilerM op s b -> CompilerM op s a
$c<* :: forall op s a b.
CompilerM op s a -> CompilerM op s b -> CompilerM op s a
*> :: forall a b.
CompilerM op s a -> CompilerM op s b -> CompilerM op s b
$c*> :: forall op s a b.
CompilerM op s a -> CompilerM op s b -> CompilerM op s b
liftA2 :: forall a b c.
(a -> b -> c)
-> CompilerM op s a -> CompilerM op s b -> CompilerM op s c
$cliftA2 :: forall op s a b c.
(a -> b -> c)
-> CompilerM op s a -> CompilerM op s b -> CompilerM op s c
<*> :: forall a b.
CompilerM op s (a -> b) -> CompilerM op s a -> CompilerM op s b
$c<*> :: forall op s a b.
CompilerM op s (a -> b) -> CompilerM op s a -> CompilerM op s b
pure :: forall a. a -> CompilerM op s a
$cpure :: forall op s a. a -> CompilerM op s a
Applicative,
      forall a. a -> CompilerM op s a
forall op s. Applicative (CompilerM op s)
forall a b.
CompilerM op s a -> CompilerM op s b -> CompilerM op s b
forall a b.
CompilerM op s a -> (a -> CompilerM op s b) -> CompilerM op s b
forall op s a. a -> CompilerM op s a
forall op s a b.
CompilerM op s a -> CompilerM op s b -> CompilerM op s b
forall op s a b.
CompilerM op s a -> (a -> CompilerM op s b) -> CompilerM op s b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> CompilerM op s a
$creturn :: forall op s a. a -> CompilerM op s a
>> :: forall a b.
CompilerM op s a -> CompilerM op s b -> CompilerM op s b
$c>> :: forall op s a b.
CompilerM op s a -> CompilerM op s b -> CompilerM op s b
>>= :: forall a b.
CompilerM op s a -> (a -> CompilerM op s b) -> CompilerM op s b
$c>>= :: forall op s a b.
CompilerM op s a -> (a -> CompilerM op s b) -> CompilerM op s b
Monad,
      MonadState (CompilerState s),
      MonadReader (CompilerEnv op s)
    )

instance MonadFreshNames (CompilerM op s) where
  getNameSource :: CompilerM op s VNameSource
getNameSource = forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall s. CompilerState s -> VNameSource
compNameSrc
  putNameSource :: VNameSource -> CompilerM op s ()
putNameSource VNameSource
src = forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \CompilerState s
s -> CompilerState s
s {compNameSrc :: VNameSource
compNameSrc = VNameSource
src}

runCompilerM ::
  Operations op s ->
  VNameSource ->
  s ->
  CompilerM op s a ->
  (a, CompilerState s)
runCompilerM :: forall op s a.
Operations op s
-> VNameSource -> s -> CompilerM op s a -> (a, CompilerState s)
runCompilerM Operations op s
ops VNameSource
src s
userstate (CompilerM ReaderT (CompilerEnv op s) (State (CompilerState s)) a
m) =
  forall s a. State s a -> s -> (a, s)
runState
    (forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (CompilerEnv op s) (State (CompilerState s)) a
m (forall op s. Operations op s -> Map Exp VName -> CompilerEnv op s
CompilerEnv Operations op s
ops forall a. Monoid a => a
mempty))
    (forall s. VNameSource -> s -> CompilerState s
newCompilerState VNameSource
src s
userstate)

getUserState :: CompilerM op s s
getUserState :: forall op s. CompilerM op s s
getUserState = forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall s. CompilerState s -> s
compUserState

modifyUserState :: (s -> s) -> CompilerM op s ()
modifyUserState :: forall s op. (s -> s) -> CompilerM op s ()
modifyUserState s -> s
f = forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \CompilerState s
compstate ->
  CompilerState s
compstate {compUserState :: s
compUserState = s -> s
f forall a b. (a -> b) -> a -> b
$ forall s. CompilerState s -> s
compUserState CompilerState s
compstate}

collect :: CompilerM op s () -> CompilerM op s [C.BlockItem]
collect :: forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
collect CompilerM op s ()
m = forall a b. (a, b) -> b
snd forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall op s a. CompilerM op s a -> CompilerM op s (a, [BlockItem])
collect' CompilerM op s ()
m

collect' :: CompilerM op s a -> CompilerM op s (a, [C.BlockItem])
collect' :: forall op s a. CompilerM op s a -> CompilerM op s (a, [BlockItem])
collect' CompilerM op s a
m = do
  DList BlockItem
old <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall s. CompilerState s -> DList BlockItem
compItems
  forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \CompilerState s
s -> CompilerState s
s {compItems :: DList BlockItem
compItems = forall a. Monoid a => a
mempty}
  a
x <- CompilerM op s a
m
  DList BlockItem
new <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall s. CompilerState s -> DList BlockItem
compItems
  forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \CompilerState s
s -> CompilerState s
s {compItems :: DList BlockItem
compItems = DList BlockItem
old}
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
x, forall a. DList a -> [a]
DL.toList DList BlockItem
new)

-- | Used when we, inside an existing 'CompilerM' action, want to
-- generate code for a new function.  Use this so that the compiler
-- understands that previously declared memory doesn't need to be
-- freed inside this action.
inNewFunction :: CompilerM op s a -> CompilerM op s a
inNewFunction :: forall op s a. CompilerM op s a -> CompilerM op s a
inNewFunction CompilerM op s a
m = do
  [(VName, Space)]
old_mem <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall s. CompilerState s -> [(VName, Space)]
compDeclaredMem
  forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \CompilerState s
s -> CompilerState s
s {compDeclaredMem :: [(VName, Space)]
compDeclaredMem = forall a. Monoid a => a
mempty}
  a
x <- forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local forall {op} {s}. CompilerEnv op s -> CompilerEnv op s
noCached CompilerM op s a
m
  forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \CompilerState s
s -> CompilerState s
s {compDeclaredMem :: [(VName, Space)]
compDeclaredMem = [(VName, Space)]
old_mem}
  forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x
  where
    noCached :: CompilerEnv op s -> CompilerEnv op s
noCached CompilerEnv op s
env = CompilerEnv op s
env {envCachedMem :: Map Exp VName
envCachedMem = forall a. Monoid a => a
mempty}

item :: C.BlockItem -> CompilerM op s ()
item :: forall op s. BlockItem -> CompilerM op s ()
item BlockItem
x = forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \CompilerState s
s -> CompilerState s
s {compItems :: DList BlockItem
compItems = forall a. DList a -> a -> DList a
DL.snoc (forall s. CompilerState s -> DList BlockItem
compItems CompilerState s
s) BlockItem
x}

items :: [C.BlockItem] -> CompilerM op s ()
items :: forall op s. [BlockItem] -> CompilerM op s ()
items [BlockItem]
xs = forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \CompilerState s
s -> CompilerState s
s {compItems :: DList BlockItem
compItems = forall a. DList a -> DList a -> DList a
DL.append (forall s. CompilerState s -> DList BlockItem
compItems CompilerState s
s) (forall a. [a] -> DList a
DL.fromList [BlockItem]
xs)}

fatMemory :: Space -> CompilerM op s Bool
fatMemory :: forall op s. Space -> CompilerM op s Bool
fatMemory ScalarSpace {} = forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
fatMemory Space
_ = forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall a b. (a -> b) -> a -> b
$ forall op s. Operations op s -> Bool
opsFatMemory forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall op s. CompilerEnv op s -> Operations op s
envOperations

cacheMem :: C.ToExp a => a -> CompilerM op s (Maybe VName)
cacheMem :: forall a op s. ToExp a => a -> CompilerM op s (Maybe VName)
cacheMem a
a = forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp a
a forall a. IsLocation a => a
noLoc) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall op s. CompilerEnv op s -> Map Exp VName
envCachedMem

-- | Construct a publicly visible definition using the specified name
-- as the template.  The first returned definition is put in the
-- header file, and the second is the implementation.  Returns the public
-- name.
publicDef ::
  T.Text ->
  HeaderSection ->
  (T.Text -> (C.Definition, C.Definition)) ->
  CompilerM op s T.Text
publicDef :: forall op s.
Text
-> HeaderSection
-> (Text -> (Definition, Definition))
-> CompilerM op s Text
publicDef Text
s HeaderSection
h Text -> (Definition, Definition)
f = do
  Text
s' <- forall op s. Text -> CompilerM op s Text
publicName Text
s
  let (Definition
pub, Definition
priv) = Text -> (Definition, Definition)
f Text
s'
  forall op s. HeaderSection -> Definition -> CompilerM op s ()
headerDecl HeaderSection
h Definition
pub
  forall op s. Definition -> CompilerM op s ()
earlyDecl Definition
priv
  forall (f :: * -> *) a. Applicative f => a -> f a
pure Text
s'

-- | As 'publicDef', but ignores the public name.
publicDef_ ::
  T.Text ->
  HeaderSection ->
  (T.Text -> (C.Definition, C.Definition)) ->
  CompilerM op s ()
publicDef_ :: forall op s.
Text
-> HeaderSection
-> (Text -> (Definition, Definition))
-> CompilerM op s ()
publicDef_ Text
s HeaderSection
h Text -> (Definition, Definition)
f = forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ forall op s.
Text
-> HeaderSection
-> (Text -> (Definition, Definition))
-> CompilerM op s Text
publicDef Text
s HeaderSection
h Text -> (Definition, Definition)
f

headerDecl :: HeaderSection -> C.Definition -> CompilerM op s ()
headerDecl :: forall op s. HeaderSection -> Definition -> CompilerM op s ()
headerDecl HeaderSection
sec Definition
def = forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \CompilerState s
s ->
  CompilerState s
s
    { compHeaderDecls :: Map HeaderSection (DList Definition)
compHeaderDecls =
        forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
M.unionWith
          forall a. Semigroup a => a -> a -> a
(<>)
          (forall s. CompilerState s -> Map HeaderSection (DList Definition)
compHeaderDecls CompilerState s
s)
          (forall k a. k -> a -> Map k a
M.singleton HeaderSection
sec (forall a. a -> DList a
DL.singleton Definition
def))
    }

libDecl :: C.Definition -> CompilerM op s ()
libDecl :: forall op s. Definition -> CompilerM op s ()
libDecl Definition
def = forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \CompilerState s
s ->
  CompilerState s
s {compLibDecls :: DList Definition
compLibDecls = forall s. CompilerState s -> DList Definition
compLibDecls CompilerState s
s forall a. Semigroup a => a -> a -> a
<> forall a. a -> DList a
DL.singleton Definition
def}

earlyDecl :: C.Definition -> CompilerM op s ()
earlyDecl :: forall op s. Definition -> CompilerM op s ()
earlyDecl Definition
def = forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \CompilerState s
s ->
  CompilerState s
s {compEarlyDecls :: DList Definition
compEarlyDecls = forall s. CompilerState s -> DList Definition
compEarlyDecls CompilerState s
s forall a. Semigroup a => a -> a -> a
<> forall a. a -> DList a
DL.singleton Definition
def}

contextField :: C.Id -> C.Type -> Maybe C.Exp -> CompilerM op s ()
contextField :: forall op s. Id -> Type -> Maybe Exp -> CompilerM op s ()
contextField Id
name Type
ty Maybe Exp
initial = forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \CompilerState s
s ->
  CompilerState s
s {compCtxFields :: DList (Id, Type, Maybe Exp, Maybe (Stm, Stm))
compCtxFields = forall s.
CompilerState s -> DList (Id, Type, Maybe Exp, Maybe (Stm, Stm))
compCtxFields CompilerState s
s forall a. Semigroup a => a -> a -> a
<> forall a. a -> DList a
DL.singleton (Id
name, Type
ty, Maybe Exp
initial, forall a. Maybe a
Nothing)}

contextFieldDyn :: C.Id -> C.Type -> C.Stm -> C.Stm -> CompilerM op s ()
contextFieldDyn :: forall op s. Id -> Type -> Stm -> Stm -> CompilerM op s ()
contextFieldDyn Id
name Type
ty Stm
create Stm
free = forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \CompilerState s
s ->
  CompilerState s
s {compCtxFields :: DList (Id, Type, Maybe Exp, Maybe (Stm, Stm))
compCtxFields = forall s.
CompilerState s -> DList (Id, Type, Maybe Exp, Maybe (Stm, Stm))
compCtxFields CompilerState s
s forall a. Semigroup a => a -> a -> a
<> forall a. a -> DList a
DL.singleton (Id
name, Type
ty, forall a. Maybe a
Nothing, forall a. a -> Maybe a
Just (Stm
create, Stm
free))}

profileReport :: C.BlockItem -> CompilerM op s ()
profileReport :: forall op s. BlockItem -> CompilerM op s ()
profileReport BlockItem
x = forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \CompilerState s
s ->
  CompilerState s
s {compProfileItems :: DList BlockItem
compProfileItems = forall s. CompilerState s -> DList BlockItem
compProfileItems CompilerState s
s forall a. Semigroup a => a -> a -> a
<> forall a. a -> DList a
DL.singleton BlockItem
x}

onClear :: C.BlockItem -> CompilerM op s ()
onClear :: forall op s. BlockItem -> CompilerM op s ()
onClear BlockItem
x = forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \CompilerState s
s ->
  CompilerState s
s {compClearItems :: DList BlockItem
compClearItems = forall s. CompilerState s -> DList BlockItem
compClearItems CompilerState s
s forall a. Semigroup a => a -> a -> a
<> forall a. a -> DList a
DL.singleton BlockItem
x}

stm :: C.Stm -> CompilerM op s ()
stm :: forall op s. Stm -> CompilerM op s ()
stm Stm
s = forall op s. BlockItem -> CompilerM op s ()
item [C.citem|$stm:s|]

stms :: [C.Stm] -> CompilerM op s ()
stms :: forall op s. [Stm] -> CompilerM op s ()
stms = forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall op s. Stm -> CompilerM op s ()
stm

decl :: C.InitGroup -> CompilerM op s ()
decl :: forall op s. InitGroup -> CompilerM op s ()
decl InitGroup
x = forall op s. BlockItem -> CompilerM op s ()
item [C.citem|$decl:x;|]

-- | Public names must have a consitent prefix.
publicName :: T.Text -> CompilerM op s T.Text
publicName :: forall op s. Text -> CompilerM op s Text
publicName Text
s = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Text
"futhark_" forall a. Semigroup a => a -> a -> a
<> Text
s

memToCType :: VName -> Space -> CompilerM op s C.Type
memToCType :: forall op s. VName -> Space -> CompilerM op s Type
memToCType VName
v Space
space = do
  Bool
refcount <- forall op s. Space -> CompilerM op s Bool
fatMemory Space
space
  Bool
cached <- forall a. Maybe a -> Bool
isJust forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a op s. ToExp a => a -> CompilerM op s (Maybe VName)
cacheMem VName
v
  if Bool
refcount Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
cached
    then forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Space -> Type
fatMemType Space
space
    else forall op s. Space -> CompilerM op s Type
rawMemCType Space
space

rawMemCType :: Space -> CompilerM op s C.Type
rawMemCType :: forall op s. Space -> CompilerM op s Type
rawMemCType Space
DefaultSpace = forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
defaultMemBlockType
rawMemCType (Space [Char]
sid) = forall (m :: * -> *) a. Monad m => m (m a) -> m a
join forall a b. (a -> b) -> a -> b
$ forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (forall op s. Operations op s -> MemoryType op s
opsMemoryType forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall op s. CompilerEnv op s -> Operations op s
envOperations) forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure [Char]
sid
rawMemCType (ScalarSpace [] PrimType
t) =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cty|$ty:(primTypeToCType t)[1]|]
rawMemCType (ScalarSpace [SubExp]
ds PrimType
t) =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cty|$ty:(primTypeToCType t)[$exp:(cproduct ds')]|]
  where
    ds' :: [Exp]
ds' = forall a b. (a -> b) -> [a] -> [b]
map (forall a. ToExp a => a -> SrcLoc -> Exp
`C.toExp` forall a. IsLocation a => a
noLoc) [SubExp]
ds

fatMemType :: Space -> C.Type
fatMemType :: Space -> Type
fatMemType Space
space =
  [C.cty|struct $id:name|]
  where
    name :: [Char]
name = case Space
space of
      Space [Char]
sid -> [Char]
"memblock_" forall a. [a] -> [a] -> [a]
++ [Char]
sid
      Space
_ -> [Char]
"memblock"

fatMemSet :: Space -> String
fatMemSet :: Space -> [Char]
fatMemSet (Space [Char]
sid) = [Char]
"memblock_set_" forall a. [a] -> [a] -> [a]
++ [Char]
sid
fatMemSet Space
_ = [Char]
"memblock_set"

fatMemAlloc :: Space -> String
fatMemAlloc :: Space -> [Char]
fatMemAlloc (Space [Char]
sid) = [Char]
"memblock_alloc_" forall a. [a] -> [a] -> [a]
++ [Char]
sid
fatMemAlloc Space
_ = [Char]
"memblock_alloc"

fatMemUnRef :: Space -> String
fatMemUnRef :: Space -> [Char]
fatMemUnRef (Space [Char]
sid) = [Char]
"memblock_unref_" forall a. [a] -> [a] -> [a]
++ [Char]
sid
fatMemUnRef Space
_ = [Char]
"memblock_unref"

rawMem :: VName -> CompilerM op s C.Exp
rawMem :: forall op s. VName -> CompilerM op s Exp
rawMem VName
v = forall a. ToExp a => Bool -> a -> Exp
rawMem' forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {op} {s}. CompilerM op s Bool
fat forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v
  where
    fat :: CompilerM op s Bool
fat = forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (Bool -> Bool -> Bool
(&&) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall op s. Operations op s -> Bool
opsFatMemory forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall op s. CompilerEnv op s -> Operations op s
envOperations) forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (forall a. Maybe a -> Bool
isNothing forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a op s. ToExp a => a -> CompilerM op s (Maybe VName)
cacheMem VName
v)

rawMem' :: C.ToExp a => Bool -> a -> C.Exp
rawMem' :: forall a. ToExp a => Bool -> a -> Exp
rawMem' Bool
True a
e = [C.cexp|$exp:e.mem|]
rawMem' Bool
False a
e = [C.cexp|$exp:e|]

allocRawMem ::
  (C.ToExp a, C.ToExp b, C.ToExp c) =>
  a ->
  b ->
  Space ->
  c ->
  CompilerM op s ()
allocRawMem :: forall a b c op s.
(ToExp a, ToExp b, ToExp c) =>
a -> b -> Space -> c -> CompilerM op s ()
allocRawMem a
dest b
size Space
space c
desc = case Space
space of
  Space [Char]
sid ->
    forall (m :: * -> *) a. Monad m => m (m a) -> m a
join forall a b. (a -> b) -> a -> b
$
      forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (forall op s. Operations op s -> Allocate op s
opsAllocate forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall op s. CompilerEnv op s -> Operations op s
envOperations)
        forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cexp|$exp:dest|]
        forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cexp|$exp:size|]
        forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cexp|$exp:desc|]
        forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure [Char]
sid
  Space
_ ->
    forall op s. Stm -> CompilerM op s ()
stm
      [C.cstm|host_alloc(ctx, (size_t)$exp:size, $exp:desc, (size_t*)&$exp:size, (void*)&$exp:dest);|]

freeRawMem ::
  (C.ToExp a, C.ToExp b, C.ToExp c) =>
  a ->
  b ->
  Space ->
  c ->
  CompilerM op s ()
freeRawMem :: forall a b c op s.
(ToExp a, ToExp b, ToExp c) =>
a -> b -> Space -> c -> CompilerM op s ()
freeRawMem a
mem b
size Space
space c
desc =
  case Space
space of
    Space [Char]
sid -> do
      Deallocate op s
free_mem <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (forall op s. Operations op s -> Allocate op s
opsDeallocate forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall op s. CompilerEnv op s -> Operations op s
envOperations)
      Deallocate op s
free_mem [C.cexp|$exp:mem|] [C.cexp|$exp:size|] [C.cexp|$exp:desc|] [Char]
sid
    Space
_ ->
      forall op s. BlockItem -> CompilerM op s ()
item
        [C.citem|host_free(ctx, (size_t)$exp:size, $exp:desc, (void*)$exp:mem);|]

declMem :: VName -> Space -> CompilerM op s ()
declMem :: forall op s. VName -> Space -> CompilerM op s ()
declMem VName
name Space
space = do
  Bool
cached <- forall a. Maybe a -> Bool
isJust forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a op s. ToExp a => a -> CompilerM op s (Maybe VName)
cacheMem VName
name
  Bool
fat <- forall op s. Space -> CompilerM op s Bool
fatMemory Space
space
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
cached forall a b. (a -> b) -> a -> b
$
    if Bool
fat
      then forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \CompilerState s
s -> CompilerState s
s {compDeclaredMem :: [(VName, Space)]
compDeclaredMem = (VName
name, Space
space) forall a. a -> [a] -> [a]
: forall s. CompilerState s -> [(VName, Space)]
compDeclaredMem CompilerState s
s}
      else do
        Type
ty <- forall op s. VName -> Space -> CompilerM op s Type
memToCType VName
name Space
space
        forall op s. InitGroup -> CompilerM op s ()
decl [C.cdecl|$ty:ty $id:name;|]

resetMem :: C.ToExp a => a -> Space -> CompilerM op s ()
resetMem :: forall a op s. ToExp a => a -> Space -> CompilerM op s ()
resetMem a
mem Space
space = do
  Bool
refcount <- forall op s. Space -> CompilerM op s Bool
fatMemory Space
space
  Bool
cached <- forall a. Maybe a -> Bool
isJust forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a op s. ToExp a => a -> CompilerM op s (Maybe VName)
cacheMem a
mem
  if Bool
cached
    then forall op s. Stm -> CompilerM op s ()
stm [C.cstm|$exp:mem = NULL;|]
    else
      forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
refcount forall a b. (a -> b) -> a -> b
$
        forall op s. Stm -> CompilerM op s ()
stm [C.cstm|$exp:mem.references = NULL;|]

setMem :: (C.ToExp a, C.ToExp b) => a -> b -> Space -> CompilerM op s ()
setMem :: forall a b op s.
(ToExp a, ToExp b) =>
a -> b -> Space -> CompilerM op s ()
setMem a
dest b
src Space
space = do
  Bool
refcount <- forall op s. Space -> CompilerM op s Bool
fatMemory Space
space
  let src_s :: [Char]
src_s = Text -> [Char]
T.unpack forall a b. (a -> b) -> a -> b
$ Exp -> Text
expText forall a b. (a -> b) -> a -> b
$ forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp b
src forall a. IsLocation a => a
noLoc
  if Bool
refcount
    then
      forall op s. Stm -> CompilerM op s ()
stm
        [C.cstm|if ($id:(fatMemSet space)(ctx, &$exp:dest, &$exp:src,
                                               $string:src_s) != 0) {
                       return 1;
                     }|]
    else case Space
space of
      ScalarSpace [SubExp]
ds PrimType
_ -> do
        VName
i' <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"i"
        let i :: SrcLoc -> Id
i = forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent VName
i'
            it :: Type
it = PrimType -> Type
primTypeToCType forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int32
            ds' :: [Exp]
ds' = forall a b. (a -> b) -> [a] -> [b]
map (forall a. ToExp a => a -> SrcLoc -> Exp
`C.toExp` forall a. IsLocation a => a
noLoc) [SubExp]
ds
            bound :: Exp
bound = [Exp] -> Exp
cproduct [Exp]
ds'
        forall op s. Stm -> CompilerM op s ()
stm
          [C.cstm|for ($ty:it $id:i = 0; $id:i < $exp:bound; $id:i++) {
                            $exp:dest[$id:i] = $exp:src[$id:i];
                  }|]
      Space
_ -> forall op s. Stm -> CompilerM op s ()
stm [C.cstm|$exp:dest = $exp:src;|]

unRefMem :: C.ToExp a => a -> Space -> CompilerM op s ()
unRefMem :: forall a op s. ToExp a => a -> Space -> CompilerM op s ()
unRefMem a
mem Space
space = do
  Bool
refcount <- forall op s. Space -> CompilerM op s Bool
fatMemory Space
space
  Bool
cached <- forall a. Maybe a -> Bool
isJust forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a op s. ToExp a => a -> CompilerM op s (Maybe VName)
cacheMem a
mem
  let mem_s :: [Char]
mem_s = Text -> [Char]
T.unpack forall a b. (a -> b) -> a -> b
$ Exp -> Text
expText forall a b. (a -> b) -> a -> b
$ forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp a
mem forall a. IsLocation a => a
noLoc
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool
refcount Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
cached) forall a b. (a -> b) -> a -> b
$
    forall op s. Stm -> CompilerM op s ()
stm
      [C.cstm|if ($id:(fatMemUnRef space)(ctx, &$exp:mem, $string:mem_s) != 0) {
                  return 1;
                }|]

allocMem ::
  (C.ToExp a, C.ToExp b) =>
  a ->
  b ->
  Space ->
  C.Stm ->
  CompilerM op s ()
allocMem :: forall a b op s.
(ToExp a, ToExp b) =>
a -> b -> Space -> Stm -> CompilerM op s ()
allocMem a
mem b
size Space
space Stm
on_failure = do
  Bool
refcount <- forall op s. Space -> CompilerM op s Bool
fatMemory Space
space
  let mem_s :: [Char]
mem_s = Text -> [Char]
T.unpack forall a b. (a -> b) -> a -> b
$ Exp -> Text
expText forall a b. (a -> b) -> a -> b
$ forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp a
mem forall a. IsLocation a => a
noLoc
  if Bool
refcount
    then
      forall op s. Stm -> CompilerM op s ()
stm
        [C.cstm|if ($id:(fatMemAlloc space)(ctx, &$exp:mem, $exp:size,
                                                 $string:mem_s)) {
                       $stm:on_failure
                     }|]
    else do
      forall a b c op s.
(ToExp a, ToExp b, ToExp c) =>
a -> b -> Space -> c -> CompilerM op s ()
freeRawMem a
mem b
size Space
space [Char]
mem_s
      forall a b c op s.
(ToExp a, ToExp b, ToExp c) =>
a -> b -> Space -> c -> CompilerM op s ()
allocRawMem a
mem b
size Space
space [C.cexp|desc|]

copyMemoryDefaultSpace ::
  C.Exp ->
  C.Exp ->
  C.Exp ->
  C.Exp ->
  C.Exp ->
  CompilerM op s ()
copyMemoryDefaultSpace :: forall op s. Exp -> Exp -> Exp -> Exp -> Exp -> CompilerM op s ()
copyMemoryDefaultSpace Exp
destmem Exp
destidx Exp
srcmem Exp
srcidx Exp
nbytes =
  forall op s. Stm -> CompilerM op s ()
stm
    [C.cstm|if ($exp:nbytes > 0) {
              memmove($exp:destmem + $exp:destidx,
                      $exp:srcmem + $exp:srcidx,
                      $exp:nbytes);
            }|]

cachingMemory ::
  M.Map VName Space ->
  ([C.BlockItem] -> [C.Stm] -> CompilerM op s a) ->
  CompilerM op s a
cachingMemory :: forall op s a.
Map VName Space
-> ([BlockItem] -> [Stm] -> CompilerM op s a) -> CompilerM op s a
cachingMemory Map VName Space
lexical [BlockItem] -> [Stm] -> CompilerM op s a
f = do
  -- We only consider lexical 'DefaultSpace' memory blocks to be
  -- cached.  This is not a deep technical restriction, but merely a
  -- heuristic based on GPU memory usually involving larger
  -- allocations, that do not suffer from the overhead of reference
  -- counting.
  let cached :: [VName]
cached = forall k a. Map k a -> [k]
M.keys forall a b. (a -> b) -> a -> b
$ forall a k. (a -> Bool) -> Map k a -> Map k a
M.filter (forall a. Eq a => a -> a -> Bool
== Space
DefaultSpace) Map VName Space
lexical

  [(VName, VName)]
cached' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
cached forall a b. (a -> b) -> a -> b
$ \VName
mem -> do
    VName
size <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName forall a b. (a -> b) -> a -> b
$ forall a. Pretty a => a -> [Char]
prettyString VName
mem forall a. Semigroup a => a -> a -> a
<> [Char]
"_cached_size"
    forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
mem, VName
size)

  let lexMem :: CompilerEnv op s -> CompilerEnv op s
lexMem CompilerEnv op s
env =
        CompilerEnv op s
env
          { envCachedMem :: Map Exp VName
envCachedMem =
              forall k a. Ord k => [(k, a)] -> Map k a
M.fromList (forall a b. (a -> b) -> [a] -> [b]
map (forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (forall a. ToExp a => a -> SrcLoc -> Exp
`C.toExp` forall a. IsLocation a => a
noLoc)) [(VName, VName)]
cached')
                forall a. Semigroup a => a -> a -> a
<> forall op s. CompilerEnv op s -> Map Exp VName
envCachedMem CompilerEnv op s
env
          }

      declCached :: (a, a) -> [BlockItem]
declCached (a
mem, a
size) =
        [ [C.citem|typename int64_t $id:size = 0;|],
          [C.citem|$ty:defaultMemBlockType $id:mem = NULL;|]
        ]

      freeCached :: (a, b) -> Stm
freeCached (a
mem, b
_) =
        [C.cstm|free($id:mem);|]

  forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local forall {op} {s}. CompilerEnv op s -> CompilerEnv op s
lexMem forall a b. (a -> b) -> a -> b
$ [BlockItem] -> [Stm] -> CompilerM op s a
f (forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap forall {a} {a}. (ToIdent a, ToIdent a) => (a, a) -> [BlockItem]
declCached [(VName, VName)]
cached') (forall a b. (a -> b) -> [a] -> [b]
map forall {a} {b}. ToIdent a => (a, b) -> Stm
freeCached [(VName, VName)]
cached')

derefPointer :: C.Exp -> C.Exp -> C.Type -> C.Exp
derefPointer :: Exp -> Exp -> Type -> Exp
derefPointer Exp
ptr Exp
i Type
res_t =
  [C.cexp|(($ty:res_t)$exp:ptr)[$exp:i]|]

volQuals :: Volatility -> [C.TypeQual]
volQuals :: Volatility -> [TypeQual]
volQuals Volatility
Volatile = [C.ctyquals|volatile|]
volQuals Volatility
Nonvolatile = []

writeScalarPointerWithQuals :: PointerQuals -> WriteScalar op s
writeScalarPointerWithQuals :: forall op s. PointerQuals -> WriteScalar op s
writeScalarPointerWithQuals PointerQuals
quals_f Exp
dest Exp
i Type
elemtype [Char]
space Volatility
vol Exp
v = do
  let quals' :: [TypeQual]
quals' = Volatility -> [TypeQual]
volQuals Volatility
vol forall a. [a] -> [a] -> [a]
++ PointerQuals
quals_f [Char]
space
      deref :: Exp
deref = Exp -> Exp -> Type -> Exp
derefPointer Exp
dest Exp
i [C.cty|$tyquals:quals' $ty:elemtype*|]
  forall op s. Stm -> CompilerM op s ()
stm [C.cstm|$exp:deref = $exp:v;|]

readScalarPointerWithQuals :: PointerQuals -> ReadScalar op s
readScalarPointerWithQuals :: forall op s. PointerQuals -> ReadScalar op s
readScalarPointerWithQuals PointerQuals
quals_f Exp
dest Exp
i Type
elemtype [Char]
space Volatility
vol = do
  let quals' :: [TypeQual]
quals' = Volatility -> [TypeQual]
volQuals Volatility
vol forall a. [a] -> [a] -> [a]
++ PointerQuals
quals_f [Char]
space
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Exp -> Exp -> Type -> Exp
derefPointer Exp
dest Exp
i [C.cty|$tyquals:quals' $ty:elemtype*|]

criticalSection :: Operations op s -> [C.BlockItem] -> [C.BlockItem]
criticalSection :: forall op s. Operations op s -> [BlockItem] -> [BlockItem]
criticalSection Operations op s
ops [BlockItem]
x =
  [C.citems|lock_lock(&ctx->lock);
            $items:(fst (opsCritical ops))
            $items:x
            $items:(snd (opsCritical ops))
            lock_unlock(&ctx->lock);
           |]

-- | The generated code must define a context struct with this name.
contextType :: CompilerM op s C.Type
contextType :: forall op s. CompilerM op s Type
contextType = do
  Text
name <- forall op s. Text -> CompilerM op s Text
publicName Text
"context"
  forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cty|struct $id:name|]

-- | The generated code must define a configuration struct with this
-- name.
configType :: CompilerM op s C.Type
configType :: forall op s. CompilerM op s Type
configType = do
  Text
name <- forall op s. Text -> CompilerM op s Text
publicName Text
"context_config"
  forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cty|struct $id:name|]