{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE UndecidableInstances #-}
module Language.Halide.Kernel
( compile
, compileForTarget
, compileToCallable
, compileToLoweredStmt
, StmtOutputFormat (..)
, IsFuncBuilder
, ReturnsFunc
, Lowered
)
where
import Control.Exception (bracket)
import Control.Monad.Primitive (touch)
import Control.Monad.ST (RealWorld)
import Data.IORef
import Data.Kind (Type)
import Data.Primitive.PrimArray (MutablePrimArray)
import Data.Primitive.PrimArray qualified as P
import Data.Primitive.Ptr qualified as P
import Data.Proxy
import Data.Text (Text, pack)
import Data.Text.Encoding (encodeUtf8)
import Data.Text.IO qualified as T
import Foreign.C.Types (CUIntPtr (..))
import Foreign.ForeignPtr
import Foreign.ForeignPtr.Unsafe
import Foreign.Ptr (FunPtr, Ptr, castPtr)
import Foreign.Storable
import GHC.TypeNats
import Language.C.Inline qualified as C
import Language.C.Inline.Cpp.Exception qualified as C
import Language.C.Inline.Unsafe qualified as CU
import Language.Halide.Buffer
import Language.Halide.Context
import Language.Halide.Expr
import Language.Halide.Func
import Language.Halide.RedundantConstraints
import Language.Halide.Target
import Language.Halide.Type
import System.IO.Temp (withSystemTempDirectory)
import Unsafe.Coerce (unsafeCoerce)
importHalide
data ArgvStorage s
= ArgvStorage
{-# UNPACK #-} !(MutablePrimArray s (Ptr ()))
{-# UNPACK #-} !(MutablePrimArray s CUIntPtr)
newArgvStorage :: Int -> IO (ArgvStorage RealWorld)
newArgvStorage :: Int -> IO (ArgvStorage RealWorld)
newArgvStorage Int
n = forall s.
MutablePrimArray s (Ptr ())
-> MutablePrimArray s CUIntPtr -> ArgvStorage s
ArgvStorage forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
Int -> m (MutablePrimArray (PrimState m) a)
P.newPinnedPrimArray Int
n forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
Int -> m (MutablePrimArray (PrimState m) a)
P.newPinnedPrimArray Int
n
setArgvStorage
:: (All ValidArgument inputs, All ValidArgument outputs)
=> ArgvStorage RealWorld
-> Arguments inputs
-> Arguments outputs
-> IO ()
setArgvStorage :: forall (inputs :: [*]) (outputs :: [*]).
(All ValidArgument inputs, All ValidArgument outputs) =>
ArgvStorage RealWorld
-> Arguments inputs -> Arguments outputs -> IO ()
setArgvStorage (ArgvStorage MutablePrimArray RealWorld (Ptr ())
argv MutablePrimArray RealWorld CUIntPtr
scalarStorage) Arguments inputs
inputs Arguments outputs
outputs = do
let argvPtr :: Ptr (Ptr ())
argvPtr = forall s a. MutablePrimArray s a -> Ptr a
P.mutablePrimArrayContents MutablePrimArray RealWorld (Ptr ())
argv
scalarStoragePtr :: Ptr CUIntPtr
scalarStoragePtr = forall s a. MutablePrimArray s a -> Ptr a
P.mutablePrimArrayContents MutablePrimArray RealWorld CUIntPtr
scalarStorage
go :: All ValidArgument ts' => Int -> Arguments ts' -> IO Int
go :: forall (ts' :: [*]).
All ValidArgument ts' =>
Int -> Arguments ts' -> IO Int
go !Int
i Arguments ts'
Nil = forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
i
go !Int
i ((t
x :: t) ::: Arguments ts
xs) = do
forall t. ValidArgument t => Ptr () -> Ptr () -> t -> IO ()
fillSlot
(forall a b. Ptr a -> Ptr b
castPtr forall a b. (a -> b) -> a -> b
$ Ptr (Ptr ())
argvPtr forall a. Prim a => Ptr a -> Int -> Ptr a
`P.advancePtr` Int
i)
(forall a b. Ptr a -> Ptr b
castPtr forall a b. (a -> b) -> a -> b
$ Ptr CUIntPtr
scalarStoragePtr forall a. Prim a => Ptr a -> Int -> Ptr a
`P.advancePtr` Int
i)
t
x
forall (ts' :: [*]).
All ValidArgument ts' =>
Int -> Arguments ts' -> IO Int
go (Int
i forall a. Num a => a -> a -> a
+ Int
1) Arguments ts
xs
Int
i <- forall (ts' :: [*]).
All ValidArgument ts' =>
Int -> Arguments ts' -> IO Int
go Int
0 Arguments inputs
inputs
Int
_ <- forall (ts' :: [*]).
All ValidArgument ts' =>
Int -> Arguments ts' -> IO Int
go Int
i Arguments outputs
outputs
forall (m :: * -> *) a. PrimMonad m => a -> m ()
touch MutablePrimArray RealWorld (Ptr ())
argv
forall (m :: * -> *) a. PrimMonad m => a -> m ()
touch MutablePrimArray RealWorld CUIntPtr
scalarStorage
class ValidArgument (t :: Type) where
fillSlot :: Ptr () -> Ptr () -> t -> IO ()
instance IsHalideType t => ValidArgument t where
fillSlot :: Ptr () -> Ptr () -> t -> IO ()
fillSlot :: Ptr () -> Ptr () -> t -> IO ()
fillSlot Ptr ()
argv Ptr ()
scalarStorage t
x = do
forall a. Storable a => Ptr a -> a -> IO ()
poke (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
scalarStorage :: Ptr t) t
x
forall a. Storable a => Ptr a -> a -> IO ()
poke (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
argv :: Ptr (Ptr ())) Ptr ()
scalarStorage
{-# INLINE fillSlot #-}
instance {-# OVERLAPPING #-} ValidArgument (Ptr CxxUserContext) where
fillSlot :: Ptr () -> Ptr () -> Ptr CxxUserContext -> IO ()
fillSlot :: Ptr () -> Ptr () -> Ptr CxxUserContext -> IO ()
fillSlot Ptr ()
argv Ptr ()
scalarStorage Ptr CxxUserContext
x = do
forall a. Storable a => Ptr a -> a -> IO ()
poke (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
scalarStorage :: Ptr (Ptr CxxUserContext)) Ptr CxxUserContext
x
forall a. Storable a => Ptr a -> a -> IO ()
poke (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
argv :: Ptr (Ptr ())) Ptr ()
scalarStorage
{-# INLINE fillSlot #-}
instance {-# OVERLAPPING #-} ValidArgument (Ptr (HalideBuffer n a)) where
fillSlot :: Ptr () -> Ptr () -> Ptr (HalideBuffer n a) -> IO ()
fillSlot :: Ptr () -> Ptr () -> Ptr (HalideBuffer n a) -> IO ()
fillSlot Ptr ()
argv Ptr ()
_ Ptr (HalideBuffer n a)
x = do
forall a. Storable a => Ptr a -> a -> IO ()
poke (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
argv :: Ptr (Ptr (HalideBuffer n a))) Ptr (HalideBuffer n a)
x
{-# INLINE fillSlot #-}
class ValidArgument (Lowered t) => ValidParameter (t :: Type) where
appendToArgList :: Ptr (CxxVector CxxArgument) -> t -> IO ()
prepareParameter :: IO t
instance IsHalideType a => ValidParameter (Expr a) where
appendToArgList :: Ptr (CxxVector CxxArgument) -> Expr a -> IO ()
appendToArgList :: Ptr (CxxVector CxxArgument) -> Expr a -> IO ()
appendToArgList Ptr (CxxVector CxxArgument)
v Expr a
expr =
forall a b.
(HasCallStack, IsHalideType a) =>
Expr a -> (Ptr CxxParameter -> IO b) -> IO b
asScalarParam Expr a
expr forall a b. (a -> b) -> a -> b
$ \Ptr CxxParameter
p ->
[CU.exp| void { $(std::vector<Halide::Argument>* v)->emplace_back(
$(Halide::Internal::Parameter const* p)->name(),
Halide::Argument::InputScalar,
$(Halide::Internal::Parameter const* p)->type(),
$(Halide::Internal::Parameter const* p)->dimensions(),
$(Halide::Internal::Parameter const* p)->get_argument_estimates()) } |]
prepareParameter :: IO (Expr a)
prepareParameter :: IO (Expr a)
prepareParameter = forall {k} (a :: k).
IORef (Maybe (ForeignPtr CxxParameter)) -> Expr a
ScalarParam forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. a -> IO (IORef a)
newIORef forall a. Maybe a
Nothing
instance (KnownNat n, IsHalideType a) => ValidParameter (Func t n a) where
appendToArgList :: Ptr (CxxVector CxxArgument) -> Func t n a -> IO ()
appendToArgList :: Ptr (CxxVector CxxArgument) -> Func t n a -> IO ()
appendToArgList Ptr (CxxVector CxxArgument)
v func :: Func t n a
func@(Param IORef (Maybe (ForeignPtr CxxImageParam))
_) =
forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
Func 'ParamTy n a -> (Ptr CxxImageParam -> IO b) -> IO b
withBufferParam Func t n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxImageParam
p ->
[CU.exp| void { $(std::vector<Halide::Argument>* v)->push_back(
*$(Halide::ImageParam const* p)) } |]
appendToArgList Ptr (CxxVector CxxArgument)
_ Func t n a
_ = forall a. HasCallStack => [Char] -> a
error [Char]
"appendToArgList called on Func; this should never happen"
prepareParameter :: IO (Func t n a)
prepareParameter :: IO (Func t n a)
prepareParameter = forall a b. a -> b
unsafeCoerce forall a b. (a -> b) -> a -> b
$ forall (n :: Nat) a.
IORef (Maybe (ForeignPtr CxxImageParam)) -> Func 'ParamTy n a
Param forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. a -> IO (IORef a)
newIORef forall a. Maybe a
Nothing
class PrepareParameters ts where
prepareParameters :: IO (Arguments ts)
instance PrepareParameters '[] where
prepareParameters :: IO (Arguments '[])
prepareParameters :: IO (Arguments '[])
prepareParameters = forall (f :: * -> *) a. Applicative f => a -> f a
pure Arguments '[]
Nil
instance (ValidParameter t, PrepareParameters ts) => PrepareParameters (t ': ts) where
prepareParameters :: IO (Arguments (t : ts))
prepareParameters :: IO (Arguments (t : ts))
prepareParameters = do
t
t <- forall t. ValidParameter t => IO t
prepareParameter @t
Arguments ts
ts <- forall (ts :: [*]). PrepareParameters ts => IO (Arguments ts)
prepareParameters @ts
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ t
t forall t (ts :: [*]). t -> Arguments ts -> Arguments (t : ts)
::: Arguments ts
ts
prepareCxxArguments
:: forall ts b
. (All ValidParameter ts, KnownNat (Length ts))
=> Arguments ts
-> (Ptr (CxxVector CxxArgument) -> IO b)
-> IO b
prepareCxxArguments :: forall (ts :: [*]) b.
(All ValidParameter ts, KnownNat (Length ts)) =>
Arguments ts -> (Ptr (CxxVector CxxArgument) -> IO b) -> IO b
prepareCxxArguments Arguments ts
args Ptr (CxxVector CxxArgument) -> IO b
action = do
let count :: CSize
count = forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Nat
natVal (forall {k} (t :: k). Proxy t
Proxy @(Length ts)))
allocate :: IO (Ptr (CxxVector CxxArgument))
allocate =
[CU.block| std::vector<Halide::Argument>* {
auto p = new std::vector<Halide::Argument>{};
p->reserve($(size_t count));
return p;
} |]
destroy :: Ptr (CxxVector CxxArgument) -> IO ()
destroy Ptr (CxxVector CxxArgument)
p = [CU.exp| void { delete $(std::vector<Halide::Argument>* p) } |]
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket IO (Ptr (CxxVector CxxArgument))
allocate Ptr (CxxVector CxxArgument) -> IO ()
destroy forall a b. (a -> b) -> a -> b
$ \Ptr (CxxVector CxxArgument)
v -> do
let go :: All ValidParameter ts' => Arguments ts' -> IO ()
go :: forall (ts' :: [*]).
All ValidParameter ts' =>
Arguments ts' -> IO ()
go Arguments ts'
Nil = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
go (t
x ::: Arguments ts
xs) = forall t.
ValidParameter t =>
Ptr (CxxVector CxxArgument) -> t -> IO ()
appendToArgList Ptr (CxxVector CxxArgument)
v t
x forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (ts' :: [*]).
All ValidParameter ts' =>
Arguments ts' -> IO ()
go Arguments ts
xs
forall (ts' :: [*]).
All ValidParameter ts' =>
Arguments ts' -> IO ()
go Arguments ts
args
Ptr (CxxVector CxxArgument) -> IO b
action Ptr (CxxVector CxxArgument)
v
deleteCxxUserContext :: FunPtr (Ptr CxxUserContext -> IO ())
deleteCxxUserContext :: FunPtr (Ptr CxxUserContext -> IO ())
deleteCxxUserContext = [C.funPtr| void deleteUserContext(Halide::JITUserContext* p) { delete p; } |]
wrapCxxUserContext :: Ptr CxxUserContext -> IO (ForeignPtr CxxUserContext)
wrapCxxUserContext :: Ptr CxxUserContext -> IO (ForeignPtr CxxUserContext)
wrapCxxUserContext = forall a. FinalizerPtr a -> Ptr a -> IO (ForeignPtr a)
newForeignPtr FunPtr (Ptr CxxUserContext -> IO ())
deleteCxxUserContext
newEmptyCxxUserContext :: IO (ForeignPtr CxxUserContext)
newEmptyCxxUserContext :: IO (ForeignPtr CxxUserContext)
newEmptyCxxUserContext =
Ptr CxxUserContext -> IO (ForeignPtr CxxUserContext)
wrapCxxUserContext forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| Halide::JITUserContext* { new Halide::JITUserContext{} } |]
wrapCxxCallable :: Ptr CxxCallable -> IO (Callable inputs outputs)
wrapCxxCallable :: forall (inputs :: [*]) outputs.
Ptr CxxCallable -> IO (Callable inputs outputs)
wrapCxxCallable = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (inputs :: [*]) output.
ForeignPtr CxxCallable -> Callable inputs output
Callable forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. FinalizerPtr a -> Ptr a -> IO (ForeignPtr a)
newForeignPtr FunPtr (Ptr CxxCallable -> IO ())
deleter
where
deleter :: FunPtr (Ptr CxxCallable -> IO ())
deleter = [C.funPtr| void deleteCallable(Halide::Callable* p) { delete p; } |]
type Lowered :: forall k. k -> k
type family Lowered (t :: k) :: k where
Lowered (Expr a) = a
Lowered (Func t n a) = Ptr (HalideBuffer n a)
Lowered '[] = '[]
Lowered (Expr a ': ts) = (a ': Lowered ts)
Lowered (Func t n a ': ts) = (Ptr (HalideBuffer n a) ': Lowered ts)
class (FunctionReturn f ~ IO (Func t n a), IsHalideType a, KnownNat n) => ReturnsFunc f t n a | f -> t n a
instance (FunctionReturn f ~ IO (Func t n a), IsHalideType a, KnownNat n) => ReturnsFunc f t n a
type IsFuncBuilder f t n a =
( All ValidParameter (FunctionArguments f)
, All ValidArgument (Lowered (FunctionArguments f))
, UnCurry f (FunctionArguments f) (FunctionReturn f)
, PrepareParameters (FunctionArguments f)
, ReturnsFunc f t n a
, KnownNat (Length (FunctionArguments f))
, KnownNat (Length (Lowered (FunctionArguments f)))
)
buildFunc :: (IsFuncBuilder f t n a) => f -> IO (Arguments (FunctionArguments f), Func t n a)
buildFunc :: forall f (t :: FuncTy) (n :: Nat) a.
IsFuncBuilder f t n a =>
f -> IO (Arguments (FunctionArguments f), Func t n a)
buildFunc f
builder = do
Arguments (FunctionArguments f)
parameters <- forall (ts :: [*]). PrepareParameters ts => IO (Arguments ts)
prepareParameters
Func t n a
func <- forall f (args :: [*]) r.
UnCurry f args r =>
f -> Arguments args -> r
uncurryG f
builder Arguments (FunctionArguments f)
parameters
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Arguments (FunctionArguments f)
parameters, Func t n a
func)
newtype Callable (inputs :: [Type]) (output :: Type) = Callable (ForeignPtr CxxCallable)
compileToCallable
:: forall n a t f inputs output
. ( IsFuncBuilder f t n a
, Lowered (FunctionArguments f) ~ inputs
, Ptr (HalideBuffer n a) ~ output
)
=> Target
-> f
-> IO (Callable inputs output)
compileToCallable :: forall (n :: Nat) a (t :: FuncTy) f (inputs :: [*]) output.
(IsFuncBuilder f t n a, Lowered (FunctionArguments f) ~ inputs,
Ptr (HalideBuffer n a) ~ output) =>
Target -> f -> IO (Callable inputs output)
compileToCallable Target
target f
builder = do
(Arguments (FunctionArguments f)
args, Func t n a
func) <- forall f (t :: FuncTy) (n :: Nat) a.
IsFuncBuilder f t n a =>
f -> IO (Arguments (FunctionArguments f), Func t n a)
buildFunc f
builder
forall (ts :: [*]) b.
(All ValidParameter ts, KnownNat (Length ts)) =>
Arguments ts -> (Ptr (CxxVector CxxArgument) -> IO b) -> IO b
prepareCxxArguments Arguments (FunctionArguments f)
args forall a b. (a -> b) -> a -> b
$ \Ptr (CxxVector CxxArgument)
args' ->
forall (n :: Nat) a (t :: FuncTy) b.
(KnownNat n, IsHalideType a) =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
func' ->
forall a. Target -> (Ptr CxxTarget -> IO a) -> IO a
withCxxTarget Target
target forall a b. (a -> b) -> a -> b
$ \Ptr CxxTarget
target' ->
forall (inputs :: [*]) outputs.
Ptr CxxCallable -> IO (Callable inputs outputs)
wrapCxxCallable
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [C.throwBlock| Halide::Callable* {
return handle_halide_exceptions([=]() {
return new Halide::Callable{
$(Halide::Func* func')->compile_to_callable(
*$(const std::vector<Halide::Argument>* args'),
*$(const Halide::Target* target'))};
});
} |]
where
()
_ = forall (c :: Constraint) (proxy :: Constraint -> *).
c =>
proxy c -> ()
keepRedundantConstraint (forall {k} (t :: k). Proxy t
Proxy @(Ptr (HalideBuffer n a) ~ output))
callableToFunction
:: forall inputs output kernel
. ( Curry inputs (output -> IO ()) kernel
, KnownNat (Length inputs)
, All ValidArgument inputs
, ValidArgument output
)
=> Callable inputs output
-> IO kernel
callableToFunction :: forall (inputs :: [*]) output kernel.
(Curry inputs (output -> IO ()) kernel, KnownNat (Length inputs),
All ValidArgument inputs, ValidArgument output) =>
Callable inputs output -> IO kernel
callableToFunction (Callable ForeignPtr CxxCallable
callable) = do
ForeignPtr CxxUserContext
context <- IO (ForeignPtr CxxUserContext)
newEmptyCxxUserContext
let argc :: CInt
argc = CInt
2 forall a. Num a => a -> a -> a
+ forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Nat
natVal (forall {k} (t :: k). Proxy t
Proxy @(Length inputs)))
storage :: ArgvStorage RealWorld
storage@(ArgvStorage MutablePrimArray RealWorld (Ptr ())
argv MutablePrimArray RealWorld CUIntPtr
scalarStorage) <- Int -> IO (ArgvStorage RealWorld)
newArgvStorage (forall a b. (Integral a, Num b) => a -> b
fromIntegral CInt
argc)
let argvPtr :: Ptr (Ptr ())
argvPtr = forall s a. MutablePrimArray s a -> Ptr a
P.mutablePrimArrayContents MutablePrimArray RealWorld (Ptr ())
argv
contextPtr :: Ptr CxxUserContext
contextPtr = forall a. ForeignPtr a -> Ptr a
unsafeForeignPtrToPtr ForeignPtr CxxUserContext
context
callablePtr :: Ptr CxxCallable
callablePtr = forall a. ForeignPtr a -> Ptr a
unsafeForeignPtrToPtr ForeignPtr CxxCallable
callable
kernel :: Arguments inputs -> output -> IO ()
kernel Arguments inputs
args output
out = do
forall (inputs :: [*]) (outputs :: [*]).
(All ValidArgument inputs, All ValidArgument outputs) =>
ArgvStorage RealWorld
-> Arguments inputs -> Arguments outputs -> IO ()
setArgvStorage ArgvStorage RealWorld
storage (Ptr CxxUserContext
contextPtr forall t (ts :: [*]). t -> Arguments ts -> Arguments (t : ts)
::: Arguments inputs
args) (output
out forall t (ts :: [*]). t -> Arguments ts -> Arguments (t : ts)
::: Arguments '[]
Nil)
[CU.exp| void {
handle_halide_exceptions([=]() {
return $(Halide::Callable* callablePtr)->call_argv_fast(
$(int argc), $(const void* const* argvPtr));
})
} |]
forall (m :: * -> *) a. PrimMonad m => a -> m ()
touch MutablePrimArray RealWorld (Ptr ())
argv
forall (m :: * -> *) a. PrimMonad m => a -> m ()
touch MutablePrimArray RealWorld CUIntPtr
scalarStorage
forall (m :: * -> *) a. PrimMonad m => a -> m ()
touch ForeignPtr CxxUserContext
context
forall (m :: * -> *) a. PrimMonad m => a -> m ()
touch ForeignPtr CxxCallable
callable
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (args :: [*]) r f.
Curry args r f =>
(Arguments args -> r) -> f
curryG @inputs @(output -> IO ()) Arguments inputs -> output -> IO ()
kernel
compile
:: forall n a t f kernel
. ( IsFuncBuilder f t n a
, Curry (Lowered (FunctionArguments f)) (Ptr (HalideBuffer n a) -> IO ()) kernel
)
=> f
-> IO kernel
compile :: forall (n :: Nat) a (t :: FuncTy) f kernel.
(IsFuncBuilder f t n a,
Curry
(Lowered (FunctionArguments f))
(Ptr (HalideBuffer n a) -> IO ())
kernel) =>
f -> IO kernel
compile = forall (n :: Nat) a (t :: FuncTy) f kernel.
(IsFuncBuilder f t n a,
Curry
(Lowered (FunctionArguments f))
(Ptr (HalideBuffer n a) -> IO ())
kernel) =>
Target -> f -> IO kernel
compileForTarget Target
hostTarget
compileForTarget
:: forall n a t f kernel
. ( IsFuncBuilder f t n a
, Curry (Lowered (FunctionArguments f)) (Ptr (HalideBuffer n a) -> IO ()) kernel
)
=> Target
-> f
-> IO kernel
compileForTarget :: forall (n :: Nat) a (t :: FuncTy) f kernel.
(IsFuncBuilder f t n a,
Curry
(Lowered (FunctionArguments f))
(Ptr (HalideBuffer n a) -> IO ())
kernel) =>
Target -> f -> IO kernel
compileForTarget Target
target f
builder = forall (n :: Nat) a (t :: FuncTy) f (inputs :: [*]) output.
(IsFuncBuilder f t n a, Lowered (FunctionArguments f) ~ inputs,
Ptr (HalideBuffer n a) ~ output) =>
Target -> f -> IO (Callable inputs output)
compileToCallable Target
target f
builder forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (inputs :: [*]) output kernel.
(Curry inputs (output -> IO ()) kernel, KnownNat (Length inputs),
All ValidArgument inputs, ValidArgument output) =>
Callable inputs output -> IO kernel
callableToFunction
data StmtOutputFormat
=
StmtText
|
StmtHTML
deriving stock (Int -> StmtOutputFormat -> ShowS
[StmtOutputFormat] -> ShowS
StmtOutputFormat -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [StmtOutputFormat] -> ShowS
$cshowList :: [StmtOutputFormat] -> ShowS
show :: StmtOutputFormat -> [Char]
$cshow :: StmtOutputFormat -> [Char]
showsPrec :: Int -> StmtOutputFormat -> ShowS
$cshowsPrec :: Int -> StmtOutputFormat -> ShowS
Show, StmtOutputFormat -> StmtOutputFormat -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: StmtOutputFormat -> StmtOutputFormat -> Bool
$c/= :: StmtOutputFormat -> StmtOutputFormat -> Bool
== :: StmtOutputFormat -> StmtOutputFormat -> Bool
$c== :: StmtOutputFormat -> StmtOutputFormat -> Bool
Eq)
instance Enum StmtOutputFormat where
fromEnum :: StmtOutputFormat -> Int
fromEnum =
forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. \case
StmtOutputFormat
StmtText -> [CU.pure| int { static_cast<int>(Halide::StmtOutputFormat::Text) } |]
StmtOutputFormat
StmtHTML -> [CU.pure| int { static_cast<int>(Halide::StmtOutputFormat::HTML) } |]
toEnum :: Int -> StmtOutputFormat
toEnum Int
k
| forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::StmtOutputFormat::Text) } |] = StmtOutputFormat
StmtText
| forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::StmtOutputFormat::HTML) } |] = StmtOutputFormat
StmtHTML
| Bool
otherwise = forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"invalid StmtOutputFormat " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> [Char]
show Int
k
compileToLoweredStmt
:: forall n a t f. (IsFuncBuilder f t n a) => StmtOutputFormat -> Target -> f -> IO Text
compileToLoweredStmt :: forall (n :: Nat) a (t :: FuncTy) f.
IsFuncBuilder f t n a =>
StmtOutputFormat -> Target -> f -> IO Text
compileToLoweredStmt StmtOutputFormat
format Target
target f
builder = do
forall (m :: * -> *) a.
(MonadIO m, MonadMask m) =>
[Char] -> ([Char] -> m a) -> m a
withSystemTempDirectory [Char]
"halide-haskell" forall a b. (a -> b) -> a -> b
$ \[Char]
dir -> do
let s :: ByteString
s = Text -> ByteString
encodeUtf8 ([Char] -> Text
pack ([Char]
dir forall a. Semigroup a => a -> a -> a
<> [Char]
"/code.stmt"))
o :: CInt
o = forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall a. Enum a => a -> Int
fromEnum StmtOutputFormat
format)
(Arguments (FunctionArguments f)
parameters, Func t n a
func) <- forall f (t :: FuncTy) (n :: Nat) a.
IsFuncBuilder f t n a =>
f -> IO (Arguments (FunctionArguments f), Func t n a)
buildFunc f
builder
forall (ts :: [*]) b.
(All ValidParameter ts, KnownNat (Length ts)) =>
Arguments ts -> (Ptr (CxxVector CxxArgument) -> IO b) -> IO b
prepareCxxArguments Arguments (FunctionArguments f)
parameters forall a b. (a -> b) -> a -> b
$ \Ptr (CxxVector CxxArgument)
v ->
forall (n :: Nat) a (t :: FuncTy) b.
(KnownNat n, IsHalideType a) =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
f ->
forall a. Target -> (Ptr CxxTarget -> IO a) -> IO a
withCxxTarget Target
target forall a b. (a -> b) -> a -> b
$ \Ptr CxxTarget
t ->
[C.throwBlock| void {
handle_halide_exceptions([=]() {
$(Halide::Func* f)->compile_to_lowered_stmt(
std::string{$bs-ptr:s, static_cast<size_t>($bs-len:s)},
*$(const std::vector<Halide::Argument>* v),
static_cast<Halide::StmtOutputFormat>($(int o)),
*$(Halide::Target* t));
});
} |]
[Char] -> IO Text
T.readFile ([Char]
dir forall a. Semigroup a => a -> a -> a
<> [Char]
"/code.stmt")