{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE UndecidableInstances #-}

-- |
-- Module      : Language.Halide.Kernel
-- Description : Compiling functions to kernels
-- Copyright   : (c) Tom Westerhout, 2023
module Language.Halide.Kernel
  ( compile
  , compileForTarget
  , compileToCallable
  , compileToLoweredStmt
  , StmtOutputFormat (..)
  , IsFuncBuilder
  , ReturnsFunc
  , Lowered
  )
where

import Control.Exception (bracket)
import Control.Monad.Primitive (touch)
import Control.Monad.ST (RealWorld)
import Data.IORef
import Data.Kind (Type)
import Data.Primitive.PrimArray (MutablePrimArray)
import Data.Primitive.PrimArray qualified as P
import Data.Primitive.Ptr qualified as P
import Data.Proxy
import Data.Text (Text, pack)
import Data.Text.Encoding (encodeUtf8)
import Data.Text.IO qualified as T
import Foreign.C.Types (CUIntPtr (..))
import Foreign.ForeignPtr
import Foreign.ForeignPtr.Unsafe
import Foreign.Ptr (FunPtr, Ptr, castPtr)
import Foreign.Storable
import GHC.TypeNats
import Language.C.Inline qualified as C
import Language.C.Inline.Cpp.Exception qualified as C
import Language.C.Inline.Unsafe qualified as CU
import Language.Halide.Buffer
import Language.Halide.Context
import Language.Halide.Expr
import Language.Halide.Func
import Language.Halide.RedundantConstraints
import Language.Halide.Target
import Language.Halide.Type
import System.IO.Temp (withSystemTempDirectory)
import Unsafe.Coerce (unsafeCoerce)

importHalide

data ArgvStorage s
  = ArgvStorage
      {-# UNPACK #-} !(MutablePrimArray s (Ptr ()))
      {-# UNPACK #-} !(MutablePrimArray s CUIntPtr)

newArgvStorage :: Int -> IO (ArgvStorage RealWorld)
newArgvStorage :: Int -> IO (ArgvStorage RealWorld)
newArgvStorage Int
n = forall s.
MutablePrimArray s (Ptr ())
-> MutablePrimArray s CUIntPtr -> ArgvStorage s
ArgvStorage forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
Int -> m (MutablePrimArray (PrimState m) a)
P.newPinnedPrimArray Int
n forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
Int -> m (MutablePrimArray (PrimState m) a)
P.newPinnedPrimArray Int
n

setArgvStorage
  :: (All ValidArgument inputs, All ValidArgument outputs)
  => ArgvStorage RealWorld
  -> Arguments inputs
  -> Arguments outputs
  -> IO ()
setArgvStorage :: forall (inputs :: [*]) (outputs :: [*]).
(All ValidArgument inputs, All ValidArgument outputs) =>
ArgvStorage RealWorld
-> Arguments inputs -> Arguments outputs -> IO ()
setArgvStorage (ArgvStorage MutablePrimArray RealWorld (Ptr ())
argv MutablePrimArray RealWorld CUIntPtr
scalarStorage) Arguments inputs
inputs Arguments outputs
outputs = do
  let argvPtr :: Ptr (Ptr ())
argvPtr = forall s a. MutablePrimArray s a -> Ptr a
P.mutablePrimArrayContents MutablePrimArray RealWorld (Ptr ())
argv
      scalarStoragePtr :: Ptr CUIntPtr
scalarStoragePtr = forall s a. MutablePrimArray s a -> Ptr a
P.mutablePrimArrayContents MutablePrimArray RealWorld CUIntPtr
scalarStorage
      go :: All ValidArgument ts' => Int -> Arguments ts' -> IO Int
      go :: forall (ts' :: [*]).
All ValidArgument ts' =>
Int -> Arguments ts' -> IO Int
go !Int
i Arguments ts'
Nil = forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
i
      go !Int
i ((t
x :: t) ::: Arguments ts
xs) = do
        forall t. ValidArgument t => Ptr () -> Ptr () -> t -> IO ()
fillSlot
          (forall a b. Ptr a -> Ptr b
castPtr forall a b. (a -> b) -> a -> b
$ Ptr (Ptr ())
argvPtr forall a. Prim a => Ptr a -> Int -> Ptr a
`P.advancePtr` Int
i)
          (forall a b. Ptr a -> Ptr b
castPtr forall a b. (a -> b) -> a -> b
$ Ptr CUIntPtr
scalarStoragePtr forall a. Prim a => Ptr a -> Int -> Ptr a
`P.advancePtr` Int
i)
          t
x
        forall (ts' :: [*]).
All ValidArgument ts' =>
Int -> Arguments ts' -> IO Int
go (Int
i forall a. Num a => a -> a -> a
+ Int
1) Arguments ts
xs
  Int
i <- forall (ts' :: [*]).
All ValidArgument ts' =>
Int -> Arguments ts' -> IO Int
go Int
0 Arguments inputs
inputs
  Int
_ <- forall (ts' :: [*]).
All ValidArgument ts' =>
Int -> Arguments ts' -> IO Int
go Int
i Arguments outputs
outputs
  forall (m :: * -> *) a. PrimMonad m => a -> m ()
touch MutablePrimArray RealWorld (Ptr ())
argv
  forall (m :: * -> *) a. PrimMonad m => a -> m ()
touch MutablePrimArray RealWorld CUIntPtr
scalarStorage

-- | Specifies that the type can be used as an argument to a kernel.
class ValidArgument (t :: Type) where
  fillSlot :: Ptr () -> Ptr () -> t -> IO ()

instance IsHalideType t => ValidArgument t where
  fillSlot :: Ptr () -> Ptr () -> t -> IO ()
  fillSlot :: Ptr () -> Ptr () -> t -> IO ()
fillSlot Ptr ()
argv Ptr ()
scalarStorage t
x = do
    forall a. Storable a => Ptr a -> a -> IO ()
poke (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
scalarStorage :: Ptr t) t
x
    forall a. Storable a => Ptr a -> a -> IO ()
poke (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
argv :: Ptr (Ptr ())) Ptr ()
scalarStorage
  {-# INLINE fillSlot #-}

instance {-# OVERLAPPING #-} ValidArgument (Ptr CxxUserContext) where
  fillSlot :: Ptr () -> Ptr () -> Ptr CxxUserContext -> IO ()
  fillSlot :: Ptr () -> Ptr () -> Ptr CxxUserContext -> IO ()
fillSlot Ptr ()
argv Ptr ()
scalarStorage Ptr CxxUserContext
x = do
    forall a. Storable a => Ptr a -> a -> IO ()
poke (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
scalarStorage :: Ptr (Ptr CxxUserContext)) Ptr CxxUserContext
x
    forall a. Storable a => Ptr a -> a -> IO ()
poke (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
argv :: Ptr (Ptr ())) Ptr ()
scalarStorage
  {-# INLINE fillSlot #-}

instance {-# OVERLAPPING #-} ValidArgument (Ptr (HalideBuffer n a)) where
  fillSlot :: Ptr () -> Ptr () -> Ptr (HalideBuffer n a) -> IO ()
  fillSlot :: Ptr () -> Ptr () -> Ptr (HalideBuffer n a) -> IO ()
fillSlot Ptr ()
argv Ptr ()
_ Ptr (HalideBuffer n a)
x = do
    forall a. Storable a => Ptr a -> a -> IO ()
poke (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
argv :: Ptr (Ptr (HalideBuffer n a))) Ptr (HalideBuffer n a)
x
  {-# INLINE fillSlot #-}

class ValidArgument (Lowered t) => ValidParameter (t :: Type) where
  appendToArgList :: Ptr (CxxVector CxxArgument) -> t -> IO ()
  prepareParameter :: IO t

instance IsHalideType a => ValidParameter (Expr a) where
  appendToArgList :: Ptr (CxxVector CxxArgument) -> Expr a -> IO ()
  appendToArgList :: Ptr (CxxVector CxxArgument) -> Expr a -> IO ()
appendToArgList Ptr (CxxVector CxxArgument)
v Expr a
expr =
    forall a b.
(HasCallStack, IsHalideType a) =>
Expr a -> (Ptr CxxParameter -> IO b) -> IO b
asScalarParam Expr a
expr forall a b. (a -> b) -> a -> b
$ \Ptr CxxParameter
p ->
      [CU.exp| void { $(std::vector<Halide::Argument>* v)->emplace_back(
        $(Halide::Internal::Parameter const* p)->name(),
        Halide::Argument::InputScalar,
        $(Halide::Internal::Parameter const* p)->type(),
        $(Halide::Internal::Parameter const* p)->dimensions(),
        $(Halide::Internal::Parameter const* p)->get_argument_estimates()) } |]
  prepareParameter :: IO (Expr a)
  prepareParameter :: IO (Expr a)
prepareParameter = forall {k} (a :: k).
IORef (Maybe (ForeignPtr CxxParameter)) -> Expr a
ScalarParam forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. a -> IO (IORef a)
newIORef forall a. Maybe a
Nothing

instance (KnownNat n, IsHalideType a) => ValidParameter (Func t n a) where
  appendToArgList :: Ptr (CxxVector CxxArgument) -> Func t n a -> IO ()
  appendToArgList :: Ptr (CxxVector CxxArgument) -> Func t n a -> IO ()
appendToArgList Ptr (CxxVector CxxArgument)
v func :: Func t n a
func@(Param IORef (Maybe (ForeignPtr CxxImageParam))
_) =
    forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
Func 'ParamTy n a -> (Ptr CxxImageParam -> IO b) -> IO b
withBufferParam Func t n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxImageParam
p ->
      [CU.exp| void { $(std::vector<Halide::Argument>* v)->push_back(
        *$(Halide::ImageParam const* p)) } |]
  appendToArgList Ptr (CxxVector CxxArgument)
_ Func t n a
_ = forall a. HasCallStack => [Char] -> a
error [Char]
"appendToArgList called on Func; this should never happen"
  prepareParameter :: IO (Func t n a)
  prepareParameter :: IO (Func t n a)
prepareParameter = forall a b. a -> b
unsafeCoerce forall a b. (a -> b) -> a -> b
$ forall (n :: Nat) a.
IORef (Maybe (ForeignPtr CxxImageParam)) -> Func 'ParamTy n a
Param forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. a -> IO (IORef a)
newIORef forall a. Maybe a
Nothing

class PrepareParameters ts where
  prepareParameters :: IO (Arguments ts)

instance PrepareParameters '[] where
  prepareParameters :: IO (Arguments '[])
  prepareParameters :: IO (Arguments '[])
prepareParameters = forall (f :: * -> *) a. Applicative f => a -> f a
pure Arguments '[]
Nil

instance (ValidParameter t, PrepareParameters ts) => PrepareParameters (t ': ts) where
  prepareParameters :: IO (Arguments (t : ts))
  prepareParameters :: IO (Arguments (t : ts))
prepareParameters = do
    t
t <- forall t. ValidParameter t => IO t
prepareParameter @t
    Arguments ts
ts <- forall (ts :: [*]). PrepareParameters ts => IO (Arguments ts)
prepareParameters @ts
    forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ t
t forall t (ts :: [*]). t -> Arguments ts -> Arguments (t : ts)
::: Arguments ts
ts

prepareCxxArguments
  :: forall ts b
   . (All ValidParameter ts, KnownNat (Length ts))
  => Arguments ts
  -> (Ptr (CxxVector CxxArgument) -> IO b)
  -> IO b
prepareCxxArguments :: forall (ts :: [*]) b.
(All ValidParameter ts, KnownNat (Length ts)) =>
Arguments ts -> (Ptr (CxxVector CxxArgument) -> IO b) -> IO b
prepareCxxArguments Arguments ts
args Ptr (CxxVector CxxArgument) -> IO b
action = do
  let count :: CSize
count = forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Nat
natVal (forall {k} (t :: k). Proxy t
Proxy @(Length ts)))
      allocate :: IO (Ptr (CxxVector CxxArgument))
allocate =
        [CU.block| std::vector<Halide::Argument>* {
          auto p = new std::vector<Halide::Argument>{};
          p->reserve($(size_t count));
          return p;
        } |]
      destroy :: Ptr (CxxVector CxxArgument) -> IO ()
destroy Ptr (CxxVector CxxArgument)
p = [CU.exp| void { delete $(std::vector<Halide::Argument>* p) } |]
  forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket IO (Ptr (CxxVector CxxArgument))
allocate Ptr (CxxVector CxxArgument) -> IO ()
destroy forall a b. (a -> b) -> a -> b
$ \Ptr (CxxVector CxxArgument)
v -> do
    let go :: All ValidParameter ts' => Arguments ts' -> IO ()
        go :: forall (ts' :: [*]).
All ValidParameter ts' =>
Arguments ts' -> IO ()
go Arguments ts'
Nil = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
        go (t
x ::: Arguments ts
xs) = forall t.
ValidParameter t =>
Ptr (CxxVector CxxArgument) -> t -> IO ()
appendToArgList Ptr (CxxVector CxxArgument)
v t
x forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (ts' :: [*]).
All ValidParameter ts' =>
Arguments ts' -> IO ()
go Arguments ts
xs
    forall (ts' :: [*]).
All ValidParameter ts' =>
Arguments ts' -> IO ()
go Arguments ts
args
    Ptr (CxxVector CxxArgument) -> IO b
action Ptr (CxxVector CxxArgument)
v

deleteCxxUserContext :: FunPtr (Ptr CxxUserContext -> IO ())
deleteCxxUserContext :: FunPtr (Ptr CxxUserContext -> IO ())
deleteCxxUserContext = [C.funPtr| void deleteUserContext(Halide::JITUserContext* p) { delete p; } |]

wrapCxxUserContext :: Ptr CxxUserContext -> IO (ForeignPtr CxxUserContext)
wrapCxxUserContext :: Ptr CxxUserContext -> IO (ForeignPtr CxxUserContext)
wrapCxxUserContext = forall a. FinalizerPtr a -> Ptr a -> IO (ForeignPtr a)
newForeignPtr FunPtr (Ptr CxxUserContext -> IO ())
deleteCxxUserContext

newEmptyCxxUserContext :: IO (ForeignPtr CxxUserContext)
newEmptyCxxUserContext :: IO (ForeignPtr CxxUserContext)
newEmptyCxxUserContext =
  Ptr CxxUserContext -> IO (ForeignPtr CxxUserContext)
wrapCxxUserContext forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| Halide::JITUserContext* { new Halide::JITUserContext{} } |]

wrapCxxCallable :: Ptr CxxCallable -> IO (Callable inputs outputs)
wrapCxxCallable :: forall (inputs :: [*]) outputs.
Ptr CxxCallable -> IO (Callable inputs outputs)
wrapCxxCallable = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (inputs :: [*]) output.
ForeignPtr CxxCallable -> Callable inputs output
Callable forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. FinalizerPtr a -> Ptr a -> IO (ForeignPtr a)
newForeignPtr FunPtr (Ptr CxxCallable -> IO ())
deleter
  where
    deleter :: FunPtr (Ptr CxxCallable -> IO ())
deleter = [C.funPtr| void deleteCallable(Halide::Callable* p) { delete p; } |]

type Lowered :: forall k. k -> k

-- | Specifies how t'Expr' and t'Func' parameters become scalar and buffer arguments in compiled kernels.
type family Lowered (t :: k) :: k where
  Lowered (Expr a) = a
  Lowered (Func t n a) = Ptr (HalideBuffer n a)
  Lowered '[] = '[]
  Lowered (Expr a ': ts) = (a ': Lowered ts)
  Lowered (Func t n a ': ts) = (Ptr (HalideBuffer n a) ': Lowered ts)

-- | A constraint that specifies that the function @f@ returns @'IO' ('Func' t n a)@.
class (FunctionReturn f ~ IO (Func t n a), IsHalideType a, KnownNat n) => ReturnsFunc f t n a | f -> t n a

instance (FunctionReturn f ~ IO (Func t n a), IsHalideType a, KnownNat n) => ReturnsFunc f t n a

type IsFuncBuilder f t n a =
  ( All ValidParameter (FunctionArguments f)
  , All ValidArgument (Lowered (FunctionArguments f))
  , UnCurry f (FunctionArguments f) (FunctionReturn f)
  , PrepareParameters (FunctionArguments f)
  , ReturnsFunc f t n a
  , KnownNat (Length (FunctionArguments f))
  , KnownNat (Length (Lowered (FunctionArguments f)))
  )

buildFunc :: (IsFuncBuilder f t n a) => f -> IO (Arguments (FunctionArguments f), Func t n a)
buildFunc :: forall f (t :: FuncTy) (n :: Nat) a.
IsFuncBuilder f t n a =>
f -> IO (Arguments (FunctionArguments f), Func t n a)
buildFunc f
builder = do
  Arguments (FunctionArguments f)
parameters <- forall (ts :: [*]). PrepareParameters ts => IO (Arguments ts)
prepareParameters
  Func t n a
func <- forall f (args :: [*]) r.
UnCurry f args r =>
f -> Arguments args -> r
uncurryG f
builder Arguments (FunctionArguments f)
parameters
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (Arguments (FunctionArguments f)
parameters, Func t n a
func)

newtype Callable (inputs :: [Type]) (output :: Type) = Callable (ForeignPtr CxxCallable)

compileToCallable
  :: forall n a t f inputs output
   . ( IsFuncBuilder f t n a
     , Lowered (FunctionArguments f) ~ inputs
     , Ptr (HalideBuffer n a) ~ output
     )
  => Target
  -> f
  -> IO (Callable inputs output)
compileToCallable :: forall (n :: Nat) a (t :: FuncTy) f (inputs :: [*]) output.
(IsFuncBuilder f t n a, Lowered (FunctionArguments f) ~ inputs,
 Ptr (HalideBuffer n a) ~ output) =>
Target -> f -> IO (Callable inputs output)
compileToCallable Target
target f
builder = do
  (Arguments (FunctionArguments f)
args, Func t n a
func) <- forall f (t :: FuncTy) (n :: Nat) a.
IsFuncBuilder f t n a =>
f -> IO (Arguments (FunctionArguments f), Func t n a)
buildFunc f
builder
  forall (ts :: [*]) b.
(All ValidParameter ts, KnownNat (Length ts)) =>
Arguments ts -> (Ptr (CxxVector CxxArgument) -> IO b) -> IO b
prepareCxxArguments Arguments (FunctionArguments f)
args forall a b. (a -> b) -> a -> b
$ \Ptr (CxxVector CxxArgument)
args' ->
    forall (n :: Nat) a (t :: FuncTy) b.
(KnownNat n, IsHalideType a) =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
func' ->
      forall a. Target -> (Ptr CxxTarget -> IO a) -> IO a
withCxxTarget Target
target forall a b. (a -> b) -> a -> b
$ \Ptr CxxTarget
target' ->
        forall (inputs :: [*]) outputs.
Ptr CxxCallable -> IO (Callable inputs outputs)
wrapCxxCallable
          forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [C.throwBlock| Halide::Callable* {
                return handle_halide_exceptions([=]() {
                  return new Halide::Callable{
                    $(Halide::Func* func')->compile_to_callable(
                      *$(const std::vector<Halide::Argument>* args'),
                      *$(const Halide::Target* target'))};
                });
              } |]
  where
    ()
_ = forall (c :: Constraint) (proxy :: Constraint -> *).
c =>
proxy c -> ()
keepRedundantConstraint (forall {k} (t :: k). Proxy t
Proxy @(Ptr (HalideBuffer n a) ~ output))

callableToFunction
  :: forall inputs output kernel
   . ( Curry inputs (output -> IO ()) kernel
     , KnownNat (Length inputs)
     , All ValidArgument inputs
     , ValidArgument output
     )
  => Callable inputs output
  -> IO kernel
callableToFunction :: forall (inputs :: [*]) output kernel.
(Curry inputs (output -> IO ()) kernel, KnownNat (Length inputs),
 All ValidArgument inputs, ValidArgument output) =>
Callable inputs output -> IO kernel
callableToFunction (Callable ForeignPtr CxxCallable
callable) = do
  ForeignPtr CxxUserContext
context <- IO (ForeignPtr CxxUserContext)
newEmptyCxxUserContext
  -- +1 comes from CxxUserContext and another +1 comes from output
  let argc :: CInt
argc = CInt
2 forall a. Num a => a -> a -> a
+ forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Nat
natVal (forall {k} (t :: k). Proxy t
Proxy @(Length inputs)))
  storage :: ArgvStorage RealWorld
storage@(ArgvStorage MutablePrimArray RealWorld (Ptr ())
argv MutablePrimArray RealWorld CUIntPtr
scalarStorage) <- Int -> IO (ArgvStorage RealWorld)
newArgvStorage (forall a b. (Integral a, Num b) => a -> b
fromIntegral CInt
argc)
  let argvPtr :: Ptr (Ptr ())
argvPtr = forall s a. MutablePrimArray s a -> Ptr a
P.mutablePrimArrayContents MutablePrimArray RealWorld (Ptr ())
argv
      contextPtr :: Ptr CxxUserContext
contextPtr = forall a. ForeignPtr a -> Ptr a
unsafeForeignPtrToPtr ForeignPtr CxxUserContext
context
      callablePtr :: Ptr CxxCallable
callablePtr = forall a. ForeignPtr a -> Ptr a
unsafeForeignPtrToPtr ForeignPtr CxxCallable
callable
      kernel :: Arguments inputs -> output -> IO ()
kernel Arguments inputs
args output
out = do
        forall (inputs :: [*]) (outputs :: [*]).
(All ValidArgument inputs, All ValidArgument outputs) =>
ArgvStorage RealWorld
-> Arguments inputs -> Arguments outputs -> IO ()
setArgvStorage ArgvStorage RealWorld
storage (Ptr CxxUserContext
contextPtr forall t (ts :: [*]). t -> Arguments ts -> Arguments (t : ts)
::: Arguments inputs
args) (output
out forall t (ts :: [*]). t -> Arguments ts -> Arguments (t : ts)
::: Arguments '[]
Nil)
        [CU.exp| void {
          handle_halide_exceptions([=]() {
            return $(Halide::Callable* callablePtr)->call_argv_fast(
              $(int argc), $(const void* const* argvPtr));
          })
        } |]
        forall (m :: * -> *) a. PrimMonad m => a -> m ()
touch MutablePrimArray RealWorld (Ptr ())
argv
        forall (m :: * -> *) a. PrimMonad m => a -> m ()
touch MutablePrimArray RealWorld CUIntPtr
scalarStorage
        forall (m :: * -> *) a. PrimMonad m => a -> m ()
touch ForeignPtr CxxUserContext
context
        forall (m :: * -> *) a. PrimMonad m => a -> m ()
touch ForeignPtr CxxCallable
callable
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (args :: [*]) r f.
Curry args r f =>
(Arguments args -> r) -> f
curryG @inputs @(output -> IO ()) Arguments inputs -> output -> IO ()
kernel

-- | Convert a function that builds a Halide 'Func' into a normal Haskell function acccepting scalars and
-- 'HalideBuffer's.
--
-- For example:
--
-- @
-- builder :: Expr Float -> Func 'ParamTy 1 Float -> IO (Func 'FuncTy 1 Float)
-- builder scale inputVector = do
--   i <- 'mkVar' "i"
--   scaledVector <- 'define' "scaledVector" i $ scale * inputVector '!' i
--   pure scaledVector
-- @
--
-- The @builder@ function accepts a scalar parameter and a vector and scales the vector by the given factor.
-- We can now pass @builder@ to 'compile':
--
-- @
-- scaler <- 'compile' builder
-- 'withHalideBuffer' @1 @Float [1, 1, 1] $ \inputVector ->
--   'allocaCpuBuffer' [3] $ \outputVector -> do
--     -- invoke the kernel
--     scaler 2.0 inputVector outputVector
--     -- print the result
--     print =<< 'peekToList' outputVector
-- @
compile
  :: forall n a t f kernel
   . ( IsFuncBuilder f t n a
     , Curry (Lowered (FunctionArguments f)) (Ptr (HalideBuffer n a) -> IO ()) kernel
     )
  => f
  -- ^ Function to compile
  -> IO kernel
  -- ^ Compiled kernel
compile :: forall (n :: Nat) a (t :: FuncTy) f kernel.
(IsFuncBuilder f t n a,
 Curry
   (Lowered (FunctionArguments f))
   (Ptr (HalideBuffer n a) -> IO ())
   kernel) =>
f -> IO kernel
compile = forall (n :: Nat) a (t :: FuncTy) f kernel.
(IsFuncBuilder f t n a,
 Curry
   (Lowered (FunctionArguments f))
   (Ptr (HalideBuffer n a) -> IO ())
   kernel) =>
Target -> f -> IO kernel
compileForTarget Target
hostTarget

-- | Similar to 'compile', but the first argument lets you explicitly specify the compilation target.
compileForTarget
  :: forall n a t f kernel
   . ( IsFuncBuilder f t n a
     , Curry (Lowered (FunctionArguments f)) (Ptr (HalideBuffer n a) -> IO ()) kernel
     )
  => Target
  -> f
  -> IO kernel
compileForTarget :: forall (n :: Nat) a (t :: FuncTy) f kernel.
(IsFuncBuilder f t n a,
 Curry
   (Lowered (FunctionArguments f))
   (Ptr (HalideBuffer n a) -> IO ())
   kernel) =>
Target -> f -> IO kernel
compileForTarget Target
target f
builder = forall (n :: Nat) a (t :: FuncTy) f (inputs :: [*]) output.
(IsFuncBuilder f t n a, Lowered (FunctionArguments f) ~ inputs,
 Ptr (HalideBuffer n a) ~ output) =>
Target -> f -> IO (Callable inputs output)
compileToCallable Target
target f
builder forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (inputs :: [*]) output kernel.
(Curry inputs (output -> IO ()) kernel, KnownNat (Length inputs),
 All ValidArgument inputs, ValidArgument output) =>
Callable inputs output -> IO kernel
callableToFunction

-- | Format in which to return the lowered code.
data StmtOutputFormat
  = -- | plain text
    StmtText
  | -- | HTML
    StmtHTML
  deriving stock (Int -> StmtOutputFormat -> ShowS
[StmtOutputFormat] -> ShowS
StmtOutputFormat -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [StmtOutputFormat] -> ShowS
$cshowList :: [StmtOutputFormat] -> ShowS
show :: StmtOutputFormat -> [Char]
$cshow :: StmtOutputFormat -> [Char]
showsPrec :: Int -> StmtOutputFormat -> ShowS
$cshowsPrec :: Int -> StmtOutputFormat -> ShowS
Show, StmtOutputFormat -> StmtOutputFormat -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: StmtOutputFormat -> StmtOutputFormat -> Bool
$c/= :: StmtOutputFormat -> StmtOutputFormat -> Bool
== :: StmtOutputFormat -> StmtOutputFormat -> Bool
$c== :: StmtOutputFormat -> StmtOutputFormat -> Bool
Eq)

instance Enum StmtOutputFormat where
  fromEnum :: StmtOutputFormat -> Int
fromEnum =
    forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. \case
      StmtOutputFormat
StmtText -> [CU.pure| int { static_cast<int>(Halide::StmtOutputFormat::Text) } |]
      StmtOutputFormat
StmtHTML -> [CU.pure| int { static_cast<int>(Halide::StmtOutputFormat::HTML) } |]
  toEnum :: Int -> StmtOutputFormat
toEnum Int
k
    | forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::StmtOutputFormat::Text) } |] = StmtOutputFormat
StmtText
    | forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::StmtOutputFormat::HTML) } |] = StmtOutputFormat
StmtHTML
    | Bool
otherwise = forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"invalid StmtOutputFormat " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> [Char]
show Int
k

-- | Get the internal representation of lowered code.
--
-- Useful for analyzing and debugging scheduling. Can emit HTML or plain text.
compileToLoweredStmt
  :: forall n a t f. (IsFuncBuilder f t n a) => StmtOutputFormat -> Target -> f -> IO Text
compileToLoweredStmt :: forall (n :: Nat) a (t :: FuncTy) f.
IsFuncBuilder f t n a =>
StmtOutputFormat -> Target -> f -> IO Text
compileToLoweredStmt StmtOutputFormat
format Target
target f
builder = do
  forall (m :: * -> *) a.
(MonadIO m, MonadMask m) =>
[Char] -> ([Char] -> m a) -> m a
withSystemTempDirectory [Char]
"halide-haskell" forall a b. (a -> b) -> a -> b
$ \[Char]
dir -> do
    let s :: ByteString
s = Text -> ByteString
encodeUtf8 ([Char] -> Text
pack ([Char]
dir forall a. Semigroup a => a -> a -> a
<> [Char]
"/code.stmt"))
        o :: CInt
o = forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall a. Enum a => a -> Int
fromEnum StmtOutputFormat
format)
    (Arguments (FunctionArguments f)
parameters, Func t n a
func) <- forall f (t :: FuncTy) (n :: Nat) a.
IsFuncBuilder f t n a =>
f -> IO (Arguments (FunctionArguments f), Func t n a)
buildFunc f
builder
    forall (ts :: [*]) b.
(All ValidParameter ts, KnownNat (Length ts)) =>
Arguments ts -> (Ptr (CxxVector CxxArgument) -> IO b) -> IO b
prepareCxxArguments Arguments (FunctionArguments f)
parameters forall a b. (a -> b) -> a -> b
$ \Ptr (CxxVector CxxArgument)
v ->
      forall (n :: Nat) a (t :: FuncTy) b.
(KnownNat n, IsHalideType a) =>
Func t n a -> (Ptr CxxFunc -> IO b) -> IO b
withFunc Func t n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
f ->
        forall a. Target -> (Ptr CxxTarget -> IO a) -> IO a
withCxxTarget Target
target forall a b. (a -> b) -> a -> b
$ \Ptr CxxTarget
t ->
          [C.throwBlock| void {
            handle_halide_exceptions([=]() {
              $(Halide::Func* f)->compile_to_lowered_stmt(
                std::string{$bs-ptr:s, static_cast<size_t>($bs-len:s)},
                *$(const std::vector<Halide::Argument>* v),
                static_cast<Halide::StmtOutputFormat>($(int o)),
                *$(Halide::Target* t));
            });
          } |]
    [Char] -> IO Text
T.readFile ([Char]
dir forall a. Semigroup a => a -> a -> a
<> [Char]
"/code.stmt")