{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE UndecidableInstances #-}

-- |
-- Module      : Language.Halide.Kernel
-- Description : Compiling functions to kernels
-- Copyright   : (c) Tom Westerhout, 2023
module Language.Halide.Kernel
  ( compile
  , compileForTarget
  , compileToCallable
  , compileToLoweredStmt
  , StmtOutputFormat (..)
  , LoweredSignature
  )
where

import Control.Exception (bracket)
import Control.Monad.Primitive (touch)
import Control.Monad.ST (RealWorld)
import Data.IORef
import Data.Kind (Type)
import Data.Primitive.PrimArray (MutablePrimArray)
import Data.Primitive.PrimArray qualified as P
import Data.Primitive.Ptr qualified as P
import Data.Proxy
import Data.Text (Text, pack)
import Data.Text.Encoding (encodeUtf8)
import Data.Text.IO qualified as T
import Foreign.C.Types (CUIntPtr (..))
import Foreign.ForeignPtr
import Foreign.ForeignPtr.Unsafe
import Foreign.Marshal.Array (allocaArray, peekArray)
import Foreign.Ptr (Ptr, castPtr)
import Foreign.Storable
import GHC.TypeNats
import Language.C.Inline qualified as C
import Language.C.Inline.Cpp.Exception qualified as C
import Language.C.Inline.Unsafe qualified as CU
import Language.Halide.Buffer
import Language.Halide.Context
import Language.Halide.Expr
import Language.Halide.Func
import Language.Halide.RedundantConstraints
import Language.Halide.Target
import Language.Halide.Type
import System.IO.Temp (withSystemTempDirectory)

-- | Haskell counterpart of @Halide::Argument@.
data CxxArgument

importHalide

instanceHasCxxVector "Halide::Argument"

data ArgvStorage s
  = ArgvStorage
      {-# UNPACK #-} !(MutablePrimArray s (Ptr ()))
      {-# UNPACK #-} !(MutablePrimArray s CUIntPtr)

newArgvStorage :: Int -> IO (ArgvStorage RealWorld)
newArgvStorage :: Int -> IO (ArgvStorage RealWorld)
newArgvStorage Int
n = forall s.
MutablePrimArray s (Ptr ())
-> MutablePrimArray s CUIntPtr -> ArgvStorage s
ArgvStorage forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
Int -> m (MutablePrimArray (PrimState m) a)
P.newPinnedPrimArray Int
n forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
Int -> m (MutablePrimArray (PrimState m) a)
P.newPinnedPrimArray Int
n

setArgvStorage
  :: All ValidArgument args
  => ArgvStorage RealWorld
  -> Ptr CxxUserContext
  -> Arguments args
  -> IO ()
setArgvStorage :: forall (args :: [*]).
All ValidArgument args =>
ArgvStorage RealWorld
-> Ptr CxxUserContext -> Arguments args -> IO ()
setArgvStorage (ArgvStorage MutablePrimArray RealWorld (Ptr ())
argv MutablePrimArray RealWorld CUIntPtr
scalarStorage) Ptr CxxUserContext
context Arguments args
inputs = do
  let argvPtr :: Ptr (Ptr ())
argvPtr = forall s a. MutablePrimArray s a -> Ptr a
P.mutablePrimArrayContents MutablePrimArray RealWorld (Ptr ())
argv
      scalarStoragePtr :: Ptr CUIntPtr
scalarStoragePtr = forall s a. MutablePrimArray s a -> Ptr a
P.mutablePrimArrayContents MutablePrimArray RealWorld CUIntPtr
scalarStorage
      go :: All ValidArgument args' => Int -> Arguments args' -> IO ()
      go :: forall (args' :: [*]).
All ValidArgument args' =>
Int -> Arguments args' -> IO ()
go !Int
_ Arguments args'
Nil = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
      go !Int
i (t
x ::: Arguments ts
xs) = do
        forall t. ValidArgument t => Ptr () -> Ptr () -> t -> IO ()
fillSlot
          (forall a b. Ptr a -> Ptr b
castPtr forall a b. (a -> b) -> a -> b
$ Ptr (Ptr ())
argvPtr forall a. Prim a => Ptr a -> Int -> Ptr a
`P.advancePtr` Int
i)
          (forall a b. Ptr a -> Ptr b
castPtr forall a b. (a -> b) -> a -> b
$ Ptr CUIntPtr
scalarStoragePtr forall a. Prim a => Ptr a -> Int -> Ptr a
`P.advancePtr` Int
i)
          t
x
        forall (args' :: [*]).
All ValidArgument args' =>
Int -> Arguments args' -> IO ()
go (Int
i forall a. Num a => a -> a -> a
+ Int
1) Arguments ts
xs
  forall (args' :: [*]).
All ValidArgument args' =>
Int -> Arguments args' -> IO ()
go Int
0 (Ptr CxxUserContext
context forall t (ts :: [*]). t -> Arguments ts -> Arguments (t : ts)
::: Arguments args
inputs)
  forall (m :: * -> *) a. PrimMonad m => a -> m ()
touch MutablePrimArray RealWorld (Ptr ())
argv
  forall (m :: * -> *) a. PrimMonad m => a -> m ()
touch MutablePrimArray RealWorld CUIntPtr
scalarStorage

-- | Specifies that the type can be used as an argument to a kernel.
class ValidArgument (t :: Type) where
  fillSlot :: Ptr () -> Ptr () -> t -> IO ()

instance IsHalideType t => ValidArgument t where
  fillSlot :: Ptr () -> Ptr () -> t -> IO ()
  fillSlot :: Ptr () -> Ptr () -> t -> IO ()
fillSlot Ptr ()
argv Ptr ()
scalarStorage t
x = do
    forall a. Storable a => Ptr a -> a -> IO ()
poke (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
scalarStorage :: Ptr t) t
x
    forall a. Storable a => Ptr a -> a -> IO ()
poke (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
argv :: Ptr (Ptr ())) Ptr ()
scalarStorage
  {-# INLINE fillSlot #-}

instance {-# OVERLAPPING #-} ValidArgument (Ptr CxxUserContext) where
  fillSlot :: Ptr () -> Ptr () -> Ptr CxxUserContext -> IO ()
  fillSlot :: Ptr () -> Ptr () -> Ptr CxxUserContext -> IO ()
fillSlot Ptr ()
argv Ptr ()
scalarStorage Ptr CxxUserContext
x = do
    forall a. Storable a => Ptr a -> a -> IO ()
poke (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
scalarStorage :: Ptr (Ptr CxxUserContext)) Ptr CxxUserContext
x
    forall a. Storable a => Ptr a -> a -> IO ()
poke (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
argv :: Ptr (Ptr ())) Ptr ()
scalarStorage
  {-# INLINE fillSlot #-}

instance {-# OVERLAPPING #-} ValidArgument (Ptr (HalideBuffer n a)) where
  fillSlot :: Ptr () -> Ptr () -> Ptr (HalideBuffer n a) -> IO ()
  fillSlot :: Ptr () -> Ptr () -> Ptr (HalideBuffer n a) -> IO ()
fillSlot Ptr ()
argv Ptr ()
_ Ptr (HalideBuffer n a)
x = do
    forall a. Storable a => Ptr a -> a -> IO ()
poke (forall a b. Ptr a -> Ptr b
castPtr Ptr ()
argv :: Ptr (Ptr (HalideBuffer n a))) Ptr (HalideBuffer n a)
x
  {-# INLINE fillSlot #-}

-- class ValidArgument (Lowered t) => ValidParameter (t :: Type) where
--   appendToArgList :: Ptr (CxxVector CxxArgument) -> t -> IO ()
--   prepareParameter :: IO t

-- instance IsHalideType a => ValidParameter (Expr a) where
--   appendToArgList :: Ptr (CxxVector CxxArgument) -> Expr a -> IO ()
--   appendToArgList v expr =
--     asScalarParam expr $ \p ->
--       [CU.exp| void { $(std::vector<Halide::Argument>* v)->emplace_back(
--         $(Halide::Internal::Parameter const* p)->name(),
--         Halide::Argument::InputScalar,
--         $(Halide::Internal::Parameter const* p)->type(),
--         $(Halide::Internal::Parameter const* p)->dimensions(),
--         $(Halide::Internal::Parameter const* p)->get_argument_estimates()) } |]
--   prepareParameter :: IO (Expr a)
--   prepareParameter = ScalarParam <$> newIORef Nothing

-- instance (KnownNat n, IsHalideType a, t ~ 'ParamTy) => ValidParameter (Func t n (Expr a)) where
--   appendToArgList :: Ptr (CxxVector CxxArgument) -> Func 'ParamTy n (Expr a) -> IO ()
--   appendToArgList v func@(Param _) =
--     withBufferParam func $ \p ->
--       [CU.exp| void { $(std::vector<Halide::Argument>* v)->push_back(
--         *$(Halide::ImageParam const* p)) } |]
--   prepareParameter = Param <$> newIORef Nothing

class KnownNat n => FuncBuilder f (n :: Nat) (a :: Type) | f -> n a where
  buildFunc :: Ptr (CxxVector CxxArgument) -> f -> IO (Func 'FuncTy n a)

instance (k ~ 'ParamTy, KnownNat m, IsHalideType t, FuncBuilder r n a) => FuncBuilder (Func k m (Expr t) -> r) n a where
  buildFunc :: Ptr (CxxVector CxxArgument)
-> (Func k m (Expr t) -> r) -> IO (Func 'FuncTy n a)
buildFunc Ptr (CxxVector CxxArgument)
v Func k m (Expr t) -> r
f = do
    Func 'ParamTy m (Expr t)
param <- forall a1 (n :: Nat).
IsHalideType a1 =>
IORef (Maybe (ForeignPtr CxxImageParam))
-> Func 'ParamTy n (Expr a1)
Param forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. a -> IO (IORef a)
newIORef forall a. Maybe a
Nothing
    Func 'FuncTy n a
func <- forall f (n :: Nat) a.
FuncBuilder f n a =>
Ptr (CxxVector CxxArgument) -> f -> IO (Func 'FuncTy n a)
buildFunc Ptr (CxxVector CxxArgument)
v (Func k m (Expr t) -> r
f Func 'ParamTy m (Expr t)
param)
    forall (n :: Nat) a b.
(HasCallStack, KnownNat n, IsHalideType a) =>
Func 'ParamTy n (Expr a) -> (Ptr CxxImageParam -> IO b) -> IO b
withBufferParam Func 'ParamTy m (Expr t)
param forall a b. (a -> b) -> a -> b
$ \Ptr CxxImageParam
p ->
      [CU.exp| void { $(std::vector<Halide::Argument>* v)->push_back(*$(Halide::ImageParam const* p)) } |]
    forall (f :: * -> *) a. Applicative f => a -> f a
pure Func 'FuncTy n a
func

instance (IsHalideType t, FuncBuilder r n a) => FuncBuilder (Expr t -> r) n a where
  buildFunc :: Ptr (CxxVector CxxArgument)
-> (Expr t -> r) -> IO (Func 'FuncTy n a)
buildFunc Ptr (CxxVector CxxArgument)
v Expr t -> r
f = do
    Expr t
param <- forall {k} (a :: k).
IORef (Maybe (ForeignPtr CxxParameter)) -> Expr a
ScalarParam forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. a -> IO (IORef a)
newIORef forall a. Maybe a
Nothing
    Func 'FuncTy n a
func <- forall f (n :: Nat) a.
FuncBuilder f n a =>
Ptr (CxxVector CxxArgument) -> f -> IO (Func 'FuncTy n a)
buildFunc Ptr (CxxVector CxxArgument)
v (Expr t -> r
f Expr t
param)
    forall a b.
(HasCallStack, IsHalideType a) =>
Expr a -> (Ptr CxxParameter -> IO b) -> IO b
asScalarParam Expr t
param forall a b. (a -> b) -> a -> b
$ \Ptr CxxParameter
p ->
      [CU.block| void {
        auto const& p = *$(Halide::Internal::Parameter const* p);
        $(std::vector<Halide::Argument>* v)->emplace_back(
          p.name(),
          Halide::Argument::InputScalar,
          p.type(),
          p.dimensions(),
          p.get_argument_estimates());
      } |]
    forall (f :: * -> *) a. Applicative f => a -> f a
pure Func 'FuncTy n a
func

instance (KnownNat n, t ~ 'FuncTy, n' ~ n, a' ~ a) => FuncBuilder (IO (Func t n' a')) n a where
  buildFunc :: Ptr (CxxVector CxxArgument)
-> IO (Func t n' a') -> IO (Func 'FuncTy n a)
buildFunc Ptr (CxxVector CxxArgument)
_ IO (Func t n' a')
action = IO (Func t n' a')
action

type family LoweredSignature f where
  LoweredSignature (Expr a -> r) = a -> LoweredSignature r
  LoweredSignature (Func t n (Expr a) -> r) = Ptr (HalideBuffer n a) -> LoweredSignature r
  LoweredSignature (IO (Func t n (Expr a))) = Ptr (HalideBuffer n a) -> IO ()
  LoweredSignature (IO (Func t n (Expr a1, Expr a2))) = Ptr (HalideBuffer n a1) -> Ptr (HalideBuffer n a2) -> IO ()

type IsHalideKernel f = (KnownNat (Length (FunctionArguments f)), All ValidArgument (FunctionArguments f), Curry (FunctionArguments f) (IO ()) f)

newtype Callable (signature :: Type) = Callable (ForeignPtr CxxCallable)

compileToCallable
  :: forall n a f
   . (FuncBuilder f n a, IsHalideKernel (LoweredSignature f))
  => Target
  -> f
  -> IO (Callable (LoweredSignature f))
compileToCallable :: forall (n :: Nat) a f.
(FuncBuilder f n a, IsHalideKernel (LoweredSignature f)) =>
Target -> f -> IO (Callable (LoweredSignature f))
compileToCallable Target
target f
builder =
  forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket (forall a. HasCxxVector a => Maybe Int -> IO (Ptr (CxxVector a))
newCxxVector forall a. Maybe a
Nothing) forall a. HasCxxVector a => Ptr (CxxVector a) -> IO ()
deleteCxxVector forall a b. (a -> b) -> a -> b
$ \Ptr (CxxVector CxxArgument)
v -> do
    Func 'FuncTy n a
func <- forall f (n :: Nat) a.
FuncBuilder f n a =>
Ptr (CxxVector CxxArgument) -> f -> IO (Func 'FuncTy n a)
buildFunc @f @n @a Ptr (CxxVector CxxArgument)
v f
builder
    forall (n :: Nat) a b.
KnownNat n =>
Func 'FuncTy n a -> (Ptr CxxFunc -> IO b) -> IO b
withCxxFunc Func 'FuncTy n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
func' ->
      forall a. Target -> (Ptr CxxTarget -> IO a) -> IO a
withCxxTarget Target
target forall a b. (a -> b) -> a -> b
$ \Ptr CxxTarget
target' ->
        forall signature. Ptr CxxCallable -> IO (Callable signature)
wrapCxxCallable
          forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [C.throwBlock| Halide::Callable* {
                return handle_halide_exceptions([=]() {
                  auto& func = *$(Halide::Func* func');
                  auto& args = *$(std::vector<Halide::Argument>* v);
                  auto const& target = *$(const Halide::Target* target');
                  std::reverse(std::begin(args), std::end(args));
                  return new Halide::Callable{func.compile_to_callable(args, target)};
                });
              } |]
  where
    Any (IsHalideKernel (LoweredSignature f)) -> ()
_ = forall (c :: Constraint) (proxy :: Constraint -> *).
c =>
proxy c -> ()
keepRedundantConstraint @(IsHalideKernel (LoweredSignature f))

callableToFunction :: forall f. IsHalideKernel f => Callable f -> IO f
callableToFunction :: forall f. IsHalideKernel f => Callable f -> IO f
callableToFunction (Callable ForeignPtr CxxCallable
callable) = do
  ForeignPtr CxxUserContext
context <- IO (ForeignPtr CxxUserContext)
newEmptyCxxUserContext
  -- +1 comes from CxxUserContext
  let argc :: CInt
argc = CInt
1 forall a. Num a => a -> a -> a
+ forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Nat
natVal (forall {k} (t :: k). Proxy t
Proxy @(Length (FunctionArguments f))))
  storage :: ArgvStorage RealWorld
storage@(ArgvStorage MutablePrimArray RealWorld (Ptr ())
argv MutablePrimArray RealWorld CUIntPtr
scalarStorage) <- Int -> IO (ArgvStorage RealWorld)
newArgvStorage (forall a b. (Integral a, Num b) => a -> b
fromIntegral CInt
argc)
  let argvPtr :: Ptr (Ptr ())
argvPtr = forall s a. MutablePrimArray s a -> Ptr a
P.mutablePrimArrayContents MutablePrimArray RealWorld (Ptr ())
argv
      contextPtr :: Ptr CxxUserContext
contextPtr = forall a. ForeignPtr a -> Ptr a
unsafeForeignPtrToPtr ForeignPtr CxxUserContext
context
      callablePtr :: Ptr CxxCallable
callablePtr = forall a. ForeignPtr a -> Ptr a
unsafeForeignPtrToPtr ForeignPtr CxxCallable
callable
      kernel :: Arguments (FunctionArguments f) -> IO ()
kernel Arguments (FunctionArguments f)
args = do
        forall (args :: [*]).
All ValidArgument args =>
ArgvStorage RealWorld
-> Ptr CxxUserContext -> Arguments args -> IO ()
setArgvStorage ArgvStorage RealWorld
storage Ptr CxxUserContext
contextPtr Arguments (FunctionArguments f)
args
        [CU.exp| void {
          handle_halide_exceptions([=]() {
            return $(Halide::Callable* callablePtr)->call_argv_fast(
              $(int argc), $(const void* const* argvPtr));
          })
        } |]
        forall (m :: * -> *) a. PrimMonad m => a -> m ()
touch MutablePrimArray RealWorld (Ptr ())
argv
        forall (m :: * -> *) a. PrimMonad m => a -> m ()
touch MutablePrimArray RealWorld CUIntPtr
scalarStorage
        forall (m :: * -> *) a. PrimMonad m => a -> m ()
touch ForeignPtr CxxUserContext
context
        forall (m :: * -> *) a. PrimMonad m => a -> m ()
touch ForeignPtr CxxCallable
callable
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall (args :: [*]) r f.
Curry args r f =>
(Arguments args -> r) -> f
curryG @(FunctionArguments f) @(IO ()) Arguments (FunctionArguments f) -> IO ()
kernel

-- class PrepareParameters ts where
--   prepareParameters :: IO (Arguments ts)
--
-- instance PrepareParameters '[] where
--   prepareParameters :: IO (Arguments '[])
--   prepareParameters = pure Nil
--
-- instance (ValidParameter t, PrepareParameters ts) => PrepareParameters (t ': ts) where
--   prepareParameters :: IO (Arguments (t : ts))
--   prepareParameters = do
--     t <- prepareParameter @t
--     ts <- prepareParameters @ts
--     pure $ t ::: ts

-- prepareCxxArguments
--   :: forall ts b
--    . (ValidParameters' ts, All ValidParameter ts, KnownNat (Length ts))
--   => Arguments ts
--   -> (Ptr (CxxVector CxxArgument) -> IO b)
--   -> IO b
-- prepareCxxArguments args action = do
--   let count = fromIntegral (natVal (Proxy @(Length ts)))
--       allocate =
--         [CU.block| std::vector<Halide::Argument>* {
--           auto p = new std::vector<Halide::Argument>{};
--           p->reserve($(size_t count));
--           return p;
--         } |]
--       destroy p = [CU.exp| void { delete $(std::vector<Halide::Argument>* p) } |]
--   bracket allocate destroy $ \v -> do
--     let go :: (All ValidParameter ts') => Arguments ts' -> IO ()
--         go Nil = pure ()
--         go (x ::: xs) = appendToArgList v x >> go xs
--     go args
--     action v

wrapCxxUserContext :: Ptr CxxUserContext -> IO (ForeignPtr CxxUserContext)
wrapCxxUserContext :: Ptr CxxUserContext -> IO (ForeignPtr CxxUserContext)
wrapCxxUserContext = forall a. FinalizerPtr a -> Ptr a -> IO (ForeignPtr a)
newForeignPtr FunPtr (Ptr CxxUserContext -> IO ())
deleter
  where
    deleter :: FunPtr (Ptr CxxUserContext -> IO ())
deleter = [C.funPtr| void deleteUserContext(Halide::JITUserContext* p) { delete p; } |]

newEmptyCxxUserContext :: IO (ForeignPtr CxxUserContext)
newEmptyCxxUserContext :: IO (ForeignPtr CxxUserContext)
newEmptyCxxUserContext =
  Ptr CxxUserContext -> IO (ForeignPtr CxxUserContext)
wrapCxxUserContext forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| Halide::JITUserContext* { new Halide::JITUserContext{} } |]

-- wrapCxxCallable :: Ptr CxxCallable -> IO (Callable inputs outputs)
-- wrapCxxCallable = fmap Callable . newForeignPtr deleter
--   where
--     deleter = [C.funPtr| void deleteCallable(Halide::Callable* p) { delete p; } |]

wrapCxxCallable :: Ptr CxxCallable -> IO (Callable signature)
wrapCxxCallable :: forall signature. Ptr CxxCallable -> IO (Callable signature)
wrapCxxCallable = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall signature. ForeignPtr CxxCallable -> Callable signature
Callable forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. FinalizerPtr a -> Ptr a -> IO (ForeignPtr a)
newForeignPtr FunPtr (Ptr CxxCallable -> IO ())
deleter
  where
    deleter :: FunPtr (Ptr CxxCallable -> IO ())
deleter = [C.funPtr| void deleteCallable(Halide::Callable* p) { delete p; } |]

-- class All ValidArgument (LoweredOutputs t) => IsOutput t

-- type Lowered :: forall k. k -> k

-- | Specifies how t'Expr' and t'Func' parameters become scalar and buffer arguments in compiled kernels.
-- type family Lowered (t :: k) :: k where
--   Lowered (Expr a) = a
--   Lowered (Func t n (Expr a)) = Ptr (HalideBuffer n a)
--   Lowered '[] = '[]
--   Lowered (t ': ts) = (Lowered t ': Lowered ts)

-- | A constraint that specifies that the function @f@ returns @'IO' ('Func' t n a)@.
-- class (FunctionReturn f ~ IO (Func 'FuncTy n a), KnownNat n) => ReturnsFunc f n a | f -> n a

-- instance (FunctionReturn f ~ IO (Func 'FuncTy n a), KnownNat n) => ReturnsFunc f n a

-- type family ValidParameters' (p :: [Type]) :: Constraint where
--   ValidParameters' (Expr a ': rest) = (IsHalideType a, ValidParameter (Expr a), ValidParameters' rest)
--   ValidParameters' (Func t n (Expr a) ': rest) = (t ~ 'ParamTy, IsHalideType a, ValidParameter (Func 'ParamTy n (Expr a)), ValidParameters' rest)
--   ValidParameters' (a ': rest) = (Bottom, ValidParameters' rest)
--   ValidParameters' '[] = ()

-- type IsFuncBuilder f n a =
--   ( ValidParameters' (FunctionArguments f)
--   , All ValidParameter (FunctionArguments f)
--   , All ValidArgument (Concat (Lowered (FunctionArguments f)) (LoweredOutputs (Func 'FuncTy n a)))
--   , UnCurry f (FunctionArguments f) (FunctionReturn f)
--   , PrepareParameters (FunctionArguments f)
--   , ReturnsFunc f n a
--   , KnownNat (Length (FunctionArguments f))
--   , KnownNat (Length (Lowered (FunctionArguments f)))
--   , KnownNat (Length (LoweredOutputs (Func 'FuncTy n a)))
--   )

-- buildFunc :: (IsFuncBuilder f n a) => f -> IO (Arguments (FunctionArguments f), Func 'FuncTy n a)
-- buildFunc builder = do
--   parameters <- prepareParameters
--   func <- uncurryG builder parameters
--   pure (parameters, func)

-- newtype Callable (inputs :: [Type]) (outputs :: [Type]) = Callable (ForeignPtr CxxCallable)

-- compileToCallable
--   :: forall n a f inputs outputs
--    . ( IsFuncBuilder f n a
--      , Lowered (FunctionArguments f) ~ inputs
--      , LoweredOutputs (Func 'FuncTy n a) ~ outputs
--      )
--   => Target
--   -> f
--   -> IO (Callable inputs outputs)
-- compileToCallable target builder = do
--   (args, func) <- buildFunc builder
--   prepareCxxArguments args $ \args' ->
--     case func of
--       Func fp ->
--         withForeignPtr fp $ \func' ->
--           withCxxTarget target $ \target' ->
--             wrapCxxCallable
--               =<< [C.throwBlock| Halide::Callable* {
--                     return handle_halide_exceptions([=]() {
--                       return new Halide::Callable{
--                         $(Halide::Func* func')->compile_to_callable(
--                           *$(const std::vector<Halide::Argument>* args'),
--                           *$(const Halide::Target* target'))};
--                     });
--                   } |]
--   where
--     _ = keepRedundantConstraint (Proxy @(LoweredOutputs (Func 'FuncTy n a) ~ outputs))

-- callableToFunction
--   :: forall inputs outputs kernel
--    . ( Curry (Concat inputs outputs) (IO ()) kernel
--      , KnownNat (Length inputs)
--      , KnownNat (Length outputs)
--      , All ValidArgument (Concat inputs outputs)
--      )
--   => Callable inputs outputs
--   -> IO kernel
-- callableToFunction (Callable callable) = do
--   context <- newEmptyCxxUserContext
--   -- +1 comes from CxxUserContext
--   let argc =
--         1
--           + fromIntegral (natVal (Proxy @(Length inputs)))
--           + fromIntegral (natVal (Proxy @(Length outputs)))
--   storage@(ArgvStorage argv scalarStorage) <- newArgvStorage (fromIntegral argc)
--   let argvPtr = P.mutablePrimArrayContents argv
--       contextPtr = unsafeForeignPtrToPtr context
--       callablePtr = unsafeForeignPtrToPtr callable
--       kernel args = do
--         setArgvStorage storage (contextPtr ::: args)
--         [CU.exp| void {
--           handle_halide_exceptions([=]() {
--             return $(Halide::Callable* callablePtr)->call_argv_fast(
--               $(int argc), $(const void* const* argvPtr));
--           })
--         } |]
--         touch argv
--         touch scalarStorage
--         touch context
--         touch callable
--   pure $ curryG @(Concat inputs outputs) @(IO ()) kernel

-- | Convert a function that builds a Halide 'Func' into a normal Haskell function acccepting scalars and
-- 'HalideBuffer's.
--
-- For example:
--
-- @
-- builder :: Expr Float -> Func 'ParamTy 1 Float -> IO (Func 'FuncTy 1 Float)
-- builder scale inputVector = do
--   i <- 'mkVar' "i"
--   scaledVector <- 'define' "scaledVector" i $ scale * inputVector '!' i
--   pure scaledVector
-- @
--
-- The @builder@ function accepts a scalar parameter and a vector and scales the vector by the given factor.
-- We can now pass @builder@ to 'compile':
--
-- @
-- scaler <- 'compile' builder
-- 'withHalideBuffer' @1 @Float [1, 1, 1] $ \inputVector ->
--   'allocaCpuBuffer' [3] $ \outputVector -> do
--     -- invoke the kernel
--     scaler 2.0 inputVector outputVector
--     -- print the result
--     print =<< 'peekToList' outputVector
-- @
compile
  :: forall f n a
   . (FuncBuilder f n a, IsHalideKernel (LoweredSignature f))
  => f
  -- ^ Function to compile
  -> IO (LoweredSignature f)
  -- ^ Compiled kernel
compile :: forall f (n :: Nat) a.
(FuncBuilder f n a, IsHalideKernel (LoweredSignature f)) =>
f -> IO (LoweredSignature f)
compile = forall f (n :: Nat) a.
(FuncBuilder f n a, IsHalideKernel (LoweredSignature f)) =>
Target -> f -> IO (LoweredSignature f)
compileForTarget Target
hostTarget

-- | Similar to 'compile', but the first argument lets you explicitly specify the compilation target.
compileForTarget
  :: forall f n a
   . (FuncBuilder f n a, IsHalideKernel (LoweredSignature f))
  => Target
  -> f
  -> IO (LoweredSignature f)
compileForTarget :: forall f (n :: Nat) a.
(FuncBuilder f n a, IsHalideKernel (LoweredSignature f)) =>
Target -> f -> IO (LoweredSignature f)
compileForTarget Target
target f
builder = forall (n :: Nat) a f.
(FuncBuilder f n a, IsHalideKernel (LoweredSignature f)) =>
Target -> f -> IO (Callable (LoweredSignature f))
compileToCallable Target
target f
builder forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall f. IsHalideKernel f => Callable f -> IO f
callableToFunction

-- | Format in which to return the lowered code.
data StmtOutputFormat
  = -- | plain text
    StmtText
  | -- | HTML
    StmtHTML
  deriving stock (Int -> StmtOutputFormat -> ShowS
[StmtOutputFormat] -> ShowS
StmtOutputFormat -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [StmtOutputFormat] -> ShowS
$cshowList :: [StmtOutputFormat] -> ShowS
show :: StmtOutputFormat -> String
$cshow :: StmtOutputFormat -> String
showsPrec :: Int -> StmtOutputFormat -> ShowS
$cshowsPrec :: Int -> StmtOutputFormat -> ShowS
Show, StmtOutputFormat -> StmtOutputFormat -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: StmtOutputFormat -> StmtOutputFormat -> Bool
$c/= :: StmtOutputFormat -> StmtOutputFormat -> Bool
== :: StmtOutputFormat -> StmtOutputFormat -> Bool
$c== :: StmtOutputFormat -> StmtOutputFormat -> Bool
Eq)

instance Enum StmtOutputFormat where
  fromEnum :: StmtOutputFormat -> Int
fromEnum =
    forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. \case
      StmtOutputFormat
StmtText -> [CU.pure| int { static_cast<int>(Halide::StmtOutputFormat::Text) } |]
      StmtOutputFormat
StmtHTML -> [CU.pure| int { static_cast<int>(Halide::StmtOutputFormat::HTML) } |]
  toEnum :: Int -> StmtOutputFormat
toEnum Int
k
    | forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::StmtOutputFormat::Text) } |] = StmtOutputFormat
StmtText
    | forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
k forall a. Eq a => a -> a -> Bool
== [CU.pure| int { static_cast<int>(Halide::StmtOutputFormat::HTML) } |] = StmtOutputFormat
StmtHTML
    | Bool
otherwise = forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"invalid StmtOutputFormat " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Int
k

-- | Get the internal representation of lowered code.
--
-- Useful for analyzing and debugging scheduling. Can emit HTML or plain text.
compileToLoweredStmt
  :: forall n a f. (FuncBuilder f n a) => StmtOutputFormat -> Target -> f -> IO Text
compileToLoweredStmt :: forall (n :: Nat) a f.
FuncBuilder f n a =>
StmtOutputFormat -> Target -> f -> IO Text
compileToLoweredStmt StmtOutputFormat
format Target
target f
builder = do
  forall (m :: * -> *) a.
(MonadIO m, MonadMask m) =>
String -> (String -> m a) -> m a
withSystemTempDirectory String
"halide-haskell" forall a b. (a -> b) -> a -> b
$ \String
dir -> do
    let s :: ByteString
s = Text -> ByteString
encodeUtf8 (String -> Text
pack (String
dir forall a. Semigroup a => a -> a -> a
<> String
"/code.stmt"))
        o :: CInt
o = forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall a. Enum a => a -> Int
fromEnum StmtOutputFormat
format)
    forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket (forall a. HasCxxVector a => Maybe Int -> IO (Ptr (CxxVector a))
newCxxVector forall a. Maybe a
Nothing) forall a. HasCxxVector a => Ptr (CxxVector a) -> IO ()
deleteCxxVector forall a b. (a -> b) -> a -> b
$ \Ptr (CxxVector CxxArgument)
v -> do
      Func 'FuncTy n a
func <- forall f (n :: Nat) a.
FuncBuilder f n a =>
Ptr (CxxVector CxxArgument) -> f -> IO (Func 'FuncTy n a)
buildFunc @f @n @a Ptr (CxxVector CxxArgument)
v f
builder
      forall (n :: Nat) a b.
KnownNat n =>
Func 'FuncTy n a -> (Ptr CxxFunc -> IO b) -> IO b
withCxxFunc Func 'FuncTy n a
func forall a b. (a -> b) -> a -> b
$ \Ptr CxxFunc
func' ->
        forall a. Target -> (Ptr CxxTarget -> IO a) -> IO a
withCxxTarget Target
target forall a b. (a -> b) -> a -> b
$ \Ptr CxxTarget
target' ->
          [C.throwBlock| void {
            handle_halide_exceptions([=]() {
              auto& func = *$(Halide::Func* func');
              auto& args = *$(std::vector<Halide::Argument>* v);
              auto const& target = *$(const Halide::Target* target');
              std::reverse(std::begin(args), std::end(args));

              func.compile_to_lowered_stmt(
                std::string{$bs-ptr:s, static_cast<size_t>($bs-len:s)},
                args,
                static_cast<Halide::StmtOutputFormat>($(int o)),
                target);
            });
          } |]
    String -> IO Text
T.readFile (String
dir forall a. Semigroup a => a -> a -> a
<> String
"/code.stmt")