{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE Trustworthy #-}
{-# LANGUAGE TupleSections #-}

-- | C code generator framework.
module Futhark.CodeGen.Backends.GenericC
  ( compileProg,
    CParts (..),
    asLibrary,
    asExecutable,
    asServer,

    -- * Pluggable compiler
    Operations (..),
    defaultOperations,
    OpCompiler,
    ErrorCompiler,
    CallCompiler,
    PointerQuals,
    MemoryType,
    WriteScalar,
    writeScalarPointerWithQuals,
    ReadScalar,
    readScalarPointerWithQuals,
    Allocate,
    Deallocate,
    Copy,
    StaticArray,

    -- * Monadic compiler interface
    CompilerM,
    CompilerState (compUserState, compNameSrc),
    getUserState,
    modifyUserState,
    contextContents,
    contextFinalInits,
    runCompilerM,
    inNewFunction,
    cachingMemory,
    blockScope,
    compileFun,
    compileCode,
    compileExp,
    compilePrimExp,
    compileExpToName,
    rawMem,
    item,
    items,
    stm,
    stms,
    decl,
    atInit,
    headerDecl,
    publicDef,
    publicDef_,
    profileReport,
    onClear,
    HeaderSection (..),
    libDecl,
    earlyDecl,
    publicName,
    contextType,
    contextField,
    memToCType,
    cacheMem,
    fatMemory,
    rawMemCType,
    cproduct,
    fatMemType,

    -- * Building Blocks
    primTypeToCType,
    intTypeToCType,
    copyMemoryDefaultSpace,
  )
where

import Control.Monad.Identity
import Control.Monad.Reader
import Control.Monad.State
import Data.Bifunctor (first)
import qualified Data.DList as DL
import Data.Loc
import qualified Data.Map.Strict as M
import Data.Maybe
import qualified Data.Text as T
import Futhark.CodeGen.Backends.GenericC.CLI (cliDefs)
import Futhark.CodeGen.Backends.GenericC.Options
import Futhark.CodeGen.Backends.GenericC.Server (serverDefs)
import Futhark.CodeGen.Backends.SimpleRep
import Futhark.CodeGen.ImpCode
import Futhark.CodeGen.RTS.C (halfH, lockH, timingH, utilH)
import Futhark.IR.Prop (isBuiltInFunction)
import Futhark.MonadFreshNames
import Futhark.Util.Pretty (prettyText)
import qualified Language.C.Quote.OpenCL as C
import qualified Language.C.Syntax as C
import NeatInterpolation (untrimming)

-- How public an array type definition sould be.  Public types show up
-- in the generated API, while private types are used only to
-- implement the members of opaques.
data Publicness = Private | Public
  deriving (Publicness -> Publicness -> Bool
(Publicness -> Publicness -> Bool)
-> (Publicness -> Publicness -> Bool) -> Eq Publicness
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Publicness -> Publicness -> Bool
$c/= :: Publicness -> Publicness -> Bool
== :: Publicness -> Publicness -> Bool
$c== :: Publicness -> Publicness -> Bool
Eq, Eq Publicness
Eq Publicness
-> (Publicness -> Publicness -> Ordering)
-> (Publicness -> Publicness -> Bool)
-> (Publicness -> Publicness -> Bool)
-> (Publicness -> Publicness -> Bool)
-> (Publicness -> Publicness -> Bool)
-> (Publicness -> Publicness -> Publicness)
-> (Publicness -> Publicness -> Publicness)
-> Ord Publicness
Publicness -> Publicness -> Bool
Publicness -> Publicness -> Ordering
Publicness -> Publicness -> Publicness
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Publicness -> Publicness -> Publicness
$cmin :: Publicness -> Publicness -> Publicness
max :: Publicness -> Publicness -> Publicness
$cmax :: Publicness -> Publicness -> Publicness
>= :: Publicness -> Publicness -> Bool
$c>= :: Publicness -> Publicness -> Bool
> :: Publicness -> Publicness -> Bool
$c> :: Publicness -> Publicness -> Bool
<= :: Publicness -> Publicness -> Bool
$c<= :: Publicness -> Publicness -> Bool
< :: Publicness -> Publicness -> Bool
$c< :: Publicness -> Publicness -> Bool
compare :: Publicness -> Publicness -> Ordering
$ccompare :: Publicness -> Publicness -> Ordering
$cp1Ord :: Eq Publicness
Ord, Int -> Publicness -> ShowS
[Publicness] -> ShowS
Publicness -> String
(Int -> Publicness -> ShowS)
-> (Publicness -> String)
-> ([Publicness] -> ShowS)
-> Show Publicness
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Publicness] -> ShowS
$cshowList :: [Publicness] -> ShowS
show :: Publicness -> String
$cshow :: Publicness -> String
showsPrec :: Int -> Publicness -> ShowS
$cshowsPrec :: Int -> Publicness -> ShowS
Show)

type ArrayType = (Space, Signedness, PrimType, Int)

data CompilerState s = CompilerState
  { CompilerState s -> Map ArrayType Publicness
compArrayTypes :: M.Map ArrayType Publicness,
    CompilerState s -> Map String [ValueDesc]
compOpaqueTypes :: M.Map String [ValueDesc],
    CompilerState s -> DList Definition
compEarlyDecls :: DL.DList C.Definition,
    CompilerState s -> [Stm]
compInit :: [C.Stm],
    CompilerState s -> VNameSource
compNameSrc :: VNameSource,
    CompilerState s -> s
compUserState :: s,
    CompilerState s -> Map HeaderSection (DList Definition)
compHeaderDecls :: M.Map HeaderSection (DL.DList C.Definition),
    CompilerState s -> DList Definition
compLibDecls :: DL.DList C.Definition,
    CompilerState s -> DList (Id, Type, Maybe Exp)
compCtxFields :: DL.DList (C.Id, C.Type, Maybe C.Exp),
    CompilerState s -> DList BlockItem
compProfileItems :: DL.DList C.BlockItem,
    CompilerState s -> DList BlockItem
compClearItems :: DL.DList C.BlockItem,
    CompilerState s -> [(VName, Space)]
compDeclaredMem :: [(VName, Space)],
    CompilerState s -> DList BlockItem
compItems :: DL.DList C.BlockItem
  }

newCompilerState :: VNameSource -> s -> CompilerState s
newCompilerState :: VNameSource -> s -> CompilerState s
newCompilerState VNameSource
src s
s =
  CompilerState :: forall s.
Map ArrayType Publicness
-> Map String [ValueDesc]
-> DList Definition
-> [Stm]
-> VNameSource
-> s
-> Map HeaderSection (DList Definition)
-> DList Definition
-> DList (Id, Type, Maybe Exp)
-> DList BlockItem
-> DList BlockItem
-> [(VName, Space)]
-> DList BlockItem
-> CompilerState s
CompilerState
    { compArrayTypes :: Map ArrayType Publicness
compArrayTypes = Map ArrayType Publicness
forall a. Monoid a => a
mempty,
      compOpaqueTypes :: Map String [ValueDesc]
compOpaqueTypes = Map String [ValueDesc]
forall a. Monoid a => a
mempty,
      compEarlyDecls :: DList Definition
compEarlyDecls = DList Definition
forall a. Monoid a => a
mempty,
      compInit :: [Stm]
compInit = [],
      compNameSrc :: VNameSource
compNameSrc = VNameSource
src,
      compUserState :: s
compUserState = s
s,
      compHeaderDecls :: Map HeaderSection (DList Definition)
compHeaderDecls = Map HeaderSection (DList Definition)
forall a. Monoid a => a
mempty,
      compLibDecls :: DList Definition
compLibDecls = DList Definition
forall a. Monoid a => a
mempty,
      compCtxFields :: DList (Id, Type, Maybe Exp)
compCtxFields = DList (Id, Type, Maybe Exp)
forall a. Monoid a => a
mempty,
      compProfileItems :: DList BlockItem
compProfileItems = DList BlockItem
forall a. Monoid a => a
mempty,
      compClearItems :: DList BlockItem
compClearItems = DList BlockItem
forall a. Monoid a => a
mempty,
      compDeclaredMem :: [(VName, Space)]
compDeclaredMem = [(VName, Space)]
forall a. Monoid a => a
mempty,
      compItems :: DList BlockItem
compItems = DList BlockItem
forall a. Monoid a => a
mempty
    }

-- | In which part of the header file we put the declaration.  This is
-- to ensure that the header file remains structured and readable.
data HeaderSection
  = ArrayDecl String
  | OpaqueDecl String
  | EntryDecl
  | MiscDecl
  | InitDecl
  deriving (HeaderSection -> HeaderSection -> Bool
(HeaderSection -> HeaderSection -> Bool)
-> (HeaderSection -> HeaderSection -> Bool) -> Eq HeaderSection
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: HeaderSection -> HeaderSection -> Bool
$c/= :: HeaderSection -> HeaderSection -> Bool
== :: HeaderSection -> HeaderSection -> Bool
$c== :: HeaderSection -> HeaderSection -> Bool
Eq, Eq HeaderSection
Eq HeaderSection
-> (HeaderSection -> HeaderSection -> Ordering)
-> (HeaderSection -> HeaderSection -> Bool)
-> (HeaderSection -> HeaderSection -> Bool)
-> (HeaderSection -> HeaderSection -> Bool)
-> (HeaderSection -> HeaderSection -> Bool)
-> (HeaderSection -> HeaderSection -> HeaderSection)
-> (HeaderSection -> HeaderSection -> HeaderSection)
-> Ord HeaderSection
HeaderSection -> HeaderSection -> Bool
HeaderSection -> HeaderSection -> Ordering
HeaderSection -> HeaderSection -> HeaderSection
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: HeaderSection -> HeaderSection -> HeaderSection
$cmin :: HeaderSection -> HeaderSection -> HeaderSection
max :: HeaderSection -> HeaderSection -> HeaderSection
$cmax :: HeaderSection -> HeaderSection -> HeaderSection
>= :: HeaderSection -> HeaderSection -> Bool
$c>= :: HeaderSection -> HeaderSection -> Bool
> :: HeaderSection -> HeaderSection -> Bool
$c> :: HeaderSection -> HeaderSection -> Bool
<= :: HeaderSection -> HeaderSection -> Bool
$c<= :: HeaderSection -> HeaderSection -> Bool
< :: HeaderSection -> HeaderSection -> Bool
$c< :: HeaderSection -> HeaderSection -> Bool
compare :: HeaderSection -> HeaderSection -> Ordering
$ccompare :: HeaderSection -> HeaderSection -> Ordering
$cp1Ord :: Eq HeaderSection
Ord)

-- | A substitute expression compiler, tried before the main
-- compilation function.
type OpCompiler op s = op -> CompilerM op s ()

type ErrorCompiler op s = ErrorMsg Exp -> String -> CompilerM op s ()

-- | The address space qualifiers for a pointer of the given type with
-- the given annotation.
type PointerQuals op s = String -> CompilerM op s [C.TypeQual]

-- | The type of a memory block in the given memory space.
type MemoryType op s = SpaceId -> CompilerM op s C.Type

-- | Write a scalar to the given memory block with the given element
-- index and in the given memory space.
type WriteScalar op s =
  C.Exp -> C.Exp -> C.Type -> SpaceId -> Volatility -> C.Exp -> CompilerM op s ()

-- | Read a scalar from the given memory block with the given element
-- index and in the given memory space.
type ReadScalar op s =
  C.Exp -> C.Exp -> C.Type -> SpaceId -> Volatility -> CompilerM op s C.Exp

-- | Allocate a memory block of the given size and with the given tag
-- in the given memory space, saving a reference in the given variable
-- name.
type Allocate op s =
  C.Exp ->
  C.Exp ->
  C.Exp ->
  SpaceId ->
  CompilerM op s ()

-- | De-allocate the given memory block with the given tag, which is
-- in the given memory space.
type Deallocate op s = C.Exp -> C.Exp -> SpaceId -> CompilerM op s ()

-- | Create a static array of values - initialised at load time.
type StaticArray op s = VName -> SpaceId -> PrimType -> ArrayContents -> CompilerM op s ()

-- | Copy from one memory block to another.
type Copy op s =
  C.Exp ->
  C.Exp ->
  Space ->
  C.Exp ->
  C.Exp ->
  Space ->
  C.Exp ->
  CompilerM op s ()

-- | Call a function.
type CallCompiler op s = [VName] -> Name -> [C.Exp] -> CompilerM op s ()

data Operations op s = Operations
  { Operations op s -> WriteScalar op s
opsWriteScalar :: WriteScalar op s,
    Operations op s -> ReadScalar op s
opsReadScalar :: ReadScalar op s,
    Operations op s -> Allocate op s
opsAllocate :: Allocate op s,
    Operations op s -> Deallocate op s
opsDeallocate :: Deallocate op s,
    Operations op s -> Copy op s
opsCopy :: Copy op s,
    Operations op s -> StaticArray op s
opsStaticArray :: StaticArray op s,
    Operations op s -> MemoryType op s
opsMemoryType :: MemoryType op s,
    Operations op s -> OpCompiler op s
opsCompiler :: OpCompiler op s,
    Operations op s -> ErrorCompiler op s
opsError :: ErrorCompiler op s,
    Operations op s -> CallCompiler op s
opsCall :: CallCompiler op s,
    -- | If true, use reference counting.  Otherwise, bare
    -- pointers.
    Operations op s -> Bool
opsFatMemory :: Bool,
    -- | Code to bracket critical sections.
    Operations op s -> ([BlockItem], [BlockItem])
opsCritical :: ([C.BlockItem], [C.BlockItem])
  }

errorMsgString :: ErrorMsg Exp -> CompilerM op s (String, [C.Exp])
errorMsgString :: ErrorMsg Exp -> CompilerM op s (String, [Exp])
errorMsgString (ErrorMsg [ErrorMsgPart Exp]
parts) = do
  let boolStr :: a -> Exp
boolStr a
e = [C.cexp|($exp:e) ? "true" : "false"|]
      asLongLong :: a -> Exp
asLongLong a
e = [C.cexp|(long long int)$exp:e|]
      asDouble :: a -> Exp
asDouble a
e = [C.cexp|(double)$exp:e|]
      onPart :: ErrorMsgPart Exp -> CompilerM op s (a, Exp)
onPart (ErrorString String
s) = (a, Exp) -> CompilerM op s (a, Exp)
forall (m :: * -> *) a. Monad m => a -> m a
return (a
"%s", [C.cexp|$string:s|])
      onPart (ErrorVal PrimType
Bool Exp
x) = (a
"%s",) (Exp -> (a, Exp)) -> (Exp -> Exp) -> Exp -> (a, Exp)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp -> Exp
forall a. ToExp a => a -> Exp
boolStr (Exp -> (a, Exp)) -> CompilerM op s Exp -> CompilerM op s (a, Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
x
      onPart (ErrorVal PrimType
Unit Exp
_) = (a, Exp) -> CompilerM op s (a, Exp)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
"%s", [C.cexp|"()"|])
      onPart (ErrorVal (IntType IntType
Int8) Exp
x) = (a
"%hhd",) (Exp -> (a, Exp)) -> CompilerM op s Exp -> CompilerM op s (a, Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
x
      onPart (ErrorVal (IntType IntType
Int16) Exp
x) = (a
"%hd",) (Exp -> (a, Exp)) -> CompilerM op s Exp -> CompilerM op s (a, Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
x
      onPart (ErrorVal (IntType IntType
Int32) Exp
x) = (a
"%d",) (Exp -> (a, Exp)) -> CompilerM op s Exp -> CompilerM op s (a, Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
x
      onPart (ErrorVal (IntType IntType
Int64) Exp
x) = (a
"%lld",) (Exp -> (a, Exp)) -> (Exp -> Exp) -> Exp -> (a, Exp)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp -> Exp
forall a. ToExp a => a -> Exp
asLongLong (Exp -> (a, Exp)) -> CompilerM op s Exp -> CompilerM op s (a, Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
x
      onPart (ErrorVal (FloatType FloatType
Float16) Exp
x) = (a
"%f",) (Exp -> (a, Exp)) -> (Exp -> Exp) -> Exp -> (a, Exp)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp -> Exp
forall a. ToExp a => a -> Exp
asDouble (Exp -> (a, Exp)) -> CompilerM op s Exp -> CompilerM op s (a, Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
x
      onPart (ErrorVal (FloatType FloatType
Float32) Exp
x) = (a
"%f",) (Exp -> (a, Exp)) -> (Exp -> Exp) -> Exp -> (a, Exp)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp -> Exp
forall a. ToExp a => a -> Exp
asDouble (Exp -> (a, Exp)) -> CompilerM op s Exp -> CompilerM op s (a, Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
x
      onPart (ErrorVal (FloatType FloatType
Float64) Exp
x) = (a
"%f",) (Exp -> (a, Exp)) -> CompilerM op s Exp -> CompilerM op s (a, Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
x
  ([String]
formatstrs, [Exp]
formatargs) <- [(String, Exp)] -> ([String], [Exp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(String, Exp)] -> ([String], [Exp]))
-> CompilerM op s [(String, Exp)]
-> CompilerM op s ([String], [Exp])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ErrorMsgPart Exp -> CompilerM op s (String, Exp))
-> [ErrorMsgPart Exp] -> CompilerM op s [(String, Exp)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ErrorMsgPart Exp -> CompilerM op s (String, Exp)
forall a op s.
IsString a =>
ErrorMsgPart Exp -> CompilerM op s (a, Exp)
onPart [ErrorMsgPart Exp]
parts
  (String, [Exp]) -> CompilerM op s (String, [Exp])
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([String] -> String
forall a. Monoid a => [a] -> a
mconcat [String]
formatstrs, [Exp]
formatargs)

defError :: ErrorCompiler op s
defError :: ErrorCompiler op s
defError ErrorMsg Exp
msg String
stacktrace = do
  [BlockItem]
free_all_mem <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
collect (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ ((VName, Space) -> CompilerM op s ())
-> [(VName, Space)] -> CompilerM op s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((VName -> Space -> CompilerM op s ())
-> (VName, Space) -> CompilerM op s ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry VName -> Space -> CompilerM op s ()
forall a op s. ToExp a => a -> Space -> CompilerM op s ()
unRefMem) ([(VName, Space)] -> CompilerM op s ())
-> CompilerM op s [(VName, Space)] -> CompilerM op s ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (CompilerState s -> [(VName, Space)])
-> CompilerM op s [(VName, Space)]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets CompilerState s -> [(VName, Space)]
forall s. CompilerState s -> [(VName, Space)]
compDeclaredMem
  (String
formatstr, [Exp]
formatargs) <- ErrorMsg Exp -> CompilerM op s (String, [Exp])
forall op s. ErrorMsg Exp -> CompilerM op s (String, [Exp])
errorMsgString ErrorMsg Exp
msg
  let formatstr' :: String
formatstr' = String
"Error: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
formatstr String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
"\n\nBacktrace:\n%s"
  [BlockItem] -> CompilerM op s ()
forall op s. [BlockItem] -> CompilerM op s ()
items
    [C.citems|ctx->error = msgprintf($string:formatstr', $args:formatargs, $string:stacktrace);
                  $items:free_all_mem
                  return 1;|]

defCall :: CallCompiler op s
defCall :: CallCompiler op s
defCall [VName]
dests Name
fname [Exp]
args = do
  let out_args :: [Exp]
out_args = [[C.cexp|&$id:d|] | VName
d <- [VName]
dests]
      args' :: [Exp]
args'
        | Name -> Bool
isBuiltInFunction Name
fname = [Exp]
args
        | Bool
otherwise = [C.cexp|ctx|] Exp -> [Exp] -> [Exp]
forall a. a -> [a] -> [a]
: [Exp]
out_args [Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++ [Exp]
args
  case [VName]
dests of
    [VName
dest]
      | Name -> Bool
isBuiltInFunction Name
fname ->
        Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|$id:dest = $id:(funName fname)($args:args');|]
    [VName]
_ ->
      BlockItem -> CompilerM op s ()
forall op s. BlockItem -> CompilerM op s ()
item [C.citem|if ($id:(funName fname)($args:args') != 0) { err = 1; goto cleanup; }|]

-- | A set of operations that fail for every operation involving
-- non-default memory spaces.  Uses plain pointers and @malloc@ for
-- memory management.
defaultOperations :: Operations op s
defaultOperations :: Operations op s
defaultOperations =
  Operations :: forall op s.
WriteScalar op s
-> ReadScalar op s
-> Allocate op s
-> Deallocate op s
-> Copy op s
-> StaticArray op s
-> MemoryType op s
-> OpCompiler op s
-> ErrorCompiler op s
-> CallCompiler op s
-> Bool
-> ([BlockItem], [BlockItem])
-> Operations op s
Operations
    { opsWriteScalar :: WriteScalar op s
opsWriteScalar = WriteScalar op s
forall p p p p p a. p -> p -> p -> p -> p -> a
defWriteScalar,
      opsReadScalar :: ReadScalar op s
opsReadScalar = ReadScalar op s
forall p p p p a. p -> p -> p -> p -> a
defReadScalar,
      opsAllocate :: Allocate op s
opsAllocate = Allocate op s
forall p p p a. p -> p -> p -> a
defAllocate,
      opsDeallocate :: Deallocate op s
opsDeallocate = Deallocate op s
forall p p a. p -> p -> a
defDeallocate,
      opsCopy :: Copy op s
opsCopy = Copy op s
forall op s.
Exp
-> Exp -> Space -> Exp -> Exp -> Space -> Exp -> CompilerM op s ()
defCopy,
      opsStaticArray :: StaticArray op s
opsStaticArray = StaticArray op s
forall p p p p a. p -> p -> p -> p -> a
defStaticArray,
      opsMemoryType :: MemoryType op s
opsMemoryType = MemoryType op s
forall p a. p -> a
defMemoryType,
      opsCompiler :: OpCompiler op s
opsCompiler = OpCompiler op s
forall p a. p -> a
defCompiler,
      opsFatMemory :: Bool
opsFatMemory = Bool
True,
      opsError :: ErrorCompiler op s
opsError = ErrorCompiler op s
forall op s. ErrorCompiler op s
defError,
      opsCall :: CallCompiler op s
opsCall = CallCompiler op s
forall op s. CallCompiler op s
defCall,
      opsCritical :: ([BlockItem], [BlockItem])
opsCritical = ([BlockItem], [BlockItem])
forall a. Monoid a => a
mempty
    }
  where
    defWriteScalar :: p -> p -> p -> p -> p -> a
defWriteScalar p
_ p
_ p
_ p
_ p
_ =
      String -> a
forall a. HasCallStack => String -> a
error String
"Cannot write to non-default memory space because I am dumb"
    defReadScalar :: p -> p -> p -> p -> a
defReadScalar p
_ p
_ p
_ p
_ =
      String -> a
forall a. HasCallStack => String -> a
error String
"Cannot read from non-default memory space"
    defAllocate :: p -> p -> p -> a
defAllocate p
_ p
_ p
_ =
      String -> a
forall a. HasCallStack => String -> a
error String
"Cannot allocate in non-default memory space"
    defDeallocate :: p -> p -> a
defDeallocate p
_ p
_ =
      String -> a
forall a. HasCallStack => String -> a
error String
"Cannot deallocate in non-default memory space"
    defCopy :: Exp
-> Exp -> Space -> Exp -> Exp -> Space -> Exp -> CompilerM op s ()
defCopy Exp
destmem Exp
destoffset Space
DefaultSpace Exp
srcmem Exp
srcoffset Space
DefaultSpace Exp
size =
      Exp -> Exp -> Exp -> Exp -> Exp -> CompilerM op s ()
forall op s. Exp -> Exp -> Exp -> Exp -> Exp -> CompilerM op s ()
copyMemoryDefaultSpace Exp
destmem Exp
destoffset Exp
srcmem Exp
srcoffset Exp
size
    defCopy Exp
_ Exp
_ Space
_ Exp
_ Exp
_ Space
_ Exp
_ =
      String -> CompilerM op s ()
forall a. HasCallStack => String -> a
error String
"Cannot copy to or from non-default memory space"
    defStaticArray :: p -> p -> p -> p -> a
defStaticArray p
_ p
_ p
_ p
_ =
      String -> a
forall a. HasCallStack => String -> a
error String
"Cannot create static array in non-default memory space"
    defMemoryType :: p -> a
defMemoryType p
_ =
      String -> a
forall a. HasCallStack => String -> a
error String
"Has no type for non-default memory space"
    defCompiler :: p -> a
defCompiler p
_ =
      String -> a
forall a. HasCallStack => String -> a
error String
"The default compiler cannot compile extended operations"

data CompilerEnv op s = CompilerEnv
  { CompilerEnv op s -> Operations op s
envOperations :: Operations op s,
    -- | Mapping memory blocks to sizes.  These memory blocks are CPU
    -- memory that we know are used in particularly simple ways (no
    -- reference counting necessary).  To cut down on allocator
    -- pressure, we keep these allocations around for a long time, and
    -- record their sizes so we can reuse them if possible (and
    -- realloc() when needed).
    CompilerEnv op s -> Map Exp VName
envCachedMem :: M.Map C.Exp VName
  }

envOpCompiler :: CompilerEnv op s -> OpCompiler op s
envOpCompiler :: CompilerEnv op s -> OpCompiler op s
envOpCompiler = Operations op s -> OpCompiler op s
forall op s. Operations op s -> OpCompiler op s
opsCompiler (Operations op s -> OpCompiler op s)
-> (CompilerEnv op s -> Operations op s)
-> CompilerEnv op s
-> OpCompiler op s
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerEnv op s -> Operations op s
forall op s. CompilerEnv op s -> Operations op s
envOperations

envMemoryType :: CompilerEnv op s -> MemoryType op s
envMemoryType :: CompilerEnv op s -> MemoryType op s
envMemoryType = Operations op s -> MemoryType op s
forall op s. Operations op s -> MemoryType op s
opsMemoryType (Operations op s -> MemoryType op s)
-> (CompilerEnv op s -> Operations op s)
-> CompilerEnv op s
-> MemoryType op s
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerEnv op s -> Operations op s
forall op s. CompilerEnv op s -> Operations op s
envOperations

envReadScalar :: CompilerEnv op s -> ReadScalar op s
envReadScalar :: CompilerEnv op s -> ReadScalar op s
envReadScalar = Operations op s -> ReadScalar op s
forall op s. Operations op s -> ReadScalar op s
opsReadScalar (Operations op s -> ReadScalar op s)
-> (CompilerEnv op s -> Operations op s)
-> CompilerEnv op s
-> ReadScalar op s
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerEnv op s -> Operations op s
forall op s. CompilerEnv op s -> Operations op s
envOperations

envWriteScalar :: CompilerEnv op s -> WriteScalar op s
envWriteScalar :: CompilerEnv op s -> WriteScalar op s
envWriteScalar = Operations op s -> WriteScalar op s
forall op s. Operations op s -> WriteScalar op s
opsWriteScalar (Operations op s -> WriteScalar op s)
-> (CompilerEnv op s -> Operations op s)
-> CompilerEnv op s
-> WriteScalar op s
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerEnv op s -> Operations op s
forall op s. CompilerEnv op s -> Operations op s
envOperations

envAllocate :: CompilerEnv op s -> Allocate op s
envAllocate :: CompilerEnv op s -> Allocate op s
envAllocate = Operations op s -> Allocate op s
forall op s. Operations op s -> Allocate op s
opsAllocate (Operations op s -> Allocate op s)
-> (CompilerEnv op s -> Operations op s)
-> CompilerEnv op s
-> Allocate op s
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerEnv op s -> Operations op s
forall op s. CompilerEnv op s -> Operations op s
envOperations

envDeallocate :: CompilerEnv op s -> Deallocate op s
envDeallocate :: CompilerEnv op s -> Deallocate op s
envDeallocate = Operations op s -> Deallocate op s
forall op s. Operations op s -> Deallocate op s
opsDeallocate (Operations op s -> Deallocate op s)
-> (CompilerEnv op s -> Operations op s)
-> CompilerEnv op s
-> Deallocate op s
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerEnv op s -> Operations op s
forall op s. CompilerEnv op s -> Operations op s
envOperations

envCopy :: CompilerEnv op s -> Copy op s
envCopy :: CompilerEnv op s -> Copy op s
envCopy = Operations op s -> Copy op s
forall op s. Operations op s -> Copy op s
opsCopy (Operations op s -> Copy op s)
-> (CompilerEnv op s -> Operations op s)
-> CompilerEnv op s
-> Copy op s
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerEnv op s -> Operations op s
forall op s. CompilerEnv op s -> Operations op s
envOperations

envStaticArray :: CompilerEnv op s -> StaticArray op s
envStaticArray :: CompilerEnv op s -> StaticArray op s
envStaticArray = Operations op s -> StaticArray op s
forall op s. Operations op s -> StaticArray op s
opsStaticArray (Operations op s -> StaticArray op s)
-> (CompilerEnv op s -> Operations op s)
-> CompilerEnv op s
-> StaticArray op s
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerEnv op s -> Operations op s
forall op s. CompilerEnv op s -> Operations op s
envOperations

envFatMemory :: CompilerEnv op s -> Bool
envFatMemory :: CompilerEnv op s -> Bool
envFatMemory = Operations op s -> Bool
forall op s. Operations op s -> Bool
opsFatMemory (Operations op s -> Bool)
-> (CompilerEnv op s -> Operations op s)
-> CompilerEnv op s
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerEnv op s -> Operations op s
forall op s. CompilerEnv op s -> Operations op s
envOperations

declsCode :: (HeaderSection -> Bool) -> CompilerState s -> T.Text
declsCode :: (HeaderSection -> Bool) -> CompilerState s -> Text
declsCode HeaderSection -> Bool
p =
  [Text] -> Text
T.unlines
    ([Text] -> Text)
-> (CompilerState s -> [Text]) -> CompilerState s -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Definition -> Text) -> [Definition] -> [Text]
forall a b. (a -> b) -> [a] -> [b]
map Definition -> Text
forall a. Pretty a => a -> Text
prettyText
    ([Definition] -> [Text])
-> (CompilerState s -> [Definition]) -> CompilerState s -> [Text]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((HeaderSection, DList Definition) -> [Definition])
-> [(HeaderSection, DList Definition)] -> [Definition]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (DList Definition -> [Definition]
forall a. DList a -> [a]
DL.toList (DList Definition -> [Definition])
-> ((HeaderSection, DList Definition) -> DList Definition)
-> (HeaderSection, DList Definition)
-> [Definition]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HeaderSection, DList Definition) -> DList Definition
forall a b. (a, b) -> b
snd)
    ([(HeaderSection, DList Definition)] -> [Definition])
-> (CompilerState s -> [(HeaderSection, DList Definition)])
-> CompilerState s
-> [Definition]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((HeaderSection, DList Definition) -> Bool)
-> [(HeaderSection, DList Definition)]
-> [(HeaderSection, DList Definition)]
forall a. (a -> Bool) -> [a] -> [a]
filter (HeaderSection -> Bool
p (HeaderSection -> Bool)
-> ((HeaderSection, DList Definition) -> HeaderSection)
-> (HeaderSection, DList Definition)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HeaderSection, DList Definition) -> HeaderSection
forall a b. (a, b) -> a
fst)
    ([(HeaderSection, DList Definition)]
 -> [(HeaderSection, DList Definition)])
-> (CompilerState s -> [(HeaderSection, DList Definition)])
-> CompilerState s
-> [(HeaderSection, DList Definition)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map HeaderSection (DList Definition)
-> [(HeaderSection, DList Definition)]
forall k a. Map k a -> [(k, a)]
M.toList
    (Map HeaderSection (DList Definition)
 -> [(HeaderSection, DList Definition)])
-> (CompilerState s -> Map HeaderSection (DList Definition))
-> CompilerState s
-> [(HeaderSection, DList Definition)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerState s -> Map HeaderSection (DList Definition)
forall s. CompilerState s -> Map HeaderSection (DList Definition)
compHeaderDecls

initDecls, arrayDecls, opaqueDecls, entryDecls, miscDecls :: CompilerState s -> T.Text
initDecls :: CompilerState s -> Text
initDecls = (HeaderSection -> Bool) -> CompilerState s -> Text
forall s. (HeaderSection -> Bool) -> CompilerState s -> Text
declsCode (HeaderSection -> HeaderSection -> Bool
forall a. Eq a => a -> a -> Bool
== HeaderSection
InitDecl)
arrayDecls :: CompilerState s -> Text
arrayDecls = (HeaderSection -> Bool) -> CompilerState s -> Text
forall s. (HeaderSection -> Bool) -> CompilerState s -> Text
declsCode HeaderSection -> Bool
isArrayDecl
  where
    isArrayDecl :: HeaderSection -> Bool
isArrayDecl ArrayDecl {} = Bool
True
    isArrayDecl HeaderSection
_ = Bool
False
opaqueDecls :: CompilerState s -> Text
opaqueDecls = (HeaderSection -> Bool) -> CompilerState s -> Text
forall s. (HeaderSection -> Bool) -> CompilerState s -> Text
declsCode HeaderSection -> Bool
isOpaqueDecl
  where
    isOpaqueDecl :: HeaderSection -> Bool
isOpaqueDecl OpaqueDecl {} = Bool
True
    isOpaqueDecl HeaderSection
_ = Bool
False
entryDecls :: CompilerState s -> Text
entryDecls = (HeaderSection -> Bool) -> CompilerState s -> Text
forall s. (HeaderSection -> Bool) -> CompilerState s -> Text
declsCode (HeaderSection -> HeaderSection -> Bool
forall a. Eq a => a -> a -> Bool
== HeaderSection
EntryDecl)
miscDecls :: CompilerState s -> Text
miscDecls = (HeaderSection -> Bool) -> CompilerState s -> Text
forall s. (HeaderSection -> Bool) -> CompilerState s -> Text
declsCode (HeaderSection -> HeaderSection -> Bool
forall a. Eq a => a -> a -> Bool
== HeaderSection
MiscDecl)

contextContents :: CompilerM op s ([C.FieldGroup], [C.Stm])
contextContents :: CompilerM op s ([FieldGroup], [Stm])
contextContents = do
  ([Id]
field_names, [Type]
field_types, [Maybe Exp]
field_values) <- (CompilerState s -> ([Id], [Type], [Maybe Exp]))
-> CompilerM op s ([Id], [Type], [Maybe Exp])
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((CompilerState s -> ([Id], [Type], [Maybe Exp]))
 -> CompilerM op s ([Id], [Type], [Maybe Exp]))
-> (CompilerState s -> ([Id], [Type], [Maybe Exp]))
-> CompilerM op s ([Id], [Type], [Maybe Exp])
forall a b. (a -> b) -> a -> b
$ [(Id, Type, Maybe Exp)] -> ([Id], [Type], [Maybe Exp])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(Id, Type, Maybe Exp)] -> ([Id], [Type], [Maybe Exp]))
-> (CompilerState s -> [(Id, Type, Maybe Exp)])
-> CompilerState s
-> ([Id], [Type], [Maybe Exp])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DList (Id, Type, Maybe Exp) -> [(Id, Type, Maybe Exp)]
forall a. DList a -> [a]
DL.toList (DList (Id, Type, Maybe Exp) -> [(Id, Type, Maybe Exp)])
-> (CompilerState s -> DList (Id, Type, Maybe Exp))
-> CompilerState s
-> [(Id, Type, Maybe Exp)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerState s -> DList (Id, Type, Maybe Exp)
forall s. CompilerState s -> DList (Id, Type, Maybe Exp)
compCtxFields
  let fields :: [FieldGroup]
fields =
        [ [C.csdecl|$ty:ty $id:name;|]
          | (Id
name, Type
ty) <- [Id] -> [Type] -> [(Id, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Id]
field_names [Type]
field_types
        ]
      init_fields :: [Stm]
init_fields =
        [ [C.cstm|ctx->$id:name = $exp:e;|]
          | (Id
name, Just Exp
e) <- [Id] -> [Maybe Exp] -> [(Id, Maybe Exp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Id]
field_names [Maybe Exp]
field_values
        ]
  ([FieldGroup], [Stm]) -> CompilerM op s ([FieldGroup], [Stm])
forall (m :: * -> *) a. Monad m => a -> m a
return ([FieldGroup]
fields, [Stm]
init_fields)

contextFinalInits :: CompilerM op s [C.Stm]
contextFinalInits :: CompilerM op s [Stm]
contextFinalInits = (CompilerState s -> [Stm]) -> CompilerM op s [Stm]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets CompilerState s -> [Stm]
forall s. CompilerState s -> [Stm]
compInit

newtype CompilerM op s a
  = CompilerM (ReaderT (CompilerEnv op s) (State (CompilerState s)) a)
  deriving
    ( a -> CompilerM op s b -> CompilerM op s a
(a -> b) -> CompilerM op s a -> CompilerM op s b
(forall a b. (a -> b) -> CompilerM op s a -> CompilerM op s b)
-> (forall a b. a -> CompilerM op s b -> CompilerM op s a)
-> Functor (CompilerM op s)
forall a b. a -> CompilerM op s b -> CompilerM op s a
forall a b. (a -> b) -> CompilerM op s a -> CompilerM op s b
forall op s a b. a -> CompilerM op s b -> CompilerM op s a
forall op s a b. (a -> b) -> CompilerM op s a -> CompilerM op s b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> CompilerM op s b -> CompilerM op s a
$c<$ :: forall op s a b. a -> CompilerM op s b -> CompilerM op s a
fmap :: (a -> b) -> CompilerM op s a -> CompilerM op s b
$cfmap :: forall op s a b. (a -> b) -> CompilerM op s a -> CompilerM op s b
Functor,
      Functor (CompilerM op s)
a -> CompilerM op s a
Functor (CompilerM op s)
-> (forall a. a -> CompilerM op s a)
-> (forall a b.
    CompilerM op s (a -> b) -> CompilerM op s a -> CompilerM op s b)
-> (forall a b c.
    (a -> b -> c)
    -> CompilerM op s a -> CompilerM op s b -> CompilerM op s c)
-> (forall a b.
    CompilerM op s a -> CompilerM op s b -> CompilerM op s b)
-> (forall a b.
    CompilerM op s a -> CompilerM op s b -> CompilerM op s a)
-> Applicative (CompilerM op s)
CompilerM op s a -> CompilerM op s b -> CompilerM op s b
CompilerM op s a -> CompilerM op s b -> CompilerM op s a
CompilerM op s (a -> b) -> CompilerM op s a -> CompilerM op s b
(a -> b -> c)
-> CompilerM op s a -> CompilerM op s b -> CompilerM op s c
forall a. a -> CompilerM op s a
forall op s. Functor (CompilerM op s)
forall a b.
CompilerM op s a -> CompilerM op s b -> CompilerM op s a
forall a b.
CompilerM op s a -> CompilerM op s b -> CompilerM op s b
forall a b.
CompilerM op s (a -> b) -> CompilerM op s a -> CompilerM op s b
forall op s a. a -> CompilerM op s a
forall a b c.
(a -> b -> c)
-> CompilerM op s a -> CompilerM op s b -> CompilerM op s c
forall op s a b.
CompilerM op s a -> CompilerM op s b -> CompilerM op s a
forall op s a b.
CompilerM op s a -> CompilerM op s b -> CompilerM op s b
forall op s a b.
CompilerM op s (a -> b) -> CompilerM op s a -> CompilerM op s b
forall op s a b c.
(a -> b -> c)
-> CompilerM op s a -> CompilerM op s b -> CompilerM op s c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: CompilerM op s a -> CompilerM op s b -> CompilerM op s a
$c<* :: forall op s a b.
CompilerM op s a -> CompilerM op s b -> CompilerM op s a
*> :: CompilerM op s a -> CompilerM op s b -> CompilerM op s b
$c*> :: forall op s a b.
CompilerM op s a -> CompilerM op s b -> CompilerM op s b
liftA2 :: (a -> b -> c)
-> CompilerM op s a -> CompilerM op s b -> CompilerM op s c
$cliftA2 :: forall op s a b c.
(a -> b -> c)
-> CompilerM op s a -> CompilerM op s b -> CompilerM op s c
<*> :: CompilerM op s (a -> b) -> CompilerM op s a -> CompilerM op s b
$c<*> :: forall op s a b.
CompilerM op s (a -> b) -> CompilerM op s a -> CompilerM op s b
pure :: a -> CompilerM op s a
$cpure :: forall op s a. a -> CompilerM op s a
$cp1Applicative :: forall op s. Functor (CompilerM op s)
Applicative,
      Applicative (CompilerM op s)
a -> CompilerM op s a
Applicative (CompilerM op s)
-> (forall a b.
    CompilerM op s a -> (a -> CompilerM op s b) -> CompilerM op s b)
-> (forall a b.
    CompilerM op s a -> CompilerM op s b -> CompilerM op s b)
-> (forall a. a -> CompilerM op s a)
-> Monad (CompilerM op s)
CompilerM op s a -> (a -> CompilerM op s b) -> CompilerM op s b
CompilerM op s a -> CompilerM op s b -> CompilerM op s b
forall a. a -> CompilerM op s a
forall op s. Applicative (CompilerM op s)
forall a b.
CompilerM op s a -> CompilerM op s b -> CompilerM op s b
forall a b.
CompilerM op s a -> (a -> CompilerM op s b) -> CompilerM op s b
forall op s a. a -> CompilerM op s a
forall op s a b.
CompilerM op s a -> CompilerM op s b -> CompilerM op s b
forall op s a b.
CompilerM op s a -> (a -> CompilerM op s b) -> CompilerM op s b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> CompilerM op s a
$creturn :: forall op s a. a -> CompilerM op s a
>> :: CompilerM op s a -> CompilerM op s b -> CompilerM op s b
$c>> :: forall op s a b.
CompilerM op s a -> CompilerM op s b -> CompilerM op s b
>>= :: CompilerM op s a -> (a -> CompilerM op s b) -> CompilerM op s b
$c>>= :: forall op s a b.
CompilerM op s a -> (a -> CompilerM op s b) -> CompilerM op s b
$cp1Monad :: forall op s. Applicative (CompilerM op s)
Monad,
      MonadState (CompilerState s),
      MonadReader (CompilerEnv op s)
    )

instance MonadFreshNames (CompilerM op s) where
  getNameSource :: CompilerM op s VNameSource
getNameSource = (CompilerState s -> VNameSource) -> CompilerM op s VNameSource
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets CompilerState s -> VNameSource
forall s. CompilerState s -> VNameSource
compNameSrc
  putNameSource :: VNameSource -> CompilerM op s ()
putNameSource VNameSource
src = (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((CompilerState s -> CompilerState s) -> CompilerM op s ())
-> (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \CompilerState s
s -> CompilerState s
s {compNameSrc :: VNameSource
compNameSrc = VNameSource
src}

runCompilerM ::
  Operations op s ->
  VNameSource ->
  s ->
  CompilerM op s a ->
  (a, CompilerState s)
runCompilerM :: Operations op s
-> VNameSource -> s -> CompilerM op s a -> (a, CompilerState s)
runCompilerM Operations op s
ops VNameSource
src s
userstate (CompilerM ReaderT (CompilerEnv op s) (State (CompilerState s)) a
m) =
  State (CompilerState s) a
-> CompilerState s -> (a, CompilerState s)
forall s a. State s a -> s -> (a, s)
runState
    (ReaderT (CompilerEnv op s) (State (CompilerState s)) a
-> CompilerEnv op s -> State (CompilerState s) a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (CompilerEnv op s) (State (CompilerState s)) a
m (Operations op s -> Map Exp VName -> CompilerEnv op s
forall op s. Operations op s -> Map Exp VName -> CompilerEnv op s
CompilerEnv Operations op s
ops Map Exp VName
forall a. Monoid a => a
mempty))
    (VNameSource -> s -> CompilerState s
forall s. VNameSource -> s -> CompilerState s
newCompilerState VNameSource
src s
userstate)

getUserState :: CompilerM op s s
getUserState :: CompilerM op s s
getUserState = (CompilerState s -> s) -> CompilerM op s s
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets CompilerState s -> s
forall s. CompilerState s -> s
compUserState

modifyUserState :: (s -> s) -> CompilerM op s ()
modifyUserState :: (s -> s) -> CompilerM op s ()
modifyUserState s -> s
f = (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((CompilerState s -> CompilerState s) -> CompilerM op s ())
-> (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \CompilerState s
compstate ->
  CompilerState s
compstate {compUserState :: s
compUserState = s -> s
f (s -> s) -> s -> s
forall a b. (a -> b) -> a -> b
$ CompilerState s -> s
forall s. CompilerState s -> s
compUserState CompilerState s
compstate}

atInit :: C.Stm -> CompilerM op s ()
atInit :: Stm -> CompilerM op s ()
atInit Stm
x = (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((CompilerState s -> CompilerState s) -> CompilerM op s ())
-> (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \CompilerState s
s ->
  CompilerState s
s {compInit :: [Stm]
compInit = CompilerState s -> [Stm]
forall s. CompilerState s -> [Stm]
compInit CompilerState s
s [Stm] -> [Stm] -> [Stm]
forall a. [a] -> [a] -> [a]
++ [Stm
x]}

collect :: CompilerM op s () -> CompilerM op s [C.BlockItem]
collect :: CompilerM op s () -> CompilerM op s [BlockItem]
collect CompilerM op s ()
m = ((), [BlockItem]) -> [BlockItem]
forall a b. (a, b) -> b
snd (((), [BlockItem]) -> [BlockItem])
-> CompilerM op s ((), [BlockItem]) -> CompilerM op s [BlockItem]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CompilerM op s () -> CompilerM op s ((), [BlockItem])
forall op s a. CompilerM op s a -> CompilerM op s (a, [BlockItem])
collect' CompilerM op s ()
m

collect' :: CompilerM op s a -> CompilerM op s (a, [C.BlockItem])
collect' :: CompilerM op s a -> CompilerM op s (a, [BlockItem])
collect' CompilerM op s a
m = do
  DList BlockItem
old <- (CompilerState s -> DList BlockItem)
-> CompilerM op s (DList BlockItem)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets CompilerState s -> DList BlockItem
forall s. CompilerState s -> DList BlockItem
compItems
  (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((CompilerState s -> CompilerState s) -> CompilerM op s ())
-> (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \CompilerState s
s -> CompilerState s
s {compItems :: DList BlockItem
compItems = DList BlockItem
forall a. Monoid a => a
mempty}
  a
x <- CompilerM op s a
m
  DList BlockItem
new <- (CompilerState s -> DList BlockItem)
-> CompilerM op s (DList BlockItem)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets CompilerState s -> DList BlockItem
forall s. CompilerState s -> DList BlockItem
compItems
  (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((CompilerState s -> CompilerState s) -> CompilerM op s ())
-> (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \CompilerState s
s -> CompilerState s
s {compItems :: DList BlockItem
compItems = DList BlockItem
old}
  (a, [BlockItem]) -> CompilerM op s (a, [BlockItem])
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
x, DList BlockItem -> [BlockItem]
forall a. DList a -> [a]
DL.toList DList BlockItem
new)

-- | Used when we, inside an existing 'CompilerM' action, want to
-- generate code for a new function.  Use this so that the compiler
-- understands that previously declared memory doesn't need to be
-- freed inside this action.
inNewFunction :: Bool -> CompilerM op s a -> CompilerM op s a
inNewFunction :: Bool -> CompilerM op s a -> CompilerM op s a
inNewFunction Bool
keep_cached CompilerM op s a
m = do
  [(VName, Space)]
old_mem <- (CompilerState s -> [(VName, Space)])
-> CompilerM op s [(VName, Space)]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets CompilerState s -> [(VName, Space)]
forall s. CompilerState s -> [(VName, Space)]
compDeclaredMem
  (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((CompilerState s -> CompilerState s) -> CompilerM op s ())
-> (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \CompilerState s
s -> CompilerState s
s {compDeclaredMem :: [(VName, Space)]
compDeclaredMem = [(VName, Space)]
forall a. Monoid a => a
mempty}
  a
x <- (CompilerEnv op s -> CompilerEnv op s)
-> CompilerM op s a -> CompilerM op s a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local CompilerEnv op s -> CompilerEnv op s
forall op s. CompilerEnv op s -> CompilerEnv op s
noCached CompilerM op s a
m
  (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((CompilerState s -> CompilerState s) -> CompilerM op s ())
-> (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \CompilerState s
s -> CompilerState s
s {compDeclaredMem :: [(VName, Space)]
compDeclaredMem = [(VName, Space)]
old_mem}
  a -> CompilerM op s a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x
  where
    noCached :: CompilerEnv op s -> CompilerEnv op s
noCached CompilerEnv op s
env
      | Bool
keep_cached = CompilerEnv op s
env
      | Bool
otherwise = CompilerEnv op s
env {envCachedMem :: Map Exp VName
envCachedMem = Map Exp VName
forall a. Monoid a => a
mempty}

item :: C.BlockItem -> CompilerM op s ()
item :: BlockItem -> CompilerM op s ()
item BlockItem
x = (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((CompilerState s -> CompilerState s) -> CompilerM op s ())
-> (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \CompilerState s
s -> CompilerState s
s {compItems :: DList BlockItem
compItems = DList BlockItem -> BlockItem -> DList BlockItem
forall a. DList a -> a -> DList a
DL.snoc (CompilerState s -> DList BlockItem
forall s. CompilerState s -> DList BlockItem
compItems CompilerState s
s) BlockItem
x}

items :: [C.BlockItem] -> CompilerM op s ()
items :: [BlockItem] -> CompilerM op s ()
items [BlockItem]
xs = (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((CompilerState s -> CompilerState s) -> CompilerM op s ())
-> (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \CompilerState s
s -> CompilerState s
s {compItems :: DList BlockItem
compItems = DList BlockItem -> DList BlockItem -> DList BlockItem
forall a. DList a -> DList a -> DList a
DL.append (CompilerState s -> DList BlockItem
forall s. CompilerState s -> DList BlockItem
compItems CompilerState s
s) ([BlockItem] -> DList BlockItem
forall a. [a] -> DList a
DL.fromList [BlockItem]
xs)}

fatMemory :: Space -> CompilerM op s Bool
fatMemory :: Space -> CompilerM op s Bool
fatMemory ScalarSpace {} = Bool -> CompilerM op s Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
fatMemory Space
_ = (CompilerEnv op s -> Bool) -> CompilerM op s Bool
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks CompilerEnv op s -> Bool
forall op s. CompilerEnv op s -> Bool
envFatMemory

cacheMem :: C.ToExp a => a -> CompilerM op s (Maybe VName)
cacheMem :: a -> CompilerM op s (Maybe VName)
cacheMem a
a = (CompilerEnv op s -> Maybe VName) -> CompilerM op s (Maybe VName)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((CompilerEnv op s -> Maybe VName) -> CompilerM op s (Maybe VName))
-> (CompilerEnv op s -> Maybe VName)
-> CompilerM op s (Maybe VName)
forall a b. (a -> b) -> a -> b
$ Exp -> Map Exp VName -> Maybe VName
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (a -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp a
a SrcLoc
forall a. IsLocation a => a
noLoc) (Map Exp VName -> Maybe VName)
-> (CompilerEnv op s -> Map Exp VName)
-> CompilerEnv op s
-> Maybe VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerEnv op s -> Map Exp VName
forall op s. CompilerEnv op s -> Map Exp VName
envCachedMem

-- | Construct a publicly visible definition using the specified name
-- as the template.  The first returned definition is put in the
-- header file, and the second is the implementation.  Returns the public
-- name.
publicDef ::
  String ->
  HeaderSection ->
  (String -> (C.Definition, C.Definition)) ->
  CompilerM op s String
publicDef :: String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s String
publicDef String
s HeaderSection
h String -> (Definition, Definition)
f = do
  String
s' <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName String
s
  let (Definition
pub, Definition
priv) = String -> (Definition, Definition)
f String
s'
  HeaderSection -> Definition -> CompilerM op s ()
forall op s. HeaderSection -> Definition -> CompilerM op s ()
headerDecl HeaderSection
h Definition
pub
  Definition -> CompilerM op s ()
forall op s. Definition -> CompilerM op s ()
earlyDecl Definition
priv
  String -> CompilerM op s String
forall (m :: * -> *) a. Monad m => a -> m a
return String
s'

-- | As 'publicDef', but ignores the public name.
publicDef_ ::
  String ->
  HeaderSection ->
  (String -> (C.Definition, C.Definition)) ->
  CompilerM op s ()
publicDef_ :: String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
publicDef_ String
s HeaderSection
h String -> (Definition, Definition)
f = CompilerM op s String -> CompilerM op s ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (CompilerM op s String -> CompilerM op s ())
-> CompilerM op s String -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s String
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s String
publicDef String
s HeaderSection
h String -> (Definition, Definition)
f

headerDecl :: HeaderSection -> C.Definition -> CompilerM op s ()
headerDecl :: HeaderSection -> Definition -> CompilerM op s ()
headerDecl HeaderSection
sec Definition
def = (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((CompilerState s -> CompilerState s) -> CompilerM op s ())
-> (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \CompilerState s
s ->
  CompilerState s
s
    { compHeaderDecls :: Map HeaderSection (DList Definition)
compHeaderDecls =
        (DList Definition -> DList Definition -> DList Definition)
-> Map HeaderSection (DList Definition)
-> Map HeaderSection (DList Definition)
-> Map HeaderSection (DList Definition)
forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
M.unionWith
          DList Definition -> DList Definition -> DList Definition
forall a. Semigroup a => a -> a -> a
(<>)
          (CompilerState s -> Map HeaderSection (DList Definition)
forall s. CompilerState s -> Map HeaderSection (DList Definition)
compHeaderDecls CompilerState s
s)
          (HeaderSection
-> DList Definition -> Map HeaderSection (DList Definition)
forall k a. k -> a -> Map k a
M.singleton HeaderSection
sec (Definition -> DList Definition
forall a. a -> DList a
DL.singleton Definition
def))
    }

libDecl :: C.Definition -> CompilerM op s ()
libDecl :: Definition -> CompilerM op s ()
libDecl Definition
def = (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((CompilerState s -> CompilerState s) -> CompilerM op s ())
-> (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \CompilerState s
s ->
  CompilerState s
s {compLibDecls :: DList Definition
compLibDecls = CompilerState s -> DList Definition
forall s. CompilerState s -> DList Definition
compLibDecls CompilerState s
s DList Definition -> DList Definition -> DList Definition
forall a. Semigroup a => a -> a -> a
<> Definition -> DList Definition
forall a. a -> DList a
DL.singleton Definition
def}

earlyDecl :: C.Definition -> CompilerM op s ()
earlyDecl :: Definition -> CompilerM op s ()
earlyDecl Definition
def = (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((CompilerState s -> CompilerState s) -> CompilerM op s ())
-> (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \CompilerState s
s ->
  CompilerState s
s {compEarlyDecls :: DList Definition
compEarlyDecls = CompilerState s -> DList Definition
forall s. CompilerState s -> DList Definition
compEarlyDecls CompilerState s
s DList Definition -> DList Definition -> DList Definition
forall a. Semigroup a => a -> a -> a
<> Definition -> DList Definition
forall a. a -> DList a
DL.singleton Definition
def}

contextField :: C.Id -> C.Type -> Maybe C.Exp -> CompilerM op s ()
contextField :: Id -> Type -> Maybe Exp -> CompilerM op s ()
contextField Id
name Type
ty Maybe Exp
initial = (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((CompilerState s -> CompilerState s) -> CompilerM op s ())
-> (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \CompilerState s
s ->
  CompilerState s
s {compCtxFields :: DList (Id, Type, Maybe Exp)
compCtxFields = CompilerState s -> DList (Id, Type, Maybe Exp)
forall s. CompilerState s -> DList (Id, Type, Maybe Exp)
compCtxFields CompilerState s
s DList (Id, Type, Maybe Exp)
-> DList (Id, Type, Maybe Exp) -> DList (Id, Type, Maybe Exp)
forall a. Semigroup a => a -> a -> a
<> (Id, Type, Maybe Exp) -> DList (Id, Type, Maybe Exp)
forall a. a -> DList a
DL.singleton (Id
name, Type
ty, Maybe Exp
initial)}

profileReport :: C.BlockItem -> CompilerM op s ()
profileReport :: BlockItem -> CompilerM op s ()
profileReport BlockItem
x = (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((CompilerState s -> CompilerState s) -> CompilerM op s ())
-> (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \CompilerState s
s ->
  CompilerState s
s {compProfileItems :: DList BlockItem
compProfileItems = CompilerState s -> DList BlockItem
forall s. CompilerState s -> DList BlockItem
compProfileItems CompilerState s
s DList BlockItem -> DList BlockItem -> DList BlockItem
forall a. Semigroup a => a -> a -> a
<> BlockItem -> DList BlockItem
forall a. a -> DList a
DL.singleton BlockItem
x}

onClear :: C.BlockItem -> CompilerM op s ()
onClear :: BlockItem -> CompilerM op s ()
onClear BlockItem
x = (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((CompilerState s -> CompilerState s) -> CompilerM op s ())
-> (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \CompilerState s
s ->
  CompilerState s
s {compClearItems :: DList BlockItem
compClearItems = CompilerState s -> DList BlockItem
forall s. CompilerState s -> DList BlockItem
compClearItems CompilerState s
s DList BlockItem -> DList BlockItem -> DList BlockItem
forall a. Semigroup a => a -> a -> a
<> BlockItem -> DList BlockItem
forall a. a -> DList a
DL.singleton BlockItem
x}

stm :: C.Stm -> CompilerM op s ()
stm :: Stm -> CompilerM op s ()
stm Stm
s = BlockItem -> CompilerM op s ()
forall op s. BlockItem -> CompilerM op s ()
item [C.citem|$stm:s|]

stms :: [C.Stm] -> CompilerM op s ()
stms :: [Stm] -> CompilerM op s ()
stms = (Stm -> CompilerM op s ()) -> [Stm] -> CompilerM op s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm

decl :: C.InitGroup -> CompilerM op s ()
decl :: InitGroup -> CompilerM op s ()
decl InitGroup
x = BlockItem -> CompilerM op s ()
forall op s. BlockItem -> CompilerM op s ()
item [C.citem|$decl:x;|]

-- | Public names must have a consitent prefix.
publicName :: String -> CompilerM op s String
publicName :: String -> CompilerM op s String
publicName String
s = String -> CompilerM op s String
forall (m :: * -> *) a. Monad m => a -> m a
return (String -> CompilerM op s String)
-> String -> CompilerM op s String
forall a b. (a -> b) -> a -> b
$ String
"futhark_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
s

-- | The generated code must define a struct with this name.
contextType :: CompilerM op s C.Type
contextType :: CompilerM op s Type
contextType = do
  String
name <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName String
"context"
  Type -> CompilerM op s Type
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cty|struct $id:name|]

memToCType :: VName -> Space -> CompilerM op s C.Type
memToCType :: VName -> Space -> CompilerM op s Type
memToCType VName
v Space
space = do
  Bool
refcount <- Space -> CompilerM op s Bool
forall op s. Space -> CompilerM op s Bool
fatMemory Space
space
  Bool
cached <- Maybe VName -> Bool
forall a. Maybe a -> Bool
isJust (Maybe VName -> Bool)
-> CompilerM op s (Maybe VName) -> CompilerM op s Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> CompilerM op s (Maybe VName)
forall a op s. ToExp a => a -> CompilerM op s (Maybe VName)
cacheMem VName
v
  if Bool
refcount Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
cached
    then Type -> CompilerM op s Type
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> CompilerM op s Type) -> Type -> CompilerM op s Type
forall a b. (a -> b) -> a -> b
$ Space -> Type
fatMemType Space
space
    else Space -> CompilerM op s Type
forall op s. Space -> CompilerM op s Type
rawMemCType Space
space

rawMemCType :: Space -> CompilerM op s C.Type
rawMemCType :: Space -> CompilerM op s Type
rawMemCType Space
DefaultSpace = Type -> CompilerM op s Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
defaultMemBlockType
rawMemCType (Space String
sid) = CompilerM op s (CompilerM op s Type) -> CompilerM op s Type
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (CompilerM op s (CompilerM op s Type) -> CompilerM op s Type)
-> CompilerM op s (CompilerM op s Type) -> CompilerM op s Type
forall a b. (a -> b) -> a -> b
$ (CompilerEnv op s -> MemoryType op s)
-> CompilerM op s (MemoryType op s)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks CompilerEnv op s -> MemoryType op s
forall op s. CompilerEnv op s -> MemoryType op s
envMemoryType CompilerM op s (MemoryType op s)
-> CompilerM op s String -> CompilerM op s (CompilerM op s Type)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> String -> CompilerM op s String
forall (f :: * -> *) a. Applicative f => a -> f a
pure String
sid
rawMemCType (ScalarSpace [] PrimType
t) =
  Type -> CompilerM op s Type
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cty|$ty:(primTypeToCType t)[1]|]
rawMemCType (ScalarSpace [SubExp]
ds PrimType
t) =
  Type -> CompilerM op s Type
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cty|$ty:(primTypeToCType t)[$exp:(cproduct ds')]|]
  where
    ds' :: [Exp]
ds' = (SubExp -> Exp) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
`C.toExp` SrcLoc
forall a. IsLocation a => a
noLoc) [SubExp]
ds

fatMemType :: Space -> C.Type
fatMemType :: Space -> Type
fatMemType Space
space =
  [C.cty|struct $id:name|]
  where
    name :: String
name = case Space
space of
      Space String
sid -> String
"memblock_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
sid
      Space
_ -> String
"memblock"

fatMemSet :: Space -> String
fatMemSet :: Space -> String
fatMemSet (Space String
sid) = String
"memblock_set_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
sid
fatMemSet Space
_ = String
"memblock_set"

fatMemAlloc :: Space -> String
fatMemAlloc :: Space -> String
fatMemAlloc (Space String
sid) = String
"memblock_alloc_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
sid
fatMemAlloc Space
_ = String
"memblock_alloc"

fatMemUnRef :: Space -> String
fatMemUnRef :: Space -> String
fatMemUnRef (Space String
sid) = String
"memblock_unref_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
sid
fatMemUnRef Space
_ = String
"memblock_unref"

rawMem :: VName -> CompilerM op s C.Exp
rawMem :: VName -> CompilerM op s Exp
rawMem VName
v = Bool -> VName -> Exp
forall a. ToExp a => Bool -> a -> Exp
rawMem' (Bool -> VName -> Exp)
-> CompilerM op s Bool -> CompilerM op s (VName -> Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CompilerM op s Bool
forall op s. CompilerM op s Bool
fat CompilerM op s (VName -> Exp)
-> CompilerM op s VName -> CompilerM op s Exp
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> VName -> CompilerM op s VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v
  where
    fat :: CompilerM op s Bool
fat = (CompilerEnv op s -> Bool -> Bool) -> CompilerM op s (Bool -> Bool)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (Bool -> Bool -> Bool
(&&) (Bool -> Bool -> Bool)
-> (CompilerEnv op s -> Bool) -> CompilerEnv op s -> Bool -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerEnv op s -> Bool
forall op s. CompilerEnv op s -> Bool
envFatMemory) CompilerM op s (Bool -> Bool)
-> CompilerM op s Bool -> CompilerM op s Bool
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Maybe VName -> Bool
forall a. Maybe a -> Bool
isNothing (Maybe VName -> Bool)
-> CompilerM op s (Maybe VName) -> CompilerM op s Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> CompilerM op s (Maybe VName)
forall a op s. ToExp a => a -> CompilerM op s (Maybe VName)
cacheMem VName
v)

rawMem' :: C.ToExp a => Bool -> a -> C.Exp
rawMem' :: Bool -> a -> Exp
rawMem' Bool
True a
e = [C.cexp|$exp:e.mem|]
rawMem' Bool
False a
e = [C.cexp|$exp:e|]

allocRawMem ::
  (C.ToExp a, C.ToExp b, C.ToExp c) =>
  a ->
  b ->
  Space ->
  c ->
  CompilerM op s ()
allocRawMem :: a -> b -> Space -> c -> CompilerM op s ()
allocRawMem a
dest b
size Space
space c
desc = case Space
space of
  Space String
sid ->
    CompilerM op s (CompilerM op s ()) -> CompilerM op s ()
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (CompilerM op s (CompilerM op s ()) -> CompilerM op s ())
-> CompilerM op s (CompilerM op s ()) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$
      (CompilerEnv op s -> Allocate op s)
-> CompilerM op s (Allocate op s)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks CompilerEnv op s -> Allocate op s
forall op s. CompilerEnv op s -> Allocate op s
envAllocate CompilerM op s (Allocate op s)
-> CompilerM op s Exp
-> CompilerM op s (Exp -> Exp -> String -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Exp -> CompilerM op s Exp
forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cexp|$exp:dest|]
        CompilerM op s (Exp -> Exp -> String -> CompilerM op s ())
-> CompilerM op s Exp
-> CompilerM op s (Exp -> String -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Exp -> CompilerM op s Exp
forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cexp|$exp:size|]
        CompilerM op s (Exp -> String -> CompilerM op s ())
-> CompilerM op s Exp
-> CompilerM op s (String -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Exp -> CompilerM op s Exp
forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cexp|$exp:desc|]
        CompilerM op s (String -> CompilerM op s ())
-> CompilerM op s String -> CompilerM op s (CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> String -> CompilerM op s String
forall (f :: * -> *) a. Applicative f => a -> f a
pure String
sid
  Space
_ ->
    Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|$exp:dest = (unsigned char*) malloc((size_t)$exp:size);|]

freeRawMem ::
  (C.ToExp a, C.ToExp b) =>
  a ->
  Space ->
  b ->
  CompilerM op s ()
freeRawMem :: a -> Space -> b -> CompilerM op s ()
freeRawMem a
mem Space
space b
desc =
  case Space
space of
    Space String
sid -> do
      Deallocate op s
free_mem <- (CompilerEnv op s -> Deallocate op s)
-> CompilerM op s (Deallocate op s)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks CompilerEnv op s -> Deallocate op s
forall op s. CompilerEnv op s -> Deallocate op s
envDeallocate
      Deallocate op s
free_mem [C.cexp|$exp:mem|] [C.cexp|$exp:desc|] String
sid
    Space
_ -> BlockItem -> CompilerM op s ()
forall op s. BlockItem -> CompilerM op s ()
item [C.citem|free($exp:mem);|]

defineMemorySpace :: Space -> CompilerM op s (C.Definition, [C.Definition], C.BlockItem)
defineMemorySpace :: Space -> CompilerM op s (Definition, [Definition], BlockItem)
defineMemorySpace Space
space = do
  Type
rm <- Space -> CompilerM op s Type
forall op s. Space -> CompilerM op s Type
rawMemCType Space
space
  let structdef :: Definition
structdef =
        [C.cedecl|struct $id:sname { int *references;
                                     $ty:rm mem;
                                     typename int64_t size;
                                     const char *desc; };|]

  Id -> Type -> Maybe Exp -> CompilerM op s ()
forall op s. Id -> Type -> Maybe Exp -> CompilerM op s ()
contextField Id
peakname [C.cty|typename int64_t|] (Maybe Exp -> CompilerM op s ()) -> Maybe Exp -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just [C.cexp|0|]
  Id -> Type -> Maybe Exp -> CompilerM op s ()
forall op s. Id -> Type -> Maybe Exp -> CompilerM op s ()
contextField Id
usagename [C.cty|typename int64_t|] (Maybe Exp -> CompilerM op s ()) -> Maybe Exp -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just [C.cexp|0|]

  -- Unreferencing a memory block consists of decreasing its reference
  -- count and freeing the corresponding memory if the count reaches
  -- zero.
  [BlockItem]
free <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
collect (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ Exp -> Space -> Exp -> CompilerM op s ()
forall a b op s.
(ToExp a, ToExp b) =>
a -> Space -> b -> CompilerM op s ()
freeRawMem [C.cexp|block->mem|] Space
space [C.cexp|desc|]
  Type
ctx_ty <- CompilerM op s Type
forall op s. CompilerM op s Type
contextType
  let unrefdef :: Definition
unrefdef =
        [C.cedecl|static int $id:(fatMemUnRef space) ($ty:ctx_ty *ctx, $ty:mty *block, const char *desc) {
  if (block->references != NULL) {
    *(block->references) -= 1;
    if (ctx->detail_memory) {
      fprintf(ctx->log, "Unreferencing block %s (allocated as %s) in %s: %d references remaining.\n",
                      desc, block->desc, $string:spacedesc, *(block->references));
    }
    if (*(block->references) == 0) {
      ctx->$id:usagename -= block->size;
      $items:free
      free(block->references);
      if (ctx->detail_memory) {
        fprintf(ctx->log, "%lld bytes freed (now allocated: %lld bytes)\n",
                (long long) block->size, (long long) ctx->$id:usagename);
      }
    }
    block->references = NULL;
  }
  return 0;
}|]

  -- When allocating a memory block we initialise the reference count to 1.
  [BlockItem]
alloc <-
    CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
collect (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$
      Exp -> Exp -> Space -> Exp -> CompilerM op s ()
forall a b c op s.
(ToExp a, ToExp b, ToExp c) =>
a -> b -> Space -> c -> CompilerM op s ()
allocRawMem [C.cexp|block->mem|] [C.cexp|size|] Space
space [C.cexp|desc|]
  let allocdef :: Definition
allocdef =
        [C.cedecl|static int $id:(fatMemAlloc space) ($ty:ctx_ty *ctx, $ty:mty *block, typename int64_t size, const char *desc) {
  if (size < 0) {
    futhark_panic(1, "Negative allocation of %lld bytes attempted for %s in %s.\n",
          (long long)size, desc, $string:spacedesc, ctx->$id:usagename);
  }
  int ret = $id:(fatMemUnRef space)(ctx, block, desc);

  ctx->$id:usagename += size;
  if (ctx->detail_memory) {
    fprintf(ctx->log, "Allocating %lld bytes for %s in %s (then allocated: %lld bytes)",
            (long long) size,
            desc, $string:spacedesc,
            (long long) ctx->$id:usagename);
  }
  if (ctx->$id:usagename > ctx->$id:peakname) {
    ctx->$id:peakname = ctx->$id:usagename;
    if (ctx->detail_memory) {
      fprintf(ctx->log, " (new peak).\n");
    }
  } else if (ctx->detail_memory) {
    fprintf(ctx->log, ".\n");
  }

  $items:alloc
  block->references = (int*) malloc(sizeof(int));
  *(block->references) = 1;
  block->size = size;
  block->desc = desc;
  return ret;
  }|]

  -- Memory setting - unreference the destination and increase the
  -- count of the source by one.
  let setdef :: Definition
setdef =
        [C.cedecl|static int $id:(fatMemSet space) ($ty:ctx_ty *ctx, $ty:mty *lhs, $ty:mty *rhs, const char *lhs_desc) {
  int ret = $id:(fatMemUnRef space)(ctx, lhs, lhs_desc);
  if (rhs->references != NULL) {
    (*(rhs->references))++;
  }
  *lhs = *rhs;
  return ret;
}
|]

  BlockItem -> CompilerM op s ()
forall op s. BlockItem -> CompilerM op s ()
onClear [C.citem|ctx->$id:peakname = 0;|]

  let peakmsg :: String
peakmsg = String
"Peak memory usage for " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
spacedesc String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
": %lld bytes.\n"
  (Definition, [Definition], BlockItem)
-> CompilerM op s (Definition, [Definition], BlockItem)
forall (m :: * -> *) a. Monad m => a -> m a
return
    ( Definition
structdef,
      [Definition
unrefdef, Definition
allocdef, Definition
setdef],
      -- Do not report memory usage for DefaultSpace (CPU memory),
      -- because it would not be accurate anyway.  This whole
      -- tracking probably needs to be rethought.
      if Space
space Space -> Space -> Bool
forall a. Eq a => a -> a -> Bool
== Space
DefaultSpace
        then [C.citem|{}|]
        else [C.citem|str_builder(&builder, $string:peakmsg, (long long) ctx->$id:peakname);|]
    )
  where
    mty :: Type
mty = Space -> Type
fatMemType Space
space
    (Id
peakname, Id
usagename, Id
sname, String
spacedesc) = case Space
space of
      Space String
sid ->
        ( String -> SrcLoc -> Id
forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent (String
"peak_mem_usage_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
sid) SrcLoc
forall a. IsLocation a => a
noLoc,
          String -> SrcLoc -> Id
forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent (String
"cur_mem_usage_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
sid) SrcLoc
forall a. IsLocation a => a
noLoc,
          String -> SrcLoc -> Id
forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent (String
"memblock_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
sid) SrcLoc
forall a. IsLocation a => a
noLoc,
          String
"space '" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
sid String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"'"
        )
      Space
_ ->
        ( Id
"peak_mem_usage_default",
          Id
"cur_mem_usage_default",
          Id
"memblock",
          String
"default space"
        )

declMem :: VName -> Space -> CompilerM op s ()
declMem :: VName -> Space -> CompilerM op s ()
declMem VName
name Space
space = do
  Bool
cached <- Maybe VName -> Bool
forall a. Maybe a -> Bool
isJust (Maybe VName -> Bool)
-> CompilerM op s (Maybe VName) -> CompilerM op s Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> CompilerM op s (Maybe VName)
forall a op s. ToExp a => a -> CompilerM op s (Maybe VName)
cacheMem VName
name
  Bool -> CompilerM op s () -> CompilerM op s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
cached (CompilerM op s () -> CompilerM op s ())
-> CompilerM op s () -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ do
    Type
ty <- VName -> Space -> CompilerM op s Type
forall op s. VName -> Space -> CompilerM op s Type
memToCType VName
name Space
space
    InitGroup -> CompilerM op s ()
forall op s. InitGroup -> CompilerM op s ()
decl [C.cdecl|$ty:ty $id:name;|]
    VName -> Space -> CompilerM op s ()
forall a op s. ToExp a => a -> Space -> CompilerM op s ()
resetMem VName
name Space
space
    (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((CompilerState s -> CompilerState s) -> CompilerM op s ())
-> (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \CompilerState s
s -> CompilerState s
s {compDeclaredMem :: [(VName, Space)]
compDeclaredMem = (VName
name, Space
space) (VName, Space) -> [(VName, Space)] -> [(VName, Space)]
forall a. a -> [a] -> [a]
: CompilerState s -> [(VName, Space)]
forall s. CompilerState s -> [(VName, Space)]
compDeclaredMem CompilerState s
s}

resetMem :: C.ToExp a => a -> Space -> CompilerM op s ()
resetMem :: a -> Space -> CompilerM op s ()
resetMem a
mem Space
space = do
  Bool
refcount <- Space -> CompilerM op s Bool
forall op s. Space -> CompilerM op s Bool
fatMemory Space
space
  Bool
cached <- Maybe VName -> Bool
forall a. Maybe a -> Bool
isJust (Maybe VName -> Bool)
-> CompilerM op s (Maybe VName) -> CompilerM op s Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> CompilerM op s (Maybe VName)
forall a op s. ToExp a => a -> CompilerM op s (Maybe VName)
cacheMem a
mem
  if Bool
cached
    then Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|$exp:mem = NULL;|]
    else
      Bool -> CompilerM op s () -> CompilerM op s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
refcount (CompilerM op s () -> CompilerM op s ())
-> CompilerM op s () -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$
        Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|$exp:mem.references = NULL;|]

setMem :: (C.ToExp a, C.ToExp b) => a -> b -> Space -> CompilerM op s ()
setMem :: a -> b -> Space -> CompilerM op s ()
setMem a
dest b
src Space
space = do
  Bool
refcount <- Space -> CompilerM op s Bool
forall op s. Space -> CompilerM op s Bool
fatMemory Space
space
  let src_s :: String
src_s = Exp -> String
forall a. Pretty a => a -> String
pretty (Exp -> String) -> Exp -> String
forall a b. (a -> b) -> a -> b
$ b -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp b
src SrcLoc
forall a. IsLocation a => a
noLoc
  if Bool
refcount
    then
      Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm
        [C.cstm|if ($id:(fatMemSet space)(ctx, &$exp:dest, &$exp:src,
                                               $string:src_s) != 0) {
                       return 1;
                     }|]
    else case Space
space of
      ScalarSpace [SubExp]
ds PrimType
_ -> do
        VName
i' <- String -> CompilerM op s VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"i"
        let i :: SrcLoc -> Id
i = VName -> SrcLoc -> Id
forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent VName
i'
            it :: Type
it = PrimType -> Type
primTypeToCType (PrimType -> Type) -> PrimType -> Type
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int32
            ds' :: [Exp]
ds' = (SubExp -> Exp) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
`C.toExp` SrcLoc
forall a. IsLocation a => a
noLoc) [SubExp]
ds
            bound :: Exp
bound = [Exp] -> Exp
cproduct [Exp]
ds'
        Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm
          [C.cstm|for ($ty:it $id:i = 0; $id:i < $exp:bound; $id:i++) {
                            $exp:dest[$id:i] = $exp:src[$id:i];
                  }|]
      Space
_ -> Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|$exp:dest = $exp:src;|]

unRefMem :: C.ToExp a => a -> Space -> CompilerM op s ()
unRefMem :: a -> Space -> CompilerM op s ()
unRefMem a
mem Space
space = do
  Bool
refcount <- Space -> CompilerM op s Bool
forall op s. Space -> CompilerM op s Bool
fatMemory Space
space
  Bool
cached <- Maybe VName -> Bool
forall a. Maybe a -> Bool
isJust (Maybe VName -> Bool)
-> CompilerM op s (Maybe VName) -> CompilerM op s Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> CompilerM op s (Maybe VName)
forall a op s. ToExp a => a -> CompilerM op s (Maybe VName)
cacheMem a
mem
  let mem_s :: String
mem_s = Exp -> String
forall a. Pretty a => a -> String
pretty (Exp -> String) -> Exp -> String
forall a b. (a -> b) -> a -> b
$ a -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp a
mem SrcLoc
forall a. IsLocation a => a
noLoc
  Bool -> CompilerM op s () -> CompilerM op s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool
refcount Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
cached) (CompilerM op s () -> CompilerM op s ())
-> CompilerM op s () -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$
    Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm
      [C.cstm|if ($id:(fatMemUnRef space)(ctx, &$exp:mem, $string:mem_s) != 0) {
                  return 1;
                }|]

allocMem ::
  (C.ToExp a, C.ToExp b) =>
  a ->
  b ->
  Space ->
  C.Stm ->
  CompilerM op s ()
allocMem :: a -> b -> Space -> Stm -> CompilerM op s ()
allocMem a
mem b
size Space
space Stm
on_failure = do
  Bool
refcount <- Space -> CompilerM op s Bool
forall op s. Space -> CompilerM op s Bool
fatMemory Space
space
  let mem_s :: String
mem_s = Exp -> String
forall a. Pretty a => a -> String
pretty (Exp -> String) -> Exp -> String
forall a b. (a -> b) -> a -> b
$ a -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp a
mem SrcLoc
forall a. IsLocation a => a
noLoc
  if Bool
refcount
    then
      Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm
        [C.cstm|if ($id:(fatMemAlloc space)(ctx, &$exp:mem, $exp:size,
                                                 $string:mem_s)) {
                       $stm:on_failure
                     }|]
    else do
      a -> Space -> String -> CompilerM op s ()
forall a b op s.
(ToExp a, ToExp b) =>
a -> Space -> b -> CompilerM op s ()
freeRawMem a
mem Space
space String
mem_s
      a -> b -> Space -> Exp -> CompilerM op s ()
forall a b c op s.
(ToExp a, ToExp b, ToExp c) =>
a -> b -> Space -> c -> CompilerM op s ()
allocRawMem a
mem b
size Space
space [C.cexp|desc|]

copyMemoryDefaultSpace ::
  C.Exp ->
  C.Exp ->
  C.Exp ->
  C.Exp ->
  C.Exp ->
  CompilerM op s ()
copyMemoryDefaultSpace :: Exp -> Exp -> Exp -> Exp -> Exp -> CompilerM op s ()
copyMemoryDefaultSpace Exp
destmem Exp
destidx Exp
srcmem Exp
srcidx Exp
nbytes =
  Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm
    [C.cstm|if ($exp:nbytes > 0) {
              memmove($exp:destmem + $exp:destidx,
                      $exp:srcmem + $exp:srcidx,
                      $exp:nbytes);
            }|]

--- Entry points.

criticalSection :: Operations op s -> [C.BlockItem] -> [C.BlockItem]
criticalSection :: Operations op s -> [BlockItem] -> [BlockItem]
criticalSection Operations op s
ops [BlockItem]
x =
  [C.citems|lock_lock(&ctx->lock);
            $items:(fst (opsCritical ops))
            $items:x
            $items:(snd (opsCritical ops))
            lock_unlock(&ctx->lock);
           |]

arrayLibraryFunctions ::
  Publicness ->
  Space ->
  PrimType ->
  Signedness ->
  Int ->
  CompilerM op s [C.Definition]
arrayLibraryFunctions :: Publicness
-> Space
-> PrimType
-> Signedness
-> Int
-> CompilerM op s [Definition]
arrayLibraryFunctions Publicness
pub Space
space PrimType
pt Signedness
signed Int
rank = do
  let pt' :: Type
pt' = Signedness -> PrimType -> Type
primAPIType Signedness
signed PrimType
pt
      name :: String
name = PrimType -> Signedness -> Int -> String
arrayName PrimType
pt Signedness
signed Int
rank
      arr_name :: String
arr_name = String
"futhark_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
name
      array_type :: Type
array_type = [C.cty|struct $id:arr_name|]

  String
new_array <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName (String -> CompilerM op s String)
-> String -> CompilerM op s String
forall a b. (a -> b) -> a -> b
$ String
"new_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
name
  String
new_raw_array <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName (String -> CompilerM op s String)
-> String -> CompilerM op s String
forall a b. (a -> b) -> a -> b
$ String
"new_raw_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
name
  String
free_array <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName (String -> CompilerM op s String)
-> String -> CompilerM op s String
forall a b. (a -> b) -> a -> b
$ String
"free_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
name
  String
values_array <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName (String -> CompilerM op s String)
-> String -> CompilerM op s String
forall a b. (a -> b) -> a -> b
$ String
"values_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
name
  String
values_raw_array <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName (String -> CompilerM op s String)
-> String -> CompilerM op s String
forall a b. (a -> b) -> a -> b
$ String
"values_raw_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
name
  String
shape_array <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName (String -> CompilerM op s String)
-> String -> CompilerM op s String
forall a b. (a -> b) -> a -> b
$ String
"shape_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
name

  let shape_names :: [String]
shape_names = [String
"dim" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i | Int
i <- [Int
0 .. Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]]
      shape_params :: [Param]
shape_params = [[C.cparam|typename int64_t $id:k|] | String
k <- [String]
shape_names]
      arr_size :: Exp
arr_size = [Exp] -> Exp
cproduct [[C.cexp|$id:k|] | String
k <- [String]
shape_names]
      arr_size_array :: Exp
arr_size_array = [Exp] -> Exp
cproduct [[C.cexp|arr->shape[$int:i]|] | Int
i <- [Int
0 .. Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]]
  Copy op s
copy <- (CompilerEnv op s -> Copy op s) -> CompilerM op s (Copy op s)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks CompilerEnv op s -> Copy op s
forall op s. CompilerEnv op s -> Copy op s
envCopy

  Type
memty <- Space -> CompilerM op s Type
forall op s. Space -> CompilerM op s Type
rawMemCType Space
space

  let prepare_new :: CompilerM op s ()
prepare_new = do
        Exp -> Space -> CompilerM op s ()
forall a op s. ToExp a => a -> Space -> CompilerM op s ()
resetMem [C.cexp|arr->mem|] Space
space
        Exp -> Exp -> Space -> Stm -> CompilerM op s ()
forall a b op s.
(ToExp a, ToExp b) =>
a -> b -> Space -> Stm -> CompilerM op s ()
allocMem
          [C.cexp|arr->mem|]
          [C.cexp|$exp:arr_size * $int:(primByteSize pt::Int)|]
          Space
space
          [C.cstm|return NULL;|]
        [Int] -> (Int -> CompilerM op s ()) -> CompilerM op s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0 .. Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> CompilerM op s ()) -> CompilerM op s ())
-> (Int -> CompilerM op s ()) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \Int
i ->
          let dim_s :: String
dim_s = String
"dim" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i
           in Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|arr->shape[$int:i] = $id:dim_s;|]

  [BlockItem]
new_body <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
collect (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ do
    CompilerM op s ()
forall op s. CompilerM op s ()
prepare_new
    Copy op s
copy
      [C.cexp|arr->mem.mem|]
      [C.cexp|0|]
      Space
space
      [C.cexp|data|]
      [C.cexp|0|]
      Space
DefaultSpace
      [C.cexp|((size_t)$exp:arr_size) * $int:(primByteSize pt::Int)|]

  [BlockItem]
new_raw_body <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
collect (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ do
    CompilerM op s ()
forall op s. CompilerM op s ()
prepare_new
    Copy op s
copy
      [C.cexp|arr->mem.mem|]
      [C.cexp|0|]
      Space
space
      [C.cexp|data|]
      [C.cexp|offset|]
      Space
space
      [C.cexp|((size_t)$exp:arr_size) * $int:(primByteSize pt::Int)|]

  [BlockItem]
free_body <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
collect (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ Exp -> Space -> CompilerM op s ()
forall a op s. ToExp a => a -> Space -> CompilerM op s ()
unRefMem [C.cexp|arr->mem|] Space
space

  [BlockItem]
values_body <-
    CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
collect (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$
      Copy op s
copy
        [C.cexp|data|]
        [C.cexp|0|]
        Space
DefaultSpace
        [C.cexp|arr->mem.mem|]
        [C.cexp|0|]
        Space
space
        [C.cexp|((size_t)$exp:arr_size_array) * $int:(primByteSize pt::Int)|]

  Type
ctx_ty <- CompilerM op s Type
forall op s. CompilerM op s Type
contextType
  Operations op s
ops <- (CompilerEnv op s -> Operations op s)
-> CompilerM op s (Operations op s)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks CompilerEnv op s -> Operations op s
forall op s. CompilerEnv op s -> Operations op s
envOperations

  let proto :: Definition -> CompilerM op s ()
proto = case Publicness
pub of
        Publicness
Public -> HeaderSection -> Definition -> CompilerM op s ()
forall op s. HeaderSection -> Definition -> CompilerM op s ()
headerDecl (String -> HeaderSection
ArrayDecl String
name)
        Publicness
Private -> Definition -> CompilerM op s ()
forall op s. Definition -> CompilerM op s ()
libDecl

  Definition -> CompilerM op s ()
forall op s. Definition -> CompilerM op s ()
proto
    [C.cedecl|struct $id:arr_name;|]
  Definition -> CompilerM op s ()
forall op s. Definition -> CompilerM op s ()
proto
    [C.cedecl|$ty:array_type* $id:new_array($ty:ctx_ty *ctx, const $ty:pt' *data, $params:shape_params);|]
  Definition -> CompilerM op s ()
forall op s. Definition -> CompilerM op s ()
proto
    [C.cedecl|$ty:array_type* $id:new_raw_array($ty:ctx_ty *ctx, const $ty:memty data, typename int64_t offset, $params:shape_params);|]
  Definition -> CompilerM op s ()
forall op s. Definition -> CompilerM op s ()
proto
    [C.cedecl|int $id:free_array($ty:ctx_ty *ctx, $ty:array_type *arr);|]
  Definition -> CompilerM op s ()
forall op s. Definition -> CompilerM op s ()
proto
    [C.cedecl|int $id:values_array($ty:ctx_ty *ctx, $ty:array_type *arr, $ty:pt' *data);|]
  Definition -> CompilerM op s ()
forall op s. Definition -> CompilerM op s ()
proto
    [C.cedecl|$ty:memty $id:values_raw_array($ty:ctx_ty *ctx, $ty:array_type *arr);|]
  Definition -> CompilerM op s ()
forall op s. Definition -> CompilerM op s ()
proto
    [C.cedecl|const typename int64_t* $id:shape_array($ty:ctx_ty *ctx, $ty:array_type *arr);|]

  [Definition] -> CompilerM op s [Definition]
forall (m :: * -> *) a. Monad m => a -> m a
return
    [C.cunit|
          $ty:array_type* $id:new_array($ty:ctx_ty *ctx, const $ty:pt' *data, $params:shape_params) {
            $ty:array_type* bad = NULL;
            $ty:array_type *arr = ($ty:array_type*) malloc(sizeof($ty:array_type));
            if (arr == NULL) {
              return bad;
            }
            $items:(criticalSection ops new_body)
            return arr;
          }

          $ty:array_type* $id:new_raw_array($ty:ctx_ty *ctx, const $ty:memty data, typename int64_t offset,
                                            $params:shape_params) {
            $ty:array_type* bad = NULL;
            $ty:array_type *arr = ($ty:array_type*) malloc(sizeof($ty:array_type));
            if (arr == NULL) {
              return bad;
            }
            $items:(criticalSection ops new_raw_body)
            return arr;
          }

          int $id:free_array($ty:ctx_ty *ctx, $ty:array_type *arr) {
            $items:(criticalSection ops free_body)
            free(arr);
            return 0;
          }

          int $id:values_array($ty:ctx_ty *ctx, $ty:array_type *arr, $ty:pt' *data) {
            $items:(criticalSection ops values_body)
            return 0;
          }

          $ty:memty $id:values_raw_array($ty:ctx_ty *ctx, $ty:array_type *arr) {
            (void)ctx;
            return arr->mem.mem;
          }

          const typename int64_t* $id:shape_array($ty:ctx_ty *ctx, $ty:array_type *arr) {
            (void)ctx;
            return arr->shape;
          }
          |]

opaqueLibraryFunctions ::
  String ->
  [ValueDesc] ->
  CompilerM op s [C.Definition]
opaqueLibraryFunctions :: String -> [ValueDesc] -> CompilerM op s [Definition]
opaqueLibraryFunctions String
desc [ValueDesc]
vds = do
  String
name <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName (String -> CompilerM op s String)
-> String -> CompilerM op s String
forall a b. (a -> b) -> a -> b
$ String -> [ValueDesc] -> String
opaqueName String
desc [ValueDesc]
vds
  String
free_opaque <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName (String -> CompilerM op s String)
-> String -> CompilerM op s String
forall a b. (a -> b) -> a -> b
$ String
"free_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String -> [ValueDesc] -> String
opaqueName String
desc [ValueDesc]
vds
  String
store_opaque <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName (String -> CompilerM op s String)
-> String -> CompilerM op s String
forall a b. (a -> b) -> a -> b
$ String
"store_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String -> [ValueDesc] -> String
opaqueName String
desc [ValueDesc]
vds
  String
restore_opaque <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName (String -> CompilerM op s String)
-> String -> CompilerM op s String
forall a b. (a -> b) -> a -> b
$ String
"restore_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String -> [ValueDesc] -> String
opaqueName String
desc [ValueDesc]
vds

  let opaque_type :: Type
opaque_type = [C.cty|struct $id:name|]

      freeComponent :: Int -> ValueDesc -> CompilerM op s ()
freeComponent Int
_ ScalarValue {} =
        () -> CompilerM op s ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
      freeComponent Int
i (ArrayValue VName
_ Space
_ PrimType
pt Signedness
signed [SubExp]
shape) = do
        let rank :: Int
rank = [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
shape
            field :: String
field = Int -> String
tupleField Int
i
        String
free_array <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName (String -> CompilerM op s String)
-> String -> CompilerM op s String
forall a b. (a -> b) -> a -> b
$ String
"free_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ PrimType -> Signedness -> Int -> String
arrayName PrimType
pt Signedness
signed Int
rank
        -- Protect against NULL here, because we also want to use this
        -- to free partially loaded opaques.
        Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm
          [C.cstm|if (obj->$id:field != NULL && (tmp = $id:free_array(ctx, obj->$id:field)) != 0) {
                ret = tmp;
             }|]

      storeComponent :: Int -> ValueDesc -> (Exp, [Stm])
storeComponent Int
i (ScalarValue PrimType
pt Signedness
sign VName
_) =
        let field :: String
field = Int -> String
tupleField Int
i
         in ( PrimType -> Int -> Exp -> Exp
storageSize PrimType
pt Int
0 [C.cexp|NULL|],
              Signedness -> PrimType -> Int -> Exp -> Exp -> [Stm]
storeValueHeader Signedness
sign PrimType
pt Int
0 [C.cexp|NULL|] [C.cexp|out|]
                [Stm] -> [Stm] -> [Stm]
forall a. [a] -> [a] -> [a]
++ [C.cstms|memcpy(out, &obj->$id:field, sizeof(obj->$id:field));
                            out += sizeof(obj->$id:field);|]
            )
      storeComponent Int
i (ArrayValue VName
_ Space
_ PrimType
pt Signedness
sign [SubExp]
shape) =
        let rank :: Int
rank = [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
shape
            arr_name :: String
arr_name = PrimType -> Signedness -> Int -> String
arrayName PrimType
pt Signedness
sign Int
rank
            field :: String
field = Int -> String
tupleField Int
i
            shape_array :: String
shape_array = String
"futhark_shape_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
arr_name
            values_array :: String
values_array = String
"futhark_values_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
arr_name
            shape' :: Exp
shape' = [C.cexp|$id:shape_array(ctx, obj->$id:field)|]
            num_elems :: Exp
num_elems = [Exp] -> Exp
cproduct [[C.cexp|$exp:shape'[$int:j]|] | Int
j <- [Int
0 .. Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]]
         in ( PrimType -> Int -> Exp -> Exp
storageSize PrimType
pt Int
rank Exp
shape',
              Signedness -> PrimType -> Int -> Exp -> Exp -> [Stm]
storeValueHeader Signedness
sign PrimType
pt Int
rank Exp
shape' [C.cexp|out|]
                [Stm] -> [Stm] -> [Stm]
forall a. [a] -> [a] -> [a]
++ [C.cstms|ret |= $id:values_array(ctx, obj->$id:field, (void*)out);
                            out += $exp:num_elems * $int:(primByteSize pt::Int);|]
            )

  Type
ctx_ty <- CompilerM op s Type
forall op s. CompilerM op s Type
contextType

  [BlockItem]
free_body <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
collect (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ (Int -> ValueDesc -> CompilerM op s ())
-> [Int] -> [ValueDesc] -> CompilerM op s ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Int -> ValueDesc -> CompilerM op s ()
forall op s. Int -> ValueDesc -> CompilerM op s ()
freeComponent [Int
0 ..] [ValueDesc]
vds

  [BlockItem]
store_body <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
collect (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ do
    let ([Exp]
sizes, [[Stm]]
stores) = [(Exp, [Stm])] -> ([Exp], [[Stm]])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Exp, [Stm])] -> ([Exp], [[Stm]]))
-> [(Exp, [Stm])] -> ([Exp], [[Stm]])
forall a b. (a -> b) -> a -> b
$ (Int -> ValueDesc -> (Exp, [Stm]))
-> [Int] -> [ValueDesc] -> [(Exp, [Stm])]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> ValueDesc -> (Exp, [Stm])
storeComponent [Int
0 ..] [ValueDesc]
vds
        size_vars :: [String]
size_vars = (Int -> String) -> [Int] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map ((String
"size_" String -> ShowS
forall a. [a] -> [a] -> [a]
++) ShowS -> (Int -> String) -> Int -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> String
forall a. Show a => a -> String
show) [Int
0 .. [Exp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Exp]
sizes Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
        size_sum :: Exp
size_sum = [Exp] -> Exp
csum [[C.cexp|$id:size|] | String
size <- [String]
size_vars]
    [(String, Exp)]
-> ((String, Exp) -> CompilerM op s ()) -> CompilerM op s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([String] -> [Exp] -> [(String, Exp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [String]
size_vars [Exp]
sizes) (((String, Exp) -> CompilerM op s ()) -> CompilerM op s ())
-> ((String, Exp) -> CompilerM op s ()) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \(String
v, Exp
e) ->
      BlockItem -> CompilerM op s ()
forall op s. BlockItem -> CompilerM op s ()
item [C.citem|typename int64_t $id:v = $exp:e;|]
    Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|*n = $exp:size_sum;|]
    Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|if (p != NULL && *p == NULL) { *p = malloc(*n); }|]
    Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|if (p != NULL) { unsigned char *out = *p; $stms:(concat stores) }|]

  let restoreComponent :: Int -> ValueDesc -> CompilerM op s [Stm]
restoreComponent Int
i (ScalarValue PrimType
pt Signedness
sign VName
_) = do
        let field :: String
field = Int -> String
tupleField Int
i
            dataptr :: String
dataptr = String
"data_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i
        [Stm] -> CompilerM op s ()
forall op s. [Stm] -> CompilerM op s ()
stms ([Stm] -> CompilerM op s ()) -> [Stm] -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ Signedness -> PrimType -> Int -> Exp -> Exp -> [Stm]
loadValueHeader Signedness
sign PrimType
pt Int
0 [C.cexp|NULL|] [C.cexp|src|]
        BlockItem -> CompilerM op s ()
forall op s. BlockItem -> CompilerM op s ()
item [C.citem|const void* $id:dataptr = src;|]
        Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|src += sizeof(obj->$id:field);|]
        [Stm] -> CompilerM op s [Stm]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cstms|memcpy(&obj->$id:field, $id:dataptr, sizeof(obj->$id:field));|]
      restoreComponent Int
i (ArrayValue VName
_ Space
_ PrimType
pt Signedness
sign [SubExp]
shape) = do
        let field :: String
field = Int -> String
tupleField Int
i
            rank :: Int
rank = [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
shape
            arr_name :: String
arr_name = PrimType -> Signedness -> Int -> String
arrayName PrimType
pt Signedness
sign Int
rank
            new_array :: String
new_array = String
"futhark_new_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
arr_name
            dataptr :: String
dataptr = String
"data_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i
            shapearr :: String
shapearr = String
"shape_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i
            dims :: [Exp]
dims = [[C.cexp|$id:shapearr[$int:j]|] | Int
j <- [Int
0 .. Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]]
            num_elems :: Exp
num_elems = [Exp] -> Exp
cproduct [Exp]
dims
        BlockItem -> CompilerM op s ()
forall op s. BlockItem -> CompilerM op s ()
item [C.citem|typename int64_t $id:shapearr[$int:rank];|]
        [Stm] -> CompilerM op s ()
forall op s. [Stm] -> CompilerM op s ()
stms ([Stm] -> CompilerM op s ()) -> [Stm] -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ Signedness -> PrimType -> Int -> Exp -> Exp -> [Stm]
loadValueHeader Signedness
sign PrimType
pt Int
rank [C.cexp|$id:shapearr|] [C.cexp|src|]
        BlockItem -> CompilerM op s ()
forall op s. BlockItem -> CompilerM op s ()
item [C.citem|const void* $id:dataptr = src;|]
        Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|obj->$id:field = NULL;|]
        Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|src += $exp:num_elems * $int:(primByteSize pt::Int);|]
        [Stm] -> CompilerM op s [Stm]
forall (f :: * -> *) a. Applicative f => a -> f a
pure
          [C.cstms|
             obj->$id:field = $id:new_array(ctx, $id:dataptr, $args:dims);
             if (obj->$id:field == NULL) { err = 1; }|]

  [BlockItem]
load_body <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
collect (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ do
    [Stm]
loads <- [[Stm]] -> [Stm]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Stm]] -> [Stm])
-> CompilerM op s [[Stm]] -> CompilerM op s [Stm]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Int -> ValueDesc -> CompilerM op s [Stm])
-> [Int] -> [ValueDesc] -> CompilerM op s [[Stm]]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM Int -> ValueDesc -> CompilerM op s [Stm]
forall op s. Int -> ValueDesc -> CompilerM op s [Stm]
restoreComponent [Int
0 ..] [ValueDesc]
vds
    Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm
      [C.cstm|if (err == 0) {
                $stms:loads
              }|]

  HeaderSection -> Definition -> CompilerM op s ()
forall op s. HeaderSection -> Definition -> CompilerM op s ()
headerDecl
    (String -> HeaderSection
OpaqueDecl String
desc)
    [C.cedecl|struct $id:name;|]
  HeaderSection -> Definition -> CompilerM op s ()
forall op s. HeaderSection -> Definition -> CompilerM op s ()
headerDecl
    (String -> HeaderSection
OpaqueDecl String
desc)
    [C.cedecl|int $id:free_opaque($ty:ctx_ty *ctx, $ty:opaque_type *obj);|]
  HeaderSection -> Definition -> CompilerM op s ()
forall op s. HeaderSection -> Definition -> CompilerM op s ()
headerDecl
    (String -> HeaderSection
OpaqueDecl String
desc)
    [C.cedecl|int $id:store_opaque($ty:ctx_ty *ctx, const $ty:opaque_type *obj, void **p, size_t *n);|]
  HeaderSection -> Definition -> CompilerM op s ()
forall op s. HeaderSection -> Definition -> CompilerM op s ()
headerDecl
    (String -> HeaderSection
OpaqueDecl String
desc)
    [C.cedecl|$ty:opaque_type* $id:restore_opaque($ty:ctx_ty *ctx, const void *p);|]

  -- We do not need to enclose the body in a critical section, because
  -- when we operate on the components of the opaque, we are calling
  -- public API functions that do their own locking.
  [Definition] -> CompilerM op s [Definition]
forall (m :: * -> *) a. Monad m => a -> m a
return
    [C.cunit|
          int $id:free_opaque($ty:ctx_ty *ctx, $ty:opaque_type *obj) {
            int ret = 0, tmp;
            $items:free_body
            free(obj);
            return ret;
          }

          int $id:store_opaque($ty:ctx_ty *ctx,
                               const $ty:opaque_type *obj, void **p, size_t *n) {
            int ret = 0;
            $items:store_body
            return ret;
          }

          $ty:opaque_type* $id:restore_opaque($ty:ctx_ty *ctx,
                                              const void *p) {
            int err = 0;
            const unsigned char *src = p;
            $ty:opaque_type* obj = malloc(sizeof($ty:opaque_type));
            $items:load_body
            if (err != 0) {
              int ret = 0, tmp;
              $items:free_body
              free(obj);
              obj = NULL;
            }
            return obj;
          }
    |]

valueDescToCType :: Publicness -> ValueDesc -> CompilerM op s C.Type
valueDescToCType :: Publicness -> ValueDesc -> CompilerM op s Type
valueDescToCType Publicness
_ (ScalarValue PrimType
pt Signedness
signed VName
_) =
  Type -> CompilerM op s Type
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> CompilerM op s Type) -> Type -> CompilerM op s Type
forall a b. (a -> b) -> a -> b
$ Signedness -> PrimType -> Type
primAPIType Signedness
signed PrimType
pt
valueDescToCType Publicness
pub (ArrayValue VName
_ Space
space PrimType
pt Signedness
signed [SubExp]
shape) = do
  let rank :: Int
rank = [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
shape
  String
name <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName (String -> CompilerM op s String)
-> String -> CompilerM op s String
forall a b. (a -> b) -> a -> b
$ PrimType -> Signedness -> Int -> String
arrayName PrimType
pt Signedness
signed Int
rank
  let add :: Map ArrayType Publicness -> Map ArrayType Publicness
add = (Publicness -> Publicness -> Publicness)
-> ArrayType
-> Publicness
-> Map ArrayType Publicness
-> Map ArrayType Publicness
forall k a. Ord k => (a -> a -> a) -> k -> a -> Map k a -> Map k a
M.insertWith Publicness -> Publicness -> Publicness
forall a. Ord a => a -> a -> a
max (Space
space, Signedness
signed, PrimType
pt, Int
rank) Publicness
pub
  (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((CompilerState s -> CompilerState s) -> CompilerM op s ())
-> (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \CompilerState s
s -> CompilerState s
s {compArrayTypes :: Map ArrayType Publicness
compArrayTypes = Map ArrayType Publicness -> Map ArrayType Publicness
add (Map ArrayType Publicness -> Map ArrayType Publicness)
-> Map ArrayType Publicness -> Map ArrayType Publicness
forall a b. (a -> b) -> a -> b
$ CompilerState s -> Map ArrayType Publicness
forall s. CompilerState s -> Map ArrayType Publicness
compArrayTypes CompilerState s
s}
  Type -> CompilerM op s Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cty|struct $id:name|]

opaqueToCType :: String -> [ValueDesc] -> CompilerM op s C.Type
opaqueToCType :: String -> [ValueDesc] -> CompilerM op s Type
opaqueToCType String
desc [ValueDesc]
vds = do
  String
name <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName (String -> CompilerM op s String)
-> String -> CompilerM op s String
forall a b. (a -> b) -> a -> b
$ String -> [ValueDesc] -> String
opaqueName String
desc [ValueDesc]
vds
  let add :: Map String [ValueDesc] -> Map String [ValueDesc]
add = String
-> [ValueDesc] -> Map String [ValueDesc] -> Map String [ValueDesc]
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert String
desc [ValueDesc]
vds
  (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((CompilerState s -> CompilerState s) -> CompilerM op s ())
-> (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \CompilerState s
s -> CompilerState s
s {compOpaqueTypes :: Map String [ValueDesc]
compOpaqueTypes = Map String [ValueDesc] -> Map String [ValueDesc]
add (Map String [ValueDesc] -> Map String [ValueDesc])
-> Map String [ValueDesc] -> Map String [ValueDesc]
forall a b. (a -> b) -> a -> b
$ CompilerState s -> Map String [ValueDesc]
forall s. CompilerState s -> Map String [ValueDesc]
compOpaqueTypes CompilerState s
s}
  -- Now ensure that the constituent array types will exist.
  (ValueDesc -> CompilerM op s Type)
-> [ValueDesc] -> CompilerM op s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Publicness -> ValueDesc -> CompilerM op s Type
forall op s. Publicness -> ValueDesc -> CompilerM op s Type
valueDescToCType Publicness
Private) [ValueDesc]
vds
  Type -> CompilerM op s Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cty|struct $id:name|]

generateAPITypes :: CompilerM op s ()
generateAPITypes :: CompilerM op s ()
generateAPITypes = do
  ((ArrayType, Publicness) -> CompilerM op s [()])
-> [(ArrayType, Publicness)] -> CompilerM op s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (ArrayType, Publicness) -> CompilerM op s [()]
forall op s. (ArrayType, Publicness) -> CompilerM op s [()]
generateArray ([(ArrayType, Publicness)] -> CompilerM op s ())
-> (Map ArrayType Publicness -> [(ArrayType, Publicness)])
-> Map ArrayType Publicness
-> CompilerM op s ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map ArrayType Publicness -> [(ArrayType, Publicness)]
forall k a. Map k a -> [(k, a)]
M.toList (Map ArrayType Publicness -> CompilerM op s ())
-> CompilerM op s (Map ArrayType Publicness) -> CompilerM op s ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (CompilerState s -> Map ArrayType Publicness)
-> CompilerM op s (Map ArrayType Publicness)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets CompilerState s -> Map ArrayType Publicness
forall s. CompilerState s -> Map ArrayType Publicness
compArrayTypes
  ((String, [ValueDesc]) -> CompilerM op s [()])
-> [(String, [ValueDesc])] -> CompilerM op s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (String, [ValueDesc]) -> CompilerM op s [()]
forall op s. (String, [ValueDesc]) -> CompilerM op s [()]
generateOpaque ([(String, [ValueDesc])] -> CompilerM op s ())
-> (Map String [ValueDesc] -> [(String, [ValueDesc])])
-> Map String [ValueDesc]
-> CompilerM op s ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map String [ValueDesc] -> [(String, [ValueDesc])]
forall k a. Map k a -> [(k, a)]
M.toList (Map String [ValueDesc] -> CompilerM op s ())
-> CompilerM op s (Map String [ValueDesc]) -> CompilerM op s ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (CompilerState s -> Map String [ValueDesc])
-> CompilerM op s (Map String [ValueDesc])
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets CompilerState s -> Map String [ValueDesc]
forall s. CompilerState s -> Map String [ValueDesc]
compOpaqueTypes
  where
    generateArray :: (ArrayType, Publicness) -> CompilerM op s [()]
generateArray ((Space
space, Signedness
signed, PrimType
pt, Int
rank), Publicness
pub) = do
      String
name <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName (String -> CompilerM op s String)
-> String -> CompilerM op s String
forall a b. (a -> b) -> a -> b
$ PrimType -> Signedness -> Int -> String
arrayName PrimType
pt Signedness
signed Int
rank
      let memty :: Type
memty = Space -> Type
fatMemType Space
space
      Definition -> CompilerM op s ()
forall op s. Definition -> CompilerM op s ()
libDecl [C.cedecl|struct $id:name { $ty:memty mem; typename int64_t shape[$int:rank]; };|]
      (Definition -> CompilerM op s ())
-> [Definition] -> CompilerM op s [()]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Definition -> CompilerM op s ()
forall op s. Definition -> CompilerM op s ()
libDecl ([Definition] -> CompilerM op s [()])
-> CompilerM op s [Definition] -> CompilerM op s [()]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Publicness
-> Space
-> PrimType
-> Signedness
-> Int
-> CompilerM op s [Definition]
forall op s.
Publicness
-> Space
-> PrimType
-> Signedness
-> Int
-> CompilerM op s [Definition]
arrayLibraryFunctions Publicness
pub Space
space PrimType
pt Signedness
signed Int
rank

    generateOpaque :: (String, [ValueDesc]) -> CompilerM op s [()]
generateOpaque (String
desc, [ValueDesc]
vds) = do
      String
name <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName (String -> CompilerM op s String)
-> String -> CompilerM op s String
forall a b. (a -> b) -> a -> b
$ String -> [ValueDesc] -> String
opaqueName String
desc [ValueDesc]
vds
      [FieldGroup]
members <- (ValueDesc -> Int -> CompilerM op s FieldGroup)
-> [ValueDesc] -> [Int] -> CompilerM op s [FieldGroup]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM ValueDesc -> Int -> CompilerM op s FieldGroup
forall op s. ValueDesc -> Int -> CompilerM op s FieldGroup
field [ValueDesc]
vds [(Int
0 :: Int) ..]
      Definition -> CompilerM op s ()
forall op s. Definition -> CompilerM op s ()
libDecl [C.cedecl|struct $id:name { $sdecls:members };|]
      (Definition -> CompilerM op s ())
-> [Definition] -> CompilerM op s [()]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Definition -> CompilerM op s ()
forall op s. Definition -> CompilerM op s ()
libDecl ([Definition] -> CompilerM op s [()])
-> CompilerM op s [Definition] -> CompilerM op s [()]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< String -> [ValueDesc] -> CompilerM op s [Definition]
forall op s. String -> [ValueDesc] -> CompilerM op s [Definition]
opaqueLibraryFunctions String
desc [ValueDesc]
vds

    field :: ValueDesc -> Int -> CompilerM op s FieldGroup
field vd :: ValueDesc
vd@ScalarValue {} Int
i = do
      Type
ct <- Publicness -> ValueDesc -> CompilerM op s Type
forall op s. Publicness -> ValueDesc -> CompilerM op s Type
valueDescToCType Publicness
Private ValueDesc
vd
      FieldGroup -> CompilerM op s FieldGroup
forall (m :: * -> *) a. Monad m => a -> m a
return [C.csdecl|$ty:ct $id:(tupleField i);|]
    field ValueDesc
vd Int
i = do
      Type
ct <- Publicness -> ValueDesc -> CompilerM op s Type
forall op s. Publicness -> ValueDesc -> CompilerM op s Type
valueDescToCType Publicness
Private ValueDesc
vd
      FieldGroup -> CompilerM op s FieldGroup
forall (m :: * -> *) a. Monad m => a -> m a
return [C.csdecl|$ty:ct *$id:(tupleField i);|]

allTrue :: [C.Exp] -> C.Exp
allTrue :: [Exp] -> Exp
allTrue [] = [C.cexp|true|]
allTrue [Exp
x] = Exp
x
allTrue (Exp
x : [Exp]
xs) = [C.cexp|$exp:x && $exp:(allTrue xs)|]

prepareEntryInputs ::
  [ExternalValue] ->
  CompilerM op s ([(C.Param, Maybe C.Exp)], [C.BlockItem])
prepareEntryInputs :: [ExternalValue]
-> CompilerM op s ([(Param, Maybe Exp)], [BlockItem])
prepareEntryInputs [ExternalValue]
args = CompilerM op s [(Param, Maybe Exp)]
-> CompilerM op s ([(Param, Maybe Exp)], [BlockItem])
forall op s a. CompilerM op s a -> CompilerM op s (a, [BlockItem])
collect' (CompilerM op s [(Param, Maybe Exp)]
 -> CompilerM op s ([(Param, Maybe Exp)], [BlockItem]))
-> CompilerM op s [(Param, Maybe Exp)]
-> CompilerM op s ([(Param, Maybe Exp)], [BlockItem])
forall a b. (a -> b) -> a -> b
$ (Int -> ExternalValue -> CompilerM op s (Param, Maybe Exp))
-> [Int] -> [ExternalValue] -> CompilerM op s [(Param, Maybe Exp)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM Int -> ExternalValue -> CompilerM op s (Param, Maybe Exp)
forall a op s.
Show a =>
a -> ExternalValue -> CompilerM op s (Param, Maybe Exp)
prepare [(Int
0 :: Int) ..] [ExternalValue]
args
  where
    arg_names :: Names
arg_names = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ (ExternalValue -> [VName]) -> [ExternalValue] -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ExternalValue -> [VName]
evNames [ExternalValue]
args
    evNames :: ExternalValue -> [VName]
evNames (OpaqueValue Uniqueness
_ String
_ [ValueDesc]
vds) = (ValueDesc -> VName) -> [ValueDesc] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map ValueDesc -> VName
vdName [ValueDesc]
vds
    evNames (TransparentValue Uniqueness
_ ValueDesc
vd) = [ValueDesc -> VName
vdName ValueDesc
vd]
    vdName :: ValueDesc -> VName
vdName (ArrayValue VName
v Space
_ PrimType
_ Signedness
_ [SubExp]
_) = VName
v
    vdName (ScalarValue PrimType
_ Signedness
_ VName
v) = VName
v

    prepare :: a -> ExternalValue -> CompilerM op s (Param, Maybe Exp)
prepare a
pno (TransparentValue Uniqueness
_ ValueDesc
vd) = do
      let pname :: String
pname = String
"in" String -> ShowS
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
pno
      (Type
ty, [Exp]
check) <- Publicness -> Exp -> ValueDesc -> CompilerM op s (Type, [Exp])
forall a op s.
ToExp a =>
Publicness -> a -> ValueDesc -> CompilerM op s (Type, [Exp])
prepareValue Publicness
Public [C.cexp|$id:pname|] ValueDesc
vd
      (Param, Maybe Exp) -> CompilerM op s (Param, Maybe Exp)
forall (m :: * -> *) a. Monad m => a -> m a
return
        ( [C.cparam|const $ty:ty $id:pname|],
          if [Exp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Exp]
check then Maybe Exp
forall a. Maybe a
Nothing else Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ [Exp] -> Exp
allTrue [Exp]
check
        )
    prepare a
pno (OpaqueValue Uniqueness
_ String
desc [ValueDesc]
vds) = do
      Type
ty <- String -> [ValueDesc] -> CompilerM op s Type
forall op s. String -> [ValueDesc] -> CompilerM op s Type
opaqueToCType String
desc [ValueDesc]
vds
      let pname :: String
pname = String
"in" String -> ShowS
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
pno
          field :: Int -> ValueDesc -> Exp
field Int
i ScalarValue {} = [C.cexp|$id:pname->$id:(tupleField i)|]
          field Int
i ArrayValue {} = [C.cexp|$id:pname->$id:(tupleField i)|]
      [[Exp]]
checks <- ((Type, [Exp]) -> [Exp]) -> [(Type, [Exp])] -> [[Exp]]
forall a b. (a -> b) -> [a] -> [b]
map (Type, [Exp]) -> [Exp]
forall a b. (a, b) -> b
snd ([(Type, [Exp])] -> [[Exp]])
-> CompilerM op s [(Type, [Exp])] -> CompilerM op s [[Exp]]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Exp -> ValueDesc -> CompilerM op s (Type, [Exp]))
-> [Exp] -> [ValueDesc] -> CompilerM op s [(Type, [Exp])]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (Publicness -> Exp -> ValueDesc -> CompilerM op s (Type, [Exp])
forall a op s.
ToExp a =>
Publicness -> a -> ValueDesc -> CompilerM op s (Type, [Exp])
prepareValue Publicness
Private) ((Int -> ValueDesc -> Exp) -> [Int] -> [ValueDesc] -> [Exp]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> ValueDesc -> Exp
field [Int
0 ..] [ValueDesc]
vds) [ValueDesc]
vds
      (Param, Maybe Exp) -> CompilerM op s (Param, Maybe Exp)
forall (m :: * -> *) a. Monad m => a -> m a
return
        ( [C.cparam|const $ty:ty *$id:pname|],
          if [Exp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([Exp] -> Bool) -> [Exp] -> Bool
forall a b. (a -> b) -> a -> b
$ [[Exp]] -> [Exp]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Exp]]
checks
            then Maybe Exp
forall a. Maybe a
Nothing
            else Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ [Exp] -> Exp
allTrue ([Exp] -> Exp) -> [Exp] -> Exp
forall a b. (a -> b) -> a -> b
$ [[Exp]] -> [Exp]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Exp]]
checks
        )

    prepareValue :: Publicness -> a -> ValueDesc -> CompilerM op s (Type, [Exp])
prepareValue Publicness
_ a
src (ScalarValue PrimType
pt Signedness
signed VName
name) = do
      let pt' :: Type
pt' = Signedness -> PrimType -> Type
primAPIType Signedness
signed PrimType
pt
          src' :: Exp
src' = PrimType -> Exp -> Exp
fromStorage PrimType
pt (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$ a -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp a
src SrcLoc
forall a. Monoid a => a
mempty
      Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|$id:name = $exp:src';|]
      (Type, [Exp]) -> CompilerM op s (Type, [Exp])
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
pt', [])
    prepareValue Publicness
pub a
src vd :: ValueDesc
vd@(ArrayValue VName
mem Space
_ PrimType
_ Signedness
_ [SubExp]
shape) = do
      Type
ty <- Publicness -> ValueDesc -> CompilerM op s Type
forall op s. Publicness -> ValueDesc -> CompilerM op s Type
valueDescToCType Publicness
pub ValueDesc
vd

      Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|$exp:mem = $exp:src->mem;|]

      let rank :: Int
rank = [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
shape
          maybeCopyDim :: SubExp -> a -> (Maybe Stm, Exp)
maybeCopyDim (Var VName
d) a
i
            | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName
d VName -> Names -> Bool
`nameIn` Names
arg_names =
              ( Stm -> Maybe Stm
forall a. a -> Maybe a
Just [C.cstm|$id:d = $exp:src->shape[$int:i];|],
                [C.cexp|$id:d == $exp:src->shape[$int:i]|]
              )
          maybeCopyDim SubExp
x a
i =
            ( Maybe Stm
forall a. Maybe a
Nothing,
              [C.cexp|$exp:x == $exp:src->shape[$int:i]|]
            )

      let ([Maybe Stm]
sets, [Exp]
checks) =
            [(Maybe Stm, Exp)] -> ([Maybe Stm], [Exp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Maybe Stm, Exp)] -> ([Maybe Stm], [Exp]))
-> [(Maybe Stm, Exp)] -> ([Maybe Stm], [Exp])
forall a b. (a -> b) -> a -> b
$ (SubExp -> Int -> (Maybe Stm, Exp))
-> [SubExp] -> [Int] -> [(Maybe Stm, Exp)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SubExp -> Int -> (Maybe Stm, Exp)
forall a. (Show a, Integral a) => SubExp -> a -> (Maybe Stm, Exp)
maybeCopyDim [SubExp]
shape [Int
0 .. Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
      [Stm] -> CompilerM op s ()
forall op s. [Stm] -> CompilerM op s ()
stms ([Stm] -> CompilerM op s ()) -> [Stm] -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ [Maybe Stm] -> [Stm]
forall a. [Maybe a] -> [a]
catMaybes [Maybe Stm]
sets

      (Type, [Exp]) -> CompilerM op s (Type, [Exp])
forall (m :: * -> *) a. Monad m => a -> m a
return ([C.cty|$ty:ty*|], [Exp]
checks)

prepareEntryOutputs :: [ExternalValue] -> CompilerM op s ([C.Param], [C.BlockItem])
prepareEntryOutputs :: [ExternalValue] -> CompilerM op s ([Param], [BlockItem])
prepareEntryOutputs = CompilerM op s [Param] -> CompilerM op s ([Param], [BlockItem])
forall op s a. CompilerM op s a -> CompilerM op s (a, [BlockItem])
collect' (CompilerM op s [Param] -> CompilerM op s ([Param], [BlockItem]))
-> ([ExternalValue] -> CompilerM op s [Param])
-> [ExternalValue]
-> CompilerM op s ([Param], [BlockItem])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> ExternalValue -> CompilerM op s Param)
-> [Int] -> [ExternalValue] -> CompilerM op s [Param]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM Int -> ExternalValue -> CompilerM op s Param
forall a op s. Show a => a -> ExternalValue -> CompilerM op s Param
prepare [(Int
0 :: Int) ..]
  where
    prepare :: a -> ExternalValue -> CompilerM op s Param
prepare a
pno (TransparentValue Uniqueness
_ ValueDesc
vd) = do
      let pname :: String
pname = String
"out" String -> ShowS
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
pno
      Type
ty <- Publicness -> ValueDesc -> CompilerM op s Type
forall op s. Publicness -> ValueDesc -> CompilerM op s Type
valueDescToCType Publicness
Public ValueDesc
vd

      case ValueDesc
vd of
        ArrayValue {} -> do
          Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|assert((*$id:pname = ($ty:ty*) malloc(sizeof($ty:ty))) != NULL);|]
          Exp -> ValueDesc -> CompilerM op s ()
forall a op s. ToExp a => a -> ValueDesc -> CompilerM op s ()
prepareValue [C.cexp|*$id:pname|] ValueDesc
vd
          Param -> CompilerM op s Param
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cparam|$ty:ty **$id:pname|]
        ScalarValue {} -> do
          Exp -> ValueDesc -> CompilerM op s ()
forall a op s. ToExp a => a -> ValueDesc -> CompilerM op s ()
prepareValue [C.cexp|*$id:pname|] ValueDesc
vd
          Param -> CompilerM op s Param
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cparam|$ty:ty *$id:pname|]
    prepare a
pno (OpaqueValue Uniqueness
_ String
desc [ValueDesc]
vds) = do
      let pname :: String
pname = String
"out" String -> ShowS
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
pno
      Type
ty <- String -> [ValueDesc] -> CompilerM op s Type
forall op s. String -> [ValueDesc] -> CompilerM op s Type
opaqueToCType String
desc [ValueDesc]
vds
      [Type]
vd_ts <- (ValueDesc -> CompilerM op s Type)
-> [ValueDesc] -> CompilerM op s [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Publicness -> ValueDesc -> CompilerM op s Type
forall op s. Publicness -> ValueDesc -> CompilerM op s Type
valueDescToCType Publicness
Private) [ValueDesc]
vds

      Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|assert((*$id:pname = ($ty:ty*) malloc(sizeof($ty:ty))) != NULL);|]

      [(Int, Type, ValueDesc)]
-> ((Int, Type, ValueDesc) -> CompilerM op s ())
-> CompilerM op s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Int] -> [Type] -> [ValueDesc] -> [(Int, Type, ValueDesc)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Int
0 ..] [Type]
vd_ts [ValueDesc]
vds) (((Int, Type, ValueDesc) -> CompilerM op s ())
 -> CompilerM op s ())
-> ((Int, Type, ValueDesc) -> CompilerM op s ())
-> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \(Int
i, Type
ct, ValueDesc
vd) -> do
        let field :: Exp
field = [C.cexp|(*$id:pname)->$id:(tupleField i)|]
        case ValueDesc
vd of
          ScalarValue {} -> () -> CompilerM op s ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
          ValueDesc
_ -> Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|assert(($exp:field = ($ty:ct*) malloc(sizeof($ty:ct))) != NULL);|]
        Exp -> ValueDesc -> CompilerM op s ()
forall a op s. ToExp a => a -> ValueDesc -> CompilerM op s ()
prepareValue Exp
field ValueDesc
vd

      Param -> CompilerM op s Param
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cparam|$ty:ty **$id:pname|]

    prepareValue :: a -> ValueDesc -> CompilerM op s ()
prepareValue a
dest (ScalarValue PrimType
t Signedness
_ VName
name) =
      let name' :: Exp
name' = PrimType -> Exp -> Exp
toStorage PrimType
t (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$ VName -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp VName
name SrcLoc
forall a. Monoid a => a
mempty
       in Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|$exp:dest = $exp:name';|]
    prepareValue a
dest (ArrayValue VName
mem Space
_ PrimType
_ Signedness
_ [SubExp]
shape) = do
      Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|$exp:dest->mem = $id:mem;|]

      let rank :: Int
rank = [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
shape
          maybeCopyDim :: SubExp -> a -> Stm
maybeCopyDim (Constant PrimValue
x) a
i =
            [C.cstm|$exp:dest->shape[$int:i] = $exp:x;|]
          maybeCopyDim (Var VName
d) a
i =
            [C.cstm|$exp:dest->shape[$int:i] = $id:d;|]
      [Stm] -> CompilerM op s ()
forall op s. [Stm] -> CompilerM op s ()
stms ([Stm] -> CompilerM op s ()) -> [Stm] -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ (SubExp -> Int -> Stm) -> [SubExp] -> [Int] -> [Stm]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SubExp -> Int -> Stm
forall a. (Show a, Integral a) => SubExp -> a -> Stm
maybeCopyDim [SubExp]
shape [Int
0 .. Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]

onEntryPoint ::
  [C.BlockItem] ->
  Name ->
  Function op ->
  CompilerM op s (Maybe C.Definition)
onEntryPoint :: [BlockItem]
-> Name -> Function op -> CompilerM op s (Maybe Definition)
onEntryPoint [BlockItem]
_ Name
_ (Function Maybe Name
Nothing [Param]
_ [Param]
_ Code op
_ [ExternalValue]
_ [ExternalValue]
_) = Maybe Definition -> CompilerM op s (Maybe Definition)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe Definition
forall a. Maybe a
Nothing
onEntryPoint [BlockItem]
get_consts Name
fname (Function (Just Name
ename) [Param]
outputs [Param]
inputs Code op
_ [ExternalValue]
results [ExternalValue]
args) = do
  let out_args :: [Exp]
out_args = (Param -> Exp) -> [Param] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (\Param
p -> [C.cexp|&$id:(paramName p)|]) [Param]
outputs
      in_args :: [Exp]
in_args = (Param -> Exp) -> [Param] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (\Param
p -> [C.cexp|$id:(paramName p)|]) [Param]
inputs

  [BlockItem]
inputdecls <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
collect (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ (Param -> CompilerM op s ()) -> [Param] -> CompilerM op s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Param -> CompilerM op s ()
forall op s. Param -> CompilerM op s ()
stubParam [Param]
inputs
  [BlockItem]
outputdecls <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
collect (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ (Param -> CompilerM op s ()) -> [Param] -> CompilerM op s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Param -> CompilerM op s ()
forall op s. Param -> CompilerM op s ()
stubParam [Param]
outputs

  String
entry_point_function_name <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName (String -> CompilerM op s String)
-> String -> CompilerM op s String
forall a b. (a -> b) -> a -> b
$ String
"entry_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Name -> String
nameToString Name
ename

  ([(Param, Maybe Exp)]
inputs', [BlockItem]
unpack_entry_inputs) <- [ExternalValue]
-> CompilerM op s ([(Param, Maybe Exp)], [BlockItem])
forall op s.
[ExternalValue]
-> CompilerM op s ([(Param, Maybe Exp)], [BlockItem])
prepareEntryInputs [ExternalValue]
args
  let ([Param]
entry_point_input_params, [Maybe Exp]
entry_point_input_checks) = [(Param, Maybe Exp)] -> ([Param], [Maybe Exp])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Param, Maybe Exp)]
inputs'

  ([Param]
entry_point_output_params, [BlockItem]
pack_entry_outputs) <-
    [ExternalValue] -> CompilerM op s ([Param], [BlockItem])
forall op s.
[ExternalValue] -> CompilerM op s ([Param], [BlockItem])
prepareEntryOutputs [ExternalValue]
results

  Type
ctx_ty <- CompilerM op s Type
forall op s. CompilerM op s Type
contextType

  HeaderSection -> Definition -> CompilerM op s ()
forall op s. HeaderSection -> Definition -> CompilerM op s ()
headerDecl
    HeaderSection
EntryDecl
    [C.cedecl|int $id:entry_point_function_name
                                     ($ty:ctx_ty *ctx,
                                      $params:entry_point_output_params,
                                      $params:entry_point_input_params);|]

  let checks :: [Exp]
checks = [Maybe Exp] -> [Exp]
forall a. [Maybe a] -> [a]
catMaybes [Maybe Exp]
entry_point_input_checks
      check_input :: [BlockItem]
check_input =
        if [Exp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Exp]
checks
          then []
          else
            [C.citems|
         if (!($exp:(allTrue (catMaybes entry_point_input_checks)))) {
           ret = 1;
           if (!ctx->error) {
             ctx->error = msgprintf("Error: entry point arguments have invalid sizes.\n");
           }
         }|]

      critical :: [BlockItem]
critical =
        [C.citems|
         $items:unpack_entry_inputs
         $items:check_input
         if (ret == 0) {
           ret = $id:(funName fname)(ctx, $args:out_args, $args:in_args);
           if (ret == 0) {
             $items:get_consts

             $items:pack_entry_outputs
           }
         }
        |]

  Operations op s
ops <- (CompilerEnv op s -> Operations op s)
-> CompilerM op s (Operations op s)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks CompilerEnv op s -> Operations op s
forall op s. CompilerEnv op s -> Operations op s
envOperations

  Maybe Definition -> CompilerM op s (Maybe Definition)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Definition -> CompilerM op s (Maybe Definition))
-> (Definition -> Maybe Definition)
-> Definition
-> CompilerM op s (Maybe Definition)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Definition -> Maybe Definition
forall a. a -> Maybe a
Just (Definition -> CompilerM op s (Maybe Definition))
-> Definition -> CompilerM op s (Maybe Definition)
forall a b. (a -> b) -> a -> b
$
    [C.cedecl|
       int $id:entry_point_function_name
           ($ty:ctx_ty *ctx,
            $params:entry_point_output_params,
            $params:entry_point_input_params) {
         $items:inputdecls
         $items:outputdecls

         int ret = 0;

         $items:(criticalSection ops critical)

         return ret;
       }|]
  where
    stubParam :: Param -> CompilerM op s ()
stubParam (MemParam VName
name Space
space) =
      VName -> Space -> CompilerM op s ()
forall op s. VName -> Space -> CompilerM op s ()
declMem VName
name Space
space
    stubParam (ScalarParam VName
name PrimType
ty) = do
      let ty' :: Type
ty' = PrimType -> Type
primTypeToCType PrimType
ty
      InitGroup -> CompilerM op s ()
forall op s. InitGroup -> CompilerM op s ()
decl [C.cdecl|$ty:ty' $id:name;|]

-- | The result of compilation to C is multiple parts, which can be
-- put together in various ways.  The obvious way is to concatenate
-- all of them, which yields a CLI program.  Another is to compile the
-- library part by itself, and use the header file to call into it.
data CParts = CParts
  { CParts -> Text
cHeader :: T.Text,
    -- | Utility definitions that must be visible
    -- to both CLI and library parts.
    CParts -> Text
cUtils :: T.Text,
    CParts -> Text
cCLI :: T.Text,
    CParts -> Text
cServer :: T.Text,
    CParts -> Text
cLib :: T.Text
  }

gnuSource :: T.Text
gnuSource :: Text
gnuSource =
  [untrimming|
// We need to define _GNU_SOURCE before
// _any_ headers files are imported to get
// the usage statistics of a thread (i.e. have RUSAGE_THREAD) on GNU/Linux
// https://manpages.courier-mta.org/htmlman2/getrusage.2.html
#ifndef _GNU_SOURCE // Avoid possible double-definition warning.
#define _GNU_SOURCE
#endif
|]

-- We may generate variables that are never used (e.g. for
-- certificates) or functions that are never called (e.g. unused
-- intrinsics), and generated code may have other cosmetic issues that
-- compilers warn about.  We disable these warnings to not clutter the
-- compilation logs.
disableWarnings :: T.Text
disableWarnings :: Text
disableWarnings =
  [untrimming|
#ifdef __clang__
#pragma clang diagnostic ignored "-Wunused-function"
#pragma clang diagnostic ignored "-Wunused-variable"
#pragma clang diagnostic ignored "-Wparentheses"
#pragma clang diagnostic ignored "-Wunused-label"
#elif __GNUC__
#pragma GCC diagnostic ignored "-Wunused-function"
#pragma GCC diagnostic ignored "-Wunused-variable"
#pragma GCC diagnostic ignored "-Wparentheses"
#pragma GCC diagnostic ignored "-Wunused-label"
#pragma GCC diagnostic ignored "-Wunused-but-set-variable"
#endif
|]

-- | Produce header and implementation files.
asLibrary :: CParts -> (T.Text, T.Text)
asLibrary :: CParts -> (Text, Text)
asLibrary CParts
parts =
  ( Text
"#pragma once\n\n" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> CParts -> Text
cHeader CParts
parts,
    Text
gnuSource Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
disableWarnings Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> CParts -> Text
cHeader CParts
parts Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> CParts -> Text
cUtils CParts
parts Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> CParts -> Text
cLib CParts
parts
  )

-- | As executable with command-line interface.
asExecutable :: CParts -> T.Text
asExecutable :: CParts -> Text
asExecutable CParts
parts =
  Text
gnuSource Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
disableWarnings Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> CParts -> Text
cHeader CParts
parts Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> CParts -> Text
cUtils CParts
parts Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> CParts -> Text
cCLI CParts
parts Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> CParts -> Text
cLib CParts
parts

-- | As server executable.
asServer :: CParts -> T.Text
asServer :: CParts -> Text
asServer CParts
parts =
  Text
gnuSource Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
disableWarnings Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> CParts -> Text
cHeader CParts
parts Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> CParts -> Text
cUtils CParts
parts Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> CParts -> Text
cServer CParts
parts Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> CParts -> Text
cLib CParts
parts

-- | Compile imperative program to a C program.  Always uses the
-- function named "main" as entry point, so make sure it is defined.
compileProg ::
  MonadFreshNames m =>
  T.Text ->
  Operations op () ->
  CompilerM op () () ->
  T.Text ->
  [Space] ->
  [Option] ->
  Definitions op ->
  m CParts
compileProg :: Text
-> Operations op ()
-> CompilerM op () ()
-> Text
-> [Space]
-> [Option]
-> Definitions op
-> m CParts
compileProg Text
backend Operations op ()
ops CompilerM op () ()
extra Text
header_extra [Space]
spaces [Option]
options Definitions op
prog = do
  VNameSource
src <- m VNameSource
forall (m :: * -> *). MonadFreshNames m => m VNameSource
getNameSource
  let ((Text
prototypes, Text
definitions, Text
entry_point_decls), CompilerState ()
endstate) =
        Operations op ()
-> VNameSource
-> ()
-> CompilerM op () (Text, Text, Text)
-> ((Text, Text, Text), CompilerState ())
forall op s a.
Operations op s
-> VNameSource -> s -> CompilerM op s a -> (a, CompilerState s)
runCompilerM Operations op ()
ops VNameSource
src () CompilerM op () (Text, Text, Text)
compileProg'
      initdecls :: Text
initdecls = CompilerState () -> Text
forall s. CompilerState s -> Text
initDecls CompilerState ()
endstate
      entrydecls :: Text
entrydecls = CompilerState () -> Text
forall s. CompilerState s -> Text
entryDecls CompilerState ()
endstate
      arraydecls :: Text
arraydecls = CompilerState () -> Text
forall s. CompilerState s -> Text
arrayDecls CompilerState ()
endstate
      opaquedecls :: Text
opaquedecls = CompilerState () -> Text
forall s. CompilerState s -> Text
opaqueDecls CompilerState ()
endstate
      miscdecls :: Text
miscdecls = CompilerState () -> Text
forall s. CompilerState s -> Text
miscDecls CompilerState ()
endstate

  let headerdefs :: Text
headerdefs =
        [untrimming|
// Headers\n")
#include <stdint.h>
#include <stddef.h>
#include <stdbool.h>
#include <stdio.h>
#include <float.h>
$header_extra
#ifdef __cplusplus
extern "C" {
#endif

// Initialisation
$initdecls

// Arrays
$arraydecls

// Opaque values
$opaquedecls

// Entry points
$entrydecls

// Miscellaneous
$miscdecls
#define FUTHARK_BACKEND_$backend

#ifdef __cplusplus
}
#endif
|]

  let utildefs :: Text
utildefs =
        [untrimming|
#include <stdio.h>
#include <stdlib.h>
#include <stdbool.h>
#include <math.h>
#include <stdint.h>
// If NDEBUG is set, the assert() macro will do nothing. Since Futhark
// (unfortunately) makes use of assert() for error detection (and even some
// side effects), we want to avoid that.
#undef NDEBUG
#include <assert.h>
#include <stdarg.h>
$utilH
$halfH
$timingH
|]

  let early_decls :: Text
early_decls = [Text] -> Text
T.unlines ([Text] -> Text) -> [Text] -> Text
forall a b. (a -> b) -> a -> b
$ (Definition -> Text) -> [Definition] -> [Text]
forall a b. (a -> b) -> [a] -> [b]
map Definition -> Text
forall a. Pretty a => a -> Text
prettyText ([Definition] -> [Text]) -> [Definition] -> [Text]
forall a b. (a -> b) -> a -> b
$ DList Definition -> [Definition]
forall a. DList a -> [a]
DL.toList (DList Definition -> [Definition])
-> DList Definition -> [Definition]
forall a b. (a -> b) -> a -> b
$ CompilerState () -> DList Definition
forall s. CompilerState s -> DList Definition
compEarlyDecls CompilerState ()
endstate
      lib_decls :: Text
lib_decls = [Text] -> Text
T.unlines ([Text] -> Text) -> [Text] -> Text
forall a b. (a -> b) -> a -> b
$ (Definition -> Text) -> [Definition] -> [Text]
forall a b. (a -> b) -> [a] -> [b]
map Definition -> Text
forall a. Pretty a => a -> Text
prettyText ([Definition] -> [Text]) -> [Definition] -> [Text]
forall a b. (a -> b) -> a -> b
$ DList Definition -> [Definition]
forall a. DList a -> [a]
DL.toList (DList Definition -> [Definition])
-> DList Definition -> [Definition]
forall a b. (a -> b) -> a -> b
$ CompilerState () -> DList Definition
forall s. CompilerState s -> DList Definition
compLibDecls CompilerState ()
endstate
      clidefs :: Text
clidefs = [Option] -> Functions op -> Text
forall a. [Option] -> Functions a -> Text
cliDefs [Option]
options (Functions op -> Text) -> Functions op -> Text
forall a b. (a -> b) -> a -> b
$ [(Name, Function op)] -> Functions op
forall a. [(Name, Function a)] -> Functions a
Functions [(Name, Function op)]
entry_funs
      serverdefs :: Text
serverdefs = [Option] -> Functions op -> Text
forall a. [Option] -> Functions a -> Text
serverDefs [Option]
options (Functions op -> Text) -> Functions op -> Text
forall a b. (a -> b) -> a -> b
$ [(Name, Function op)] -> Functions op
forall a. [(Name, Function a)] -> Functions a
Functions [(Name, Function op)]
entry_funs
      libdefs :: Text
libdefs =
        [untrimming|
#ifdef _MSC_VER
#define inline __inline
#endif
#include <string.h>
#include <string.h>
#include <errno.h>
#include <assert.h>
#include <ctype.h>

$header_extra

$lockH

#define FUTHARK_F64_ENABLED

$cScalarDefs

$early_decls

$prototypes

$lib_decls

$definitions

$entry_point_decls
  |]

  CParts -> m CParts
forall (m :: * -> *) a. Monad m => a -> m a
return
    CParts :: Text -> Text -> Text -> Text -> Text -> CParts
CParts
      { cHeader :: Text
cHeader = Text
headerdefs,
        cUtils :: Text
cUtils = Text
utildefs,
        cCLI :: Text
cCLI = Text
clidefs,
        cServer :: Text
cServer = Text
serverdefs,
        cLib :: Text
cLib = Text
libdefs
      }
  where
    Definitions Constants op
consts (Functions [(Name, Function op)]
funs) = Definitions op
prog
    entry_funs :: [(Name, Function op)]
entry_funs = ((Name, Function op) -> Bool)
-> [(Name, Function op)] -> [(Name, Function op)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Maybe Name -> Bool
forall a. Maybe a -> Bool
isJust (Maybe Name -> Bool)
-> ((Name, Function op) -> Maybe Name)
-> (Name, Function op)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Function op -> Maybe Name
forall a. FunctionT a -> Maybe Name
functionEntry (Function op -> Maybe Name)
-> ((Name, Function op) -> Function op)
-> (Name, Function op)
-> Maybe Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Name, Function op) -> Function op
forall a b. (a, b) -> b
snd) [(Name, Function op)]
funs

    compileProg' :: CompilerM op () (Text, Text, Text)
compileProg' = do
      ([Definition]
memstructs, [[Definition]]
memfuns, [BlockItem]
memreport) <- [(Definition, [Definition], BlockItem)]
-> ([Definition], [[Definition]], [BlockItem])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(Definition, [Definition], BlockItem)]
 -> ([Definition], [[Definition]], [BlockItem]))
-> CompilerM op () [(Definition, [Definition], BlockItem)]
-> CompilerM op () ([Definition], [[Definition]], [BlockItem])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Space -> CompilerM op () (Definition, [Definition], BlockItem))
-> [Space]
-> CompilerM op () [(Definition, [Definition], BlockItem)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Space -> CompilerM op () (Definition, [Definition], BlockItem)
forall op s.
Space -> CompilerM op s (Definition, [Definition], BlockItem)
defineMemorySpace [Space]
spaces

      [BlockItem]
get_consts <- Constants op -> CompilerM op () [BlockItem]
forall op s. Constants op -> CompilerM op s [BlockItem]
compileConstants Constants op
consts

      Type
ctx_ty <- CompilerM op () Type
forall op s. CompilerM op s Type
contextType

      ([Definition]
prototypes, [Func]
functions) <-
        [(Definition, Func)] -> ([Definition], [Func])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Definition, Func)] -> ([Definition], [Func]))
-> CompilerM op () [(Definition, Func)]
-> CompilerM op () ([Definition], [Func])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Name, Function op) -> CompilerM op () (Definition, Func))
-> [(Name, Function op)] -> CompilerM op () [(Definition, Func)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([BlockItem]
-> [Param]
-> (Name, Function op)
-> CompilerM op () (Definition, Func)
forall op s.
[BlockItem]
-> [Param]
-> (Name, Function op)
-> CompilerM op s (Definition, Func)
compileFun [BlockItem]
get_consts [[C.cparam|$ty:ctx_ty *ctx|]]) [(Name, Function op)]
funs

      (Definition -> CompilerM op () ())
-> [Definition] -> CompilerM op () ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Definition -> CompilerM op () ()
forall op s. Definition -> CompilerM op s ()
earlyDecl [Definition]
memstructs
      [Definition]
entry_points <-
        [Maybe Definition] -> [Definition]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe Definition] -> [Definition])
-> CompilerM op () [Maybe Definition]
-> CompilerM op () [Definition]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Name, Function op) -> CompilerM op () (Maybe Definition))
-> [(Name, Function op)] -> CompilerM op () [Maybe Definition]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((Name -> Function op -> CompilerM op () (Maybe Definition))
-> (Name, Function op) -> CompilerM op () (Maybe Definition)
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ([BlockItem]
-> Name -> Function op -> CompilerM op () (Maybe Definition)
forall op s.
[BlockItem]
-> Name -> Function op -> CompilerM op s (Maybe Definition)
onEntryPoint [BlockItem]
get_consts)) [(Name, Function op)]
funs

      CompilerM op () ()
extra

      (Definition -> CompilerM op () ())
-> [Definition] -> CompilerM op () ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Definition -> CompilerM op () ()
forall op s. Definition -> CompilerM op s ()
earlyDecl ([Definition] -> CompilerM op () ())
-> [Definition] -> CompilerM op () ()
forall a b. (a -> b) -> a -> b
$ [[Definition]] -> [Definition]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Definition]]
memfuns

      [BlockItem] -> CompilerM op () ()
forall op s. [BlockItem] -> CompilerM op s ()
commonLibFuns [BlockItem]
memreport

      (Text, Text, Text) -> CompilerM op () (Text, Text, Text)
forall (m :: * -> *) a. Monad m => a -> m a
return
        ( [Text] -> Text
T.unlines ([Text] -> Text) -> [Text] -> Text
forall a b. (a -> b) -> a -> b
$ (Definition -> Text) -> [Definition] -> [Text]
forall a b. (a -> b) -> [a] -> [b]
map Definition -> Text
forall a. Pretty a => a -> Text
prettyText [Definition]
prototypes,
          [Text] -> Text
T.unlines ([Text] -> Text) -> [Text] -> Text
forall a b. (a -> b) -> a -> b
$ (Func -> Text) -> [Func] -> [Text]
forall a b. (a -> b) -> [a] -> [b]
map (Definition -> Text
forall a. Pretty a => a -> Text
prettyText (Definition -> Text) -> (Func -> Definition) -> Func -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Func -> Definition
funcToDef) [Func]
functions,
          [Text] -> Text
T.unlines ([Text] -> Text) -> [Text] -> Text
forall a b. (a -> b) -> a -> b
$ (Definition -> Text) -> [Definition] -> [Text]
forall a b. (a -> b) -> [a] -> [b]
map Definition -> Text
forall a. Pretty a => a -> Text
prettyText [Definition]
entry_points
        )

    funcToDef :: Func -> Definition
funcToDef Func
func = Func -> SrcLoc -> Definition
C.FuncDef Func
func SrcLoc
loc
      where
        loc :: SrcLoc
loc = case Func
func of
          C.OldFunc _ _ _ _ _ _ l -> SrcLoc
l
          C.Func _ _ _ _ _ l -> SrcLoc
l

commonLibFuns :: [C.BlockItem] -> CompilerM op s ()
commonLibFuns :: [BlockItem] -> CompilerM op s ()
commonLibFuns [BlockItem]
memreport = do
  CompilerM op s ()
forall op s. CompilerM op s ()
generateAPITypes
  Type
ctx <- CompilerM op s Type
forall op s. CompilerM op s Type
contextType
  Operations op s
ops <- (CompilerEnv op s -> Operations op s)
-> CompilerM op s (Operations op s)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks CompilerEnv op s -> Operations op s
forall op s. CompilerEnv op s -> Operations op s
envOperations
  [BlockItem]
profilereport <- (CompilerState s -> [BlockItem]) -> CompilerM op s [BlockItem]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((CompilerState s -> [BlockItem]) -> CompilerM op s [BlockItem])
-> (CompilerState s -> [BlockItem]) -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ DList BlockItem -> [BlockItem]
forall a. DList a -> [a]
DL.toList (DList BlockItem -> [BlockItem])
-> (CompilerState s -> DList BlockItem)
-> CompilerState s
-> [BlockItem]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerState s -> DList BlockItem
forall s. CompilerState s -> DList BlockItem
compProfileItems

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
publicDef_ String
"get_num_sizes" HeaderSection
InitDecl ((String -> (Definition, Definition)) -> CompilerM op s ())
-> (String -> (Definition, Definition)) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ( [C.cedecl|int $id:s(void);|],
      [C.cedecl|int $id:s(void) {
                return sizeof(size_names)/sizeof(size_names[0]);
              }|]
    )

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
publicDef_ String
"get_size_name" HeaderSection
InitDecl ((String -> (Definition, Definition)) -> CompilerM op s ())
-> (String -> (Definition, Definition)) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ( [C.cedecl|const char* $id:s(int);|],
      [C.cedecl|const char* $id:s(int i) {
                return size_names[i];
              }|]
    )

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
publicDef_ String
"get_size_class" HeaderSection
InitDecl ((String -> (Definition, Definition)) -> CompilerM op s ())
-> (String -> (Definition, Definition)) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ( [C.cedecl|const char* $id:s(int);|],
      [C.cedecl|const char* $id:s(int i) {
                return size_classes[i];
              }|]
    )

  String
sync <- String -> CompilerM op s String
forall op s. String -> CompilerM op s String
publicName String
"context_sync"
  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
publicDef_ String
"context_report" HeaderSection
MiscDecl ((String -> (Definition, Definition)) -> CompilerM op s ())
-> (String -> (Definition, Definition)) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ( [C.cedecl|char* $id:s($ty:ctx *ctx);|],
      [C.cedecl|char* $id:s($ty:ctx *ctx) {
                 if ($id:sync(ctx) != 0) {
                   return NULL;
                 }

                 struct str_builder builder;
                 str_builder_init(&builder);
                 if (ctx->detail_memory || ctx->profiling || ctx->logging) {
                   $items:memreport
                 }
                 if (ctx->profiling) {
                   $items:profilereport
                 }
                 return builder.str;
               }|]
    )

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
publicDef_ String
"context_get_error" HeaderSection
MiscDecl ((String -> (Definition, Definition)) -> CompilerM op s ())
-> (String -> (Definition, Definition)) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ( [C.cedecl|char* $id:s($ty:ctx* ctx);|],
      [C.cedecl|char* $id:s($ty:ctx* ctx) {
                         char* error = ctx->error;
                         ctx->error = NULL;
                         return error;
                       }|]
    )

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
publicDef_ String
"context_set_logging_file" HeaderSection
MiscDecl ((String -> (Definition, Definition)) -> CompilerM op s ())
-> (String -> (Definition, Definition)) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ( [C.cedecl|void $id:s($ty:ctx* ctx, typename FILE* f);|],
      [C.cedecl|void $id:s($ty:ctx* ctx, typename FILE* f) {
                  ctx->log = f;
                }|]
    )

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
publicDef_ String
"context_pause_profiling" HeaderSection
MiscDecl ((String -> (Definition, Definition)) -> CompilerM op s ())
-> (String -> (Definition, Definition)) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ( [C.cedecl|void $id:s($ty:ctx* ctx);|],
      [C.cedecl|void $id:s($ty:ctx* ctx) {
                 ctx->profiling_paused = 1;
               }|]
    )

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
publicDef_ String
"context_unpause_profiling" HeaderSection
MiscDecl ((String -> (Definition, Definition)) -> CompilerM op s ())
-> (String -> (Definition, Definition)) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ( [C.cedecl|void $id:s($ty:ctx* ctx);|],
      [C.cedecl|void $id:s($ty:ctx* ctx) {
                 ctx->profiling_paused = 0;
               }|]
    )

  [BlockItem]
clears <- (CompilerState s -> [BlockItem]) -> CompilerM op s [BlockItem]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((CompilerState s -> [BlockItem]) -> CompilerM op s [BlockItem])
-> (CompilerState s -> [BlockItem]) -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ DList BlockItem -> [BlockItem]
forall a. DList a -> [a]
DL.toList (DList BlockItem -> [BlockItem])
-> (CompilerState s -> DList BlockItem)
-> CompilerState s
-> [BlockItem]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerState s -> DList BlockItem
forall s. CompilerState s -> DList BlockItem
compClearItems
  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
publicDef_ String
"context_clear_caches" HeaderSection
MiscDecl ((String -> (Definition, Definition)) -> CompilerM op s ())
-> (String -> (Definition, Definition)) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ( [C.cedecl|int $id:s($ty:ctx* ctx);|],
      [C.cedecl|int $id:s($ty:ctx* ctx) {
                         $items:(criticalSection ops clears)
                         return ctx->error != NULL;
                       }|]
    )

compileConstants :: Constants op -> CompilerM op s [C.BlockItem]
compileConstants :: Constants op -> CompilerM op s [BlockItem]
compileConstants (Constants [Param]
ps Code op
init_consts) = do
  Type
ctx_ty <- CompilerM op s Type
forall op s. CompilerM op s Type
contextType
  [FieldGroup]
const_fields <- (Param -> CompilerM op s FieldGroup)
-> [Param] -> CompilerM op s [FieldGroup]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Param -> CompilerM op s FieldGroup
forall op s. Param -> CompilerM op s FieldGroup
constParamField [Param]
ps
  -- Avoid an empty struct, as that is apparently undefined behaviour.
  let const_fields' :: [FieldGroup]
const_fields'
        | [FieldGroup] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [FieldGroup]
const_fields = [[C.csdecl|int dummy;|]]
        | Bool
otherwise = [FieldGroup]
const_fields
  Id -> Type -> Maybe Exp -> CompilerM op s ()
forall op s. Id -> Type -> Maybe Exp -> CompilerM op s ()
contextField Id
"constants" [C.cty|struct { $sdecls:const_fields' }|] Maybe Exp
forall a. Maybe a
Nothing
  Definition -> CompilerM op s ()
forall op s. Definition -> CompilerM op s ()
earlyDecl [C.cedecl|static int init_constants($ty:ctx_ty*);|]
  Definition -> CompilerM op s ()
forall op s. Definition -> CompilerM op s ()
earlyDecl [C.cedecl|static int free_constants($ty:ctx_ty*);|]

  -- We locally define macros for the constants, so that when we
  -- generate assignments to local variables, we actually assign into
  -- the constants struct.  This is not needed for functions, because
  -- they can only read constants, not write them.
  let ([BlockItem]
defs, [BlockItem]
undefs) = [(BlockItem, BlockItem)] -> ([BlockItem], [BlockItem])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(BlockItem, BlockItem)] -> ([BlockItem], [BlockItem]))
-> [(BlockItem, BlockItem)] -> ([BlockItem], [BlockItem])
forall a b. (a -> b) -> a -> b
$ (Param -> (BlockItem, BlockItem))
-> [Param] -> [(BlockItem, BlockItem)]
forall a b. (a -> b) -> [a] -> [b]
map Param -> (BlockItem, BlockItem)
constMacro [Param]
ps
  [BlockItem]
init_consts' <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
blockScope (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ do
    (Param -> CompilerM op s ()) -> [Param] -> CompilerM op s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Param -> CompilerM op s ()
forall op s. Param -> CompilerM op s ()
resetMemConst [Param]
ps
    Code op -> CompilerM op s ()
forall op s. Code op -> CompilerM op s ()
compileCode Code op
init_consts
  Definition -> CompilerM op s ()
forall op s. Definition -> CompilerM op s ()
libDecl
    [C.cedecl|static int init_constants($ty:ctx_ty *ctx) {
      (void)ctx;
      int err = 0;
      $items:defs
      $items:init_consts'
      $items:undefs
      cleanup:
      return err;
    }|]

  [BlockItem]
free_consts <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
collect (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ (Param -> CompilerM op s ()) -> [Param] -> CompilerM op s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Param -> CompilerM op s ()
forall op s. Param -> CompilerM op s ()
freeConst [Param]
ps
  Definition -> CompilerM op s ()
forall op s. Definition -> CompilerM op s ()
libDecl
    [C.cedecl|static int free_constants($ty:ctx_ty *ctx) {
      (void)ctx;
      $items:free_consts
      return 0;
    }|]

  (Param -> CompilerM op s BlockItem)
-> [Param] -> CompilerM op s [BlockItem]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Param -> CompilerM op s BlockItem
forall op s. Param -> CompilerM op s BlockItem
getConst [Param]
ps
  where
    constParamField :: Param -> CompilerM op s FieldGroup
constParamField (ScalarParam VName
name PrimType
bt) = do
      let ctp :: Type
ctp = PrimType -> Type
primTypeToCType PrimType
bt
      FieldGroup -> CompilerM op s FieldGroup
forall (m :: * -> *) a. Monad m => a -> m a
return [C.csdecl|$ty:ctp $id:name;|]
    constParamField (MemParam VName
name Space
space) = do
      Type
ty <- VName -> Space -> CompilerM op s Type
forall op s. VName -> Space -> CompilerM op s Type
memToCType VName
name Space
space
      FieldGroup -> CompilerM op s FieldGroup
forall (m :: * -> *) a. Monad m => a -> m a
return [C.csdecl|$ty:ty $id:name;|]

    constMacro :: Param -> (BlockItem, BlockItem)
constMacro Param
p = ([C.citem|$escstm:def|], [C.citem|$escstm:undef|])
      where
        p' :: String
p' = Id -> String
forall a. Pretty a => a -> String
pretty (VName -> SrcLoc -> Id
forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent (Param -> VName
paramName Param
p) SrcLoc
forall a. Monoid a => a
mempty)
        def :: String
def = String
"#define " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
p' String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" (" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"ctx->constants." String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
p' String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"
        undef :: String
undef = String
"#undef " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
p'

    resetMemConst :: Param -> CompilerM op s ()
resetMemConst ScalarParam {} = () -> CompilerM op s ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    resetMemConst (MemParam VName
name Space
space) = VName -> Space -> CompilerM op s ()
forall a op s. ToExp a => a -> Space -> CompilerM op s ()
resetMem VName
name Space
space

    freeConst :: Param -> CompilerM op s ()
freeConst ScalarParam {} = () -> CompilerM op s ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    freeConst (MemParam VName
name Space
space) = Exp -> Space -> CompilerM op s ()
forall a op s. ToExp a => a -> Space -> CompilerM op s ()
unRefMem [C.cexp|ctx->constants.$id:name|] Space
space

    getConst :: Param -> CompilerM op s BlockItem
getConst (ScalarParam VName
name PrimType
bt) = do
      let ctp :: Type
ctp = PrimType -> Type
primTypeToCType PrimType
bt
      BlockItem -> CompilerM op s BlockItem
forall (m :: * -> *) a. Monad m => a -> m a
return [C.citem|$ty:ctp $id:name = ctx->constants.$id:name;|]
    getConst (MemParam VName
name Space
space) = do
      Type
ty <- VName -> Space -> CompilerM op s Type
forall op s. VName -> Space -> CompilerM op s Type
memToCType VName
name Space
space
      BlockItem -> CompilerM op s BlockItem
forall (m :: * -> *) a. Monad m => a -> m a
return [C.citem|$ty:ty $id:name = ctx->constants.$id:name;|]

cachingMemory ::
  M.Map VName Space ->
  ([C.BlockItem] -> [C.Stm] -> CompilerM op s a) ->
  CompilerM op s a
cachingMemory :: Map VName Space
-> ([BlockItem] -> [Stm] -> CompilerM op s a) -> CompilerM op s a
cachingMemory Map VName Space
lexical [BlockItem] -> [Stm] -> CompilerM op s a
f = do
  -- We only consider lexical 'DefaultSpace' memory blocks to be
  -- cached.  This is not a deep technical restriction, but merely a
  -- heuristic based on GPU memory usually involving larger
  -- allocations, that do not suffer from the overhead of reference
  -- counting.
  let cached :: [VName]
cached = Map VName Space -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName Space -> [VName]) -> Map VName Space -> [VName]
forall a b. (a -> b) -> a -> b
$ (Space -> Bool) -> Map VName Space -> Map VName Space
forall a k. (a -> Bool) -> Map k a -> Map k a
M.filter (Space -> Space -> Bool
forall a. Eq a => a -> a -> Bool
== Space
DefaultSpace) Map VName Space
lexical

  [(VName, VName)]
cached' <- [VName]
-> (VName -> CompilerM op s (VName, VName))
-> CompilerM op s [(VName, VName)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
cached ((VName -> CompilerM op s (VName, VName))
 -> CompilerM op s [(VName, VName)])
-> (VName -> CompilerM op s (VName, VName))
-> CompilerM op s [(VName, VName)]
forall a b. (a -> b) -> a -> b
$ \VName
mem -> do
    VName
size <- String -> CompilerM op s VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> CompilerM op s VName) -> String -> CompilerM op s VName
forall a b. (a -> b) -> a -> b
$ VName -> String
forall a. Pretty a => a -> String
pretty VName
mem String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
"_cached_size"
    (VName, VName) -> CompilerM op s (VName, VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
mem, VName
size)

  let lexMem :: CompilerEnv op s -> CompilerEnv op s
lexMem CompilerEnv op s
env =
        CompilerEnv op s
env
          { envCachedMem :: Map Exp VName
envCachedMem =
              [(Exp, VName)] -> Map Exp VName
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList (((VName, VName) -> (Exp, VName))
-> [(VName, VName)] -> [(Exp, VName)]
forall a b. (a -> b) -> [a] -> [b]
map ((VName -> Exp) -> (VName, VName) -> (Exp, VName)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (VName -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
`C.toExp` SrcLoc
forall a. IsLocation a => a
noLoc)) [(VName, VName)]
cached')
                Map Exp VName -> Map Exp VName -> Map Exp VName
forall a. Semigroup a => a -> a -> a
<> CompilerEnv op s -> Map Exp VName
forall op s. CompilerEnv op s -> Map Exp VName
envCachedMem CompilerEnv op s
env
          }

      declCached :: (a, a) -> [BlockItem]
declCached (a
mem, a
size) =
        [ [C.citem|size_t $id:size = 0;|],
          [C.citem|$ty:defaultMemBlockType $id:mem = NULL;|]
        ]

      freeCached :: (a, b) -> Stm
freeCached (a
mem, b
_) =
        [C.cstm|free($id:mem);|]

  (CompilerEnv op s -> CompilerEnv op s)
-> CompilerM op s a -> CompilerM op s a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local CompilerEnv op s -> CompilerEnv op s
forall op s. CompilerEnv op s -> CompilerEnv op s
lexMem (CompilerM op s a -> CompilerM op s a)
-> CompilerM op s a -> CompilerM op s a
forall a b. (a -> b) -> a -> b
$ [BlockItem] -> [Stm] -> CompilerM op s a
f (((VName, VName) -> [BlockItem]) -> [(VName, VName)] -> [BlockItem]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (VName, VName) -> [BlockItem]
forall a a. (ToIdent a, ToIdent a) => (a, a) -> [BlockItem]
declCached [(VName, VName)]
cached') (((VName, VName) -> Stm) -> [(VName, VName)] -> [Stm]
forall a b. (a -> b) -> [a] -> [b]
map (VName, VName) -> Stm
forall a b. ToIdent a => (a, b) -> Stm
freeCached [(VName, VName)]
cached')

compileFun :: [C.BlockItem] -> [C.Param] -> (Name, Function op) -> CompilerM op s (C.Definition, C.Func)
compileFun :: [BlockItem]
-> [Param]
-> (Name, Function op)
-> CompilerM op s (Definition, Func)
compileFun [BlockItem]
get_constants [Param]
extra (Name
fname, func :: Function op
func@(Function Maybe Name
_ [Param]
outputs [Param]
inputs Code op
body [ExternalValue]
_ [ExternalValue]
_)) = do
  ([Param]
outparams, [Exp]
out_ptrs) <- [(Param, Exp)] -> ([Param], [Exp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Param, Exp)] -> ([Param], [Exp]))
-> CompilerM op s [(Param, Exp)] -> CompilerM op s ([Param], [Exp])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Param -> CompilerM op s (Param, Exp))
-> [Param] -> CompilerM op s [(Param, Exp)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Param -> CompilerM op s (Param, Exp)
forall op s. Param -> CompilerM op s (Param, Exp)
compileOutput [Param]
outputs
  [Param]
inparams <- (Param -> CompilerM op s Param)
-> [Param] -> CompilerM op s [Param]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Param -> CompilerM op s Param
forall op s. Param -> CompilerM op s Param
compileInput [Param]
inputs

  Map VName Space
-> ([BlockItem] -> [Stm] -> CompilerM op s (Definition, Func))
-> CompilerM op s (Definition, Func)
forall op s a.
Map VName Space
-> ([BlockItem] -> [Stm] -> CompilerM op s a) -> CompilerM op s a
cachingMemory (Function op -> Map VName Space
forall a. Function a -> Map VName Space
lexicalMemoryUsage Function op
func) (([BlockItem] -> [Stm] -> CompilerM op s (Definition, Func))
 -> CompilerM op s (Definition, Func))
-> ([BlockItem] -> [Stm] -> CompilerM op s (Definition, Func))
-> CompilerM op s (Definition, Func)
forall a b. (a -> b) -> a -> b
$ \[BlockItem]
decl_cached [Stm]
free_cached -> do
    [BlockItem]
body' <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
blockScope (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ [Exp] -> [Param] -> Code op -> CompilerM op s ()
forall op s. [Exp] -> [Param] -> Code op -> CompilerM op s ()
compileFunBody [Exp]
out_ptrs [Param]
outputs Code op
body

    (Definition, Func) -> CompilerM op s (Definition, Func)
forall (m :: * -> *) a. Monad m => a -> m a
return
      ( [C.cedecl|static int $id:(funName fname)($params:extra, $params:outparams, $params:inparams);|],
        [C.cfun|static int $id:(funName fname)($params:extra, $params:outparams, $params:inparams) {
               $stms:ignores
               int err = 0;
               $items:decl_cached
               $items:get_constants
               $items:body'
              cleanup:
               {}
               $stms:free_cached
               return err;
  }|]
      )
  where
    -- Ignore all the boilerplate parameters, just in case we don't
    -- actually need to use them.
    ignores :: [Stm]
ignores = [[C.cstm|(void)$id:p;|] | C.Param (Just Id
p) DeclSpec
_ Decl
_ SrcLoc
_ <- [Param]
extra]

    compileInput :: Param -> CompilerM op s Param
compileInput (ScalarParam VName
name PrimType
bt) = do
      let ctp :: Type
ctp = PrimType -> Type
primTypeToCType PrimType
bt
      Param -> CompilerM op s Param
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cparam|$ty:ctp $id:name|]
    compileInput (MemParam VName
name Space
space) = do
      Type
ty <- VName -> Space -> CompilerM op s Type
forall op s. VName -> Space -> CompilerM op s Type
memToCType VName
name Space
space
      Param -> CompilerM op s Param
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cparam|$ty:ty $id:name|]

    compileOutput :: Param -> CompilerM op s (Param, Exp)
compileOutput (ScalarParam VName
name PrimType
bt) = do
      let ctp :: Type
ctp = PrimType -> Type
primTypeToCType PrimType
bt
      VName
p_name <- String -> CompilerM op s VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> CompilerM op s VName) -> String -> CompilerM op s VName
forall a b. (a -> b) -> a -> b
$ String
"out_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
baseString VName
name
      (Param, Exp) -> CompilerM op s (Param, Exp)
forall (m :: * -> *) a. Monad m => a -> m a
return ([C.cparam|$ty:ctp *$id:p_name|], [C.cexp|$id:p_name|])
    compileOutput (MemParam VName
name Space
space) = do
      Type
ty <- VName -> Space -> CompilerM op s Type
forall op s. VName -> Space -> CompilerM op s Type
memToCType VName
name Space
space
      VName
p_name <- String -> CompilerM op s VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> CompilerM op s VName) -> String -> CompilerM op s VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
name String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_p"
      (Param, Exp) -> CompilerM op s (Param, Exp)
forall (m :: * -> *) a. Monad m => a -> m a
return ([C.cparam|$ty:ty *$id:p_name|], [C.cexp|$id:p_name|])

derefPointer :: C.Exp -> C.Exp -> C.Type -> C.Exp
derefPointer :: Exp -> Exp -> Type -> Exp
derefPointer Exp
ptr Exp
i Type
res_t =
  [C.cexp|(($ty:res_t)$exp:ptr)[$exp:i]|]

volQuals :: Volatility -> [C.TypeQual]
volQuals :: Volatility -> [TypeQual]
volQuals Volatility
Volatile = [C.ctyquals|volatile|]
volQuals Volatility
Nonvolatile = []

writeScalarPointerWithQuals :: PointerQuals op s -> WriteScalar op s
writeScalarPointerWithQuals :: PointerQuals op s -> WriteScalar op s
writeScalarPointerWithQuals PointerQuals op s
quals_f Exp
dest Exp
i Type
elemtype String
space Volatility
vol Exp
v = do
  [TypeQual]
quals <- PointerQuals op s
quals_f String
space
  let quals' :: [TypeQual]
quals' = Volatility -> [TypeQual]
volQuals Volatility
vol [TypeQual] -> [TypeQual] -> [TypeQual]
forall a. [a] -> [a] -> [a]
++ [TypeQual]
quals
      deref :: Exp
deref =
        Exp -> Exp -> Type -> Exp
derefPointer
          Exp
dest
          Exp
i
          [C.cty|$tyquals:quals' $ty:elemtype*|]
  Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|$exp:deref = $exp:v;|]

readScalarPointerWithQuals :: PointerQuals op s -> ReadScalar op s
readScalarPointerWithQuals :: PointerQuals op s -> ReadScalar op s
readScalarPointerWithQuals PointerQuals op s
quals_f Exp
dest Exp
i Type
elemtype String
space Volatility
vol = do
  [TypeQual]
quals <- PointerQuals op s
quals_f String
space
  let quals' :: [TypeQual]
quals' = Volatility -> [TypeQual]
volQuals Volatility
vol [TypeQual] -> [TypeQual] -> [TypeQual]
forall a. [a] -> [a] -> [a]
++ [TypeQual]
quals
  Exp -> CompilerM op s Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> CompilerM op s Exp) -> Exp -> CompilerM op s Exp
forall a b. (a -> b) -> a -> b
$ Exp -> Exp -> Type -> Exp
derefPointer Exp
dest Exp
i [C.cty|$tyquals:quals' $ty:elemtype*|]

compileExpToName :: String -> PrimType -> Exp -> CompilerM op s VName
compileExpToName :: String -> PrimType -> Exp -> CompilerM op s VName
compileExpToName String
_ PrimType
_ (LeafExp (ScalarVar VName
v) PrimType
_) =
  VName -> CompilerM op s VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
v
compileExpToName String
desc PrimType
t Exp
e = do
  VName
desc' <- String -> CompilerM op s VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
desc
  Exp
e' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
e
  InitGroup -> CompilerM op s ()
forall op s. InitGroup -> CompilerM op s ()
decl [C.cdecl|$ty:(primTypeToCType t) $id:desc' = $e';|]
  VName -> CompilerM op s VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
desc'

compileExp :: Exp -> CompilerM op s C.Exp
compileExp :: Exp -> CompilerM op s Exp
compileExp = (ExpLeaf -> CompilerM op s Exp) -> Exp -> CompilerM op s Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp ExpLeaf -> CompilerM op s Exp
forall op s. ExpLeaf -> CompilerM op s Exp
compileLeaf
  where
    compileLeaf :: ExpLeaf -> CompilerM op s Exp
compileLeaf (ScalarVar VName
src) =
      Exp -> CompilerM op s Exp
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cexp|$id:src|]
    compileLeaf (Index VName
_ Count Elements (TExp Int64)
_ PrimType
Unit Space
__ Volatility
_) =
      Exp -> CompilerM op s Exp
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp -> CompilerM op s Exp) -> Exp -> CompilerM op s Exp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp PrimValue
UnitValue SrcLoc
forall a. Monoid a => a
mempty
    compileLeaf (Index VName
src (Count TExp Int64
iexp) PrimType
restype Space
DefaultSpace Volatility
vol) = do
      Exp
src' <- VName -> CompilerM op s Exp
forall op s. VName -> CompilerM op s Exp
rawMem VName
src
      (Exp -> Exp) -> CompilerM op s Exp -> CompilerM op s Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (PrimType -> Exp -> Exp
fromStorage PrimType
restype) (CompilerM op s Exp -> CompilerM op s Exp)
-> CompilerM op s Exp -> CompilerM op s Exp
forall a b. (a -> b) -> a -> b
$
        Exp -> Exp -> Type -> Exp
derefPointer Exp
src'
          (Exp -> Type -> Exp)
-> CompilerM op s Exp -> CompilerM op s (Type -> Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp (TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
iexp)
          CompilerM op s (Type -> Exp)
-> CompilerM op s Type -> CompilerM op s Exp
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type -> CompilerM op s Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cty|$tyquals:(volQuals vol) $ty:(primStorageType restype)*|]
    compileLeaf (Index VName
src (Count TExp Int64
iexp) PrimType
restype (Space String
space) Volatility
vol) =
      (Exp -> Exp) -> CompilerM op s Exp -> CompilerM op s Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (PrimType -> Exp -> Exp
fromStorage PrimType
restype) (CompilerM op s Exp -> CompilerM op s Exp)
-> (CompilerM op s (CompilerM op s Exp) -> CompilerM op s Exp)
-> CompilerM op s (CompilerM op s Exp)
-> CompilerM op s Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerM op s (CompilerM op s Exp) -> CompilerM op s Exp
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (CompilerM op s (CompilerM op s Exp) -> CompilerM op s Exp)
-> CompilerM op s (CompilerM op s Exp) -> CompilerM op s Exp
forall a b. (a -> b) -> a -> b
$
        (CompilerEnv op s -> ReadScalar op s)
-> CompilerM op s (ReadScalar op s)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks CompilerEnv op s -> ReadScalar op s
forall op s. CompilerEnv op s -> ReadScalar op s
envReadScalar
          CompilerM op s (ReadScalar op s)
-> CompilerM op s Exp
-> CompilerM
     op s (Exp -> Type -> String -> Volatility -> CompilerM op s Exp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> VName -> CompilerM op s Exp
forall op s. VName -> CompilerM op s Exp
rawMem VName
src
          CompilerM
  op s (Exp -> Type -> String -> Volatility -> CompilerM op s Exp)
-> CompilerM op s Exp
-> CompilerM
     op s (Type -> String -> Volatility -> CompilerM op s Exp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp (TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
iexp)
          CompilerM op s (Type -> String -> Volatility -> CompilerM op s Exp)
-> CompilerM op s Type
-> CompilerM op s (String -> Volatility -> CompilerM op s Exp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type -> CompilerM op s Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PrimType -> Type
primStorageType PrimType
restype)
          CompilerM op s (String -> Volatility -> CompilerM op s Exp)
-> CompilerM op s String
-> CompilerM op s (Volatility -> CompilerM op s Exp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> String -> CompilerM op s String
forall (f :: * -> *) a. Applicative f => a -> f a
pure String
space
          CompilerM op s (Volatility -> CompilerM op s Exp)
-> CompilerM op s Volatility -> CompilerM op s (CompilerM op s Exp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Volatility -> CompilerM op s Volatility
forall (f :: * -> *) a. Applicative f => a -> f a
pure Volatility
vol
    compileLeaf (Index VName
src (Count TExp Int64
iexp) PrimType
_ ScalarSpace {} Volatility
_) = do
      Exp
iexp' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp (Exp -> CompilerM op s Exp) -> Exp -> CompilerM op s Exp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
iexp
      Exp -> CompilerM op s Exp
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cexp|$id:src[$exp:iexp']|]

-- | Tell me how to compile a @v@, and I'll Compile any @PrimExp v@ for you.
compilePrimExp :: Monad m => (v -> m C.Exp) -> PrimExp v -> m C.Exp
compilePrimExp :: (v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp v -> m Exp
_ (ValueExp PrimValue
val) =
  Exp -> m Exp
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp -> m Exp) -> Exp -> m Exp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp PrimValue
val SrcLoc
forall a. Monoid a => a
mempty
compilePrimExp v -> m Exp
f (LeafExp v
v PrimType
_) =
  v -> m Exp
f v
v
compilePrimExp v -> m Exp
f (UnOpExp Complement {} PrimExp v
x) = do
  Exp
x' <- (v -> m Exp) -> PrimExp v -> m Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp v -> m Exp
f PrimExp v
x
  Exp -> m Exp
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cexp|~$exp:x'|]
compilePrimExp v -> m Exp
f (UnOpExp Not {} PrimExp v
x) = do
  Exp
x' <- (v -> m Exp) -> PrimExp v -> m Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp v -> m Exp
f PrimExp v
x
  Exp -> m Exp
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cexp|!$exp:x'|]
compilePrimExp v -> m Exp
f (UnOpExp (FAbs FloatType
Float32) PrimExp v
x) = do
  Exp
x' <- (v -> m Exp) -> PrimExp v -> m Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp v -> m Exp
f PrimExp v
x
  Exp -> m Exp
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cexp|(float)fabs($exp:x')|]
compilePrimExp v -> m Exp
f (UnOpExp (FAbs FloatType
Float64) PrimExp v
x) = do
  Exp
x' <- (v -> m Exp) -> PrimExp v -> m Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp v -> m Exp
f PrimExp v
x
  Exp -> m Exp
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cexp|fabs($exp:x')|]
compilePrimExp v -> m Exp
f (UnOpExp SSignum {} PrimExp v
x) = do
  Exp
x' <- (v -> m Exp) -> PrimExp v -> m Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp v -> m Exp
f PrimExp v
x
  Exp -> m Exp
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cexp|($exp:x' > 0) - ($exp:x' < 0)|]
compilePrimExp v -> m Exp
f (UnOpExp USignum {} PrimExp v
x) = do
  Exp
x' <- (v -> m Exp) -> PrimExp v -> m Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp v -> m Exp
f PrimExp v
x
  Exp -> m Exp
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cexp|($exp:x' > 0) - ($exp:x' < 0) != 0|]
compilePrimExp v -> m Exp
f (UnOpExp UnOp
op PrimExp v
x) = do
  Exp
x' <- (v -> m Exp) -> PrimExp v -> m Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp v -> m Exp
f PrimExp v
x
  Exp -> m Exp
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cexp|$id:(pretty op)($exp:x')|]
compilePrimExp v -> m Exp
f (CmpOpExp CmpOp
cmp PrimExp v
x PrimExp v
y) = do
  Exp
x' <- (v -> m Exp) -> PrimExp v -> m Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp v -> m Exp
f PrimExp v
x
  Exp
y' <- (v -> m Exp) -> PrimExp v -> m Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp v -> m Exp
f PrimExp v
y
  Exp -> m Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> m Exp) -> Exp -> m Exp
forall a b. (a -> b) -> a -> b
$ case CmpOp
cmp of
    CmpEq {} -> [C.cexp|$exp:x' == $exp:y'|]
    FCmpLt {} -> [C.cexp|$exp:x' < $exp:y'|]
    FCmpLe {} -> [C.cexp|$exp:x' <= $exp:y'|]
    CmpLlt {} -> [C.cexp|$exp:x' < $exp:y'|]
    CmpLle {} -> [C.cexp|$exp:x' <= $exp:y'|]
    CmpOp
_ -> [C.cexp|$id:(pretty cmp)($exp:x', $exp:y')|]
compilePrimExp v -> m Exp
f (ConvOpExp ConvOp
conv PrimExp v
x) = do
  Exp
x' <- (v -> m Exp) -> PrimExp v -> m Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp v -> m Exp
f PrimExp v
x
  Exp -> m Exp
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cexp|$id:(pretty conv)($exp:x')|]
compilePrimExp v -> m Exp
f (BinOpExp BinOp
bop PrimExp v
x PrimExp v
y) = do
  Exp
x' <- (v -> m Exp) -> PrimExp v -> m Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp v -> m Exp
f PrimExp v
x
  Exp
y' <- (v -> m Exp) -> PrimExp v -> m Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp v -> m Exp
f PrimExp v
y
  -- Note that integer addition, subtraction, and multiplication with
  -- OverflowWrap are not handled by explicit operators, but rather by
  -- functions.  This is because we want to implicitly convert them to
  -- unsigned numbers, so we can do overflow without invoking
  -- undefined behaviour.
  Exp -> m Exp
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp -> m Exp) -> Exp -> m Exp
forall a b. (a -> b) -> a -> b
$ case BinOp
bop of
    Add IntType
_ Overflow
OverflowUndef -> [C.cexp|$exp:x' + $exp:y'|]
    Sub IntType
_ Overflow
OverflowUndef -> [C.cexp|$exp:x' - $exp:y'|]
    Mul IntType
_ Overflow
OverflowUndef -> [C.cexp|$exp:x' * $exp:y'|]
    FAdd {} -> [C.cexp|$exp:x' + $exp:y'|]
    FSub {} -> [C.cexp|$exp:x' - $exp:y'|]
    FMul {} -> [C.cexp|$exp:x' * $exp:y'|]
    FDiv {} -> [C.cexp|$exp:x' / $exp:y'|]
    Xor {} -> [C.cexp|$exp:x' ^ $exp:y'|]
    And {} -> [C.cexp|$exp:x' & $exp:y'|]
    Or {} -> [C.cexp|$exp:x' | $exp:y'|]
    LogAnd {} -> [C.cexp|$exp:x' && $exp:y'|]
    LogOr {} -> [C.cexp|$exp:x' || $exp:y'|]
    BinOp
_ -> [C.cexp|$id:(pretty bop)($exp:x', $exp:y')|]
compilePrimExp v -> m Exp
f (FunExp String
h [PrimExp v]
args PrimType
_) = do
  [Exp]
args' <- (PrimExp v -> m Exp) -> [PrimExp v] -> m [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((v -> m Exp) -> PrimExp v -> m Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp v -> m Exp
f) [PrimExp v]
args
  Exp -> m Exp
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cexp|$id:(funName (nameFromString h))($args:args')|]

linearCode :: Code op -> [Code op]
linearCode :: Code op -> [Code op]
linearCode = [Code op] -> [Code op]
forall a. [a] -> [a]
reverse ([Code op] -> [Code op])
-> (Code op -> [Code op]) -> Code op -> [Code op]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Code op] -> Code op -> [Code op]
forall a. [Code a] -> Code a -> [Code a]
go []
  where
    go :: [Code a] -> Code a -> [Code a]
go [Code a]
acc (Code a
x :>>: Code a
y) =
      [Code a] -> Code a -> [Code a]
go ([Code a] -> Code a -> [Code a]
go [Code a]
acc Code a
x) Code a
y
    go [Code a]
acc Code a
x = Code a
x Code a -> [Code a] -> [Code a]
forall a. a -> [a] -> [a]
: [Code a]
acc

compileCode :: Code op -> CompilerM op s ()
compileCode :: Code op -> CompilerM op s ()
compileCode (Op op
op) =
  CompilerM op s (CompilerM op s ()) -> CompilerM op s ()
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (CompilerM op s (CompilerM op s ()) -> CompilerM op s ())
-> CompilerM op s (CompilerM op s ()) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ (CompilerEnv op s -> OpCompiler op s)
-> CompilerM op s (OpCompiler op s)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks CompilerEnv op s -> OpCompiler op s
forall op s. CompilerEnv op s -> OpCompiler op s
envOpCompiler CompilerM op s (OpCompiler op s)
-> CompilerM op s op -> CompilerM op s (CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> op -> CompilerM op s op
forall (f :: * -> *) a. Applicative f => a -> f a
pure op
op
compileCode Code op
Skip = () -> CompilerM op s ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
compileCode (Comment String
s Code op
code) = do
  [BlockItem]
xs <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
blockScope (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ Code op -> CompilerM op s ()
forall op s. Code op -> CompilerM op s ()
compileCode Code op
code
  let comment :: String
comment = String
"// " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
s
  Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm
    [C.cstm|$comment:comment
              { $items:xs }
             |]
compileCode (TracePrint ErrorMsg Exp
msg) = do
  (String
formatstr, [Exp]
formatargs) <- ErrorMsg Exp -> CompilerM op s (String, [Exp])
forall op s. ErrorMsg Exp -> CompilerM op s (String, [Exp])
errorMsgString ErrorMsg Exp
msg
  Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|fprintf(ctx->log, $string:formatstr, $args:formatargs);|]
compileCode (DebugPrint String
s (Just Exp
e)) = do
  Exp
e' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
e
  Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm
    [C.cstm|if (ctx->debugging) {
          fprintf(ctx->log, $string:fmtstr, $exp:s, ($ty:ety)$exp:e', '\n');
       }|]
  where
    (String
fmt, Type
ety) = case Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType Exp
e of
      IntType IntType
_ -> (String
"llu", [C.cty|long long int|])
      FloatType FloatType
_ -> (String
"f", [C.cty|double|])
      PrimType
_ -> (String
"d", [C.cty|int|])
    fmtstr :: String
fmtstr = String
"%s: %" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
fmt String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"%c"
compileCode (DebugPrint String
s Maybe Exp
Nothing) =
  Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm
    [C.cstm|if (ctx->debugging) {
          fprintf(ctx->log, "%s\n", $exp:s);
       }|]
-- :>>: is treated in a special way to detect declare-set pairs in
-- order to generate prettier code.
compileCode (Code op
c1 :>>: Code op
c2) = [Code op] -> CompilerM op s ()
forall op s. [Code op] -> CompilerM op s ()
go (Code op -> [Code op]
forall op. Code op -> [Code op]
linearCode (Code op
c1 Code op -> Code op -> Code op
forall a. Code a -> Code a -> Code a
:>>: Code op
c2))
  where
    go :: [Code op] -> CompilerM op s ()
go (DeclareScalar VName
name Volatility
vol PrimType
t : SetScalar VName
dest Exp
e : [Code op]
code)
      | VName
name VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
dest = do
        let ct :: Type
ct = PrimType -> Type
primTypeToCType PrimType
t
        Exp
e' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
e
        BlockItem -> CompilerM op s ()
forall op s. BlockItem -> CompilerM op s ()
item [C.citem|$tyquals:(volQuals vol) $ty:ct $id:name = $exp:e';|]
        [Code op] -> CompilerM op s ()
go [Code op]
code
    go (Code op
x : [Code op]
xs) = Code op -> CompilerM op s ()
forall op s. Code op -> CompilerM op s ()
compileCode Code op
x CompilerM op s () -> CompilerM op s () -> CompilerM op s ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> [Code op] -> CompilerM op s ()
go [Code op]
xs
    go [] = () -> CompilerM op s ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
compileCode (Assert Exp
e ErrorMsg Exp
msg (SrcLoc
loc, [SrcLoc]
locs)) = do
  Exp
e' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
e
  [BlockItem]
err <-
    CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
collect (CompilerM op s () -> CompilerM op s [BlockItem])
-> (CompilerM op s (CompilerM op s ()) -> CompilerM op s ())
-> CompilerM op s (CompilerM op s ())
-> CompilerM op s [BlockItem]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerM op s (CompilerM op s ()) -> CompilerM op s ()
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (CompilerM op s (CompilerM op s ()) -> CompilerM op s [BlockItem])
-> CompilerM op s (CompilerM op s ()) -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$
      (CompilerEnv op s -> ErrorCompiler op s)
-> CompilerM op s (ErrorCompiler op s)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (Operations op s -> ErrorCompiler op s
forall op s. Operations op s -> ErrorCompiler op s
opsError (Operations op s -> ErrorCompiler op s)
-> (CompilerEnv op s -> Operations op s)
-> CompilerEnv op s
-> ErrorCompiler op s
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerEnv op s -> Operations op s
forall op s. CompilerEnv op s -> Operations op s
envOperations) CompilerM op s (ErrorCompiler op s)
-> CompilerM op s (ErrorMsg Exp)
-> CompilerM op s (String -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ErrorMsg Exp -> CompilerM op s (ErrorMsg Exp)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ErrorMsg Exp
msg CompilerM op s (String -> CompilerM op s ())
-> CompilerM op s String -> CompilerM op s (CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> String -> CompilerM op s String
forall (f :: * -> *) a. Applicative f => a -> f a
pure String
stacktrace
  Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|if (!$exp:e') { $items:err }|]
  where
    stacktrace :: String
stacktrace = Int -> [String] -> String
prettyStacktrace Int
0 ([String] -> String) -> [String] -> String
forall a b. (a -> b) -> a -> b
$ (SrcLoc -> String) -> [SrcLoc] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map SrcLoc -> String
forall a. Located a => a -> String
locStr ([SrcLoc] -> [String]) -> [SrcLoc] -> [String]
forall a b. (a -> b) -> a -> b
$ SrcLoc
loc SrcLoc -> [SrcLoc] -> [SrcLoc]
forall a. a -> [a] -> [a]
: [SrcLoc]
locs
compileCode (Allocate VName
_ Count Bytes (TExp Int64)
_ ScalarSpace {}) =
  -- Handled by the declaration of the memory block, which is
  -- translated to an actual array.
  () -> CompilerM op s ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
compileCode (Allocate VName
name (Count (TPrimExp Exp
e)) Space
space) = do
  Exp
size <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
e
  Maybe VName
cached <- VName -> CompilerM op s (Maybe VName)
forall a op s. ToExp a => a -> CompilerM op s (Maybe VName)
cacheMem VName
name
  case Maybe VName
cached of
    Just VName
cur_size ->
      Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm
        [C.cstm|if ($exp:cur_size < (size_t)$exp:size) {
                    $exp:name = realloc($exp:name, $exp:size);
                    $exp:cur_size = $exp:size;
                  }|]
    Maybe VName
_ ->
      VName -> Exp -> Space -> Stm -> CompilerM op s ()
forall a b op s.
(ToExp a, ToExp b) =>
a -> b -> Space -> Stm -> CompilerM op s ()
allocMem VName
name Exp
size Space
space [C.cstm|{err = 1; goto cleanup;}|]
compileCode (Free VName
name Space
space) = do
  Bool
cached <- Maybe VName -> Bool
forall a. Maybe a -> Bool
isJust (Maybe VName -> Bool)
-> CompilerM op s (Maybe VName) -> CompilerM op s Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> CompilerM op s (Maybe VName)
forall a op s. ToExp a => a -> CompilerM op s (Maybe VName)
cacheMem VName
name
  Bool -> CompilerM op s () -> CompilerM op s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
cached (CompilerM op s () -> CompilerM op s ())
-> CompilerM op s () -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ VName -> Space -> CompilerM op s ()
forall a op s. ToExp a => a -> Space -> CompilerM op s ()
unRefMem VName
name Space
space
compileCode (For VName
i Exp
bound Code op
body) = do
  let i' :: SrcLoc -> Id
i' = VName -> SrcLoc -> Id
forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent VName
i
      t :: Type
t = PrimType -> Type
primTypeToCType (PrimType -> Type) -> PrimType -> Type
forall a b. (a -> b) -> a -> b
$ Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType Exp
bound
  Exp
bound' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
bound
  [BlockItem]
body' <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
blockScope (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ Code op -> CompilerM op s ()
forall op s. Code op -> CompilerM op s ()
compileCode Code op
body
  Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm
    [C.cstm|for ($ty:t $id:i' = 0; $id:i' < $exp:bound'; $id:i'++) {
            $items:body'
          }|]
compileCode (While TExp Bool
cond Code op
body) = do
  Exp
cond' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp (Exp -> CompilerM op s Exp) -> Exp -> CompilerM op s Exp
forall a b. (a -> b) -> a -> b
$ TExp Bool -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Bool
cond
  [BlockItem]
body' <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
blockScope (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ Code op -> CompilerM op s ()
forall op s. Code op -> CompilerM op s ()
compileCode Code op
body
  Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm
    [C.cstm|while ($exp:cond') {
            $items:body'
          }|]
compileCode (If TExp Bool
cond Code op
tbranch Code op
fbranch) = do
  Exp
cond' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp (Exp -> CompilerM op s Exp) -> Exp -> CompilerM op s Exp
forall a b. (a -> b) -> a -> b
$ TExp Bool -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Bool
cond
  [BlockItem]
tbranch' <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
blockScope (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ Code op -> CompilerM op s ()
forall op s. Code op -> CompilerM op s ()
compileCode Code op
tbranch
  [BlockItem]
fbranch' <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
blockScope (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ Code op -> CompilerM op s ()
forall op s. Code op -> CompilerM op s ()
compileCode Code op
fbranch
  Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm (Stm -> CompilerM op s ()) -> Stm -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ case ([BlockItem]
tbranch', [BlockItem]
fbranch') of
    ([BlockItem]
_, []) ->
      [C.cstm|if ($exp:cond') { $items:tbranch' }|]
    ([], [BlockItem]
_) ->
      [C.cstm|if (!($exp:cond')) { $items:fbranch' }|]
    ([BlockItem], [BlockItem])
_ ->
      [C.cstm|if ($exp:cond') { $items:tbranch' } else { $items:fbranch' }|]
compileCode (Copy VName
dest (Count TExp Int64
destoffset) Space
DefaultSpace VName
src (Count TExp Int64
srcoffset) Space
DefaultSpace (Count TExp Int64
size)) =
  CompilerM op s (CompilerM op s ()) -> CompilerM op s ()
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (CompilerM op s (CompilerM op s ()) -> CompilerM op s ())
-> CompilerM op s (CompilerM op s ()) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$
    Exp -> Exp -> Exp -> Exp -> Exp -> CompilerM op s ()
forall op s. Exp -> Exp -> Exp -> Exp -> Exp -> CompilerM op s ()
copyMemoryDefaultSpace
      (Exp -> Exp -> Exp -> Exp -> Exp -> CompilerM op s ())
-> CompilerM op s Exp
-> CompilerM op s (Exp -> Exp -> Exp -> Exp -> CompilerM op s ())
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> CompilerM op s Exp
forall op s. VName -> CompilerM op s Exp
rawMem VName
dest
      CompilerM op s (Exp -> Exp -> Exp -> Exp -> CompilerM op s ())
-> CompilerM op s Exp
-> CompilerM op s (Exp -> Exp -> Exp -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp (TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
destoffset)
      CompilerM op s (Exp -> Exp -> Exp -> CompilerM op s ())
-> CompilerM op s Exp
-> CompilerM op s (Exp -> Exp -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> VName -> CompilerM op s Exp
forall op s. VName -> CompilerM op s Exp
rawMem VName
src
      CompilerM op s (Exp -> Exp -> CompilerM op s ())
-> CompilerM op s Exp -> CompilerM op s (Exp -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp (TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
srcoffset)
      CompilerM op s (Exp -> CompilerM op s ())
-> CompilerM op s Exp -> CompilerM op s (CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp (TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
size)
compileCode (Copy VName
dest (Count TExp Int64
destoffset) Space
destspace VName
src (Count TExp Int64
srcoffset) Space
srcspace (Count TExp Int64
size)) = do
  Copy op s
copy <- (CompilerEnv op s -> Copy op s) -> CompilerM op s (Copy op s)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks CompilerEnv op s -> Copy op s
forall op s. CompilerEnv op s -> Copy op s
envCopy
  CompilerM op s (CompilerM op s ()) -> CompilerM op s ()
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (CompilerM op s (CompilerM op s ()) -> CompilerM op s ())
-> CompilerM op s (CompilerM op s ()) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$
    Copy op s
copy
      Copy op s
-> CompilerM op s Exp
-> CompilerM
     op
     s
     (Exp -> Space -> Exp -> Exp -> Space -> Exp -> CompilerM op s ())
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> CompilerM op s Exp
forall op s. VName -> CompilerM op s Exp
rawMem VName
dest
      CompilerM
  op
  s
  (Exp -> Space -> Exp -> Exp -> Space -> Exp -> CompilerM op s ())
-> CompilerM op s Exp
-> CompilerM
     op s (Space -> Exp -> Exp -> Space -> Exp -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp (TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
destoffset)
      CompilerM
  op s (Space -> Exp -> Exp -> Space -> Exp -> CompilerM op s ())
-> CompilerM op s Space
-> CompilerM op s (Exp -> Exp -> Space -> Exp -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Space -> CompilerM op s Space
forall (f :: * -> *) a. Applicative f => a -> f a
pure Space
destspace
      CompilerM op s (Exp -> Exp -> Space -> Exp -> CompilerM op s ())
-> CompilerM op s Exp
-> CompilerM op s (Exp -> Space -> Exp -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> VName -> CompilerM op s Exp
forall op s. VName -> CompilerM op s Exp
rawMem VName
src
      CompilerM op s (Exp -> Space -> Exp -> CompilerM op s ())
-> CompilerM op s Exp
-> CompilerM op s (Space -> Exp -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp (TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
srcoffset)
      CompilerM op s (Space -> Exp -> CompilerM op s ())
-> CompilerM op s Space
-> CompilerM op s (Exp -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Space -> CompilerM op s Space
forall (f :: * -> *) a. Applicative f => a -> f a
pure Space
srcspace
      CompilerM op s (Exp -> CompilerM op s ())
-> CompilerM op s Exp -> CompilerM op s (CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp (TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
size)
compileCode (Write VName
_ Count Elements (TExp Int64)
_ PrimType
Unit Space
_ Volatility
_ Exp
_) = () -> CompilerM op s ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
compileCode (Write VName
dest (Count TExp Int64
idx) PrimType
elemtype Space
DefaultSpace Volatility
vol Exp
elemexp) = do
  Exp
dest' <- VName -> CompilerM op s Exp
forall op s. VName -> CompilerM op s Exp
rawMem VName
dest
  Exp
deref <-
    Exp -> Exp -> Type -> Exp
derefPointer Exp
dest'
      (Exp -> Type -> Exp)
-> CompilerM op s Exp -> CompilerM op s (Type -> Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp (TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
idx)
      CompilerM op s (Type -> Exp)
-> CompilerM op s Type -> CompilerM op s Exp
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type -> CompilerM op s Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cty|$tyquals:(volQuals vol) $ty:(primStorageType elemtype)*|]
  Exp
elemexp' <- PrimType -> Exp -> Exp
toStorage PrimType
elemtype (Exp -> Exp) -> CompilerM op s Exp -> CompilerM op s Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
elemexp
  Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|$exp:deref = $exp:elemexp';|]
compileCode (Write VName
dest (Count TExp Int64
idx) PrimType
_ ScalarSpace {} Volatility
_ Exp
elemexp) = do
  Exp
idx' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp (TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
idx)
  Exp
elemexp' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
elemexp
  Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|$id:dest[$exp:idx'] = $exp:elemexp';|]
compileCode (Write VName
dest (Count TExp Int64
idx) PrimType
elemtype (Space String
space) Volatility
vol Exp
elemexp) =
  CompilerM op s (CompilerM op s ()) -> CompilerM op s ()
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (CompilerM op s (CompilerM op s ()) -> CompilerM op s ())
-> CompilerM op s (CompilerM op s ()) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$
    (CompilerEnv op s -> WriteScalar op s)
-> CompilerM op s (WriteScalar op s)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks CompilerEnv op s -> WriteScalar op s
forall op s. CompilerEnv op s -> WriteScalar op s
envWriteScalar
      CompilerM op s (WriteScalar op s)
-> CompilerM op s Exp
-> CompilerM
     op
     s
     (Exp -> Type -> String -> Volatility -> Exp -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> VName -> CompilerM op s Exp
forall op s. VName -> CompilerM op s Exp
rawMem VName
dest
      CompilerM
  op
  s
  (Exp -> Type -> String -> Volatility -> Exp -> CompilerM op s ())
-> CompilerM op s Exp
-> CompilerM
     op s (Type -> String -> Volatility -> Exp -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp (TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
idx)
      CompilerM
  op s (Type -> String -> Volatility -> Exp -> CompilerM op s ())
-> CompilerM op s Type
-> CompilerM
     op s (String -> Volatility -> Exp -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type -> CompilerM op s Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PrimType -> Type
primStorageType PrimType
elemtype)
      CompilerM op s (String -> Volatility -> Exp -> CompilerM op s ())
-> CompilerM op s String
-> CompilerM op s (Volatility -> Exp -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> String -> CompilerM op s String
forall (f :: * -> *) a. Applicative f => a -> f a
pure String
space
      CompilerM op s (Volatility -> Exp -> CompilerM op s ())
-> CompilerM op s Volatility
-> CompilerM op s (Exp -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Volatility -> CompilerM op s Volatility
forall (f :: * -> *) a. Applicative f => a -> f a
pure Volatility
vol
      CompilerM op s (Exp -> CompilerM op s ())
-> CompilerM op s Exp -> CompilerM op s (CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (PrimType -> Exp -> Exp
toStorage PrimType
elemtype (Exp -> Exp) -> CompilerM op s Exp -> CompilerM op s Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
elemexp)
compileCode (DeclareMem VName
name Space
space) =
  VName -> Space -> CompilerM op s ()
forall op s. VName -> Space -> CompilerM op s ()
declMem VName
name Space
space
compileCode (DeclareScalar VName
name Volatility
vol PrimType
t) = do
  let ct :: Type
ct = PrimType -> Type
primTypeToCType PrimType
t
  InitGroup -> CompilerM op s ()
forall op s. InitGroup -> CompilerM op s ()
decl [C.cdecl|$tyquals:(volQuals vol) $ty:ct $id:name;|]
compileCode (DeclareArray VName
name ScalarSpace {} PrimType
_ ArrayContents
_) =
  String -> CompilerM op s ()
forall a. HasCallStack => String -> a
error (String -> CompilerM op s ()) -> String -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ String
"Cannot declare array " String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
name String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" in scalar space."
compileCode (DeclareArray VName
name Space
DefaultSpace PrimType
t ArrayContents
vs) = do
  VName
name_realtype <- String -> CompilerM op s VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> CompilerM op s VName) -> String -> CompilerM op s VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
name String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_realtype"
  let ct :: Type
ct = PrimType -> Type
primTypeToCType PrimType
t
  case ArrayContents
vs of
    ArrayValues [PrimValue]
vs' -> do
      let vs'' :: [Initializer]
vs'' = [[C.cinit|$exp:v|] | PrimValue
v <- [PrimValue]
vs']
      Definition -> CompilerM op s ()
forall op s. Definition -> CompilerM op s ()
earlyDecl [C.cedecl|static $ty:ct $id:name_realtype[$int:(length vs')] = {$inits:vs''};|]
    ArrayZeros Int
n ->
      Definition -> CompilerM op s ()
forall op s. Definition -> CompilerM op s ()
earlyDecl [C.cedecl|static $ty:ct $id:name_realtype[$int:n];|]
  -- Fake a memory block.
  Id -> Type -> Maybe Exp -> CompilerM op s ()
forall op s. Id -> Type -> Maybe Exp -> CompilerM op s ()
contextField
    (VName -> SrcLoc -> Id
forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent VName
name SrcLoc
forall a. IsLocation a => a
noLoc)
    [C.cty|struct memblock|]
    (Maybe Exp -> CompilerM op s ()) -> Maybe Exp -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ Exp -> Maybe Exp
forall a. a -> Maybe a
Just [C.cexp|(struct memblock){NULL, (char*)$id:name_realtype, 0}|]
  BlockItem -> CompilerM op s ()
forall op s. BlockItem -> CompilerM op s ()
item [C.citem|struct memblock $id:name = ctx->$id:name;|]
compileCode (DeclareArray VName
name (Space String
space) PrimType
t ArrayContents
vs) =
  CompilerM op s (CompilerM op s ()) -> CompilerM op s ()
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (CompilerM op s (CompilerM op s ()) -> CompilerM op s ())
-> CompilerM op s (CompilerM op s ()) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$
    (CompilerEnv op s -> StaticArray op s)
-> CompilerM op s (StaticArray op s)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks CompilerEnv op s -> StaticArray op s
forall op s. CompilerEnv op s -> StaticArray op s
envStaticArray
      CompilerM op s (StaticArray op s)
-> CompilerM op s VName
-> CompilerM
     op s (String -> PrimType -> ArrayContents -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> VName -> CompilerM op s VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
name
      CompilerM
  op s (String -> PrimType -> ArrayContents -> CompilerM op s ())
-> CompilerM op s String
-> CompilerM op s (PrimType -> ArrayContents -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> String -> CompilerM op s String
forall (f :: * -> *) a. Applicative f => a -> f a
pure String
space
      CompilerM op s (PrimType -> ArrayContents -> CompilerM op s ())
-> CompilerM op s PrimType
-> CompilerM op s (ArrayContents -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> PrimType -> CompilerM op s PrimType
forall (f :: * -> *) a. Applicative f => a -> f a
pure PrimType
t
      CompilerM op s (ArrayContents -> CompilerM op s ())
-> CompilerM op s ArrayContents
-> CompilerM op s (CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ArrayContents -> CompilerM op s ArrayContents
forall (f :: * -> *) a. Applicative f => a -> f a
pure ArrayContents
vs
-- For assignments of the form 'x = x OP e', we generate C assignment
-- operators to make the resulting code slightly nicer.  This has no
-- effect on performance.
compileCode (SetScalar VName
dest (BinOpExp BinOp
op (LeafExp (ScalarVar VName
x) PrimType
_) Exp
y))
  | VName
dest VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
x,
    Just VName -> Exp -> Exp
f <- BinOp -> Maybe (VName -> Exp -> Exp)
assignmentOperator BinOp
op = do
    Exp
y' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
y
    Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|$exp:(f dest y');|]
compileCode (SetScalar VName
dest Exp
src) = do
  Exp
src' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
src
  Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|$id:dest = $exp:src';|]
compileCode (SetMem VName
dest VName
src Space
space) =
  VName -> VName -> Space -> CompilerM op s ()
forall a b op s.
(ToExp a, ToExp b) =>
a -> b -> Space -> CompilerM op s ()
setMem VName
dest VName
src Space
space
compileCode (Call [VName]
dests Name
fname [Arg]
args) =
  CompilerM op s (CompilerM op s ()) -> CompilerM op s ()
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (CompilerM op s (CompilerM op s ()) -> CompilerM op s ())
-> CompilerM op s (CompilerM op s ()) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$
    (CompilerEnv op s -> CallCompiler op s)
-> CompilerM op s (CallCompiler op s)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (Operations op s -> CallCompiler op s
forall op s. Operations op s -> CallCompiler op s
opsCall (Operations op s -> CallCompiler op s)
-> (CompilerEnv op s -> Operations op s)
-> CompilerEnv op s
-> CallCompiler op s
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerEnv op s -> Operations op s
forall op s. CompilerEnv op s -> Operations op s
envOperations)
      CompilerM op s (CallCompiler op s)
-> CompilerM op s [VName]
-> CompilerM op s (Name -> [Exp] -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [VName] -> CompilerM op s [VName]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName]
dests
      CompilerM op s (Name -> [Exp] -> CompilerM op s ())
-> CompilerM op s Name
-> CompilerM op s ([Exp] -> CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Name -> CompilerM op s Name
forall (f :: * -> *) a. Applicative f => a -> f a
pure Name
fname
      CompilerM op s ([Exp] -> CompilerM op s ())
-> CompilerM op s [Exp] -> CompilerM op s (CompilerM op s ())
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Arg -> CompilerM op s Exp) -> [Arg] -> CompilerM op s [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Arg -> CompilerM op s Exp
forall op s. Arg -> CompilerM op s Exp
compileArg [Arg]
args
  where
    compileArg :: Arg -> CompilerM op s Exp
compileArg (MemArg VName
m) = Exp -> CompilerM op s Exp
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cexp|$exp:m|]
    compileArg (ExpArg Exp
e) = Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
e

blockScope :: CompilerM op s () -> CompilerM op s [C.BlockItem]
blockScope :: CompilerM op s () -> CompilerM op s [BlockItem]
blockScope = (((), [BlockItem]) -> [BlockItem])
-> CompilerM op s ((), [BlockItem]) -> CompilerM op s [BlockItem]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((), [BlockItem]) -> [BlockItem]
forall a b. (a, b) -> b
snd (CompilerM op s ((), [BlockItem]) -> CompilerM op s [BlockItem])
-> (CompilerM op s () -> CompilerM op s ((), [BlockItem]))
-> CompilerM op s ()
-> CompilerM op s [BlockItem]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerM op s () -> CompilerM op s ((), [BlockItem])
forall op s a. CompilerM op s a -> CompilerM op s (a, [BlockItem])
blockScope'

blockScope' :: CompilerM op s a -> CompilerM op s (a, [C.BlockItem])
blockScope' :: CompilerM op s a -> CompilerM op s (a, [BlockItem])
blockScope' CompilerM op s a
m = do
  [(VName, Space)]
old_allocs <- (CompilerState s -> [(VName, Space)])
-> CompilerM op s [(VName, Space)]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets CompilerState s -> [(VName, Space)]
forall s. CompilerState s -> [(VName, Space)]
compDeclaredMem
  (a
x, [BlockItem]
xs) <- CompilerM op s a -> CompilerM op s (a, [BlockItem])
forall op s a. CompilerM op s a -> CompilerM op s (a, [BlockItem])
collect' CompilerM op s a
m
  [(VName, Space)]
new_allocs <- (CompilerState s -> [(VName, Space)])
-> CompilerM op s [(VName, Space)]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((CompilerState s -> [(VName, Space)])
 -> CompilerM op s [(VName, Space)])
-> (CompilerState s -> [(VName, Space)])
-> CompilerM op s [(VName, Space)]
forall a b. (a -> b) -> a -> b
$ ((VName, Space) -> Bool) -> [(VName, Space)] -> [(VName, Space)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((VName, Space) -> [(VName, Space)] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [(VName, Space)]
old_allocs) ([(VName, Space)] -> [(VName, Space)])
-> (CompilerState s -> [(VName, Space)])
-> CompilerState s
-> [(VName, Space)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerState s -> [(VName, Space)]
forall s. CompilerState s -> [(VName, Space)]
compDeclaredMem
  (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((CompilerState s -> CompilerState s) -> CompilerM op s ())
-> (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \CompilerState s
s -> CompilerState s
s {compDeclaredMem :: [(VName, Space)]
compDeclaredMem = [(VName, Space)]
old_allocs}
  [BlockItem]
releases <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
collect (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ ((VName, Space) -> CompilerM op s ())
-> [(VName, Space)] -> CompilerM op s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((VName -> Space -> CompilerM op s ())
-> (VName, Space) -> CompilerM op s ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry VName -> Space -> CompilerM op s ()
forall a op s. ToExp a => a -> Space -> CompilerM op s ()
unRefMem) [(VName, Space)]
new_allocs
  (a, [BlockItem]) -> CompilerM op s (a, [BlockItem])
forall (m :: * -> *) a. Monad m => a -> m a
return (a
x, [BlockItem]
xs [BlockItem] -> [BlockItem] -> [BlockItem]
forall a. Semigroup a => a -> a -> a
<> [BlockItem]
releases)

compileFunBody :: [C.Exp] -> [Param] -> Code op -> CompilerM op s ()
compileFunBody :: [Exp] -> [Param] -> Code op -> CompilerM op s ()
compileFunBody [Exp]
output_ptrs [Param]
outputs Code op
code = do
  (Param -> CompilerM op s ()) -> [Param] -> CompilerM op s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Param -> CompilerM op s ()
forall op s. Param -> CompilerM op s ()
declareOutput [Param]
outputs
  Code op -> CompilerM op s ()
forall op s. Code op -> CompilerM op s ()
compileCode Code op
code
  (Exp -> Param -> CompilerM op s ())
-> [Exp] -> [Param] -> CompilerM op s ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Exp -> Param -> CompilerM op s ()
forall a op s. ToExp a => a -> Param -> CompilerM op s ()
setRetVal' [Exp]
output_ptrs [Param]
outputs
  where
    declareOutput :: Param -> CompilerM op s ()
declareOutput (MemParam VName
name Space
space) =
      VName -> Space -> CompilerM op s ()
forall op s. VName -> Space -> CompilerM op s ()
declMem VName
name Space
space
    declareOutput (ScalarParam VName
name PrimType
pt) = do
      let ctp :: Type
ctp = PrimType -> Type
primTypeToCType PrimType
pt
      InitGroup -> CompilerM op s ()
forall op s. InitGroup -> CompilerM op s ()
decl [C.cdecl|$ty:ctp $id:name;|]

    setRetVal' :: a -> Param -> CompilerM op s ()
setRetVal' a
p (MemParam VName
name Space
space) = do
      Exp -> Space -> CompilerM op s ()
forall a op s. ToExp a => a -> Space -> CompilerM op s ()
resetMem [C.cexp|*$exp:p|] Space
space
      Exp -> VName -> Space -> CompilerM op s ()
forall a b op s.
(ToExp a, ToExp b) =>
a -> b -> Space -> CompilerM op s ()
setMem [C.cexp|*$exp:p|] VName
name Space
space
    setRetVal' a
p (ScalarParam VName
name PrimType
_) =
      Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|*$exp:p = $id:name;|]

assignmentOperator :: BinOp -> Maybe (VName -> C.Exp -> C.Exp)
assignmentOperator :: BinOp -> Maybe (VName -> Exp -> Exp)
assignmentOperator Add {} = (VName -> Exp -> Exp) -> Maybe (VName -> Exp -> Exp)
forall a. a -> Maybe a
Just ((VName -> Exp -> Exp) -> Maybe (VName -> Exp -> Exp))
-> (VName -> Exp -> Exp) -> Maybe (VName -> Exp -> Exp)
forall a b. (a -> b) -> a -> b
$ \VName
d Exp
e -> [C.cexp|$id:d += $exp:e|]
assignmentOperator Sub {} = (VName -> Exp -> Exp) -> Maybe (VName -> Exp -> Exp)
forall a. a -> Maybe a
Just ((VName -> Exp -> Exp) -> Maybe (VName -> Exp -> Exp))
-> (VName -> Exp -> Exp) -> Maybe (VName -> Exp -> Exp)
forall a b. (a -> b) -> a -> b
$ \VName
d Exp
e -> [C.cexp|$id:d -= $exp:e|]
assignmentOperator Mul {} = (VName -> Exp -> Exp) -> Maybe (VName -> Exp -> Exp)
forall a. a -> Maybe a
Just ((VName -> Exp -> Exp) -> Maybe (VName -> Exp -> Exp))
-> (VName -> Exp -> Exp) -> Maybe (VName -> Exp -> Exp)
forall a b. (a -> b) -> a -> b
$ \VName
d Exp
e -> [C.cexp|$id:d *= $exp:e|]
assignmentOperator BinOp
_ = Maybe (VName -> Exp -> Exp)
forall a. Maybe a
Nothing