{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE Trustworthy #-}
{-# LANGUAGE TupleSections #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}

-- | C code generator framework.
module Futhark.CodeGen.Backends.GenericC
  ( compileProg,
    CParts (..),
    asLibrary,
    asExecutable,

    -- * Pluggable compiler
    Operations (..),
    defaultOperations,
    OpCompiler,
    ErrorCompiler,
    CallCompiler,
    PointerQuals,
    MemoryType,
    WriteScalar,
    writeScalarPointerWithQuals,
    ReadScalar,
    readScalarPointerWithQuals,
    Allocate,
    Deallocate,
    Copy,
    StaticArray,

    -- * Monadic compiler interface
    CompilerM,
    CompilerState (compUserState, compNameSrc),
    getUserState,
    modifyUserState,
    contextContents,
    contextFinalInits,
    runCompilerM,
    inNewFunction,
    cachingMemory,
    blockScope,
    compileFun,
    compileCode,
    compileExp,
    compilePrimExp,
    compilePrimValue,
    compileExpToName,
    rawMem,
    item,
    items,
    stm,
    stms,
    decl,
    atInit,
    headerDecl,
    publicDef,
    publicDef_,
    profileReport,
    HeaderSection (..),
    libDecl,
    earlyDecl,
    publicName,
    contextType,
    contextField,
    memToCType,
    cacheMem,
    fatMemory,
    rawMemCType,
    cproduct,
    fatMemType,

    -- * Building Blocks
    primTypeToCType,
    intTypeToCType,
    copyMemoryDefaultSpace,
  )
where

import Control.Monad.Identity
import Control.Monad.RWS
import Data.Bifunctor (first)
import Data.Bits (shiftR, xor)
import Data.Char (isAlphaNum, isDigit, ord)
import qualified Data.DList as DL
import Data.FileEmbed
import Data.List (unzip4)
import Data.Loc
import qualified Data.Map.Strict as M
import Data.Maybe
import Futhark.CodeGen.Backends.GenericC.Options
import Futhark.CodeGen.Backends.SimpleRep
import Futhark.CodeGen.ImpCode
import Futhark.IR.Prop (isBuiltInFunction)
import Futhark.MonadFreshNames
import Futhark.Util (zEncodeString)
import qualified Language.C.Quote.OpenCL as C
import qualified Language.C.Syntax as C
import Text.Printf

data CompilerState s = CompilerState
  { CompilerState s -> [((Type, Int), (Type, [Definition]))]
compArrayStructs :: [((C.Type, Int), (C.Type, [C.Definition]))],
    CompilerState s -> [(String, (Type, [Definition]))]
compOpaqueStructs :: [(String, (C.Type, [C.Definition]))],
    CompilerState s -> DList Definition
compEarlyDecls :: DL.DList C.Definition,
    CompilerState s -> [Stm]
compInit :: [C.Stm],
    CompilerState s -> VNameSource
compNameSrc :: VNameSource,
    CompilerState s -> s
compUserState :: s,
    CompilerState s -> Map HeaderSection (DList Definition)
compHeaderDecls :: M.Map HeaderSection (DL.DList C.Definition),
    CompilerState s -> DList Definition
compLibDecls :: DL.DList C.Definition,
    CompilerState s -> DList (Id, Type, Maybe Exp)
compCtxFields :: DL.DList (C.Id, C.Type, Maybe C.Exp),
    CompilerState s -> DList BlockItem
compProfileItems :: DL.DList C.BlockItem,
    CompilerState s -> [(VName, Space)]
compDeclaredMem :: [(VName, Space)]
  }

newCompilerState :: VNameSource -> s -> CompilerState s
newCompilerState :: VNameSource -> s -> CompilerState s
newCompilerState VNameSource
src s
s =
  CompilerState :: forall s.
[((Type, Int), (Type, [Definition]))]
-> [(String, (Type, [Definition]))]
-> DList Definition
-> [Stm]
-> VNameSource
-> s
-> Map HeaderSection (DList Definition)
-> DList Definition
-> DList (Id, Type, Maybe Exp)
-> DList BlockItem
-> [(VName, Space)]
-> CompilerState s
CompilerState
    { compArrayStructs :: [((Type, Int), (Type, [Definition]))]
compArrayStructs = [],
      compOpaqueStructs :: [(String, (Type, [Definition]))]
compOpaqueStructs = [],
      compEarlyDecls :: DList Definition
compEarlyDecls = DList Definition
forall a. Monoid a => a
mempty,
      compInit :: [Stm]
compInit = [],
      compNameSrc :: VNameSource
compNameSrc = VNameSource
src,
      compUserState :: s
compUserState = s
s,
      compHeaderDecls :: Map HeaderSection (DList Definition)
compHeaderDecls = Map HeaderSection (DList Definition)
forall a. Monoid a => a
mempty,
      compLibDecls :: DList Definition
compLibDecls = DList Definition
forall a. Monoid a => a
mempty,
      compCtxFields :: DList (Id, Type, Maybe Exp)
compCtxFields = DList (Id, Type, Maybe Exp)
forall a. Monoid a => a
mempty,
      compProfileItems :: DList BlockItem
compProfileItems = DList BlockItem
forall a. Monoid a => a
mempty,
      compDeclaredMem :: [(VName, Space)]
compDeclaredMem = [(VName, Space)]
forall a. Monoid a => a
mempty
    }

-- | In which part of the header file we put the declaration.  This is
-- to ensure that the header file remains structured and readable.
data HeaderSection
  = ArrayDecl String
  | OpaqueDecl String
  | EntryDecl
  | MiscDecl
  | InitDecl
  deriving (HeaderSection -> HeaderSection -> Bool
(HeaderSection -> HeaderSection -> Bool)
-> (HeaderSection -> HeaderSection -> Bool) -> Eq HeaderSection
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: HeaderSection -> HeaderSection -> Bool
$c/= :: HeaderSection -> HeaderSection -> Bool
== :: HeaderSection -> HeaderSection -> Bool
$c== :: HeaderSection -> HeaderSection -> Bool
Eq, Eq HeaderSection
Eq HeaderSection
-> (HeaderSection -> HeaderSection -> Ordering)
-> (HeaderSection -> HeaderSection -> Bool)
-> (HeaderSection -> HeaderSection -> Bool)
-> (HeaderSection -> HeaderSection -> Bool)
-> (HeaderSection -> HeaderSection -> Bool)
-> (HeaderSection -> HeaderSection -> HeaderSection)
-> (HeaderSection -> HeaderSection -> HeaderSection)
-> Ord HeaderSection
HeaderSection -> HeaderSection -> Bool
HeaderSection -> HeaderSection -> Ordering
HeaderSection -> HeaderSection -> HeaderSection
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: HeaderSection -> HeaderSection -> HeaderSection
$cmin :: HeaderSection -> HeaderSection -> HeaderSection
max :: HeaderSection -> HeaderSection -> HeaderSection
$cmax :: HeaderSection -> HeaderSection -> HeaderSection
>= :: HeaderSection -> HeaderSection -> Bool
$c>= :: HeaderSection -> HeaderSection -> Bool
> :: HeaderSection -> HeaderSection -> Bool
$c> :: HeaderSection -> HeaderSection -> Bool
<= :: HeaderSection -> HeaderSection -> Bool
$c<= :: HeaderSection -> HeaderSection -> Bool
< :: HeaderSection -> HeaderSection -> Bool
$c< :: HeaderSection -> HeaderSection -> Bool
compare :: HeaderSection -> HeaderSection -> Ordering
$ccompare :: HeaderSection -> HeaderSection -> Ordering
$cp1Ord :: Eq HeaderSection
Ord)

-- | A substitute expression compiler, tried before the main
-- compilation function.
type OpCompiler op s = op -> CompilerM op s ()

type ErrorCompiler op s = ErrorMsg Exp -> String -> CompilerM op s ()

-- | The address space qualifiers for a pointer of the given type with
-- the given annotation.
type PointerQuals op s = String -> CompilerM op s [C.TypeQual]

-- | The type of a memory block in the given memory space.
type MemoryType op s = SpaceId -> CompilerM op s C.Type

-- | Write a scalar to the given memory block with the given element
-- index and in the given memory space.
type WriteScalar op s =
  C.Exp -> C.Exp -> C.Type -> SpaceId -> Volatility -> C.Exp -> CompilerM op s ()

-- | Read a scalar from the given memory block with the given element
-- index and in the given memory space.
type ReadScalar op s =
  C.Exp -> C.Exp -> C.Type -> SpaceId -> Volatility -> CompilerM op s C.Exp

-- | Allocate a memory block of the given size and with the given tag
-- in the given memory space, saving a reference in the given variable
-- name.
type Allocate op s =
  C.Exp ->
  C.Exp ->
  C.Exp ->
  SpaceId ->
  CompilerM op s ()

-- | De-allocate the given memory block with the given tag, which is
-- in the given memory space.
type Deallocate op s = C.Exp -> C.Exp -> SpaceId -> CompilerM op s ()

-- | Create a static array of values - initialised at load time.
type StaticArray op s = VName -> SpaceId -> PrimType -> ArrayContents -> CompilerM op s ()

-- | Copy from one memory block to another.
type Copy op s =
  C.Exp ->
  C.Exp ->
  Space ->
  C.Exp ->
  C.Exp ->
  Space ->
  C.Exp ->
  CompilerM op s ()

-- | Call a function.
type CallCompiler op s = [VName] -> Name -> [C.Exp] -> CompilerM op s ()

data Operations op s = Operations
  { Operations op s -> WriteScalar op s
opsWriteScalar :: WriteScalar op s,
    Operations op s -> ReadScalar op s
opsReadScalar :: ReadScalar op s,
    Operations op s -> Allocate op s
opsAllocate :: Allocate op s,
    Operations op s -> Deallocate op s
opsDeallocate :: Deallocate op s,
    Operations op s -> Copy op s
opsCopy :: Copy op s,
    Operations op s -> StaticArray op s
opsStaticArray :: StaticArray op s,
    Operations op s -> MemoryType op s
opsMemoryType :: MemoryType op s,
    Operations op s -> OpCompiler op s
opsCompiler :: OpCompiler op s,
    Operations op s -> ErrorCompiler op s
opsError :: ErrorCompiler op s,
    Operations op s -> CallCompiler op s
opsCall :: CallCompiler op s,
    -- | If true, use reference counting.  Otherwise, bare
    -- pointers.
    Operations op s -> Bool
opsFatMemory :: Bool,
    -- | Code to bracket critical sections.
    Operations op s -> ([BlockItem], [BlockItem])
opsCritical :: ([C.BlockItem], [C.BlockItem])
  }

defError :: ErrorCompiler op s
defError :: ErrorCompiler op s
defError (ErrorMsg [ErrorMsgPart Exp]
parts) String
stacktrace = do
  [BlockItem]
free_all_mem <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
collect (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ ((VName, Space) -> CompilerM op s ())
-> [(VName, Space)] -> CompilerM op s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((VName -> Space -> CompilerM op s ())
-> (VName, Space) -> CompilerM op s ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry VName -> Space -> CompilerM op s ()
forall a op s. ToExp a => a -> Space -> CompilerM op s ()
unRefMem) ([(VName, Space)] -> CompilerM op s ())
-> CompilerM op s [(VName, Space)] -> CompilerM op s ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (CompilerState s -> [(VName, Space)])
-> CompilerM op s [(VName, Space)]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets CompilerState s -> [(VName, Space)]
forall s. CompilerState s -> [(VName, Space)]
compDeclaredMem
  let onPart :: ErrorMsgPart Exp -> CompilerM op s (a, Exp)
onPart (ErrorString String
s) = (a, Exp) -> CompilerM op s (a, Exp)
forall (m :: * -> *) a. Monad m => a -> m a
return (a
"%s", [C.cexp|$string:s|])
      onPart (ErrorInt32 Exp
x) = (a
"%d",) (Exp -> (a, Exp)) -> CompilerM op s Exp -> CompilerM op s (a, Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
x
      onPart (ErrorInt64 Exp
x) = (a
"%lld",) (Exp -> (a, Exp)) -> CompilerM op s Exp -> CompilerM op s (a, Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
x
  ([String]
formatstrs, [Exp]
formatargs) <- [(String, Exp)] -> ([String], [Exp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(String, Exp)] -> ([String], [Exp]))
-> CompilerM op s [(String, Exp)]
-> CompilerM op s ([String], [Exp])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ErrorMsgPart Exp -> CompilerM op s (String, Exp))
-> [ErrorMsgPart Exp] -> CompilerM op s [(String, Exp)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ErrorMsgPart Exp -> CompilerM op s (String, Exp)
forall a op s.
IsString a =>
ErrorMsgPart Exp -> CompilerM op s (a, Exp)
onPart [ErrorMsgPart Exp]
parts
  let formatstr :: String
formatstr = String
"Error: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [String] -> String
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [String]
formatstrs String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"\n\nBacktrace:\n%s"
  [BlockItem] -> CompilerM op s ()
forall op s. [BlockItem] -> CompilerM op s ()
items
    [C.citems|ctx->error = msgprintf($string:formatstr, $args:formatargs, $string:stacktrace);
                  $items:free_all_mem
                  return 1;|]

defCall :: CallCompiler op s
defCall :: CallCompiler op s
defCall [VName]
dests Name
fname [Exp]
args = do
  let out_args :: [Exp]
out_args = [[C.cexp|&$id:d|] | VName
d <- [VName]
dests]
      args' :: [Exp]
args'
        | Name -> Bool
isBuiltInFunction Name
fname = [Exp]
args
        | Bool
otherwise = [C.cexp|ctx|] Exp -> [Exp] -> [Exp]
forall a. a -> [a] -> [a]
: [Exp]
out_args [Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++ [Exp]
args
  case [VName]
dests of
    [VName
dest]
      | Name -> Bool
isBuiltInFunction Name
fname ->
        Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|$id:dest = $id:(funName fname)($args:args');|]
    [VName]
_ ->
      BlockItem -> CompilerM op s ()
forall op s. BlockItem -> CompilerM op s ()
item [C.citem|if ($id:(funName fname)($args:args') != 0) { err = 1; goto cleanup; }|]

-- | A set of operations that fail for every operation involving
-- non-default memory spaces.  Uses plain pointers and @malloc@ for
-- memory management.
defaultOperations :: Operations op s
defaultOperations :: Operations op s
defaultOperations =
  Operations :: forall op s.
WriteScalar op s
-> ReadScalar op s
-> Allocate op s
-> Deallocate op s
-> Copy op s
-> StaticArray op s
-> MemoryType op s
-> OpCompiler op s
-> ErrorCompiler op s
-> CallCompiler op s
-> Bool
-> ([BlockItem], [BlockItem])
-> Operations op s
Operations
    { opsWriteScalar :: WriteScalar op s
opsWriteScalar = WriteScalar op s
forall p p p p p a. p -> p -> p -> p -> p -> a
defWriteScalar,
      opsReadScalar :: ReadScalar op s
opsReadScalar = ReadScalar op s
forall p p p p a. p -> p -> p -> p -> a
defReadScalar,
      opsAllocate :: Allocate op s
opsAllocate = Allocate op s
forall p p p a. p -> p -> p -> a
defAllocate,
      opsDeallocate :: Deallocate op s
opsDeallocate = Deallocate op s
forall p p a. p -> p -> a
defDeallocate,
      opsCopy :: Copy op s
opsCopy = Copy op s
forall op s.
Exp
-> Exp -> Space -> Exp -> Exp -> Space -> Exp -> CompilerM op s ()
defCopy,
      opsStaticArray :: StaticArray op s
opsStaticArray = StaticArray op s
forall p p p p a. p -> p -> p -> p -> a
defStaticArray,
      opsMemoryType :: MemoryType op s
opsMemoryType = MemoryType op s
forall p a. p -> a
defMemoryType,
      opsCompiler :: OpCompiler op s
opsCompiler = OpCompiler op s
forall p a. p -> a
defCompiler,
      opsFatMemory :: Bool
opsFatMemory = Bool
True,
      opsError :: ErrorCompiler op s
opsError = ErrorCompiler op s
forall op s. ErrorCompiler op s
defError,
      opsCall :: CallCompiler op s
opsCall = CallCompiler op s
forall op s. CallCompiler op s
defCall,
      opsCritical :: ([BlockItem], [BlockItem])
opsCritical = ([BlockItem], [BlockItem])
forall a. Monoid a => a
mempty
    }
  where
    defWriteScalar :: p -> p -> p -> p -> p -> a
defWriteScalar p
_ p
_ p
_ p
_ p
_ =
      String -> a
forall a. HasCallStack => String -> a
error String
"Cannot write to non-default memory space because I am dumb"
    defReadScalar :: p -> p -> p -> p -> a
defReadScalar p
_ p
_ p
_ p
_ =
      String -> a
forall a. HasCallStack => String -> a
error String
"Cannot read from non-default memory space"
    defAllocate :: p -> p -> p -> a
defAllocate p
_ p
_ p
_ =
      String -> a
forall a. HasCallStack => String -> a
error String
"Cannot allocate in non-default memory space"
    defDeallocate :: p -> p -> a
defDeallocate p
_ p
_ =
      String -> a
forall a. HasCallStack => String -> a
error String
"Cannot deallocate in non-default memory space"
    defCopy :: Exp
-> Exp -> Space -> Exp -> Exp -> Space -> Exp -> CompilerM op s ()
defCopy Exp
destmem Exp
destoffset Space
DefaultSpace Exp
srcmem Exp
srcoffset Space
DefaultSpace Exp
size =
      Exp -> Exp -> Exp -> Exp -> Exp -> CompilerM op s ()
forall op s. Exp -> Exp -> Exp -> Exp -> Exp -> CompilerM op s ()
copyMemoryDefaultSpace Exp
destmem Exp
destoffset Exp
srcmem Exp
srcoffset Exp
size
    defCopy Exp
_ Exp
_ Space
_ Exp
_ Exp
_ Space
_ Exp
_ =
      String -> CompilerM op s ()
forall a. HasCallStack => String -> a
error String
"Cannot copy to or from non-default memory space"
    defStaticArray :: p -> p -> p -> p -> a
defStaticArray p
_ p
_ p
_ p
_ =
      String -> a
forall a. HasCallStack => String -> a
error String
"Cannot create static array in non-default memory space"
    defMemoryType :: p -> a
defMemoryType p
_ =
      String -> a
forall a. HasCallStack => String -> a
error String
"Has no type for non-default memory space"
    defCompiler :: p -> a
defCompiler p
_ =
      String -> a
forall a. HasCallStack => String -> a
error String
"The default compiler cannot compile extended operations"

data CompilerEnv op s = CompilerEnv
  { CompilerEnv op s -> Operations op s
envOperations :: Operations op s,
    -- | Mapping memory blocks to sizes.  These memory blocks are CPU
    -- memory that we know are used in particularly simple ways (no
    -- reference counting necessary).  To cut down on allocator
    -- pressure, we keep these allocations around for a long time, and
    -- record their sizes so we can reuse them if possible (and
    -- realloc() when needed).
    CompilerEnv op s -> Map Exp VName
envCachedMem :: M.Map C.Exp VName
  }

newtype CompilerAcc op s = CompilerAcc
  { CompilerAcc op s -> DList BlockItem
accItems :: DL.DList C.BlockItem
  }

instance Semigroup (CompilerAcc op s) where
  CompilerAcc DList BlockItem
items1 <> :: CompilerAcc op s -> CompilerAcc op s -> CompilerAcc op s
<> CompilerAcc DList BlockItem
items2 =
    DList BlockItem -> CompilerAcc op s
forall op s. DList BlockItem -> CompilerAcc op s
CompilerAcc (DList BlockItem
items1 DList BlockItem -> DList BlockItem -> DList BlockItem
forall a. Semigroup a => a -> a -> a
<> DList BlockItem
items2)

instance Monoid (CompilerAcc op s) where
  mempty :: CompilerAcc op s
mempty = DList BlockItem -> CompilerAcc op s
forall op s. DList BlockItem -> CompilerAcc op s
CompilerAcc DList BlockItem
forall a. Monoid a => a
mempty

envOpCompiler :: CompilerEnv op s -> OpCompiler op s
envOpCompiler :: CompilerEnv op s -> OpCompiler op s
envOpCompiler = Operations op s -> OpCompiler op s
forall op s. Operations op s -> OpCompiler op s
opsCompiler (Operations op s -> OpCompiler op s)
-> (CompilerEnv op s -> Operations op s)
-> CompilerEnv op s
-> OpCompiler op s
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerEnv op s -> Operations op s
forall op s. CompilerEnv op s -> Operations op s
envOperations

envMemoryType :: CompilerEnv op s -> MemoryType op s
envMemoryType :: CompilerEnv op s -> MemoryType op s
envMemoryType = Operations op s -> MemoryType op s
forall op s. Operations op s -> MemoryType op s
opsMemoryType (Operations op s -> MemoryType op s)
-> (CompilerEnv op s -> Operations op s)
-> CompilerEnv op s
-> MemoryType op s
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerEnv op s -> Operations op s
forall op s. CompilerEnv op s -> Operations op s
envOperations

envReadScalar :: CompilerEnv op s -> ReadScalar op s
envReadScalar :: CompilerEnv op s -> ReadScalar op s
envReadScalar = Operations op s -> ReadScalar op s
forall op s. Operations op s -> ReadScalar op s
opsReadScalar (Operations op s -> ReadScalar op s)
-> (CompilerEnv op s -> Operations op s)
-> CompilerEnv op s
-> ReadScalar op s
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerEnv op s -> Operations op s
forall op s. CompilerEnv op s -> Operations op s
envOperations

envWriteScalar :: CompilerEnv op s -> WriteScalar op s
envWriteScalar :: CompilerEnv op s -> WriteScalar op s
envWriteScalar = Operations op s -> WriteScalar op s
forall op s. Operations op s -> WriteScalar op s
opsWriteScalar (Operations op s -> WriteScalar op s)
-> (CompilerEnv op s -> Operations op s)
-> CompilerEnv op s
-> WriteScalar op s
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerEnv op s -> Operations op s
forall op s. CompilerEnv op s -> Operations op s
envOperations

envAllocate :: CompilerEnv op s -> Allocate op s
envAllocate :: CompilerEnv op s -> Allocate op s
envAllocate = Operations op s -> Allocate op s
forall op s. Operations op s -> Allocate op s
opsAllocate (Operations op s -> Allocate op s)
-> (CompilerEnv op s -> Operations op s)
-> CompilerEnv op s
-> Allocate op s
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerEnv op s -> Operations op s
forall op s. CompilerEnv op s -> Operations op s
envOperations

envDeallocate :: CompilerEnv op s -> Deallocate op s
envDeallocate :: CompilerEnv op s -> Deallocate op s
envDeallocate = Operations op s -> Deallocate op s
forall op s. Operations op s -> Deallocate op s
opsDeallocate (Operations op s -> Deallocate op s)
-> (CompilerEnv op s -> Operations op s)
-> CompilerEnv op s
-> Deallocate op s
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerEnv op s -> Operations op s
forall op s. CompilerEnv op s -> Operations op s
envOperations

envCopy :: CompilerEnv op s -> Copy op s
envCopy :: CompilerEnv op s -> Copy op s
envCopy = Operations op s -> Copy op s
forall op s. Operations op s -> Copy op s
opsCopy (Operations op s -> Copy op s)
-> (CompilerEnv op s -> Operations op s)
-> CompilerEnv op s
-> Copy op s
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerEnv op s -> Operations op s
forall op s. CompilerEnv op s -> Operations op s
envOperations

envStaticArray :: CompilerEnv op s -> StaticArray op s
envStaticArray :: CompilerEnv op s -> StaticArray op s
envStaticArray = Operations op s -> StaticArray op s
forall op s. Operations op s -> StaticArray op s
opsStaticArray (Operations op s -> StaticArray op s)
-> (CompilerEnv op s -> Operations op s)
-> CompilerEnv op s
-> StaticArray op s
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerEnv op s -> Operations op s
forall op s. CompilerEnv op s -> Operations op s
envOperations

envFatMemory :: CompilerEnv op s -> Bool
envFatMemory :: CompilerEnv op s -> Bool
envFatMemory = Operations op s -> Bool
forall op s. Operations op s -> Bool
opsFatMemory (Operations op s -> Bool)
-> (CompilerEnv op s -> Operations op s)
-> CompilerEnv op s
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerEnv op s -> Operations op s
forall op s. CompilerEnv op s -> Operations op s
envOperations

arrayDefinitions, opaqueDefinitions :: CompilerState s -> [C.Definition]
arrayDefinitions :: CompilerState s -> [Definition]
arrayDefinitions = (((Type, Int), (Type, [Definition])) -> [Definition])
-> [((Type, Int), (Type, [Definition]))] -> [Definition]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ((Type, [Definition]) -> [Definition]
forall a b. (a, b) -> b
snd ((Type, [Definition]) -> [Definition])
-> (((Type, Int), (Type, [Definition])) -> (Type, [Definition]))
-> ((Type, Int), (Type, [Definition]))
-> [Definition]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Type, Int), (Type, [Definition])) -> (Type, [Definition])
forall a b. (a, b) -> b
snd) ([((Type, Int), (Type, [Definition]))] -> [Definition])
-> (CompilerState s -> [((Type, Int), (Type, [Definition]))])
-> CompilerState s
-> [Definition]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerState s -> [((Type, Int), (Type, [Definition]))]
forall s. CompilerState s -> [((Type, Int), (Type, [Definition]))]
compArrayStructs
opaqueDefinitions :: CompilerState s -> [Definition]
opaqueDefinitions = ((String, (Type, [Definition])) -> [Definition])
-> [(String, (Type, [Definition]))] -> [Definition]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ((Type, [Definition]) -> [Definition]
forall a b. (a, b) -> b
snd ((Type, [Definition]) -> [Definition])
-> ((String, (Type, [Definition])) -> (Type, [Definition]))
-> (String, (Type, [Definition]))
-> [Definition]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (String, (Type, [Definition])) -> (Type, [Definition])
forall a b. (a, b) -> b
snd) ([(String, (Type, [Definition]))] -> [Definition])
-> (CompilerState s -> [(String, (Type, [Definition]))])
-> CompilerState s
-> [Definition]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerState s -> [(String, (Type, [Definition]))]
forall s. CompilerState s -> [(String, (Type, [Definition]))]
compOpaqueStructs

initDecls, arrayDecls, opaqueDecls, entryDecls, miscDecls :: CompilerState s -> [C.Definition]
initDecls :: CompilerState s -> [Definition]
initDecls = ((HeaderSection, DList Definition) -> [Definition])
-> [(HeaderSection, DList Definition)] -> [Definition]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (DList Definition -> [Definition]
forall a. DList a -> [a]
DL.toList (DList Definition -> [Definition])
-> ((HeaderSection, DList Definition) -> DList Definition)
-> (HeaderSection, DList Definition)
-> [Definition]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HeaderSection, DList Definition) -> DList Definition
forall a b. (a, b) -> b
snd) ([(HeaderSection, DList Definition)] -> [Definition])
-> (CompilerState s -> [(HeaderSection, DList Definition)])
-> CompilerState s
-> [Definition]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((HeaderSection, DList Definition) -> Bool)
-> [(HeaderSection, DList Definition)]
-> [(HeaderSection, DList Definition)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((HeaderSection -> HeaderSection -> Bool
forall a. Eq a => a -> a -> Bool
== HeaderSection
InitDecl) (HeaderSection -> Bool)
-> ((HeaderSection, DList Definition) -> HeaderSection)
-> (HeaderSection, DList Definition)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HeaderSection, DList Definition) -> HeaderSection
forall a b. (a, b) -> a
fst) ([(HeaderSection, DList Definition)]
 -> [(HeaderSection, DList Definition)])
-> (CompilerState s -> [(HeaderSection, DList Definition)])
-> CompilerState s
-> [(HeaderSection, DList Definition)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map HeaderSection (DList Definition)
-> [(HeaderSection, DList Definition)]
forall k a. Map k a -> [(k, a)]
M.toList (Map HeaderSection (DList Definition)
 -> [(HeaderSection, DList Definition)])
-> (CompilerState s -> Map HeaderSection (DList Definition))
-> CompilerState s
-> [(HeaderSection, DList Definition)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerState s -> Map HeaderSection (DList Definition)
forall s. CompilerState s -> Map HeaderSection (DList Definition)
compHeaderDecls
arrayDecls :: CompilerState s -> [Definition]
arrayDecls = ((HeaderSection, DList Definition) -> [Definition])
-> [(HeaderSection, DList Definition)] -> [Definition]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (DList Definition -> [Definition]
forall a. DList a -> [a]
DL.toList (DList Definition -> [Definition])
-> ((HeaderSection, DList Definition) -> DList Definition)
-> (HeaderSection, DList Definition)
-> [Definition]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HeaderSection, DList Definition) -> DList Definition
forall a b. (a, b) -> b
snd) ([(HeaderSection, DList Definition)] -> [Definition])
-> (CompilerState s -> [(HeaderSection, DList Definition)])
-> CompilerState s
-> [Definition]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((HeaderSection, DList Definition) -> Bool)
-> [(HeaderSection, DList Definition)]
-> [(HeaderSection, DList Definition)]
forall a. (a -> Bool) -> [a] -> [a]
filter (HeaderSection -> Bool
isArrayDecl (HeaderSection -> Bool)
-> ((HeaderSection, DList Definition) -> HeaderSection)
-> (HeaderSection, DList Definition)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HeaderSection, DList Definition) -> HeaderSection
forall a b. (a, b) -> a
fst) ([(HeaderSection, DList Definition)]
 -> [(HeaderSection, DList Definition)])
-> (CompilerState s -> [(HeaderSection, DList Definition)])
-> CompilerState s
-> [(HeaderSection, DList Definition)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map HeaderSection (DList Definition)
-> [(HeaderSection, DList Definition)]
forall k a. Map k a -> [(k, a)]
M.toList (Map HeaderSection (DList Definition)
 -> [(HeaderSection, DList Definition)])
-> (CompilerState s -> Map HeaderSection (DList Definition))
-> CompilerState s
-> [(HeaderSection, DList Definition)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerState s -> Map HeaderSection (DList Definition)
forall s. CompilerState s -> Map HeaderSection (DList Definition)
compHeaderDecls
  where
    isArrayDecl :: HeaderSection -> Bool
isArrayDecl ArrayDecl {} = Bool
True
    isArrayDecl HeaderSection
_ = Bool
False
opaqueDecls :: CompilerState s -> [Definition]
opaqueDecls = ((HeaderSection, DList Definition) -> [Definition])
-> [(HeaderSection, DList Definition)] -> [Definition]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (DList Definition -> [Definition]
forall a. DList a -> [a]
DL.toList (DList Definition -> [Definition])
-> ((HeaderSection, DList Definition) -> DList Definition)
-> (HeaderSection, DList Definition)
-> [Definition]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HeaderSection, DList Definition) -> DList Definition
forall a b. (a, b) -> b
snd) ([(HeaderSection, DList Definition)] -> [Definition])
-> (CompilerState s -> [(HeaderSection, DList Definition)])
-> CompilerState s
-> [Definition]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((HeaderSection, DList Definition) -> Bool)
-> [(HeaderSection, DList Definition)]
-> [(HeaderSection, DList Definition)]
forall a. (a -> Bool) -> [a] -> [a]
filter (HeaderSection -> Bool
isOpaqueDecl (HeaderSection -> Bool)
-> ((HeaderSection, DList Definition) -> HeaderSection)
-> (HeaderSection, DList Definition)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HeaderSection, DList Definition) -> HeaderSection
forall a b. (a, b) -> a
fst) ([(HeaderSection, DList Definition)]
 -> [(HeaderSection, DList Definition)])
-> (CompilerState s -> [(HeaderSection, DList Definition)])
-> CompilerState s
-> [(HeaderSection, DList Definition)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map HeaderSection (DList Definition)
-> [(HeaderSection, DList Definition)]
forall k a. Map k a -> [(k, a)]
M.toList (Map HeaderSection (DList Definition)
 -> [(HeaderSection, DList Definition)])
-> (CompilerState s -> Map HeaderSection (DList Definition))
-> CompilerState s
-> [(HeaderSection, DList Definition)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerState s -> Map HeaderSection (DList Definition)
forall s. CompilerState s -> Map HeaderSection (DList Definition)
compHeaderDecls
  where
    isOpaqueDecl :: HeaderSection -> Bool
isOpaqueDecl OpaqueDecl {} = Bool
True
    isOpaqueDecl HeaderSection
_ = Bool
False
entryDecls :: CompilerState s -> [Definition]
entryDecls = ((HeaderSection, DList Definition) -> [Definition])
-> [(HeaderSection, DList Definition)] -> [Definition]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (DList Definition -> [Definition]
forall a. DList a -> [a]
DL.toList (DList Definition -> [Definition])
-> ((HeaderSection, DList Definition) -> DList Definition)
-> (HeaderSection, DList Definition)
-> [Definition]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HeaderSection, DList Definition) -> DList Definition
forall a b. (a, b) -> b
snd) ([(HeaderSection, DList Definition)] -> [Definition])
-> (CompilerState s -> [(HeaderSection, DList Definition)])
-> CompilerState s
-> [Definition]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((HeaderSection, DList Definition) -> Bool)
-> [(HeaderSection, DList Definition)]
-> [(HeaderSection, DList Definition)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((HeaderSection -> HeaderSection -> Bool
forall a. Eq a => a -> a -> Bool
== HeaderSection
EntryDecl) (HeaderSection -> Bool)
-> ((HeaderSection, DList Definition) -> HeaderSection)
-> (HeaderSection, DList Definition)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HeaderSection, DList Definition) -> HeaderSection
forall a b. (a, b) -> a
fst) ([(HeaderSection, DList Definition)]
 -> [(HeaderSection, DList Definition)])
-> (CompilerState s -> [(HeaderSection, DList Definition)])
-> CompilerState s
-> [(HeaderSection, DList Definition)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map HeaderSection (DList Definition)
-> [(HeaderSection, DList Definition)]
forall k a. Map k a -> [(k, a)]
M.toList (Map HeaderSection (DList Definition)
 -> [(HeaderSection, DList Definition)])
-> (CompilerState s -> Map HeaderSection (DList Definition))
-> CompilerState s
-> [(HeaderSection, DList Definition)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerState s -> Map HeaderSection (DList Definition)
forall s. CompilerState s -> Map HeaderSection (DList Definition)
compHeaderDecls
miscDecls :: CompilerState s -> [Definition]
miscDecls = ((HeaderSection, DList Definition) -> [Definition])
-> [(HeaderSection, DList Definition)] -> [Definition]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (DList Definition -> [Definition]
forall a. DList a -> [a]
DL.toList (DList Definition -> [Definition])
-> ((HeaderSection, DList Definition) -> DList Definition)
-> (HeaderSection, DList Definition)
-> [Definition]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HeaderSection, DList Definition) -> DList Definition
forall a b. (a, b) -> b
snd) ([(HeaderSection, DList Definition)] -> [Definition])
-> (CompilerState s -> [(HeaderSection, DList Definition)])
-> CompilerState s
-> [Definition]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((HeaderSection, DList Definition) -> Bool)
-> [(HeaderSection, DList Definition)]
-> [(HeaderSection, DList Definition)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((HeaderSection -> HeaderSection -> Bool
forall a. Eq a => a -> a -> Bool
== HeaderSection
MiscDecl) (HeaderSection -> Bool)
-> ((HeaderSection, DList Definition) -> HeaderSection)
-> (HeaderSection, DList Definition)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HeaderSection, DList Definition) -> HeaderSection
forall a b. (a, b) -> a
fst) ([(HeaderSection, DList Definition)]
 -> [(HeaderSection, DList Definition)])
-> (CompilerState s -> [(HeaderSection, DList Definition)])
-> CompilerState s
-> [(HeaderSection, DList Definition)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map HeaderSection (DList Definition)
-> [(HeaderSection, DList Definition)]
forall k a. Map k a -> [(k, a)]
M.toList (Map HeaderSection (DList Definition)
 -> [(HeaderSection, DList Definition)])
-> (CompilerState s -> Map HeaderSection (DList Definition))
-> CompilerState s
-> [(HeaderSection, DList Definition)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerState s -> Map HeaderSection (DList Definition)
forall s. CompilerState s -> Map HeaderSection (DList Definition)
compHeaderDecls

contextContents :: CompilerM op s ([C.FieldGroup], [C.Stm])
contextContents :: CompilerM op s ([FieldGroup], [Stm])
contextContents = do
  ([Id]
field_names, [Type]
field_types, [Maybe Exp]
field_values) <- (CompilerState s -> ([Id], [Type], [Maybe Exp]))
-> CompilerM op s ([Id], [Type], [Maybe Exp])
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((CompilerState s -> ([Id], [Type], [Maybe Exp]))
 -> CompilerM op s ([Id], [Type], [Maybe Exp]))
-> (CompilerState s -> ([Id], [Type], [Maybe Exp]))
-> CompilerM op s ([Id], [Type], [Maybe Exp])
forall a b. (a -> b) -> a -> b
$ [(Id, Type, Maybe Exp)] -> ([Id], [Type], [Maybe Exp])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(Id, Type, Maybe Exp)] -> ([Id], [Type], [Maybe Exp]))
-> (CompilerState s -> [(Id, Type, Maybe Exp)])
-> CompilerState s
-> ([Id], [Type], [Maybe Exp])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DList (Id, Type, Maybe Exp) -> [(Id, Type, Maybe Exp)]
forall a. DList a -> [a]
DL.toList (DList (Id, Type, Maybe Exp) -> [(Id, Type, Maybe Exp)])
-> (CompilerState s -> DList (Id, Type, Maybe Exp))
-> CompilerState s
-> [(Id, Type, Maybe Exp)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerState s -> DList (Id, Type, Maybe Exp)
forall s. CompilerState s -> DList (Id, Type, Maybe Exp)
compCtxFields
  let fields :: [FieldGroup]
fields =
        [ [C.csdecl|$ty:ty $id:name;|]
          | (Id
name, Type
ty) <- [Id] -> [Type] -> [(Id, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Id]
field_names [Type]
field_types
        ]
      init_fields :: [Stm]
init_fields =
        [ [C.cstm|ctx->$id:name = $exp:e;|]
          | (Id
name, Just Exp
e) <- [Id] -> [Maybe Exp] -> [(Id, Maybe Exp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Id]
field_names [Maybe Exp]
field_values
        ]
  ([FieldGroup], [Stm]) -> CompilerM op s ([FieldGroup], [Stm])
forall (m :: * -> *) a. Monad m => a -> m a
return ([FieldGroup]
fields, [Stm]
init_fields)

contextFinalInits :: CompilerM op s [C.Stm]
contextFinalInits :: CompilerM op s [Stm]
contextFinalInits = (CompilerState s -> [Stm]) -> CompilerM op s [Stm]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets CompilerState s -> [Stm]
forall s. CompilerState s -> [Stm]
compInit

newtype CompilerM op s a
  = CompilerM
      ( RWS
          (CompilerEnv op s)
          (CompilerAcc op s)
          (CompilerState s)
          a
      )
  deriving
    ( a -> CompilerM op s b -> CompilerM op s a
(a -> b) -> CompilerM op s a -> CompilerM op s b
(forall a b. (a -> b) -> CompilerM op s a -> CompilerM op s b)
-> (forall a b. a -> CompilerM op s b -> CompilerM op s a)
-> Functor (CompilerM op s)
forall a b. a -> CompilerM op s b -> CompilerM op s a
forall a b. (a -> b) -> CompilerM op s a -> CompilerM op s b
forall op s a b. a -> CompilerM op s b -> CompilerM op s a
forall op s a b. (a -> b) -> CompilerM op s a -> CompilerM op s b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> CompilerM op s b -> CompilerM op s a
$c<$ :: forall op s a b. a -> CompilerM op s b -> CompilerM op s a
fmap :: (a -> b) -> CompilerM op s a -> CompilerM op s b
$cfmap :: forall op s a b. (a -> b) -> CompilerM op s a -> CompilerM op s b
Functor,
      Functor (CompilerM op s)
a -> CompilerM op s a
Functor (CompilerM op s)
-> (forall a. a -> CompilerM op s a)
-> (forall a b.
    CompilerM op s (a -> b) -> CompilerM op s a -> CompilerM op s b)
-> (forall a b c.
    (a -> b -> c)
    -> CompilerM op s a -> CompilerM op s b -> CompilerM op s c)
-> (forall a b.
    CompilerM op s a -> CompilerM op s b -> CompilerM op s b)
-> (forall a b.
    CompilerM op s a -> CompilerM op s b -> CompilerM op s a)
-> Applicative (CompilerM op s)
CompilerM op s a -> CompilerM op s b -> CompilerM op s b
CompilerM op s a -> CompilerM op s b -> CompilerM op s a
CompilerM op s (a -> b) -> CompilerM op s a -> CompilerM op s b
(a -> b -> c)
-> CompilerM op s a -> CompilerM op s b -> CompilerM op s c
forall a. a -> CompilerM op s a
forall op s. Functor (CompilerM op s)
forall a b.
CompilerM op s a -> CompilerM op s b -> CompilerM op s a
forall a b.
CompilerM op s a -> CompilerM op s b -> CompilerM op s b
forall a b.
CompilerM op s (a -> b) -> CompilerM op s a -> CompilerM op s b
forall op s a. a -> CompilerM op s a
forall a b c.
(a -> b -> c)
-> CompilerM op s a -> CompilerM op s b -> CompilerM op s c
forall op s a b.
CompilerM op s a -> CompilerM op s b -> CompilerM op s a
forall op s a b.
CompilerM op s a -> CompilerM op s b -> CompilerM op s b
forall op s a b.
CompilerM op s (a -> b) -> CompilerM op s a -> CompilerM op s b
forall op s a b c.
(a -> b -> c)
-> CompilerM op s a -> CompilerM op s b -> CompilerM op s c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: CompilerM op s a -> CompilerM op s b -> CompilerM op s a
$c<* :: forall op s a b.
CompilerM op s a -> CompilerM op s b -> CompilerM op s a
*> :: CompilerM op s a -> CompilerM op s b -> CompilerM op s b
$c*> :: forall op s a b.
CompilerM op s a -> CompilerM op s b -> CompilerM op s b
liftA2 :: (a -> b -> c)
-> CompilerM op s a -> CompilerM op s b -> CompilerM op s c
$cliftA2 :: forall op s a b c.
(a -> b -> c)
-> CompilerM op s a -> CompilerM op s b -> CompilerM op s c
<*> :: CompilerM op s (a -> b) -> CompilerM op s a -> CompilerM op s b
$c<*> :: forall op s a b.
CompilerM op s (a -> b) -> CompilerM op s a -> CompilerM op s b
pure :: a -> CompilerM op s a
$cpure :: forall op s a. a -> CompilerM op s a
$cp1Applicative :: forall op s. Functor (CompilerM op s)
Applicative,
      Applicative (CompilerM op s)
a -> CompilerM op s a
Applicative (CompilerM op s)
-> (forall a b.
    CompilerM op s a -> (a -> CompilerM op s b) -> CompilerM op s b)
-> (forall a b.
    CompilerM op s a -> CompilerM op s b -> CompilerM op s b)
-> (forall a. a -> CompilerM op s a)
-> Monad (CompilerM op s)
CompilerM op s a -> (a -> CompilerM op s b) -> CompilerM op s b
CompilerM op s a -> CompilerM op s b -> CompilerM op s b
forall a. a -> CompilerM op s a
forall op s. Applicative (CompilerM op s)
forall a b.
CompilerM op s a -> CompilerM op s b -> CompilerM op s b
forall a b.
CompilerM op s a -> (a -> CompilerM op s b) -> CompilerM op s b
forall op s a. a -> CompilerM op s a
forall op s a b.
CompilerM op s a -> CompilerM op s b -> CompilerM op s b
forall op s a b.
CompilerM op s a -> (a -> CompilerM op s b) -> CompilerM op s b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> CompilerM op s a
$creturn :: forall op s a. a -> CompilerM op s a
>> :: CompilerM op s a -> CompilerM op s b -> CompilerM op s b
$c>> :: forall op s a b.
CompilerM op s a -> CompilerM op s b -> CompilerM op s b
>>= :: CompilerM op s a -> (a -> CompilerM op s b) -> CompilerM op s b
$c>>= :: forall op s a b.
CompilerM op s a -> (a -> CompilerM op s b) -> CompilerM op s b
$cp1Monad :: forall op s. Applicative (CompilerM op s)
Monad,
      MonadState (CompilerState s),
      MonadReader (CompilerEnv op s),
      MonadWriter (CompilerAcc op s)
    )

instance MonadFreshNames (CompilerM op s) where
  getNameSource :: CompilerM op s VNameSource
getNameSource = (CompilerState s -> VNameSource) -> CompilerM op s VNameSource
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets CompilerState s -> VNameSource
forall s. CompilerState s -> VNameSource
compNameSrc
  putNameSource :: VNameSource -> CompilerM op s ()
putNameSource VNameSource
src = (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((CompilerState s -> CompilerState s) -> CompilerM op s ())
-> (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \CompilerState s
s -> CompilerState s
s {compNameSrc :: VNameSource
compNameSrc = VNameSource
src}

runCompilerM ::
  Operations op s ->
  VNameSource ->
  s ->
  CompilerM op s a ->
  (a, CompilerState s)
runCompilerM :: Operations op s
-> VNameSource -> s -> CompilerM op s a -> (a, CompilerState s)
runCompilerM Operations op s
ops VNameSource
src s
userstate (CompilerM RWS (CompilerEnv op s) (CompilerAcc op s) (CompilerState s) a
m) =
  let (a
x, CompilerState s
s, CompilerAcc op s
_) = RWS (CompilerEnv op s) (CompilerAcc op s) (CompilerState s) a
-> CompilerEnv op s
-> CompilerState s
-> (a, CompilerState s, CompilerAcc op s)
forall r w s a. RWS r w s a -> r -> s -> (a, s, w)
runRWS RWS (CompilerEnv op s) (CompilerAcc op s) (CompilerState s) a
m (Operations op s -> Map Exp VName -> CompilerEnv op s
forall op s. Operations op s -> Map Exp VName -> CompilerEnv op s
CompilerEnv Operations op s
ops Map Exp VName
forall a. Monoid a => a
mempty) (VNameSource -> s -> CompilerState s
forall s. VNameSource -> s -> CompilerState s
newCompilerState VNameSource
src s
userstate)
   in (a
x, CompilerState s
s)

getUserState :: CompilerM op s s
getUserState :: CompilerM op s s
getUserState = (CompilerState s -> s) -> CompilerM op s s
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets CompilerState s -> s
forall s. CompilerState s -> s
compUserState

modifyUserState :: (s -> s) -> CompilerM op s ()
modifyUserState :: (s -> s) -> CompilerM op s ()
modifyUserState s -> s
f = (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((CompilerState s -> CompilerState s) -> CompilerM op s ())
-> (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \CompilerState s
compstate ->
  CompilerState s
compstate {compUserState :: s
compUserState = s -> s
f (s -> s) -> s -> s
forall a b. (a -> b) -> a -> b
$ CompilerState s -> s
forall s. CompilerState s -> s
compUserState CompilerState s
compstate}

atInit :: C.Stm -> CompilerM op s ()
atInit :: Stm -> CompilerM op s ()
atInit Stm
x = (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((CompilerState s -> CompilerState s) -> CompilerM op s ())
-> (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \CompilerState s
s ->
  CompilerState s
s {compInit :: [Stm]
compInit = CompilerState s -> [Stm]
forall s. CompilerState s -> [Stm]
compInit CompilerState s
s [Stm] -> [Stm] -> [Stm]
forall a. [a] -> [a] -> [a]
++ [Stm
x]}

collect :: CompilerM op s () -> CompilerM op s [C.BlockItem]
collect :: CompilerM op s () -> CompilerM op s [BlockItem]
collect CompilerM op s ()
m = ((), [BlockItem]) -> [BlockItem]
forall a b. (a, b) -> b
snd (((), [BlockItem]) -> [BlockItem])
-> CompilerM op s ((), [BlockItem]) -> CompilerM op s [BlockItem]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CompilerM op s () -> CompilerM op s ((), [BlockItem])
forall op s a. CompilerM op s a -> CompilerM op s (a, [BlockItem])
collect' CompilerM op s ()
m

collect' :: CompilerM op s a -> CompilerM op s (a, [C.BlockItem])
collect' :: CompilerM op s a -> CompilerM op s (a, [BlockItem])
collect' CompilerM op s a
m = CompilerM
  op s ((a, [BlockItem]), CompilerAcc op s -> CompilerAcc op s)
-> CompilerM op s (a, [BlockItem])
forall w (m :: * -> *) a. MonadWriter w m => m (a, w -> w) -> m a
pass (CompilerM
   op s ((a, [BlockItem]), CompilerAcc op s -> CompilerAcc op s)
 -> CompilerM op s (a, [BlockItem]))
-> CompilerM
     op s ((a, [BlockItem]), CompilerAcc op s -> CompilerAcc op s)
-> CompilerM op s (a, [BlockItem])
forall a b. (a -> b) -> a -> b
$ do
  (a
x, CompilerAcc op s
w) <- CompilerM op s a -> CompilerM op s (a, CompilerAcc op s)
forall w (m :: * -> *) a. MonadWriter w m => m a -> m (a, w)
listen CompilerM op s a
m
  ((a, [BlockItem]), CompilerAcc op s -> CompilerAcc op s)
-> CompilerM
     op s ((a, [BlockItem]), CompilerAcc op s -> CompilerAcc op s)
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( (a
x, DList BlockItem -> [BlockItem]
forall a. DList a -> [a]
DL.toList (DList BlockItem -> [BlockItem]) -> DList BlockItem -> [BlockItem]
forall a b. (a -> b) -> a -> b
$ CompilerAcc op s -> DList BlockItem
forall op s. CompilerAcc op s -> DList BlockItem
accItems CompilerAcc op s
w),
      CompilerAcc op s -> CompilerAcc op s -> CompilerAcc op s
forall a b. a -> b -> a
const CompilerAcc op s
w {accItems :: DList BlockItem
accItems = DList BlockItem
forall a. Monoid a => a
mempty}
    )

-- | Used when we, inside an existing 'CompilerM' action, want to
-- generate code for a new function.  Use this so that the compiler
-- understands that previously declared memory doesn't need to be
-- freed inside this action.
inNewFunction :: Bool -> CompilerM op s a -> CompilerM op s a
inNewFunction :: Bool -> CompilerM op s a -> CompilerM op s a
inNewFunction Bool
keep_cached CompilerM op s a
m = do
  [(VName, Space)]
old_mem <- (CompilerState s -> [(VName, Space)])
-> CompilerM op s [(VName, Space)]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets CompilerState s -> [(VName, Space)]
forall s. CompilerState s -> [(VName, Space)]
compDeclaredMem
  (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((CompilerState s -> CompilerState s) -> CompilerM op s ())
-> (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \CompilerState s
s -> CompilerState s
s {compDeclaredMem :: [(VName, Space)]
compDeclaredMem = [(VName, Space)]
forall a. Monoid a => a
mempty}
  a
x <- (CompilerEnv op s -> CompilerEnv op s)
-> CompilerM op s a -> CompilerM op s a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local CompilerEnv op s -> CompilerEnv op s
forall op s. CompilerEnv op s -> CompilerEnv op s
noCached CompilerM op s a
m
  (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((CompilerState s -> CompilerState s) -> CompilerM op s ())
-> (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \CompilerState s
s -> CompilerState s
s {compDeclaredMem :: [(VName, Space)]
compDeclaredMem = [(VName, Space)]
old_mem}
  a -> CompilerM op s a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x
  where
    noCached :: CompilerEnv op s -> CompilerEnv op s
noCached CompilerEnv op s
env
      | Bool
keep_cached = CompilerEnv op s
env
      | Bool
otherwise = CompilerEnv op s
env {envCachedMem :: Map Exp VName
envCachedMem = Map Exp VName
forall a. Monoid a => a
mempty}

item :: C.BlockItem -> CompilerM op s ()
item :: BlockItem -> CompilerM op s ()
item BlockItem
x = CompilerAcc op s -> CompilerM op s ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell (CompilerAcc op s -> CompilerM op s ())
-> CompilerAcc op s -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ CompilerAcc Any Any
forall a. Monoid a => a
mempty {accItems :: DList BlockItem
accItems = BlockItem -> DList BlockItem
forall a. a -> DList a
DL.singleton BlockItem
x}

items :: [C.BlockItem] -> CompilerM op s ()
items :: [BlockItem] -> CompilerM op s ()
items = (BlockItem -> CompilerM op s ())
-> [BlockItem] -> CompilerM op s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ BlockItem -> CompilerM op s ()
forall op s. BlockItem -> CompilerM op s ()
item

fatMemory :: Space -> CompilerM op s Bool
fatMemory :: Space -> CompilerM op s Bool
fatMemory ScalarSpace {} = Bool -> CompilerM op s Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
fatMemory Space
_ = (CompilerEnv op s -> Bool) -> CompilerM op s Bool
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks CompilerEnv op s -> Bool
forall op s. CompilerEnv op s -> Bool
envFatMemory

cacheMem :: C.ToExp a => a -> CompilerM op s (Maybe VName)
cacheMem :: a -> CompilerM op s (Maybe VName)
cacheMem a
a = (CompilerEnv op s -> Maybe VName) -> CompilerM op s (Maybe VName)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((CompilerEnv op s -> Maybe VName) -> CompilerM op s (Maybe VName))
-> (CompilerEnv op s -> Maybe VName)
-> CompilerM op s (Maybe VName)
forall a b. (a -> b) -> a -> b
$ Exp -> Map Exp VName -> Maybe VName
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (a -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp a
a SrcLoc
forall a. IsLocation a => a
noLoc) (Map Exp VName -> Maybe VName)
-> (CompilerEnv op s -> Map Exp VName)
-> CompilerEnv op s
-> Maybe VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerEnv op s -> Map Exp VName
forall op s. CompilerEnv op s -> Map Exp VName
envCachedMem

instance C.ToIdent Name where
  toIdent :: Name -> SrcLoc -> Id
toIdent = String -> SrcLoc -> Id
forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent (String -> SrcLoc -> Id)
-> (Name -> String) -> Name -> SrcLoc -> Id
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> String
zEncodeString (String -> String) -> (Name -> String) -> Name -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> String
nameToString

instance C.ToIdent VName where
  toIdent :: VName -> SrcLoc -> Id
toIdent = String -> SrcLoc -> Id
forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent (String -> SrcLoc -> Id)
-> (VName -> String) -> VName -> SrcLoc -> Id
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> String
zEncodeString (String -> String) -> (VName -> String) -> VName -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> String
forall a. Pretty a => a -> String
pretty

instance C.ToExp VName where
  toExp :: VName -> SrcLoc -> Exp
toExp VName
v SrcLoc
_ = [C.cexp|$id:v|]

instance C.ToExp IntValue where
  toExp :: IntValue -> SrcLoc -> Exp
toExp (Int8Value Int8
v) = Int8 -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp Int8
v
  toExp (Int16Value Int16
v) = Int16 -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp Int16
v
  toExp (Int32Value Int32
v) = Int32 -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp Int32
v
  toExp (Int64Value Int64
v) = Int64 -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp Int64
v

instance C.ToExp FloatValue where
  toExp :: FloatValue -> SrcLoc -> Exp
toExp (Float32Value Float
v) = Float -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp Float
v
  toExp (Float64Value Double
v) = Double -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp Double
v

instance C.ToExp PrimValue where
  toExp :: PrimValue -> SrcLoc -> Exp
toExp (IntValue IntValue
v) = IntValue -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp IntValue
v
  toExp (FloatValue FloatValue
v) = FloatValue -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp FloatValue
v
  toExp (BoolValue Bool
True) = Int8 -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp (Int8
1 :: Int8)
  toExp (BoolValue Bool
False) = Int8 -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp (Int8
0 :: Int8)
  toExp PrimValue
Checked = Int8 -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp (Int8
1 :: Int8)

instance C.ToExp SubExp where
  toExp :: SubExp -> SrcLoc -> Exp
toExp (Var VName
v) = VName -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp VName
v
  toExp (Constant PrimValue
c) = PrimValue -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp PrimValue
c

-- | Construct a publicly visible definition using the specified name
-- as the template.  The first returned definition is put in the
-- header file, and the second is the implementation.  Returns the public
-- name.
publicDef ::
  String ->
  HeaderSection ->
  (String -> (C.Definition, C.Definition)) ->
  CompilerM op s String
publicDef :: String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s String
publicDef String
s HeaderSection
h String -> (Definition, Definition)
f = do
  String
s' <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName String
s
  let (Definition
pub, Definition
priv) = String -> (Definition, Definition)
f String
s'
  HeaderSection -> Definition -> CompilerM op s ()
forall op s. HeaderSection -> Definition -> CompilerM op s ()
headerDecl HeaderSection
h Definition
pub
  Definition -> CompilerM op s ()
forall op s. Definition -> CompilerM op s ()
earlyDecl Definition
priv
  String -> CompilerM op s String
forall (m :: * -> *) a. Monad m => a -> m a
return String
s'

-- | As 'publicDef', but ignores the public name.
publicDef_ ::
  String ->
  HeaderSection ->
  (String -> (C.Definition, C.Definition)) ->
  CompilerM op s ()
publicDef_ :: String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
publicDef_ String
s HeaderSection
h String -> (Definition, Definition)
f = CompilerM op s String -> CompilerM op s ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (CompilerM op s String -> CompilerM op s ())
-> CompilerM op s String -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s String
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s String
publicDef String
s HeaderSection
h String -> (Definition, Definition)
f

headerDecl :: HeaderSection -> C.Definition -> CompilerM op s ()
headerDecl :: HeaderSection -> Definition -> CompilerM op s ()
headerDecl HeaderSection
sec Definition
def = (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((CompilerState s -> CompilerState s) -> CompilerM op s ())
-> (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \CompilerState s
s ->
  CompilerState s
s
    { compHeaderDecls :: Map HeaderSection (DList Definition)
compHeaderDecls =
        (DList Definition -> DList Definition -> DList Definition)
-> Map HeaderSection (DList Definition)
-> Map HeaderSection (DList Definition)
-> Map HeaderSection (DList Definition)
forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
M.unionWith
          DList Definition -> DList Definition -> DList Definition
forall a. Semigroup a => a -> a -> a
(<>)
          (CompilerState s -> Map HeaderSection (DList Definition)
forall s. CompilerState s -> Map HeaderSection (DList Definition)
compHeaderDecls CompilerState s
s)
          (HeaderSection
-> DList Definition -> Map HeaderSection (DList Definition)
forall k a. k -> a -> Map k a
M.singleton HeaderSection
sec (Definition -> DList Definition
forall a. a -> DList a
DL.singleton Definition
def))
    }

libDecl :: C.Definition -> CompilerM op s ()
libDecl :: Definition -> CompilerM op s ()
libDecl Definition
def = (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((CompilerState s -> CompilerState s) -> CompilerM op s ())
-> (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \CompilerState s
s ->
  CompilerState s
s {compLibDecls :: DList Definition
compLibDecls = CompilerState s -> DList Definition
forall s. CompilerState s -> DList Definition
compLibDecls CompilerState s
s DList Definition -> DList Definition -> DList Definition
forall a. Semigroup a => a -> a -> a
<> Definition -> DList Definition
forall a. a -> DList a
DL.singleton Definition
def}

earlyDecl :: C.Definition -> CompilerM op s ()
earlyDecl :: Definition -> CompilerM op s ()
earlyDecl Definition
def = (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((CompilerState s -> CompilerState s) -> CompilerM op s ())
-> (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \CompilerState s
s ->
  CompilerState s
s {compEarlyDecls :: DList Definition
compEarlyDecls = CompilerState s -> DList Definition
forall s. CompilerState s -> DList Definition
compEarlyDecls CompilerState s
s DList Definition -> DList Definition -> DList Definition
forall a. Semigroup a => a -> a -> a
<> Definition -> DList Definition
forall a. a -> DList a
DL.singleton Definition
def}

contextField :: C.Id -> C.Type -> Maybe C.Exp -> CompilerM op s ()
contextField :: Id -> Type -> Maybe Exp -> CompilerM op s ()
contextField Id
name Type
ty Maybe Exp
initial = (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((CompilerState s -> CompilerState s) -> CompilerM op s ())
-> (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \CompilerState s
s ->
  CompilerState s
s {compCtxFields :: DList (Id, Type, Maybe Exp)
compCtxFields = CompilerState s -> DList (Id, Type, Maybe Exp)
forall s. CompilerState s -> DList (Id, Type, Maybe Exp)
compCtxFields CompilerState s
s DList (Id, Type, Maybe Exp)
-> DList (Id, Type, Maybe Exp) -> DList (Id, Type, Maybe Exp)
forall a. Semigroup a => a -> a -> a
<> (Id, Type, Maybe Exp) -> DList (Id, Type, Maybe Exp)
forall a. a -> DList a
DL.singleton (Id
name, Type
ty, Maybe Exp
initial)}

profileReport :: C.BlockItem -> CompilerM op s ()
profileReport :: BlockItem -> CompilerM op s ()
profileReport BlockItem
x = (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((CompilerState s -> CompilerState s) -> CompilerM op s ())
-> (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \CompilerState s
s ->
  CompilerState s
s {compProfileItems :: DList BlockItem
compProfileItems = CompilerState s -> DList BlockItem
forall s. CompilerState s -> DList BlockItem
compProfileItems CompilerState s
s DList BlockItem -> DList BlockItem -> DList BlockItem
forall a. Semigroup a => a -> a -> a
<> BlockItem -> DList BlockItem
forall a. a -> DList a
DL.singleton BlockItem
x}

stm :: C.Stm -> CompilerM op s ()
stm :: Stm -> CompilerM op s ()
stm Stm
s = BlockItem -> CompilerM op s ()
forall op s. BlockItem -> CompilerM op s ()
item [C.citem|$stm:s|]

stms :: [C.Stm] -> CompilerM op s ()
stms :: [Stm] -> CompilerM op s ()
stms = (Stm -> CompilerM op s ()) -> [Stm] -> CompilerM op s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm

decl :: C.InitGroup -> CompilerM op s ()
decl :: InitGroup -> CompilerM op s ()
decl InitGroup
x = BlockItem -> CompilerM op s ()
forall op s. BlockItem -> CompilerM op s ()
item [C.citem|$decl:x;|]

addrOf :: C.Exp -> C.Exp
addrOf :: Exp -> Exp
addrOf Exp
e = [C.cexp|&$exp:e|]

-- | Public names must have a consitent prefix.
publicName :: String -> CompilerM op s String
publicName :: String -> CompilerM op s String
publicName String
s = String -> CompilerM op s String
forall (m :: * -> *) a. Monad m => a -> m a
return (String -> CompilerM op s String)
-> String -> CompilerM op s String
forall a b. (a -> b) -> a -> b
$ String
"futhark_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
s

-- | The generated code must define a struct with this name.
contextType :: CompilerM op s C.Type
contextType :: CompilerM op s Type
contextType = do
  String
name <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName String
"context"
  Type -> CompilerM op s Type
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cty|struct $id:name|]

memToCType :: VName -> Space -> CompilerM op s C.Type
memToCType :: VName -> Space -> CompilerM op s Type
memToCType VName
v Space
space = do
  Bool
refcount <- Space -> CompilerM op s Bool
forall op s. Space -> CompilerM op s Bool
fatMemory Space
space
  Bool
cached <- Maybe VName -> Bool
forall a. Maybe a -> Bool
isJust (Maybe VName -> Bool)
-> CompilerM op s (Maybe VName) -> CompilerM op s Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> CompilerM op s (Maybe VName)
forall a op s. ToExp a => a -> CompilerM op s (Maybe VName)
cacheMem VName
v
  if Bool
refcount Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
cached
    then Type -> CompilerM op s Type
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> CompilerM op s Type) -> Type -> CompilerM op s Type
forall a b. (a -> b) -> a -> b
$ Space -> Type
fatMemType Space
space
    else Space -> CompilerM op s Type
forall op s. Space -> CompilerM op s Type
rawMemCType Space
space

rawMemCType :: Space -> CompilerM op s C.Type
rawMemCType :: Space -> CompilerM op s Type
rawMemCType Space
DefaultSpace = Type -> CompilerM op s Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
defaultMemBlockType
rawMemCType (Space String
sid) = CompilerM op s (CompilerM op s Type) -> CompilerM op s Type
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (CompilerM op s (CompilerM op s Type) -> CompilerM op s Type)
-> CompilerM op s (CompilerM op s Type) -> CompilerM op s Type
forall a b. (a -> b) -> a -> b
$ (CompilerEnv op s -> MemoryType op s)
-> CompilerM op s (MemoryType op s)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks CompilerEnv op s -> MemoryType op s
forall op s. CompilerEnv op s -> MemoryType op s
envMemoryType CompilerM op s (MemoryType op s)
-> CompilerM op s String -> CompilerM op s (CompilerM op s Type)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> String -> CompilerM op s String
forall (f :: * -> *) a. Applicative f => a -> f a
pure String
sid
rawMemCType (ScalarSpace [] PrimType
t) =
  Type -> CompilerM op s Type
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cty|$ty:(primTypeToCType t)[1]|]
rawMemCType (ScalarSpace [SubExp]
ds PrimType
t) =
  Type -> CompilerM op s Type
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cty|$ty:(primTypeToCType t)[$exp:(cproduct ds')]|]
  where
    ds' :: [Exp]
ds' = (SubExp -> Exp) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
`C.toExp` SrcLoc
forall a. IsLocation a => a
noLoc) [SubExp]
ds

fatMemType :: Space -> C.Type
fatMemType :: Space -> Type
fatMemType Space
space =
  [C.cty|struct $id:name|]
  where
    name :: String
name = case Space
space of
      Space String
sid -> String
"memblock_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
sid
      Space
_ -> String
"memblock"

fatMemSet :: Space -> String
fatMemSet :: Space -> String
fatMemSet (Space String
sid) = String
"memblock_set_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
sid
fatMemSet Space
_ = String
"memblock_set"

fatMemAlloc :: Space -> String
fatMemAlloc :: Space -> String
fatMemAlloc (Space String
sid) = String
"memblock_alloc_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
sid
fatMemAlloc Space
_ = String
"memblock_alloc"

fatMemUnRef :: Space -> String
fatMemUnRef :: Space -> String
fatMemUnRef (Space String
sid) = String
"memblock_unref_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
sid
fatMemUnRef Space
_ = String
"memblock_unref"

rawMem :: VName -> CompilerM op s C.Exp
rawMem :: VName -> CompilerM op s Exp
rawMem VName
v = Bool -> VName -> Exp
forall a. ToExp a => Bool -> a -> Exp
rawMem' (Bool -> VName -> Exp)
-> CompilerM op s Bool -> CompilerM op s (VName -> Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CompilerM op s Bool
forall op s. CompilerM op s Bool
fat CompilerM op s (VName -> Exp)
-> CompilerM op s VName -> CompilerM op s Exp
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> VName -> CompilerM op s VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v
  where
    fat :: CompilerM op s Bool
fat = (CompilerEnv op s -> Bool -> Bool) -> CompilerM op s (Bool -> Bool)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (Bool -> Bool -> Bool
(&&) (Bool -> Bool -> Bool)
-> (CompilerEnv op s -> Bool) -> CompilerEnv op s -> Bool -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerEnv op s -> Bool
forall op s. CompilerEnv op s -> Bool
envFatMemory) CompilerM op s (Bool -> Bool)
-> CompilerM op s Bool -> CompilerM op s Bool
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Maybe VName -> Bool
forall a. Maybe a -> Bool
isNothing (Maybe VName -> Bool)
-> CompilerM op s (Maybe VName) -> CompilerM op s Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> CompilerM op s (Maybe VName)
forall a op s. ToExp a => a -> CompilerM op s (Maybe VName)
cacheMem VName
v)

rawMem' :: C.ToExp a => Bool -> a -> C.Exp
rawMem' :: Bool -> a -> Exp
rawMem' Bool
True a
e = [C.cexp|$exp:e.mem|]
rawMem' Bool
False a
e = [C.cexp|$exp:e|]

allocRawMem ::
  (C.ToExp a, C.ToExp b, C.ToExp c) =>
  a ->
  b ->
  Space ->
  c ->
  CompilerM op s ()
allocRawMem :: a -> b -> Space -> c -> CompilerM op s ()
allocRawMem a
dest b
size Space
space c
desc = case Space
space of
  Space String
sid ->
    CompilerM op s (CompilerM op s ()) -> CompilerM op s ()
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (CompilerM op s (CompilerM op s ()) -> CompilerM op s ())
-> CompilerM op s (CompilerM op s ()) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$
      (CompilerEnv op s -> Allocate op s)
-> CompilerM op s (Allocate op s)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks CompilerEnv op s -> Allocate op s
forall op s. CompilerEnv op s -> Allocate op s
envAllocate CompilerM op s (Allocate op s)
-> CompilerM op s Exp
-> CompilerM op s (Exp -> Exp -> String -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Exp -> CompilerM op s Exp
forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cexp|$exp:dest|]
        CompilerM op s (Exp -> Exp -> String -> CompilerM op s ())
-> CompilerM op s Exp
-> CompilerM op s (Exp -> String -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Exp -> CompilerM op s Exp
forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cexp|$exp:size|]
        CompilerM op s (Exp -> String -> CompilerM op s ())
-> CompilerM op s Exp
-> CompilerM op s (String -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Exp -> CompilerM op s Exp
forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cexp|$exp:desc|]
        CompilerM op s (String -> CompilerM op s ())
-> CompilerM op s String -> CompilerM op s (CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> String -> CompilerM op s String
forall (f :: * -> *) a. Applicative f => a -> f a
pure String
sid
  Space
_ ->
    Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|$exp:dest = (char*) malloc($exp:size);|]

freeRawMem ::
  (C.ToExp a, C.ToExp b) =>
  a ->
  Space ->
  b ->
  CompilerM op s ()
freeRawMem :: a -> Space -> b -> CompilerM op s ()
freeRawMem a
mem Space
space b
desc =
  case Space
space of
    Space String
sid -> do
      Deallocate op s
free_mem <- (CompilerEnv op s -> Deallocate op s)
-> CompilerM op s (Deallocate op s)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks CompilerEnv op s -> Deallocate op s
forall op s. CompilerEnv op s -> Deallocate op s
envDeallocate
      Deallocate op s
free_mem [C.cexp|$exp:mem|] [C.cexp|$exp:desc|] String
sid
    Space
_ -> BlockItem -> CompilerM op s ()
forall op s. BlockItem -> CompilerM op s ()
item [C.citem|free($exp:mem);|]

defineMemorySpace :: Space -> CompilerM op s (C.Definition, [C.Definition], C.BlockItem)
defineMemorySpace :: Space -> CompilerM op s (Definition, [Definition], BlockItem)
defineMemorySpace Space
space = do
  Type
rm <- Space -> CompilerM op s Type
forall op s. Space -> CompilerM op s Type
rawMemCType Space
space
  let structdef :: Definition
structdef =
        [C.cedecl|struct $id:sname { int *references;
                                     $ty:rm mem;
                                     typename int64_t size;
                                     const char *desc; };|]

  Id -> Type -> Maybe Exp -> CompilerM op s ()
forall op s. Id -> Type -> Maybe Exp -> CompilerM op s ()
contextField Id
peakname [C.cty|typename int64_t|] (Maybe Exp -> CompilerM op s ()) -> Maybe Exp -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just [C.cexp|0|]
  Id -> Type -> Maybe Exp -> CompilerM op s ()
forall op s. Id -> Type -> Maybe Exp -> CompilerM op s ()
contextField Id
usagename [C.cty|typename int64_t|] (Maybe Exp -> CompilerM op s ()) -> Maybe Exp -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just [C.cexp|0|]

  -- Unreferencing a memory block consists of decreasing its reference
  -- count and freeing the corresponding memory if the count reaches
  -- zero.
  [BlockItem]
free <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
collect (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ Exp -> Space -> Exp -> CompilerM op s ()
forall a b op s.
(ToExp a, ToExp b) =>
a -> Space -> b -> CompilerM op s ()
freeRawMem [C.cexp|block->mem|] Space
space [C.cexp|desc|]
  Type
ctx_ty <- CompilerM op s Type
forall op s. CompilerM op s Type
contextType
  let unrefdef :: Definition
unrefdef =
        [C.cedecl|static int $id:(fatMemUnRef space) ($ty:ctx_ty *ctx, $ty:mty *block, const char *desc) {
  if (block->references != NULL) {
    *(block->references) -= 1;
    if (ctx->detail_memory) {
      fprintf(stderr, "Unreferencing block %s (allocated as %s) in %s: %d references remaining.\n",
                      desc, block->desc, $string:spacedesc, *(block->references));
    }
    if (*(block->references) == 0) {
      ctx->$id:usagename -= block->size;
      $items:free
      free(block->references);
      if (ctx->detail_memory) {
        fprintf(stderr, "%lld bytes freed (now allocated: %lld bytes)\n",
                (long long) block->size, (long long) ctx->$id:usagename);
      }
    }
    block->references = NULL;
  }
  return 0;
}|]

  -- When allocating a memory block we initialise the reference count to 1.
  [BlockItem]
alloc <-
    CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
collect (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$
      Exp -> Exp -> Space -> Exp -> CompilerM op s ()
forall a b c op s.
(ToExp a, ToExp b, ToExp c) =>
a -> b -> Space -> c -> CompilerM op s ()
allocRawMem [C.cexp|block->mem|] [C.cexp|size|] Space
space [C.cexp|desc|]
  let allocdef :: Definition
allocdef =
        [C.cedecl|static int $id:(fatMemAlloc space) ($ty:ctx_ty *ctx, $ty:mty *block, typename int64_t size, const char *desc) {
  if (size < 0) {
    futhark_panic(1, "Negative allocation of %lld bytes attempted for %s in %s.\n",
          (long long)size, desc, $string:spacedesc, ctx->$id:usagename);
  }
  int ret = $id:(fatMemUnRef space)(ctx, block, desc);

  ctx->$id:usagename += size;
  if (ctx->detail_memory) {
    fprintf(stderr, "Allocating %lld bytes for %s in %s (then allocated: %lld bytes)",
            (long long) size,
            desc, $string:spacedesc,
            (long long) ctx->$id:usagename);
  }
  if (ctx->$id:usagename > ctx->$id:peakname) {
    ctx->$id:peakname = ctx->$id:usagename;
    if (ctx->detail_memory) {
      fprintf(stderr, " (new peak).\n");
    }
  } else if (ctx->detail_memory) {
    fprintf(stderr, ".\n");
  }

  $items:alloc
  block->references = (int*) malloc(sizeof(int));
  *(block->references) = 1;
  block->size = size;
  block->desc = desc;
  return ret;
  }|]

  -- Memory setting - unreference the destination and increase the
  -- count of the source by one.
  let setdef :: Definition
setdef =
        [C.cedecl|static int $id:(fatMemSet space) ($ty:ctx_ty *ctx, $ty:mty *lhs, $ty:mty *rhs, const char *lhs_desc) {
  int ret = $id:(fatMemUnRef space)(ctx, lhs, lhs_desc);
  if (rhs->references != NULL) {
    (*(rhs->references))++;
  }
  *lhs = *rhs;
  return ret;
}
|]

  let peakmsg :: String
peakmsg = String
"Peak memory usage for " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
spacedesc String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": %lld bytes.\n"
  (Definition, [Definition], BlockItem)
-> CompilerM op s (Definition, [Definition], BlockItem)
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( Definition
structdef,
      [Definition
unrefdef, Definition
allocdef, Definition
setdef],
      -- Do not report memory usage for DefaultSpace (CPU memory),
      -- because it would not be accurate anyway.  This whole
      -- tracking probably needs to be rethought.
      if Space
space Space -> Space -> Bool
forall a. Eq a => a -> a -> Bool
== Space
DefaultSpace
        then [C.citem|{}|]
        else [C.citem|str_builder(&builder, $string:peakmsg, (long long) ctx->$id:peakname);|]
    )
  where
    mty :: Type
mty = Space -> Type
fatMemType Space
space
    (Id
peakname, Id
usagename, Id
sname, String
spacedesc) = case Space
space of
      Space String
sid ->
        ( String -> SrcLoc -> Id
forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent (String
"peak_mem_usage_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
sid) SrcLoc
forall a. IsLocation a => a
noLoc,
          String -> SrcLoc -> Id
forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent (String
"cur_mem_usage_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
sid) SrcLoc
forall a. IsLocation a => a
noLoc,
          String -> SrcLoc -> Id
forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent (String
"memblock_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
sid) SrcLoc
forall a. IsLocation a => a
noLoc,
          String
"space '" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
sid String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"'"
        )
      Space
_ ->
        ( Id
"peak_mem_usage_default",
          Id
"cur_mem_usage_default",
          Id
"memblock",
          String
"default space"
        )

declMem :: VName -> Space -> CompilerM op s ()
declMem :: VName -> Space -> CompilerM op s ()
declMem VName
name Space
space = do
  Bool
cached <- Maybe VName -> Bool
forall a. Maybe a -> Bool
isJust (Maybe VName -> Bool)
-> CompilerM op s (Maybe VName) -> CompilerM op s Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> CompilerM op s (Maybe VName)
forall a op s. ToExp a => a -> CompilerM op s (Maybe VName)
cacheMem VName
name
  Bool -> CompilerM op s () -> CompilerM op s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
cached (CompilerM op s () -> CompilerM op s ())
-> CompilerM op s () -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ do
    Type
ty <- VName -> Space -> CompilerM op s Type
forall op s. VName -> Space -> CompilerM op s Type
memToCType VName
name Space
space
    InitGroup -> CompilerM op s ()
forall op s. InitGroup -> CompilerM op s ()
decl [C.cdecl|$ty:ty $id:name;|]
    VName -> Space -> CompilerM op s ()
forall a op s. ToExp a => a -> Space -> CompilerM op s ()
resetMem VName
name Space
space
    (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((CompilerState s -> CompilerState s) -> CompilerM op s ())
-> (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \CompilerState s
s -> CompilerState s
s {compDeclaredMem :: [(VName, Space)]
compDeclaredMem = (VName
name, Space
space) (VName, Space) -> [(VName, Space)] -> [(VName, Space)]
forall a. a -> [a] -> [a]
: CompilerState s -> [(VName, Space)]
forall s. CompilerState s -> [(VName, Space)]
compDeclaredMem CompilerState s
s}

resetMem :: C.ToExp a => a -> Space -> CompilerM op s ()
resetMem :: a -> Space -> CompilerM op s ()
resetMem a
mem Space
space = do
  Bool
refcount <- Space -> CompilerM op s Bool
forall op s. Space -> CompilerM op s Bool
fatMemory Space
space
  Bool
cached <- Maybe VName -> Bool
forall a. Maybe a -> Bool
isJust (Maybe VName -> Bool)
-> CompilerM op s (Maybe VName) -> CompilerM op s Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> CompilerM op s (Maybe VName)
forall a op s. ToExp a => a -> CompilerM op s (Maybe VName)
cacheMem a
mem
  if Bool
cached
    then Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|$exp:mem = NULL;|]
    else
      Bool -> CompilerM op s () -> CompilerM op s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
refcount (CompilerM op s () -> CompilerM op s ())
-> CompilerM op s () -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$
        Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|$exp:mem.references = NULL;|]

setMem :: (C.ToExp a, C.ToExp b) => a -> b -> Space -> CompilerM op s ()
setMem :: a -> b -> Space -> CompilerM op s ()
setMem a
dest b
src Space
space = do
  Bool
refcount <- Space -> CompilerM op s Bool
forall op s. Space -> CompilerM op s Bool
fatMemory Space
space
  let src_s :: String
src_s = Exp -> String
forall a. Pretty a => a -> String
pretty (Exp -> String) -> Exp -> String
forall a b. (a -> b) -> a -> b
$ b -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp b
src SrcLoc
forall a. IsLocation a => a
noLoc
  if Bool
refcount
    then
      Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm
        [C.cstm|if ($id:(fatMemSet space)(ctx, &$exp:dest, &$exp:src,
                                               $string:src_s) != 0) {
                       return 1;
                     }|]
    else case Space
space of
      ScalarSpace [SubExp]
ds PrimType
_ -> do
        VName
i' <- String -> CompilerM op s VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"i"
        let i :: SrcLoc -> Id
i = VName -> SrcLoc -> Id
forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent VName
i'
            it :: Type
it = PrimType -> Type
primTypeToCType (PrimType -> Type) -> PrimType -> Type
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int32
            ds' :: [Exp]
ds' = (SubExp -> Exp) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
`C.toExp` SrcLoc
forall a. IsLocation a => a
noLoc) [SubExp]
ds
            bound :: Exp
bound = [Exp] -> Exp
cproduct [Exp]
ds'
        Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm
          [C.cstm|for ($ty:it $id:i = 0; $id:i < $exp:bound; $id:i++) {
                            $exp:dest[$id:i] = $exp:src[$id:i];
                  }|]
      Space
_ -> Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|$exp:dest = $exp:src;|]

unRefMem :: C.ToExp a => a -> Space -> CompilerM op s ()
unRefMem :: a -> Space -> CompilerM op s ()
unRefMem a
mem Space
space = do
  Bool
refcount <- Space -> CompilerM op s Bool
forall op s. Space -> CompilerM op s Bool
fatMemory Space
space
  Bool
cached <- Maybe VName -> Bool
forall a. Maybe a -> Bool
isJust (Maybe VName -> Bool)
-> CompilerM op s (Maybe VName) -> CompilerM op s Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> CompilerM op s (Maybe VName)
forall a op s. ToExp a => a -> CompilerM op s (Maybe VName)
cacheMem a
mem
  let mem_s :: String
mem_s = Exp -> String
forall a. Pretty a => a -> String
pretty (Exp -> String) -> Exp -> String
forall a b. (a -> b) -> a -> b
$ a -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp a
mem SrcLoc
forall a. IsLocation a => a
noLoc
  Bool -> CompilerM op s () -> CompilerM op s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool
refcount Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
cached) (CompilerM op s () -> CompilerM op s ())
-> CompilerM op s () -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$
    Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm
      [C.cstm|if ($id:(fatMemUnRef space)(ctx, &$exp:mem, $string:mem_s) != 0) {
                  return 1;
                }|]

allocMem ::
  (C.ToExp a, C.ToExp b) =>
  a ->
  b ->
  Space ->
  C.Stm ->
  CompilerM op s ()
allocMem :: a -> b -> Space -> Stm -> CompilerM op s ()
allocMem a
mem b
size Space
space Stm
on_failure = do
  Bool
refcount <- Space -> CompilerM op s Bool
forall op s. Space -> CompilerM op s Bool
fatMemory Space
space
  let mem_s :: String
mem_s = Exp -> String
forall a. Pretty a => a -> String
pretty (Exp -> String) -> Exp -> String
forall a b. (a -> b) -> a -> b
$ a -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp a
mem SrcLoc
forall a. IsLocation a => a
noLoc
  if Bool
refcount
    then
      Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm
        [C.cstm|if ($id:(fatMemAlloc space)(ctx, &$exp:mem, $exp:size,
                                                 $string:mem_s)) {
                       $stm:on_failure
                     }|]
    else do
      a -> Space -> String -> CompilerM op s ()
forall a b op s.
(ToExp a, ToExp b) =>
a -> Space -> b -> CompilerM op s ()
freeRawMem a
mem Space
space String
mem_s
      a -> b -> Space -> Exp -> CompilerM op s ()
forall a b c op s.
(ToExp a, ToExp b, ToExp c) =>
a -> b -> Space -> c -> CompilerM op s ()
allocRawMem a
mem b
size Space
space [C.cexp|desc|]

primTypeInfo :: PrimType -> Signedness -> C.Exp
primTypeInfo :: PrimType -> Signedness -> Exp
primTypeInfo (IntType IntType
it) Signedness
t = case (IntType
it, Signedness
t) of
  (IntType
Int8, Signedness
TypeUnsigned) -> [C.cexp|u8_info|]
  (IntType
Int16, Signedness
TypeUnsigned) -> [C.cexp|u16_info|]
  (IntType
Int32, Signedness
TypeUnsigned) -> [C.cexp|u32_info|]
  (IntType
Int64, Signedness
TypeUnsigned) -> [C.cexp|u64_info|]
  (IntType
Int8, Signedness
_) -> [C.cexp|i8_info|]
  (IntType
Int16, Signedness
_) -> [C.cexp|i16_info|]
  (IntType
Int32, Signedness
_) -> [C.cexp|i32_info|]
  (IntType
Int64, Signedness
_) -> [C.cexp|i64_info|]
primTypeInfo (FloatType FloatType
Float32) Signedness
_ = [C.cexp|f32_info|]
primTypeInfo (FloatType FloatType
Float64) Signedness
_ = [C.cexp|f64_info|]
primTypeInfo PrimType
Bool Signedness
_ = [C.cexp|bool_info|]
primTypeInfo PrimType
Cert Signedness
_ = [C.cexp|bool_info|]

copyMemoryDefaultSpace ::
  C.Exp ->
  C.Exp ->
  C.Exp ->
  C.Exp ->
  C.Exp ->
  CompilerM op s ()
copyMemoryDefaultSpace :: Exp -> Exp -> Exp -> Exp -> Exp -> CompilerM op s ()
copyMemoryDefaultSpace Exp
destmem Exp
destidx Exp
srcmem Exp
srcidx Exp
nbytes =
  Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm
    [C.cstm|memmove($exp:destmem + $exp:destidx,
                      $exp:srcmem + $exp:srcidx,
                      $exp:nbytes);|]

--- Entry points.

arrayName :: PrimType -> Signedness -> Int -> String
arrayName :: PrimType -> Signedness -> Int -> String
arrayName PrimType
pt Signedness
signed Int
rank =
  Bool -> PrimType -> String
prettySigned (Signedness
signed Signedness -> Signedness -> Bool
forall a. Eq a => a -> a -> Bool
== Signedness
TypeUnsigned) PrimType
pt String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
rank String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"d"

opaqueName :: String -> [ValueDesc] -> String
opaqueName :: String -> [ValueDesc] -> String
opaqueName String
s [ValueDesc]
_
  | Bool
valid = String
"opaque_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
s
  where
    valid :: Bool
valid =
      String -> Char
forall a. [a] -> a
head String
s Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
/= Char
'_'
        Bool -> Bool -> Bool
&& Bool -> Bool
not (Char -> Bool
isDigit (Char -> Bool) -> Char -> Bool
forall a b. (a -> b) -> a -> b
$ String -> Char
forall a. [a] -> a
head String
s)
        Bool -> Bool -> Bool
&& (Char -> Bool) -> String -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Char -> Bool
ok String
s
    ok :: Char -> Bool
ok Char
c = Char -> Bool
isAlphaNum Char
c Bool -> Bool -> Bool
|| Char
c Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
'_'
opaqueName String
s [ValueDesc]
vds = String
"opaque_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ [Int] -> String
hash ((Int -> Int -> Int) -> [Int] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Int -> Int
forall a. Bits a => a -> a -> a
xor [Int
0 ..] ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ (Char -> Int) -> String -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map Char -> Int
ord (String
s String -> String -> String
forall a. [a] -> [a] -> [a]
++ (ValueDesc -> String) -> [ValueDesc] -> String
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ValueDesc -> String
p [ValueDesc]
vds))
  where
    p :: ValueDesc -> String
p (ScalarValue PrimType
pt Signedness
signed VName
_) =
      (PrimType, Signedness) -> String
forall a. Show a => a -> String
show (PrimType
pt, Signedness
signed)
    p (ArrayValue VName
_ Space
space PrimType
pt Signedness
signed [SubExp]
dims) =
      (Space, PrimType, Signedness, Int) -> String
forall a. Show a => a -> String
show (Space
space, PrimType
pt, Signedness
signed, [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
dims)

    -- FIXME: a stupid hash algorithm; may have collisions.
    hash :: [Int] -> String
hash =
      String -> Word32 -> String
forall r. PrintfType r => String -> r
printf String
"%x" (Word32 -> String) -> ([Int] -> Word32) -> [Int] -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Word32 -> Word32 -> Word32) -> Word32 -> [Word32] -> Word32
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
xor Word32
0
        ([Word32] -> Word32) -> ([Int] -> [Word32]) -> [Int] -> Word32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Word32) -> [Int] -> [Word32]
forall a b. (a -> b) -> [a] -> [b]
map
          ( Word32 -> Word32
iter (Word32 -> Word32) -> (Int -> Word32) -> Int -> Word32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
* Word32
0x45d9f3b)
              (Word32 -> Word32) -> (Int -> Word32) -> Int -> Word32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word32 -> Word32
iter
              (Word32 -> Word32) -> (Int -> Word32) -> Int -> Word32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
* Word32
0x45d9f3b)
              (Word32 -> Word32) -> (Int -> Word32) -> Int -> Word32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word32 -> Word32
iter
              (Word32 -> Word32) -> (Int -> Word32) -> Int -> Word32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral
          )
    iter :: Word32 -> Word32
iter Word32
x = ((Word32
x :: Word32) Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftR` Int
16) Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
`xor` Word32
x

criticalSection :: Operations op s -> [C.BlockItem] -> [C.BlockItem]
criticalSection :: Operations op s -> [BlockItem] -> [BlockItem]
criticalSection Operations op s
ops [BlockItem]
x =
  [C.citems|lock_lock(&ctx->lock);
            $items:(fst (opsCritical ops))
            $items:x
            $items:(snd (opsCritical ops))
            lock_unlock(&ctx->lock);
           |]

arrayLibraryFunctions ::
  Space ->
  PrimType ->
  Signedness ->
  [DimSize] ->
  CompilerM op s [C.Definition]
arrayLibraryFunctions :: Space
-> PrimType
-> Signedness
-> [SubExp]
-> CompilerM op s [Definition]
arrayLibraryFunctions Space
space PrimType
pt Signedness
signed [SubExp]
shape = do
  let rank :: Int
rank = [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
shape
      pt' :: Type
pt' = Signedness -> PrimType -> Type
signedPrimTypeToCType Signedness
signed PrimType
pt
      name :: String
name = PrimType -> Signedness -> Int -> String
arrayName PrimType
pt Signedness
signed Int
rank
      arr_name :: String
arr_name = String
"futhark_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
name
      array_type :: Type
array_type = [C.cty|struct $id:arr_name|]

  String
new_array <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName (String -> CompilerM op s String)
-> String -> CompilerM op s String
forall a b. (a -> b) -> a -> b
$ String
"new_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
name
  String
new_raw_array <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName (String -> CompilerM op s String)
-> String -> CompilerM op s String
forall a b. (a -> b) -> a -> b
$ String
"new_raw_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
name
  String
free_array <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName (String -> CompilerM op s String)
-> String -> CompilerM op s String
forall a b. (a -> b) -> a -> b
$ String
"free_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
name
  String
values_array <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName (String -> CompilerM op s String)
-> String -> CompilerM op s String
forall a b. (a -> b) -> a -> b
$ String
"values_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
name
  String
values_raw_array <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName (String -> CompilerM op s String)
-> String -> CompilerM op s String
forall a b. (a -> b) -> a -> b
$ String
"values_raw_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
name
  String
shape_array <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName (String -> CompilerM op s String)
-> String -> CompilerM op s String
forall a b. (a -> b) -> a -> b
$ String
"shape_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
name

  let shape_names :: [String]
shape_names = [String
"dim" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i | Int
i <- [Int
0 .. Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]]
      shape_params :: [Param]
shape_params = [[C.cparam|typename int64_t $id:k|] | String
k <- [String]
shape_names]
      arr_size :: Exp
arr_size = [Exp] -> Exp
cproduct [[C.cexp|$id:k|] | String
k <- [String]
shape_names]
      arr_size_array :: Exp
arr_size_array = [Exp] -> Exp
cproduct [[C.cexp|arr->shape[$int:i]|] | Int
i <- [Int
0 .. Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]]
  Copy op s
copy <- (CompilerEnv op s -> Copy op s) -> CompilerM op s (Copy op s)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks CompilerEnv op s -> Copy op s
forall op s. CompilerEnv op s -> Copy op s
envCopy

  Type
memty <- Space -> CompilerM op s Type
forall op s. Space -> CompilerM op s Type
rawMemCType Space
space

  let prepare_new :: CompilerM op s ()
prepare_new = do
        Exp -> Space -> CompilerM op s ()
forall a op s. ToExp a => a -> Space -> CompilerM op s ()
resetMem [C.cexp|arr->mem|] Space
space
        Exp -> Exp -> Space -> Stm -> CompilerM op s ()
forall a b op s.
(ToExp a, ToExp b) =>
a -> b -> Space -> Stm -> CompilerM op s ()
allocMem
          [C.cexp|arr->mem|]
          [C.cexp|((size_t)$exp:arr_size) * sizeof($ty:pt')|]
          Space
space
          [C.cstm|return NULL;|]
        [Int] -> (Int -> CompilerM op s ()) -> CompilerM op s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0 .. Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> CompilerM op s ()) -> CompilerM op s ())
-> (Int -> CompilerM op s ()) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \Int
i ->
          let dim_s :: String
dim_s = String
"dim" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i
           in Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|arr->shape[$int:i] = $id:dim_s;|]

  [BlockItem]
new_body <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
collect (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ do
    CompilerM op s ()
forall op s. CompilerM op s ()
prepare_new
    Copy op s
copy
      [C.cexp|arr->mem.mem|]
      [C.cexp|0|]
      Space
space
      [C.cexp|data|]
      [C.cexp|0|]
      Space
DefaultSpace
      [C.cexp|((size_t)$exp:arr_size) * sizeof($ty:pt')|]

  [BlockItem]
new_raw_body <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
collect (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ do
    CompilerM op s ()
forall op s. CompilerM op s ()
prepare_new
    Copy op s
copy
      [C.cexp|arr->mem.mem|]
      [C.cexp|0|]
      Space
space
      [C.cexp|data|]
      [C.cexp|offset|]
      Space
space
      [C.cexp|((size_t)$exp:arr_size) * sizeof($ty:pt')|]

  [BlockItem]
free_body <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
collect (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ Exp -> Space -> CompilerM op s ()
forall a op s. ToExp a => a -> Space -> CompilerM op s ()
unRefMem [C.cexp|arr->mem|] Space
space

  [BlockItem]
values_body <-
    CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
collect (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$
      Copy op s
copy
        [C.cexp|data|]
        [C.cexp|0|]
        Space
DefaultSpace
        [C.cexp|arr->mem.mem|]
        [C.cexp|0|]
        Space
space
        [C.cexp|((size_t)$exp:arr_size_array) * sizeof($ty:pt')|]

  Type
ctx_ty <- CompilerM op s Type
forall op s. CompilerM op s Type
contextType
  Operations op s
ops <- (CompilerEnv op s -> Operations op s)
-> CompilerM op s (Operations op s)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks CompilerEnv op s -> Operations op s
forall op s. CompilerEnv op s -> Operations op s
envOperations

  HeaderSection -> Definition -> CompilerM op s ()
forall op s. HeaderSection -> Definition -> CompilerM op s ()
headerDecl
    (String -> HeaderSection
ArrayDecl String
name)
    [C.cedecl|struct $id:arr_name;|]
  HeaderSection -> Definition -> CompilerM op s ()
forall op s. HeaderSection -> Definition -> CompilerM op s ()
headerDecl
    (String -> HeaderSection
ArrayDecl String
name)
    [C.cedecl|$ty:array_type* $id:new_array($ty:ctx_ty *ctx, const $ty:pt' *data, $params:shape_params);|]
  HeaderSection -> Definition -> CompilerM op s ()
forall op s. HeaderSection -> Definition -> CompilerM op s ()
headerDecl
    (String -> HeaderSection
ArrayDecl String
name)
    [C.cedecl|$ty:array_type* $id:new_raw_array($ty:ctx_ty *ctx, const $ty:memty data, int offset, $params:shape_params);|]
  HeaderSection -> Definition -> CompilerM op s ()
forall op s. HeaderSection -> Definition -> CompilerM op s ()
headerDecl
    (String -> HeaderSection
ArrayDecl String
name)
    [C.cedecl|int $id:free_array($ty:ctx_ty *ctx, $ty:array_type *arr);|]
  HeaderSection -> Definition -> CompilerM op s ()
forall op s. HeaderSection -> Definition -> CompilerM op s ()
headerDecl
    (String -> HeaderSection
ArrayDecl String
name)
    [C.cedecl|int $id:values_array($ty:ctx_ty *ctx, $ty:array_type *arr, $ty:pt' *data);|]
  HeaderSection -> Definition -> CompilerM op s ()
forall op s. HeaderSection -> Definition -> CompilerM op s ()
headerDecl
    (String -> HeaderSection
ArrayDecl String
name)
    [C.cedecl|$ty:memty $id:values_raw_array($ty:ctx_ty *ctx, $ty:array_type *arr);|]
  HeaderSection -> Definition -> CompilerM op s ()
forall op s. HeaderSection -> Definition -> CompilerM op s ()
headerDecl
    (String -> HeaderSection
ArrayDecl String
name)
    [C.cedecl|const typename int64_t* $id:shape_array($ty:ctx_ty *ctx, $ty:array_type *arr);|]

  [Definition] -> CompilerM op s [Definition]
forall (m :: * -> *) a. Monad m => a -> m a
return
    [C.cunit|
          $ty:array_type* $id:new_array($ty:ctx_ty *ctx, const $ty:pt' *data, $params:shape_params) {
            $ty:array_type* bad = NULL;
            $ty:array_type *arr = ($ty:array_type*) malloc(sizeof($ty:array_type));
            if (arr == NULL) {
              return bad;
            }
            $items:(criticalSection ops new_body)
            return arr;
          }

          $ty:array_type* $id:new_raw_array($ty:ctx_ty *ctx, const $ty:memty data, int offset,
                                            $params:shape_params) {
            $ty:array_type* bad = NULL;
            $ty:array_type *arr = ($ty:array_type*) malloc(sizeof($ty:array_type));
            if (arr == NULL) {
              return bad;
            }
            $items:(criticalSection ops new_raw_body)
            return arr;
          }

          int $id:free_array($ty:ctx_ty *ctx, $ty:array_type *arr) {
            $items:(criticalSection ops free_body)
            free(arr);
            return 0;
          }

          int $id:values_array($ty:ctx_ty *ctx, $ty:array_type *arr, $ty:pt' *data) {
            $items:(criticalSection ops values_body)
            return 0;
          }

          $ty:memty $id:values_raw_array($ty:ctx_ty *ctx, $ty:array_type *arr) {
            (void)ctx;
            return arr->mem.mem;
          }

          const typename int64_t* $id:shape_array($ty:ctx_ty *ctx, $ty:array_type *arr) {
            (void)ctx;
            return arr->shape;
          }
          |]

opaqueLibraryFunctions ::
  String ->
  [ValueDesc] ->
  CompilerM op s [C.Definition]
opaqueLibraryFunctions :: String -> [ValueDesc] -> CompilerM op s [Definition]
opaqueLibraryFunctions String
desc [ValueDesc]
vds = do
  String
name <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName (String -> CompilerM op s String)
-> String -> CompilerM op s String
forall a b. (a -> b) -> a -> b
$ String -> [ValueDesc] -> String
opaqueName String
desc [ValueDesc]
vds
  String
free_opaque <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName (String -> CompilerM op s String)
-> String -> CompilerM op s String
forall a b. (a -> b) -> a -> b
$ String
"free_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String -> [ValueDesc] -> String
opaqueName String
desc [ValueDesc]
vds

  let opaque_type :: Type
opaque_type = [C.cty|struct $id:name|]

      freeComponent :: Int -> ValueDesc -> CompilerM op s ()
freeComponent Int
_ ScalarValue {} =
        () -> CompilerM op s ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
      freeComponent Int
i (ArrayValue VName
_ Space
_ PrimType
pt Signedness
signed [SubExp]
shape) = do
        let rank :: Int
rank = [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
shape
        String
free_array <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName (String -> CompilerM op s String)
-> String -> CompilerM op s String
forall a b. (a -> b) -> a -> b
$ String
"free_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ PrimType -> Signedness -> Int -> String
arrayName PrimType
pt Signedness
signed Int
rank
        Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm
          [C.cstm|if ((tmp = $id:free_array(ctx, obj->$id:(tupleField i))) != 0) {
                ret = tmp;
             }|]

  Type
ctx_ty <- CompilerM op s Type
forall op s. CompilerM op s Type
contextType

  [BlockItem]
free_body <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
collect (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ (Int -> ValueDesc -> CompilerM op s ())
-> [Int] -> [ValueDesc] -> CompilerM op s ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Int -> ValueDesc -> CompilerM op s ()
forall op s. Int -> ValueDesc -> CompilerM op s ()
freeComponent [Int
0 ..] [ValueDesc]
vds

  HeaderSection -> Definition -> CompilerM op s ()
forall op s. HeaderSection -> Definition -> CompilerM op s ()
headerDecl
    (String -> HeaderSection
OpaqueDecl String
desc)
    [C.cedecl|int $id:free_opaque($ty:ctx_ty *ctx, $ty:opaque_type *obj);|]

  -- We do not need to enclose the body in a critical section, because
  -- when we free the components of the opaque, we are calling public
  -- API functions that do their own locking.
  [Definition] -> CompilerM op s [Definition]
forall (m :: * -> *) a. Monad m => a -> m a
return
    [C.cunit|
          int $id:free_opaque($ty:ctx_ty *ctx, $ty:opaque_type *obj) {
            int ret = 0, tmp;
            $items:free_body
            free(obj);
            return ret;
          }
           |]

valueDescToCType :: ValueDesc -> CompilerM op s C.Type
valueDescToCType :: ValueDesc -> CompilerM op s Type
valueDescToCType (ScalarValue PrimType
pt Signedness
signed VName
_) =
  Type -> CompilerM op s Type
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> CompilerM op s Type) -> Type -> CompilerM op s Type
forall a b. (a -> b) -> a -> b
$ Signedness -> PrimType -> Type
signedPrimTypeToCType Signedness
signed PrimType
pt
valueDescToCType (ArrayValue VName
mem Space
space PrimType
pt Signedness
signed [SubExp]
shape) = do
  let pt' :: Type
pt' = Signedness -> PrimType -> Type
signedPrimTypeToCType Signedness
signed PrimType
pt
      rank :: Int
rank = [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
shape
  Maybe (Type, [Definition])
exists <- (CompilerState s -> Maybe (Type, [Definition]))
-> CompilerM op s (Maybe (Type, [Definition]))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((CompilerState s -> Maybe (Type, [Definition]))
 -> CompilerM op s (Maybe (Type, [Definition])))
-> (CompilerState s -> Maybe (Type, [Definition]))
-> CompilerM op s (Maybe (Type, [Definition]))
forall a b. (a -> b) -> a -> b
$ (Type, Int)
-> [((Type, Int), (Type, [Definition]))]
-> Maybe (Type, [Definition])
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup (Type
pt', Int
rank) ([((Type, Int), (Type, [Definition]))]
 -> Maybe (Type, [Definition]))
-> (CompilerState s -> [((Type, Int), (Type, [Definition]))])
-> CompilerState s
-> Maybe (Type, [Definition])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerState s -> [((Type, Int), (Type, [Definition]))]
forall s. CompilerState s -> [((Type, Int), (Type, [Definition]))]
compArrayStructs
  case Maybe (Type, [Definition])
exists of
    Just (Type
cty, [Definition]
_) -> Type -> CompilerM op s Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
cty
    Maybe (Type, [Definition])
Nothing -> do
      Type
memty <- VName -> Space -> CompilerM op s Type
forall op s. VName -> Space -> CompilerM op s Type
memToCType VName
mem Space
space
      String
name <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName (String -> CompilerM op s String)
-> String -> CompilerM op s String
forall a b. (a -> b) -> a -> b
$ PrimType -> Signedness -> Int -> String
arrayName PrimType
pt Signedness
signed Int
rank
      let struct :: Definition
struct = [C.cedecl|struct $id:name { $ty:memty mem; typename int64_t shape[$int:rank]; };|]
          stype :: Type
stype = [C.cty|struct $id:name|]
      [Definition]
library <- Space
-> PrimType
-> Signedness
-> [SubExp]
-> CompilerM op s [Definition]
forall op s.
Space
-> PrimType
-> Signedness
-> [SubExp]
-> CompilerM op s [Definition]
arrayLibraryFunctions Space
space PrimType
pt Signedness
signed [SubExp]
shape
      (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((CompilerState s -> CompilerState s) -> CompilerM op s ())
-> (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \CompilerState s
s ->
        CompilerState s
s
          { compArrayStructs :: [((Type, Int), (Type, [Definition]))]
compArrayStructs =
              ((Type
pt', Int
rank), (Type
stype, Definition
struct Definition -> [Definition] -> [Definition]
forall a. a -> [a] -> [a]
: [Definition]
library)) ((Type, Int), (Type, [Definition]))
-> [((Type, Int), (Type, [Definition]))]
-> [((Type, Int), (Type, [Definition]))]
forall a. a -> [a] -> [a]
: CompilerState s -> [((Type, Int), (Type, [Definition]))]
forall s. CompilerState s -> [((Type, Int), (Type, [Definition]))]
compArrayStructs CompilerState s
s
          }
      Type -> CompilerM op s Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
stype

opaqueToCType :: String -> [ValueDesc] -> CompilerM op s C.Type
opaqueToCType :: String -> [ValueDesc] -> CompilerM op s Type
opaqueToCType String
desc [ValueDesc]
vds = do
  String
name <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName (String -> CompilerM op s String)
-> String -> CompilerM op s String
forall a b. (a -> b) -> a -> b
$ String -> [ValueDesc] -> String
opaqueName String
desc [ValueDesc]
vds
  Maybe (Type, [Definition])
exists <- (CompilerState s -> Maybe (Type, [Definition]))
-> CompilerM op s (Maybe (Type, [Definition]))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((CompilerState s -> Maybe (Type, [Definition]))
 -> CompilerM op s (Maybe (Type, [Definition])))
-> (CompilerState s -> Maybe (Type, [Definition]))
-> CompilerM op s (Maybe (Type, [Definition]))
forall a b. (a -> b) -> a -> b
$ String
-> [(String, (Type, [Definition]))] -> Maybe (Type, [Definition])
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup String
name ([(String, (Type, [Definition]))] -> Maybe (Type, [Definition]))
-> (CompilerState s -> [(String, (Type, [Definition]))])
-> CompilerState s
-> Maybe (Type, [Definition])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerState s -> [(String, (Type, [Definition]))]
forall s. CompilerState s -> [(String, (Type, [Definition]))]
compOpaqueStructs
  case Maybe (Type, [Definition])
exists of
    Just (Type
ty, [Definition]
_) -> Type -> CompilerM op s Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
ty
    Maybe (Type, [Definition])
Nothing -> do
      [FieldGroup]
members <- (ValueDesc -> Int -> CompilerM op s FieldGroup)
-> [ValueDesc] -> [Int] -> CompilerM op s [FieldGroup]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM ValueDesc -> Int -> CompilerM op s FieldGroup
forall op s. ValueDesc -> Int -> CompilerM op s FieldGroup
field [ValueDesc]
vds [(Int
0 :: Int) ..]
      let struct :: Definition
struct = [C.cedecl|struct $id:name { $sdecls:members };|]
          stype :: Type
stype = [C.cty|struct $id:name|]
      HeaderSection -> Definition -> CompilerM op s ()
forall op s. HeaderSection -> Definition -> CompilerM op s ()
headerDecl (String -> HeaderSection
OpaqueDecl String
desc) [C.cedecl|struct $id:name;|]
      [Definition]
library <- String -> [ValueDesc] -> CompilerM op s [Definition]
forall op s. String -> [ValueDesc] -> CompilerM op s [Definition]
opaqueLibraryFunctions String
desc [ValueDesc]
vds
      (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((CompilerState s -> CompilerState s) -> CompilerM op s ())
-> (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \CompilerState s
s ->
        CompilerState s
s
          { compOpaqueStructs :: [(String, (Type, [Definition]))]
compOpaqueStructs =
              (String
name, (Type
stype, Definition
struct Definition -> [Definition] -> [Definition]
forall a. a -> [a] -> [a]
: [Definition]
library)) (String, (Type, [Definition]))
-> [(String, (Type, [Definition]))]
-> [(String, (Type, [Definition]))]
forall a. a -> [a] -> [a]
:
              CompilerState s -> [(String, (Type, [Definition]))]
forall s. CompilerState s -> [(String, (Type, [Definition]))]
compOpaqueStructs CompilerState s
s
          }
      Type -> CompilerM op s Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
stype
  where
    field :: ValueDesc -> Int -> CompilerM op s FieldGroup
field vd :: ValueDesc
vd@ScalarValue {} Int
i = do
      Type
ct <- ValueDesc -> CompilerM op s Type
forall op s. ValueDesc -> CompilerM op s Type
valueDescToCType ValueDesc
vd
      FieldGroup -> CompilerM op s FieldGroup
forall (m :: * -> *) a. Monad m => a -> m a
return [C.csdecl|$ty:ct $id:(tupleField i);|]
    field ValueDesc
vd Int
i = do
      Type
ct <- ValueDesc -> CompilerM op s Type
forall op s. ValueDesc -> CompilerM op s Type
valueDescToCType ValueDesc
vd
      FieldGroup -> CompilerM op s FieldGroup
forall (m :: * -> *) a. Monad m => a -> m a
return [C.csdecl|$ty:ct *$id:(tupleField i);|]

externalValueToCType :: ExternalValue -> CompilerM op s C.Type
externalValueToCType :: ExternalValue -> CompilerM op s Type
externalValueToCType (TransparentValue ValueDesc
vd) = ValueDesc -> CompilerM op s Type
forall op s. ValueDesc -> CompilerM op s Type
valueDescToCType ValueDesc
vd
externalValueToCType (OpaqueValue String
desc [ValueDesc]
vds) = String -> [ValueDesc] -> CompilerM op s Type
forall op s. String -> [ValueDesc] -> CompilerM op s Type
opaqueToCType String
desc [ValueDesc]
vds

prepareEntryInputs :: [ExternalValue] -> CompilerM op s [C.Param]
prepareEntryInputs :: [ExternalValue] -> CompilerM op s [Param]
prepareEntryInputs = (Int -> ExternalValue -> CompilerM op s Param)
-> [Int] -> [ExternalValue] -> CompilerM op s [Param]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM Int -> ExternalValue -> CompilerM op s Param
forall a op s. Show a => a -> ExternalValue -> CompilerM op s Param
prepare [(Int
0 :: Int) ..]
  where
    prepare :: a -> ExternalValue -> CompilerM op s Param
prepare a
pno (TransparentValue ValueDesc
vd) = do
      let pname :: String
pname = String
"in" String -> String -> String
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
pno
      Type
ty <- Exp -> ValueDesc -> CompilerM op s Type
forall a op s. ToExp a => a -> ValueDesc -> CompilerM op s Type
prepareValue [C.cexp|$id:pname|] ValueDesc
vd
      Param -> CompilerM op s Param
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cparam|const $ty:ty $id:pname|]
    prepare a
pno (OpaqueValue String
desc [ValueDesc]
vds) = do
      Type
ty <- String -> [ValueDesc] -> CompilerM op s Type
forall op s. String -> [ValueDesc] -> CompilerM op s Type
opaqueToCType String
desc [ValueDesc]
vds
      let pname :: String
pname = String
"in" String -> String -> String
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
pno
          field :: Int -> ValueDesc -> Exp
field Int
i ScalarValue {} = [C.cexp|$id:pname->$id:(tupleField i)|]
          field Int
i ArrayValue {} = [C.cexp|$id:pname->$id:(tupleField i)|]
      (Exp -> ValueDesc -> CompilerM op s Type)
-> [Exp] -> [ValueDesc] -> CompilerM op s ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Exp -> ValueDesc -> CompilerM op s Type
forall a op s. ToExp a => a -> ValueDesc -> CompilerM op s Type
prepareValue ((Int -> ValueDesc -> Exp) -> [Int] -> [ValueDesc] -> [Exp]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> ValueDesc -> Exp
field [Int
0 ..] [ValueDesc]
vds) [ValueDesc]
vds
      Param -> CompilerM op s Param
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cparam|const $ty:ty *$id:pname|]

    prepareValue :: a -> ValueDesc -> CompilerM op s Type
prepareValue a
src (ScalarValue PrimType
pt Signedness
signed VName
name) = do
      let pt' :: Type
pt' = Signedness -> PrimType -> Type
signedPrimTypeToCType Signedness
signed PrimType
pt
      Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|$id:name = $exp:src;|]
      Type -> CompilerM op s Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
pt'
    prepareValue a
src vd :: ValueDesc
vd@(ArrayValue VName
mem Space
_ PrimType
_ Signedness
_ [SubExp]
shape) = do
      Type
ty <- ValueDesc -> CompilerM op s Type
forall op s. ValueDesc -> CompilerM op s Type
valueDescToCType ValueDesc
vd

      Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|$exp:mem = $exp:src->mem;|]

      let rank :: Int
rank = [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
shape
          maybeCopyDim :: SubExp -> a -> Maybe Stm
maybeCopyDim (Var VName
d) a
i =
            Stm -> Maybe Stm
forall a. a -> Maybe a
Just [C.cstm|$id:d = $exp:src->shape[$int:i];|]
          maybeCopyDim SubExp
_ a
_ = Maybe Stm
forall a. Maybe a
Nothing

      [Stm] -> CompilerM op s ()
forall op s. [Stm] -> CompilerM op s ()
stms ([Stm] -> CompilerM op s ()) -> [Stm] -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ [Maybe Stm] -> [Stm]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe Stm] -> [Stm]) -> [Maybe Stm] -> [Stm]
forall a b. (a -> b) -> a -> b
$ (SubExp -> Int -> Maybe Stm) -> [SubExp] -> [Int] -> [Maybe Stm]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SubExp -> Int -> Maybe Stm
forall a. (Show a, Integral a) => SubExp -> a -> Maybe Stm
maybeCopyDim [SubExp]
shape [Int
0 .. Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]

      Type -> CompilerM op s Type
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cty|$ty:ty*|]

prepareEntryOutputs :: [ExternalValue] -> CompilerM op s [C.Param]
prepareEntryOutputs :: [ExternalValue] -> CompilerM op s [Param]
prepareEntryOutputs = (Int -> ExternalValue -> CompilerM op s Param)
-> [Int] -> [ExternalValue] -> CompilerM op s [Param]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM Int -> ExternalValue -> CompilerM op s Param
forall a op s. Show a => a -> ExternalValue -> CompilerM op s Param
prepare [(Int
0 :: Int) ..]
  where
    prepare :: a -> ExternalValue -> CompilerM op s Param
prepare a
pno (TransparentValue ValueDesc
vd) = do
      let pname :: String
pname = String
"out" String -> String -> String
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
pno
      Type
ty <- ValueDesc -> CompilerM op s Type
forall op s. ValueDesc -> CompilerM op s Type
valueDescToCType ValueDesc
vd

      case ValueDesc
vd of
        ArrayValue {} -> do
          Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|assert((*$id:pname = ($ty:ty*) malloc(sizeof($ty:ty))) != NULL);|]
          Exp -> ValueDesc -> CompilerM op s ()
forall a op s. ToExp a => a -> ValueDesc -> CompilerM op s ()
prepareValue [C.cexp|*$id:pname|] ValueDesc
vd
          Param -> CompilerM op s Param
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cparam|$ty:ty **$id:pname|]
        ScalarValue {} -> do
          Exp -> ValueDesc -> CompilerM op s ()
forall a op s. ToExp a => a -> ValueDesc -> CompilerM op s ()
prepareValue [C.cexp|*$id:pname|] ValueDesc
vd
          Param -> CompilerM op s Param
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cparam|$ty:ty *$id:pname|]
    prepare a
pno (OpaqueValue String
desc [ValueDesc]
vds) = do
      let pname :: String
pname = String
"out" String -> String -> String
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
pno
      Type
ty <- String -> [ValueDesc] -> CompilerM op s Type
forall op s. String -> [ValueDesc] -> CompilerM op s Type
opaqueToCType String
desc [ValueDesc]
vds
      [Type]
vd_ts <- (ValueDesc -> CompilerM op s Type)
-> [ValueDesc] -> CompilerM op s [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ValueDesc -> CompilerM op s Type
forall op s. ValueDesc -> CompilerM op s Type
valueDescToCType [ValueDesc]
vds

      Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|assert((*$id:pname = ($ty:ty*) malloc(sizeof($ty:ty))) != NULL);|]

      [(Int, Type, ValueDesc)]
-> ((Int, Type, ValueDesc) -> CompilerM op s ())
-> CompilerM op s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Int] -> [Type] -> [ValueDesc] -> [(Int, Type, ValueDesc)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Int
0 ..] [Type]
vd_ts [ValueDesc]
vds) (((Int, Type, ValueDesc) -> CompilerM op s ())
 -> CompilerM op s ())
-> ((Int, Type, ValueDesc) -> CompilerM op s ())
-> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \(Int
i, Type
ct, ValueDesc
vd) -> do
        let field :: Exp
field = [C.cexp|(*$id:pname)->$id:(tupleField i)|]
        case ValueDesc
vd of
          ScalarValue {} -> () -> CompilerM op s ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
          ValueDesc
_ -> Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|assert(($exp:field = ($ty:ct*) malloc(sizeof($ty:ct))) != NULL);|]
        Exp -> ValueDesc -> CompilerM op s ()
forall a op s. ToExp a => a -> ValueDesc -> CompilerM op s ()
prepareValue Exp
field ValueDesc
vd

      Param -> CompilerM op s Param
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cparam|$ty:ty **$id:pname|]

    prepareValue :: a -> ValueDesc -> CompilerM op s ()
prepareValue a
dest (ScalarValue PrimType
_ Signedness
_ VName
name) =
      Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|$exp:dest = $id:name;|]
    prepareValue a
dest (ArrayValue VName
mem Space
_ PrimType
_ Signedness
_ [SubExp]
shape) = do
      Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|$exp:dest->mem = $id:mem;|]

      let rank :: Int
rank = [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
shape
          maybeCopyDim :: SubExp -> a -> Stm
maybeCopyDim (Constant PrimValue
x) a
i =
            [C.cstm|$exp:dest->shape[$int:i] = $exp:x;|]
          maybeCopyDim (Var VName
d) a
i =
            [C.cstm|$exp:dest->shape[$int:i] = $id:d;|]
      [Stm] -> CompilerM op s ()
forall op s. [Stm] -> CompilerM op s ()
stms ([Stm] -> CompilerM op s ()) -> [Stm] -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ (SubExp -> Int -> Stm) -> [SubExp] -> [Int] -> [Stm]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SubExp -> Int -> Stm
forall a. (Show a, Integral a) => SubExp -> a -> Stm
maybeCopyDim [SubExp]
shape [Int
0 .. Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]

onEntryPoint ::
  Name ->
  Function op ->
  CompilerM op s (C.Definition, C.Definition, C.Initializer)
onEntryPoint :: Name
-> Function op
-> CompilerM op s (Definition, Definition, Initializer)
onEntryPoint Name
fname function :: Function op
function@(Function Bool
_ [Param]
outputs [Param]
inputs Code op
_ [ExternalValue]
results [ExternalValue]
args) = do
  let out_args :: [Exp]
out_args = (Param -> Exp) -> [Param] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (\Param
p -> [C.cexp|&$id:(paramName p)|]) [Param]
outputs
      in_args :: [Exp]
in_args = (Param -> Exp) -> [Param] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (\Param
p -> [C.cexp|$id:(paramName p)|]) [Param]
inputs

  [BlockItem]
inputdecls <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
collect (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ (Param -> CompilerM op s ()) -> [Param] -> CompilerM op s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Param -> CompilerM op s ()
forall op s. Param -> CompilerM op s ()
stubParam [Param]
inputs
  [BlockItem]
outputdecls <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
collect (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ (Param -> CompilerM op s ()) -> [Param] -> CompilerM op s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Param -> CompilerM op s ()
forall op s. Param -> CompilerM op s ()
stubParam [Param]
outputs

  let entry_point_name :: String
entry_point_name = Name -> String
nameToString Name
fname
  String
entry_point_function_name <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName (String -> CompilerM op s String)
-> String -> CompilerM op s String
forall a b. (a -> b) -> a -> b
$ String
"entry_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
entry_point_name

  ([Param]
entry_point_input_params, [BlockItem]
unpack_entry_inputs) <-
    CompilerM op s [Param] -> CompilerM op s ([Param], [BlockItem])
forall op s a. CompilerM op s a -> CompilerM op s (a, [BlockItem])
collect' (CompilerM op s [Param] -> CompilerM op s ([Param], [BlockItem]))
-> CompilerM op s [Param] -> CompilerM op s ([Param], [BlockItem])
forall a b. (a -> b) -> a -> b
$ [ExternalValue] -> CompilerM op s [Param]
forall op s. [ExternalValue] -> CompilerM op s [Param]
prepareEntryInputs [ExternalValue]
args
  ([Param]
entry_point_output_params, [BlockItem]
pack_entry_outputs) <-
    CompilerM op s [Param] -> CompilerM op s ([Param], [BlockItem])
forall op s a. CompilerM op s a -> CompilerM op s (a, [BlockItem])
collect' (CompilerM op s [Param] -> CompilerM op s ([Param], [BlockItem]))
-> CompilerM op s [Param] -> CompilerM op s ([Param], [BlockItem])
forall a b. (a -> b) -> a -> b
$ [ExternalValue] -> CompilerM op s [Param]
forall op s. [ExternalValue] -> CompilerM op s [Param]
prepareEntryOutputs [ExternalValue]
results

  (Definition
cli_entry_point, Initializer
cli_init) <- Name -> Function op -> CompilerM op s (Definition, Initializer)
forall a op s.
Name -> FunctionT a -> CompilerM op s (Definition, Initializer)
cliEntryPoint Name
fname Function op
function

  Type
ctx_ty <- CompilerM op s Type
forall op s. CompilerM op s Type
contextType

  HeaderSection -> Definition -> CompilerM op s ()
forall op s. HeaderSection -> Definition -> CompilerM op s ()
headerDecl
    HeaderSection
EntryDecl
    [C.cedecl|int $id:entry_point_function_name
                                     ($ty:ctx_ty *ctx,
                                      $params:entry_point_output_params,
                                      $params:entry_point_input_params);|]

  let critical :: [BlockItem]
critical =
        [C.citems|
         $items:unpack_entry_inputs

         int ret = $id:(funName fname)(ctx, $args:out_args, $args:in_args);

         if (ret == 0) {
           $items:pack_entry_outputs
         }
        |]

  Operations op s
ops <- (CompilerEnv op s -> Operations op s)
-> CompilerM op s (Operations op s)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks CompilerEnv op s -> Operations op s
forall op s. CompilerEnv op s -> Operations op s
envOperations

  (Definition, Definition, Initializer)
-> CompilerM op s (Definition, Definition, Initializer)
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( [C.cedecl|
       int $id:entry_point_function_name
           ($ty:ctx_ty *ctx,
            $params:entry_point_output_params,
            $params:entry_point_input_params) {
         $items:inputdecls
         $items:outputdecls

         $items:(criticalSection ops critical)

         return ret;
       }|],
      Definition
cli_entry_point,
      Initializer
cli_init
    )
  where
    stubParam :: Param -> CompilerM op s ()
stubParam (MemParam VName
name Space
space) =
      VName -> Space -> CompilerM op s ()
forall op s. VName -> Space -> CompilerM op s ()
declMem VName
name Space
space
    stubParam (ScalarParam VName
name PrimType
ty) = do
      let ty' :: Type
ty' = PrimType -> Type
primTypeToCType PrimType
ty
      InitGroup -> CompilerM op s ()
forall op s. InitGroup -> CompilerM op s ()
decl [C.cdecl|$ty:ty' $id:name;|]

--- CLI interface
--
-- Our strategy for CLI entry points is to parse everything into
-- host memory ('DefaultSpace') and copy the result into host memory
-- after the entry point has returned.  We have some ad-hoc frobbery
-- to copy the host-level memory blocks to another memory space if
-- necessary.  This will break if the Futhark entry point uses
-- non-trivial index functions for its input or output.
--
-- The idea here is to keep the nastyness in the wrapper, whilst not
-- messing up anything else.

printPrimStm :: (C.ToExp a, C.ToExp b) => a -> b -> PrimType -> Signedness -> C.Stm
printPrimStm :: a -> b -> PrimType -> Signedness -> Stm
printPrimStm a
dest b
val PrimType
bt Signedness
ept =
  [C.cstm|write_scalar($exp:dest, binary_output, &$exp:(primTypeInfo bt ept), &$exp:val);|]

-- | Return a statement printing the given external value.
printStm :: ExternalValue -> C.Exp -> CompilerM op s C.Stm
printStm :: ExternalValue -> Exp -> CompilerM op s Stm
printStm (OpaqueValue String
desc [ValueDesc]
_) Exp
_ =
  Stm -> CompilerM op s Stm
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cstm|printf("#<opaque %s>", $string:desc);|]
printStm (TransparentValue (ScalarValue PrimType
bt Signedness
ept VName
_)) Exp
e =
  Stm -> CompilerM op s Stm
forall (m :: * -> *) a. Monad m => a -> m a
return (Stm -> CompilerM op s Stm) -> Stm -> CompilerM op s Stm
forall a b. (a -> b) -> a -> b
$ Exp -> Exp -> PrimType -> Signedness -> Stm
forall a b.
(ToExp a, ToExp b) =>
a -> b -> PrimType -> Signedness -> Stm
printPrimStm [C.cexp|stdout|] Exp
e PrimType
bt Signedness
ept
printStm (TransparentValue (ArrayValue VName
_ Space
_ PrimType
bt Signedness
ept [SubExp]
shape)) Exp
e = do
  String
values_array <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName (String -> CompilerM op s String)
-> String -> CompilerM op s String
forall a b. (a -> b) -> a -> b
$ String
"values_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
name
  String
shape_array <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName (String -> CompilerM op s String)
-> String -> CompilerM op s String
forall a b. (a -> b) -> a -> b
$ String
"shape_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
name
  let num_elems :: Exp
num_elems = [Exp] -> Exp
cproduct [[C.cexp|$id:shape_array(ctx, $exp:e)[$int:i]|] | Int
i <- [Int
0 .. Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]]
  Stm -> CompilerM op s Stm
forall (m :: * -> *) a. Monad m => a -> m a
return
    [C.cstm|{
      $ty:bt' *arr = calloc(sizeof($ty:bt'), $exp:num_elems);
      assert(arr != NULL);
      assert($id:values_array(ctx, $exp:e, arr) == 0);
      write_array(stdout, binary_output, &$exp:(primTypeInfo bt ept), arr,
                  $id:shape_array(ctx, $exp:e), $int:rank);
      free(arr);
    }|]
  where
    rank :: Int
rank = [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
shape
    bt' :: Type
bt' = PrimType -> Type
primTypeToCType PrimType
bt
    name :: String
name = PrimType -> Signedness -> Int -> String
arrayName PrimType
bt Signedness
ept Int
rank

readPrimStm :: C.ToExp a => a -> Int -> PrimType -> Signedness -> C.Stm
readPrimStm :: a -> Int -> PrimType -> Signedness -> Stm
readPrimStm a
place Int
i PrimType
t Signedness
ept =
  [C.cstm|if (read_scalar(&$exp:(primTypeInfo t ept),&$exp:place) != 0) {
        futhark_panic(1, "Error when reading input #%d of type %s (errno: %s).\n",
              $int:i,
              $exp:(primTypeInfo t ept).type_name,
              strerror(errno));
      }|]

readInputs :: [ExternalValue] -> CompilerM op s [(C.Stm, C.Stm, C.Stm, C.Exp)]
readInputs :: [ExternalValue] -> CompilerM op s [(Stm, Stm, Stm, Exp)]
readInputs = (Int -> ExternalValue -> CompilerM op s (Stm, Stm, Stm, Exp))
-> [Int]
-> [ExternalValue]
-> CompilerM op s [(Stm, Stm, Stm, Exp)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM Int -> ExternalValue -> CompilerM op s (Stm, Stm, Stm, Exp)
forall op s.
Int -> ExternalValue -> CompilerM op s (Stm, Stm, Stm, Exp)
readInput [Int
0 ..]

readInput :: Int -> ExternalValue -> CompilerM op s (C.Stm, C.Stm, C.Stm, C.Exp)
readInput :: Int -> ExternalValue -> CompilerM op s (Stm, Stm, Stm, Exp)
readInput Int
i (OpaqueValue String
desc [ValueDesc]
_) = do
  Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|futhark_panic(1, "Cannot read input #%d of type %s\n", $int:i, $string:desc);|]
  (Stm, Stm, Stm, Exp) -> CompilerM op s (Stm, Stm, Stm, Exp)
forall (m :: * -> *) a. Monad m => a -> m a
return ([C.cstm|;|], [C.cstm|;|], [C.cstm|;|], [C.cexp|NULL|])
readInput Int
i (TransparentValue (ScalarValue PrimType
t Signedness
ept VName
_)) = do
  VName
dest <- String -> CompilerM op s VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"read_value"
  BlockItem -> CompilerM op s ()
forall op s. BlockItem -> CompilerM op s ()
item [C.citem|$ty:(primTypeToCType t) $id:dest;|]
  Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm (Stm -> CompilerM op s ()) -> Stm -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ VName -> Int -> PrimType -> Signedness -> Stm
forall a. ToExp a => a -> Int -> PrimType -> Signedness -> Stm
readPrimStm VName
dest Int
i PrimType
t Signedness
ept
  (Stm, Stm, Stm, Exp) -> CompilerM op s (Stm, Stm, Stm, Exp)
forall (m :: * -> *) a. Monad m => a -> m a
return ([C.cstm|;|], [C.cstm|;|], [C.cstm|;|], [C.cexp|$id:dest|])
readInput Int
i (TransparentValue vd :: ValueDesc
vd@(ArrayValue VName
_ Space
_ PrimType
t Signedness
ept [SubExp]
dims)) = do
  VName
dest <- String -> CompilerM op s VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"read_value"
  VName
shape <- String -> CompilerM op s VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"read_shape"
  VName
arr <- String -> CompilerM op s VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"read_arr"
  Type
ty <- ValueDesc -> CompilerM op s Type
forall op s. ValueDesc -> CompilerM op s Type
valueDescToCType ValueDesc
vd
  BlockItem -> CompilerM op s ()
forall op s. BlockItem -> CompilerM op s ()
item [C.citem|$ty:ty *$id:dest;|]

  let t' :: Type
t' = Signedness -> PrimType -> Type
signedPrimTypeToCType Signedness
ept PrimType
t
      rank :: Int
rank = [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
dims
      name :: String
name = PrimType -> Signedness -> Int -> String
arrayName PrimType
t Signedness
ept Int
rank
      dims_exps :: [Exp]
dims_exps = [[C.cexp|$id:shape[$int:j]|] | Int
j <- [Int
0 .. Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]]
      dims_s :: String
dims_s = [String] -> String
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([String] -> String) -> [String] -> String
forall a b. (a -> b) -> a -> b
$ Int -> String -> [String]
forall a. Int -> a -> [a]
replicate Int
rank String
"[]"

  String
new_array <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName (String -> CompilerM op s String)
-> String -> CompilerM op s String
forall a b. (a -> b) -> a -> b
$ String
"new_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
name
  String
free_array <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName (String -> CompilerM op s String)
-> String -> CompilerM op s String
forall a b. (a -> b) -> a -> b
$ String
"free_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
name

  [BlockItem] -> CompilerM op s ()
forall op s. [BlockItem] -> CompilerM op s ()
items
    [C.citems|
     typename int64_t $id:shape[$int:rank];
     $ty:t' *$id:arr = NULL;
     errno = 0;
     if (read_array(&$exp:(primTypeInfo t ept),
                    (void**) &$id:arr,
                    $id:shape,
                    $int:(length dims))
         != 0) {
       futhark_panic(1, "Cannot read input #%d of type %s%s (errno: %s).\n",
                 $int:i,
                 $string:dims_s,
                 $exp:(primTypeInfo t ept).type_name,
                 strerror(errno));
     }|]

  (Stm, Stm, Stm, Exp) -> CompilerM op s (Stm, Stm, Stm, Exp)
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( [C.cstm|assert(($exp:dest = $id:new_array(ctx, $id:arr, $args:dims_exps)) != 0);|],
      [C.cstm|assert($id:free_array(ctx, $exp:dest) == 0);|],
      [C.cstm|free($id:arr);|],
      [C.cexp|$id:dest|]
    )

prepareOutputs :: [ExternalValue] -> CompilerM op s [(C.Exp, C.Stm)]
prepareOutputs :: [ExternalValue] -> CompilerM op s [(Exp, Stm)]
prepareOutputs = (ExternalValue -> CompilerM op s (Exp, Stm))
-> [ExternalValue] -> CompilerM op s [(Exp, Stm)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ExternalValue -> CompilerM op s (Exp, Stm)
forall op s. ExternalValue -> CompilerM op s (Exp, Stm)
prepareResult
  where
    prepareResult :: ExternalValue -> CompilerM op s (Exp, Stm)
prepareResult ExternalValue
ev = do
      Type
ty <- ExternalValue -> CompilerM op s Type
forall op s. ExternalValue -> CompilerM op s Type
externalValueToCType ExternalValue
ev
      VName
result <- String -> CompilerM op s VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"result"

      case ExternalValue
ev of
        TransparentValue ScalarValue {} -> do
          BlockItem -> CompilerM op s ()
forall op s. BlockItem -> CompilerM op s ()
item [C.citem|$ty:ty $id:result;|]
          (Exp, Stm) -> CompilerM op s (Exp, Stm)
forall (m :: * -> *) a. Monad m => a -> m a
return ([C.cexp|$id:result|], [C.cstm|;|])
        TransparentValue (ArrayValue VName
_ Space
_ PrimType
t Signedness
ept [SubExp]
dims) -> do
          let name :: String
name = PrimType -> Signedness -> Int -> String
arrayName PrimType
t Signedness
ept (Int -> String) -> Int -> String
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
dims
          String
free_array <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName (String -> CompilerM op s String)
-> String -> CompilerM op s String
forall a b. (a -> b) -> a -> b
$ String
"free_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
name
          BlockItem -> CompilerM op s ()
forall op s. BlockItem -> CompilerM op s ()
item [C.citem|$ty:ty *$id:result;|]
          (Exp, Stm) -> CompilerM op s (Exp, Stm)
forall (m :: * -> *) a. Monad m => a -> m a
return
            ( [C.cexp|$id:result|],
              [C.cstm|assert($id:free_array(ctx, $exp:result) == 0);|]
            )
        OpaqueValue String
desc [ValueDesc]
vds -> do
          String
free_opaque <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName (String -> CompilerM op s String)
-> String -> CompilerM op s String
forall a b. (a -> b) -> a -> b
$ String
"free_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String -> [ValueDesc] -> String
opaqueName String
desc [ValueDesc]
vds
          BlockItem -> CompilerM op s ()
forall op s. BlockItem -> CompilerM op s ()
item [C.citem|$ty:ty *$id:result;|]
          (Exp, Stm) -> CompilerM op s (Exp, Stm)
forall (m :: * -> *) a. Monad m => a -> m a
return
            ( [C.cexp|$id:result|],
              [C.cstm|assert($id:free_opaque(ctx, $exp:result) == 0);|]
            )

printResult :: [(ExternalValue, C.Exp)] -> CompilerM op s [C.Stm]
printResult :: [(ExternalValue, Exp)] -> CompilerM op s [Stm]
printResult [(ExternalValue, Exp)]
vs = ([[Stm]] -> [Stm])
-> CompilerM op s [[Stm]] -> CompilerM op s [Stm]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [[Stm]] -> [Stm]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat (CompilerM op s [[Stm]] -> CompilerM op s [Stm])
-> CompilerM op s [[Stm]] -> CompilerM op s [Stm]
forall a b. (a -> b) -> a -> b
$
  [(ExternalValue, Exp)]
-> ((ExternalValue, Exp) -> CompilerM op s [Stm])
-> CompilerM op s [[Stm]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(ExternalValue, Exp)]
vs (((ExternalValue, Exp) -> CompilerM op s [Stm])
 -> CompilerM op s [[Stm]])
-> ((ExternalValue, Exp) -> CompilerM op s [Stm])
-> CompilerM op s [[Stm]]
forall a b. (a -> b) -> a -> b
$ \(ExternalValue
v, Exp
e) -> do
    Stm
p <- ExternalValue -> Exp -> CompilerM op s Stm
forall op s. ExternalValue -> Exp -> CompilerM op s Stm
printStm ExternalValue
v Exp
e
    [Stm] -> CompilerM op s [Stm]
forall (m :: * -> *) a. Monad m => a -> m a
return [Stm
p, [C.cstm|printf("\n");|]]

cliEntryPoint ::
  Name ->
  FunctionT a ->
  CompilerM op s (C.Definition, C.Initializer)
cliEntryPoint :: Name -> FunctionT a -> CompilerM op s (Definition, Initializer)
cliEntryPoint Name
fname (Function Bool
_ [Param]
_ [Param]
_ Code a
_ [ExternalValue]
results [ExternalValue]
args) = do
  (([Stm]
pack_input, [Stm]
free_input, [Stm]
free_parsed, [Exp]
input_args), [BlockItem]
input_items) <-
    CompilerM op s ([Stm], [Stm], [Stm], [Exp])
-> CompilerM op s (([Stm], [Stm], [Stm], [Exp]), [BlockItem])
forall op s a. CompilerM op s a -> CompilerM op s (a, [BlockItem])
collect' (CompilerM op s ([Stm], [Stm], [Stm], [Exp])
 -> CompilerM op s (([Stm], [Stm], [Stm], [Exp]), [BlockItem]))
-> CompilerM op s ([Stm], [Stm], [Stm], [Exp])
-> CompilerM op s (([Stm], [Stm], [Stm], [Exp]), [BlockItem])
forall a b. (a -> b) -> a -> b
$ [(Stm, Stm, Stm, Exp)] -> ([Stm], [Stm], [Stm], [Exp])
forall a b c d. [(a, b, c, d)] -> ([a], [b], [c], [d])
unzip4 ([(Stm, Stm, Stm, Exp)] -> ([Stm], [Stm], [Stm], [Exp]))
-> CompilerM op s [(Stm, Stm, Stm, Exp)]
-> CompilerM op s ([Stm], [Stm], [Stm], [Exp])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [ExternalValue] -> CompilerM op s [(Stm, Stm, Stm, Exp)]
forall op s.
[ExternalValue] -> CompilerM op s [(Stm, Stm, Stm, Exp)]
readInputs [ExternalValue]
args

  (([Exp]
output_vals, [Stm]
free_outputs), [BlockItem]
output_decls) <-
    CompilerM op s ([Exp], [Stm])
-> CompilerM op s (([Exp], [Stm]), [BlockItem])
forall op s a. CompilerM op s a -> CompilerM op s (a, [BlockItem])
collect' (CompilerM op s ([Exp], [Stm])
 -> CompilerM op s (([Exp], [Stm]), [BlockItem]))
-> CompilerM op s ([Exp], [Stm])
-> CompilerM op s (([Exp], [Stm]), [BlockItem])
forall a b. (a -> b) -> a -> b
$ [(Exp, Stm)] -> ([Exp], [Stm])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Exp, Stm)] -> ([Exp], [Stm]))
-> CompilerM op s [(Exp, Stm)] -> CompilerM op s ([Exp], [Stm])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [ExternalValue] -> CompilerM op s [(Exp, Stm)]
forall op s. [ExternalValue] -> CompilerM op s [(Exp, Stm)]
prepareOutputs [ExternalValue]
results
  [Stm]
printstms <- [(ExternalValue, Exp)] -> CompilerM op s [Stm]
forall op s. [(ExternalValue, Exp)] -> CompilerM op s [Stm]
printResult ([(ExternalValue, Exp)] -> CompilerM op s [Stm])
-> [(ExternalValue, Exp)] -> CompilerM op s [Stm]
forall a b. (a -> b) -> a -> b
$ [ExternalValue] -> [Exp] -> [(ExternalValue, Exp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [ExternalValue]
results [Exp]
output_vals

  Type
ctx_ty <- CompilerM op s Type
forall op s. CompilerM op s Type
contextType
  String
sync_ctx <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName String
"context_sync"
  String
error_ctx <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName String
"context_get_error"

  let entry_point_name :: String
entry_point_name = Name -> String
nameToString Name
fname
      cli_entry_point_function_name :: String
cli_entry_point_function_name = String
"futrts_cli_entry_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
entry_point_name
  String
entry_point_function_name <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName (String -> CompilerM op s String)
-> String -> CompilerM op s String
forall a b. (a -> b) -> a -> b
$ String
"entry_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
entry_point_name

  String
pause_profiling <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName String
"context_pause_profiling"
  String
unpause_profiling <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName String
"context_unpause_profiling"

  let run_it :: [BlockItem]
run_it =
        [C.citems|
                  int r;
                  // Run the program once.
                  $stms:pack_input
                  if ($id:sync_ctx(ctx) != 0) {
                    futhark_panic(1, "%s", $id:error_ctx(ctx));
                  };
                  // Only profile last run.
                  if (profile_run) {
                    $id:unpause_profiling(ctx);
                  }
                  t_start = get_wall_time();
                  r = $id:entry_point_function_name(ctx,
                                                    $args:(map addrOf output_vals),
                                                    $args:input_args);
                  if (r != 0) {
                    futhark_panic(1, "%s", $id:error_ctx(ctx));
                  }
                  if ($id:sync_ctx(ctx) != 0) {
                    futhark_panic(1, "%s", $id:error_ctx(ctx));
                  };
                  if (profile_run) {
                    $id:pause_profiling(ctx);
                  }
                  t_end = get_wall_time();
                  long int elapsed_usec = t_end - t_start;
                  if (time_runs && runtime_file != NULL) {
                    fprintf(runtime_file, "%lld\n", (long long) elapsed_usec);
                    fflush(runtime_file);
                  }
                  $stms:free_input
                |]

  (Definition, Initializer)
-> CompilerM op s (Definition, Initializer)
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( [C.cedecl|static void $id:cli_entry_point_function_name($ty:ctx_ty *ctx) {
    typename int64_t t_start, t_end;
    int time_runs = 0, profile_run = 0;

    // We do not want to profile all the initialisation.
    $id:pause_profiling(ctx);

    // Declare and read input.
    set_binary_mode(stdin);
    $items:input_items

    if (end_of_input() != 0) {
      futhark_panic(1, "Expected EOF on stdin after reading input for %s.\n", $string:(quote (pretty fname)));
    }

    $items:output_decls

    // Warmup run
    if (perform_warmup) {
      $items:run_it
      $stms:free_outputs
    }
    time_runs = 1;
    // Proper run.
    for (int run = 0; run < num_runs; run++) {
      // Only profile last run.
      profile_run = run == num_runs -1;
      $items:run_it
      if (run < num_runs-1) {
        $stms:free_outputs
      }
    }

    // Free the parsed input.
    $stms:free_parsed

    // Print the final result.
    if (binary_output) {
      set_binary_mode(stdout);
    }
    $stms:printstms

    $stms:free_outputs
  }
                |],
      [C.cinit|{ .name = $string:entry_point_name,
                      .fun = $id:cli_entry_point_function_name }|]
    )

genericOptions :: [Option]
genericOptions :: [Option]
genericOptions =
  [ Option :: String -> Maybe Char -> OptionArgument -> String -> Stm -> Option
Option
      { optionLongName :: String
optionLongName = String
"write-runtime-to",
        optionShortName :: Maybe Char
optionShortName = Char -> Maybe Char
forall a. a -> Maybe a
Just Char
't',
        optionArgument :: OptionArgument
optionArgument = String -> OptionArgument
RequiredArgument String
"FILE",
        optionDescription :: String
optionDescription = String
"Print the time taken to execute the program to the indicated file, an integral number of microseconds.",
        optionAction :: Stm
optionAction = Stm
set_runtime_file
      },
    Option :: String -> Maybe Char -> OptionArgument -> String -> Stm -> Option
Option
      { optionLongName :: String
optionLongName = String
"runs",
        optionShortName :: Maybe Char
optionShortName = Char -> Maybe Char
forall a. a -> Maybe a
Just Char
'r',
        optionArgument :: OptionArgument
optionArgument = String -> OptionArgument
RequiredArgument String
"INT",
        optionDescription :: String
optionDescription = String
"Perform NUM runs of the program.",
        optionAction :: Stm
optionAction = Stm
set_num_runs
      },
    Option :: String -> Maybe Char -> OptionArgument -> String -> Stm -> Option
Option
      { optionLongName :: String
optionLongName = String
"debugging",
        optionShortName :: Maybe Char
optionShortName = Char -> Maybe Char
forall a. a -> Maybe a
Just Char
'D',
        optionArgument :: OptionArgument
optionArgument = OptionArgument
NoArgument,
        optionDescription :: String
optionDescription = String
"Perform possibly expensive internal correctness checks and verbose logging.",
        optionAction :: Stm
optionAction = [C.cstm|futhark_context_config_set_debugging(cfg, 1);|]
      },
    Option :: String -> Maybe Char -> OptionArgument -> String -> Stm -> Option
Option
      { optionLongName :: String
optionLongName = String
"log",
        optionShortName :: Maybe Char
optionShortName = Char -> Maybe Char
forall a. a -> Maybe a
Just Char
'L',
        optionArgument :: OptionArgument
optionArgument = OptionArgument
NoArgument,
        optionDescription :: String
optionDescription = String
"Print various low-overhead logging information to stderr while running.",
        optionAction :: Stm
optionAction = [C.cstm|futhark_context_config_set_logging(cfg, 1);|]
      },
    Option :: String -> Maybe Char -> OptionArgument -> String -> Stm -> Option
Option
      { optionLongName :: String
optionLongName = String
"entry-point",
        optionShortName :: Maybe Char
optionShortName = Char -> Maybe Char
forall a. a -> Maybe a
Just Char
'e',
        optionArgument :: OptionArgument
optionArgument = String -> OptionArgument
RequiredArgument String
"NAME",
        optionDescription :: String
optionDescription = String
"The entry point to run. Defaults to main.",
        optionAction :: Stm
optionAction = [C.cstm|if (entry_point != NULL) entry_point = optarg;|]
      },
    Option :: String -> Maybe Char -> OptionArgument -> String -> Stm -> Option
Option
      { optionLongName :: String
optionLongName = String
"binary-output",
        optionShortName :: Maybe Char
optionShortName = Char -> Maybe Char
forall a. a -> Maybe a
Just Char
'b',
        optionArgument :: OptionArgument
optionArgument = OptionArgument
NoArgument,
        optionDescription :: String
optionDescription = String
"Print the program result in the binary output format.",
        optionAction :: Stm
optionAction = [C.cstm|binary_output = 1;|]
      },
    Option :: String -> Maybe Char -> OptionArgument -> String -> Stm -> Option
Option
      { optionLongName :: String
optionLongName = String
"help",
        optionShortName :: Maybe Char
optionShortName = Char -> Maybe Char
forall a. a -> Maybe a
Just Char
'h',
        optionArgument :: OptionArgument
optionArgument = OptionArgument
NoArgument,
        optionDescription :: String
optionDescription = String
"Print help information and exit.",
        optionAction :: Stm
optionAction =
          [C.cstm|{
                   printf("Usage: %s [OPTION]...\nOptions:\n\n%s\nFor more information, consult the Futhark User's Guide or the man pages.\n",
                          fut_progname, option_descriptions);
                   exit(0);
                  }|]
      }
  ]
  where
    set_runtime_file :: Stm
set_runtime_file =
      [C.cstm|{
          runtime_file = fopen(optarg, "w");
          if (runtime_file == NULL) {
            futhark_panic(1, "Cannot open %s: %s\n", optarg, strerror(errno));
          }
        }|]
    set_num_runs :: Stm
set_num_runs =
      [C.cstm|{
          num_runs = atoi(optarg);
          perform_warmup = 1;
          if (num_runs <= 0) {
            futhark_panic(1, "Need a positive number of runs, not %s\n", optarg);
          }
        }|]

-- | The result of compilation to C is four parts, which can be put
-- together in various ways.  The obvious way is to concatenate all of
-- them, which yields a CLI program.  Another is to compile the
-- library part by itself, and use the header file to call into it.
data CParts = CParts
  { CParts -> String
cHeader :: String,
    -- | Utility definitions that must be visible
    -- to both CLI and library parts.
    CParts -> String
cUtils :: String,
    CParts -> String
cCLI :: String,
    CParts -> String
cLib :: String
  }

-- We may generate variables that are never used (e.g. for
-- certificates) or functions that are never called (e.g. unused
-- intrinsics), and generated code may have other cosmetic issues that
-- compilers warn about.  We disable these warnings to not clutter the
-- compilation logs.
disableWarnings :: String
disableWarnings :: String
disableWarnings =
  [Definition] -> String
forall a. Pretty a => a -> String
pretty
    [C.cunit|
$esc:("#ifdef __GNUC__")
$esc:("#pragma GCC diagnostic ignored \"-Wunused-function\"")
$esc:("#pragma GCC diagnostic ignored \"-Wunused-variable\"")
$esc:("#pragma GCC diagnostic ignored \"-Wparentheses\"")
$esc:("#pragma GCC diagnostic ignored \"-Wunused-label\"")
$esc:("#endif")

$esc:("#ifdef __clang__")
$esc:("#pragma clang diagnostic ignored \"-Wunused-function\"")
$esc:("#pragma clang diagnostic ignored \"-Wunused-variable\"")
$esc:("#pragma clang diagnostic ignored \"-Wparentheses\"")
$esc:("#pragma clang diagnostic ignored \"-Wunused-label\"")
$esc:("#endif")
|]

-- | Produce header and implementation files.
asLibrary :: CParts -> (String, String)
asLibrary :: CParts -> (String, String)
asLibrary CParts
parts =
  ( String
"#pragma once\n\n" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> CParts -> String
cHeader CParts
parts,
    String
disableWarnings String -> String -> String
forall a. Semigroup a => a -> a -> a
<> CParts -> String
cHeader CParts
parts String -> String -> String
forall a. Semigroup a => a -> a -> a
<> CParts -> String
cUtils CParts
parts String -> String -> String
forall a. Semigroup a => a -> a -> a
<> CParts -> String
cLib CParts
parts
  )

-- | As executable with command-line interface.
asExecutable :: CParts -> String
asExecutable :: CParts -> String
asExecutable (CParts String
a String
b String
c String
d) = String
disableWarnings String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
a String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
b String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
c String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
d

-- | Compile imperative program to a C program.  Always uses the
-- function named "main" as entry point, so make sure it is defined.
compileProg ::
  MonadFreshNames m =>
  String ->
  Operations op () ->
  CompilerM op () () ->
  String ->
  [Space] ->
  [Option] ->
  Definitions op ->
  m CParts
compileProg :: String
-> Operations op ()
-> CompilerM op () ()
-> String
-> [Space]
-> [Option]
-> Definitions op
-> m CParts
compileProg String
backend Operations op ()
ops CompilerM op () ()
extra String
header_extra [Space]
spaces [Option]
options Definitions op
prog = do
  VNameSource
src <- m VNameSource
forall (m :: * -> *). MonadFreshNames m => m VNameSource
getNameSource
  let (([Definition]
prototypes, [Func]
definitions, [(Definition, Definition, Initializer)]
entry_points), CompilerState ()
endstate) =
        Operations op ()
-> VNameSource
-> ()
-> CompilerM
     op
     ()
     ([Definition], [Func], [(Definition, Definition, Initializer)])
-> (([Definition], [Func],
     [(Definition, Definition, Initializer)]),
    CompilerState ())
forall op s a.
Operations op s
-> VNameSource -> s -> CompilerM op s a -> (a, CompilerState s)
runCompilerM Operations op ()
ops VNameSource
src () CompilerM
  op
  ()
  ([Definition], [Func], [(Definition, Definition, Initializer)])
compileProg'
      ([Definition]
entry_point_decls, [Definition]
cli_entry_point_decls, [Initializer]
entry_point_inits) =
        [(Definition, Definition, Initializer)]
-> ([Definition], [Definition], [Initializer])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Definition, Definition, Initializer)]
entry_points
      option_parser :: Func
option_parser = String -> [Option] -> Func
generateOptionParser String
"parse_options" ([Option] -> Func) -> [Option] -> Func
forall a b. (a -> b) -> a -> b
$ [Option]
genericOptions [Option] -> [Option] -> [Option]
forall a. [a] -> [a] -> [a]
++ [Option]
options

  let headerdefs :: [Definition]
headerdefs =
        [C.cunit|
$esc:("// Headers\n")
/* We need to define _GNU_SOURCE before
   _any_ headers files are imported to get
   the usage statistics of a thread (i.e. have RUSAGE_THREAD) on GNU/Linux
   https://manpages.courier-mta.org/htmlman2/getrusage.2.html */
$esc:("#define _GNU_SOURCE")
$esc:("#include <stdint.h>")
$esc:("#include <stddef.h>")
$esc:("#include <stdbool.h>")
$esc:("#include <float.h>")
$esc:(header_extra)

$esc:("\n// Initialisation\n")
$edecls:(initDecls endstate)

$esc:("\n// Arrays\n")
$edecls:(arrayDecls endstate)

$esc:("\n// Opaque values\n")
$edecls:(opaqueDecls endstate)

$esc:("\n// Entry points\n")
$edecls:(entryDecls endstate)

$esc:("\n// Miscellaneous\n")
$edecls:(miscDecls endstate)
$esc:("#define FUTHARK_BACKEND_"++backend)
                           |]

  let utildefs :: [Definition]
utildefs =
        [C.cunit|
$esc:("#include <stdio.h>")
$esc:("#include <stdlib.h>")
$esc:("#include <stdbool.h>")
$esc:("#include <math.h>")
$esc:("#include <stdint.h>")
// If NDEBUG is set, the assert() macro will do nothing. Since Futhark
// (unfortunately) makes use of assert() for error detection (and even some
// side effects), we want to avoid that.
$esc:("#undef NDEBUG")
$esc:("#include <assert.h>")
$esc:("#include <stdarg.h>")

$esc:util_h

$esc:timing_h
|]

  let clidefs :: [Definition]
clidefs =
        [C.cunit|
$esc:("#include <string.h>")
$esc:("#include <inttypes.h>")
$esc:("#include <errno.h>")
$esc:("#include <ctype.h>")
$esc:("#include <errno.h>")
$esc:("#include <getopt.h>")

$esc:values_h

$esc:("#define __private")

static int binary_output = 0;
static typename FILE *runtime_file;
static int perform_warmup = 0;
static int num_runs = 1;
// If the entry point is NULL, the program will terminate after doing initialisation and such.
static const char *entry_point = "main";

$esc:tuning_h

$func:option_parser

$edecls:cli_entry_point_decls

typedef void entry_point_fun(struct futhark_context*);

struct entry_point_entry {
  const char *name;
  entry_point_fun *fun;
};

int main(int argc, char** argv) {
  fut_progname = argv[0];

  struct entry_point_entry entry_points[] = {
    $inits:entry_point_inits
  };

  struct futhark_context_config *cfg = futhark_context_config_new();
  assert(cfg != NULL);

  int parsed_options = parse_options(cfg, argc, argv);
  argc -= parsed_options;
  argv += parsed_options;

  if (argc != 0) {
    futhark_panic(1, "Excess non-option: %s\n", argv[0]);
  }

  struct futhark_context *ctx = futhark_context_new(cfg);
  assert (ctx != NULL);

  char* error = futhark_context_get_error(ctx);
  if (error != NULL) {
    futhark_panic(1, "%s", error);
  }

  if (entry_point != NULL) {
    int num_entry_points = sizeof(entry_points) / sizeof(entry_points[0]);
    entry_point_fun *entry_point_fun = NULL;
    for (int i = 0; i < num_entry_points; i++) {
      if (strcmp(entry_points[i].name, entry_point) == 0) {
        entry_point_fun = entry_points[i].fun;
        break;
      }
    }

    if (entry_point_fun == NULL) {
      fprintf(stderr, "No entry point '%s'.  Select another with --entry-point.  Options are:\n",
                      entry_point);
      for (int i = 0; i < num_entry_points; i++) {
        fprintf(stderr, "%s\n", entry_points[i].name);
      }
      return 1;
    }

    entry_point_fun(ctx);

    if (runtime_file != NULL) {
      fclose(runtime_file);
    }

    char *report = futhark_context_report(ctx);
    fputs(report, stderr);
    free(report);
  }

  futhark_context_free(ctx);
  futhark_context_config_free(cfg);
  return 0;
}
                        |]

  let early_decls :: [Definition]
early_decls = DList Definition -> [Definition]
forall a. DList a -> [a]
DL.toList (DList Definition -> [Definition])
-> DList Definition -> [Definition]
forall a b. (a -> b) -> a -> b
$ CompilerState () -> DList Definition
forall s. CompilerState s -> DList Definition
compEarlyDecls CompilerState ()
endstate
  let lib_decls :: [Definition]
lib_decls = DList Definition -> [Definition]
forall a. DList a -> [a]
DL.toList (DList Definition -> [Definition])
-> DList Definition -> [Definition]
forall a b. (a -> b) -> a -> b
$ CompilerState () -> DList Definition
forall s. CompilerState s -> DList Definition
compLibDecls CompilerState ()
endstate
  let libdefs :: [Definition]
libdefs =
        [C.cunit|
$esc:("#ifdef _MSC_VER\n#define inline __inline\n#endif")
$esc:("#include <string.h>")
$esc:("#include <inttypes.h>")
$esc:("#include <ctype.h>")
$esc:("#include <errno.h>")
$esc:("#include <assert.h>")

$esc:(header_extra)

$esc:lock_h

$edecls:builtin

$edecls:early_decls

$edecls:prototypes

$edecls:lib_decls

$edecls:(map funcToDef definitions)

$edecls:(arrayDefinitions endstate)

$edecls:(opaqueDefinitions endstate)

$edecls:entry_point_decls
  |]

  CParts -> m CParts
forall (m :: * -> *) a. Monad m => a -> m a
return (CParts -> m CParts) -> CParts -> m CParts
forall a b. (a -> b) -> a -> b
$ String -> String -> String -> String -> CParts
CParts ([Definition] -> String
forall a. Pretty a => a -> String
pretty [Definition]
headerdefs) ([Definition] -> String
forall a. Pretty a => a -> String
pretty [Definition]
utildefs) ([Definition] -> String
forall a. Pretty a => a -> String
pretty [Definition]
clidefs) ([Definition] -> String
forall a. Pretty a => a -> String
pretty [Definition]
libdefs)
  where
    compileProg' :: CompilerM
  op
  ()
  ([Definition], [Func], [(Definition, Definition, Initializer)])
compileProg' = do
      let Definitions Constants op
consts (Functions [(Name, Function op)]
funs) = Definitions op
prog

      ([Definition]
memstructs, [[Definition]]
memfuns, [BlockItem]
memreport) <- [(Definition, [Definition], BlockItem)]
-> ([Definition], [[Definition]], [BlockItem])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(Definition, [Definition], BlockItem)]
 -> ([Definition], [[Definition]], [BlockItem]))
-> CompilerM op () [(Definition, [Definition], BlockItem)]
-> CompilerM op () ([Definition], [[Definition]], [BlockItem])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Space -> CompilerM op () (Definition, [Definition], BlockItem))
-> [Space]
-> CompilerM op () [(Definition, [Definition], BlockItem)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Space -> CompilerM op () (Definition, [Definition], BlockItem)
forall op s.
Space -> CompilerM op s (Definition, [Definition], BlockItem)
defineMemorySpace [Space]
spaces

      [BlockItem]
get_consts <- Constants op -> CompilerM op () [BlockItem]
forall op s. Constants op -> CompilerM op s [BlockItem]
compileConstants Constants op
consts

      Type
ctx_ty <- CompilerM op () Type
forall op s. CompilerM op s Type
contextType

      ([Definition]
prototypes, [Func]
definitions) <-
        [(Definition, Func)] -> ([Definition], [Func])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Definition, Func)] -> ([Definition], [Func]))
-> CompilerM op () [(Definition, Func)]
-> CompilerM op () ([Definition], [Func])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Name, Function op) -> CompilerM op () (Definition, Func))
-> [(Name, Function op)] -> CompilerM op () [(Definition, Func)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([BlockItem]
-> [Param]
-> (Name, Function op)
-> CompilerM op () (Definition, Func)
forall op s.
[BlockItem]
-> [Param]
-> (Name, Function op)
-> CompilerM op s (Definition, Func)
compileFun [BlockItem]
get_consts [[C.cparam|$ty:ctx_ty *ctx|]]) [(Name, Function op)]
funs

      (Definition -> CompilerM op () ())
-> [Definition] -> CompilerM op () ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Definition -> CompilerM op () ()
forall op s. Definition -> CompilerM op s ()
earlyDecl [Definition]
memstructs
      [(Definition, Definition, Initializer)]
entry_points <-
        ((Name, Function op)
 -> CompilerM op () (Definition, Definition, Initializer))
-> [(Name, Function op)]
-> CompilerM op () [(Definition, Definition, Initializer)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((Name
 -> Function op
 -> CompilerM op () (Definition, Definition, Initializer))
-> (Name, Function op)
-> CompilerM op () (Definition, Definition, Initializer)
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Name
-> Function op
-> CompilerM op () (Definition, Definition, Initializer)
forall op s.
Name
-> Function op
-> CompilerM op s (Definition, Definition, Initializer)
onEntryPoint) ([(Name, Function op)]
 -> CompilerM op () [(Definition, Definition, Initializer)])
-> [(Name, Function op)]
-> CompilerM op () [(Definition, Definition, Initializer)]
forall a b. (a -> b) -> a -> b
$ ((Name, Function op) -> Bool)
-> [(Name, Function op)] -> [(Name, Function op)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Function op -> Bool
forall a. FunctionT a -> Bool
functionEntry (Function op -> Bool)
-> ((Name, Function op) -> Function op)
-> (Name, Function op)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Name, Function op) -> Function op
forall a b. (a, b) -> b
snd) [(Name, Function op)]
funs

      CompilerM op () ()
extra

      (Definition -> CompilerM op () ())
-> [Definition] -> CompilerM op () ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Definition -> CompilerM op () ()
forall op s. Definition -> CompilerM op s ()
earlyDecl ([Definition] -> CompilerM op () ())
-> [Definition] -> CompilerM op () ()
forall a b. (a -> b) -> a -> b
$ [[Definition]] -> [Definition]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Definition]]
memfuns

      [BlockItem] -> CompilerM op () ()
forall op s. [BlockItem] -> CompilerM op s ()
commonLibFuns [BlockItem]
memreport

      ([Definition], [Func], [(Definition, Definition, Initializer)])
-> CompilerM
     op
     ()
     ([Definition], [Func], [(Definition, Definition, Initializer)])
forall (m :: * -> *) a. Monad m => a -> m a
return ([Definition]
prototypes, [Func]
definitions, [(Definition, Definition, Initializer)]
entry_points)

    funcToDef :: Func -> Definition
funcToDef Func
func = Func -> SrcLoc -> Definition
C.FuncDef Func
func SrcLoc
loc
      where
        loc :: SrcLoc
loc = case Func
func of
          C.OldFunc _ _ _ _ _ _ l -> SrcLoc
l
          C.Func _ _ _ _ _ l -> SrcLoc
l

    builtin :: [Definition]
builtin =
      [Definition]
cIntOps [Definition] -> [Definition] -> [Definition]
forall a. [a] -> [a] -> [a]
++ [Definition]
cFloat32Ops [Definition] -> [Definition] -> [Definition]
forall a. [a] -> [a] -> [a]
++ [Definition]
cFloat64Ops [Definition] -> [Definition] -> [Definition]
forall a. [a] -> [a] -> [a]
++ [Definition]
cFloatConvOps
        [Definition] -> [Definition] -> [Definition]
forall a. [a] -> [a] -> [a]
++ [Definition]
cFloat32Funs
        [Definition] -> [Definition] -> [Definition]
forall a. [a] -> [a] -> [a]
++ [Definition]
cFloat64Funs

    util_h :: String
util_h = $(embedStringFile "rts/c/util.h")
    values_h :: String
values_h = $(embedStringFile "rts/c/values.h")
    timing_h :: String
timing_h = $(embedStringFile "rts/c/timing.h")
    lock_h :: String
lock_h = $(embedStringFile "rts/c/lock.h")
    tuning_h :: String
tuning_h = $(embedStringFile "rts/c/tuning.h")

commonLibFuns :: [C.BlockItem] -> CompilerM op s ()
commonLibFuns :: [BlockItem] -> CompilerM op s ()
commonLibFuns [BlockItem]
memreport = do
  Type
ctx <- CompilerM op s Type
forall op s. CompilerM op s Type
contextType
  [BlockItem]
profilereport <- (CompilerState s -> [BlockItem]) -> CompilerM op s [BlockItem]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((CompilerState s -> [BlockItem]) -> CompilerM op s [BlockItem])
-> (CompilerState s -> [BlockItem]) -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ DList BlockItem -> [BlockItem]
forall a. DList a -> [a]
DL.toList (DList BlockItem -> [BlockItem])
-> (CompilerState s -> DList BlockItem)
-> CompilerState s
-> [BlockItem]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerState s -> DList BlockItem
forall s. CompilerState s -> DList BlockItem
compProfileItems

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
publicDef_ String
"context_report" HeaderSection
MiscDecl ((String -> (Definition, Definition)) -> CompilerM op s ())
-> (String -> (Definition, Definition)) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ( [C.cedecl|char* $id:s($ty:ctx *ctx);|],
      [C.cedecl|char* $id:s($ty:ctx *ctx) {
                 struct str_builder builder;
                 str_builder_init(&builder);
                 if (ctx->detail_memory || ctx->profiling) {
                   $items:memreport
                 }
                 if (ctx->profiling) {
                   $items:profilereport
                 }
                 return builder.str;
               }|]
    )

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
publicDef_ String
"context_get_error" HeaderSection
MiscDecl ((String -> (Definition, Definition)) -> CompilerM op s ())
-> (String -> (Definition, Definition)) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ( [C.cedecl|char* $id:s($ty:ctx* ctx);|],
      [C.cedecl|char* $id:s($ty:ctx* ctx) {
                         char* error = ctx->error;
                         ctx->error = NULL;
                         return error;
                       }|]
    )

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
publicDef_ String
"context_pause_profiling" HeaderSection
MiscDecl ((String -> (Definition, Definition)) -> CompilerM op s ())
-> (String -> (Definition, Definition)) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ( [C.cedecl|void $id:s($ty:ctx* ctx);|],
      [C.cedecl|void $id:s($ty:ctx* ctx) {
                 ctx->profiling_paused = 1;
               }|]
    )

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
publicDef_ String
"context_unpause_profiling" HeaderSection
MiscDecl ((String -> (Definition, Definition)) -> CompilerM op s ())
-> (String -> (Definition, Definition)) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ( [C.cedecl|void $id:s($ty:ctx* ctx);|],
      [C.cedecl|void $id:s($ty:ctx* ctx) {
                 ctx->profiling_paused = 0;
               }|]
    )

compileConstants :: Constants op -> CompilerM op s [C.BlockItem]
compileConstants :: Constants op -> CompilerM op s [BlockItem]
compileConstants (Constants [Param]
ps Code op
init_consts) = do
  Type
ctx_ty <- CompilerM op s Type
forall op s. CompilerM op s Type
contextType
  [FieldGroup]
const_fields <- (Param -> CompilerM op s FieldGroup)
-> [Param] -> CompilerM op s [FieldGroup]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Param -> CompilerM op s FieldGroup
forall op s. Param -> CompilerM op s FieldGroup
constParamField [Param]
ps
  -- Avoid an empty struct, as that is apparently undefined behaviour.
  let const_fields' :: [FieldGroup]
const_fields'
        | [FieldGroup] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [FieldGroup]
const_fields = [[C.csdecl|int dummy;|]]
        | Bool
otherwise = [FieldGroup]
const_fields
  Id -> Type -> Maybe Exp -> CompilerM op s ()
forall op s. Id -> Type -> Maybe Exp -> CompilerM op s ()
contextField Id
"constants" [C.cty|struct { $sdecls:const_fields' }|] Maybe Exp
forall a. Maybe a
Nothing
  Definition -> CompilerM op s ()
forall op s. Definition -> CompilerM op s ()
earlyDecl [C.cedecl|static int init_constants($ty:ctx_ty*);|]
  Definition -> CompilerM op s ()
forall op s. Definition -> CompilerM op s ()
earlyDecl [C.cedecl|static int free_constants($ty:ctx_ty*);|]

  -- We locally define macros for the constants, so that when we
  -- generate assignments to local variables, we actually assign into
  -- the constants struct.  This is not needed for functions, because
  -- they can only read constants, not write them.
  let ([BlockItem]
defs, [BlockItem]
undefs) = [(BlockItem, BlockItem)] -> ([BlockItem], [BlockItem])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(BlockItem, BlockItem)] -> ([BlockItem], [BlockItem]))
-> [(BlockItem, BlockItem)] -> ([BlockItem], [BlockItem])
forall a b. (a -> b) -> a -> b
$ (Param -> (BlockItem, BlockItem))
-> [Param] -> [(BlockItem, BlockItem)]
forall a b. (a -> b) -> [a] -> [b]
map Param -> (BlockItem, BlockItem)
constMacro [Param]
ps
  [BlockItem]
init_consts' <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
blockScope (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ do
    (Param -> CompilerM op s ()) -> [Param] -> CompilerM op s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Param -> CompilerM op s ()
forall op s. Param -> CompilerM op s ()
resetMemConst [Param]
ps
    Code op -> CompilerM op s ()
forall op s. Code op -> CompilerM op s ()
compileCode Code op
init_consts
  Definition -> CompilerM op s ()
forall op s. Definition -> CompilerM op s ()
libDecl
    [C.cedecl|static int init_constants($ty:ctx_ty *ctx) {
      (void)ctx;
      int err = 0;
      $items:defs
      $items:init_consts'
      $items:undefs
      cleanup:
      return err;
    }|]

  [BlockItem]
free_consts <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
collect (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ (Param -> CompilerM op s ()) -> [Param] -> CompilerM op s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Param -> CompilerM op s ()
forall op s. Param -> CompilerM op s ()
freeConst [Param]
ps
  Definition -> CompilerM op s ()
forall op s. Definition -> CompilerM op s ()
libDecl
    [C.cedecl|static int free_constants($ty:ctx_ty *ctx) {
      (void)ctx;
      $items:free_consts
      return 0;
    }|]

  (Param -> CompilerM op s BlockItem)
-> [Param] -> CompilerM op s [BlockItem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Param -> CompilerM op s BlockItem
forall op s. Param -> CompilerM op s BlockItem
getConst [Param]
ps
  where
    constParamField :: Param -> CompilerM op s FieldGroup
constParamField (ScalarParam VName
name PrimType
bt) = do
      let ctp :: Type
ctp = PrimType -> Type
primTypeToCType PrimType
bt
      FieldGroup -> CompilerM op s FieldGroup
forall (m :: * -> *) a. Monad m => a -> m a
return [C.csdecl|$ty:ctp $id:name;|]
    constParamField (MemParam VName
name Space
space) = do
      Type
ty <- VName -> Space -> CompilerM op s Type
forall op s. VName -> Space -> CompilerM op s Type
memToCType VName
name Space
space
      FieldGroup -> CompilerM op s FieldGroup
forall (m :: * -> *) a. Monad m => a -> m a
return [C.csdecl|$ty:ty $id:name;|]

    constMacro :: Param -> (BlockItem, BlockItem)
constMacro Param
p = ([C.citem|$escstm:def|], [C.citem|$escstm:undef|])
      where
        p' :: String
p' = Id -> String
forall a. Pretty a => a -> String
pretty (VName -> SrcLoc -> Id
forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent (Param -> VName
paramName Param
p) SrcLoc
forall a. Monoid a => a
mempty)
        def :: String
def = String
"#define " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
p' String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" (" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"ctx->constants." String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
p' String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
")"
        undef :: String
undef = String
"#undef " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
p'

    resetMemConst :: Param -> CompilerM op s ()
resetMemConst ScalarParam {} = () -> CompilerM op s ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    resetMemConst (MemParam VName
name Space
space) = VName -> Space -> CompilerM op s ()
forall a op s. ToExp a => a -> Space -> CompilerM op s ()
resetMem VName
name Space
space

    freeConst :: Param -> CompilerM op s ()
freeConst ScalarParam {} = () -> CompilerM op s ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    freeConst (MemParam VName
name Space
space) = Exp -> Space -> CompilerM op s ()
forall a op s. ToExp a => a -> Space -> CompilerM op s ()
unRefMem [C.cexp|ctx->constants.$id:name|] Space
space

    getConst :: Param -> CompilerM op s BlockItem
getConst (ScalarParam VName
name PrimType
bt) = do
      let ctp :: Type
ctp = PrimType -> Type
primTypeToCType PrimType
bt
      BlockItem -> CompilerM op s BlockItem
forall (m :: * -> *) a. Monad m => a -> m a
return [C.citem|$ty:ctp $id:name = ctx->constants.$id:name;|]
    getConst (MemParam VName
name Space
space) = do
      Type
ty <- VName -> Space -> CompilerM op s Type
forall op s. VName -> Space -> CompilerM op s Type
memToCType VName
name Space
space
      BlockItem -> CompilerM op s BlockItem
forall (m :: * -> *) a. Monad m => a -> m a
return [C.citem|$ty:ty $id:name = ctx->constants.$id:name;|]

cachingMemory ::
  M.Map VName Space ->
  ([C.BlockItem] -> [C.Stm] -> CompilerM op s a) ->
  CompilerM op s a
cachingMemory :: Map VName Space
-> ([BlockItem] -> [Stm] -> CompilerM op s a) -> CompilerM op s a
cachingMemory Map VName Space
lexical [BlockItem] -> [Stm] -> CompilerM op s a
f = do
  -- We only consider lexical 'DefaultSpace' memory blocks to be
  -- cached.  This is not a deep technical restriction, but merely a
  -- heuristic based on GPU memory usually involving larger
  -- allocations, that do not suffer from the overhead of reference
  -- counting.
  let cached :: [VName]
cached = Map VName Space -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName Space -> [VName]) -> Map VName Space -> [VName]
forall a b. (a -> b) -> a -> b
$ (Space -> Bool) -> Map VName Space -> Map VName Space
forall a k. (a -> Bool) -> Map k a -> Map k a
M.filter (Space -> Space -> Bool
forall a. Eq a => a -> a -> Bool
== Space
DefaultSpace) Map VName Space
lexical

  [(VName, VName)]
cached' <- [VName]
-> (VName -> CompilerM op s (VName, VName))
-> CompilerM op s [(VName, VName)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
cached ((VName -> CompilerM op s (VName, VName))
 -> CompilerM op s [(VName, VName)])
-> (VName -> CompilerM op s (VName, VName))
-> CompilerM op s [(VName, VName)]
forall a b. (a -> b) -> a -> b
$ \VName
mem -> do
    VName
size <- String -> CompilerM op s VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> CompilerM op s VName) -> String -> CompilerM op s VName
forall a b. (a -> b) -> a -> b
$ VName -> String
forall a. Pretty a => a -> String
pretty VName
mem String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_cached_size"
    (VName, VName) -> CompilerM op s (VName, VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
mem, VName
size)

  let lexMem :: CompilerEnv op s -> CompilerEnv op s
lexMem CompilerEnv op s
env =
        CompilerEnv op s
env
          { envCachedMem :: Map Exp VName
envCachedMem =
              [(Exp, VName)] -> Map Exp VName
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList (((VName, VName) -> (Exp, VName))
-> [(VName, VName)] -> [(Exp, VName)]
forall a b. (a -> b) -> [a] -> [b]
map ((VName -> Exp) -> (VName, VName) -> (Exp, VName)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (VName -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
`C.toExp` SrcLoc
forall a. IsLocation a => a
noLoc)) [(VName, VName)]
cached')
                Map Exp VName -> Map Exp VName -> Map Exp VName
forall a. Semigroup a => a -> a -> a
<> CompilerEnv op s -> Map Exp VName
forall op s. CompilerEnv op s -> Map Exp VName
envCachedMem CompilerEnv op s
env
          }

      declCached :: (a, a) -> [BlockItem]
declCached (a
mem, a
size) =
        [ [C.citem|size_t $id:size = 0;|],
          [C.citem|$ty:defaultMemBlockType $id:mem = NULL;|]
        ]

      freeCached :: (a, b) -> Stm
freeCached (a
mem, b
_) =
        [C.cstm|free($id:mem);|]

  (CompilerEnv op s -> CompilerEnv op s)
-> CompilerM op s a -> CompilerM op s a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local CompilerEnv op s -> CompilerEnv op s
forall op s. CompilerEnv op s -> CompilerEnv op s
lexMem (CompilerM op s a -> CompilerM op s a)
-> CompilerM op s a -> CompilerM op s a
forall a b. (a -> b) -> a -> b
$ [BlockItem] -> [Stm] -> CompilerM op s a
f (((VName, VName) -> [BlockItem]) -> [(VName, VName)] -> [BlockItem]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (VName, VName) -> [BlockItem]
forall a a. (ToIdent a, ToIdent a) => (a, a) -> [BlockItem]
declCached [(VName, VName)]
cached') (((VName, VName) -> Stm) -> [(VName, VName)] -> [Stm]
forall a b. (a -> b) -> [a] -> [b]
map (VName, VName) -> Stm
forall a b. ToIdent a => (a, b) -> Stm
freeCached [(VName, VName)]
cached')

compileFun :: [C.BlockItem] -> [C.Param] -> (Name, Function op) -> CompilerM op s (C.Definition, C.Func)
compileFun :: [BlockItem]
-> [Param]
-> (Name, Function op)
-> CompilerM op s (Definition, Func)
compileFun [BlockItem]
get_constants [Param]
extra (Name
fname, func :: Function op
func@(Function Bool
_ [Param]
outputs [Param]
inputs Code op
body [ExternalValue]
_ [ExternalValue]
_)) = do
  ([Param]
outparams, [Exp]
out_ptrs) <- [(Param, Exp)] -> ([Param], [Exp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Param, Exp)] -> ([Param], [Exp]))
-> CompilerM op s [(Param, Exp)] -> CompilerM op s ([Param], [Exp])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Param -> CompilerM op s (Param, Exp))
-> [Param] -> CompilerM op s [(Param, Exp)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Param -> CompilerM op s (Param, Exp)
forall op s. Param -> CompilerM op s (Param, Exp)
compileOutput [Param]
outputs
  [Param]
inparams <- (Param -> CompilerM op s Param)
-> [Param] -> CompilerM op s [Param]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Param -> CompilerM op s Param
forall op s. Param -> CompilerM op s Param
compileInput [Param]
inputs

  Map VName Space
-> ([BlockItem] -> [Stm] -> CompilerM op s (Definition, Func))
-> CompilerM op s (Definition, Func)
forall op s a.
Map VName Space
-> ([BlockItem] -> [Stm] -> CompilerM op s a) -> CompilerM op s a
cachingMemory (Function op -> Map VName Space
forall a. Function a -> Map VName Space
lexicalMemoryUsage Function op
func) (([BlockItem] -> [Stm] -> CompilerM op s (Definition, Func))
 -> CompilerM op s (Definition, Func))
-> ([BlockItem] -> [Stm] -> CompilerM op s (Definition, Func))
-> CompilerM op s (Definition, Func)
forall a b. (a -> b) -> a -> b
$ \[BlockItem]
decl_cached [Stm]
free_cached -> do
    [BlockItem]
body' <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
blockScope (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ [Exp] -> [Param] -> Code op -> CompilerM op s ()
forall op s. [Exp] -> [Param] -> Code op -> CompilerM op s ()
compileFunBody [Exp]
out_ptrs [Param]
outputs Code op
body

    (Definition, Func) -> CompilerM op s (Definition, Func)
forall (m :: * -> *) a. Monad m => a -> m a
return
      ( [C.cedecl|static int $id:(funName fname)($params:extra, $params:outparams, $params:inparams);|],
        [C.cfun|static int $id:(funName fname)($params:extra, $params:outparams, $params:inparams) {
               $stms:ignores
               int err = 0;
               $items:decl_cached
               $items:get_constants
               $items:body'
              cleanup:
               {}
               $stms:free_cached
               return err;
  }|]
      )
  where
    -- Ignore all the boilerplate parameters, just in case we don't
    -- actually need to use them.
    ignores :: [Stm]
ignores = [[C.cstm|(void)$id:p;|] | C.Param (Just Id
p) DeclSpec
_ Decl
_ SrcLoc
_ <- [Param]
extra]

    compileInput :: Param -> CompilerM op s Param
compileInput (ScalarParam VName
name PrimType
bt) = do
      let ctp :: Type
ctp = PrimType -> Type
primTypeToCType PrimType
bt
      Param -> CompilerM op s Param
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cparam|$ty:ctp $id:name|]
    compileInput (MemParam VName
name Space
space) = do
      Type
ty <- VName -> Space -> CompilerM op s Type
forall op s. VName -> Space -> CompilerM op s Type
memToCType VName
name Space
space
      Param -> CompilerM op s Param
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cparam|$ty:ty $id:name|]

    compileOutput :: Param -> CompilerM op s (Param, Exp)
compileOutput (ScalarParam VName
name PrimType
bt) = do
      let ctp :: Type
ctp = PrimType -> Type
primTypeToCType PrimType
bt
      VName
p_name <- String -> CompilerM op s VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> CompilerM op s VName) -> String -> CompilerM op s VName
forall a b. (a -> b) -> a -> b
$ String
"out_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ VName -> String
baseString VName
name
      (Param, Exp) -> CompilerM op s (Param, Exp)
forall (m :: * -> *) a. Monad m => a -> m a
return ([C.cparam|$ty:ctp *$id:p_name|], [C.cexp|$id:p_name|])
    compileOutput (MemParam VName
name Space
space) = do
      Type
ty <- VName -> Space -> CompilerM op s Type
forall op s. VName -> Space -> CompilerM op s Type
memToCType VName
name Space
space
      VName
p_name <- String -> CompilerM op s VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> CompilerM op s VName) -> String -> CompilerM op s VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
name String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_p"
      (Param, Exp) -> CompilerM op s (Param, Exp)
forall (m :: * -> *) a. Monad m => a -> m a
return ([C.cparam|$ty:ty *$id:p_name|], [C.cexp|$id:p_name|])

compilePrimValue :: PrimValue -> C.Exp
compilePrimValue :: PrimValue -> Exp
compilePrimValue (IntValue (Int8Value Int8
k)) = [C.cexp|$int:k|]
compilePrimValue (IntValue (Int16Value Int16
k)) = [C.cexp|$int:k|]
compilePrimValue (IntValue (Int32Value Int32
k)) = [C.cexp|$int:k|]
compilePrimValue (IntValue (Int64Value Int64
k)) = [C.cexp|$int:k|]
compilePrimValue (FloatValue (Float64Value Double
x))
  | Double -> Bool
forall a. RealFloat a => a -> Bool
isInfinite Double
x =
    if Double
x Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
> Double
0 then [C.cexp|INFINITY|] else [C.cexp|-INFINITY|]
  | Double -> Bool
forall a. RealFloat a => a -> Bool
isNaN Double
x =
    [C.cexp|NAN|]
  | Bool
otherwise =
    [C.cexp|$double:x|]
compilePrimValue (FloatValue (Float32Value Float
x))
  | Float -> Bool
forall a. RealFloat a => a -> Bool
isInfinite Float
x =
    if Float
x Float -> Float -> Bool
forall a. Ord a => a -> a -> Bool
> Float
0 then [C.cexp|INFINITY|] else [C.cexp|-INFINITY|]
  | Float -> Bool
forall a. RealFloat a => a -> Bool
isNaN Float
x =
    [C.cexp|NAN|]
  | Bool
otherwise =
    [C.cexp|$float:x|]
compilePrimValue (BoolValue Bool
b) =
  [C.cexp|$int:b'|]
  where
    b' :: Int
    b' :: Int
b' = if Bool
b then Int
1 else Int
0
compilePrimValue PrimValue
Checked =
  [C.cexp|0|]

derefPointer :: C.Exp -> C.Exp -> C.Type -> C.Exp
derefPointer :: Exp -> Exp -> Type -> Exp
derefPointer Exp
ptr Exp
i Type
res_t =
  [C.cexp|(($ty:res_t)$exp:ptr)[$exp:i]|]

volQuals :: Volatility -> [C.TypeQual]
volQuals :: Volatility -> [TypeQual]
volQuals Volatility
Volatile = [C.ctyquals|volatile|]
volQuals Volatility
Nonvolatile = []

writeScalarPointerWithQuals :: PointerQuals op s -> WriteScalar op s
writeScalarPointerWithQuals :: PointerQuals op s -> WriteScalar op s
writeScalarPointerWithQuals PointerQuals op s
quals_f Exp
dest Exp
i Type
elemtype String
space Volatility
vol Exp
v = do
  [TypeQual]
quals <- PointerQuals op s
quals_f String
space
  let quals' :: [TypeQual]
quals' = Volatility -> [TypeQual]
volQuals Volatility
vol [TypeQual] -> [TypeQual] -> [TypeQual]
forall a. [a] -> [a] -> [a]
++ [TypeQual]
quals
      deref :: Exp
deref =
        Exp -> Exp -> Type -> Exp
derefPointer
          Exp
dest
          Exp
i
          [C.cty|$tyquals:quals' $ty:elemtype*|]
  Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|$exp:deref = $exp:v;|]

readScalarPointerWithQuals :: PointerQuals op s -> ReadScalar op s
readScalarPointerWithQuals :: PointerQuals op s -> ReadScalar op s
readScalarPointerWithQuals PointerQuals op s
quals_f Exp
dest Exp
i Type
elemtype String
space Volatility
vol = do
  [TypeQual]
quals <- PointerQuals op s
quals_f String
space
  let quals' :: [TypeQual]
quals' = Volatility -> [TypeQual]
volQuals Volatility
vol [TypeQual] -> [TypeQual] -> [TypeQual]
forall a. [a] -> [a] -> [a]
++ [TypeQual]
quals
  Exp -> CompilerM op s Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> CompilerM op s Exp) -> Exp -> CompilerM op s Exp
forall a b. (a -> b) -> a -> b
$ Exp -> Exp -> Type -> Exp
derefPointer Exp
dest Exp
i [C.cty|$tyquals:quals' $ty:elemtype*|]

compileExpToName :: String -> PrimType -> Exp -> CompilerM op s VName
compileExpToName :: String -> PrimType -> Exp -> CompilerM op s VName
compileExpToName String
_ PrimType
_ (LeafExp (ScalarVar VName
v) PrimType
_) =
  VName -> CompilerM op s VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
v
compileExpToName String
desc PrimType
t Exp
e = do
  VName
desc' <- String -> CompilerM op s VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
desc
  Exp
e' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
e
  InitGroup -> CompilerM op s ()
forall op s. InitGroup -> CompilerM op s ()
decl [C.cdecl|$ty:(primTypeToCType t) $id:desc' = $e';|]
  VName -> CompilerM op s VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
desc'

compileExp :: Exp -> CompilerM op s C.Exp
compileExp :: Exp -> CompilerM op s Exp
compileExp = (ExpLeaf -> CompilerM op s Exp) -> Exp -> CompilerM op s Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp ExpLeaf -> CompilerM op s Exp
forall op s. ExpLeaf -> CompilerM op s Exp
compileLeaf
  where
    compileLeaf :: ExpLeaf -> CompilerM op s Exp
compileLeaf (ScalarVar VName
src) =
      Exp -> CompilerM op s Exp
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cexp|$id:src|]
    compileLeaf (Index VName
src (Count TExp Int64
iexp) PrimType
restype Space
DefaultSpace Volatility
vol) = do
      Exp
src' <- VName -> CompilerM op s Exp
forall op s. VName -> CompilerM op s Exp
rawMem VName
src
      Exp -> Exp -> Type -> Exp
derefPointer Exp
src'
        (Exp -> Type -> Exp)
-> CompilerM op s Exp -> CompilerM op s (Type -> Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp (TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
iexp)
        CompilerM op s (Type -> Exp)
-> CompilerM op s Type -> CompilerM op s Exp
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type -> CompilerM op s Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cty|$tyquals:(volQuals vol) $ty:(primTypeToCType restype)*|]
    compileLeaf (Index VName
src (Count TExp Int64
iexp) PrimType
restype (Space String
space) Volatility
vol) =
      CompilerM op s (CompilerM op s Exp) -> CompilerM op s Exp
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (CompilerM op s (CompilerM op s Exp) -> CompilerM op s Exp)
-> CompilerM op s (CompilerM op s Exp) -> CompilerM op s Exp
forall a b. (a -> b) -> a -> b
$
        (CompilerEnv op s -> ReadScalar op s)
-> CompilerM op s (ReadScalar op s)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks CompilerEnv op s -> ReadScalar op s
forall op s. CompilerEnv op s -> ReadScalar op s
envReadScalar
          CompilerM op s (ReadScalar op s)
-> CompilerM op s Exp
-> CompilerM
     op s (Exp -> Type -> String -> Volatility -> CompilerM op s Exp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> VName -> CompilerM op s Exp
forall op s. VName -> CompilerM op s Exp
rawMem VName
src
          CompilerM
  op s (Exp -> Type -> String -> Volatility -> CompilerM op s Exp)
-> CompilerM op s Exp
-> CompilerM
     op s (Type -> String -> Volatility -> CompilerM op s Exp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp (TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
iexp)
          CompilerM op s (Type -> String -> Volatility -> CompilerM op s Exp)
-> CompilerM op s Type
-> CompilerM op s (String -> Volatility -> CompilerM op s Exp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type -> CompilerM op s Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PrimType -> Type
primTypeToCType PrimType
restype)
          CompilerM op s (String -> Volatility -> CompilerM op s Exp)
-> CompilerM op s String
-> CompilerM op s (Volatility -> CompilerM op s Exp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> String -> CompilerM op s String
forall (f :: * -> *) a. Applicative f => a -> f a
pure String
space
          CompilerM op s (Volatility -> CompilerM op s Exp)
-> CompilerM op s Volatility -> CompilerM op s (CompilerM op s Exp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Volatility -> CompilerM op s Volatility
forall (f :: * -> *) a. Applicative f => a -> f a
pure Volatility
vol
    compileLeaf (Index VName
src (Count TExp Int64
iexp) PrimType
_ ScalarSpace {} Volatility
_) = do
      Exp
iexp' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp (Exp -> CompilerM op s Exp) -> Exp -> CompilerM op s Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
iexp
      Exp -> CompilerM op s Exp
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cexp|$id:src[$exp:iexp']|]
    compileLeaf (SizeOf PrimType
t) =
      Exp -> CompilerM op s Exp
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cexp|(typename int64_t)sizeof($ty:t')|]
      where
        t' :: Type
t' = PrimType -> Type
primTypeToCType PrimType
t

-- | Tell me how to compile a @v@, and I'll Compile any @PrimExp v@ for you.
compilePrimExp :: Monad m => (v -> m C.Exp) -> PrimExp v -> m C.Exp
compilePrimExp :: (v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp v -> m Exp
_ (ValueExp PrimValue
val) =
  Exp -> m Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> m Exp) -> Exp -> m Exp
forall a b. (a -> b) -> a -> b
$ PrimValue -> Exp
compilePrimValue PrimValue
val
compilePrimExp v -> m Exp
f (LeafExp v
v PrimType
_) =
  v -> m Exp
f v
v
compilePrimExp v -> m Exp
f (UnOpExp Complement {} PrimExp v
x) = do
  Exp
x' <- (v -> m Exp) -> PrimExp v -> m Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp v -> m Exp
f PrimExp v
x
  Exp -> m Exp
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cexp|~$exp:x'|]
compilePrimExp v -> m Exp
f (UnOpExp Not {} PrimExp v
x) = do
  Exp
x' <- (v -> m Exp) -> PrimExp v -> m Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp v -> m Exp
f PrimExp v
x
  Exp -> m Exp
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cexp|!$exp:x'|]
compilePrimExp v -> m Exp
f (UnOpExp Abs {} PrimExp v
x) = do
  Exp
x' <- (v -> m Exp) -> PrimExp v -> m Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp v -> m Exp
f PrimExp v
x
  Exp -> m Exp
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cexp|abs($exp:x')|]
compilePrimExp v -> m Exp
f (UnOpExp (FAbs FloatType
Float32) PrimExp v
x) = do
  Exp
x' <- (v -> m Exp) -> PrimExp v -> m Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp v -> m Exp
f PrimExp v
x
  Exp -> m Exp
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cexp|(float)fabs($exp:x')|]
compilePrimExp v -> m Exp
f (UnOpExp (FAbs FloatType
Float64) PrimExp v
x) = do
  Exp
x' <- (v -> m Exp) -> PrimExp v -> m Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp v -> m Exp
f PrimExp v
x
  Exp -> m Exp
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cexp|fabs($exp:x')|]
compilePrimExp v -> m Exp
f (UnOpExp SSignum {} PrimExp v
x) = do
  Exp
x' <- (v -> m Exp) -> PrimExp v -> m Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp v -> m Exp
f PrimExp v
x
  Exp -> m Exp
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cexp|($exp:x' > 0) - ($exp:x' < 0)|]
compilePrimExp v -> m Exp
f (UnOpExp USignum {} PrimExp v
x) = do
  Exp
x' <- (v -> m Exp) -> PrimExp v -> m Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp v -> m Exp
f PrimExp v
x
  Exp -> m Exp
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cexp|($exp:x' > 0) - ($exp:x' < 0) != 0|]
compilePrimExp v -> m Exp
f (CmpOpExp CmpOp
cmp PrimExp v
x PrimExp v
y) = do
  Exp
x' <- (v -> m Exp) -> PrimExp v -> m Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp v -> m Exp
f PrimExp v
x
  Exp
y' <- (v -> m Exp) -> PrimExp v -> m Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp v -> m Exp
f PrimExp v
y
  Exp -> m Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> m Exp) -> Exp -> m Exp
forall a b. (a -> b) -> a -> b
$ case CmpOp
cmp of
    CmpEq {} -> [C.cexp|$exp:x' == $exp:y'|]
    FCmpLt {} -> [C.cexp|$exp:x' < $exp:y'|]
    FCmpLe {} -> [C.cexp|$exp:x' <= $exp:y'|]
    CmpLlt {} -> [C.cexp|$exp:x' < $exp:y'|]
    CmpLle {} -> [C.cexp|$exp:x' <= $exp:y'|]
    CmpOp
_ -> [C.cexp|$id:(pretty cmp)($exp:x', $exp:y')|]
compilePrimExp v -> m Exp
f (ConvOpExp ConvOp
conv PrimExp v
x) = do
  Exp
x' <- (v -> m Exp) -> PrimExp v -> m Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp v -> m Exp
f PrimExp v
x
  Exp -> m Exp
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cexp|$id:(pretty conv)($exp:x')|]
compilePrimExp v -> m Exp
f (BinOpExp BinOp
bop PrimExp v
x PrimExp v
y) = do
  Exp
x' <- (v -> m Exp) -> PrimExp v -> m Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp v -> m Exp
f PrimExp v
x
  Exp
y' <- (v -> m Exp) -> PrimExp v -> m Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp v -> m Exp
f PrimExp v
y
  -- Note that integer addition, subtraction, and multiplication with
  -- OverflowWrap are not handled by explicit operators, but rather by
  -- functions.  This is because we want to implicitly convert them to
  -- unsigned numbers, so we can do overflow without invoking
  -- undefined behaviour.
  Exp -> m Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> m Exp) -> Exp -> m Exp
forall a b. (a -> b) -> a -> b
$ case BinOp
bop of
    Add IntType
_ Overflow
OverflowUndef -> [C.cexp|$exp:x' + $exp:y'|]
    Sub IntType
_ Overflow
OverflowUndef -> [C.cexp|$exp:x' - $exp:y'|]
    Mul IntType
_ Overflow
OverflowUndef -> [C.cexp|$exp:x' * $exp:y'|]
    FAdd {} -> [C.cexp|$exp:x' + $exp:y'|]
    FSub {} -> [C.cexp|$exp:x' - $exp:y'|]
    FMul {} -> [C.cexp|$exp:x' * $exp:y'|]
    FDiv {} -> [C.cexp|$exp:x' / $exp:y'|]
    Xor {} -> [C.cexp|$exp:x' ^ $exp:y'|]
    And {} -> [C.cexp|$exp:x' & $exp:y'|]
    Or {} -> [C.cexp|$exp:x' | $exp:y'|]
    Shl {} -> [C.cexp|$exp:x' << $exp:y'|]
    LogAnd {} -> [C.cexp|$exp:x' && $exp:y'|]
    LogOr {} -> [C.cexp|$exp:x' || $exp:y'|]
    BinOp
_ -> [C.cexp|$id:(pretty bop)($exp:x', $exp:y')|]
compilePrimExp v -> m Exp
f (FunExp String
h [PrimExp v]
args PrimType
_) = do
  [Exp]
args' <- (PrimExp v -> m Exp) -> [PrimExp v] -> m [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((v -> m Exp) -> PrimExp v -> m Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp v -> m Exp
f) [PrimExp v]
args
  Exp -> m Exp
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cexp|$id:(funName (nameFromString h))($args:args')|]

compileCode :: Code op -> CompilerM op s ()
compileCode :: Code op -> CompilerM op s ()
compileCode (Op op
op) =
  CompilerM op s (CompilerM op s ()) -> CompilerM op s ()
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (CompilerM op s (CompilerM op s ()) -> CompilerM op s ())
-> CompilerM op s (CompilerM op s ()) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ (CompilerEnv op s -> OpCompiler op s)
-> CompilerM op s (OpCompiler op s)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks CompilerEnv op s -> OpCompiler op s
forall op s. CompilerEnv op s -> OpCompiler op s
envOpCompiler CompilerM op s (OpCompiler op s)
-> CompilerM op s op -> CompilerM op s (CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> op -> CompilerM op s op
forall (f :: * -> *) a. Applicative f => a -> f a
pure op
op
compileCode Code op
Skip = () -> CompilerM op s ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
compileCode (Comment String
s Code op
code) = do
  [BlockItem]
xs <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
blockScope (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ Code op -> CompilerM op s ()
forall op s. Code op -> CompilerM op s ()
compileCode Code op
code
  let comment :: String
comment = String
"// " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
s
  Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm
    [C.cstm|$comment:comment
              { $items:xs }
             |]
compileCode (DebugPrint String
s (Just Exp
e)) = do
  Exp
e' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
e
  Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm
    [C.cstm|if (ctx->debugging) {
          fprintf(stderr, $string:fmtstr, $exp:s, ($ty:ety)$exp:e', '\n');
       }|]
  where
    (String
fmt, Type
ety) = case Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType Exp
e of
      IntType IntType
_ -> (String
"llu", [C.cty|long long int|])
      FloatType FloatType
_ -> (String
"f", [C.cty|double|])
      PrimType
_ -> (String
"d", [C.cty|int|])
    fmtstr :: String
fmtstr = String
"%s: %" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
fmt String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"%c"
compileCode (DebugPrint String
s Maybe Exp
Nothing) =
  Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm
    [C.cstm|if (ctx->debugging) {
          fprintf(stderr, "%s\n", $exp:s);
       }|]
compileCode Code op
c
  | Just (VName
name, Volatility
vol, PrimType
t, Exp
e, Code op
c') <- Code op -> Maybe (VName, Volatility, PrimType, Exp, Code op)
forall op.
Code op -> Maybe (VName, Volatility, PrimType, Exp, Code op)
declareAndSet Code op
c = do
    let ct :: Type
ct = PrimType -> Type
primTypeToCType PrimType
t
    Exp
e' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
e
    BlockItem -> CompilerM op s ()
forall op s. BlockItem -> CompilerM op s ()
item [C.citem|$tyquals:(volQuals vol) $ty:ct $id:name = $exp:e';|]
    Code op -> CompilerM op s ()
forall op s. Code op -> CompilerM op s ()
compileCode Code op
c'
compileCode (Code op
c1 :>>: Code op
c2) = Code op -> CompilerM op s ()
forall op s. Code op -> CompilerM op s ()
compileCode Code op
c1 CompilerM op s () -> CompilerM op s () -> CompilerM op s ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Code op -> CompilerM op s ()
forall op s. Code op -> CompilerM op s ()
compileCode Code op
c2
compileCode (Assert Exp
e ErrorMsg Exp
msg (SrcLoc
loc, [SrcLoc]
locs)) = do
  Exp
e' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
e
  [BlockItem]
err <-
    CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
collect (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$
      CompilerM op s (CompilerM op s ()) -> CompilerM op s ()
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (CompilerM op s (CompilerM op s ()) -> CompilerM op s ())
-> CompilerM op s (CompilerM op s ()) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$
        (CompilerEnv op s -> ErrorCompiler op s)
-> CompilerM op s (ErrorCompiler op s)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (Operations op s -> ErrorCompiler op s
forall op s. Operations op s -> ErrorCompiler op s
opsError (Operations op s -> ErrorCompiler op s)
-> (CompilerEnv op s -> Operations op s)
-> CompilerEnv op s
-> ErrorCompiler op s
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerEnv op s -> Operations op s
forall op s. CompilerEnv op s -> Operations op s
envOperations) CompilerM op s (ErrorCompiler op s)
-> CompilerM op s (ErrorMsg Exp)
-> CompilerM op s (String -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ErrorMsg Exp -> CompilerM op s (ErrorMsg Exp)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ErrorMsg Exp
msg CompilerM op s (String -> CompilerM op s ())
-> CompilerM op s String -> CompilerM op s (CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> String -> CompilerM op s String
forall (f :: * -> *) a. Applicative f => a -> f a
pure String
stacktrace
  Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|if (!$exp:e') { $items:err }|]
  where
    stacktrace :: String
stacktrace = Int -> [String] -> String
prettyStacktrace Int
0 ([String] -> String) -> [String] -> String
forall a b. (a -> b) -> a -> b
$ (SrcLoc -> String) -> [SrcLoc] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map SrcLoc -> String
forall a. Located a => a -> String
locStr ([SrcLoc] -> [String]) -> [SrcLoc] -> [String]
forall a b. (a -> b) -> a -> b
$ SrcLoc
loc SrcLoc -> [SrcLoc] -> [SrcLoc]
forall a. a -> [a] -> [a]
: [SrcLoc]
locs
compileCode (Allocate VName
_ Count Bytes (TExp Int64)
_ ScalarSpace {}) =
  -- Handled by the declaration of the memory block, which is
  -- translated to an actual array.
  () -> CompilerM op s ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
compileCode (Allocate VName
name (Count (TPrimExp Exp
e)) Space
space) = do
  Exp
size <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
e
  Maybe VName
cached <- VName -> CompilerM op s (Maybe VName)
forall a op s. ToExp a => a -> CompilerM op s (Maybe VName)
cacheMem VName
name
  case Maybe VName
cached of
    Just VName
cur_size ->
      Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm
        [C.cstm|if ($exp:cur_size < (size_t)$exp:size) {
                    $exp:name = realloc($exp:name, $exp:size);
                    $exp:cur_size = $exp:size;
                  }|]
    Maybe VName
_ ->
      VName -> Exp -> Space -> Stm -> CompilerM op s ()
forall a b op s.
(ToExp a, ToExp b) =>
a -> b -> Space -> Stm -> CompilerM op s ()
allocMem VName
name Exp
size Space
space [C.cstm|{err = 1; goto cleanup;}|]
compileCode (Free VName
name Space
space) = do
  Bool
cached <- Maybe VName -> Bool
forall a. Maybe a -> Bool
isJust (Maybe VName -> Bool)
-> CompilerM op s (Maybe VName) -> CompilerM op s Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> CompilerM op s (Maybe VName)
forall a op s. ToExp a => a -> CompilerM op s (Maybe VName)
cacheMem VName
name
  Bool -> CompilerM op s () -> CompilerM op s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
cached (CompilerM op s () -> CompilerM op s ())
-> CompilerM op s () -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ VName -> Space -> CompilerM op s ()
forall a op s. ToExp a => a -> Space -> CompilerM op s ()
unRefMem VName
name Space
space
compileCode (For VName
i Exp
bound Code op
body) = do
  let i' :: SrcLoc -> Id
i' = VName -> SrcLoc -> Id
forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent VName
i
      t :: Type
t = PrimType -> Type
primTypeToCType (PrimType -> Type) -> PrimType -> Type
forall a b. (a -> b) -> a -> b
$ Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType Exp
bound
  Exp
bound' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
bound
  [BlockItem]
body' <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
blockScope (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ Code op -> CompilerM op s ()
forall op s. Code op -> CompilerM op s ()
compileCode Code op
body
  Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm
    [C.cstm|for ($ty:t $id:i' = 0; $id:i' < $exp:bound'; $id:i'++) {
            $items:body'
          }|]
compileCode (While TExp Bool
cond Code op
body) = do
  Exp
cond' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp (Exp -> CompilerM op s Exp) -> Exp -> CompilerM op s Exp
forall a b. (a -> b) -> a -> b
$ TExp Bool -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Bool
cond
  [BlockItem]
body' <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
blockScope (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ Code op -> CompilerM op s ()
forall op s. Code op -> CompilerM op s ()
compileCode Code op
body
  Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm
    [C.cstm|while ($exp:cond') {
            $items:body'
          }|]
compileCode (If TExp Bool
cond Code op
tbranch Code op
fbranch) = do
  Exp
cond' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp (Exp -> CompilerM op s Exp) -> Exp -> CompilerM op s Exp
forall a b. (a -> b) -> a -> b
$ TExp Bool -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Bool
cond
  [BlockItem]
tbranch' <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
blockScope (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ Code op -> CompilerM op s ()
forall op s. Code op -> CompilerM op s ()
compileCode Code op
tbranch
  [BlockItem]
fbranch' <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
blockScope (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ Code op -> CompilerM op s ()
forall op s. Code op -> CompilerM op s ()
compileCode Code op
fbranch
  Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm (Stm -> CompilerM op s ()) -> Stm -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ case ([BlockItem]
tbranch', [BlockItem]
fbranch') of
    ([BlockItem]
_, []) ->
      [C.cstm|if ($exp:cond') { $items:tbranch' }|]
    ([], [BlockItem]
_) ->
      [C.cstm|if (!($exp:cond')) { $items:fbranch' }|]
    ([BlockItem], [BlockItem])
_ ->
      [C.cstm|if ($exp:cond') { $items:tbranch' } else { $items:fbranch' }|]
compileCode (Copy VName
dest (Count TExp Int64
destoffset) Space
DefaultSpace VName
src (Count TExp Int64
srcoffset) Space
DefaultSpace (Count TExp Int64
size)) =
  CompilerM op s (CompilerM op s ()) -> CompilerM op s ()
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (CompilerM op s (CompilerM op s ()) -> CompilerM op s ())
-> CompilerM op s (CompilerM op s ()) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$
    Exp -> Exp -> Exp -> Exp -> Exp -> CompilerM op s ()
forall op s. Exp -> Exp -> Exp -> Exp -> Exp -> CompilerM op s ()
copyMemoryDefaultSpace
      (Exp -> Exp -> Exp -> Exp -> Exp -> CompilerM op s ())
-> CompilerM op s Exp
-> CompilerM op s (Exp -> Exp -> Exp -> Exp -> CompilerM op s ())
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> CompilerM op s Exp
forall op s. VName -> CompilerM op s Exp
rawMem VName
dest
      CompilerM op s (Exp -> Exp -> Exp -> Exp -> CompilerM op s ())
-> CompilerM op s Exp
-> CompilerM op s (Exp -> Exp -> Exp -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp (TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
destoffset)
      CompilerM op s (Exp -> Exp -> Exp -> CompilerM op s ())
-> CompilerM op s Exp
-> CompilerM op s (Exp -> Exp -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> VName -> CompilerM op s Exp
forall op s. VName -> CompilerM op s Exp
rawMem VName
src
      CompilerM op s (Exp -> Exp -> CompilerM op s ())
-> CompilerM op s Exp -> CompilerM op s (Exp -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp (TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
srcoffset)
      CompilerM op s (Exp -> CompilerM op s ())
-> CompilerM op s Exp -> CompilerM op s (CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp (TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
size)
compileCode (Copy VName
dest (Count TExp Int64
destoffset) Space
destspace VName
src (Count TExp Int64
srcoffset) Space
srcspace (Count TExp Int64
size)) = do
  Copy op s
copy <- (CompilerEnv op s -> Copy op s) -> CompilerM op s (Copy op s)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks CompilerEnv op s -> Copy op s
forall op s. CompilerEnv op s -> Copy op s
envCopy
  CompilerM op s (CompilerM op s ()) -> CompilerM op s ()
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (CompilerM op s (CompilerM op s ()) -> CompilerM op s ())
-> CompilerM op s (CompilerM op s ()) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$
    Copy op s
copy
      Copy op s
-> CompilerM op s Exp
-> CompilerM
     op
     s
     (Exp -> Space -> Exp -> Exp -> Space -> Exp -> CompilerM op s ())
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> CompilerM op s Exp
forall op s. VName -> CompilerM op s Exp
rawMem VName
dest
      CompilerM
  op
  s
  (Exp -> Space -> Exp -> Exp -> Space -> Exp -> CompilerM op s ())
-> CompilerM op s Exp
-> CompilerM
     op s (Space -> Exp -> Exp -> Space -> Exp -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp (TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
destoffset)
      CompilerM
  op s (Space -> Exp -> Exp -> Space -> Exp -> CompilerM op s ())
-> CompilerM op s Space
-> CompilerM op s (Exp -> Exp -> Space -> Exp -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Space -> CompilerM op s Space
forall (f :: * -> *) a. Applicative f => a -> f a
pure Space
destspace
      CompilerM op s (Exp -> Exp -> Space -> Exp -> CompilerM op s ())
-> CompilerM op s Exp
-> CompilerM op s (Exp -> Space -> Exp -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> VName -> CompilerM op s Exp
forall op s. VName -> CompilerM op s Exp
rawMem VName
src
      CompilerM op s (Exp -> Space -> Exp -> CompilerM op s ())
-> CompilerM op s Exp
-> CompilerM op s (Space -> Exp -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp (TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
srcoffset)
      CompilerM op s (Space -> Exp -> CompilerM op s ())
-> CompilerM op s Space
-> CompilerM op s (Exp -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Space -> CompilerM op s Space
forall (f :: * -> *) a. Applicative f => a -> f a
pure Space
srcspace
      CompilerM op s (Exp -> CompilerM op s ())
-> CompilerM op s Exp -> CompilerM op s (CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp (TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
size)
compileCode (Write VName
dest (Count TExp Int64
idx) PrimType
elemtype Space
DefaultSpace Volatility
vol Exp
elemexp) = do
  Exp
dest' <- VName -> CompilerM op s Exp
forall op s. VName -> CompilerM op s Exp
rawMem VName
dest
  Exp
deref <-
    Exp -> Exp -> Type -> Exp
derefPointer Exp
dest'
      (Exp -> Type -> Exp)
-> CompilerM op s Exp -> CompilerM op s (Type -> Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp (TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
idx)
      CompilerM op s (Type -> Exp)
-> CompilerM op s Type -> CompilerM op s Exp
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type -> CompilerM op s Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cty|$tyquals:(volQuals vol) $ty:(primTypeToCType elemtype)*|]
  Exp
elemexp' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
elemexp
  Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|$exp:deref = $exp:elemexp';|]
compileCode (Write VName
dest (Count TExp Int64
idx) PrimType
_ ScalarSpace {} Volatility
_ Exp
elemexp) = do
  Exp
idx' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp (TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
idx)
  Exp
elemexp' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
elemexp
  Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|$id:dest[$exp:idx'] = $exp:elemexp';|]
compileCode (Write VName
dest (Count TExp Int64
idx) PrimType
elemtype (Space String
space) Volatility
vol Exp
elemexp) =
  CompilerM op s (CompilerM op s ()) -> CompilerM op s ()
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (CompilerM op s (CompilerM op s ()) -> CompilerM op s ())
-> CompilerM op s (CompilerM op s ()) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$
    (CompilerEnv op s -> WriteScalar op s)
-> CompilerM op s (WriteScalar op s)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks CompilerEnv op s -> WriteScalar op s
forall op s. CompilerEnv op s -> WriteScalar op s
envWriteScalar
      CompilerM op s (WriteScalar op s)
-> CompilerM op s Exp
-> CompilerM
     op
     s
     (Exp -> Type -> String -> Volatility -> Exp -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> VName -> CompilerM op s Exp
forall op s. VName -> CompilerM op s Exp
rawMem VName
dest
      CompilerM
  op
  s
  (Exp -> Type -> String -> Volatility -> Exp -> CompilerM op s ())
-> CompilerM op s Exp
-> CompilerM
     op s (Type -> String -> Volatility -> Exp -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp (TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
idx)
      CompilerM
  op s (Type -> String -> Volatility -> Exp -> CompilerM op s ())
-> CompilerM op s Type
-> CompilerM
     op s (String -> Volatility -> Exp -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type -> CompilerM op s Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PrimType -> Type
primTypeToCType PrimType
elemtype)
      CompilerM op s (String -> Volatility -> Exp -> CompilerM op s ())
-> CompilerM op s String
-> CompilerM op s (Volatility -> Exp -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> String -> CompilerM op s String
forall (f :: * -> *) a. Applicative f => a -> f a
pure String
space
      CompilerM op s (Volatility -> Exp -> CompilerM op s ())
-> CompilerM op s Volatility
-> CompilerM op s (Exp -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Volatility -> CompilerM op s Volatility
forall (f :: * -> *) a. Applicative f => a -> f a
pure Volatility
vol
      CompilerM op s (Exp -> CompilerM op s ())
-> CompilerM op s Exp -> CompilerM op s (CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
elemexp
compileCode (DeclareMem VName
name Space
space) =
  VName -> Space -> CompilerM op s ()
forall op s. VName -> Space -> CompilerM op s ()
declMem VName
name Space
space
compileCode (DeclareScalar VName
name Volatility
vol PrimType
t) = do
  let ct :: Type
ct = PrimType -> Type
primTypeToCType PrimType
t
  InitGroup -> CompilerM op s ()
forall op s. InitGroup -> CompilerM op s ()
decl [C.cdecl|$tyquals:(volQuals vol) $ty:ct $id:name;|]
compileCode (DeclareArray VName
name ScalarSpace {} PrimType
_ ArrayContents
_) =
  String -> CompilerM op s ()
forall a. HasCallStack => String -> a
error (String -> CompilerM op s ()) -> String -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ String
"Cannot declare array " String -> String -> String
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
name String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" in scalar space."
compileCode (DeclareArray VName
name Space
DefaultSpace PrimType
t ArrayContents
vs) = do
  VName
name_realtype <- String -> CompilerM op s VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> CompilerM op s VName) -> String -> CompilerM op s VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
name String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_realtype"
  let ct :: Type
ct = PrimType -> Type
primTypeToCType PrimType
t
  case ArrayContents
vs of
    ArrayValues [PrimValue]
vs' -> do
      let vs'' :: [Initializer]
vs'' = [[C.cinit|$exp:(compilePrimValue v)|] | PrimValue
v <- [PrimValue]
vs']
      Definition -> CompilerM op s ()
forall op s. Definition -> CompilerM op s ()
earlyDecl [C.cedecl|static $ty:ct $id:name_realtype[$int:(length vs')] = {$inits:vs''};|]
    ArrayZeros Int
n ->
      Definition -> CompilerM op s ()
forall op s. Definition -> CompilerM op s ()
earlyDecl [C.cedecl|static $ty:ct $id:name_realtype[$int:n];|]
  -- Fake a memory block.
  Id -> Type -> Maybe Exp -> CompilerM op s ()
forall op s. Id -> Type -> Maybe Exp -> CompilerM op s ()
contextField
    (VName -> SrcLoc -> Id
forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent VName
name SrcLoc
forall a. IsLocation a => a
noLoc)
    [C.cty|struct memblock|]
    (Maybe Exp -> CompilerM op s ()) -> Maybe Exp -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just [C.cexp|(struct memblock){NULL, (char*)$id:name_realtype, 0}|]
  BlockItem -> CompilerM op s ()
forall op s. BlockItem -> CompilerM op s ()
item [C.citem|struct memblock $id:name = ctx->$id:name;|]
compileCode (DeclareArray VName
name (Space String
space) PrimType
t ArrayContents
vs) =
  CompilerM op s (CompilerM op s ()) -> CompilerM op s ()
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (CompilerM op s (CompilerM op s ()) -> CompilerM op s ())
-> CompilerM op s (CompilerM op s ()) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$
    (CompilerEnv op s -> StaticArray op s)
-> CompilerM op s (StaticArray op s)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks CompilerEnv op s -> StaticArray op s
forall op s. CompilerEnv op s -> StaticArray op s
envStaticArray
      CompilerM op s (StaticArray op s)
-> CompilerM op s VName
-> CompilerM
     op s (String -> PrimType -> ArrayContents -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> VName -> CompilerM op s VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
name
      CompilerM
  op s (String -> PrimType -> ArrayContents -> CompilerM op s ())
-> CompilerM op s String
-> CompilerM op s (PrimType -> ArrayContents -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> String -> CompilerM op s String
forall (f :: * -> *) a. Applicative f => a -> f a
pure String
space
      CompilerM op s (PrimType -> ArrayContents -> CompilerM op s ())
-> CompilerM op s PrimType
-> CompilerM op s (ArrayContents -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> PrimType -> CompilerM op s PrimType
forall (f :: * -> *) a. Applicative f => a -> f a
pure PrimType
t
      CompilerM op s (ArrayContents -> CompilerM op s ())
-> CompilerM op s ArrayContents
-> CompilerM op s (CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ArrayContents -> CompilerM op s ArrayContents
forall (f :: * -> *) a. Applicative f => a -> f a
pure ArrayContents
vs
-- For assignments of the form 'x = x OP e', we generate C assignment
-- operators to make the resulting code slightly nicer.  This has no
-- effect on performance.
compileCode (SetScalar VName
dest (BinOpExp BinOp
op (LeafExp (ScalarVar VName
x) PrimType
_) Exp
y))
  | VName
dest VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
x,
    Just VName -> Exp -> Exp
f <- BinOp -> Maybe (VName -> Exp -> Exp)
assignmentOperator BinOp
op = do
    Exp
y' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
y
    Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|$exp:(f dest y');|]
compileCode (SetScalar VName
dest Exp
src) = do
  Exp
src' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
src
  Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|$id:dest = $exp:src';|]
compileCode (SetMem VName
dest VName
src Space
space) =
  VName -> VName -> Space -> CompilerM op s ()
forall a b op s.
(ToExp a, ToExp b) =>
a -> b -> Space -> CompilerM op s ()
setMem VName
dest VName
src Space
space
compileCode (Call [VName]
dests Name
fname [Arg]
args) =
  CompilerM op s (CompilerM op s ()) -> CompilerM op s ()
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (CompilerM op s (CompilerM op s ()) -> CompilerM op s ())
-> CompilerM op s (CompilerM op s ()) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$
    (CompilerEnv op s -> CallCompiler op s)
-> CompilerM op s (CallCompiler op s)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (Operations op s -> CallCompiler op s
forall op s. Operations op s -> CallCompiler op s
opsCall (Operations op s -> CallCompiler op s)
-> (CompilerEnv op s -> Operations op s)
-> CompilerEnv op s
-> CallCompiler op s
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerEnv op s -> Operations op s
forall op s. CompilerEnv op s -> Operations op s
envOperations)
      CompilerM op s (CallCompiler op s)
-> CompilerM op s [VName]
-> CompilerM op s (Name -> [Exp] -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [VName] -> CompilerM op s [VName]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName]
dests
      CompilerM op s (Name -> [Exp] -> CompilerM op s ())
-> CompilerM op s Name
-> CompilerM op s ([Exp] -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Name -> CompilerM op s Name
forall (f :: * -> *) a. Applicative f => a -> f a
pure Name
fname
      CompilerM op s ([Exp] -> CompilerM op s ())
-> CompilerM op s [Exp] -> CompilerM op s (CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Arg -> CompilerM op s Exp) -> [Arg] -> CompilerM op s [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Arg -> CompilerM op s Exp
forall op s. Arg -> CompilerM op s Exp
compileArg [Arg]
args
  where
    compileArg :: Arg -> CompilerM op s Exp
compileArg (MemArg VName
m) = Exp -> CompilerM op s Exp
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cexp|$exp:m|]
    compileArg (ExpArg Exp
e) = Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
e

blockScope :: CompilerM op s () -> CompilerM op s [C.BlockItem]
blockScope :: CompilerM op s () -> CompilerM op s [BlockItem]
blockScope = (((), [BlockItem]) -> [BlockItem])
-> CompilerM op s ((), [BlockItem]) -> CompilerM op s [BlockItem]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((), [BlockItem]) -> [BlockItem]
forall a b. (a, b) -> b
snd (CompilerM op s ((), [BlockItem]) -> CompilerM op s [BlockItem])
-> (CompilerM op s () -> CompilerM op s ((), [BlockItem]))
-> CompilerM op s ()
-> CompilerM op s [BlockItem]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerM op s () -> CompilerM op s ((), [BlockItem])
forall op s a. CompilerM op s a -> CompilerM op s (a, [BlockItem])
blockScope'

blockScope' :: CompilerM op s a -> CompilerM op s (a, [C.BlockItem])
blockScope' :: CompilerM op s a -> CompilerM op s (a, [BlockItem])
blockScope' CompilerM op s a
m = do
  [(VName, Space)]
old_allocs <- (CompilerState s -> [(VName, Space)])
-> CompilerM op s [(VName, Space)]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets CompilerState s -> [(VName, Space)]
forall s. CompilerState s -> [(VName, Space)]
compDeclaredMem
  (a
x, [BlockItem]
xs) <- CompilerM
  op s ((a, [BlockItem]), CompilerAcc op s -> CompilerAcc op s)
-> CompilerM op s (a, [BlockItem])
forall w (m :: * -> *) a. MonadWriter w m => m (a, w -> w) -> m a
pass (CompilerM
   op s ((a, [BlockItem]), CompilerAcc op s -> CompilerAcc op s)
 -> CompilerM op s (a, [BlockItem]))
-> CompilerM
     op s ((a, [BlockItem]), CompilerAcc op s -> CompilerAcc op s)
-> CompilerM op s (a, [BlockItem])
forall a b. (a -> b) -> a -> b
$ do
    (a
x, CompilerAcc op s
w) <- CompilerM op s a -> CompilerM op s (a, CompilerAcc op s)
forall w (m :: * -> *) a. MonadWriter w m => m a -> m (a, w)
listen CompilerM op s a
m
    let xs :: [BlockItem]
xs = DList BlockItem -> [BlockItem]
forall a. DList a -> [a]
DL.toList (DList BlockItem -> [BlockItem]) -> DList BlockItem -> [BlockItem]
forall a b. (a -> b) -> a -> b
$ CompilerAcc op s -> DList BlockItem
forall op s. CompilerAcc op s -> DList BlockItem
accItems CompilerAcc op s
w
    ((a, [BlockItem]), CompilerAcc op s -> CompilerAcc op s)
-> CompilerM
     op s ((a, [BlockItem]), CompilerAcc op s -> CompilerAcc op s)
forall (m :: * -> *) a. Monad m => a -> m a
return ((a
x, [BlockItem]
xs), CompilerAcc op s -> CompilerAcc op s -> CompilerAcc op s
forall a b. a -> b -> a
const CompilerAcc op s
forall a. Monoid a => a
mempty)
  [(VName, Space)]
new_allocs <- (CompilerState s -> [(VName, Space)])
-> CompilerM op s [(VName, Space)]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((CompilerState s -> [(VName, Space)])
 -> CompilerM op s [(VName, Space)])
-> (CompilerState s -> [(VName, Space)])
-> CompilerM op s [(VName, Space)]
forall a b. (a -> b) -> a -> b
$ ((VName, Space) -> Bool) -> [(VName, Space)] -> [(VName, Space)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((VName, Space) -> [(VName, Space)] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [(VName, Space)]
old_allocs) ([(VName, Space)] -> [(VName, Space)])
-> (CompilerState s -> [(VName, Space)])
-> CompilerState s
-> [(VName, Space)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerState s -> [(VName, Space)]
forall s. CompilerState s -> [(VName, Space)]
compDeclaredMem
  (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((CompilerState s -> CompilerState s) -> CompilerM op s ())
-> (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \CompilerState s
s -> CompilerState s
s {compDeclaredMem :: [(VName, Space)]
compDeclaredMem = [(VName, Space)]
old_allocs}
  [BlockItem]
releases <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
collect (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ ((VName, Space) -> CompilerM op s ())
-> [(VName, Space)] -> CompilerM op s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((VName -> Space -> CompilerM op s ())
-> (VName, Space) -> CompilerM op s ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry VName -> Space -> CompilerM op s ()
forall a op s. ToExp a => a -> Space -> CompilerM op s ()
unRefMem) [(VName, Space)]
new_allocs
  (a, [BlockItem]) -> CompilerM op s (a, [BlockItem])
forall (m :: * -> *) a. Monad m => a -> m a
return (a
x, [BlockItem]
xs [BlockItem] -> [BlockItem] -> [BlockItem]
forall a. Semigroup a => a -> a -> a
<> [BlockItem]
releases)

compileFunBody :: [C.Exp] -> [Param] -> Code op -> CompilerM op s ()
compileFunBody :: [Exp] -> [Param] -> Code op -> CompilerM op s ()
compileFunBody [Exp]
output_ptrs [Param]
outputs Code op
code = do
  (Param -> CompilerM op s ()) -> [Param] -> CompilerM op s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Param -> CompilerM op s ()
forall op s. Param -> CompilerM op s ()
declareOutput [Param]
outputs
  Code op -> CompilerM op s ()
forall op s. Code op -> CompilerM op s ()
compileCode Code op
code
  (Exp -> Param -> CompilerM op s ())
-> [Exp] -> [Param] -> CompilerM op s ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Exp -> Param -> CompilerM op s ()
forall a op s. ToExp a => a -> Param -> CompilerM op s ()
setRetVal' [Exp]
output_ptrs [Param]
outputs
  where
    declareOutput :: Param -> CompilerM op s ()
declareOutput (MemParam VName
name Space
space) =
      VName -> Space -> CompilerM op s ()
forall op s. VName -> Space -> CompilerM op s ()
declMem VName
name Space
space
    declareOutput (ScalarParam VName
name PrimType
pt) = do
      let ctp :: Type
ctp = PrimType -> Type
primTypeToCType PrimType
pt
      InitGroup -> CompilerM op s ()
forall op s. InitGroup -> CompilerM op s ()
decl [C.cdecl|$ty:ctp $id:name;|]

    setRetVal' :: a -> Param -> CompilerM op s ()
setRetVal' a
p (MemParam VName
name Space
space) = do
      Exp -> Space -> CompilerM op s ()
forall a op s. ToExp a => a -> Space -> CompilerM op s ()
resetMem [C.cexp|*$exp:p|] Space
space
      Exp -> VName -> Space -> CompilerM op s ()
forall a b op s.
(ToExp a, ToExp b) =>
a -> b -> Space -> CompilerM op s ()
setMem [C.cexp|*$exp:p|] VName
name Space
space
    setRetVal' a
p (ScalarParam VName
name PrimType
_) =
      Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|*$exp:p = $id:name;|]

declareAndSet :: Code op -> Maybe (VName, Volatility, PrimType, Exp, Code op)
declareAndSet :: Code op -> Maybe (VName, Volatility, PrimType, Exp, Code op)
declareAndSet Code op
code = do
  (DeclareScalar VName
name Volatility
vol PrimType
t, Code op
code') <- Code op -> Maybe (Code op, Code op)
forall op. Code op -> Maybe (Code op, Code op)
nextCode Code op
code
  (SetScalar VName
dest Exp
e, Code op
code'') <- Code op -> Maybe (Code op, Code op)
forall op. Code op -> Maybe (Code op, Code op)
nextCode Code op
code'
  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ VName
name VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
dest
  (VName, Volatility, PrimType, Exp, Code op)
-> Maybe (VName, Volatility, PrimType, Exp, Code op)
forall a. a -> Maybe a
Just (VName
name, Volatility
vol, PrimType
t, Exp
e, Code op
code'')

nextCode :: Code op -> Maybe (Code op, Code op)
nextCode :: Code op -> Maybe (Code op, Code op)
nextCode (Code op
x :>>: Code op
y)
  | Just (Code op
x_a, Code op
x_b) <- Code op -> Maybe (Code op, Code op)
forall op. Code op -> Maybe (Code op, Code op)
nextCode Code op
x =
    (Code op, Code op) -> Maybe (Code op, Code op)
forall a. a -> Maybe a
Just (Code op
x_a, Code op
x_b Code op -> Code op -> Code op
forall a. Semigroup a => a -> a -> a
<> Code op
y)
  | Bool
otherwise =
    (Code op, Code op) -> Maybe (Code op, Code op)
forall a. a -> Maybe a
Just (Code op
x, Code op
y)
nextCode Code op
_ = Maybe (Code op, Code op)
forall a. Maybe a
Nothing

assignmentOperator :: BinOp -> Maybe (VName -> C.Exp -> C.Exp)
assignmentOperator :: BinOp -> Maybe (VName -> Exp -> Exp)
assignmentOperator Add {} = (VName -> Exp -> Exp) -> Maybe (VName -> Exp -> Exp)
forall a. a -> Maybe a
Just ((VName -> Exp -> Exp) -> Maybe (VName -> Exp -> Exp))
-> (VName -> Exp -> Exp) -> Maybe (VName -> Exp -> Exp)
forall a b. (a -> b) -> a -> b
$ \VName
d Exp
e -> [C.cexp|$id:d += $exp:e|]
assignmentOperator Sub {} = (VName -> Exp -> Exp) -> Maybe (VName -> Exp -> Exp)
forall a. a -> Maybe a
Just ((VName -> Exp -> Exp) -> Maybe (VName -> Exp -> Exp))
-> (VName -> Exp -> Exp) -> Maybe (VName -> Exp -> Exp)
forall a b. (a -> b) -> a -> b
$ \VName
d Exp
e -> [C.cexp|$id:d -= $exp:e|]
assignmentOperator Mul {} = (VName -> Exp -> Exp) -> Maybe (VName -> Exp -> Exp)
forall a. a -> Maybe a
Just ((VName -> Exp -> Exp) -> Maybe (VName -> Exp -> Exp))
-> (VName -> Exp -> Exp) -> Maybe (VName -> Exp -> Exp)
forall a b. (a -> b) -> a -> b
$ \VName
d Exp
e -> [C.cexp|$id:d *= $exp:e|]
assignmentOperator BinOp
_ = Maybe (VName -> Exp -> Exp)
forall a. Maybe a
Nothing

-- | Return an expression multiplying together the given expressions.
-- If an empty list is given, the expression @1@ is returned.
cproduct :: [C.Exp] -> C.Exp
cproduct :: [Exp] -> Exp
cproduct [] = [C.cexp|1|]
cproduct (Exp
e : [Exp]
es) = (Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Exp -> Exp -> Exp
forall a a. (ToExp a, ToExp a) => a -> a -> Exp
mult Exp
e [Exp]
es
  where
    mult :: a -> a -> Exp
mult a
x a
y = [C.cexp|$exp:x * $exp:y|]