{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE UndecidableInstances #-}
module Language.Halide.Kernel
( compile
, compileForTarget
, compileToCallable
, compileToLoweredStmt
, StmtOutputFormat (..)
, LoweredSignature
)
where
import Control.Exception (bracket)
import Control.Monad.Primitive (touch)
import Control.Monad.ST (RealWorld)
import Data.IORef
import Data.Kind (Type)
import Data.Primitive.PrimArray (MutablePrimArray)
import Data.Primitive.PrimArray qualified as P
import Data.Primitive.Ptr qualified as P
import Data.Proxy
import Data.Text (Text, pack)
import Data.Text.Encoding (encodeUtf8)
import Data.Text.IO qualified as T
import Foreign.C.Types (CUIntPtr (..))
import Foreign.ForeignPtr
import Foreign.ForeignPtr.Unsafe
import Foreign.Marshal.Array (allocaArray, peekArray)
import Foreign.Ptr (Ptr, castPtr)
import Foreign.Storable
import GHC.TypeNats
import Language.C.Inline qualified as C
import Language.C.Inline.Cpp.Exception qualified as C
import Language.C.Inline.Unsafe qualified as CU
import Language.Halide.Buffer
import Language.Halide.Context
import Language.Halide.Expr
import Language.Halide.Func
import Language.Halide.RedundantConstraints
import Language.Halide.Target
import Language.Halide.Type
import System.IO.Temp (withSystemTempDirectory)
data CxxArgument
importHalide
instanceHasCxxVector "Halide::Argument"
data ArgvStorage s
= ArgvStorage
{-# UNPACK #-} !(MutablePrimArray s (Ptr ()))
{-# UNPACK #-} !(MutablePrimArray s CUIntPtr)
newArgvStorage :: Int -> IO (ArgvStorage RealWorld)
newArgvStorage :: Int -> IO (ArgvStorage RealWorld)
newArgvStorage Int
n = forall s.
MutablePrimArray s (Ptr ())
-> MutablePrimArray s CUIntPtr -> ArgvStorage s
ArgvStorage forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
Int -> m (MutablePrimArray (PrimState m) a)
P.newPinnedPrimArray Int
n forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
Int -> m (MutablePrimArray (PrimState m) a)
P.newPinnedPrimArray Int
n
setArgvStorage
:: All ValidArgument args
=> ArgvStorage RealWorld
-> Ptr CxxUserContext
-> Arguments args
-> IO ()
setArgvStorage :: forall (args :: [*]).
All ValidArgument args =>
ArgvStorage RealWorld
-> Ptr CxxUserContext -> Arguments args -> IO ()
setArgvStorage (ArgvStorage MutablePrimArray RealWorld (Ptr ())
argv MutablePrimArray RealWorld CUIntPtr
scalarStorage) Ptr CxxUserContext
context Arguments args
inputs = do
let argvPtr :: Ptr (Ptr ())
argvPtr = forall s a. MutablePrimArray s a -> Ptr a
P.mutablePrimArrayContents MutablePrimArray RealWorld (Ptr ())
argv
scalarStoragePtr :: Ptr CUIntPtr
scalarStoragePtr = forall s a. MutablePrimArray s a -> Ptr a
P.mutablePrimArrayContents MutablePrimArray RealWorld CUIntPtr
scalarStorage
go :: All ValidArgument args' => Int -> Arguments args' -> IO ()
go :: forall (args' :: [*]).
All ValidArgument args' =>
Int -> Arguments args' -> IO ()
go !Int
_ Arguments args'
Nil = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
go !Int
i (t
x ::: Arguments ts
xs) = do
forall t. ValidArgument t => Ptr () -> Ptr () -> t -> IO ()
fillSlot
(forall a b. Ptr a -> Ptr b
castPtr forall a b. (a -> b) -> a -> b
$ Ptr (Ptr ())
argvPtr forall a. Prim a => Ptr a -> Int -> Ptr a
`P.advancePtr` Int
i)
(forall a b. Ptr a -> Ptr b
castPtr forall a b. (a -> b) -> a -> b
$ Ptr CUIntPtr
scalarStoragePtr forall a. Prim a => Ptr a -> Int -> Ptr a
`P.advancePtr` Int
i)
t
x
forall (args' :: [*]).
All ValidArgument args' =>
Int -> Arguments args' -> IO ()
go (Int
i forall a. Num a => a -> a -> a
+ Int
1) Arguments ts
xs
forall (args' :: [*]).
All ValidArgument args' =>
Int -> Arguments args' -> IO ()
go Int
0 (Ptr CxxUserContext
context forall t (ts :: [*]). t -> Arguments ts -> Arguments (t : ts)
::: Arguments args
inputs)
forall (m :: * -> *) a. PrimMonad m => a -> m ()
touch MutablePrimArray RealWorld (Ptr ())
argv
forall (m :: * -> *) a. PrimMonad m => a -> m ()
touch MutablePrimArray RealWorld CUIntPtr
scalarStorage
class ValidArgument (t :: Type) where
fillSlot :: Ptr () -> Ptr () -> t -> IO ()
instance IsHalideType t => ValidArgument t where
fillSlot :: Ptr () -> Ptr () -> t -> IO ()
fillSlot :: Ptr () -> Ptr () -> t -> IO ()
fillSlot Ptr ()
argv Ptr ()
scalarStorage t
x = do
forall a. Storable a => Ptr a -> a -> IO ()
poke (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
scalarStorage :: Ptr t) t
x
forall a. Storable a => Ptr a -> a -> IO ()
poke (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
argv :: Ptr (Ptr ())) Ptr ()
scalarStorage
{-# INLINE fillSlot #-}
instance {-# OVERLAPPING #-} ValidArgument (Ptr CxxUserContext) where
fillSlot :: Ptr () -> Ptr () -> Ptr CxxUserContext -> IO ()
fillSlot :: Ptr () -> Ptr () -> Ptr CxxUserContext -> IO ()
fillSlot Ptr ()
argv Ptr ()
scalarStorage Ptr CxxUserContext
x = do
forall a. Storable a => Ptr a -> a -> IO ()
poke (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
scalarStorage :: Ptr (Ptr CxxUserContext)) Ptr CxxUserContext
x
forall a. Storable a => Ptr a -> a -> IO ()
poke (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
argv :: Ptr (Ptr ())) Ptr ()
scalarStorage
{-# INLINE fillSlot #-}
instance {-# OVERLAPPING #-} ValidArgument (Ptr (HalideBuffer n a)) where
fillSlot :: Ptr () -> Ptr () -> Ptr (HalideBuffer n a) -> IO ()
fillSlot :: Ptr () -> Ptr () -> Ptr (HalideBuffer n a) -> IO ()
fillSlot Ptr ()
argv Ptr ()
_ Ptr (HalideBuffer n a)
x = do
forall a. Storable a => Ptr a -> a -> IO ()
poke (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
argv :: Ptr (Ptr (HalideBuffer n a))) Ptr (HalideBuffer n a)
x
{-# INLINE fillSlot #-}
class KnownNat n => FuncBuilder f (n :: Nat) (a :: Type) | f -> n a where
buildFunc :: Ptr (CxxVector CxxArgument) -> f -> IO (Func 'FuncTy n a)
instance (k ~ 'ParamTy, KnownNat m, IsHalideType t, FuncBuilder r n a) => FuncBuilder (Func k m (Expr t) -> r) n a where
buildFunc :: Ptr (CxxVector CxxArgument)
-> (Func k m (Expr t) -> r) -> IO (Func 'FuncTy n a)
buildFunc Ptr (CxxVector CxxArgument)
v Func k m (Expr t) -> r
f = do
Func 'ParamTy m (Expr t)
param <- forall a1 (n :: Nat).
IsHalideType a1 =>
IORef (Maybe (ForeignPtr CxxImageParam))
-> Func 'ParamTy n (Expr a1)
Param forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. a -> IO (IORef a)
newIORef forall a. Maybe a
Nothing
Func 'FuncTy n a
func <- forall f (n :: Nat) a.
FuncBuilder f n a =>
Ptr (CxxVector CxxArgument) -> f -> IO (Func 'FuncTy n a)
buildFunc Ptr (CxxVector CxxArgument)
v (Func k m (Expr t) -> r
f Func 'ParamTy m (Expr t)
param)
forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
Func 'ParamTy n (Expr a) -> (Ptr CxxImageParam -> IO b) -> IO b
withBufferParam Func 'ParamTy m (Expr t)
param forall a b. (a -> b) -> a -> b
$ \Ptr CxxImageParam
p ->
[CU.exp| void { $(std::vector<Halide::Argument>* v)->push_back(*$(Halide::ImageParam const* p)) } |]
forall (f :: * -> *) a. Applicative f => a -> f a
pure Func 'FuncTy n a
func
instance (IsHalideType t, FuncBuilder r n a) => FuncBuilder (Expr t -> r) n a where
buildFunc :: Ptr (CxxVector CxxArgument)
-> (Expr t -> r) -> IO (Func 'FuncTy n a)
buildFunc Ptr (CxxVector CxxArgument)
v Expr t -> r
f = do
Expr t
param <- forall {k} (a :: k).
IORef (Maybe (ForeignPtr CxxParameter)) -> Expr a
ScalarParam forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. a -> IO (IORef a)
newIORef forall a. Maybe a
Nothing
Func 'FuncTy n a
func <- forall f (n :: Nat) a.
FuncBuilder f n a =>
Ptr (CxxVector CxxArgument) -> f -> IO (Func 'FuncTy n a)
buildFunc Ptr (CxxVector CxxArgument)
v (Expr t -> r
f Expr t
param)
forall a b.
(HasCallStack, IsHalideType a) =>
Expr a -> (Ptr CxxParameter -> IO b) -> IO b
asScalarParam Expr t
param forall a b. (a -> b) -> a -> b
$ \Ptr CxxParameter
p ->
[CU.block| void {
auto const& p = *$(Halide::Internal::Parameter const* p);
$(std::vector<Halide::Argument>* v)->emplace_back(
p.name(),
Halide::Argument::InputScalar,
p.type(),
p.dimensions(),
p.get_argument_estimates());
} |]
forall (f :: * -> *) a. Applicative f => a -> f a
pure Func 'FuncTy n a
func
instance (KnownNat n, t ~ 'FuncTy, n' ~ n, a' ~ a) => FuncBuilder (IO (Func t n' a')) n a where
buildFunc :: Ptr (CxxVector CxxArgument)
-> IO (Func t n' a') -> IO (Func 'FuncTy n a)
buildFunc Ptr (CxxVector CxxArgument)
_ IO (Func t n' a')
action = IO (Func t n' a')
action
type family LoweredSignature f where
LoweredSignature (Expr a -> r) = a -> LoweredSignature r
LoweredSignature (Func t n (Expr a) -> r) = Ptr (HalideBuffer n a) -> LoweredSignature r
LoweredSignature (IO (Func t n (Expr a))) = Ptr (HalideBuffer n a) -> IO ()
LoweredSignature (IO (Func t n (Expr a1, Expr a2))) = Ptr (HalideBuffer n a1) -> Ptr (HalideBuffer n a2) -> IO ()
type IsHalideKernel f = (KnownNat (Length (FunctionArguments f)), All ValidArgument (FunctionArguments f), Curry (FunctionArguments f) (IO ()) f)
newtype Callable (signature :: Type) = Callable (ForeignPtr CxxCallable)
compileToCallable
:: forall n a f
. (FuncBuilder f n a, IsHalideKernel (LoweredSignature f))
=> Target
-> f
-> IO (Callable (LoweredSignature f))
compileToCallable :: forall (n :: Nat) a f.
(FuncBuilder f n a, IsHalideKernel (LoweredSignature f)) =>
Target -> f -> IO (Callable (LoweredSignature f))
compileToCallable Target
target f
builder =
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket (forall a. HasCxxVector a => Maybe Int -> IO (Ptr (CxxVector a))
newCxxVector forall a. Maybe a
Nothing) forall a. HasCxxVector a => Ptr (CxxVector a) -> IO ()
deleteCxxVector forall a b. (a -> b) -> a -> b
$ \Ptr (CxxVector CxxArgument)
v -> do
Func 'FuncTy n a
func <- forall f (n :: Nat) a.
FuncBuilder f n a =>
Ptr (CxxVector CxxArgument) -> f -> IO (Func 'FuncTy n a)
buildFunc @f @n @a Ptr (CxxVector CxxArgument)
v f
builder
forall (n :: Nat) a b.
KnownNat n =>
Func 'FuncTy n a -> (Ptr CxxFunc -> IO b) -> IO b
withCxxFunc Func 'FuncTy n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
func' ->
forall a. Target -> (Ptr CxxTarget -> IO a) -> IO a
withCxxTarget Target
target forall a b. (a -> b) -> a -> b
$ \Ptr CxxTarget
target' ->
forall signature. Ptr CxxCallable -> IO (Callable signature)
wrapCxxCallable
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [C.throwBlock| Halide::Callable* {
return handle_halide_exceptions([=]() {
auto& func = *$(Halide::Func* func');
auto& args = *$(std::vector<Halide::Argument>* v);
auto const& target = *$(const Halide::Target* target');
std::reverse(std::begin(args), std::end(args));
return new Halide::Callable{func.compile_to_callable(args, target)};
});
} |]
where
Any (IsHalideKernel (LoweredSignature f)) -> ()
_ = forall (c :: Constraint) (proxy :: Constraint -> *).
c =>
proxy c -> ()
keepRedundantConstraint @(IsHalideKernel (LoweredSignature f))
callableToFunction :: forall f. IsHalideKernel f => Callable f -> IO f
callableToFunction :: forall f. IsHalideKernel f => Callable f -> IO f
callableToFunction (Callable ForeignPtr CxxCallable
callable) = do
ForeignPtr CxxUserContext
context <- IO (ForeignPtr CxxUserContext)
newEmptyCxxUserContext
let argc :: CInt
argc = CInt
1 forall a. Num a => a -> a -> a
+ forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Nat
natVal (forall {k} (t :: k). Proxy t
Proxy @(Length (FunctionArguments f))))
storage :: ArgvStorage RealWorld
storage@(ArgvStorage MutablePrimArray RealWorld (Ptr ())
argv MutablePrimArray RealWorld CUIntPtr
scalarStorage) <- Int -> IO (ArgvStorage RealWorld)
newArgvStorage (forall a b. (Integral a, Num b) => a -> b
fromIntegral CInt
argc)
let argvPtr :: Ptr (Ptr ())
argvPtr = forall s a. MutablePrimArray s a -> Ptr a
P.mutablePrimArrayContents MutablePrimArray RealWorld (Ptr ())
argv
contextPtr :: Ptr CxxUserContext
contextPtr = forall a. ForeignPtr a -> Ptr a
unsafeForeignPtrToPtr ForeignPtr CxxUserContext
context
callablePtr :: Ptr CxxCallable
callablePtr = forall a. ForeignPtr a -> Ptr a
unsafeForeignPtrToPtr ForeignPtr CxxCallable
callable
kernel :: Arguments (FunctionArguments f) -> IO ()
kernel Arguments (FunctionArguments f)
args = do
forall (args :: [*]).
All ValidArgument args =>
ArgvStorage RealWorld
-> Ptr CxxUserContext -> Arguments args -> IO ()
setArgvStorage ArgvStorage RealWorld
storage Ptr CxxUserContext
contextPtr Arguments (FunctionArguments f)
args
[CU.exp| void {
handle_halide_exceptions([=]() {
return $(Halide::Callable* callablePtr)->call_argv_fast(
$(int argc), $(const void* const* argvPtr));
})
} |]
forall (m :: * -> *) a. PrimMonad m => a -> m ()
touch MutablePrimArray RealWorld (Ptr ())
argv
forall (m :: * -> *) a. PrimMonad m => a -> m ()
touch MutablePrimArray RealWorld CUIntPtr
scalarStorage
forall (m :: * -> *) a. PrimMonad m => a -> m ()
touch ForeignPtr CxxUserContext
context
forall (m :: * -> *) a. PrimMonad m => a -> m ()
touch ForeignPtr CxxCallable
callable
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (args :: [*]) r f.
Curry args r f =>
(Arguments args -> r) -> f
curryG @(FunctionArguments f) @(IO ()) Arguments (FunctionArguments f) -> IO ()
kernel
wrapCxxUserContext :: Ptr CxxUserContext -> IO (ForeignPtr CxxUserContext)
wrapCxxUserContext :: Ptr CxxUserContext -> IO (ForeignPtr CxxUserContext)
wrapCxxUserContext = forall a. FinalizerPtr a -> Ptr a -> IO (ForeignPtr a)
newForeignPtr FunPtr (Ptr CxxUserContext -> IO ())
deleter
where
deleter :: FunPtr (Ptr CxxUserContext -> IO ())
deleter = [C.funPtr| void deleteUserContext(Halide::JITUserContext* p) { delete p; } |]
newEmptyCxxUserContext :: IO (ForeignPtr CxxUserContext)
newEmptyCxxUserContext :: IO (ForeignPtr CxxUserContext)
newEmptyCxxUserContext =
Ptr CxxUserContext -> IO (ForeignPtr CxxUserContext)
wrapCxxUserContext forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| Halide::JITUserContext* { new Halide::JITUserContext{} } |]
wrapCxxCallable :: Ptr CxxCallable -> IO (Callable signature)
wrapCxxCallable :: forall signature. Ptr CxxCallable -> IO (Callable signature)
wrapCxxCallable = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall signature. ForeignPtr CxxCallable -> Callable signature
Callable forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. FinalizerPtr a -> Ptr a -> IO (ForeignPtr a)
newForeignPtr FunPtr (Ptr CxxCallable -> IO ())
deleter
where
deleter :: FunPtr (Ptr CxxCallable -> IO ())
deleter = [C.funPtr| void deleteCallable(Halide::Callable* p) { delete p; } |]
compile
:: forall f n a
. (FuncBuilder f n a, IsHalideKernel (LoweredSignature f))
=> f
-> IO (LoweredSignature f)
compile :: forall f (n :: Nat) a.
(FuncBuilder f n a, IsHalideKernel (LoweredSignature f)) =>
f -> IO (LoweredSignature f)
compile = forall f (n :: Nat) a.
(FuncBuilder f n a, IsHalideKernel (LoweredSignature f)) =>
Target -> f -> IO (LoweredSignature f)
compileForTarget Target
hostTarget
compileForTarget
:: forall f n a
. (FuncBuilder f n a, IsHalideKernel (LoweredSignature f))
=> Target
-> f
-> IO (LoweredSignature f)
compileForTarget :: forall f (n :: Nat) a.
(FuncBuilder f n a, IsHalideKernel (LoweredSignature f)) =>
Target -> f -> IO (LoweredSignature f)
compileForTarget Target
target f
builder = forall (n :: Nat) a f.
(FuncBuilder f n a, IsHalideKernel (LoweredSignature f)) =>
Target -> f -> IO (Callable (LoweredSignature f))
compileToCallable Target
target f
builder forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall f. IsHalideKernel f => Callable f -> IO f
callableToFunction
data StmtOutputFormat
=
StmtText
|
StmtHTML
deriving stock (Int -> StmtOutputFormat -> ShowS
[StmtOutputFormat] -> ShowS
StmtOutputFormat -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [StmtOutputFormat] -> ShowS
$cshowList :: [StmtOutputFormat] -> ShowS
show :: StmtOutputFormat -> String
$cshow :: StmtOutputFormat -> String
showsPrec :: Int -> StmtOutputFormat -> ShowS
$cshowsPrec :: Int -> StmtOutputFormat -> ShowS
Show, StmtOutputFormat -> StmtOutputFormat -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: StmtOutputFormat -> StmtOutputFormat -> Bool
$c/= :: StmtOutputFormat -> StmtOutputFormat -> Bool
== :: StmtOutputFormat -> StmtOutputFormat -> Bool
$c== :: StmtOutputFormat -> StmtOutputFormat -> Bool
Eq)
instance Enum StmtOutputFormat where
fromEnum :: StmtOutputFormat -> Int
fromEnum =
forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. \case
StmtOutputFormat
StmtText -> [CU.pure| int { static_cast<int>(Halide::StmtOutputFormat::Text) } |]
StmtOutputFormat
StmtHTML -> [CU.pure| int { static_cast<int>(Halide::StmtOutputFormat::HTML) } |]
toEnum :: Int -> StmtOutputFormat
toEnum Int
k
| forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::StmtOutputFormat::Text) } |] = StmtOutputFormat
StmtText
| forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::StmtOutputFormat::HTML) } |] = StmtOutputFormat
StmtHTML
| Bool
otherwise = forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"invalid StmtOutputFormat " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Int
k
compileToLoweredStmt
:: forall n a f. (FuncBuilder f n a) => StmtOutputFormat -> Target -> f -> IO Text
compileToLoweredStmt :: forall (n :: Nat) a f.
FuncBuilder f n a =>
StmtOutputFormat -> Target -> f -> IO Text
compileToLoweredStmt StmtOutputFormat
format Target
target f
builder = do
forall (m :: * -> *) a.
(MonadIO m, MonadMask m) =>
String -> (String -> m a) -> m a
withSystemTempDirectory String
"halide-haskell" forall a b. (a -> b) -> a -> b
$ \String
dir -> do
let s :: ByteString
s = Text -> ByteString
encodeUtf8 (String -> Text
pack (String
dir forall a. Semigroup a => a -> a -> a
<> String
"/code.stmt"))
o :: CInt
o = forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall a. Enum a => a -> Int
fromEnum StmtOutputFormat
format)
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket (forall a. HasCxxVector a => Maybe Int -> IO (Ptr (CxxVector a))
newCxxVector forall a. Maybe a
Nothing) forall a. HasCxxVector a => Ptr (CxxVector a) -> IO ()
deleteCxxVector forall a b. (a -> b) -> a -> b
$ \Ptr (CxxVector CxxArgument)
v -> do
Func 'FuncTy n a
func <- forall f (n :: Nat) a.
FuncBuilder f n a =>
Ptr (CxxVector CxxArgument) -> f -> IO (Func 'FuncTy n a)
buildFunc @f @n @a Ptr (CxxVector CxxArgument)
v f
builder
forall (n :: Nat) a b.
KnownNat n =>
Func 'FuncTy n a -> (Ptr CxxFunc -> IO b) -> IO b
withCxxFunc Func 'FuncTy n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
func' ->
forall a. Target -> (Ptr CxxTarget -> IO a) -> IO a
withCxxTarget Target
target forall a b. (a -> b) -> a -> b
$ \Ptr CxxTarget
target' ->
[C.throwBlock| void {
handle_halide_exceptions([=]() {
auto& func = *$(Halide::Func* func');
auto& args = *$(std::vector<Halide::Argument>* v);
auto const& target = *$(const Halide::Target* target');
std::reverse(std::begin(args), std::end(args));
func.compile_to_lowered_stmt(
std::string{$bs-ptr:s, static_cast<size_t>($bs-len:s)},
args,
static_cast<Halide::StmtOutputFormat>($(int o)),
target);
});
} |]
String -> IO Text
T.readFile (String
dir forall a. Semigroup a => a -> a -> a
<> String
"/code.stmt")