{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE Trustworthy #-}
{-# LANGUAGE TupleSections #-}

-- | C code generator framework.
module Futhark.CodeGen.Backends.GenericC
  ( compileProg,
    CParts (..),
    asLibrary,
    asExecutable,
    asServer,

    -- * Pluggable compiler
    Operations (..),
    defaultOperations,
    OpCompiler,
    ErrorCompiler,
    CallCompiler,
    PointerQuals,
    MemoryType,
    WriteScalar,
    writeScalarPointerWithQuals,
    ReadScalar,
    readScalarPointerWithQuals,
    Allocate,
    Deallocate,
    Copy,
    StaticArray,

    -- * Monadic compiler interface
    CompilerM,
    CompilerState (compUserState, compNameSrc),
    getUserState,
    modifyUserState,
    contextContents,
    contextFinalInits,
    runCompilerM,
    inNewFunction,
    cachingMemory,
    blockScope,
    compileFun,
    compileCode,
    compileExp,
    compilePrimExp,
    compilePrimValue,
    compileExpToName,
    rawMem,
    item,
    items,
    stm,
    stms,
    decl,
    atInit,
    headerDecl,
    publicDef,
    publicDef_,
    profileReport,
    onClear,
    HeaderSection (..),
    libDecl,
    earlyDecl,
    publicName,
    contextType,
    contextField,
    memToCType,
    cacheMem,
    fatMemory,
    rawMemCType,
    cproduct,
    fatMemType,

    -- * Building Blocks
    primTypeToCType,
    intTypeToCType,
    copyMemoryDefaultSpace,
  )
where

import Control.Monad.Identity
import Control.Monad.Reader
import Control.Monad.State
import Data.Bifunctor (first)
import qualified Data.DList as DL
import Data.FileEmbed
import Data.Loc
import qualified Data.Map.Strict as M
import Data.Maybe
import Futhark.CodeGen.Backends.GenericC.CLI (cliDefs)
import Futhark.CodeGen.Backends.GenericC.Options
import Futhark.CodeGen.Backends.GenericC.Server (serverDefs)
import Futhark.CodeGen.Backends.SimpleRep
import Futhark.CodeGen.ImpCode
import Futhark.IR.Prop (isBuiltInFunction)
import Futhark.MonadFreshNames
import qualified Language.C.Quote.OpenCL as C
import qualified Language.C.Syntax as C

-- How public an array type definition sould be.  Public types show up
-- in the generated API, while private types are used only to
-- implement the members of opaques.
data Publicness = Private | Public
  deriving (Publicness -> Publicness -> Bool
(Publicness -> Publicness -> Bool)
-> (Publicness -> Publicness -> Bool) -> Eq Publicness
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Publicness -> Publicness -> Bool
$c/= :: Publicness -> Publicness -> Bool
== :: Publicness -> Publicness -> Bool
$c== :: Publicness -> Publicness -> Bool
Eq, Eq Publicness
Eq Publicness
-> (Publicness -> Publicness -> Ordering)
-> (Publicness -> Publicness -> Bool)
-> (Publicness -> Publicness -> Bool)
-> (Publicness -> Publicness -> Bool)
-> (Publicness -> Publicness -> Bool)
-> (Publicness -> Publicness -> Publicness)
-> (Publicness -> Publicness -> Publicness)
-> Ord Publicness
Publicness -> Publicness -> Bool
Publicness -> Publicness -> Ordering
Publicness -> Publicness -> Publicness
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Publicness -> Publicness -> Publicness
$cmin :: Publicness -> Publicness -> Publicness
max :: Publicness -> Publicness -> Publicness
$cmax :: Publicness -> Publicness -> Publicness
>= :: Publicness -> Publicness -> Bool
$c>= :: Publicness -> Publicness -> Bool
> :: Publicness -> Publicness -> Bool
$c> :: Publicness -> Publicness -> Bool
<= :: Publicness -> Publicness -> Bool
$c<= :: Publicness -> Publicness -> Bool
< :: Publicness -> Publicness -> Bool
$c< :: Publicness -> Publicness -> Bool
compare :: Publicness -> Publicness -> Ordering
$ccompare :: Publicness -> Publicness -> Ordering
Ord, Int -> Publicness -> ShowS
[Publicness] -> ShowS
Publicness -> String
(Int -> Publicness -> ShowS)
-> (Publicness -> String)
-> ([Publicness] -> ShowS)
-> Show Publicness
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Publicness] -> ShowS
$cshowList :: [Publicness] -> ShowS
show :: Publicness -> String
$cshow :: Publicness -> String
showsPrec :: Int -> Publicness -> ShowS
$cshowsPrec :: Int -> Publicness -> ShowS
Show)

type ArrayType = (Space, Signedness, PrimType, Int)

data CompilerState s = CompilerState
  { forall s. CompilerState s -> Map ArrayType Publicness
compArrayTypes :: M.Map ArrayType Publicness,
    forall s. CompilerState s -> Map String [ValueDesc]
compOpaqueTypes :: M.Map String [ValueDesc],
    forall s. CompilerState s -> DList Definition
compEarlyDecls :: DL.DList C.Definition,
    forall s. CompilerState s -> [Stm]
compInit :: [C.Stm],
    forall s. CompilerState s -> VNameSource
compNameSrc :: VNameSource,
    forall s. CompilerState s -> s
compUserState :: s,
    forall s. CompilerState s -> Map HeaderSection (DList Definition)
compHeaderDecls :: M.Map HeaderSection (DL.DList C.Definition),
    forall s. CompilerState s -> DList Definition
compLibDecls :: DL.DList C.Definition,
    forall s. CompilerState s -> DList (Id, Type, Maybe Exp)
compCtxFields :: DL.DList (C.Id, C.Type, Maybe C.Exp),
    forall s. CompilerState s -> DList BlockItem
compProfileItems :: DL.DList C.BlockItem,
    forall s. CompilerState s -> DList BlockItem
compClearItems :: DL.DList C.BlockItem,
    forall s. CompilerState s -> [(VName, Space)]
compDeclaredMem :: [(VName, Space)],
    forall s. CompilerState s -> DList BlockItem
compItems :: DL.DList C.BlockItem
  }

newCompilerState :: VNameSource -> s -> CompilerState s
newCompilerState :: forall s. VNameSource -> s -> CompilerState s
newCompilerState VNameSource
src s
s =
  CompilerState :: forall s.
Map ArrayType Publicness
-> Map String [ValueDesc]
-> DList Definition
-> [Stm]
-> VNameSource
-> s
-> Map HeaderSection (DList Definition)
-> DList Definition
-> DList (Id, Type, Maybe Exp)
-> DList BlockItem
-> DList BlockItem
-> [(VName, Space)]
-> DList BlockItem
-> CompilerState s
CompilerState
    { compArrayTypes :: Map ArrayType Publicness
compArrayTypes = Map ArrayType Publicness
forall a. Monoid a => a
mempty,
      compOpaqueTypes :: Map String [ValueDesc]
compOpaqueTypes = Map String [ValueDesc]
forall a. Monoid a => a
mempty,
      compEarlyDecls :: DList Definition
compEarlyDecls = DList Definition
forall a. Monoid a => a
mempty,
      compInit :: [Stm]
compInit = [],
      compNameSrc :: VNameSource
compNameSrc = VNameSource
src,
      compUserState :: s
compUserState = s
s,
      compHeaderDecls :: Map HeaderSection (DList Definition)
compHeaderDecls = Map HeaderSection (DList Definition)
forall a. Monoid a => a
mempty,
      compLibDecls :: DList Definition
compLibDecls = DList Definition
forall a. Monoid a => a
mempty,
      compCtxFields :: DList (Id, Type, Maybe Exp)
compCtxFields = DList (Id, Type, Maybe Exp)
forall a. Monoid a => a
mempty,
      compProfileItems :: DList BlockItem
compProfileItems = DList BlockItem
forall a. Monoid a => a
mempty,
      compClearItems :: DList BlockItem
compClearItems = DList BlockItem
forall a. Monoid a => a
mempty,
      compDeclaredMem :: [(VName, Space)]
compDeclaredMem = [(VName, Space)]
forall a. Monoid a => a
mempty,
      compItems :: DList BlockItem
compItems = DList BlockItem
forall a. Monoid a => a
mempty
    }

-- | In which part of the header file we put the declaration.  This is
-- to ensure that the header file remains structured and readable.
data HeaderSection
  = ArrayDecl String
  | OpaqueDecl String
  | EntryDecl
  | MiscDecl
  | InitDecl
  deriving (HeaderSection -> HeaderSection -> Bool
(HeaderSection -> HeaderSection -> Bool)
-> (HeaderSection -> HeaderSection -> Bool) -> Eq HeaderSection
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: HeaderSection -> HeaderSection -> Bool
$c/= :: HeaderSection -> HeaderSection -> Bool
== :: HeaderSection -> HeaderSection -> Bool
$c== :: HeaderSection -> HeaderSection -> Bool
Eq, Eq HeaderSection
Eq HeaderSection
-> (HeaderSection -> HeaderSection -> Ordering)
-> (HeaderSection -> HeaderSection -> Bool)
-> (HeaderSection -> HeaderSection -> Bool)
-> (HeaderSection -> HeaderSection -> Bool)
-> (HeaderSection -> HeaderSection -> Bool)
-> (HeaderSection -> HeaderSection -> HeaderSection)
-> (HeaderSection -> HeaderSection -> HeaderSection)
-> Ord HeaderSection
HeaderSection -> HeaderSection -> Bool
HeaderSection -> HeaderSection -> Ordering
HeaderSection -> HeaderSection -> HeaderSection
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: HeaderSection -> HeaderSection -> HeaderSection
$cmin :: HeaderSection -> HeaderSection -> HeaderSection
max :: HeaderSection -> HeaderSection -> HeaderSection
$cmax :: HeaderSection -> HeaderSection -> HeaderSection
>= :: HeaderSection -> HeaderSection -> Bool
$c>= :: HeaderSection -> HeaderSection -> Bool
> :: HeaderSection -> HeaderSection -> Bool
$c> :: HeaderSection -> HeaderSection -> Bool
<= :: HeaderSection -> HeaderSection -> Bool
$c<= :: HeaderSection -> HeaderSection -> Bool
< :: HeaderSection -> HeaderSection -> Bool
$c< :: HeaderSection -> HeaderSection -> Bool
compare :: HeaderSection -> HeaderSection -> Ordering
$ccompare :: HeaderSection -> HeaderSection -> Ordering
Ord)

-- | A substitute expression compiler, tried before the main
-- compilation function.
type OpCompiler op s = op -> CompilerM op s ()

type ErrorCompiler op s = ErrorMsg Exp -> String -> CompilerM op s ()

-- | The address space qualifiers for a pointer of the given type with
-- the given annotation.
type PointerQuals op s = String -> CompilerM op s [C.TypeQual]

-- | The type of a memory block in the given memory space.
type MemoryType op s = SpaceId -> CompilerM op s C.Type

-- | Write a scalar to the given memory block with the given element
-- index and in the given memory space.
type WriteScalar op s =
  C.Exp -> C.Exp -> C.Type -> SpaceId -> Volatility -> C.Exp -> CompilerM op s ()

-- | Read a scalar from the given memory block with the given element
-- index and in the given memory space.
type ReadScalar op s =
  C.Exp -> C.Exp -> C.Type -> SpaceId -> Volatility -> CompilerM op s C.Exp

-- | Allocate a memory block of the given size and with the given tag
-- in the given memory space, saving a reference in the given variable
-- name.
type Allocate op s =
  C.Exp ->
  C.Exp ->
  C.Exp ->
  SpaceId ->
  CompilerM op s ()

-- | De-allocate the given memory block with the given tag, which is
-- in the given memory space.
type Deallocate op s = C.Exp -> C.Exp -> SpaceId -> CompilerM op s ()

-- | Create a static array of values - initialised at load time.
type StaticArray op s = VName -> SpaceId -> PrimType -> ArrayContents -> CompilerM op s ()

-- | Copy from one memory block to another.
type Copy op s =
  C.Exp ->
  C.Exp ->
  Space ->
  C.Exp ->
  C.Exp ->
  Space ->
  C.Exp ->
  CompilerM op s ()

-- | Call a function.
type CallCompiler op s = [VName] -> Name -> [C.Exp] -> CompilerM op s ()

data Operations op s = Operations
  { forall op s. Operations op s -> WriteScalar op s
opsWriteScalar :: WriteScalar op s,
    forall op s. Operations op s -> ReadScalar op s
opsReadScalar :: ReadScalar op s,
    forall op s. Operations op s -> Allocate op s
opsAllocate :: Allocate op s,
    forall op s. Operations op s -> Deallocate op s
opsDeallocate :: Deallocate op s,
    forall op s. Operations op s -> Copy op s
opsCopy :: Copy op s,
    forall op s. Operations op s -> StaticArray op s
opsStaticArray :: StaticArray op s,
    forall op s. Operations op s -> MemoryType op s
opsMemoryType :: MemoryType op s,
    forall op s. Operations op s -> OpCompiler op s
opsCompiler :: OpCompiler op s,
    forall op s. Operations op s -> ErrorCompiler op s
opsError :: ErrorCompiler op s,
    forall op s. Operations op s -> CallCompiler op s
opsCall :: CallCompiler op s,
    -- | If true, use reference counting.  Otherwise, bare
    -- pointers.
    forall op s. Operations op s -> Bool
opsFatMemory :: Bool,
    -- | Code to bracket critical sections.
    forall op s. Operations op s -> ([BlockItem], [BlockItem])
opsCritical :: ([C.BlockItem], [C.BlockItem])
  }

defError :: ErrorCompiler op s
defError :: forall op s. ErrorCompiler op s
defError (ErrorMsg [ErrorMsgPart Exp]
parts) String
stacktrace = do
  [BlockItem]
free_all_mem <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
collect (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ ((VName, Space) -> CompilerM op s ())
-> [(VName, Space)] -> CompilerM op s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((VName -> Space -> CompilerM op s ())
-> (VName, Space) -> CompilerM op s ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry VName -> Space -> CompilerM op s ()
forall a op s. ToExp a => a -> Space -> CompilerM op s ()
unRefMem) ([(VName, Space)] -> CompilerM op s ())
-> CompilerM op s [(VName, Space)] -> CompilerM op s ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (CompilerState s -> [(VName, Space)])
-> CompilerM op s [(VName, Space)]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets CompilerState s -> [(VName, Space)]
forall s. CompilerState s -> [(VName, Space)]
compDeclaredMem
  let onPart :: ErrorMsgPart Exp -> CompilerM op s (a, Exp)
onPart (ErrorString String
s) = (a, Exp) -> CompilerM op s (a, Exp)
forall (m :: * -> *) a. Monad m => a -> m a
return (a
"%s", [C.cexp|$string:s|])
      onPart (ErrorInt32 Exp
x) = (a
"%d",) (Exp -> (a, Exp)) -> CompilerM op s Exp -> CompilerM op s (a, Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
x
      onPart (ErrorInt64 Exp
x) = (a
"%lld",) (Exp -> (a, Exp)) -> CompilerM op s Exp -> CompilerM op s (a, Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
x
  ([String]
formatstrs, [Exp]
formatargs) <- [(String, Exp)] -> ([String], [Exp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(String, Exp)] -> ([String], [Exp]))
-> CompilerM op s [(String, Exp)]
-> CompilerM op s ([String], [Exp])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ErrorMsgPart Exp -> CompilerM op s (String, Exp))
-> [ErrorMsgPart Exp] -> CompilerM op s [(String, Exp)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ErrorMsgPart Exp -> CompilerM op s (String, Exp)
forall {a} {op} {s}.
IsString a =>
ErrorMsgPart Exp -> CompilerM op s (a, Exp)
onPart [ErrorMsgPart Exp]
parts
  let formatstr :: String
formatstr = String
"Error: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ [String] -> String
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [String]
formatstrs String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"\n\nBacktrace:\n%s"
  [BlockItem] -> CompilerM op s ()
forall op s. [BlockItem] -> CompilerM op s ()
items
    [C.citems|ctx->error = msgprintf($string:formatstr, $args:formatargs, $string:stacktrace);
                  $items:free_all_mem
                  return 1;|]

defCall :: CallCompiler op s
defCall :: forall op s. CallCompiler op s
defCall [VName]
dests Name
fname [Exp]
args = do
  let out_args :: [Exp]
out_args = [[C.cexp|&$id:d|] | VName
d <- [VName]
dests]
      args' :: [Exp]
args'
        | Name -> Bool
isBuiltInFunction Name
fname = [Exp]
args
        | Bool
otherwise = [C.cexp|ctx|] Exp -> [Exp] -> [Exp]
forall a. a -> [a] -> [a]
: [Exp]
out_args [Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++ [Exp]
args
  case [VName]
dests of
    [VName
dest]
      | Name -> Bool
isBuiltInFunction Name
fname ->
        Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|$id:dest = $id:(funName fname)($args:args');|]
    [VName]
_ ->
      BlockItem -> CompilerM op s ()
forall op s. BlockItem -> CompilerM op s ()
item [C.citem|if ($id:(funName fname)($args:args') != 0) { err = 1; goto cleanup; }|]

-- | A set of operations that fail for every operation involving
-- non-default memory spaces.  Uses plain pointers and @malloc@ for
-- memory management.
defaultOperations :: Operations op s
defaultOperations :: forall op s. Operations op s
defaultOperations =
  Operations :: forall op s.
WriteScalar op s
-> ReadScalar op s
-> Allocate op s
-> Deallocate op s
-> Copy op s
-> StaticArray op s
-> MemoryType op s
-> OpCompiler op s
-> ErrorCompiler op s
-> CallCompiler op s
-> Bool
-> ([BlockItem], [BlockItem])
-> Operations op s
Operations
    { opsWriteScalar :: WriteScalar op s
opsWriteScalar = WriteScalar op s
forall {p} {p} {p} {p} {p} {a}. p -> p -> p -> p -> p -> a
defWriteScalar,
      opsReadScalar :: ReadScalar op s
opsReadScalar = ReadScalar op s
forall {p} {p} {p} {p} {a}. p -> p -> p -> p -> a
defReadScalar,
      opsAllocate :: Allocate op s
opsAllocate = Allocate op s
forall {p} {p} {p} {a}. p -> p -> p -> a
defAllocate,
      opsDeallocate :: Deallocate op s
opsDeallocate = Deallocate op s
forall {p} {p} {a}. p -> p -> a
defDeallocate,
      opsCopy :: Copy op s
opsCopy = Copy op s
forall {op} {s}.
Exp
-> Exp -> Space -> Exp -> Exp -> Space -> Exp -> CompilerM op s ()
defCopy,
      opsStaticArray :: StaticArray op s
opsStaticArray = StaticArray op s
forall {p} {p} {p} {p} {a}. p -> p -> p -> p -> a
defStaticArray,
      opsMemoryType :: MemoryType op s
opsMemoryType = MemoryType op s
forall {p} {a}. p -> a
defMemoryType,
      opsCompiler :: OpCompiler op s
opsCompiler = OpCompiler op s
forall {p} {a}. p -> a
defCompiler,
      opsFatMemory :: Bool
opsFatMemory = Bool
True,
      opsError :: ErrorCompiler op s
opsError = ErrorCompiler op s
forall op s. ErrorCompiler op s
defError,
      opsCall :: CallCompiler op s
opsCall = CallCompiler op s
forall op s. CallCompiler op s
defCall,
      opsCritical :: ([BlockItem], [BlockItem])
opsCritical = ([BlockItem], [BlockItem])
forall a. Monoid a => a
mempty
    }
  where
    defWriteScalar :: p -> p -> p -> p -> p -> a
defWriteScalar p
_ p
_ p
_ p
_ p
_ =
      String -> a
forall a. HasCallStack => String -> a
error String
"Cannot write to non-default memory space because I am dumb"
    defReadScalar :: p -> p -> p -> p -> a
defReadScalar p
_ p
_ p
_ p
_ =
      String -> a
forall a. HasCallStack => String -> a
error String
"Cannot read from non-default memory space"
    defAllocate :: p -> p -> p -> a
defAllocate p
_ p
_ p
_ =
      String -> a
forall a. HasCallStack => String -> a
error String
"Cannot allocate in non-default memory space"
    defDeallocate :: p -> p -> a
defDeallocate p
_ p
_ =
      String -> a
forall a. HasCallStack => String -> a
error String
"Cannot deallocate in non-default memory space"
    defCopy :: Exp
-> Exp -> Space -> Exp -> Exp -> Space -> Exp -> CompilerM op s ()
defCopy Exp
destmem Exp
destoffset Space
DefaultSpace Exp
srcmem Exp
srcoffset Space
DefaultSpace Exp
size =
      Exp -> Exp -> Exp -> Exp -> Exp -> CompilerM op s ()
forall op s. Exp -> Exp -> Exp -> Exp -> Exp -> CompilerM op s ()
copyMemoryDefaultSpace Exp
destmem Exp
destoffset Exp
srcmem Exp
srcoffset Exp
size
    defCopy Exp
_ Exp
_ Space
_ Exp
_ Exp
_ Space
_ Exp
_ =
      String -> CompilerM op s ()
forall a. HasCallStack => String -> a
error String
"Cannot copy to or from non-default memory space"
    defStaticArray :: p -> p -> p -> p -> a
defStaticArray p
_ p
_ p
_ p
_ =
      String -> a
forall a. HasCallStack => String -> a
error String
"Cannot create static array in non-default memory space"
    defMemoryType :: p -> a
defMemoryType p
_ =
      String -> a
forall a. HasCallStack => String -> a
error String
"Has no type for non-default memory space"
    defCompiler :: p -> a
defCompiler p
_ =
      String -> a
forall a. HasCallStack => String -> a
error String
"The default compiler cannot compile extended operations"

data CompilerEnv op s = CompilerEnv
  { forall op s. CompilerEnv op s -> Operations op s
envOperations :: Operations op s,
    -- | Mapping memory blocks to sizes.  These memory blocks are CPU
    -- memory that we know are used in particularly simple ways (no
    -- reference counting necessary).  To cut down on allocator
    -- pressure, we keep these allocations around for a long time, and
    -- record their sizes so we can reuse them if possible (and
    -- realloc() when needed).
    forall op s. CompilerEnv op s -> Map Exp VName
envCachedMem :: M.Map C.Exp VName
  }

envOpCompiler :: CompilerEnv op s -> OpCompiler op s
envOpCompiler :: forall op s. CompilerEnv op s -> OpCompiler op s
envOpCompiler = Operations op s -> OpCompiler op s
forall op s. Operations op s -> OpCompiler op s
opsCompiler (Operations op s -> OpCompiler op s)
-> (CompilerEnv op s -> Operations op s)
-> CompilerEnv op s
-> OpCompiler op s
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerEnv op s -> Operations op s
forall op s. CompilerEnv op s -> Operations op s
envOperations

envMemoryType :: CompilerEnv op s -> MemoryType op s
envMemoryType :: forall op s. CompilerEnv op s -> MemoryType op s
envMemoryType = Operations op s -> MemoryType op s
forall op s. Operations op s -> MemoryType op s
opsMemoryType (Operations op s -> MemoryType op s)
-> (CompilerEnv op s -> Operations op s)
-> CompilerEnv op s
-> MemoryType op s
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerEnv op s -> Operations op s
forall op s. CompilerEnv op s -> Operations op s
envOperations

envReadScalar :: CompilerEnv op s -> ReadScalar op s
envReadScalar :: forall op s. CompilerEnv op s -> ReadScalar op s
envReadScalar = Operations op s -> ReadScalar op s
forall op s. Operations op s -> ReadScalar op s
opsReadScalar (Operations op s -> ReadScalar op s)
-> (CompilerEnv op s -> Operations op s)
-> CompilerEnv op s
-> ReadScalar op s
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerEnv op s -> Operations op s
forall op s. CompilerEnv op s -> Operations op s
envOperations

envWriteScalar :: CompilerEnv op s -> WriteScalar op s
envWriteScalar :: forall op s. CompilerEnv op s -> WriteScalar op s
envWriteScalar = Operations op s -> WriteScalar op s
forall op s. Operations op s -> WriteScalar op s
opsWriteScalar (Operations op s -> WriteScalar op s)
-> (CompilerEnv op s -> Operations op s)
-> CompilerEnv op s
-> WriteScalar op s
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerEnv op s -> Operations op s
forall op s. CompilerEnv op s -> Operations op s
envOperations

envAllocate :: CompilerEnv op s -> Allocate op s
envAllocate :: forall op s. CompilerEnv op s -> Allocate op s
envAllocate = Operations op s -> Allocate op s
forall op s. Operations op s -> Allocate op s
opsAllocate (Operations op s -> Allocate op s)
-> (CompilerEnv op s -> Operations op s)
-> CompilerEnv op s
-> Allocate op s
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerEnv op s -> Operations op s
forall op s. CompilerEnv op s -> Operations op s
envOperations

envDeallocate :: CompilerEnv op s -> Deallocate op s
envDeallocate :: forall op s. CompilerEnv op s -> Deallocate op s
envDeallocate = Operations op s -> Deallocate op s
forall op s. Operations op s -> Deallocate op s
opsDeallocate (Operations op s -> Deallocate op s)
-> (CompilerEnv op s -> Operations op s)
-> CompilerEnv op s
-> Deallocate op s
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerEnv op s -> Operations op s
forall op s. CompilerEnv op s -> Operations op s
envOperations

envCopy :: CompilerEnv op s -> Copy op s
envCopy :: forall op s. CompilerEnv op s -> Copy op s
envCopy = Operations op s -> Copy op s
forall op s. Operations op s -> Copy op s
opsCopy (Operations op s -> Copy op s)
-> (CompilerEnv op s -> Operations op s)
-> CompilerEnv op s
-> Copy op s
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerEnv op s -> Operations op s
forall op s. CompilerEnv op s -> Operations op s
envOperations

envStaticArray :: CompilerEnv op s -> StaticArray op s
envStaticArray :: forall op s. CompilerEnv op s -> StaticArray op s
envStaticArray = Operations op s -> StaticArray op s
forall op s. Operations op s -> StaticArray op s
opsStaticArray (Operations op s -> StaticArray op s)
-> (CompilerEnv op s -> Operations op s)
-> CompilerEnv op s
-> StaticArray op s
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerEnv op s -> Operations op s
forall op s. CompilerEnv op s -> Operations op s
envOperations

envFatMemory :: CompilerEnv op s -> Bool
envFatMemory :: forall op s. CompilerEnv op s -> Bool
envFatMemory = Operations op s -> Bool
forall op s. Operations op s -> Bool
opsFatMemory (Operations op s -> Bool)
-> (CompilerEnv op s -> Operations op s)
-> CompilerEnv op s
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerEnv op s -> Operations op s
forall op s. CompilerEnv op s -> Operations op s
envOperations

initDecls, arrayDecls, opaqueDecls, entryDecls, miscDecls :: CompilerState s -> [C.Definition]
initDecls :: forall s. CompilerState s -> [Definition]
initDecls = ((HeaderSection, DList Definition) -> [Definition])
-> [(HeaderSection, DList Definition)] -> [Definition]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (DList Definition -> [Definition]
forall a. DList a -> [a]
DL.toList (DList Definition -> [Definition])
-> ((HeaderSection, DList Definition) -> DList Definition)
-> (HeaderSection, DList Definition)
-> [Definition]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HeaderSection, DList Definition) -> DList Definition
forall a b. (a, b) -> b
snd) ([(HeaderSection, DList Definition)] -> [Definition])
-> (CompilerState s -> [(HeaderSection, DList Definition)])
-> CompilerState s
-> [Definition]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((HeaderSection, DList Definition) -> Bool)
-> [(HeaderSection, DList Definition)]
-> [(HeaderSection, DList Definition)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((HeaderSection -> HeaderSection -> Bool
forall a. Eq a => a -> a -> Bool
== HeaderSection
InitDecl) (HeaderSection -> Bool)
-> ((HeaderSection, DList Definition) -> HeaderSection)
-> (HeaderSection, DList Definition)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HeaderSection, DList Definition) -> HeaderSection
forall a b. (a, b) -> a
fst) ([(HeaderSection, DList Definition)]
 -> [(HeaderSection, DList Definition)])
-> (CompilerState s -> [(HeaderSection, DList Definition)])
-> CompilerState s
-> [(HeaderSection, DList Definition)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map HeaderSection (DList Definition)
-> [(HeaderSection, DList Definition)]
forall k a. Map k a -> [(k, a)]
M.toList (Map HeaderSection (DList Definition)
 -> [(HeaderSection, DList Definition)])
-> (CompilerState s -> Map HeaderSection (DList Definition))
-> CompilerState s
-> [(HeaderSection, DList Definition)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerState s -> Map HeaderSection (DList Definition)
forall s. CompilerState s -> Map HeaderSection (DList Definition)
compHeaderDecls
arrayDecls :: forall s. CompilerState s -> [Definition]
arrayDecls = ((HeaderSection, DList Definition) -> [Definition])
-> [(HeaderSection, DList Definition)] -> [Definition]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (DList Definition -> [Definition]
forall a. DList a -> [a]
DL.toList (DList Definition -> [Definition])
-> ((HeaderSection, DList Definition) -> DList Definition)
-> (HeaderSection, DList Definition)
-> [Definition]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HeaderSection, DList Definition) -> DList Definition
forall a b. (a, b) -> b
snd) ([(HeaderSection, DList Definition)] -> [Definition])
-> (CompilerState s -> [(HeaderSection, DList Definition)])
-> CompilerState s
-> [Definition]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((HeaderSection, DList Definition) -> Bool)
-> [(HeaderSection, DList Definition)]
-> [(HeaderSection, DList Definition)]
forall a. (a -> Bool) -> [a] -> [a]
filter (HeaderSection -> Bool
isArrayDecl (HeaderSection -> Bool)
-> ((HeaderSection, DList Definition) -> HeaderSection)
-> (HeaderSection, DList Definition)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HeaderSection, DList Definition) -> HeaderSection
forall a b. (a, b) -> a
fst) ([(HeaderSection, DList Definition)]
 -> [(HeaderSection, DList Definition)])
-> (CompilerState s -> [(HeaderSection, DList Definition)])
-> CompilerState s
-> [(HeaderSection, DList Definition)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map HeaderSection (DList Definition)
-> [(HeaderSection, DList Definition)]
forall k a. Map k a -> [(k, a)]
M.toList (Map HeaderSection (DList Definition)
 -> [(HeaderSection, DList Definition)])
-> (CompilerState s -> Map HeaderSection (DList Definition))
-> CompilerState s
-> [(HeaderSection, DList Definition)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerState s -> Map HeaderSection (DList Definition)
forall s. CompilerState s -> Map HeaderSection (DList Definition)
compHeaderDecls
  where
    isArrayDecl :: HeaderSection -> Bool
isArrayDecl ArrayDecl {} = Bool
True
    isArrayDecl HeaderSection
_ = Bool
False
opaqueDecls :: forall s. CompilerState s -> [Definition]
opaqueDecls = ((HeaderSection, DList Definition) -> [Definition])
-> [(HeaderSection, DList Definition)] -> [Definition]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (DList Definition -> [Definition]
forall a. DList a -> [a]
DL.toList (DList Definition -> [Definition])
-> ((HeaderSection, DList Definition) -> DList Definition)
-> (HeaderSection, DList Definition)
-> [Definition]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HeaderSection, DList Definition) -> DList Definition
forall a b. (a, b) -> b
snd) ([(HeaderSection, DList Definition)] -> [Definition])
-> (CompilerState s -> [(HeaderSection, DList Definition)])
-> CompilerState s
-> [Definition]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((HeaderSection, DList Definition) -> Bool)
-> [(HeaderSection, DList Definition)]
-> [(HeaderSection, DList Definition)]
forall a. (a -> Bool) -> [a] -> [a]
filter (HeaderSection -> Bool
isOpaqueDecl (HeaderSection -> Bool)
-> ((HeaderSection, DList Definition) -> HeaderSection)
-> (HeaderSection, DList Definition)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HeaderSection, DList Definition) -> HeaderSection
forall a b. (a, b) -> a
fst) ([(HeaderSection, DList Definition)]
 -> [(HeaderSection, DList Definition)])
-> (CompilerState s -> [(HeaderSection, DList Definition)])
-> CompilerState s
-> [(HeaderSection, DList Definition)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map HeaderSection (DList Definition)
-> [(HeaderSection, DList Definition)]
forall k a. Map k a -> [(k, a)]
M.toList (Map HeaderSection (DList Definition)
 -> [(HeaderSection, DList Definition)])
-> (CompilerState s -> Map HeaderSection (DList Definition))
-> CompilerState s
-> [(HeaderSection, DList Definition)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerState s -> Map HeaderSection (DList Definition)
forall s. CompilerState s -> Map HeaderSection (DList Definition)
compHeaderDecls
  where
    isOpaqueDecl :: HeaderSection -> Bool
isOpaqueDecl OpaqueDecl {} = Bool
True
    isOpaqueDecl HeaderSection
_ = Bool
False
entryDecls :: forall s. CompilerState s -> [Definition]
entryDecls = ((HeaderSection, DList Definition) -> [Definition])
-> [(HeaderSection, DList Definition)] -> [Definition]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (DList Definition -> [Definition]
forall a. DList a -> [a]
DL.toList (DList Definition -> [Definition])
-> ((HeaderSection, DList Definition) -> DList Definition)
-> (HeaderSection, DList Definition)
-> [Definition]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HeaderSection, DList Definition) -> DList Definition
forall a b. (a, b) -> b
snd) ([(HeaderSection, DList Definition)] -> [Definition])
-> (CompilerState s -> [(HeaderSection, DList Definition)])
-> CompilerState s
-> [Definition]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((HeaderSection, DList Definition) -> Bool)
-> [(HeaderSection, DList Definition)]
-> [(HeaderSection, DList Definition)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((HeaderSection -> HeaderSection -> Bool
forall a. Eq a => a -> a -> Bool
== HeaderSection
EntryDecl) (HeaderSection -> Bool)
-> ((HeaderSection, DList Definition) -> HeaderSection)
-> (HeaderSection, DList Definition)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HeaderSection, DList Definition) -> HeaderSection
forall a b. (a, b) -> a
fst) ([(HeaderSection, DList Definition)]
 -> [(HeaderSection, DList Definition)])
-> (CompilerState s -> [(HeaderSection, DList Definition)])
-> CompilerState s
-> [(HeaderSection, DList Definition)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map HeaderSection (DList Definition)
-> [(HeaderSection, DList Definition)]
forall k a. Map k a -> [(k, a)]
M.toList (Map HeaderSection (DList Definition)
 -> [(HeaderSection, DList Definition)])
-> (CompilerState s -> Map HeaderSection (DList Definition))
-> CompilerState s
-> [(HeaderSection, DList Definition)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerState s -> Map HeaderSection (DList Definition)
forall s. CompilerState s -> Map HeaderSection (DList Definition)
compHeaderDecls
miscDecls :: forall s. CompilerState s -> [Definition]
miscDecls = ((HeaderSection, DList Definition) -> [Definition])
-> [(HeaderSection, DList Definition)] -> [Definition]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (DList Definition -> [Definition]
forall a. DList a -> [a]
DL.toList (DList Definition -> [Definition])
-> ((HeaderSection, DList Definition) -> DList Definition)
-> (HeaderSection, DList Definition)
-> [Definition]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HeaderSection, DList Definition) -> DList Definition
forall a b. (a, b) -> b
snd) ([(HeaderSection, DList Definition)] -> [Definition])
-> (CompilerState s -> [(HeaderSection, DList Definition)])
-> CompilerState s
-> [Definition]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((HeaderSection, DList Definition) -> Bool)
-> [(HeaderSection, DList Definition)]
-> [(HeaderSection, DList Definition)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((HeaderSection -> HeaderSection -> Bool
forall a. Eq a => a -> a -> Bool
== HeaderSection
MiscDecl) (HeaderSection -> Bool)
-> ((HeaderSection, DList Definition) -> HeaderSection)
-> (HeaderSection, DList Definition)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HeaderSection, DList Definition) -> HeaderSection
forall a b. (a, b) -> a
fst) ([(HeaderSection, DList Definition)]
 -> [(HeaderSection, DList Definition)])
-> (CompilerState s -> [(HeaderSection, DList Definition)])
-> CompilerState s
-> [(HeaderSection, DList Definition)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map HeaderSection (DList Definition)
-> [(HeaderSection, DList Definition)]
forall k a. Map k a -> [(k, a)]
M.toList (Map HeaderSection (DList Definition)
 -> [(HeaderSection, DList Definition)])
-> (CompilerState s -> Map HeaderSection (DList Definition))
-> CompilerState s
-> [(HeaderSection, DList Definition)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerState s -> Map HeaderSection (DList Definition)
forall s. CompilerState s -> Map HeaderSection (DList Definition)
compHeaderDecls

contextContents :: CompilerM op s ([C.FieldGroup], [C.Stm])
contextContents :: forall op s. CompilerM op s ([FieldGroup], [Stm])
contextContents = do
  ([Id]
field_names, [Type]
field_types, [Maybe Exp]
field_values) <- (CompilerState s -> ([Id], [Type], [Maybe Exp]))
-> CompilerM op s ([Id], [Type], [Maybe Exp])
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((CompilerState s -> ([Id], [Type], [Maybe Exp]))
 -> CompilerM op s ([Id], [Type], [Maybe Exp]))
-> (CompilerState s -> ([Id], [Type], [Maybe Exp]))
-> CompilerM op s ([Id], [Type], [Maybe Exp])
forall a b. (a -> b) -> a -> b
$ [(Id, Type, Maybe Exp)] -> ([Id], [Type], [Maybe Exp])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(Id, Type, Maybe Exp)] -> ([Id], [Type], [Maybe Exp]))
-> (CompilerState s -> [(Id, Type, Maybe Exp)])
-> CompilerState s
-> ([Id], [Type], [Maybe Exp])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DList (Id, Type, Maybe Exp) -> [(Id, Type, Maybe Exp)]
forall a. DList a -> [a]
DL.toList (DList (Id, Type, Maybe Exp) -> [(Id, Type, Maybe Exp)])
-> (CompilerState s -> DList (Id, Type, Maybe Exp))
-> CompilerState s
-> [(Id, Type, Maybe Exp)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerState s -> DList (Id, Type, Maybe Exp)
forall s. CompilerState s -> DList (Id, Type, Maybe Exp)
compCtxFields
  let fields :: [FieldGroup]
fields =
        [ [C.csdecl|$ty:ty $id:name;|]
          | (Id
name, Type
ty) <- [Id] -> [Type] -> [(Id, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Id]
field_names [Type]
field_types
        ]
      init_fields :: [Stm]
init_fields =
        [ [C.cstm|ctx->$id:name = $exp:e;|]
          | (Id
name, Just Exp
e) <- [Id] -> [Maybe Exp] -> [(Id, Maybe Exp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Id]
field_names [Maybe Exp]
field_values
        ]
  ([FieldGroup], [Stm]) -> CompilerM op s ([FieldGroup], [Stm])
forall (m :: * -> *) a. Monad m => a -> m a
return ([FieldGroup]
fields, [Stm]
init_fields)

contextFinalInits :: CompilerM op s [C.Stm]
contextFinalInits :: forall op s. CompilerM op s [Stm]
contextFinalInits = (CompilerState s -> [Stm]) -> CompilerM op s [Stm]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets CompilerState s -> [Stm]
forall s. CompilerState s -> [Stm]
compInit

newtype CompilerM op s a
  = CompilerM (ReaderT (CompilerEnv op s) (State (CompilerState s)) a)
  deriving
    ( (forall a b. (a -> b) -> CompilerM op s a -> CompilerM op s b)
-> (forall a b. a -> CompilerM op s b -> CompilerM op s a)
-> Functor (CompilerM op s)
forall a b. a -> CompilerM op s b -> CompilerM op s a
forall a b. (a -> b) -> CompilerM op s a -> CompilerM op s b
forall op s a b. a -> CompilerM op s b -> CompilerM op s a
forall op s a b. (a -> b) -> CompilerM op s a -> CompilerM op s b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> CompilerM op s b -> CompilerM op s a
$c<$ :: forall op s a b. a -> CompilerM op s b -> CompilerM op s a
fmap :: forall a b. (a -> b) -> CompilerM op s a -> CompilerM op s b
$cfmap :: forall op s a b. (a -> b) -> CompilerM op s a -> CompilerM op s b
Functor,
      Functor (CompilerM op s)
Functor (CompilerM op s)
-> (forall a. a -> CompilerM op s a)
-> (forall a b.
    CompilerM op s (a -> b) -> CompilerM op s a -> CompilerM op s b)
-> (forall a b c.
    (a -> b -> c)
    -> CompilerM op s a -> CompilerM op s b -> CompilerM op s c)
-> (forall a b.
    CompilerM op s a -> CompilerM op s b -> CompilerM op s b)
-> (forall a b.
    CompilerM op s a -> CompilerM op s b -> CompilerM op s a)
-> Applicative (CompilerM op s)
forall a. a -> CompilerM op s a
forall op s. Functor (CompilerM op s)
forall a b.
CompilerM op s a -> CompilerM op s b -> CompilerM op s a
forall a b.
CompilerM op s a -> CompilerM op s b -> CompilerM op s b
forall a b.
CompilerM op s (a -> b) -> CompilerM op s a -> CompilerM op s b
forall op s a. a -> CompilerM op s a
forall a b c.
(a -> b -> c)
-> CompilerM op s a -> CompilerM op s b -> CompilerM op s c
forall op s a b.
CompilerM op s a -> CompilerM op s b -> CompilerM op s a
forall op s a b.
CompilerM op s a -> CompilerM op s b -> CompilerM op s b
forall op s a b.
CompilerM op s (a -> b) -> CompilerM op s a -> CompilerM op s b
forall op s a b c.
(a -> b -> c)
-> CompilerM op s a -> CompilerM op s b -> CompilerM op s c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b.
CompilerM op s a -> CompilerM op s b -> CompilerM op s a
$c<* :: forall op s a b.
CompilerM op s a -> CompilerM op s b -> CompilerM op s a
*> :: forall a b.
CompilerM op s a -> CompilerM op s b -> CompilerM op s b
$c*> :: forall op s a b.
CompilerM op s a -> CompilerM op s b -> CompilerM op s b
liftA2 :: forall a b c.
(a -> b -> c)
-> CompilerM op s a -> CompilerM op s b -> CompilerM op s c
$cliftA2 :: forall op s a b c.
(a -> b -> c)
-> CompilerM op s a -> CompilerM op s b -> CompilerM op s c
<*> :: forall a b.
CompilerM op s (a -> b) -> CompilerM op s a -> CompilerM op s b
$c<*> :: forall op s a b.
CompilerM op s (a -> b) -> CompilerM op s a -> CompilerM op s b
pure :: forall a. a -> CompilerM op s a
$cpure :: forall op s a. a -> CompilerM op s a
Applicative,
      Applicative (CompilerM op s)
Applicative (CompilerM op s)
-> (forall a b.
    CompilerM op s a -> (a -> CompilerM op s b) -> CompilerM op s b)
-> (forall a b.
    CompilerM op s a -> CompilerM op s b -> CompilerM op s b)
-> (forall a. a -> CompilerM op s a)
-> Monad (CompilerM op s)
forall a. a -> CompilerM op s a
forall op s. Applicative (CompilerM op s)
forall a b.
CompilerM op s a -> CompilerM op s b -> CompilerM op s b
forall a b.
CompilerM op s a -> (a -> CompilerM op s b) -> CompilerM op s b
forall op s a. a -> CompilerM op s a
forall op s a b.
CompilerM op s a -> CompilerM op s b -> CompilerM op s b
forall op s a b.
CompilerM op s a -> (a -> CompilerM op s b) -> CompilerM op s b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> CompilerM op s a
$creturn :: forall op s a. a -> CompilerM op s a
>> :: forall a b.
CompilerM op s a -> CompilerM op s b -> CompilerM op s b
$c>> :: forall op s a b.
CompilerM op s a -> CompilerM op s b -> CompilerM op s b
>>= :: forall a b.
CompilerM op s a -> (a -> CompilerM op s b) -> CompilerM op s b
$c>>= :: forall op s a b.
CompilerM op s a -> (a -> CompilerM op s b) -> CompilerM op s b
Monad,
      MonadState (CompilerState s),
      MonadReader (CompilerEnv op s)
    )

instance MonadFreshNames (CompilerM op s) where
  getNameSource :: CompilerM op s VNameSource
getNameSource = (CompilerState s -> VNameSource) -> CompilerM op s VNameSource
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets CompilerState s -> VNameSource
forall s. CompilerState s -> VNameSource
compNameSrc
  putNameSource :: VNameSource -> CompilerM op s ()
putNameSource VNameSource
src = (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((CompilerState s -> CompilerState s) -> CompilerM op s ())
-> (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \CompilerState s
s -> CompilerState s
s {compNameSrc :: VNameSource
compNameSrc = VNameSource
src}

runCompilerM ::
  Operations op s ->
  VNameSource ->
  s ->
  CompilerM op s a ->
  (a, CompilerState s)
runCompilerM :: forall op s a.
Operations op s
-> VNameSource -> s -> CompilerM op s a -> (a, CompilerState s)
runCompilerM Operations op s
ops VNameSource
src s
userstate (CompilerM ReaderT (CompilerEnv op s) (State (CompilerState s)) a
m) =
  State (CompilerState s) a
-> CompilerState s -> (a, CompilerState s)
forall s a. State s a -> s -> (a, s)
runState
    (ReaderT (CompilerEnv op s) (State (CompilerState s)) a
-> CompilerEnv op s -> State (CompilerState s) a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (CompilerEnv op s) (State (CompilerState s)) a
m (Operations op s -> Map Exp VName -> CompilerEnv op s
forall op s. Operations op s -> Map Exp VName -> CompilerEnv op s
CompilerEnv Operations op s
ops Map Exp VName
forall a. Monoid a => a
mempty))
    (VNameSource -> s -> CompilerState s
forall s. VNameSource -> s -> CompilerState s
newCompilerState VNameSource
src s
userstate)

getUserState :: CompilerM op s s
getUserState :: forall op s. CompilerM op s s
getUserState = (CompilerState s -> s) -> CompilerM op s s
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets CompilerState s -> s
forall s. CompilerState s -> s
compUserState

modifyUserState :: (s -> s) -> CompilerM op s ()
modifyUserState :: forall s op. (s -> s) -> CompilerM op s ()
modifyUserState s -> s
f = (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((CompilerState s -> CompilerState s) -> CompilerM op s ())
-> (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \CompilerState s
compstate ->
  CompilerState s
compstate {compUserState :: s
compUserState = s -> s
f (s -> s) -> s -> s
forall a b. (a -> b) -> a -> b
$ CompilerState s -> s
forall s. CompilerState s -> s
compUserState CompilerState s
compstate}

atInit :: C.Stm -> CompilerM op s ()
atInit :: forall op s. Stm -> CompilerM op s ()
atInit Stm
x = (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((CompilerState s -> CompilerState s) -> CompilerM op s ())
-> (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \CompilerState s
s ->
  CompilerState s
s {compInit :: [Stm]
compInit = CompilerState s -> [Stm]
forall s. CompilerState s -> [Stm]
compInit CompilerState s
s [Stm] -> [Stm] -> [Stm]
forall a. [a] -> [a] -> [a]
++ [Stm
x]}

collect :: CompilerM op s () -> CompilerM op s [C.BlockItem]
collect :: forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
collect CompilerM op s ()
m = ((), [BlockItem]) -> [BlockItem]
forall a b. (a, b) -> b
snd (((), [BlockItem]) -> [BlockItem])
-> CompilerM op s ((), [BlockItem]) -> CompilerM op s [BlockItem]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CompilerM op s () -> CompilerM op s ((), [BlockItem])
forall op s a. CompilerM op s a -> CompilerM op s (a, [BlockItem])
collect' CompilerM op s ()
m

collect' :: CompilerM op s a -> CompilerM op s (a, [C.BlockItem])
collect' :: forall op s a. CompilerM op s a -> CompilerM op s (a, [BlockItem])
collect' CompilerM op s a
m = do
  DList BlockItem
old <- (CompilerState s -> DList BlockItem)
-> CompilerM op s (DList BlockItem)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets CompilerState s -> DList BlockItem
forall s. CompilerState s -> DList BlockItem
compItems
  (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((CompilerState s -> CompilerState s) -> CompilerM op s ())
-> (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \CompilerState s
s -> CompilerState s
s {compItems :: DList BlockItem
compItems = DList BlockItem
forall a. Monoid a => a
mempty}
  a
x <- CompilerM op s a
m
  DList BlockItem
new <- (CompilerState s -> DList BlockItem)
-> CompilerM op s (DList BlockItem)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets CompilerState s -> DList BlockItem
forall s. CompilerState s -> DList BlockItem
compItems
  (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((CompilerState s -> CompilerState s) -> CompilerM op s ())
-> (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \CompilerState s
s -> CompilerState s
s {compItems :: DList BlockItem
compItems = DList BlockItem
old}
  (a, [BlockItem]) -> CompilerM op s (a, [BlockItem])
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
x, DList BlockItem -> [BlockItem]
forall a. DList a -> [a]
DL.toList DList BlockItem
new)

-- | Used when we, inside an existing 'CompilerM' action, want to
-- generate code for a new function.  Use this so that the compiler
-- understands that previously declared memory doesn't need to be
-- freed inside this action.
inNewFunction :: Bool -> CompilerM op s a -> CompilerM op s a
inNewFunction :: forall op s a. Bool -> CompilerM op s a -> CompilerM op s a
inNewFunction Bool
keep_cached CompilerM op s a
m = do
  [(VName, Space)]
old_mem <- (CompilerState s -> [(VName, Space)])
-> CompilerM op s [(VName, Space)]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets CompilerState s -> [(VName, Space)]
forall s. CompilerState s -> [(VName, Space)]
compDeclaredMem
  (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((CompilerState s -> CompilerState s) -> CompilerM op s ())
-> (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \CompilerState s
s -> CompilerState s
s {compDeclaredMem :: [(VName, Space)]
compDeclaredMem = [(VName, Space)]
forall a. Monoid a => a
mempty}
  a
x <- (CompilerEnv op s -> CompilerEnv op s)
-> CompilerM op s a -> CompilerM op s a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local CompilerEnv op s -> CompilerEnv op s
forall {op} {s}. CompilerEnv op s -> CompilerEnv op s
noCached CompilerM op s a
m
  (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((CompilerState s -> CompilerState s) -> CompilerM op s ())
-> (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \CompilerState s
s -> CompilerState s
s {compDeclaredMem :: [(VName, Space)]
compDeclaredMem = [(VName, Space)]
old_mem}
  a -> CompilerM op s a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x
  where
    noCached :: CompilerEnv op s -> CompilerEnv op s
noCached CompilerEnv op s
env
      | Bool
keep_cached = CompilerEnv op s
env
      | Bool
otherwise = CompilerEnv op s
env {envCachedMem :: Map Exp VName
envCachedMem = Map Exp VName
forall a. Monoid a => a
mempty}

item :: C.BlockItem -> CompilerM op s ()
item :: forall op s. BlockItem -> CompilerM op s ()
item BlockItem
x = (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((CompilerState s -> CompilerState s) -> CompilerM op s ())
-> (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \CompilerState s
s -> CompilerState s
s {compItems :: DList BlockItem
compItems = DList BlockItem -> BlockItem -> DList BlockItem
forall a. DList a -> a -> DList a
DL.snoc (CompilerState s -> DList BlockItem
forall s. CompilerState s -> DList BlockItem
compItems CompilerState s
s) BlockItem
x}

items :: [C.BlockItem] -> CompilerM op s ()
items :: forall op s. [BlockItem] -> CompilerM op s ()
items [BlockItem]
xs = (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((CompilerState s -> CompilerState s) -> CompilerM op s ())
-> (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \CompilerState s
s -> CompilerState s
s {compItems :: DList BlockItem
compItems = DList BlockItem -> DList BlockItem -> DList BlockItem
forall a. DList a -> DList a -> DList a
DL.append (CompilerState s -> DList BlockItem
forall s. CompilerState s -> DList BlockItem
compItems CompilerState s
s) ([BlockItem] -> DList BlockItem
forall a. [a] -> DList a
DL.fromList [BlockItem]
xs)}

fatMemory :: Space -> CompilerM op s Bool
fatMemory :: forall op s. Space -> CompilerM op s Bool
fatMemory ScalarSpace {} = Bool -> CompilerM op s Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
fatMemory Space
_ = (CompilerEnv op s -> Bool) -> CompilerM op s Bool
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks CompilerEnv op s -> Bool
forall op s. CompilerEnv op s -> Bool
envFatMemory

cacheMem :: C.ToExp a => a -> CompilerM op s (Maybe VName)
cacheMem :: forall a op s. ToExp a => a -> CompilerM op s (Maybe VName)
cacheMem a
a = (CompilerEnv op s -> Maybe VName) -> CompilerM op s (Maybe VName)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((CompilerEnv op s -> Maybe VName) -> CompilerM op s (Maybe VName))
-> (CompilerEnv op s -> Maybe VName)
-> CompilerM op s (Maybe VName)
forall a b. (a -> b) -> a -> b
$ Exp -> Map Exp VName -> Maybe VName
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (a -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp a
a SrcLoc
forall a. IsLocation a => a
noLoc) (Map Exp VName -> Maybe VName)
-> (CompilerEnv op s -> Map Exp VName)
-> CompilerEnv op s
-> Maybe VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerEnv op s -> Map Exp VName
forall op s. CompilerEnv op s -> Map Exp VName
envCachedMem

-- | Construct a publicly visible definition using the specified name
-- as the template.  The first returned definition is put in the
-- header file, and the second is the implementation.  Returns the public
-- name.
publicDef ::
  String ->
  HeaderSection ->
  (String -> (C.Definition, C.Definition)) ->
  CompilerM op s String
publicDef :: forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s String
publicDef String
s HeaderSection
h String -> (Definition, Definition)
f = do
  String
s' <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName String
s
  let (Definition
pub, Definition
priv) = String -> (Definition, Definition)
f String
s'
  HeaderSection -> Definition -> CompilerM op s ()
forall op s. HeaderSection -> Definition -> CompilerM op s ()
headerDecl HeaderSection
h Definition
pub
  Definition -> CompilerM op s ()
forall op s. Definition -> CompilerM op s ()
earlyDecl Definition
priv
  String -> CompilerM op s String
forall (m :: * -> *) a. Monad m => a -> m a
return String
s'

-- | As 'publicDef', but ignores the public name.
publicDef_ ::
  String ->
  HeaderSection ->
  (String -> (C.Definition, C.Definition)) ->
  CompilerM op s ()
publicDef_ :: forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
publicDef_ String
s HeaderSection
h String -> (Definition, Definition)
f = CompilerM op s String -> CompilerM op s ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (CompilerM op s String -> CompilerM op s ())
-> CompilerM op s String -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s String
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s String
publicDef String
s HeaderSection
h String -> (Definition, Definition)
f

headerDecl :: HeaderSection -> C.Definition -> CompilerM op s ()
headerDecl :: forall op s. HeaderSection -> Definition -> CompilerM op s ()
headerDecl HeaderSection
sec Definition
def = (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((CompilerState s -> CompilerState s) -> CompilerM op s ())
-> (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \CompilerState s
s ->
  CompilerState s
s
    { compHeaderDecls :: Map HeaderSection (DList Definition)
compHeaderDecls =
        (DList Definition -> DList Definition -> DList Definition)
-> Map HeaderSection (DList Definition)
-> Map HeaderSection (DList Definition)
-> Map HeaderSection (DList Definition)
forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
M.unionWith
          DList Definition -> DList Definition -> DList Definition
forall a. Semigroup a => a -> a -> a
(<>)
          (CompilerState s -> Map HeaderSection (DList Definition)
forall s. CompilerState s -> Map HeaderSection (DList Definition)
compHeaderDecls CompilerState s
s)
          (HeaderSection
-> DList Definition -> Map HeaderSection (DList Definition)
forall k a. k -> a -> Map k a
M.singleton HeaderSection
sec (Definition -> DList Definition
forall a. a -> DList a
DL.singleton Definition
def))
    }

libDecl :: C.Definition -> CompilerM op s ()
libDecl :: forall op s. Definition -> CompilerM op s ()
libDecl Definition
def = (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((CompilerState s -> CompilerState s) -> CompilerM op s ())
-> (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \CompilerState s
s ->
  CompilerState s
s {compLibDecls :: DList Definition
compLibDecls = CompilerState s -> DList Definition
forall s. CompilerState s -> DList Definition
compLibDecls CompilerState s
s DList Definition -> DList Definition -> DList Definition
forall a. Semigroup a => a -> a -> a
<> Definition -> DList Definition
forall a. a -> DList a
DL.singleton Definition
def}

earlyDecl :: C.Definition -> CompilerM op s ()
earlyDecl :: forall op s. Definition -> CompilerM op s ()
earlyDecl Definition
def = (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((CompilerState s -> CompilerState s) -> CompilerM op s ())
-> (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \CompilerState s
s ->
  CompilerState s
s {compEarlyDecls :: DList Definition
compEarlyDecls = CompilerState s -> DList Definition
forall s. CompilerState s -> DList Definition
compEarlyDecls CompilerState s
s DList Definition -> DList Definition -> DList Definition
forall a. Semigroup a => a -> a -> a
<> Definition -> DList Definition
forall a. a -> DList a
DL.singleton Definition
def}

contextField :: C.Id -> C.Type -> Maybe C.Exp -> CompilerM op s ()
contextField :: forall op s. Id -> Type -> Maybe Exp -> CompilerM op s ()
contextField Id
name Type
ty Maybe Exp
initial = (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((CompilerState s -> CompilerState s) -> CompilerM op s ())
-> (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \CompilerState s
s ->
  CompilerState s
s {compCtxFields :: DList (Id, Type, Maybe Exp)
compCtxFields = CompilerState s -> DList (Id, Type, Maybe Exp)
forall s. CompilerState s -> DList (Id, Type, Maybe Exp)
compCtxFields CompilerState s
s DList (Id, Type, Maybe Exp)
-> DList (Id, Type, Maybe Exp) -> DList (Id, Type, Maybe Exp)
forall a. Semigroup a => a -> a -> a
<> (Id, Type, Maybe Exp) -> DList (Id, Type, Maybe Exp)
forall a. a -> DList a
DL.singleton (Id
name, Type
ty, Maybe Exp
initial)}

profileReport :: C.BlockItem -> CompilerM op s ()
profileReport :: forall op s. BlockItem -> CompilerM op s ()
profileReport BlockItem
x = (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((CompilerState s -> CompilerState s) -> CompilerM op s ())
-> (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \CompilerState s
s ->
  CompilerState s
s {compProfileItems :: DList BlockItem
compProfileItems = CompilerState s -> DList BlockItem
forall s. CompilerState s -> DList BlockItem
compProfileItems CompilerState s
s DList BlockItem -> DList BlockItem -> DList BlockItem
forall a. Semigroup a => a -> a -> a
<> BlockItem -> DList BlockItem
forall a. a -> DList a
DL.singleton BlockItem
x}

onClear :: C.BlockItem -> CompilerM op s ()
onClear :: forall op s. BlockItem -> CompilerM op s ()
onClear BlockItem
x = (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((CompilerState s -> CompilerState s) -> CompilerM op s ())
-> (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \CompilerState s
s ->
  CompilerState s
s {compClearItems :: DList BlockItem
compClearItems = CompilerState s -> DList BlockItem
forall s. CompilerState s -> DList BlockItem
compClearItems CompilerState s
s DList BlockItem -> DList BlockItem -> DList BlockItem
forall a. Semigroup a => a -> a -> a
<> BlockItem -> DList BlockItem
forall a. a -> DList a
DL.singleton BlockItem
x}

stm :: C.Stm -> CompilerM op s ()
stm :: forall op s. Stm -> CompilerM op s ()
stm Stm
s = BlockItem -> CompilerM op s ()
forall op s. BlockItem -> CompilerM op s ()
item [C.citem|$stm:s|]

stms :: [C.Stm] -> CompilerM op s ()
stms :: forall op s. [Stm] -> CompilerM op s ()
stms = (Stm -> CompilerM op s ()) -> [Stm] -> CompilerM op s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm

decl :: C.InitGroup -> CompilerM op s ()
decl :: forall op s. InitGroup -> CompilerM op s ()
decl InitGroup
x = BlockItem -> CompilerM op s ()
forall op s. BlockItem -> CompilerM op s ()
item [C.citem|$decl:x;|]

-- | Public names must have a consitent prefix.
publicName :: String -> CompilerM op s String
publicName :: forall op s. String -> CompilerM op s String
publicName String
s = String -> CompilerM op s String
forall (m :: * -> *) a. Monad m => a -> m a
return (String -> CompilerM op s String)
-> String -> CompilerM op s String
forall a b. (a -> b) -> a -> b
$ String
"futhark_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
s

-- | The generated code must define a struct with this name.
contextType :: CompilerM op s C.Type
contextType :: forall op s. CompilerM op s Type
contextType = do
  String
name <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName String
"context"
  Type -> CompilerM op s Type
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cty|struct $id:name|]

memToCType :: VName -> Space -> CompilerM op s C.Type
memToCType :: forall op s. VName -> Space -> CompilerM op s Type
memToCType VName
v Space
space = do
  Bool
refcount <- Space -> CompilerM op s Bool
forall op s. Space -> CompilerM op s Bool
fatMemory Space
space
  Bool
cached <- Maybe VName -> Bool
forall a. Maybe a -> Bool
isJust (Maybe VName -> Bool)
-> CompilerM op s (Maybe VName) -> CompilerM op s Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> CompilerM op s (Maybe VName)
forall a op s. ToExp a => a -> CompilerM op s (Maybe VName)
cacheMem VName
v
  if Bool
refcount Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
cached
    then Type -> CompilerM op s Type
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> CompilerM op s Type) -> Type -> CompilerM op s Type
forall a b. (a -> b) -> a -> b
$ Space -> Type
fatMemType Space
space
    else Space -> CompilerM op s Type
forall op s. Space -> CompilerM op s Type
rawMemCType Space
space

rawMemCType :: Space -> CompilerM op s C.Type
rawMemCType :: forall op s. Space -> CompilerM op s Type
rawMemCType Space
DefaultSpace = Type -> CompilerM op s Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
defaultMemBlockType
rawMemCType (Space String
sid) = CompilerM op s (CompilerM op s Type) -> CompilerM op s Type
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (CompilerM op s (CompilerM op s Type) -> CompilerM op s Type)
-> CompilerM op s (CompilerM op s Type) -> CompilerM op s Type
forall a b. (a -> b) -> a -> b
$ (CompilerEnv op s -> MemoryType op s)
-> CompilerM op s (MemoryType op s)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks CompilerEnv op s -> MemoryType op s
forall op s. CompilerEnv op s -> MemoryType op s
envMemoryType CompilerM op s (MemoryType op s)
-> CompilerM op s String -> CompilerM op s (CompilerM op s Type)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> String -> CompilerM op s String
forall (f :: * -> *) a. Applicative f => a -> f a
pure String
sid
rawMemCType (ScalarSpace [] PrimType
t) =
  Type -> CompilerM op s Type
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cty|$ty:(primTypeToCType t)[1]|]
rawMemCType (ScalarSpace [SubExp]
ds PrimType
t) =
  Type -> CompilerM op s Type
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cty|$ty:(primTypeToCType t)[$exp:(cproduct ds')]|]
  where
    ds' :: [Exp]
ds' = (SubExp -> Exp) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
`C.toExp` SrcLoc
forall a. IsLocation a => a
noLoc) [SubExp]
ds

fatMemType :: Space -> C.Type
fatMemType :: Space -> Type
fatMemType Space
space =
  [C.cty|struct $id:name|]
  where
    name :: String
name = case Space
space of
      Space String
sid -> String
"memblock_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
sid
      Space
_ -> String
"memblock"

fatMemSet :: Space -> String
fatMemSet :: Space -> String
fatMemSet (Space String
sid) = String
"memblock_set_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
sid
fatMemSet Space
_ = String
"memblock_set"

fatMemAlloc :: Space -> String
fatMemAlloc :: Space -> String
fatMemAlloc (Space String
sid) = String
"memblock_alloc_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
sid
fatMemAlloc Space
_ = String
"memblock_alloc"

fatMemUnRef :: Space -> String
fatMemUnRef :: Space -> String
fatMemUnRef (Space String
sid) = String
"memblock_unref_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
sid
fatMemUnRef Space
_ = String
"memblock_unref"

rawMem :: VName -> CompilerM op s C.Exp
rawMem :: forall op s. VName -> CompilerM op s Exp
rawMem VName
v = Bool -> VName -> Exp
forall a. ToExp a => Bool -> a -> Exp
rawMem' (Bool -> VName -> Exp)
-> CompilerM op s Bool -> CompilerM op s (VName -> Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CompilerM op s Bool
forall {op} {s}. CompilerM op s Bool
fat CompilerM op s (VName -> Exp)
-> CompilerM op s VName -> CompilerM op s Exp
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> VName -> CompilerM op s VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v
  where
    fat :: CompilerM op s Bool
fat = (CompilerEnv op s -> Bool -> Bool) -> CompilerM op s (Bool -> Bool)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (Bool -> Bool -> Bool
(&&) (Bool -> Bool -> Bool)
-> (CompilerEnv op s -> Bool) -> CompilerEnv op s -> Bool -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerEnv op s -> Bool
forall op s. CompilerEnv op s -> Bool
envFatMemory) CompilerM op s (Bool -> Bool)
-> CompilerM op s Bool -> CompilerM op s Bool
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Maybe VName -> Bool
forall a. Maybe a -> Bool
isNothing (Maybe VName -> Bool)
-> CompilerM op s (Maybe VName) -> CompilerM op s Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> CompilerM op s (Maybe VName)
forall a op s. ToExp a => a -> CompilerM op s (Maybe VName)
cacheMem VName
v)

rawMem' :: C.ToExp a => Bool -> a -> C.Exp
rawMem' :: forall a. ToExp a => Bool -> a -> Exp
rawMem' Bool
True a
e = [C.cexp|$exp:e.mem|]
rawMem' Bool
False a
e = [C.cexp|$exp:e|]

allocRawMem ::
  (C.ToExp a, C.ToExp b, C.ToExp c) =>
  a ->
  b ->
  Space ->
  c ->
  CompilerM op s ()
allocRawMem :: forall a b c op s.
(ToExp a, ToExp b, ToExp c) =>
a -> b -> Space -> c -> CompilerM op s ()
allocRawMem a
dest b
size Space
space c
desc = case Space
space of
  Space String
sid ->
    CompilerM op s (CompilerM op s ()) -> CompilerM op s ()
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (CompilerM op s (CompilerM op s ()) -> CompilerM op s ())
-> CompilerM op s (CompilerM op s ()) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$
      (CompilerEnv op s -> Allocate op s)
-> CompilerM op s (Allocate op s)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks CompilerEnv op s -> Allocate op s
forall op s. CompilerEnv op s -> Allocate op s
envAllocate CompilerM op s (Allocate op s)
-> CompilerM op s Exp
-> CompilerM op s (Exp -> Exp -> String -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Exp -> CompilerM op s Exp
forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cexp|$exp:dest|]
        CompilerM op s (Exp -> Exp -> String -> CompilerM op s ())
-> CompilerM op s Exp
-> CompilerM op s (Exp -> String -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Exp -> CompilerM op s Exp
forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cexp|$exp:size|]
        CompilerM op s (Exp -> String -> CompilerM op s ())
-> CompilerM op s Exp
-> CompilerM op s (String -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Exp -> CompilerM op s Exp
forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cexp|$exp:desc|]
        CompilerM op s (String -> CompilerM op s ())
-> CompilerM op s String -> CompilerM op s (CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> String -> CompilerM op s String
forall (f :: * -> *) a. Applicative f => a -> f a
pure String
sid
  Space
_ ->
    Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|$exp:dest = (char*) malloc($exp:size);|]

freeRawMem ::
  (C.ToExp a, C.ToExp b) =>
  a ->
  Space ->
  b ->
  CompilerM op s ()
freeRawMem :: forall a b op s.
(ToExp a, ToExp b) =>
a -> Space -> b -> CompilerM op s ()
freeRawMem a
mem Space
space b
desc =
  case Space
space of
    Space String
sid -> do
      Deallocate op s
free_mem <- (CompilerEnv op s -> Deallocate op s)
-> CompilerM op s (Deallocate op s)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks CompilerEnv op s -> Deallocate op s
forall op s. CompilerEnv op s -> Deallocate op s
envDeallocate
      Deallocate op s
free_mem [C.cexp|$exp:mem|] [C.cexp|$exp:desc|] String
sid
    Space
_ -> BlockItem -> CompilerM op s ()
forall op s. BlockItem -> CompilerM op s ()
item [C.citem|free($exp:mem);|]

defineMemorySpace :: Space -> CompilerM op s (C.Definition, [C.Definition], C.BlockItem)
defineMemorySpace :: forall op s.
Space -> CompilerM op s (Definition, [Definition], BlockItem)
defineMemorySpace Space
space = do
  Type
rm <- Space -> CompilerM op s Type
forall op s. Space -> CompilerM op s Type
rawMemCType Space
space
  let structdef :: Definition
structdef =
        [C.cedecl|struct $id:sname { int *references;
                                     $ty:rm mem;
                                     typename int64_t size;
                                     const char *desc; };|]

  Id -> Type -> Maybe Exp -> CompilerM op s ()
forall op s. Id -> Type -> Maybe Exp -> CompilerM op s ()
contextField Id
peakname [C.cty|typename int64_t|] (Maybe Exp -> CompilerM op s ()) -> Maybe Exp -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just [C.cexp|0|]
  Id -> Type -> Maybe Exp -> CompilerM op s ()
forall op s. Id -> Type -> Maybe Exp -> CompilerM op s ()
contextField Id
usagename [C.cty|typename int64_t|] (Maybe Exp -> CompilerM op s ()) -> Maybe Exp -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just [C.cexp|0|]

  -- Unreferencing a memory block consists of decreasing its reference
  -- count and freeing the corresponding memory if the count reaches
  -- zero.
  [BlockItem]
free <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
collect (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ Exp -> Space -> Exp -> CompilerM op s ()
forall a b op s.
(ToExp a, ToExp b) =>
a -> Space -> b -> CompilerM op s ()
freeRawMem [C.cexp|block->mem|] Space
space [C.cexp|desc|]
  Type
ctx_ty <- CompilerM op s Type
forall op s. CompilerM op s Type
contextType
  let unrefdef :: Definition
unrefdef =
        [C.cedecl|static int $id:(fatMemUnRef space) ($ty:ctx_ty *ctx, $ty:mty *block, const char *desc) {
  if (block->references != NULL) {
    *(block->references) -= 1;
    if (ctx->detail_memory) {
      fprintf(ctx->log, "Unreferencing block %s (allocated as %s) in %s: %d references remaining.\n",
                      desc, block->desc, $string:spacedesc, *(block->references));
    }
    if (*(block->references) == 0) {
      ctx->$id:usagename -= block->size;
      $items:free
      free(block->references);
      if (ctx->detail_memory) {
        fprintf(ctx->log, "%lld bytes freed (now allocated: %lld bytes)\n",
                (long long) block->size, (long long) ctx->$id:usagename);
      }
    }
    block->references = NULL;
  }
  return 0;
}|]

  -- When allocating a memory block we initialise the reference count to 1.
  [BlockItem]
alloc <-
    CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
collect (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$
      Exp -> Exp -> Space -> Exp -> CompilerM op s ()
forall a b c op s.
(ToExp a, ToExp b, ToExp c) =>
a -> b -> Space -> c -> CompilerM op s ()
allocRawMem [C.cexp|block->mem|] [C.cexp|size|] Space
space [C.cexp|desc|]
  let allocdef :: Definition
allocdef =
        [C.cedecl|static int $id:(fatMemAlloc space) ($ty:ctx_ty *ctx, $ty:mty *block, typename int64_t size, const char *desc) {
  if (size < 0) {
    futhark_panic(1, "Negative allocation of %lld bytes attempted for %s in %s.\n",
          (long long)size, desc, $string:spacedesc, ctx->$id:usagename);
  }
  int ret = $id:(fatMemUnRef space)(ctx, block, desc);

  ctx->$id:usagename += size;
  if (ctx->detail_memory) {
    fprintf(ctx->log, "Allocating %lld bytes for %s in %s (then allocated: %lld bytes)",
            (long long) size,
            desc, $string:spacedesc,
            (long long) ctx->$id:usagename);
  }
  if (ctx->$id:usagename > ctx->$id:peakname) {
    ctx->$id:peakname = ctx->$id:usagename;
    if (ctx->detail_memory) {
      fprintf(ctx->log, " (new peak).\n");
    }
  } else if (ctx->detail_memory) {
    fprintf(ctx->log, ".\n");
  }

  $items:alloc
  block->references = (int*) malloc(sizeof(int));
  *(block->references) = 1;
  block->size = size;
  block->desc = desc;
  return ret;
  }|]

  -- Memory setting - unreference the destination and increase the
  -- count of the source by one.
  let setdef :: Definition
setdef =
        [C.cedecl|static int $id:(fatMemSet space) ($ty:ctx_ty *ctx, $ty:mty *lhs, $ty:mty *rhs, const char *lhs_desc) {
  int ret = $id:(fatMemUnRef space)(ctx, lhs, lhs_desc);
  if (rhs->references != NULL) {
    (*(rhs->references))++;
  }
  *lhs = *rhs;
  return ret;
}
|]

  BlockItem -> CompilerM op s ()
forall op s. BlockItem -> CompilerM op s ()
onClear [C.citem|ctx->$id:peakname = 0;|]

  let peakmsg :: String
peakmsg = String
"Peak memory usage for " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
spacedesc String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
": %lld bytes.\n"
  (Definition, [Definition], BlockItem)
-> CompilerM op s (Definition, [Definition], BlockItem)
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( Definition
structdef,
      [Definition
unrefdef, Definition
allocdef, Definition
setdef],
      -- Do not report memory usage for DefaultSpace (CPU memory),
      -- because it would not be accurate anyway.  This whole
      -- tracking probably needs to be rethought.
      if Space
space Space -> Space -> Bool
forall a. Eq a => a -> a -> Bool
== Space
DefaultSpace
        then [C.citem|{}|]
        else [C.citem|str_builder(&builder, $string:peakmsg, (long long) ctx->$id:peakname);|]
    )
  where
    mty :: Type
mty = Space -> Type
fatMemType Space
space
    (Id
peakname, Id
usagename, Id
sname, String
spacedesc) = case Space
space of
      Space String
sid ->
        ( String -> SrcLoc -> Id
forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent (String
"peak_mem_usage_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
sid) SrcLoc
forall a. IsLocation a => a
noLoc,
          String -> SrcLoc -> Id
forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent (String
"cur_mem_usage_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
sid) SrcLoc
forall a. IsLocation a => a
noLoc,
          String -> SrcLoc -> Id
forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent (String
"memblock_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
sid) SrcLoc
forall a. IsLocation a => a
noLoc,
          String
"space '" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
sid String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"'"
        )
      Space
_ ->
        ( Id
"peak_mem_usage_default",
          Id
"cur_mem_usage_default",
          Id
"memblock",
          String
"default space"
        )

declMem :: VName -> Space -> CompilerM op s ()
declMem :: forall op s. VName -> Space -> CompilerM op s ()
declMem VName
name Space
space = do
  Bool
cached <- Maybe VName -> Bool
forall a. Maybe a -> Bool
isJust (Maybe VName -> Bool)
-> CompilerM op s (Maybe VName) -> CompilerM op s Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> CompilerM op s (Maybe VName)
forall a op s. ToExp a => a -> CompilerM op s (Maybe VName)
cacheMem VName
name
  Bool -> CompilerM op s () -> CompilerM op s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
cached (CompilerM op s () -> CompilerM op s ())
-> CompilerM op s () -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ do
    Type
ty <- VName -> Space -> CompilerM op s Type
forall op s. VName -> Space -> CompilerM op s Type
memToCType VName
name Space
space
    InitGroup -> CompilerM op s ()
forall op s. InitGroup -> CompilerM op s ()
decl [C.cdecl|$ty:ty $id:name;|]
    VName -> Space -> CompilerM op s ()
forall a op s. ToExp a => a -> Space -> CompilerM op s ()
resetMem VName
name Space
space
    (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((CompilerState s -> CompilerState s) -> CompilerM op s ())
-> (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \CompilerState s
s -> CompilerState s
s {compDeclaredMem :: [(VName, Space)]
compDeclaredMem = (VName
name, Space
space) (VName, Space) -> [(VName, Space)] -> [(VName, Space)]
forall a. a -> [a] -> [a]
: CompilerState s -> [(VName, Space)]
forall s. CompilerState s -> [(VName, Space)]
compDeclaredMem CompilerState s
s}

resetMem :: C.ToExp a => a -> Space -> CompilerM op s ()
resetMem :: forall a op s. ToExp a => a -> Space -> CompilerM op s ()
resetMem a
mem Space
space = do
  Bool
refcount <- Space -> CompilerM op s Bool
forall op s. Space -> CompilerM op s Bool
fatMemory Space
space
  Bool
cached <- Maybe VName -> Bool
forall a. Maybe a -> Bool
isJust (Maybe VName -> Bool)
-> CompilerM op s (Maybe VName) -> CompilerM op s Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> CompilerM op s (Maybe VName)
forall a op s. ToExp a => a -> CompilerM op s (Maybe VName)
cacheMem a
mem
  if Bool
cached
    then Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|$exp:mem = NULL;|]
    else
      Bool -> CompilerM op s () -> CompilerM op s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
refcount (CompilerM op s () -> CompilerM op s ())
-> CompilerM op s () -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$
        Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|$exp:mem.references = NULL;|]

setMem :: (C.ToExp a, C.ToExp b) => a -> b -> Space -> CompilerM op s ()
setMem :: forall a b op s.
(ToExp a, ToExp b) =>
a -> b -> Space -> CompilerM op s ()
setMem a
dest b
src Space
space = do
  Bool
refcount <- Space -> CompilerM op s Bool
forall op s. Space -> CompilerM op s Bool
fatMemory Space
space
  let src_s :: String
src_s = Exp -> String
forall a. Pretty a => a -> String
pretty (Exp -> String) -> Exp -> String
forall a b. (a -> b) -> a -> b
$ b -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp b
src SrcLoc
forall a. IsLocation a => a
noLoc
  if Bool
refcount
    then
      Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm
        [C.cstm|if ($id:(fatMemSet space)(ctx, &$exp:dest, &$exp:src,
                                               $string:src_s) != 0) {
                       return 1;
                     }|]
    else case Space
space of
      ScalarSpace [SubExp]
ds PrimType
_ -> do
        VName
i' <- String -> CompilerM op s VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"i"
        let i :: SrcLoc -> Id
i = VName -> SrcLoc -> Id
forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent VName
i'
            it :: Type
it = PrimType -> Type
primTypeToCType (PrimType -> Type) -> PrimType -> Type
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int32
            ds' :: [Exp]
ds' = (SubExp -> Exp) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
`C.toExp` SrcLoc
forall a. IsLocation a => a
noLoc) [SubExp]
ds
            bound :: Exp
bound = [Exp] -> Exp
cproduct [Exp]
ds'
        Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm
          [C.cstm|for ($ty:it $id:i = 0; $id:i < $exp:bound; $id:i++) {
                            $exp:dest[$id:i] = $exp:src[$id:i];
                  }|]
      Space
_ -> Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|$exp:dest = $exp:src;|]

unRefMem :: C.ToExp a => a -> Space -> CompilerM op s ()
unRefMem :: forall a op s. ToExp a => a -> Space -> CompilerM op s ()
unRefMem a
mem Space
space = do
  Bool
refcount <- Space -> CompilerM op s Bool
forall op s. Space -> CompilerM op s Bool
fatMemory Space
space
  Bool
cached <- Maybe VName -> Bool
forall a. Maybe a -> Bool
isJust (Maybe VName -> Bool)
-> CompilerM op s (Maybe VName) -> CompilerM op s Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> CompilerM op s (Maybe VName)
forall a op s. ToExp a => a -> CompilerM op s (Maybe VName)
cacheMem a
mem
  let mem_s :: String
mem_s = Exp -> String
forall a. Pretty a => a -> String
pretty (Exp -> String) -> Exp -> String
forall a b. (a -> b) -> a -> b
$ a -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp a
mem SrcLoc
forall a. IsLocation a => a
noLoc
  Bool -> CompilerM op s () -> CompilerM op s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool
refcount Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
cached) (CompilerM op s () -> CompilerM op s ())
-> CompilerM op s () -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$
    Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm
      [C.cstm|if ($id:(fatMemUnRef space)(ctx, &$exp:mem, $string:mem_s) != 0) {
                  return 1;
                }|]

allocMem ::
  (C.ToExp a, C.ToExp b) =>
  a ->
  b ->
  Space ->
  C.Stm ->
  CompilerM op s ()
allocMem :: forall a b op s.
(ToExp a, ToExp b) =>
a -> b -> Space -> Stm -> CompilerM op s ()
allocMem a
mem b
size Space
space Stm
on_failure = do
  Bool
refcount <- Space -> CompilerM op s Bool
forall op s. Space -> CompilerM op s Bool
fatMemory Space
space
  let mem_s :: String
mem_s = Exp -> String
forall a. Pretty a => a -> String
pretty (Exp -> String) -> Exp -> String
forall a b. (a -> b) -> a -> b
$ a -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp a
mem SrcLoc
forall a. IsLocation a => a
noLoc
  if Bool
refcount
    then
      Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm
        [C.cstm|if ($id:(fatMemAlloc space)(ctx, &$exp:mem, $exp:size,
                                                 $string:mem_s)) {
                       $stm:on_failure
                     }|]
    else do
      a -> Space -> String -> CompilerM op s ()
forall a b op s.
(ToExp a, ToExp b) =>
a -> Space -> b -> CompilerM op s ()
freeRawMem a
mem Space
space String
mem_s
      a -> b -> Space -> Exp -> CompilerM op s ()
forall a b c op s.
(ToExp a, ToExp b, ToExp c) =>
a -> b -> Space -> c -> CompilerM op s ()
allocRawMem a
mem b
size Space
space [C.cexp|desc|]

copyMemoryDefaultSpace ::
  C.Exp ->
  C.Exp ->
  C.Exp ->
  C.Exp ->
  C.Exp ->
  CompilerM op s ()
copyMemoryDefaultSpace :: forall op s. Exp -> Exp -> Exp -> Exp -> Exp -> CompilerM op s ()
copyMemoryDefaultSpace Exp
destmem Exp
destidx Exp
srcmem Exp
srcidx Exp
nbytes =
  Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm
    [C.cstm|memmove($exp:destmem + $exp:destidx,
                      $exp:srcmem + $exp:srcidx,
                      $exp:nbytes);|]

--- Entry points.

criticalSection :: Operations op s -> [C.BlockItem] -> [C.BlockItem]
criticalSection :: forall op s. Operations op s -> [BlockItem] -> [BlockItem]
criticalSection Operations op s
ops [BlockItem]
x =
  [C.citems|lock_lock(&ctx->lock);
            $items:(fst (opsCritical ops))
            $items:x
            $items:(snd (opsCritical ops))
            lock_unlock(&ctx->lock);
           |]

arrayLibraryFunctions ::
  Publicness ->
  Space ->
  PrimType ->
  Signedness ->
  Int ->
  CompilerM op s [C.Definition]
arrayLibraryFunctions :: forall op s.
Publicness
-> Space
-> PrimType
-> Signedness
-> Int
-> CompilerM op s [Definition]
arrayLibraryFunctions Publicness
pub Space
space PrimType
pt Signedness
signed Int
rank = do
  let pt' :: Type
pt' = Signedness -> PrimType -> Type
signedPrimTypeToCType Signedness
signed PrimType
pt
      name :: String
name = PrimType -> Signedness -> Int -> String
arrayName PrimType
pt Signedness
signed Int
rank
      arr_name :: String
arr_name = String
"futhark_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
name
      array_type :: Type
array_type = [C.cty|struct $id:arr_name|]

  String
new_array <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName (String -> CompilerM op s String)
-> String -> CompilerM op s String
forall a b. (a -> b) -> a -> b
$ String
"new_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
name
  String
new_raw_array <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName (String -> CompilerM op s String)
-> String -> CompilerM op s String
forall a b. (a -> b) -> a -> b
$ String
"new_raw_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
name
  String
free_array <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName (String -> CompilerM op s String)
-> String -> CompilerM op s String
forall a b. (a -> b) -> a -> b
$ String
"free_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
name
  String
values_array <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName (String -> CompilerM op s String)
-> String -> CompilerM op s String
forall a b. (a -> b) -> a -> b
$ String
"values_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
name
  String
values_raw_array <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName (String -> CompilerM op s String)
-> String -> CompilerM op s String
forall a b. (a -> b) -> a -> b
$ String
"values_raw_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
name
  String
shape_array <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName (String -> CompilerM op s String)
-> String -> CompilerM op s String
forall a b. (a -> b) -> a -> b
$ String
"shape_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
name

  let shape_names :: [String]
shape_names = [String
"dim" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i | Int
i <- [Int
0 .. Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]]
      shape_params :: [Param]
shape_params = [[C.cparam|typename int64_t $id:k|] | String
k <- [String]
shape_names]
      arr_size :: Exp
arr_size = [Exp] -> Exp
cproduct [[C.cexp|$id:k|] | String
k <- [String]
shape_names]
      arr_size_array :: Exp
arr_size_array = [Exp] -> Exp
cproduct [[C.cexp|arr->shape[$int:i]|] | Int
i <- [Int
0 .. Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]]
  Copy op s
copy <- (CompilerEnv op s -> Copy op s) -> CompilerM op s (Copy op s)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks CompilerEnv op s -> Copy op s
forall op s. CompilerEnv op s -> Copy op s
envCopy

  Type
memty <- Space -> CompilerM op s Type
forall op s. Space -> CompilerM op s Type
rawMemCType Space
space

  let prepare_new :: CompilerM op s ()
prepare_new = do
        Exp -> Space -> CompilerM op s ()
forall a op s. ToExp a => a -> Space -> CompilerM op s ()
resetMem [C.cexp|arr->mem|] Space
space
        Exp -> Exp -> Space -> Stm -> CompilerM op s ()
forall a b op s.
(ToExp a, ToExp b) =>
a -> b -> Space -> Stm -> CompilerM op s ()
allocMem
          [C.cexp|arr->mem|]
          [C.cexp|((size_t)$exp:arr_size) * $int:(primByteSize pt::Int)|]
          Space
space
          [C.cstm|return NULL;|]
        [Int] -> (Int -> CompilerM op s ()) -> CompilerM op s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0 .. Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> CompilerM op s ()) -> CompilerM op s ())
-> (Int -> CompilerM op s ()) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \Int
i ->
          let dim_s :: String
dim_s = String
"dim" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i
           in Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|arr->shape[$int:i] = $id:dim_s;|]

  [BlockItem]
new_body <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
collect (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ do
    CompilerM op s ()
forall {op} {s}. CompilerM op s ()
prepare_new
    Copy op s
copy
      [C.cexp|arr->mem.mem|]
      [C.cexp|0|]
      Space
space
      [C.cexp|data|]
      [C.cexp|0|]
      Space
DefaultSpace
      [C.cexp|((size_t)$exp:arr_size) * $int:(primByteSize pt::Int)|]

  [BlockItem]
new_raw_body <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
collect (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ do
    CompilerM op s ()
forall {op} {s}. CompilerM op s ()
prepare_new
    Copy op s
copy
      [C.cexp|arr->mem.mem|]
      [C.cexp|0|]
      Space
space
      [C.cexp|data|]
      [C.cexp|offset|]
      Space
space
      [C.cexp|((size_t)$exp:arr_size) * $int:(primByteSize pt::Int)|]

  [BlockItem]
free_body <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
collect (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ Exp -> Space -> CompilerM op s ()
forall a op s. ToExp a => a -> Space -> CompilerM op s ()
unRefMem [C.cexp|arr->mem|] Space
space

  [BlockItem]
values_body <-
    CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
collect (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$
      Copy op s
copy
        [C.cexp|data|]
        [C.cexp|0|]
        Space
DefaultSpace
        [C.cexp|arr->mem.mem|]
        [C.cexp|0|]
        Space
space
        [C.cexp|((size_t)$exp:arr_size_array) * $int:(primByteSize pt::Int)|]

  Type
ctx_ty <- CompilerM op s Type
forall op s. CompilerM op s Type
contextType
  Operations op s
ops <- (CompilerEnv op s -> Operations op s)
-> CompilerM op s (Operations op s)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks CompilerEnv op s -> Operations op s
forall op s. CompilerEnv op s -> Operations op s
envOperations

  let proto :: Definition -> CompilerM op s ()
proto = case Publicness
pub of
        Publicness
Public -> HeaderSection -> Definition -> CompilerM op s ()
forall op s. HeaderSection -> Definition -> CompilerM op s ()
headerDecl (String -> HeaderSection
ArrayDecl String
name)
        Publicness
Private -> Definition -> CompilerM op s ()
forall op s. Definition -> CompilerM op s ()
libDecl

  Definition -> CompilerM op s ()
forall op s. Definition -> CompilerM op s ()
proto
    [C.cedecl|struct $id:arr_name;|]
  Definition -> CompilerM op s ()
forall op s. Definition -> CompilerM op s ()
proto
    [C.cedecl|$ty:array_type* $id:new_array($ty:ctx_ty *ctx, const $ty:pt' *data, $params:shape_params);|]
  Definition -> CompilerM op s ()
forall op s. Definition -> CompilerM op s ()
proto
    [C.cedecl|$ty:array_type* $id:new_raw_array($ty:ctx_ty *ctx, const $ty:memty data, int offset, $params:shape_params);|]
  Definition -> CompilerM op s ()
forall op s. Definition -> CompilerM op s ()
proto
    [C.cedecl|int $id:free_array($ty:ctx_ty *ctx, $ty:array_type *arr);|]
  Definition -> CompilerM op s ()
forall op s. Definition -> CompilerM op s ()
proto
    [C.cedecl|int $id:values_array($ty:ctx_ty *ctx, $ty:array_type *arr, $ty:pt' *data);|]
  Definition -> CompilerM op s ()
forall op s. Definition -> CompilerM op s ()
proto
    [C.cedecl|$ty:memty $id:values_raw_array($ty:ctx_ty *ctx, $ty:array_type *arr);|]
  Definition -> CompilerM op s ()
forall op s. Definition -> CompilerM op s ()
proto
    [C.cedecl|const typename int64_t* $id:shape_array($ty:ctx_ty *ctx, $ty:array_type *arr);|]

  [Definition] -> CompilerM op s [Definition]
forall (m :: * -> *) a. Monad m => a -> m a
return
    [C.cunit|
          $ty:array_type* $id:new_array($ty:ctx_ty *ctx, const $ty:pt' *data, $params:shape_params) {
            $ty:array_type* bad = NULL;
            $ty:array_type *arr = ($ty:array_type*) malloc(sizeof($ty:array_type));
            if (arr == NULL) {
              return bad;
            }
            $items:(criticalSection ops new_body)
            return arr;
          }

          $ty:array_type* $id:new_raw_array($ty:ctx_ty *ctx, const $ty:memty data, int offset,
                                            $params:shape_params) {
            $ty:array_type* bad = NULL;
            $ty:array_type *arr = ($ty:array_type*) malloc(sizeof($ty:array_type));
            if (arr == NULL) {
              return bad;
            }
            $items:(criticalSection ops new_raw_body)
            return arr;
          }

          int $id:free_array($ty:ctx_ty *ctx, $ty:array_type *arr) {
            $items:(criticalSection ops free_body)
            free(arr);
            return 0;
          }

          int $id:values_array($ty:ctx_ty *ctx, $ty:array_type *arr, $ty:pt' *data) {
            $items:(criticalSection ops values_body)
            return 0;
          }

          $ty:memty $id:values_raw_array($ty:ctx_ty *ctx, $ty:array_type *arr) {
            (void)ctx;
            return arr->mem.mem;
          }

          const typename int64_t* $id:shape_array($ty:ctx_ty *ctx, $ty:array_type *arr) {
            (void)ctx;
            return arr->shape;
          }
          |]

opaqueLibraryFunctions ::
  String ->
  [ValueDesc] ->
  CompilerM op s [C.Definition]
opaqueLibraryFunctions :: forall op s. String -> [ValueDesc] -> CompilerM op s [Definition]
opaqueLibraryFunctions String
desc [ValueDesc]
vds = do
  String
name <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName (String -> CompilerM op s String)
-> String -> CompilerM op s String
forall a b. (a -> b) -> a -> b
$ String -> [ValueDesc] -> String
opaqueName String
desc [ValueDesc]
vds
  String
free_opaque <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName (String -> CompilerM op s String)
-> String -> CompilerM op s String
forall a b. (a -> b) -> a -> b
$ String
"free_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String -> [ValueDesc] -> String
opaqueName String
desc [ValueDesc]
vds
  String
store_opaque <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName (String -> CompilerM op s String)
-> String -> CompilerM op s String
forall a b. (a -> b) -> a -> b
$ String
"store_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String -> [ValueDesc] -> String
opaqueName String
desc [ValueDesc]
vds
  String
restore_opaque <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName (String -> CompilerM op s String)
-> String -> CompilerM op s String
forall a b. (a -> b) -> a -> b
$ String
"restore_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String -> [ValueDesc] -> String
opaqueName String
desc [ValueDesc]
vds

  let opaque_type :: Type
opaque_type = [C.cty|struct $id:name|]

      freeComponent :: Int -> ValueDesc -> CompilerM op s ()
freeComponent Int
_ ScalarValue {} =
        () -> CompilerM op s ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
      freeComponent Int
i (ArrayValue VName
_ Space
_ PrimType
pt Signedness
signed [SubExp]
shape) = do
        let rank :: Int
rank = [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
shape
            field :: String
field = Int -> String
tupleField Int
i
        String
free_array <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName (String -> CompilerM op s String)
-> String -> CompilerM op s String
forall a b. (a -> b) -> a -> b
$ String
"free_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ PrimType -> Signedness -> Int -> String
arrayName PrimType
pt Signedness
signed Int
rank
        -- Protect against NULL here, because we also want to use this
        -- to free partially loaded opaques.
        Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm
          [C.cstm|if (obj->$id:field != NULL && (tmp = $id:free_array(ctx, obj->$id:field)) != 0) {
                ret = tmp;
             }|]

      storeComponent :: Int -> ValueDesc -> (Exp, [Stm])
storeComponent Int
i (ScalarValue PrimType
pt Signedness
sign VName
_) =
        let field :: String
field = Int -> String
tupleField Int
i
         in ( PrimType -> Int -> Exp -> Exp
storageSize PrimType
pt Int
0 [C.cexp|NULL|],
              Signedness -> PrimType -> Int -> Exp -> Exp -> [Stm]
storeValueHeader Signedness
sign PrimType
pt Int
0 [C.cexp|NULL|] [C.cexp|out|]
                [Stm] -> [Stm] -> [Stm]
forall a. [a] -> [a] -> [a]
++ [C.cstms|memcpy(out, &obj->$id:field, sizeof(obj->$id:field));
                            out += sizeof(obj->$id:field);|]
            )
      storeComponent Int
i (ArrayValue VName
_ Space
_ PrimType
pt Signedness
sign [SubExp]
shape) =
        let rank :: Int
rank = [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
shape
            arr_name :: String
arr_name = PrimType -> Signedness -> Int -> String
arrayName PrimType
pt Signedness
sign Int
rank
            field :: String
field = Int -> String
tupleField Int
i
            shape_array :: String
shape_array = String
"futhark_shape_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
arr_name
            values_array :: String
values_array = String
"futhark_values_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
arr_name
            shape' :: Exp
shape' = [C.cexp|$id:shape_array(ctx, obj->$id:field)|]
            num_elems :: Exp
num_elems = [Exp] -> Exp
cproduct [[C.cexp|$exp:shape'[$int:j]|] | Int
j <- [Int
0 .. Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]]
         in ( PrimType -> Int -> Exp -> Exp
storageSize PrimType
pt Int
rank Exp
shape',
              Signedness -> PrimType -> Int -> Exp -> Exp -> [Stm]
storeValueHeader Signedness
sign PrimType
pt Int
rank Exp
shape' [C.cexp|out|]
                [Stm] -> [Stm] -> [Stm]
forall a. [a] -> [a] -> [a]
++ [C.cstms|ret |= $id:values_array(ctx, obj->$id:field, (void*)out);
                            out += $exp:num_elems * $int:(primByteSize pt::Int);|]
            )

  Type
ctx_ty <- CompilerM op s Type
forall op s. CompilerM op s Type
contextType

  [BlockItem]
free_body <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
collect (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ (Int -> ValueDesc -> CompilerM op s ())
-> [Int] -> [ValueDesc] -> CompilerM op s ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Int -> ValueDesc -> CompilerM op s ()
forall {op} {s}. Int -> ValueDesc -> CompilerM op s ()
freeComponent [Int
0 ..] [ValueDesc]
vds

  [BlockItem]
store_body <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
collect (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ do
    let ([Exp]
sizes, [[Stm]]
stores) = [(Exp, [Stm])] -> ([Exp], [[Stm]])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Exp, [Stm])] -> ([Exp], [[Stm]]))
-> [(Exp, [Stm])] -> ([Exp], [[Stm]])
forall a b. (a -> b) -> a -> b
$ (Int -> ValueDesc -> (Exp, [Stm]))
-> [Int] -> [ValueDesc] -> [(Exp, [Stm])]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> ValueDesc -> (Exp, [Stm])
storeComponent [Int
0 ..] [ValueDesc]
vds
        size_vars :: [String]
size_vars = (Int -> String) -> [Int] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map ((String
"size_" String -> ShowS
forall a. [a] -> [a] -> [a]
++) ShowS -> (Int -> String) -> Int -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> String
forall a. Show a => a -> String
show) [Int
0 .. [Exp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Exp]
sizes Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
        size_sum :: Exp
size_sum = [Exp] -> Exp
csum [[C.cexp|$id:size|] | String
size <- [String]
size_vars]
    [(String, Exp)]
-> ((String, Exp) -> CompilerM op s ()) -> CompilerM op s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([String] -> [Exp] -> [(String, Exp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [String]
size_vars [Exp]
sizes) (((String, Exp) -> CompilerM op s ()) -> CompilerM op s ())
-> ((String, Exp) -> CompilerM op s ()) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \(String
v, Exp
e) ->
      BlockItem -> CompilerM op s ()
forall op s. BlockItem -> CompilerM op s ()
item [C.citem|typename int64_t $id:v = $exp:e;|]
    Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|*n = $exp:size_sum;|]
    Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|if (p != NULL && *p == NULL) { *p = malloc(*n); }|]
    Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|if (p != NULL) { unsigned char *out = *p; $stms:(concat stores) }|]

  let restoreComponent :: Int -> ValueDesc -> CompilerM op s [Stm]
restoreComponent Int
i (ScalarValue PrimType
pt Signedness
sign VName
_) = do
        let field :: String
field = Int -> String
tupleField Int
i
            dataptr :: String
dataptr = String
"data_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i
        [Stm] -> CompilerM op s ()
forall op s. [Stm] -> CompilerM op s ()
stms ([Stm] -> CompilerM op s ()) -> [Stm] -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ Signedness -> PrimType -> Int -> Exp -> Exp -> [Stm]
loadValueHeader Signedness
sign PrimType
pt Int
0 [C.cexp|NULL|] [C.cexp|src|]
        BlockItem -> CompilerM op s ()
forall op s. BlockItem -> CompilerM op s ()
item [C.citem|const void* $id:dataptr = src;|]
        Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|src += sizeof(obj->$id:field);|]
        [Stm] -> CompilerM op s [Stm]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cstms|memcpy(&obj->$id:field, $id:dataptr, sizeof(obj->$id:field));|]
      restoreComponent Int
i (ArrayValue VName
_ Space
_ PrimType
pt Signedness
sign [SubExp]
shape) = do
        let field :: String
field = Int -> String
tupleField Int
i
            rank :: Int
rank = [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
shape
            arr_name :: String
arr_name = PrimType -> Signedness -> Int -> String
arrayName PrimType
pt Signedness
sign Int
rank
            new_array :: String
new_array = String
"futhark_new_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
arr_name
            dataptr :: String
dataptr = String
"data_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i
            shapearr :: String
shapearr = String
"shape_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i
            dims :: [Exp]
dims = [[C.cexp|$id:shapearr[$int:j]|] | Int
j <- [Int
0 .. Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]]
            num_elems :: Exp
num_elems = [Exp] -> Exp
cproduct [Exp]
dims
        BlockItem -> CompilerM op s ()
forall op s. BlockItem -> CompilerM op s ()
item [C.citem|typename int64_t $id:shapearr[$int:rank];|]
        [Stm] -> CompilerM op s ()
forall op s. [Stm] -> CompilerM op s ()
stms ([Stm] -> CompilerM op s ()) -> [Stm] -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ Signedness -> PrimType -> Int -> Exp -> Exp -> [Stm]
loadValueHeader Signedness
sign PrimType
pt Int
rank [C.cexp|$id:shapearr|] [C.cexp|src|]
        BlockItem -> CompilerM op s ()
forall op s. BlockItem -> CompilerM op s ()
item [C.citem|const void* $id:dataptr = src;|]
        Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|obj->$id:field = NULL;|]
        Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|src += $exp:num_elems * $int:(primByteSize pt::Int);|]
        [Stm] -> CompilerM op s [Stm]
forall (f :: * -> *) a. Applicative f => a -> f a
pure
          [C.cstms|
             obj->$id:field = $id:new_array(ctx, $id:dataptr, $args:dims);
             if (obj->$id:field == NULL) { err = 1; }|]

  [BlockItem]
load_body <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
collect (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ do
    [Stm]
loads <- [[Stm]] -> [Stm]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Stm]] -> [Stm])
-> CompilerM op s [[Stm]] -> CompilerM op s [Stm]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Int -> ValueDesc -> CompilerM op s [Stm])
-> [Int] -> [ValueDesc] -> CompilerM op s [[Stm]]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM Int -> ValueDesc -> CompilerM op s [Stm]
forall {op} {s}. Int -> ValueDesc -> CompilerM op s [Stm]
restoreComponent [Int
0 ..] [ValueDesc]
vds
    Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm
      [C.cstm|if (err == 0) {
                $stms:loads
              }|]

  HeaderSection -> Definition -> CompilerM op s ()
forall op s. HeaderSection -> Definition -> CompilerM op s ()
headerDecl
    (String -> HeaderSection
OpaqueDecl String
desc)
    [C.cedecl|struct $id:name;|]
  HeaderSection -> Definition -> CompilerM op s ()
forall op s. HeaderSection -> Definition -> CompilerM op s ()
headerDecl
    (String -> HeaderSection
OpaqueDecl String
desc)
    [C.cedecl|int $id:free_opaque($ty:ctx_ty *ctx, $ty:opaque_type *obj);|]
  HeaderSection -> Definition -> CompilerM op s ()
forall op s. HeaderSection -> Definition -> CompilerM op s ()
headerDecl
    (String -> HeaderSection
OpaqueDecl String
desc)
    [C.cedecl|int $id:store_opaque($ty:ctx_ty *ctx, const $ty:opaque_type *obj, void **p, size_t *n);|]
  HeaderSection -> Definition -> CompilerM op s ()
forall op s. HeaderSection -> Definition -> CompilerM op s ()
headerDecl
    (String -> HeaderSection
OpaqueDecl String
desc)
    [C.cedecl|$ty:opaque_type* $id:restore_opaque($ty:ctx_ty *ctx, const void *p);|]

  -- We do not need to enclose the body in a critical section, because
  -- when we operate on the components of the opaque, we are calling
  -- public API functions that do their own locking.
  [Definition] -> CompilerM op s [Definition]
forall (m :: * -> *) a. Monad m => a -> m a
return
    [C.cunit|
          int $id:free_opaque($ty:ctx_ty *ctx, $ty:opaque_type *obj) {
            int ret = 0, tmp;
            $items:free_body
            free(obj);
            return ret;
          }

          int $id:store_opaque($ty:ctx_ty *ctx,
                               const $ty:opaque_type *obj, void **p, size_t *n) {
            int ret = 0;
            $items:store_body
            return ret;
          }

          $ty:opaque_type* $id:restore_opaque($ty:ctx_ty *ctx,
                                              const void *p) {
            int err = 0;
            const unsigned char *src = p;
            $ty:opaque_type* obj = malloc(sizeof($ty:opaque_type));
            $items:load_body
            if (err != 0) {
              int ret = 0, tmp;
              $items:free_body
              free(obj);
              obj = NULL;
            }
            return obj;
          }
    |]

valueDescToCType :: Publicness -> ValueDesc -> CompilerM op s C.Type
valueDescToCType :: forall op s. Publicness -> ValueDesc -> CompilerM op s Type
valueDescToCType Publicness
_ (ScalarValue PrimType
pt Signedness
signed VName
_) =
  Type -> CompilerM op s Type
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> CompilerM op s Type) -> Type -> CompilerM op s Type
forall a b. (a -> b) -> a -> b
$ Signedness -> PrimType -> Type
signedPrimTypeToCType Signedness
signed PrimType
pt
valueDescToCType Publicness
pub (ArrayValue VName
_ Space
space PrimType
pt Signedness
signed [SubExp]
shape) = do
  let rank :: Int
rank = [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
shape
  String
name <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName (String -> CompilerM op s String)
-> String -> CompilerM op s String
forall a b. (a -> b) -> a -> b
$ PrimType -> Signedness -> Int -> String
arrayName PrimType
pt Signedness
signed Int
rank
  let add :: Map ArrayType Publicness -> Map ArrayType Publicness
add = (Publicness -> Publicness -> Publicness)
-> ArrayType
-> Publicness
-> Map ArrayType Publicness
-> Map ArrayType Publicness
forall k a. Ord k => (a -> a -> a) -> k -> a -> Map k a -> Map k a
M.insertWith Publicness -> Publicness -> Publicness
forall a. Ord a => a -> a -> a
max (Space
space, Signedness
signed, PrimType
pt, Int
rank) Publicness
pub
  (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((CompilerState s -> CompilerState s) -> CompilerM op s ())
-> (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \CompilerState s
s -> CompilerState s
s {compArrayTypes :: Map ArrayType Publicness
compArrayTypes = Map ArrayType Publicness -> Map ArrayType Publicness
add (Map ArrayType Publicness -> Map ArrayType Publicness)
-> Map ArrayType Publicness -> Map ArrayType Publicness
forall a b. (a -> b) -> a -> b
$ CompilerState s -> Map ArrayType Publicness
forall s. CompilerState s -> Map ArrayType Publicness
compArrayTypes CompilerState s
s}
  Type -> CompilerM op s Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cty|struct $id:name|]

opaqueToCType :: String -> [ValueDesc] -> CompilerM op s C.Type
opaqueToCType :: forall op s. String -> [ValueDesc] -> CompilerM op s Type
opaqueToCType String
desc [ValueDesc]
vds = do
  String
name <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName (String -> CompilerM op s String)
-> String -> CompilerM op s String
forall a b. (a -> b) -> a -> b
$ String -> [ValueDesc] -> String
opaqueName String
desc [ValueDesc]
vds
  let add :: Map String [ValueDesc] -> Map String [ValueDesc]
add = String
-> [ValueDesc] -> Map String [ValueDesc] -> Map String [ValueDesc]
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert String
desc [ValueDesc]
vds
  (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((CompilerState s -> CompilerState s) -> CompilerM op s ())
-> (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \CompilerState s
s -> CompilerState s
s {compOpaqueTypes :: Map String [ValueDesc]
compOpaqueTypes = Map String [ValueDesc] -> Map String [ValueDesc]
add (Map String [ValueDesc] -> Map String [ValueDesc])
-> Map String [ValueDesc] -> Map String [ValueDesc]
forall a b. (a -> b) -> a -> b
$ CompilerState s -> Map String [ValueDesc]
forall s. CompilerState s -> Map String [ValueDesc]
compOpaqueTypes CompilerState s
s}
  -- Now ensure that the constituent array types will exist.
  (ValueDesc -> CompilerM op s Type)
-> [ValueDesc] -> CompilerM op s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Publicness -> ValueDesc -> CompilerM op s Type
forall op s. Publicness -> ValueDesc -> CompilerM op s Type
valueDescToCType Publicness
Private) [ValueDesc]
vds
  Type -> CompilerM op s Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cty|struct $id:name|]

generateAPITypes :: CompilerM op s ()
generateAPITypes :: forall {op} {s}. CompilerM op s ()
generateAPITypes = do
  ((ArrayType, Publicness) -> CompilerM op s [()])
-> [(ArrayType, Publicness)] -> CompilerM op s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (ArrayType, Publicness) -> CompilerM op s [()]
forall {op} {s}. (ArrayType, Publicness) -> CompilerM op s [()]
generateArray ([(ArrayType, Publicness)] -> CompilerM op s ())
-> (Map ArrayType Publicness -> [(ArrayType, Publicness)])
-> Map ArrayType Publicness
-> CompilerM op s ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map ArrayType Publicness -> [(ArrayType, Publicness)]
forall k a. Map k a -> [(k, a)]
M.toList (Map ArrayType Publicness -> CompilerM op s ())
-> CompilerM op s (Map ArrayType Publicness) -> CompilerM op s ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (CompilerState s -> Map ArrayType Publicness)
-> CompilerM op s (Map ArrayType Publicness)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets CompilerState s -> Map ArrayType Publicness
forall s. CompilerState s -> Map ArrayType Publicness
compArrayTypes
  ((String, [ValueDesc]) -> CompilerM op s [()])
-> [(String, [ValueDesc])] -> CompilerM op s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (String, [ValueDesc]) -> CompilerM op s [()]
forall {op} {s}. (String, [ValueDesc]) -> CompilerM op s [()]
generateOpaque ([(String, [ValueDesc])] -> CompilerM op s ())
-> (Map String [ValueDesc] -> [(String, [ValueDesc])])
-> Map String [ValueDesc]
-> CompilerM op s ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map String [ValueDesc] -> [(String, [ValueDesc])]
forall k a. Map k a -> [(k, a)]
M.toList (Map String [ValueDesc] -> CompilerM op s ())
-> CompilerM op s (Map String [ValueDesc]) -> CompilerM op s ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (CompilerState s -> Map String [ValueDesc])
-> CompilerM op s (Map String [ValueDesc])
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets CompilerState s -> Map String [ValueDesc]
forall s. CompilerState s -> Map String [ValueDesc]
compOpaqueTypes
  where
    generateArray :: (ArrayType, Publicness) -> CompilerM op s [()]
generateArray ((Space
space, Signedness
signed, PrimType
pt, Int
rank), Publicness
pub) = do
      String
name <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName (String -> CompilerM op s String)
-> String -> CompilerM op s String
forall a b. (a -> b) -> a -> b
$ PrimType -> Signedness -> Int -> String
arrayName PrimType
pt Signedness
signed Int
rank
      let memty :: Type
memty = Space -> Type
fatMemType Space
space
      Definition -> CompilerM op s ()
forall op s. Definition -> CompilerM op s ()
libDecl [C.cedecl|struct $id:name { $ty:memty mem; typename int64_t shape[$int:rank]; };|]
      (Definition -> CompilerM op s ())
-> [Definition] -> CompilerM op s [()]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Definition -> CompilerM op s ()
forall op s. Definition -> CompilerM op s ()
libDecl ([Definition] -> CompilerM op s [()])
-> CompilerM op s [Definition] -> CompilerM op s [()]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Publicness
-> Space
-> PrimType
-> Signedness
-> Int
-> CompilerM op s [Definition]
forall op s.
Publicness
-> Space
-> PrimType
-> Signedness
-> Int
-> CompilerM op s [Definition]
arrayLibraryFunctions Publicness
pub Space
space PrimType
pt Signedness
signed Int
rank

    generateOpaque :: (String, [ValueDesc]) -> CompilerM op s [()]
generateOpaque (String
desc, [ValueDesc]
vds) = do
      String
name <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName (String -> CompilerM op s String)
-> String -> CompilerM op s String
forall a b. (a -> b) -> a -> b
$ String -> [ValueDesc] -> String
opaqueName String
desc [ValueDesc]
vds
      [FieldGroup]
members <- (ValueDesc -> Int -> CompilerM op s FieldGroup)
-> [ValueDesc] -> [Int] -> CompilerM op s [FieldGroup]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM ValueDesc -> Int -> CompilerM op s FieldGroup
forall {op} {s}. ValueDesc -> Int -> CompilerM op s FieldGroup
field [ValueDesc]
vds [(Int
0 :: Int) ..]
      Definition -> CompilerM op s ()
forall op s. Definition -> CompilerM op s ()
libDecl [C.cedecl|struct $id:name { $sdecls:members };|]
      (Definition -> CompilerM op s ())
-> [Definition] -> CompilerM op s [()]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Definition -> CompilerM op s ()
forall op s. Definition -> CompilerM op s ()
libDecl ([Definition] -> CompilerM op s [()])
-> CompilerM op s [Definition] -> CompilerM op s [()]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< String -> [ValueDesc] -> CompilerM op s [Definition]
forall op s. String -> [ValueDesc] -> CompilerM op s [Definition]
opaqueLibraryFunctions String
desc [ValueDesc]
vds

    field :: ValueDesc -> Int -> CompilerM op s FieldGroup
field vd :: ValueDesc
vd@ScalarValue {} Int
i = do
      Type
ct <- Publicness -> ValueDesc -> CompilerM op s Type
forall op s. Publicness -> ValueDesc -> CompilerM op s Type
valueDescToCType Publicness
Private ValueDesc
vd
      FieldGroup -> CompilerM op s FieldGroup
forall (m :: * -> *) a. Monad m => a -> m a
return [C.csdecl|$ty:ct $id:(tupleField i);|]
    field ValueDesc
vd Int
i = do
      Type
ct <- Publicness -> ValueDesc -> CompilerM op s Type
forall op s. Publicness -> ValueDesc -> CompilerM op s Type
valueDescToCType Publicness
Private ValueDesc
vd
      FieldGroup -> CompilerM op s FieldGroup
forall (m :: * -> *) a. Monad m => a -> m a
return [C.csdecl|$ty:ct *$id:(tupleField i);|]

allTrue :: [C.Exp] -> C.Exp
allTrue :: [Exp] -> Exp
allTrue [] = [C.cexp|true|]
allTrue [Exp
x] = Exp
x
allTrue (Exp
x : [Exp]
xs) = [C.cexp|$exp:x && $exp:(allTrue xs)|]

prepareEntryInputs ::
  [ExternalValue] ->
  CompilerM op s ([(C.Param, C.Exp)], [C.BlockItem])
prepareEntryInputs :: forall op s.
[ExternalValue] -> CompilerM op s ([(Param, Exp)], [BlockItem])
prepareEntryInputs [ExternalValue]
args = CompilerM op s [(Param, Exp)]
-> CompilerM op s ([(Param, Exp)], [BlockItem])
forall op s a. CompilerM op s a -> CompilerM op s (a, [BlockItem])
collect' (CompilerM op s [(Param, Exp)]
 -> CompilerM op s ([(Param, Exp)], [BlockItem]))
-> CompilerM op s [(Param, Exp)]
-> CompilerM op s ([(Param, Exp)], [BlockItem])
forall a b. (a -> b) -> a -> b
$ (Int -> ExternalValue -> CompilerM op s (Param, Exp))
-> [Int] -> [ExternalValue] -> CompilerM op s [(Param, Exp)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM Int -> ExternalValue -> CompilerM op s (Param, Exp)
forall {p} {op} {s}.
Show p =>
p -> ExternalValue -> CompilerM op s (Param, Exp)
prepare [(Int
0 :: Int) ..] [ExternalValue]
args
  where
    arg_names :: Names
arg_names = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ (ExternalValue -> [VName]) -> [ExternalValue] -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ExternalValue -> [VName]
evNames [ExternalValue]
args
    evNames :: ExternalValue -> [VName]
evNames (OpaqueValue Uniqueness
_ String
_ [ValueDesc]
vds) = (ValueDesc -> VName) -> [ValueDesc] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map ValueDesc -> VName
vdName [ValueDesc]
vds
    evNames (TransparentValue Uniqueness
_ ValueDesc
vd) = [ValueDesc -> VName
vdName ValueDesc
vd]
    vdName :: ValueDesc -> VName
vdName (ArrayValue VName
v Space
_ PrimType
_ Signedness
_ [SubExp]
_) = VName
v
    vdName (ScalarValue PrimType
_ Signedness
_ VName
v) = VName
v

    prepare :: p -> ExternalValue -> CompilerM op s (Param, Exp)
prepare p
pno (TransparentValue Uniqueness
_ ValueDesc
vd) = do
      let pname :: String
pname = String
"in" String -> ShowS
forall a. [a] -> [a] -> [a]
++ p -> String
forall a. Show a => a -> String
show p
pno
      (Type
ty, [Exp]
check) <- Publicness -> Exp -> ValueDesc -> CompilerM op s (Type, [Exp])
forall {a} {op} {s}.
ToExp a =>
Publicness -> a -> ValueDesc -> CompilerM op s (Type, [Exp])
prepareValue Publicness
Public [C.cexp|$id:pname|] ValueDesc
vd
      (Param, Exp) -> CompilerM op s (Param, Exp)
forall (m :: * -> *) a. Monad m => a -> m a
return
        ( [C.cparam|const $ty:ty $id:pname|],
          [Exp] -> Exp
allTrue [Exp]
check
        )
    prepare p
pno (OpaqueValue Uniqueness
_ String
desc [ValueDesc]
vds) = do
      Type
ty <- String -> [ValueDesc] -> CompilerM op s Type
forall op s. String -> [ValueDesc] -> CompilerM op s Type
opaqueToCType String
desc [ValueDesc]
vds
      let pname :: String
pname = String
"in" String -> ShowS
forall a. [a] -> [a] -> [a]
++ p -> String
forall a. Show a => a -> String
show p
pno
          field :: Int -> ValueDesc -> Exp
field Int
i ScalarValue {} = [C.cexp|$id:pname->$id:(tupleField i)|]
          field Int
i ArrayValue {} = [C.cexp|$id:pname->$id:(tupleField i)|]
      [[Exp]]
checks <- ((Type, [Exp]) -> [Exp]) -> [(Type, [Exp])] -> [[Exp]]
forall a b. (a -> b) -> [a] -> [b]
map (Type, [Exp]) -> [Exp]
forall a b. (a, b) -> b
snd ([(Type, [Exp])] -> [[Exp]])
-> CompilerM op s [(Type, [Exp])] -> CompilerM op s [[Exp]]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Exp -> ValueDesc -> CompilerM op s (Type, [Exp]))
-> [Exp] -> [ValueDesc] -> CompilerM op s [(Type, [Exp])]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (Publicness -> Exp -> ValueDesc -> CompilerM op s (Type, [Exp])
forall {a} {op} {s}.
ToExp a =>
Publicness -> a -> ValueDesc -> CompilerM op s (Type, [Exp])
prepareValue Publicness
Private) ((Int -> ValueDesc -> Exp) -> [Int] -> [ValueDesc] -> [Exp]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> ValueDesc -> Exp
field [Int
0 ..] [ValueDesc]
vds) [ValueDesc]
vds
      (Param, Exp) -> CompilerM op s (Param, Exp)
forall (m :: * -> *) a. Monad m => a -> m a
return
        ( [C.cparam|const $ty:ty *$id:pname|],
          [Exp] -> Exp
allTrue ([Exp] -> Exp) -> [Exp] -> Exp
forall a b. (a -> b) -> a -> b
$ [[Exp]] -> [Exp]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Exp]]
checks
        )

    prepareValue :: Publicness -> a -> ValueDesc -> CompilerM op s (Type, [Exp])
prepareValue Publicness
_ a
src (ScalarValue PrimType
pt Signedness
signed VName
name) = do
      let pt' :: Type
pt' = Signedness -> PrimType -> Type
signedPrimTypeToCType Signedness
signed PrimType
pt
      Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|$id:name = $exp:src;|]
      (Type, [Exp]) -> CompilerM op s (Type, [Exp])
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
pt', [])
    prepareValue Publicness
pub a
src vd :: ValueDesc
vd@(ArrayValue VName
mem Space
_ PrimType
_ Signedness
_ [SubExp]
shape) = do
      Type
ty <- Publicness -> ValueDesc -> CompilerM op s Type
forall op s. Publicness -> ValueDesc -> CompilerM op s Type
valueDescToCType Publicness
pub ValueDesc
vd

      Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|$exp:mem = $exp:src->mem;|]

      let rank :: Int
rank = [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
shape
          maybeCopyDim :: SubExp -> a -> (Maybe Stm, Exp)
maybeCopyDim (Var VName
d) a
i
            | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName
d VName -> Names -> Bool
`nameIn` Names
arg_names =
              ( Stm -> Maybe Stm
forall a. a -> Maybe a
Just [C.cstm|$id:d = $exp:src->shape[$int:i];|],
                [C.cexp|$id:d == $exp:src->shape[$int:i]|]
              )
          maybeCopyDim SubExp
x a
i =
            ( Maybe Stm
forall a. Maybe a
Nothing,
              [C.cexp|$exp:x == $exp:src->shape[$int:i]|]
            )

      let ([Maybe Stm]
sets, [Exp]
checks) =
            [(Maybe Stm, Exp)] -> ([Maybe Stm], [Exp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Maybe Stm, Exp)] -> ([Maybe Stm], [Exp]))
-> [(Maybe Stm, Exp)] -> ([Maybe Stm], [Exp])
forall a b. (a -> b) -> a -> b
$ (SubExp -> Int -> (Maybe Stm, Exp))
-> [SubExp] -> [Int] -> [(Maybe Stm, Exp)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SubExp -> Int -> (Maybe Stm, Exp)
forall {a}. (Show a, Integral a) => SubExp -> a -> (Maybe Stm, Exp)
maybeCopyDim [SubExp]
shape [Int
0 .. Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
      [Stm] -> CompilerM op s ()
forall op s. [Stm] -> CompilerM op s ()
stms ([Stm] -> CompilerM op s ()) -> [Stm] -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ [Maybe Stm] -> [Stm]
forall a. [Maybe a] -> [a]
catMaybes [Maybe Stm]
sets

      (Type, [Exp]) -> CompilerM op s (Type, [Exp])
forall (m :: * -> *) a. Monad m => a -> m a
return ([C.cty|$ty:ty*|], [Exp]
checks)

prepareEntryOutputs :: [ExternalValue] -> CompilerM op s ([C.Param], [C.BlockItem])
prepareEntryOutputs :: forall op s.
[ExternalValue] -> CompilerM op s ([Param], [BlockItem])
prepareEntryOutputs = CompilerM op s [Param] -> CompilerM op s ([Param], [BlockItem])
forall op s a. CompilerM op s a -> CompilerM op s (a, [BlockItem])
collect' (CompilerM op s [Param] -> CompilerM op s ([Param], [BlockItem]))
-> ([ExternalValue] -> CompilerM op s [Param])
-> [ExternalValue]
-> CompilerM op s ([Param], [BlockItem])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> ExternalValue -> CompilerM op s Param)
-> [Int] -> [ExternalValue] -> CompilerM op s [Param]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM Int -> ExternalValue -> CompilerM op s Param
forall {p} {op} {s}.
Show p =>
p -> ExternalValue -> CompilerM op s Param
prepare [(Int
0 :: Int) ..]
  where
    prepare :: p -> ExternalValue -> CompilerM op s Param
prepare p
pno (TransparentValue Uniqueness
_ ValueDesc
vd) = do
      let pname :: String
pname = String
"out" String -> ShowS
forall a. [a] -> [a] -> [a]
++ p -> String
forall a. Show a => a -> String
show p
pno
      Type
ty <- Publicness -> ValueDesc -> CompilerM op s Type
forall op s. Publicness -> ValueDesc -> CompilerM op s Type
valueDescToCType Publicness
Public ValueDesc
vd

      case ValueDesc
vd of
        ArrayValue {} -> do
          Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|assert((*$id:pname = ($ty:ty*) malloc(sizeof($ty:ty))) != NULL);|]
          Exp -> ValueDesc -> CompilerM op s ()
forall {a} {op} {s}. ToExp a => a -> ValueDesc -> CompilerM op s ()
prepareValue [C.cexp|*$id:pname|] ValueDesc
vd
          Param -> CompilerM op s Param
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cparam|$ty:ty **$id:pname|]
        ScalarValue {} -> do
          Exp -> ValueDesc -> CompilerM op s ()
forall {a} {op} {s}. ToExp a => a -> ValueDesc -> CompilerM op s ()
prepareValue [C.cexp|*$id:pname|] ValueDesc
vd
          Param -> CompilerM op s Param
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cparam|$ty:ty *$id:pname|]
    prepare p
pno (OpaqueValue Uniqueness
_ String
desc [ValueDesc]
vds) = do
      let pname :: String
pname = String
"out" String -> ShowS
forall a. [a] -> [a] -> [a]
++ p -> String
forall a. Show a => a -> String
show p
pno
      Type
ty <- String -> [ValueDesc] -> CompilerM op s Type
forall op s. String -> [ValueDesc] -> CompilerM op s Type
opaqueToCType String
desc [ValueDesc]
vds
      [Type]
vd_ts <- (ValueDesc -> CompilerM op s Type)
-> [ValueDesc] -> CompilerM op s [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Publicness -> ValueDesc -> CompilerM op s Type
forall op s. Publicness -> ValueDesc -> CompilerM op s Type
valueDescToCType Publicness
Private) [ValueDesc]
vds

      Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|assert((*$id:pname = ($ty:ty*) malloc(sizeof($ty:ty))) != NULL);|]

      [(Int, Type, ValueDesc)]
-> ((Int, Type, ValueDesc) -> CompilerM op s ())
-> CompilerM op s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Int] -> [Type] -> [ValueDesc] -> [(Int, Type, ValueDesc)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Int
0 ..] [Type]
vd_ts [ValueDesc]
vds) (((Int, Type, ValueDesc) -> CompilerM op s ())
 -> CompilerM op s ())
-> ((Int, Type, ValueDesc) -> CompilerM op s ())
-> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \(Int
i, Type
ct, ValueDesc
vd) -> do
        let field :: Exp
field = [C.cexp|(*$id:pname)->$id:(tupleField i)|]
        case ValueDesc
vd of
          ScalarValue {} -> () -> CompilerM op s ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
          ValueDesc
_ -> Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|assert(($exp:field = ($ty:ct*) malloc(sizeof($ty:ct))) != NULL);|]
        Exp -> ValueDesc -> CompilerM op s ()
forall {a} {op} {s}. ToExp a => a -> ValueDesc -> CompilerM op s ()
prepareValue Exp
field ValueDesc
vd

      Param -> CompilerM op s Param
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cparam|$ty:ty **$id:pname|]

    prepareValue :: a -> ValueDesc -> CompilerM op s ()
prepareValue a
dest (ScalarValue PrimType
_ Signedness
_ VName
name) =
      Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|$exp:dest = $id:name;|]
    prepareValue a
dest (ArrayValue VName
mem Space
_ PrimType
_ Signedness
_ [SubExp]
shape) = do
      Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|$exp:dest->mem = $id:mem;|]

      let rank :: Int
rank = [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
shape
          maybeCopyDim :: SubExp -> a -> Stm
maybeCopyDim (Constant PrimValue
x) a
i =
            [C.cstm|$exp:dest->shape[$int:i] = $exp:x;|]
          maybeCopyDim (Var VName
d) a
i =
            [C.cstm|$exp:dest->shape[$int:i] = $id:d;|]
      [Stm] -> CompilerM op s ()
forall op s. [Stm] -> CompilerM op s ()
stms ([Stm] -> CompilerM op s ()) -> [Stm] -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ (SubExp -> Int -> Stm) -> [SubExp] -> [Int] -> [Stm]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SubExp -> Int -> Stm
forall {a}. (Show a, Integral a) => SubExp -> a -> Stm
maybeCopyDim [SubExp]
shape [Int
0 .. Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]

onEntryPoint ::
  [C.BlockItem] ->
  Name ->
  Function op ->
  CompilerM op s (Maybe C.Definition)
onEntryPoint :: forall op s.
[BlockItem]
-> Name -> Function op -> CompilerM op s (Maybe Definition)
onEntryPoint [BlockItem]
_ Name
_ (Function Maybe Name
Nothing [Param]
_ [Param]
_ Code op
_ [ExternalValue]
_ [ExternalValue]
_) = Maybe Definition -> CompilerM op s (Maybe Definition)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe Definition
forall a. Maybe a
Nothing
onEntryPoint [BlockItem]
get_consts Name
fname (Function (Just Name
ename) [Param]
outputs [Param]
inputs Code op
_ [ExternalValue]
results [ExternalValue]
args) = do
  let out_args :: [Exp]
out_args = (Param -> Exp) -> [Param] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (\Param
p -> [C.cexp|&$id:(paramName p)|]) [Param]
outputs
      in_args :: [Exp]
in_args = (Param -> Exp) -> [Param] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (\Param
p -> [C.cexp|$id:(paramName p)|]) [Param]
inputs

  [BlockItem]
inputdecls <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
collect (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ (Param -> CompilerM op s ()) -> [Param] -> CompilerM op s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Param -> CompilerM op s ()
forall {op} {s}. Param -> CompilerM op s ()
stubParam [Param]
inputs
  [BlockItem]
outputdecls <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
collect (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ (Param -> CompilerM op s ()) -> [Param] -> CompilerM op s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Param -> CompilerM op s ()
forall {op} {s}. Param -> CompilerM op s ()
stubParam [Param]
outputs

  String
entry_point_function_name <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName (String -> CompilerM op s String)
-> String -> CompilerM op s String
forall a b. (a -> b) -> a -> b
$ String
"entry_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Name -> String
nameToString Name
ename

  ([(Param, Exp)]
inputs', [BlockItem]
unpack_entry_inputs) <- [ExternalValue] -> CompilerM op s ([(Param, Exp)], [BlockItem])
forall op s.
[ExternalValue] -> CompilerM op s ([(Param, Exp)], [BlockItem])
prepareEntryInputs [ExternalValue]
args
  let ([Param]
entry_point_input_params, [Exp]
entry_point_input_checks) = [(Param, Exp)] -> ([Param], [Exp])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Param, Exp)]
inputs'

  ([Param]
entry_point_output_params, [BlockItem]
pack_entry_outputs) <-
    [ExternalValue] -> CompilerM op s ([Param], [BlockItem])
forall op s.
[ExternalValue] -> CompilerM op s ([Param], [BlockItem])
prepareEntryOutputs [ExternalValue]
results

  Type
ctx_ty <- CompilerM op s Type
forall op s. CompilerM op s Type
contextType

  HeaderSection -> Definition -> CompilerM op s ()
forall op s. HeaderSection -> Definition -> CompilerM op s ()
headerDecl
    HeaderSection
EntryDecl
    [C.cedecl|int $id:entry_point_function_name
                                     ($ty:ctx_ty *ctx,
                                      $params:entry_point_output_params,
                                      $params:entry_point_input_params);|]

  let critical :: [BlockItem]
critical =
        [C.citems|
         $items:unpack_entry_inputs

         if (!($exp:(allTrue entry_point_input_checks))) {
           ret = 1;
           if (!ctx->error) {
             ctx->error = msgprintf("Error: entry point arguments have invalid sizes.\n");
           }
         } else {
           ret = $id:(funName fname)(ctx, $args:out_args, $args:in_args);

           if (ret == 0) {
             $items:get_consts

             $items:pack_entry_outputs
           }
         }
        |]

  Operations op s
ops <- (CompilerEnv op s -> Operations op s)
-> CompilerM op s (Operations op s)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks CompilerEnv op s -> Operations op s
forall op s. CompilerEnv op s -> Operations op s
envOperations

  Maybe Definition -> CompilerM op s (Maybe Definition)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Definition -> CompilerM op s (Maybe Definition))
-> (Definition -> Maybe Definition)
-> Definition
-> CompilerM op s (Maybe Definition)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Definition -> Maybe Definition
forall a. a -> Maybe a
Just (Definition -> CompilerM op s (Maybe Definition))
-> Definition -> CompilerM op s (Maybe Definition)
forall a b. (a -> b) -> a -> b
$
    [C.cedecl|
       int $id:entry_point_function_name
           ($ty:ctx_ty *ctx,
            $params:entry_point_output_params,
            $params:entry_point_input_params) {
         $items:inputdecls
         $items:outputdecls

         int ret = 0;

         $items:(criticalSection ops critical)

         return ret;
       }|]
  where
    stubParam :: Param -> CompilerM op s ()
stubParam (MemParam VName
name Space
space) =
      VName -> Space -> CompilerM op s ()
forall op s. VName -> Space -> CompilerM op s ()
declMem VName
name Space
space
    stubParam (ScalarParam VName
name PrimType
ty) = do
      let ty' :: Type
ty' = PrimType -> Type
primTypeToCType PrimType
ty
      InitGroup -> CompilerM op s ()
forall op s. InitGroup -> CompilerM op s ()
decl [C.cdecl|$ty:ty' $id:name;|]

-- | The result of compilation to C is multiple parts, which can be
-- put together in various ways.  The obvious way is to concatenate
-- all of them, which yields a CLI program.  Another is to compile the
-- library part by itself, and use the header file to call into it.
data CParts = CParts
  { CParts -> String
cHeader :: String,
    -- | Utility definitions that must be visible
    -- to both CLI and library parts.
    CParts -> String
cUtils :: String,
    CParts -> String
cCLI :: String,
    CParts -> String
cServer :: String,
    CParts -> String
cLib :: String
  }

gnuSource :: String
gnuSource :: String
gnuSource =
  [Definition] -> String
forall a. Pretty a => a -> String
pretty
    [C.cunit|
// We need to define _GNU_SOURCE before
// _any_ headers files are imported to get
// the usage statistics of a thread (i.e. have RUSAGE_THREAD) on GNU/Linux
// https://manpages.courier-mta.org/htmlman2/getrusage.2.html
$esc:("#ifndef _GNU_SOURCE") // Avoid possible double-definition warning.
$esc:("#define _GNU_SOURCE")
$esc:("#endif")
|]

-- We may generate variables that are never used (e.g. for
-- certificates) or functions that are never called (e.g. unused
-- intrinsics), and generated code may have other cosmetic issues that
-- compilers warn about.  We disable these warnings to not clutter the
-- compilation logs.
disableWarnings :: String
disableWarnings :: String
disableWarnings =
  [Definition] -> String
forall a. Pretty a => a -> String
pretty
    [C.cunit|
$esc:("#ifdef __clang__")
$esc:("#pragma clang diagnostic ignored \"-Wunused-function\"")
$esc:("#pragma clang diagnostic ignored \"-Wunused-variable\"")
$esc:("#pragma clang diagnostic ignored \"-Wparentheses\"")
$esc:("#pragma clang diagnostic ignored \"-Wunused-label\"")
$esc:("#elif __GNUC__")
$esc:("#pragma GCC diagnostic ignored \"-Wunused-function\"")
$esc:("#pragma GCC diagnostic ignored \"-Wunused-variable\"")
$esc:("#pragma GCC diagnostic ignored \"-Wparentheses\"")
$esc:("#pragma GCC diagnostic ignored \"-Wunused-label\"")
$esc:("#pragma GCC diagnostic ignored \"-Wunused-but-set-variable\"")
$esc:("#endif")

|]

-- | Produce header and implementation files.
asLibrary :: CParts -> (String, String)
asLibrary :: CParts -> (String, String)
asLibrary CParts
parts =
  ( String
"#pragma once\n\n" String -> ShowS
forall a. Semigroup a => a -> a -> a
<> CParts -> String
cHeader CParts
parts,
    String
gnuSource String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
disableWarnings String -> ShowS
forall a. Semigroup a => a -> a -> a
<> CParts -> String
cHeader CParts
parts String -> ShowS
forall a. Semigroup a => a -> a -> a
<> CParts -> String
cUtils CParts
parts String -> ShowS
forall a. Semigroup a => a -> a -> a
<> CParts -> String
cLib CParts
parts
  )

-- | As executable with command-line interface.
asExecutable :: CParts -> String
asExecutable :: CParts -> String
asExecutable CParts
parts =
  String
gnuSource String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
disableWarnings String -> ShowS
forall a. Semigroup a => a -> a -> a
<> CParts -> String
cHeader CParts
parts String -> ShowS
forall a. Semigroup a => a -> a -> a
<> CParts -> String
cUtils CParts
parts String -> ShowS
forall a. Semigroup a => a -> a -> a
<> CParts -> String
cCLI CParts
parts String -> ShowS
forall a. Semigroup a => a -> a -> a
<> CParts -> String
cLib CParts
parts

-- | As server executable.
asServer :: CParts -> String
asServer :: CParts -> String
asServer CParts
parts =
  String
gnuSource String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
disableWarnings String -> ShowS
forall a. Semigroup a => a -> a -> a
<> CParts -> String
cHeader CParts
parts String -> ShowS
forall a. Semigroup a => a -> a -> a
<> CParts -> String
cUtils CParts
parts String -> ShowS
forall a. Semigroup a => a -> a -> a
<> CParts -> String
cServer CParts
parts String -> ShowS
forall a. Semigroup a => a -> a -> a
<> CParts -> String
cLib CParts
parts

-- | Compile imperative program to a C program.  Always uses the
-- function named "main" as entry point, so make sure it is defined.
compileProg ::
  MonadFreshNames m =>
  String ->
  Operations op () ->
  CompilerM op () () ->
  String ->
  [Space] ->
  [Option] ->
  Definitions op ->
  m CParts
compileProg :: forall (m :: * -> *) op.
MonadFreshNames m =>
String
-> Operations op ()
-> CompilerM op () ()
-> String
-> [Space]
-> [Option]
-> Definitions op
-> m CParts
compileProg String
backend Operations op ()
ops CompilerM op () ()
extra String
header_extra [Space]
spaces [Option]
options Definitions op
prog = do
  VNameSource
src <- m VNameSource
forall (m :: * -> *). MonadFreshNames m => m VNameSource
getNameSource
  let (([Definition]
prototypes, [Definition]
definitions, [Definition]
entry_point_decls), CompilerState ()
endstate) =
        Operations op ()
-> VNameSource
-> ()
-> CompilerM op () ([Definition], [Definition], [Definition])
-> (([Definition], [Definition], [Definition]), CompilerState ())
forall op s a.
Operations op s
-> VNameSource -> s -> CompilerM op s a -> (a, CompilerState s)
runCompilerM Operations op ()
ops VNameSource
src () CompilerM op () ([Definition], [Definition], [Definition])
compileProg'

  let headerdefs :: [Definition]
headerdefs =
        [C.cunit|
$esc:("// Headers\n")
$esc:("#include <stdint.h>")
$esc:("#include <stddef.h>")
$esc:("#include <stdbool.h>")
$esc:("#include <stdio.h>")
$esc:("#include <float.h>")
$esc:(header_extra)

$esc:("#ifdef __cplusplus")
$esc:("extern \"C\" {")
$esc:("#endif")

$esc:("\n// Initialisation\n")
$edecls:(initDecls endstate)

$esc:("\n// Arrays\n")
$edecls:(arrayDecls endstate)

$esc:("\n// Opaque values\n")
$edecls:(opaqueDecls endstate)

$esc:("\n// Entry points\n")
$edecls:(entryDecls endstate)

$esc:("\n// Miscellaneous\n")
$edecls:(miscDecls endstate)
$esc:("#define FUTHARK_BACKEND_"++backend)

$esc:("#ifdef __cplusplus")
$esc:("}")
$esc:("#endif")
                           |]

  let utildefs :: [Definition]
utildefs =
        [C.cunit|
$esc:("#include <stdio.h>")
$esc:("#include <stdlib.h>")
$esc:("#include <stdbool.h>")
$esc:("#include <math.h>")
$esc:("#include <stdint.h>")
// If NDEBUG is set, the assert() macro will do nothing. Since Futhark
// (unfortunately) makes use of assert() for error detection (and even some
// side effects), we want to avoid that.
$esc:("#undef NDEBUG")
$esc:("#include <assert.h>")
$esc:("#include <stdarg.h>")

$esc:util_h

$esc:timing_h
|]

  let early_decls :: [Definition]
early_decls = DList Definition -> [Definition]
forall a. DList a -> [a]
DL.toList (DList Definition -> [Definition])
-> DList Definition -> [Definition]
forall a b. (a -> b) -> a -> b
$ CompilerState () -> DList Definition
forall s. CompilerState s -> DList Definition
compEarlyDecls CompilerState ()
endstate
  let lib_decls :: [Definition]
lib_decls = DList Definition -> [Definition]
forall a. DList a -> [a]
DL.toList (DList Definition -> [Definition])
-> DList Definition -> [Definition]
forall a b. (a -> b) -> a -> b
$ CompilerState () -> DList Definition
forall s. CompilerState s -> DList Definition
compLibDecls CompilerState ()
endstate
  let clidefs :: [Definition]
clidefs = [Option] -> Functions op -> [Definition]
forall a. [Option] -> Functions a -> [Definition]
cliDefs [Option]
options (Functions op -> [Definition]) -> Functions op -> [Definition]
forall a b. (a -> b) -> a -> b
$ [(Name, Function op)] -> Functions op
forall a. [(Name, Function a)] -> Functions a
Functions [(Name, Function op)]
entry_funs
  let serverdefs :: [Definition]
serverdefs = [Option] -> Functions op -> [Definition]
forall a. [Option] -> Functions a -> [Definition]
serverDefs [Option]
options (Functions op -> [Definition]) -> Functions op -> [Definition]
forall a b. (a -> b) -> a -> b
$ [(Name, Function op)] -> Functions op
forall a. [(Name, Function a)] -> Functions a
Functions [(Name, Function op)]
entry_funs
  let libdefs :: [Definition]
libdefs =
        [C.cunit|
$esc:("#ifdef _MSC_VER\n#define inline __inline\n#endif")
$esc:("#include <string.h>")
$esc:("#include <string.h>")
$esc:("#include <errno.h>")
$esc:("#include <assert.h>")
$esc:("#include <ctype.h>")

$esc:header_extra

$esc:lock_h

$edecls:builtin

$edecls:early_decls

$edecls:prototypes

$edecls:lib_decls

$edecls:definitions

$edecls:entry_point_decls
  |]

  CParts -> m CParts
forall (m :: * -> *) a. Monad m => a -> m a
return (CParts -> m CParts) -> CParts -> m CParts
forall a b. (a -> b) -> a -> b
$
    CParts :: String -> String -> String -> String -> String -> CParts
CParts
      { cHeader :: String
cHeader = [Definition] -> String
forall a. Pretty a => a -> String
pretty [Definition]
headerdefs,
        cUtils :: String
cUtils = [Definition] -> String
forall a. Pretty a => a -> String
pretty [Definition]
utildefs,
        cCLI :: String
cCLI = [Definition] -> String
forall a. Pretty a => a -> String
pretty [Definition]
clidefs,
        cServer :: String
cServer = [Definition] -> String
forall a. Pretty a => a -> String
pretty [Definition]
serverdefs,
        cLib :: String
cLib = [Definition] -> String
forall a. Pretty a => a -> String
pretty [Definition]
libdefs
      }
  where
    Definitions Constants op
consts (Functions [(Name, Function op)]
funs) = Definitions op
prog
    entry_funs :: [(Name, Function op)]
entry_funs = ((Name, Function op) -> Bool)
-> [(Name, Function op)] -> [(Name, Function op)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Maybe Name -> Bool
forall a. Maybe a -> Bool
isJust (Maybe Name -> Bool)
-> ((Name, Function op) -> Maybe Name)
-> (Name, Function op)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Function op -> Maybe Name
forall a. FunctionT a -> Maybe Name
functionEntry (Function op -> Maybe Name)
-> ((Name, Function op) -> Function op)
-> (Name, Function op)
-> Maybe Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Name, Function op) -> Function op
forall a b. (a, b) -> b
snd) [(Name, Function op)]
funs

    compileProg' :: CompilerM op () ([Definition], [Definition], [Definition])
compileProg' = do
      ([Definition]
memstructs, [[Definition]]
memfuns, [BlockItem]
memreport) <- [(Definition, [Definition], BlockItem)]
-> ([Definition], [[Definition]], [BlockItem])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(Definition, [Definition], BlockItem)]
 -> ([Definition], [[Definition]], [BlockItem]))
-> CompilerM op () [(Definition, [Definition], BlockItem)]
-> CompilerM op () ([Definition], [[Definition]], [BlockItem])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Space -> CompilerM op () (Definition, [Definition], BlockItem))
-> [Space]
-> CompilerM op () [(Definition, [Definition], BlockItem)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Space -> CompilerM op () (Definition, [Definition], BlockItem)
forall op s.
Space -> CompilerM op s (Definition, [Definition], BlockItem)
defineMemorySpace [Space]
spaces

      [BlockItem]
get_consts <- Constants op -> CompilerM op () [BlockItem]
forall op s. Constants op -> CompilerM op s [BlockItem]
compileConstants Constants op
consts

      Type
ctx_ty <- CompilerM op () Type
forall op s. CompilerM op s Type
contextType

      ([Definition]
prototypes, [Func]
functions) <-
        [(Definition, Func)] -> ([Definition], [Func])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Definition, Func)] -> ([Definition], [Func]))
-> CompilerM op () [(Definition, Func)]
-> CompilerM op () ([Definition], [Func])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Name, Function op) -> CompilerM op () (Definition, Func))
-> [(Name, Function op)] -> CompilerM op () [(Definition, Func)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([BlockItem]
-> [Param]
-> (Name, Function op)
-> CompilerM op () (Definition, Func)
forall op s.
[BlockItem]
-> [Param]
-> (Name, Function op)
-> CompilerM op s (Definition, Func)
compileFun [BlockItem]
get_consts [[C.cparam|$ty:ctx_ty *ctx|]]) [(Name, Function op)]
funs

      (Definition -> CompilerM op () ())
-> [Definition] -> CompilerM op () ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Definition -> CompilerM op () ()
forall op s. Definition -> CompilerM op s ()
earlyDecl [Definition]
memstructs
      [Definition]
entry_points <-
        [Maybe Definition] -> [Definition]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe Definition] -> [Definition])
-> CompilerM op () [Maybe Definition]
-> CompilerM op () [Definition]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Name, Function op) -> CompilerM op () (Maybe Definition))
-> [(Name, Function op)] -> CompilerM op () [Maybe Definition]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((Name -> Function op -> CompilerM op () (Maybe Definition))
-> (Name, Function op) -> CompilerM op () (Maybe Definition)
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ([BlockItem]
-> Name -> Function op -> CompilerM op () (Maybe Definition)
forall op s.
[BlockItem]
-> Name -> Function op -> CompilerM op s (Maybe Definition)
onEntryPoint [BlockItem]
get_consts)) [(Name, Function op)]
funs

      CompilerM op () ()
extra

      (Definition -> CompilerM op () ())
-> [Definition] -> CompilerM op () ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Definition -> CompilerM op () ()
forall op s. Definition -> CompilerM op s ()
earlyDecl ([Definition] -> CompilerM op () ())
-> [Definition] -> CompilerM op () ()
forall a b. (a -> b) -> a -> b
$ [[Definition]] -> [Definition]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Definition]]
memfuns

      [BlockItem] -> CompilerM op () ()
forall op s. [BlockItem] -> CompilerM op s ()
commonLibFuns [BlockItem]
memreport

      ([Definition], [Definition], [Definition])
-> CompilerM op () ([Definition], [Definition], [Definition])
forall (m :: * -> *) a. Monad m => a -> m a
return ([Definition]
prototypes, (Func -> Definition) -> [Func] -> [Definition]
forall a b. (a -> b) -> [a] -> [b]
map Func -> Definition
funcToDef [Func]
functions, [Definition]
entry_points)

    funcToDef :: Func -> Definition
funcToDef Func
func = Func -> SrcLoc -> Definition
C.FuncDef Func
func SrcLoc
loc
      where
        loc :: SrcLoc
loc = case Func
func of
          C.OldFunc DeclSpec
_ Id
_ Decl
_ [Id]
_ Maybe [InitGroup]
_ [BlockItem]
_ SrcLoc
l -> SrcLoc
l
          C.Func DeclSpec
_ Id
_ Decl
_ Params
_ [BlockItem]
_ SrcLoc
l -> SrcLoc
l

    builtin :: [Definition]
builtin =
      [Definition]
cIntOps [Definition] -> [Definition] -> [Definition]
forall a. [a] -> [a] -> [a]
++ [Definition]
cFloat32Ops [Definition] -> [Definition] -> [Definition]
forall a. [a] -> [a] -> [a]
++ [Definition]
cFloat64Ops [Definition] -> [Definition] -> [Definition]
forall a. [a] -> [a] -> [a]
++ [Definition]
cFloatConvOps
        [Definition] -> [Definition] -> [Definition]
forall a. [a] -> [a] -> [a]
++ [Definition]
cFloat32Funs
        [Definition] -> [Definition] -> [Definition]
forall a. [a] -> [a] -> [a]
++ [Definition]
cFloat64Funs

    util_h :: String
util_h = $(embedStringFile "rts/c/util.h")
    timing_h :: String
timing_h = $(embedStringFile "rts/c/timing.h")
    lock_h :: String
lock_h = $(embedStringFile "rts/c/lock.h")

commonLibFuns :: [C.BlockItem] -> CompilerM op s ()
commonLibFuns :: forall op s. [BlockItem] -> CompilerM op s ()
commonLibFuns [BlockItem]
memreport = do
  CompilerM op s ()
forall {op} {s}. CompilerM op s ()
generateAPITypes
  Type
ctx <- CompilerM op s Type
forall op s. CompilerM op s Type
contextType
  Operations op s
ops <- (CompilerEnv op s -> Operations op s)
-> CompilerM op s (Operations op s)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks CompilerEnv op s -> Operations op s
forall op s. CompilerEnv op s -> Operations op s
envOperations
  [BlockItem]
profilereport <- (CompilerState s -> [BlockItem]) -> CompilerM op s [BlockItem]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((CompilerState s -> [BlockItem]) -> CompilerM op s [BlockItem])
-> (CompilerState s -> [BlockItem]) -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ DList BlockItem -> [BlockItem]
forall a. DList a -> [a]
DL.toList (DList BlockItem -> [BlockItem])
-> (CompilerState s -> DList BlockItem)
-> CompilerState s
-> [BlockItem]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerState s -> DList BlockItem
forall s. CompilerState s -> DList BlockItem
compProfileItems

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
publicDef_ String
"get_num_sizes" HeaderSection
InitDecl ((String -> (Definition, Definition)) -> CompilerM op s ())
-> (String -> (Definition, Definition)) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ( [C.cedecl|int $id:s(void);|],
      [C.cedecl|int $id:s(void) {
                return sizeof(size_names)/sizeof(size_names[0]);
              }|]
    )

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
publicDef_ String
"get_size_name" HeaderSection
InitDecl ((String -> (Definition, Definition)) -> CompilerM op s ())
-> (String -> (Definition, Definition)) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ( [C.cedecl|const char* $id:s(int);|],
      [C.cedecl|const char* $id:s(int i) {
                return size_names[i];
              }|]
    )

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
publicDef_ String
"get_size_class" HeaderSection
InitDecl ((String -> (Definition, Definition)) -> CompilerM op s ())
-> (String -> (Definition, Definition)) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ( [C.cedecl|const char* $id:s(int);|],
      [C.cedecl|const char* $id:s(int i) {
                return size_classes[i];
              }|]
    )

  String
sync <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName String
"context_sync"
  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
publicDef_ String
"context_report" HeaderSection
MiscDecl ((String -> (Definition, Definition)) -> CompilerM op s ())
-> (String -> (Definition, Definition)) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ( [C.cedecl|char* $id:s($ty:ctx *ctx);|],
      [C.cedecl|char* $id:s($ty:ctx *ctx) {
                 if ($id:sync(ctx) != 0) {
                   return NULL;
                 }

                 struct str_builder builder;
                 str_builder_init(&builder);
                 if (ctx->detail_memory || ctx->profiling || ctx->logging) {
                   $items:memreport
                 }
                 if (ctx->profiling) {
                   $items:profilereport
                 }
                 return builder.str;
               }|]
    )

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
publicDef_ String
"context_get_error" HeaderSection
MiscDecl ((String -> (Definition, Definition)) -> CompilerM op s ())
-> (String -> (Definition, Definition)) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ( [C.cedecl|char* $id:s($ty:ctx* ctx);|],
      [C.cedecl|char* $id:s($ty:ctx* ctx) {
                         char* error = ctx->error;
                         ctx->error = NULL;
                         return error;
                       }|]
    )

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
publicDef_ String
"context_set_logging_file" HeaderSection
MiscDecl ((String -> (Definition, Definition)) -> CompilerM op s ())
-> (String -> (Definition, Definition)) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ( [C.cedecl|void $id:s($ty:ctx* ctx, typename FILE* f);|],
      [C.cedecl|void $id:s($ty:ctx* ctx, typename FILE* f) {
                  ctx->log = f;
                }|]
    )

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
publicDef_ String
"context_pause_profiling" HeaderSection
MiscDecl ((String -> (Definition, Definition)) -> CompilerM op s ())
-> (String -> (Definition, Definition)) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ( [C.cedecl|void $id:s($ty:ctx* ctx);|],
      [C.cedecl|void $id:s($ty:ctx* ctx) {
                 ctx->profiling_paused = 1;
               }|]
    )

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
publicDef_ String
"context_unpause_profiling" HeaderSection
MiscDecl ((String -> (Definition, Definition)) -> CompilerM op s ())
-> (String -> (Definition, Definition)) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ( [C.cedecl|void $id:s($ty:ctx* ctx);|],
      [C.cedecl|void $id:s($ty:ctx* ctx) {
                 ctx->profiling_paused = 0;
               }|]
    )

  [BlockItem]
clears <- (CompilerState s -> [BlockItem]) -> CompilerM op s [BlockItem]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((CompilerState s -> [BlockItem]) -> CompilerM op s [BlockItem])
-> (CompilerState s -> [BlockItem]) -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ DList BlockItem -> [BlockItem]
forall a. DList a -> [a]
DL.toList (DList BlockItem -> [BlockItem])
-> (CompilerState s -> DList BlockItem)
-> CompilerState s
-> [BlockItem]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerState s -> DList BlockItem
forall s. CompilerState s -> DList BlockItem
compClearItems
  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
publicDef_ String
"context_clear_caches" HeaderSection
MiscDecl ((String -> (Definition, Definition)) -> CompilerM op s ())
-> (String -> (Definition, Definition)) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ( [C.cedecl|int $id:s($ty:ctx* ctx);|],
      [C.cedecl|int $id:s($ty:ctx* ctx) {
                         $items:(criticalSection ops clears)
                         return ctx->error != NULL;
                       }|]
    )

compileConstants :: Constants op -> CompilerM op s [C.BlockItem]
compileConstants :: forall op s. Constants op -> CompilerM op s [BlockItem]
compileConstants (Constants [Param]
ps Code op
init_consts) = do
  Type
ctx_ty <- CompilerM op s Type
forall op s. CompilerM op s Type
contextType
  [FieldGroup]
const_fields <- (Param -> CompilerM op s FieldGroup)
-> [Param] -> CompilerM op s [FieldGroup]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Param -> CompilerM op s FieldGroup
forall {op} {s}. Param -> CompilerM op s FieldGroup
constParamField [Param]
ps
  -- Avoid an empty struct, as that is apparently undefined behaviour.
  let const_fields' :: [FieldGroup]
const_fields'
        | [FieldGroup] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [FieldGroup]
const_fields = [[C.csdecl|int dummy;|]]
        | Bool
otherwise = [FieldGroup]
const_fields
  Id -> Type -> Maybe Exp -> CompilerM op s ()
forall op s. Id -> Type -> Maybe Exp -> CompilerM op s ()
contextField Id
"constants" [C.cty|struct { $sdecls:const_fields' }|] Maybe Exp
forall a. Maybe a
Nothing
  Definition -> CompilerM op s ()
forall op s. Definition -> CompilerM op s ()
earlyDecl [C.cedecl|static int init_constants($ty:ctx_ty*);|]
  Definition -> CompilerM op s ()
forall op s. Definition -> CompilerM op s ()
earlyDecl [C.cedecl|static int free_constants($ty:ctx_ty*);|]

  -- We locally define macros for the constants, so that when we
  -- generate assignments to local variables, we actually assign into
  -- the constants struct.  This is not needed for functions, because
  -- they can only read constants, not write them.
  let ([BlockItem]
defs, [BlockItem]
undefs) = [(BlockItem, BlockItem)] -> ([BlockItem], [BlockItem])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(BlockItem, BlockItem)] -> ([BlockItem], [BlockItem]))
-> [(BlockItem, BlockItem)] -> ([BlockItem], [BlockItem])
forall a b. (a -> b) -> a -> b
$ (Param -> (BlockItem, BlockItem))
-> [Param] -> [(BlockItem, BlockItem)]
forall a b. (a -> b) -> [a] -> [b]
map Param -> (BlockItem, BlockItem)
constMacro [Param]
ps
  [BlockItem]
init_consts' <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
blockScope (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ do
    (Param -> CompilerM op s ()) -> [Param] -> CompilerM op s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Param -> CompilerM op s ()
forall {op} {s}. Param -> CompilerM op s ()
resetMemConst [Param]
ps
    Code op -> CompilerM op s ()
forall op s. Code op -> CompilerM op s ()
compileCode Code op
init_consts
  Definition -> CompilerM op s ()
forall op s. Definition -> CompilerM op s ()
libDecl
    [C.cedecl|static int init_constants($ty:ctx_ty *ctx) {
      (void)ctx;
      int err = 0;
      $items:defs
      $items:init_consts'
      $items:undefs
      cleanup:
      return err;
    }|]

  [BlockItem]
free_consts <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
collect (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ (Param -> CompilerM op s ()) -> [Param] -> CompilerM op s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Param -> CompilerM op s ()
forall {op} {s}. Param -> CompilerM op s ()
freeConst [Param]
ps
  Definition -> CompilerM op s ()
forall op s. Definition -> CompilerM op s ()
libDecl
    [C.cedecl|static int free_constants($ty:ctx_ty *ctx) {
      (void)ctx;
      $items:free_consts
      return 0;
    }|]

  (Param -> CompilerM op s BlockItem)
-> [Param] -> CompilerM op s [BlockItem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Param -> CompilerM op s BlockItem
forall {op} {s}. Param -> CompilerM op s BlockItem
getConst [Param]
ps
  where
    constParamField :: Param -> CompilerM op s FieldGroup
constParamField (ScalarParam VName
name PrimType
bt) = do
      let ctp :: Type
ctp = PrimType -> Type
primTypeToCType PrimType
bt
      FieldGroup -> CompilerM op s FieldGroup
forall (m :: * -> *) a. Monad m => a -> m a
return [C.csdecl|$ty:ctp $id:name;|]
    constParamField (MemParam VName
name Space
space) = do
      Type
ty <- VName -> Space -> CompilerM op s Type
forall op s. VName -> Space -> CompilerM op s Type
memToCType VName
name Space
space
      FieldGroup -> CompilerM op s FieldGroup
forall (m :: * -> *) a. Monad m => a -> m a
return [C.csdecl|$ty:ty $id:name;|]

    constMacro :: Param -> (BlockItem, BlockItem)
constMacro Param
p = ([C.citem|$escstm:def|], [C.citem|$escstm:undef|])
      where
        p' :: String
p' = Id -> String
forall a. Pretty a => a -> String
pretty (VName -> SrcLoc -> Id
forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent (Param -> VName
paramName Param
p) SrcLoc
forall a. Monoid a => a
mempty)
        def :: String
def = String
"#define " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
p' String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" (" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"ctx->constants." String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
p' String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"
        undef :: String
undef = String
"#undef " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
p'

    resetMemConst :: Param -> CompilerM op s ()
resetMemConst ScalarParam {} = () -> CompilerM op s ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    resetMemConst (MemParam VName
name Space
space) = VName -> Space -> CompilerM op s ()
forall a op s. ToExp a => a -> Space -> CompilerM op s ()
resetMem VName
name Space
space

    freeConst :: Param -> CompilerM op s ()
freeConst ScalarParam {} = () -> CompilerM op s ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    freeConst (MemParam VName
name Space
space) = Exp -> Space -> CompilerM op s ()
forall a op s. ToExp a => a -> Space -> CompilerM op s ()
unRefMem [C.cexp|ctx->constants.$id:name|] Space
space

    getConst :: Param -> CompilerM op s BlockItem
getConst (ScalarParam VName
name PrimType
bt) = do
      let ctp :: Type
ctp = PrimType -> Type
primTypeToCType PrimType
bt
      BlockItem -> CompilerM op s BlockItem
forall (m :: * -> *) a. Monad m => a -> m a
return [C.citem|$ty:ctp $id:name = ctx->constants.$id:name;|]
    getConst (MemParam VName
name Space
space) = do
      Type
ty <- VName -> Space -> CompilerM op s Type
forall op s. VName -> Space -> CompilerM op s Type
memToCType VName
name Space
space
      BlockItem -> CompilerM op s BlockItem
forall (m :: * -> *) a. Monad m => a -> m a
return [C.citem|$ty:ty $id:name = ctx->constants.$id:name;|]

cachingMemory ::
  M.Map VName Space ->
  ([C.BlockItem] -> [C.Stm] -> CompilerM op s a) ->
  CompilerM op s a
cachingMemory :: forall op s a.
Map VName Space
-> ([BlockItem] -> [Stm] -> CompilerM op s a) -> CompilerM op s a
cachingMemory Map VName Space
lexical [BlockItem] -> [Stm] -> CompilerM op s a
f = do
  -- We only consider lexical 'DefaultSpace' memory blocks to be
  -- cached.  This is not a deep technical restriction, but merely a
  -- heuristic based on GPU memory usually involving larger
  -- allocations, that do not suffer from the overhead of reference
  -- counting.
  let cached :: [VName]
cached = Map VName Space -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName Space -> [VName]) -> Map VName Space -> [VName]
forall a b. (a -> b) -> a -> b
$ (Space -> Bool) -> Map VName Space -> Map VName Space
forall a k. (a -> Bool) -> Map k a -> Map k a
M.filter (Space -> Space -> Bool
forall a. Eq a => a -> a -> Bool
== Space
DefaultSpace) Map VName Space
lexical

  [(VName, VName)]
cached' <- [VName]
-> (VName -> CompilerM op s (VName, VName))
-> CompilerM op s [(VName, VName)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
cached ((VName -> CompilerM op s (VName, VName))
 -> CompilerM op s [(VName, VName)])
-> (VName -> CompilerM op s (VName, VName))
-> CompilerM op s [(VName, VName)]
forall a b. (a -> b) -> a -> b
$ \VName
mem -> do
    VName
size <- String -> CompilerM op s VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> CompilerM op s VName) -> String -> CompilerM op s VName
forall a b. (a -> b) -> a -> b
$ VName -> String
forall a. Pretty a => a -> String
pretty VName
mem String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
"_cached_size"
    (VName, VName) -> CompilerM op s (VName, VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
mem, VName
size)

  let lexMem :: CompilerEnv op s -> CompilerEnv op s
lexMem CompilerEnv op s
env =
        CompilerEnv op s
env
          { envCachedMem :: Map Exp VName
envCachedMem =
              [(Exp, VName)] -> Map Exp VName
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList (((VName, VName) -> (Exp, VName))
-> [(VName, VName)] -> [(Exp, VName)]
forall a b. (a -> b) -> [a] -> [b]
map ((VName -> Exp) -> (VName, VName) -> (Exp, VName)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (VName -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
`C.toExp` SrcLoc
forall a. IsLocation a => a
noLoc)) [(VName, VName)]
cached')
                Map Exp VName -> Map Exp VName -> Map Exp VName
forall a. Semigroup a => a -> a -> a
<> CompilerEnv op s -> Map Exp VName
forall op s. CompilerEnv op s -> Map Exp VName
envCachedMem CompilerEnv op s
env
          }

      declCached :: (a, a) -> [BlockItem]
declCached (a
mem, a
size) =
        [ [C.citem|size_t $id:size = 0;|],
          [C.citem|$ty:defaultMemBlockType $id:mem = NULL;|]
        ]

      freeCached :: (a, b) -> Stm
freeCached (a
mem, b
_) =
        [C.cstm|free($id:mem);|]

  (CompilerEnv op s -> CompilerEnv op s)
-> CompilerM op s a -> CompilerM op s a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local CompilerEnv op s -> CompilerEnv op s
forall {op} {s}. CompilerEnv op s -> CompilerEnv op s
lexMem (CompilerM op s a -> CompilerM op s a)
-> CompilerM op s a -> CompilerM op s a
forall a b. (a -> b) -> a -> b
$ [BlockItem] -> [Stm] -> CompilerM op s a
f (((VName, VName) -> [BlockItem]) -> [(VName, VName)] -> [BlockItem]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (VName, VName) -> [BlockItem]
forall {a} {a}. (ToIdent a, ToIdent a) => (a, a) -> [BlockItem]
declCached [(VName, VName)]
cached') (((VName, VName) -> Stm) -> [(VName, VName)] -> [Stm]
forall a b. (a -> b) -> [a] -> [b]
map (VName, VName) -> Stm
forall {a} {b}. ToIdent a => (a, b) -> Stm
freeCached [(VName, VName)]
cached')

compileFun :: [C.BlockItem] -> [C.Param] -> (Name, Function op) -> CompilerM op s (C.Definition, C.Func)
compileFun :: forall op s.
[BlockItem]
-> [Param]
-> (Name, Function op)
-> CompilerM op s (Definition, Func)
compileFun [BlockItem]
get_constants [Param]
extra (Name
fname, func :: Function op
func@(Function Maybe Name
_ [Param]
outputs [Param]
inputs Code op
body [ExternalValue]
_ [ExternalValue]
_)) = do
  ([Param]
outparams, [Exp]
out_ptrs) <- [(Param, Exp)] -> ([Param], [Exp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Param, Exp)] -> ([Param], [Exp]))
-> CompilerM op s [(Param, Exp)] -> CompilerM op s ([Param], [Exp])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Param -> CompilerM op s (Param, Exp))
-> [Param] -> CompilerM op s [(Param, Exp)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Param -> CompilerM op s (Param, Exp)
forall {op} {s}. Param -> CompilerM op s (Param, Exp)
compileOutput [Param]
outputs
  [Param]
inparams <- (Param -> CompilerM op s Param)
-> [Param] -> CompilerM op s [Param]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Param -> CompilerM op s Param
forall {op} {s}. Param -> CompilerM op s Param
compileInput [Param]
inputs

  Map VName Space
-> ([BlockItem] -> [Stm] -> CompilerM op s (Definition, Func))
-> CompilerM op s (Definition, Func)
forall op s a.
Map VName Space
-> ([BlockItem] -> [Stm] -> CompilerM op s a) -> CompilerM op s a
cachingMemory (Function op -> Map VName Space
forall a. Function a -> Map VName Space
lexicalMemoryUsage Function op
func) (([BlockItem] -> [Stm] -> CompilerM op s (Definition, Func))
 -> CompilerM op s (Definition, Func))
-> ([BlockItem] -> [Stm] -> CompilerM op s (Definition, Func))
-> CompilerM op s (Definition, Func)
forall a b. (a -> b) -> a -> b
$ \[BlockItem]
decl_cached [Stm]
free_cached -> do
    [BlockItem]
body' <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
blockScope (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ [Exp] -> [Param] -> Code op -> CompilerM op s ()
forall op s. [Exp] -> [Param] -> Code op -> CompilerM op s ()
compileFunBody [Exp]
out_ptrs [Param]
outputs Code op
body

    (Definition, Func) -> CompilerM op s (Definition, Func)
forall (m :: * -> *) a. Monad m => a -> m a
return
      ( [C.cedecl|static int $id:(funName fname)($params:extra, $params:outparams, $params:inparams);|],
        [C.cfun|static int $id:(funName fname)($params:extra, $params:outparams, $params:inparams) {
               $stms:ignores
               int err = 0;
               $items:decl_cached
               $items:get_constants
               $items:body'
              cleanup:
               {}
               $stms:free_cached
               return err;
  }|]
      )
  where
    -- Ignore all the boilerplate parameters, just in case we don't
    -- actually need to use them.
    ignores :: [Stm]
ignores = [[C.cstm|(void)$id:p;|] | C.Param (Just Id
p) DeclSpec
_ Decl
_ SrcLoc
_ <- [Param]
extra]

    compileInput :: Param -> CompilerM op s Param
compileInput (ScalarParam VName
name PrimType
bt) = do
      let ctp :: Type
ctp = PrimType -> Type
primTypeToCType PrimType
bt
      Param -> CompilerM op s Param
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cparam|$ty:ctp $id:name|]
    compileInput (MemParam VName
name Space
space) = do
      Type
ty <- VName -> Space -> CompilerM op s Type
forall op s. VName -> Space -> CompilerM op s Type
memToCType VName
name Space
space
      Param -> CompilerM op s Param
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cparam|$ty:ty $id:name|]

    compileOutput :: Param -> CompilerM op s (Param, Exp)
compileOutput (ScalarParam VName
name PrimType
bt) = do
      let ctp :: Type
ctp = PrimType -> Type
primTypeToCType PrimType
bt
      VName
p_name <- String -> CompilerM op s VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> CompilerM op s VName) -> String -> CompilerM op s VName
forall a b. (a -> b) -> a -> b
$ String
"out_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
baseString VName
name
      (Param, Exp) -> CompilerM op s (Param, Exp)
forall (m :: * -> *) a. Monad m => a -> m a
return ([C.cparam|$ty:ctp *$id:p_name|], [C.cexp|$id:p_name|])
    compileOutput (MemParam VName
name Space
space) = do
      Type
ty <- VName -> Space -> CompilerM op s Type
forall op s. VName -> Space -> CompilerM op s Type
memToCType VName
name Space
space
      VName
p_name <- String -> CompilerM op s VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> CompilerM op s VName) -> String -> CompilerM op s VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
name String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_p"
      (Param, Exp) -> CompilerM op s (Param, Exp)
forall (m :: * -> *) a. Monad m => a -> m a
return ([C.cparam|$ty:ty *$id:p_name|], [C.cexp|$id:p_name|])

compilePrimValue :: PrimValue -> C.Exp
compilePrimValue :: PrimValue -> Exp
compilePrimValue (IntValue (Int8Value Int8
k)) = [C.cexp|(typename int8_t)$int:k|]
compilePrimValue (IntValue (Int16Value Int16
k)) = [C.cexp|(typename int16_t)$int:k|]
compilePrimValue (IntValue (Int32Value Int32
k)) = [C.cexp|$int:k|]
compilePrimValue (IntValue (Int64Value Int64
k)) = [C.cexp|(typename int64_t)$int:k|]
compilePrimValue (FloatValue (Float64Value Double
x))
  | Double -> Bool
forall a. RealFloat a => a -> Bool
isInfinite Double
x =
    if Double
x Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
> Double
0 then [C.cexp|INFINITY|] else [C.cexp|-INFINITY|]
  | Double -> Bool
forall a. RealFloat a => a -> Bool
isNaN Double
x =
    [C.cexp|NAN|]
  | Bool
otherwise =
    [C.cexp|$double:x|]
compilePrimValue (FloatValue (Float32Value Float
x))
  | Float -> Bool
forall a. RealFloat a => a -> Bool
isInfinite Float
x =
    if Float
x Float -> Float -> Bool
forall a. Ord a => a -> a -> Bool
> Float
0 then [C.cexp|INFINITY|] else [C.cexp|-INFINITY|]
  | Float -> Bool
forall a. RealFloat a => a -> Bool
isNaN Float
x =
    [C.cexp|NAN|]
  | Bool
otherwise =
    [C.cexp|$float:x|]
compilePrimValue (BoolValue Bool
b) =
  [C.cexp|$int:b'|]
  where
    b' :: Int
    b' :: Int
b' = if Bool
b then Int
1 else Int
0
compilePrimValue PrimValue
UnitValue =
  [C.cexp|0|]

derefPointer :: C.Exp -> C.Exp -> C.Type -> C.Exp
derefPointer :: Exp -> Exp -> Type -> Exp
derefPointer Exp
ptr Exp
i Type
res_t =
  [C.cexp|(($ty:res_t)$exp:ptr)[$exp:i]|]

volQuals :: Volatility -> [C.TypeQual]
volQuals :: Volatility -> [TypeQual]
volQuals Volatility
Volatile = [C.ctyquals|volatile|]
volQuals Volatility
Nonvolatile = []

writeScalarPointerWithQuals :: PointerQuals op s -> WriteScalar op s
writeScalarPointerWithQuals :: forall op s. PointerQuals op s -> WriteScalar op s
writeScalarPointerWithQuals PointerQuals op s
quals_f Exp
dest Exp
i Type
elemtype String
space Volatility
vol Exp
v = do
  [TypeQual]
quals <- PointerQuals op s
quals_f String
space
  let quals' :: [TypeQual]
quals' = Volatility -> [TypeQual]
volQuals Volatility
vol [TypeQual] -> [TypeQual] -> [TypeQual]
forall a. [a] -> [a] -> [a]
++ [TypeQual]
quals
      deref :: Exp
deref =
        Exp -> Exp -> Type -> Exp
derefPointer
          Exp
dest
          Exp
i
          [C.cty|$tyquals:quals' $ty:elemtype*|]
  Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|$exp:deref = $exp:v;|]

readScalarPointerWithQuals :: PointerQuals op s -> ReadScalar op s
readScalarPointerWithQuals :: forall op s. PointerQuals op s -> ReadScalar op s
readScalarPointerWithQuals PointerQuals op s
quals_f Exp
dest Exp
i Type
elemtype String
space Volatility
vol = do
  [TypeQual]
quals <- PointerQuals op s
quals_f String
space
  let quals' :: [TypeQual]
quals' = Volatility -> [TypeQual]
volQuals Volatility
vol [TypeQual] -> [TypeQual] -> [TypeQual]
forall a. [a] -> [a] -> [a]
++ [TypeQual]
quals
  Exp -> CompilerM op s Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> CompilerM op s Exp) -> Exp -> CompilerM op s Exp
forall a b. (a -> b) -> a -> b
$ Exp -> Exp -> Type -> Exp
derefPointer Exp
dest Exp
i [C.cty|$tyquals:quals' $ty:elemtype*|]

compileExpToName :: String -> PrimType -> Exp -> CompilerM op s VName
compileExpToName :: forall op s. String -> PrimType -> Exp -> CompilerM op s VName
compileExpToName String
_ PrimType
_ (LeafExp (ScalarVar VName
v) PrimType
_) =
  VName -> CompilerM op s VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
v
compileExpToName String
desc PrimType
t Exp
e = do
  VName
desc' <- String -> CompilerM op s VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
desc
  Exp
e' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
e
  InitGroup -> CompilerM op s ()
forall op s. InitGroup -> CompilerM op s ()
decl [C.cdecl|$ty:(primTypeToCType t) $id:desc' = $e';|]
  VName -> CompilerM op s VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
desc'

compileExp :: Exp -> CompilerM op s C.Exp
compileExp :: forall op s. Exp -> CompilerM op s Exp
compileExp = (ExpLeaf -> CompilerM op s Exp) -> Exp -> CompilerM op s Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp ExpLeaf -> CompilerM op s Exp
forall {op} {s}. ExpLeaf -> CompilerM op s Exp
compileLeaf
  where
    compileLeaf :: ExpLeaf -> CompilerM op s Exp
compileLeaf (ScalarVar VName
src) =
      Exp -> CompilerM op s Exp
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cexp|$id:src|]
    compileLeaf (Index VName
_ Count Elements (TExp Int64)
_ PrimType
Unit Space
__ Volatility
_) =
      Exp -> CompilerM op s Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> CompilerM op s Exp) -> Exp -> CompilerM op s Exp
forall a b. (a -> b) -> a -> b
$ PrimValue -> Exp
compilePrimValue PrimValue
UnitValue
    compileLeaf (Index VName
src (Count TExp Int64
iexp) PrimType
restype Space
DefaultSpace Volatility
vol) = do
      Exp
src' <- VName -> CompilerM op s Exp
forall op s. VName -> CompilerM op s Exp
rawMem VName
src
      Exp -> Exp -> Type -> Exp
derefPointer Exp
src'
        (Exp -> Type -> Exp)
-> CompilerM op s Exp -> CompilerM op s (Type -> Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp (TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
iexp)
        CompilerM op s (Type -> Exp)
-> CompilerM op s Type -> CompilerM op s Exp
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type -> CompilerM op s Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cty|$tyquals:(volQuals vol) $ty:(primTypeToCType restype)*|]
    compileLeaf (Index VName
src (Count TExp Int64
iexp) PrimType
restype (Space String
space) Volatility
vol) =
      CompilerM op s (CompilerM op s Exp) -> CompilerM op s Exp
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (CompilerM op s (CompilerM op s Exp) -> CompilerM op s Exp)
-> CompilerM op s (CompilerM op s Exp) -> CompilerM op s Exp
forall a b. (a -> b) -> a -> b
$
        (CompilerEnv op s -> ReadScalar op s)
-> CompilerM op s (ReadScalar op s)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks CompilerEnv op s -> ReadScalar op s
forall op s. CompilerEnv op s -> ReadScalar op s
envReadScalar
          CompilerM op s (ReadScalar op s)
-> CompilerM op s Exp
-> CompilerM
     op s (Exp -> Type -> String -> Volatility -> CompilerM op s Exp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> VName -> CompilerM op s Exp
forall op s. VName -> CompilerM op s Exp
rawMem VName
src
          CompilerM
  op s (Exp -> Type -> String -> Volatility -> CompilerM op s Exp)
-> CompilerM op s Exp
-> CompilerM
     op s (Type -> String -> Volatility -> CompilerM op s Exp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp (TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
iexp)
          CompilerM op s (Type -> String -> Volatility -> CompilerM op s Exp)
-> CompilerM op s Type
-> CompilerM op s (String -> Volatility -> CompilerM op s Exp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type -> CompilerM op s Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PrimType -> Type
primTypeToCType PrimType
restype)
          CompilerM op s (String -> Volatility -> CompilerM op s Exp)
-> CompilerM op s String
-> CompilerM op s (Volatility -> CompilerM op s Exp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> String -> CompilerM op s String
forall (f :: * -> *) a. Applicative f => a -> f a
pure String
space
          CompilerM op s (Volatility -> CompilerM op s Exp)
-> CompilerM op s Volatility -> CompilerM op s (CompilerM op s Exp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Volatility -> CompilerM op s Volatility
forall (f :: * -> *) a. Applicative f => a -> f a
pure Volatility
vol
    compileLeaf (Index VName
src (Count TExp Int64
iexp) PrimType
_ ScalarSpace {} Volatility
_) = do
      Exp
iexp' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp (Exp -> CompilerM op s Exp) -> Exp -> CompilerM op s Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
iexp
      Exp -> CompilerM op s Exp
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cexp|$id:src[$exp:iexp']|]

-- | Tell me how to compile a @v@, and I'll Compile any @PrimExp v@ for you.
compilePrimExp :: Monad m => (v -> m C.Exp) -> PrimExp v -> m C.Exp
compilePrimExp :: forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp v -> m Exp
_ (ValueExp PrimValue
val) =
  Exp -> m Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> m Exp) -> Exp -> m Exp
forall a b. (a -> b) -> a -> b
$ PrimValue -> Exp
compilePrimValue PrimValue
val
compilePrimExp v -> m Exp
f (LeafExp v
v PrimType
_) =
  v -> m Exp
f v
v
compilePrimExp v -> m Exp
f (UnOpExp Complement {} PrimExp v
x) = do
  Exp
x' <- (v -> m Exp) -> PrimExp v -> m Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp v -> m Exp
f PrimExp v
x
  Exp -> m Exp
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cexp|~$exp:x'|]
compilePrimExp v -> m Exp
f (UnOpExp Not {} PrimExp v
x) = do
  Exp
x' <- (v -> m Exp) -> PrimExp v -> m Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp v -> m Exp
f PrimExp v
x
  Exp -> m Exp
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cexp|!$exp:x'|]
compilePrimExp v -> m Exp
f (UnOpExp Abs {} PrimExp v
x) = do
  Exp
x' <- (v -> m Exp) -> PrimExp v -> m Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp v -> m Exp
f PrimExp v
x
  Exp -> m Exp
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cexp|abs($exp:x')|]
compilePrimExp v -> m Exp
f (UnOpExp (FAbs FloatType
Float32) PrimExp v
x) = do
  Exp
x' <- (v -> m Exp) -> PrimExp v -> m Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp v -> m Exp
f PrimExp v
x
  Exp -> m Exp
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cexp|(float)fabs($exp:x')|]
compilePrimExp v -> m Exp
f (UnOpExp (FAbs FloatType
Float64) PrimExp v
x) = do
  Exp
x' <- (v -> m Exp) -> PrimExp v -> m Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp v -> m Exp
f PrimExp v
x
  Exp -> m Exp
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cexp|fabs($exp:x')|]
compilePrimExp v -> m Exp
f (UnOpExp SSignum {} PrimExp v
x) = do
  Exp
x' <- (v -> m Exp) -> PrimExp v -> m Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp v -> m Exp
f PrimExp v
x
  Exp -> m Exp
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cexp|($exp:x' > 0) - ($exp:x' < 0)|]
compilePrimExp v -> m Exp
f (UnOpExp USignum {} PrimExp v
x) = do
  Exp
x' <- (v -> m Exp) -> PrimExp v -> m Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp v -> m Exp
f PrimExp v
x
  Exp -> m Exp
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cexp|($exp:x' > 0) - ($exp:x' < 0) != 0|]
compilePrimExp v -> m Exp
f (UnOpExp (FSignum FloatType
Float32) PrimExp v
x) = do
  Exp
x' <- (v -> m Exp) -> PrimExp v -> m Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp v -> m Exp
f PrimExp v
x
  Exp -> m Exp
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cexp|fsignum32($exp:x')|]
compilePrimExp v -> m Exp
f (UnOpExp (FSignum FloatType
Float64) PrimExp v
x) = do
  Exp
x' <- (v -> m Exp) -> PrimExp v -> m Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp v -> m Exp
f PrimExp v
x
  Exp -> m Exp
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cexp|fsignum32($exp:x')|]
compilePrimExp v -> m Exp
f (CmpOpExp CmpOp
cmp PrimExp v
x PrimExp v
y) = do
  Exp
x' <- (v -> m Exp) -> PrimExp v -> m Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp v -> m Exp
f PrimExp v
x
  Exp
y' <- (v -> m Exp) -> PrimExp v -> m Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp v -> m Exp
f PrimExp v
y
  Exp -> m Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> m Exp) -> Exp -> m Exp
forall a b. (a -> b) -> a -> b
$ case CmpOp
cmp of
    CmpEq {} -> [C.cexp|$exp:x' == $exp:y'|]
    FCmpLt {} -> [C.cexp|$exp:x' < $exp:y'|]
    FCmpLe {} -> [C.cexp|$exp:x' <= $exp:y'|]
    CmpLlt {} -> [C.cexp|$exp:x' < $exp:y'|]
    CmpLle {} -> [C.cexp|$exp:x' <= $exp:y'|]
    CmpOp
_ -> [C.cexp|$id:(pretty cmp)($exp:x', $exp:y')|]
compilePrimExp v -> m Exp
f (ConvOpExp ConvOp
conv PrimExp v
x) = do
  Exp
x' <- (v -> m Exp) -> PrimExp v -> m Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp v -> m Exp
f PrimExp v
x
  Exp -> m Exp
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cexp|$id:(pretty conv)($exp:x')|]
compilePrimExp v -> m Exp
f (BinOpExp BinOp
bop PrimExp v
x PrimExp v
y) = do
  Exp
x' <- (v -> m Exp) -> PrimExp v -> m Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp v -> m Exp
f PrimExp v
x
  Exp
y' <- (v -> m Exp) -> PrimExp v -> m Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp v -> m Exp
f PrimExp v
y
  -- Note that integer addition, subtraction, and multiplication with
  -- OverflowWrap are not handled by explicit operators, but rather by
  -- functions.  This is because we want to implicitly convert them to
  -- unsigned numbers, so we can do overflow without invoking
  -- undefined behaviour.
  Exp -> m Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> m Exp) -> Exp -> m Exp
forall a b. (a -> b) -> a -> b
$ case BinOp
bop of
    Add IntType
_ Overflow
OverflowUndef -> [C.cexp|$exp:x' + $exp:y'|]
    Sub IntType
_ Overflow
OverflowUndef -> [C.cexp|$exp:x' - $exp:y'|]
    Mul IntType
_ Overflow
OverflowUndef -> [C.cexp|$exp:x' * $exp:y'|]
    FAdd {} -> [C.cexp|$exp:x' + $exp:y'|]
    FSub {} -> [C.cexp|$exp:x' - $exp:y'|]
    FMul {} -> [C.cexp|$exp:x' * $exp:y'|]
    FDiv {} -> [C.cexp|$exp:x' / $exp:y'|]
    Xor {} -> [C.cexp|$exp:x' ^ $exp:y'|]
    And {} -> [C.cexp|$exp:x' & $exp:y'|]
    Or {} -> [C.cexp|$exp:x' | $exp:y'|]
    Shl {} -> [C.cexp|$exp:x' << $exp:y'|]
    LogAnd {} -> [C.cexp|$exp:x' && $exp:y'|]
    LogOr {} -> [C.cexp|$exp:x' || $exp:y'|]
    BinOp
_ -> [C.cexp|$id:(pretty bop)($exp:x', $exp:y')|]
compilePrimExp v -> m Exp
f (FunExp String
h [PrimExp v]
args PrimType
_) = do
  [Exp]
args' <- (PrimExp v -> m Exp) -> [PrimExp v] -> m [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((v -> m Exp) -> PrimExp v -> m Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp v -> m Exp
f) [PrimExp v]
args
  Exp -> m Exp
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cexp|$id:(funName (nameFromString h))($args:args')|]

linearCode :: Code op -> [Code op]
linearCode :: forall op. Code op -> [Code op]
linearCode = [Code op] -> [Code op]
forall a. [a] -> [a]
reverse ([Code op] -> [Code op])
-> (Code op -> [Code op]) -> Code op -> [Code op]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Code op] -> Code op -> [Code op]
forall {a}. [Code a] -> Code a -> [Code a]
go []
  where
    go :: [Code a] -> Code a -> [Code a]
go [Code a]
acc (Code a
x :>>: Code a
y) =
      [Code a] -> Code a -> [Code a]
go ([Code a] -> Code a -> [Code a]
go [Code a]
acc Code a
x) Code a
y
    go [Code a]
acc Code a
x = Code a
x Code a -> [Code a] -> [Code a]
forall a. a -> [a] -> [a]
: [Code a]
acc

compileCode :: Code op -> CompilerM op s ()
compileCode :: forall op s. Code op -> CompilerM op s ()
compileCode (Op op
op) =
  CompilerM op s (CompilerM op s ()) -> CompilerM op s ()
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (CompilerM op s (CompilerM op s ()) -> CompilerM op s ())
-> CompilerM op s (CompilerM op s ()) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ (CompilerEnv op s -> OpCompiler op s)
-> CompilerM op s (OpCompiler op s)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks CompilerEnv op s -> OpCompiler op s
forall op s. CompilerEnv op s -> OpCompiler op s
envOpCompiler CompilerM op s (OpCompiler op s)
-> CompilerM op s op -> CompilerM op s (CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> op -> CompilerM op s op
forall (f :: * -> *) a. Applicative f => a -> f a
pure op
op
compileCode Code op
Skip = () -> CompilerM op s ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
compileCode (Comment String
s Code op
code) = do
  [BlockItem]
xs <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
blockScope (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ Code op -> CompilerM op s ()
forall op s. Code op -> CompilerM op s ()
compileCode Code op
code
  let comment :: String
comment = String
"// " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
s
  Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm
    [C.cstm|$comment:comment
              { $items:xs }
             |]
compileCode (DebugPrint String
s (Just Exp
e)) = do
  Exp
e' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
e
  Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm
    [C.cstm|if (ctx->debugging) {
          fprintf(ctx->log, $string:fmtstr, $exp:s, ($ty:ety)$exp:e', '\n');
       }|]
  where
    (String
fmt, Type
ety) = case Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType Exp
e of
      IntType IntType
_ -> (String
"llu", [C.cty|long long int|])
      FloatType FloatType
_ -> (String
"f", [C.cty|double|])
      PrimType
_ -> (String
"d", [C.cty|int|])
    fmtstr :: String
fmtstr = String
"%s: %" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
fmt String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"%c"
compileCode (DebugPrint String
s Maybe Exp
Nothing) =
  Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm
    [C.cstm|if (ctx->debugging) {
          fprintf(ctx->log, "%s\n", $exp:s);
       }|]
-- :>>: is treated in a special way to detect declare-set pairs in
-- order to generate prettier code.
compileCode (Code op
c1 :>>: Code op
c2) = [Code op] -> CompilerM op s ()
forall {op} {s}. [Code op] -> CompilerM op s ()
go (Code op -> [Code op]
forall op. Code op -> [Code op]
linearCode (Code op
c1 Code op -> Code op -> Code op
forall a. Code a -> Code a -> Code a
:>>: Code op
c2))
  where
    go :: [Code op] -> CompilerM op s ()
go (DeclareScalar VName
name Volatility
vol PrimType
t : SetScalar VName
dest Exp
e : [Code op]
code)
      | VName
name VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
dest = do
        let ct :: Type
ct = PrimType -> Type
primTypeToCType PrimType
t
        Exp
e' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
e
        BlockItem -> CompilerM op s ()
forall op s. BlockItem -> CompilerM op s ()
item [C.citem|$tyquals:(volQuals vol) $ty:ct $id:name = $exp:e';|]
        [Code op] -> CompilerM op s ()
go [Code op]
code
    go (Code op
x : [Code op]
xs) = Code op -> CompilerM op s ()
forall op s. Code op -> CompilerM op s ()
compileCode Code op
x CompilerM op s () -> CompilerM op s () -> CompilerM op s ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> [Code op] -> CompilerM op s ()
go [Code op]
xs
    go [] = () -> CompilerM op s ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
compileCode (Assert Exp
e ErrorMsg Exp
msg (SrcLoc
loc, [SrcLoc]
locs)) = do
  Exp
e' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
e
  [BlockItem]
err <-
    CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
collect (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$
      CompilerM op s (CompilerM op s ()) -> CompilerM op s ()
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (CompilerM op s (CompilerM op s ()) -> CompilerM op s ())
-> CompilerM op s (CompilerM op s ()) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$
        (CompilerEnv op s -> ErrorCompiler op s)
-> CompilerM op s (ErrorCompiler op s)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (Operations op s -> ErrorCompiler op s
forall op s. Operations op s -> ErrorCompiler op s
opsError (Operations op s -> ErrorCompiler op s)
-> (CompilerEnv op s -> Operations op s)
-> CompilerEnv op s
-> ErrorCompiler op s
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerEnv op s -> Operations op s
forall op s. CompilerEnv op s -> Operations op s
envOperations) CompilerM op s (ErrorCompiler op s)
-> CompilerM op s (ErrorMsg Exp)
-> CompilerM op s (String -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ErrorMsg Exp -> CompilerM op s (ErrorMsg Exp)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ErrorMsg Exp
msg CompilerM op s (String -> CompilerM op s ())
-> CompilerM op s String -> CompilerM op s (CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> String -> CompilerM op s String
forall (f :: * -> *) a. Applicative f => a -> f a
pure String
stacktrace
  Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|if (!$exp:e') { $items:err }|]
  where
    stacktrace :: String
stacktrace = Int -> [String] -> String
prettyStacktrace Int
0 ([String] -> String) -> [String] -> String
forall a b. (a -> b) -> a -> b
$ (SrcLoc -> String) -> [SrcLoc] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map SrcLoc -> String
forall a. Located a => a -> String
locStr ([SrcLoc] -> [String]) -> [SrcLoc] -> [String]
forall a b. (a -> b) -> a -> b
$ SrcLoc
loc SrcLoc -> [SrcLoc] -> [SrcLoc]
forall a. a -> [a] -> [a]
: [SrcLoc]
locs
compileCode (Allocate VName
_ Count Bytes (TExp Int64)
_ ScalarSpace {}) =
  -- Handled by the declaration of the memory block, which is
  -- translated to an actual array.
  () -> CompilerM op s ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
compileCode (Allocate VName
name (Count (TPrimExp Exp
e)) Space
space) = do
  Exp
size <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
e
  Maybe VName
cached <- VName -> CompilerM op s (Maybe VName)
forall a op s. ToExp a => a -> CompilerM op s (Maybe VName)
cacheMem VName
name
  case Maybe VName
cached of
    Just VName
cur_size ->
      Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm
        [C.cstm|if ($exp:cur_size < (size_t)$exp:size) {
                    $exp:name = realloc($exp:name, $exp:size);
                    $exp:cur_size = $exp:size;
                  }|]
    Maybe VName
_ ->
      VName -> Exp -> Space -> Stm -> CompilerM op s ()
forall a b op s.
(ToExp a, ToExp b) =>
a -> b -> Space -> Stm -> CompilerM op s ()
allocMem VName
name Exp
size Space
space [C.cstm|{err = 1; goto cleanup;}|]
compileCode (Free VName
name Space
space) = do
  Bool
cached <- Maybe VName -> Bool
forall a. Maybe a -> Bool
isJust (Maybe VName -> Bool)
-> CompilerM op s (Maybe VName) -> CompilerM op s Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> CompilerM op s (Maybe VName)
forall a op s. ToExp a => a -> CompilerM op s (Maybe VName)
cacheMem VName
name
  Bool -> CompilerM op s () -> CompilerM op s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
cached (CompilerM op s () -> CompilerM op s ())
-> CompilerM op s () -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ VName -> Space -> CompilerM op s ()
forall a op s. ToExp a => a -> Space -> CompilerM op s ()
unRefMem VName
name Space
space
compileCode (For VName
i Exp
bound Code op
body) = do
  let i' :: SrcLoc -> Id
i' = VName -> SrcLoc -> Id
forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent VName
i
      t :: Type
t = PrimType -> Type
primTypeToCType (PrimType -> Type) -> PrimType -> Type
forall a b. (a -> b) -> a -> b
$ Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType Exp
bound
  Exp
bound' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
bound
  [BlockItem]
body' <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
blockScope (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ Code op -> CompilerM op s ()
forall op s. Code op -> CompilerM op s ()
compileCode Code op
body
  Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm
    [C.cstm|for ($ty:t $id:i' = 0; $id:i' < $exp:bound'; $id:i'++) {
            $items:body'
          }|]
compileCode (While TExp Bool
cond Code op
body) = do
  Exp
cond' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp (Exp -> CompilerM op s Exp) -> Exp -> CompilerM op s Exp
forall a b. (a -> b) -> a -> b
$ TExp Bool -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Bool
cond
  [BlockItem]
body' <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
blockScope (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ Code op -> CompilerM op s ()
forall op s. Code op -> CompilerM op s ()
compileCode Code op
body
  Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm
    [C.cstm|while ($exp:cond') {
            $items:body'
          }|]
compileCode (If TExp Bool
cond Code op
tbranch Code op
fbranch) = do
  Exp
cond' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp (Exp -> CompilerM op s Exp) -> Exp -> CompilerM op s Exp
forall a b. (a -> b) -> a -> b
$ TExp Bool -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Bool
cond
  [BlockItem]
tbranch' <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
blockScope (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ Code op -> CompilerM op s ()
forall op s. Code op -> CompilerM op s ()
compileCode Code op
tbranch
  [BlockItem]
fbranch' <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
blockScope (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ Code op -> CompilerM op s ()
forall op s. Code op -> CompilerM op s ()
compileCode Code op
fbranch
  Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm (Stm -> CompilerM op s ()) -> Stm -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ case ([BlockItem]
tbranch', [BlockItem]
fbranch') of
    ([BlockItem]
_, []) ->
      [C.cstm|if ($exp:cond') { $items:tbranch' }|]
    ([], [BlockItem]
_) ->
      [C.cstm|if (!($exp:cond')) { $items:fbranch' }|]
    ([BlockItem], [BlockItem])
_ ->
      [C.cstm|if ($exp:cond') { $items:tbranch' } else { $items:fbranch' }|]
compileCode (Copy VName
dest (Count TExp Int64
destoffset) Space
DefaultSpace VName
src (Count TExp Int64
srcoffset) Space
DefaultSpace (Count TExp Int64
size)) =
  CompilerM op s (CompilerM op s ()) -> CompilerM op s ()
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (CompilerM op s (CompilerM op s ()) -> CompilerM op s ())
-> CompilerM op s (CompilerM op s ()) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$
    Exp -> Exp -> Exp -> Exp -> Exp -> CompilerM op s ()
forall op s. Exp -> Exp -> Exp -> Exp -> Exp -> CompilerM op s ()
copyMemoryDefaultSpace
      (Exp -> Exp -> Exp -> Exp -> Exp -> CompilerM op s ())
-> CompilerM op s Exp
-> CompilerM op s (Exp -> Exp -> Exp -> Exp -> CompilerM op s ())
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> CompilerM op s Exp
forall op s. VName -> CompilerM op s Exp
rawMem VName
dest
      CompilerM op s (Exp -> Exp -> Exp -> Exp -> CompilerM op s ())
-> CompilerM op s Exp
-> CompilerM op s (Exp -> Exp -> Exp -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp (TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
destoffset)
      CompilerM op s (Exp -> Exp -> Exp -> CompilerM op s ())
-> CompilerM op s Exp
-> CompilerM op s (Exp -> Exp -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> VName -> CompilerM op s Exp
forall op s. VName -> CompilerM op s Exp
rawMem VName
src
      CompilerM op s (Exp -> Exp -> CompilerM op s ())
-> CompilerM op s Exp -> CompilerM op s (Exp -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp (TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
srcoffset)
      CompilerM op s (Exp -> CompilerM op s ())
-> CompilerM op s Exp -> CompilerM op s (CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp (TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
size)
compileCode (Copy VName
dest (Count TExp Int64
destoffset) Space
destspace VName
src (Count TExp Int64
srcoffset) Space
srcspace (Count TExp Int64
size)) = do
  Copy op s
copy <- (CompilerEnv op s -> Copy op s) -> CompilerM op s (Copy op s)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks CompilerEnv op s -> Copy op s
forall op s. CompilerEnv op s -> Copy op s
envCopy
  CompilerM op s (CompilerM op s ()) -> CompilerM op s ()
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (CompilerM op s (CompilerM op s ()) -> CompilerM op s ())
-> CompilerM op s (CompilerM op s ()) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$
    Copy op s
copy
      Copy op s
-> CompilerM op s Exp
-> CompilerM
     op
     s
     (Exp -> Space -> Exp -> Exp -> Space -> Exp -> CompilerM op s ())
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> CompilerM op s Exp
forall op s. VName -> CompilerM op s Exp
rawMem VName
dest
      CompilerM
  op
  s
  (Exp -> Space -> Exp -> Exp -> Space -> Exp -> CompilerM op s ())
-> CompilerM op s Exp
-> CompilerM
     op s (Space -> Exp -> Exp -> Space -> Exp -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp (TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
destoffset)
      CompilerM
  op s (Space -> Exp -> Exp -> Space -> Exp -> CompilerM op s ())
-> CompilerM op s Space
-> CompilerM op s (Exp -> Exp -> Space -> Exp -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Space -> CompilerM op s Space
forall (f :: * -> *) a. Applicative f => a -> f a
pure Space
destspace
      CompilerM op s (Exp -> Exp -> Space -> Exp -> CompilerM op s ())
-> CompilerM op s Exp
-> CompilerM op s (Exp -> Space -> Exp -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> VName -> CompilerM op s Exp
forall op s. VName -> CompilerM op s Exp
rawMem VName
src
      CompilerM op s (Exp -> Space -> Exp -> CompilerM op s ())
-> CompilerM op s Exp
-> CompilerM op s (Space -> Exp -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp (TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
srcoffset)
      CompilerM op s (Space -> Exp -> CompilerM op s ())
-> CompilerM op s Space
-> CompilerM op s (Exp -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Space -> CompilerM op s Space
forall (f :: * -> *) a. Applicative f => a -> f a
pure Space
srcspace
      CompilerM op s (Exp -> CompilerM op s ())
-> CompilerM op s Exp -> CompilerM op s (CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp (TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
size)
compileCode (Write VName
_ Count Elements (TExp Int64)
_ PrimType
Unit Space
_ Volatility
_ Exp
_) = () -> CompilerM op s ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
compileCode (Write VName
dest (Count TExp Int64
idx) PrimType
elemtype Space
DefaultSpace Volatility
vol Exp
elemexp) = do
  Exp
dest' <- VName -> CompilerM op s Exp
forall op s. VName -> CompilerM op s Exp
rawMem VName
dest
  Exp
deref <-
    Exp -> Exp -> Type -> Exp
derefPointer Exp
dest'
      (Exp -> Type -> Exp)
-> CompilerM op s Exp -> CompilerM op s (Type -> Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp (TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
idx)
      CompilerM op s (Type -> Exp)
-> CompilerM op s Type -> CompilerM op s Exp
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type -> CompilerM op s Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cty|$tyquals:(volQuals vol) $ty:(primTypeToCType elemtype)*|]
  Exp
elemexp' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
elemexp
  Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|$exp:deref = $exp:elemexp';|]
compileCode (Write VName
dest (Count TExp Int64
idx) PrimType
_ ScalarSpace {} Volatility
_ Exp
elemexp) = do
  Exp
idx' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp (TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
idx)
  Exp
elemexp' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
elemexp
  Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|$id:dest[$exp:idx'] = $exp:elemexp';|]
compileCode (Write VName
dest (Count TExp Int64
idx) PrimType
elemtype (Space String
space) Volatility
vol Exp
elemexp) =
  CompilerM op s (CompilerM op s ()) -> CompilerM op s ()
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (CompilerM op s (CompilerM op s ()) -> CompilerM op s ())
-> CompilerM op s (CompilerM op s ()) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$
    (CompilerEnv op s -> WriteScalar op s)
-> CompilerM op s (WriteScalar op s)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks CompilerEnv op s -> WriteScalar op s
forall op s. CompilerEnv op s -> WriteScalar op s
envWriteScalar
      CompilerM op s (WriteScalar op s)
-> CompilerM op s Exp
-> CompilerM
     op
     s
     (Exp -> Type -> String -> Volatility -> Exp -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> VName -> CompilerM op s Exp
forall op s. VName -> CompilerM op s Exp
rawMem VName
dest
      CompilerM
  op
  s
  (Exp -> Type -> String -> Volatility -> Exp -> CompilerM op s ())
-> CompilerM op s Exp
-> CompilerM
     op s (Type -> String -> Volatility -> Exp -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp (TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
idx)
      CompilerM
  op s (Type -> String -> Volatility -> Exp -> CompilerM op s ())
-> CompilerM op s Type
-> CompilerM
     op s (String -> Volatility -> Exp -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type -> CompilerM op s Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PrimType -> Type
primTypeToCType PrimType
elemtype)
      CompilerM op s (String -> Volatility -> Exp -> CompilerM op s ())
-> CompilerM op s String
-> CompilerM op s (Volatility -> Exp -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> String -> CompilerM op s String
forall (f :: * -> *) a. Applicative f => a -> f a
pure String
space
      CompilerM op s (Volatility -> Exp -> CompilerM op s ())
-> CompilerM op s Volatility
-> CompilerM op s (Exp -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Volatility -> CompilerM op s Volatility
forall (f :: * -> *) a. Applicative f => a -> f a
pure Volatility
vol
      CompilerM op s (Exp -> CompilerM op s ())
-> CompilerM op s Exp -> CompilerM op s (CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
elemexp
compileCode (DeclareMem VName
name Space
space) =
  VName -> Space -> CompilerM op s ()
forall op s. VName -> Space -> CompilerM op s ()
declMem VName
name Space
space
compileCode (DeclareScalar VName
name Volatility
vol PrimType
t) = do
  let ct :: Type
ct = PrimType -> Type
primTypeToCType PrimType
t
  InitGroup -> CompilerM op s ()
forall op s. InitGroup -> CompilerM op s ()
decl [C.cdecl|$tyquals:(volQuals vol) $ty:ct $id:name;|]
compileCode (DeclareArray VName
name ScalarSpace {} PrimType
_ ArrayContents
_) =
  String -> CompilerM op s ()
forall a. HasCallStack => String -> a
error (String -> CompilerM op s ()) -> String -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ String
"Cannot declare array " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
name String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" in scalar space."
compileCode (DeclareArray VName
name Space
DefaultSpace PrimType
t ArrayContents
vs) = do
  VName
name_realtype <- String -> CompilerM op s VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> CompilerM op s VName) -> String -> CompilerM op s VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
name String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_realtype"
  let ct :: Type
ct = PrimType -> Type
primTypeToCType PrimType
t
  case ArrayContents
vs of
    ArrayValues [PrimValue]
vs' -> do
      let vs'' :: [Initializer]
vs'' = [[C.cinit|$exp:(compilePrimValue v)|] | PrimValue
v <- [PrimValue]
vs']
      Definition -> CompilerM op s ()
forall op s. Definition -> CompilerM op s ()
earlyDecl [C.cedecl|static $ty:ct $id:name_realtype[$int:(length vs')] = {$inits:vs''};|]
    ArrayZeros Int
n ->
      Definition -> CompilerM op s ()
forall op s. Definition -> CompilerM op s ()
earlyDecl [C.cedecl|static $ty:ct $id:name_realtype[$int:n];|]
  -- Fake a memory block.
  Id -> Type -> Maybe Exp -> CompilerM op s ()
forall op s. Id -> Type -> Maybe Exp -> CompilerM op s ()
contextField
    (VName -> SrcLoc -> Id
forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent VName
name SrcLoc
forall a. IsLocation a => a
noLoc)
    [C.cty|struct memblock|]
    (Maybe Exp -> CompilerM op s ()) -> Maybe Exp -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just [C.cexp|(struct memblock){NULL, (char*)$id:name_realtype, 0}|]
  BlockItem -> CompilerM op s ()
forall op s. BlockItem -> CompilerM op s ()
item [C.citem|struct memblock $id:name = ctx->$id:name;|]
compileCode (DeclareArray VName
name (Space String
space) PrimType
t ArrayContents
vs) =
  CompilerM op s (CompilerM op s ()) -> CompilerM op s ()
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (CompilerM op s (CompilerM op s ()) -> CompilerM op s ())
-> CompilerM op s (CompilerM op s ()) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$
    (CompilerEnv op s -> StaticArray op s)
-> CompilerM op s (StaticArray op s)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks CompilerEnv op s -> StaticArray op s
forall op s. CompilerEnv op s -> StaticArray op s
envStaticArray
      CompilerM op s (StaticArray op s)
-> CompilerM op s VName
-> CompilerM
     op s (String -> PrimType -> ArrayContents -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> VName -> CompilerM op s VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
name
      CompilerM
  op s (String -> PrimType -> ArrayContents -> CompilerM op s ())
-> CompilerM op s String
-> CompilerM op s (PrimType -> ArrayContents -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> String -> CompilerM op s String
forall (f :: * -> *) a. Applicative f => a -> f a
pure String
space
      CompilerM op s (PrimType -> ArrayContents -> CompilerM op s ())
-> CompilerM op s PrimType
-> CompilerM op s (ArrayContents -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> PrimType -> CompilerM op s PrimType
forall (f :: * -> *) a. Applicative f => a -> f a
pure PrimType
t
      CompilerM op s (ArrayContents -> CompilerM op s ())
-> CompilerM op s ArrayContents
-> CompilerM op s (CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ArrayContents -> CompilerM op s ArrayContents
forall (f :: * -> *) a. Applicative f => a -> f a
pure ArrayContents
vs
-- For assignments of the form 'x = x OP e', we generate C assignment
-- operators to make the resulting code slightly nicer.  This has no
-- effect on performance.
compileCode (SetScalar VName
dest (BinOpExp BinOp
op (LeafExp (ScalarVar VName
x) PrimType
_) Exp
y))
  | VName
dest VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
x,
    Just VName -> Exp -> Exp
f <- BinOp -> Maybe (VName -> Exp -> Exp)
assignmentOperator BinOp
op = do
    Exp
y' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
y
    Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|$exp:(f dest y');|]
compileCode (SetScalar VName
dest Exp
src) = do
  Exp
src' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
src
  Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|$id:dest = $exp:src';|]
compileCode (SetMem VName
dest VName
src Space
space) =
  VName -> VName -> Space -> CompilerM op s ()
forall a b op s.
(ToExp a, ToExp b) =>
a -> b -> Space -> CompilerM op s ()
setMem VName
dest VName
src Space
space
compileCode (Call [VName]
dests Name
fname [Arg]
args) =
  CompilerM op s (CompilerM op s ()) -> CompilerM op s ()
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (CompilerM op s (CompilerM op s ()) -> CompilerM op s ())
-> CompilerM op s (CompilerM op s ()) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$
    (CompilerEnv op s -> CallCompiler op s)
-> CompilerM op s (CallCompiler op s)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (Operations op s -> CallCompiler op s
forall op s. Operations op s -> CallCompiler op s
opsCall (Operations op s -> CallCompiler op s)
-> (CompilerEnv op s -> Operations op s)
-> CompilerEnv op s
-> CallCompiler op s
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerEnv op s -> Operations op s
forall op s. CompilerEnv op s -> Operations op s
envOperations)
      CompilerM op s (CallCompiler op s)
-> CompilerM op s [VName]
-> CompilerM op s (Name -> [Exp] -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [VName] -> CompilerM op s [VName]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName]
dests
      CompilerM op s (Name -> [Exp] -> CompilerM op s ())
-> CompilerM op s Name
-> CompilerM op s ([Exp] -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Name -> CompilerM op s Name
forall (f :: * -> *) a. Applicative f => a -> f a
pure Name
fname
      CompilerM op s ([Exp] -> CompilerM op s ())
-> CompilerM op s [Exp] -> CompilerM op s (CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Arg -> CompilerM op s Exp) -> [Arg] -> CompilerM op s [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Arg -> CompilerM op s Exp
forall {op} {s}. Arg -> CompilerM op s Exp
compileArg [Arg]
args
  where
    compileArg :: Arg -> CompilerM op s Exp
compileArg (MemArg VName
m) = Exp -> CompilerM op s Exp
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cexp|$exp:m|]
    compileArg (ExpArg Exp
e) = Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
e

blockScope :: CompilerM op s () -> CompilerM op s [C.BlockItem]
blockScope :: forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
blockScope = (((), [BlockItem]) -> [BlockItem])
-> CompilerM op s ((), [BlockItem]) -> CompilerM op s [BlockItem]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((), [BlockItem]) -> [BlockItem]
forall a b. (a, b) -> b
snd (CompilerM op s ((), [BlockItem]) -> CompilerM op s [BlockItem])
-> (CompilerM op s () -> CompilerM op s ((), [BlockItem]))
-> CompilerM op s ()
-> CompilerM op s [BlockItem]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerM op s () -> CompilerM op s ((), [BlockItem])
forall op s a. CompilerM op s a -> CompilerM op s (a, [BlockItem])
blockScope'

blockScope' :: CompilerM op s a -> CompilerM op s (a, [C.BlockItem])
blockScope' :: forall op s a. CompilerM op s a -> CompilerM op s (a, [BlockItem])
blockScope' CompilerM op s a
m = do
  [(VName, Space)]
old_allocs <- (CompilerState s -> [(VName, Space)])
-> CompilerM op s [(VName, Space)]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets CompilerState s -> [(VName, Space)]
forall s. CompilerState s -> [(VName, Space)]
compDeclaredMem
  (a
x, [BlockItem]
xs) <- CompilerM op s a -> CompilerM op s (a, [BlockItem])
forall op s a. CompilerM op s a -> CompilerM op s (a, [BlockItem])
collect' CompilerM op s a
m
  [(VName, Space)]
new_allocs <- (CompilerState s -> [(VName, Space)])
-> CompilerM op s [(VName, Space)]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((CompilerState s -> [(VName, Space)])
 -> CompilerM op s [(VName, Space)])
-> (CompilerState s -> [(VName, Space)])
-> CompilerM op s [(VName, Space)]
forall a b. (a -> b) -> a -> b
$ ((VName, Space) -> Bool) -> [(VName, Space)] -> [(VName, Space)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((VName, Space) -> [(VName, Space)] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [(VName, Space)]
old_allocs) ([(VName, Space)] -> [(VName, Space)])
-> (CompilerState s -> [(VName, Space)])
-> CompilerState s
-> [(VName, Space)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerState s -> [(VName, Space)]
forall s. CompilerState s -> [(VName, Space)]
compDeclaredMem
  (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((CompilerState s -> CompilerState s) -> CompilerM op s ())
-> (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \CompilerState s
s -> CompilerState s
s {compDeclaredMem :: [(VName, Space)]
compDeclaredMem = [(VName, Space)]
old_allocs}
  [BlockItem]
releases <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
collect (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ ((VName, Space) -> CompilerM op s ())
-> [(VName, Space)] -> CompilerM op s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((VName -> Space -> CompilerM op s ())
-> (VName, Space) -> CompilerM op s ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry VName -> Space -> CompilerM op s ()
forall a op s. ToExp a => a -> Space -> CompilerM op s ()
unRefMem) [(VName, Space)]
new_allocs
  (a, [BlockItem]) -> CompilerM op s (a, [BlockItem])
forall (m :: * -> *) a. Monad m => a -> m a
return (a
x, [BlockItem]
xs [BlockItem] -> [BlockItem] -> [BlockItem]
forall a. Semigroup a => a -> a -> a
<> [BlockItem]
releases)

compileFunBody :: [C.Exp] -> [Param] -> Code op -> CompilerM op s ()
compileFunBody :: forall op s. [Exp] -> [Param] -> Code op -> CompilerM op s ()
compileFunBody [Exp]
output_ptrs [Param]
outputs Code op
code = do
  (Param -> CompilerM op s ()) -> [Param] -> CompilerM op s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Param -> CompilerM op s ()
forall {op} {s}. Param -> CompilerM op s ()
declareOutput [Param]
outputs
  Code op -> CompilerM op s ()
forall op s. Code op -> CompilerM op s ()
compileCode Code op
code
  (Exp -> Param -> CompilerM op s ())
-> [Exp] -> [Param] -> CompilerM op s ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Exp -> Param -> CompilerM op s ()
forall {a} {op} {s}. ToExp a => a -> Param -> CompilerM op s ()
setRetVal' [Exp]
output_ptrs [Param]
outputs
  where
    declareOutput :: Param -> CompilerM op s ()
declareOutput (MemParam VName
name Space
space) =
      VName -> Space -> CompilerM op s ()
forall op s. VName -> Space -> CompilerM op s ()
declMem VName
name Space
space
    declareOutput (ScalarParam VName
name PrimType
pt) = do
      let ctp :: Type
ctp = PrimType -> Type
primTypeToCType PrimType
pt
      InitGroup -> CompilerM op s ()
forall op s. InitGroup -> CompilerM op s ()
decl [C.cdecl|$ty:ctp $id:name;|]

    setRetVal' :: a -> Param -> CompilerM op s ()
setRetVal' a
p (MemParam VName
name Space
space) = do
      Exp -> Space -> CompilerM op s ()
forall a op s. ToExp a => a -> Space -> CompilerM op s ()
resetMem [C.cexp|*$exp:p|] Space
space
      Exp -> VName -> Space -> CompilerM op s ()
forall a b op s.
(ToExp a, ToExp b) =>
a -> b -> Space -> CompilerM op s ()
setMem [C.cexp|*$exp:p|] VName
name Space
space
    setRetVal' a
p (ScalarParam VName
name PrimType
_) =
      Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|*$exp:p = $id:name;|]

assignmentOperator :: BinOp -> Maybe (VName -> C.Exp -> C.Exp)
assignmentOperator :: BinOp -> Maybe (VName -> Exp -> Exp)
assignmentOperator Add {} = (VName -> Exp -> Exp) -> Maybe (VName -> Exp -> Exp)
forall a. a -> Maybe a
Just ((VName -> Exp -> Exp) -> Maybe (VName -> Exp -> Exp))
-> (VName -> Exp -> Exp) -> Maybe (VName -> Exp -> Exp)
forall a b. (a -> b) -> a -> b
$ \VName
d Exp
e -> [C.cexp|$id:d += $exp:e|]
assignmentOperator Sub {} = (VName -> Exp -> Exp) -> Maybe (VName -> Exp -> Exp)
forall a. a -> Maybe a
Just ((VName -> Exp -> Exp) -> Maybe (VName -> Exp -> Exp))
-> (VName -> Exp -> Exp) -> Maybe (VName -> Exp -> Exp)
forall a b. (a -> b) -> a -> b
$ \VName
d Exp
e -> [C.cexp|$id:d -= $exp:e|]
assignmentOperator Mul {} = (VName -> Exp -> Exp) -> Maybe (VName -> Exp -> Exp)
forall a. a -> Maybe a
Just ((VName -> Exp -> Exp) -> Maybe (VName -> Exp -> Exp))
-> (VName -> Exp -> Exp) -> Maybe (VName -> Exp -> Exp)
forall a b. (a -> b) -> a -> b
$ \VName
d Exp
e -> [C.cexp|$id:d *= $exp:e|]
assignmentOperator BinOp
_ = Maybe (VName -> Exp -> Exp)
forall a. Maybe a
Nothing